source: trunk/pg.py @ 785

Last change on this file since 785 was 785, checked in by cito, 3 years ago

Make all tests run again with Python 3

  • Property svn:keywords set to Id
File size: 59.7 KB
Line 
1#! /usr/bin/python
2#
3# pg.py
4#
5# $Id: pg.py 785 2016-01-26 17:50:41Z cito $
6#
7
8"""PyGreSQL classic interface.
9
10This pg module implements some basic database management stuff.
11It includes the _pg module and builds on it, providing the higher
12level wrapper class named DB with additional functionality.
13This is known as the "classic" ("old style") PyGreSQL interface.
14For a DB-API 2 compliant interface use the newer pgdb module.
15"""
16
17# Copyright (c) 1997-2016 by D'Arcy J.M. Cain.
18#
19# Contributions made by Ch. Zwerschke and others.
20#
21# The notification handler is based on pgnotify which is
22# Copyright (c) 2001 Ng Pheng Siong. All rights reserved.
23#
24# Permission to use, copy, modify, and distribute this software and its
25# documentation for any purpose and without fee is hereby granted,
26# provided that the above copyright notice appear in all copies and that
27# both that copyright notice and this permission notice appear in
28# supporting documentation.
29
30from __future__ import print_function
31
32from _pg import *
33
34import select
35import warnings
36
37from decimal import Decimal
38from collections import namedtuple
39from functools import partial
40from operator import itemgetter
41from re import compile as regex
42from json import loads as jsondecode, dumps as jsonencode
43
44try:
45    basestring
46except NameError:  # Python >= 3.0
47    basestring = (str, bytes)
48
49try:
50    from collections import OrderedDict
51except ImportError:  # Python 2.6 or 3.0
52    OrderedDict = dict
53
54
55    class AttrDict(dict):
56        """Simple read-only ordered dictionary for storing attribute names."""
57
58        def __init__(self, *args, **kw):
59            if len(args) > 1 or kw:
60                raise TypeError
61            items = args[0] if args else []
62            if isinstance(items, dict):
63                raise TypeError
64            items = list(items)
65            self._keys = [item[0] for item in items]
66            dict.__init__(self, items)
67            self._read_only = True
68            error = self._read_only_error
69            self.clear = self.update = error
70            self.pop = self.setdefault = self.popitem = error
71
72        def __setitem__(self, key, value):
73            if self._read_only:
74                self._read_only_error()
75            dict.__setitem__(self, key, value)
76
77        def __delitem__(self, key):
78            if self._read_only:
79                self._read_only_error()
80            dict.__delitem__(self, key)
81
82        def __iter__(self):
83            return iter(self._keys)
84
85        def keys(self):
86            return list(self._keys)
87
88        def values(self):
89            return [self[key] for key in self]
90
91        def items(self):
92            return [(key, self[key]) for key in self]
93
94        def iterkeys(self):
95            return self.__iter__()
96
97        def itervalues(self):
98            return iter(self.values())
99
100        def iteritems(self):
101            return iter(self.items())
102
103        @staticmethod
104        def _read_only_error(*args, **kw):
105            raise TypeError('This object is read-only')
106
107else:
108
109     class AttrDict(OrderedDict):
110        """Simple read-only ordered dictionary for storing attribute names."""
111
112        def __init__(self, *args, **kw):
113            self._read_only = False
114            OrderedDict.__init__(self, *args, **kw)
115            self._read_only = True
116            error = self._read_only_error
117            self.clear = self.update = error
118            self.pop = self.setdefault = self.popitem = error
119
120        def __setitem__(self, key, value):
121            if self._read_only:
122                self._read_only_error()
123            OrderedDict.__setitem__(self, key, value)
124
125        def __delitem__(self, key):
126            if self._read_only:
127                self._read_only_error()
128            OrderedDict.__delitem__(self, key)
129
130        @staticmethod
131        def _read_only_error(*args, **kw):
132            raise TypeError('This object is read-only')
133
134
135# Auxiliary classes and functions that are independent from a DB connection:
136
137def _oid_key(table):
138    """Build oid key from a table name."""
139    return 'oid(%s)' % table
140
141
142class _SimpleType(dict):
143    """Dictionary mapping pg_type names to simple type names."""
144
145    _types = {'bool': 'bool',
146        'bytea': 'bytea',
147        'date': 'date interval time timetz timestamp timestamptz'
148            ' abstime reltime',  # these are very old
149        'float': 'float4 float8',
150        'int': 'cid int2 int4 int8 oid xid',
151        'json': 'json jsonb',
152        'num': 'numeric',
153        'money': 'money',
154        'text': 'bpchar char name text varchar'}
155
156    def __init__(self):
157        for typ, keys in self._types.items():
158            for key in keys.split():
159                self[key] = typ
160                self['_%s' % key] = '%s[]' % typ
161
162    @staticmethod
163    def __missing__(key):
164        return 'text'
165
166_simpletype = _SimpleType()
167
168
169class _Literal(str):
170    """Wrapper class for literal SQL."""
171
172
173def _namedresult(q):
174    """Get query result as named tuples."""
175    row = namedtuple('Row', q.listfields())
176    return [row(*r) for r in q.getresult()]
177
178
179class _MemoryQuery:
180    """Class that embodies a given query result."""
181
182    def __init__(self, result, fields):
183        """Create query from given result rows and field names."""
184        self.result = result
185        self.fields = fields
186
187    def listfields(self):
188        """Return the stored field names of this query."""
189        return self.fields
190
191    def getresult(self):
192        """Return the stored result of this query."""
193        return self.result
194
195
196def _db_error(msg, cls=DatabaseError):
197    """Return DatabaseError with empty sqlstate attribute."""
198    error = cls(msg)
199    error.sqlstate = None
200    return error
201
202
203def _int_error(msg):
204    """Return InternalError."""
205    return _db_error(msg, InternalError)
206
207
208def _prg_error(msg):
209    """Return ProgrammingError."""
210    return _db_error(msg, ProgrammingError)
211
212
213# Initialize the C module
214
215set_namedresult(_namedresult)
216set_decimal(Decimal)
217set_jsondecode(jsondecode)
218
219
220# The notification handler
221
222class NotificationHandler(object):
223    """A PostgreSQL client-side asynchronous notification handler."""
224
225    def __init__(self, db, event, callback=None,
226            arg_dict=None, timeout=None, stop_event=None):
227        """Initialize the notification handler.
228
229        You must pass a PyGreSQL database connection, the name of an
230        event (notification channel) to listen for and a callback function.
231
232        You can also specify a dictionary arg_dict that will be passed as
233        the single argument to the callback function, and a timeout value
234        in seconds (a floating point number denotes fractions of seconds).
235        If it is absent or None, the callers will never time out.  If the
236        timeout is reached, the callback function will be called with a
237        single argument that is None.  If you set the timeout to zero,
238        the handler will poll notifications synchronously and return.
239
240        You can specify the name of the event that will be used to signal
241        the handler to stop listening as stop_event. By default, it will
242        be the event name prefixed with 'stop_'.
243        """
244        self.db = db
245        self.event = event
246        self.stop_event = stop_event or 'stop_%s' % event
247        self.listening = False
248        self.callback = callback
249        if arg_dict is None:
250            arg_dict = {}
251        self.arg_dict = arg_dict
252        self.timeout = timeout
253
254    def __del__(self):
255        self.unlisten()
256
257    def close(self):
258        """Stop listening and close the connection."""
259        if self.db:
260            self.unlisten()
261            self.db.close()
262            self.db = None
263
264    def listen(self):
265        """Start listening for the event and the stop event."""
266        if not self.listening:
267            self.db.query('listen "%s"' % self.event)
268            self.db.query('listen "%s"' % self.stop_event)
269            self.listening = True
270
271    def unlisten(self):
272        """Stop listening for the event and the stop event."""
273        if self.listening:
274            self.db.query('unlisten "%s"' % self.event)
275            self.db.query('unlisten "%s"' % self.stop_event)
276            self.listening = False
277
278    def notify(self, db=None, stop=False, payload=None):
279        """Generate a notification.
280
281        Optionally, you can pass a payload with the notification.
282
283        If you set the stop flag, a stop notification will be sent that
284        will cause the handler to stop listening.
285
286        Note: If the notification handler is running in another thread, you
287        must pass a different database connection since PyGreSQL database
288        connections are not thread-safe.
289        """
290        if self.listening:
291            if not db:
292                db = self.db
293            q = 'notify "%s"' % (self.stop_event if stop else self.event)
294            if payload:
295                q += ", '%s'" % payload
296            return db.query(q)
297
298    def __call__(self):
299        """Invoke the notification handler.
300
301        The handler is a loop that listens for notifications on the event
302        and stop event channels.  When either of these notifications are
303        received, its associated 'pid', 'event' and 'extra' (the payload
304        passed with the notification) are inserted into its arg_dict
305        dictionary and the callback is invoked with this dictionary as
306        a single argument.  When the handler receives a stop event, it
307        stops listening to both events and return.
308
309        In the special case that the timeout of the handler has been set
310        to zero, the handler will poll all events synchronously and return.
311        If will keep listening until it receives a stop event.
312
313        Note: If you run this loop in another thread, don't use the same
314        database connection for database operations in the main thread.
315        """
316        self.listen()
317        poll = self.timeout == 0
318        if not poll:
319            rlist = [self.db.fileno()]
320        while self.listening:
321            if poll or select.select(rlist, [], [], self.timeout)[0]:
322                while self.listening:
323                    notice = self.db.getnotify()
324                    if not notice:  # no more messages
325                        break
326                    event, pid, extra = notice
327                    if event not in (self.event, self.stop_event):
328                        self.unlisten()
329                        raise _db_error(
330                            'Listening for "%s" and "%s", but notified of "%s"'
331                            % (self.event, self.stop_event, event))
332                    if event == self.stop_event:
333                        self.unlisten()
334                    self.arg_dict.update(pid=pid, event=event, extra=extra)
335                    self.callback(self.arg_dict)
336                if poll:
337                    break
338            else:   # we timed out
339                self.unlisten()
340                self.callback(None)
341
342
343def pgnotify(*args, **kw):
344    """Same as NotificationHandler, under the traditional name."""
345    warnings.warn("pgnotify is deprecated, use NotificationHandler instead",
346        DeprecationWarning, stacklevel=2)
347    return NotificationHandler(*args, **kw)
348
349
350# The actual PostGreSQL database connection interface:
351
352class DB(object):
353    """Wrapper class for the _pg connection type."""
354
355    def __init__(self, *args, **kw):
356        """Create a new connection
357
358        You can pass either the connection parameters or an existing
359        _pg or pgdb connection. This allows you to use the methods
360        of the classic pg interface with a DB-API 2 pgdb connection.
361        """
362        if not args and len(kw) == 1:
363            db = kw.get('db')
364        elif not kw and len(args) == 1:
365            db = args[0]
366        else:
367            db = None
368        if db:
369            if isinstance(db, DB):
370                db = db.db
371            else:
372                try:
373                    db = db._cnx
374                except AttributeError:
375                    pass
376        if not db or not hasattr(db, 'db') or not hasattr(db, 'query'):
377            db = connect(*args, **kw)
378            self._closeable = True
379        else:
380            self._closeable = False
381        self.db = db
382        self.dbname = db.db
383        self._regtypes = False
384        self._attnames = {}
385        self._pkeys = {}
386        self._privileges = {}
387        self._args = args, kw
388        self.debug = None  # For debugging scripts, this can be set
389            # * to a string format specification (e.g. in CGI set to "%s<BR>"),
390            # * to a file object to write debug statements or
391            # * to a callable object which takes a string argument
392            # * to any other true value to just print debug statements
393
394    def __getattr__(self, name):
395        # All undefined members are same as in underlying connection:
396        if self.db:
397            return getattr(self.db, name)
398        else:
399            raise _int_error('Connection is not valid')
400
401    def __dir__(self):
402        # Custom dir function including the attributes of the connection:
403        attrs = set(self.__class__.__dict__)
404        attrs.update(self.__dict__)
405        attrs.update(dir(self.db))
406        return sorted(attrs)
407
408    # Context manager methods
409
410    def __enter__(self):
411        """Enter the runtime context. This will start a transactio."""
412        self.begin()
413        return self
414
415    def __exit__(self, et, ev, tb):
416        """Exit the runtime context. This will end the transaction."""
417        if et is None and ev is None and tb is None:
418            self.commit()
419        else:
420            self.rollback()
421
422    # Auxiliary methods
423
424    def _do_debug(self, *args):
425        """Print a debug message"""
426        if self.debug:
427            s = '\n'.join(str(arg) for arg in args)
428            if isinstance(self.debug, basestring):
429                print(self.debug % s)
430            elif hasattr(self.debug, 'write'):
431                self.debug.write(s + '\n')
432            elif callable(self.debug):
433                self.debug(s)
434            else:
435                print(s)
436
437    def _escape_qualified_name(self, s):
438        """Escape a qualified name.
439
440        Escapes the name for use as an SQL identifier, unless the
441        name contains a dot, in which case the name is ambiguous
442        (could be a qualified name or just a name with a dot in it)
443        and must be quoted manually by the caller.
444        """
445        if '.' not in s:
446            s = self.escape_identifier(s)
447        return s
448
449    @staticmethod
450    def _make_bool(d):
451        """Get boolean value corresponding to d."""
452        return bool(d) if get_bool() else ('t' if d else 'f')
453
454    _bool_true_values = frozenset('t true 1 y yes on'.split())
455
456    def _prepare_bool(self, d):
457        """Prepare a boolean parameter."""
458        if isinstance(d, basestring):
459            if not d:
460                return None
461            d = d.lower() in self._bool_true_values
462        return 't' if d else 'f'
463
464    _date_literals = frozenset('current_date current_time'
465        ' current_timestamp localtime localtimestamp'.split())
466
467    def _prepare_date(self, d):
468        """Prepare a date parameter."""
469        if not d:
470            return None
471        if isinstance(d, basestring) and d.lower() in self._date_literals:
472            return _Literal(d)
473        return d
474
475    _num_types = frozenset('int float num money'
476        ' int2 int4 int8 float4 float8 numeric money'.split())
477
478    @staticmethod
479    def _prepare_num(d):
480        """Prepare a numeric parameter."""
481        if not d and d != 0:
482            return None
483        return d
484
485    _prepare_int = _prepare_float = _prepare_money = _prepare_num
486
487    def _prepare_bytea(self, d):
488        """Prepare a bytea parameter."""
489        return self.escape_bytea(d)
490
491    def _prepare_json(self, d):
492        """Prepare a json parameter."""
493        if not d:
494            return None
495        if isinstance(d, basestring):
496            return d
497        return self.encode_json(d)
498
499    _re_array_escape = regex(r'(["\\])')
500    _re_array_quote = regex(r'[{},"\\\s]|^[Nn][Uu][Ll][Ll]$')
501
502    def _prepare_bool_array(self, d):
503        """Prepare a bool array parameter."""
504        if isinstance(d, list):
505            return '{%s}' % ','.join(self._prepare_bool_array(v) for v in d)
506        if d is None:
507            return 'null'
508        if isinstance(d, basestring):
509            if not d:
510                return 'null'
511            d = d.lower() in self._bool_true_values
512        return 't' if d else 'f'
513
514    def _prepare_num_array(self, d):
515        """Prepare a numeric array parameter."""
516        if isinstance(d, list):
517            return '{%s}' % ','.join(self._prepare_num_array(v) for v in d)
518        if not d and d != 0:
519            return 'null'
520        return str(d)
521
522    _prepare_int_array = _prepare_float_array = _prepare_money_array = \
523            _prepare_num_array
524
525    def _prepare_text_array(self, d):
526        """Prepare a text array parameter."""
527        if isinstance(d, list):
528            return '{%s}' % ','.join(self._prepare_text_array(v) for v in d)
529        if d is None:
530            return 'null'
531        if not d:
532            return '""'
533        d = str(d)
534        if self._re_array_quote.search(d):
535            d = '"%s"' % self._re_array_escape.sub(r'\\\1', d)
536        return d
537
538    def _prepare_bytea_array(self, d):
539        """Prepare a bytea array parameter."""
540        if isinstance(d, list):
541            return b'{' + b','.join(
542                self._prepare_bytea_array(v) for v in d) + b'}'
543        if d is None:
544            return b'null'
545        return self.escape_bytea(d).replace(b'\\', b'\\\\')
546
547    def _prepare_json_array(self, d):
548        """Prepare a json array parameter."""
549        if isinstance(d, list):
550            return '{%s}' % ','.join(self._prepare_json_array(v) for v in d)
551        if not d:
552            return 'null'
553        if not isinstance(d, basestring):
554            d = self.encode_json(d)
555        if self._re_array_quote.search(d):
556            d = '"%s"' % self._re_array_escape.sub(r'\\\1', d)
557        return d
558
559    def _prepare_param(self, value, typ, params):
560        """Prepare and add a parameter to the list."""
561        if isinstance(value, _Literal):
562            return value
563        if value is not None and typ != 'text':
564            if typ.endswith('[]'):
565                if isinstance(value, list):
566                    prepare = getattr(self, '_prepare_%s_array' % typ[:-2])
567                    value = prepare(value)
568                elif isinstance(value, basestring):
569                    value = value.strip()
570                    if not value.startswith('{') or not value.endswith('}'):
571                        if value[:5].lower() == 'array':
572                            value = value[5:].lstrip()
573                        if value.startswith('[') and value.endswith(']'):
574                            value = _Literal('ARRAY%s' % value)
575                        else:
576                            raise ValueError(
577                                'Invalid array expression: %s' % value)
578                else:
579                    raise ValueError('Invalid array parameter: %s' % value)
580            else:
581                prepare = getattr(self, '_prepare_%s' % typ)
582                value = prepare(value)
583            if isinstance(value, _Literal):
584                return value
585        params.append(value)
586        return '$%d' % len(params)
587
588    def _list_params(self, params):
589        """Create a human readable parameter list."""
590        return ', '.join('$%d=%r' % (n, v) for n, v in enumerate(params, 1))
591
592    @staticmethod
593    def _prepare_qualified_param(name, param):
594        """Quote parameter representing a qualified name.
595
596        Escapes the name for use as an SQL parameter, unless the
597        name contains a dot, in which case the name is ambiguous
598        (could be a qualified name or just a name with a dot in it)
599        and must be quoted manually by the caller.
600
601        """
602        if isinstance(param, int):
603            param = "$%d" % param
604        if '.' not in name:
605            param = 'quote_ident(%s)' % (param,)
606        return param
607
608    # Public methods
609
610    # escape_string and escape_bytea exist as methods,
611    # so we define unescape_bytea as a method as well
612    unescape_bytea = staticmethod(unescape_bytea)
613
614    def decode_json(self, s):
615        """Decode a JSON string coming from the database."""
616        return (get_jsondecode() or jsondecode)(s)
617
618    def encode_json(self, d):
619        """Encode a JSON string for use within SQL."""
620        return jsonencode(d)
621
622    def close(self):
623        """Close the database connection."""
624        # Wraps shared library function so we can track state.
625        if self._closeable:
626            if self.db:
627                self.db.close()
628                self.db = None
629            else:
630                raise _int_error('Connection already closed')
631
632    def reset(self):
633        """Reset connection with current parameters.
634
635        All derived queries and large objects derived from this connection
636        will not be usable after this call.
637
638        """
639        if self.db:
640            self.db.reset()
641        else:
642            raise _int_error('Connection already closed')
643
644    def reopen(self):
645        """Reopen connection to the database.
646
647        Used in case we need another connection to the same database.
648        Note that we can still reopen a database that we have closed.
649
650        """
651        # There is no such shared library function.
652        if self._closeable:
653            db = connect(*self._args[0], **self._args[1])
654            if self.db:
655                self.db.close()
656            self.db = db
657
658    def begin(self, mode=None):
659        """Begin a transaction."""
660        qstr = 'BEGIN'
661        if mode:
662            qstr += ' ' + mode
663        return self.query(qstr)
664
665    start = begin
666
667    def commit(self):
668        """Commit the current transaction."""
669        return self.query('COMMIT')
670
671    end = commit
672
673    def rollback(self, name=None):
674        """Roll back the current transaction."""
675        qstr = 'ROLLBACK'
676        if name:
677            qstr += ' TO ' + name
678        return self.query(qstr)
679
680    abort = rollback
681
682    def savepoint(self, name):
683        """Define a new savepoint within the current transaction."""
684        return self.query('SAVEPOINT ' + name)
685
686    def release(self, name):
687        """Destroy a previously defined savepoint."""
688        return self.query('RELEASE ' + name)
689
690    def get_parameter(self, parameter):
691        """Get the value of a run-time parameter.
692
693        If the parameter is a string, the return value will also be a string
694        that is the current setting of the run-time parameter with that name.
695
696        You can get several parameters at once by passing a list, set or dict.
697        When passing a list of parameter names, the return value will be a
698        corresponding list of parameter settings.  When passing a set of
699        parameter names, a new dict will be returned, mapping these parameter
700        names to their settings.  Finally, if you pass a dict as parameter,
701        its values will be set to the current parameter settings corresponding
702        to its keys.
703
704        By passing the special name 'all' as the parameter, you can get a dict
705        of all existing configuration parameters.
706        """
707        if isinstance(parameter, basestring):
708            parameter = [parameter]
709            values = None
710        elif isinstance(parameter, (list, tuple)):
711            values = []
712        elif isinstance(parameter, (set, frozenset)):
713            values = {}
714        elif isinstance(parameter, dict):
715            values = parameter
716        else:
717            raise TypeError(
718                'The parameter must be a string, list, set or dict')
719        if not parameter:
720            raise TypeError('No parameter has been specified')
721        params = {} if isinstance(values, dict) else []
722        for key in parameter:
723            param = key.strip().lower() if isinstance(
724                key, basestring) else None
725            if not param:
726                raise TypeError('Invalid parameter')
727            if param == 'all':
728                q = 'SHOW ALL'
729                values = self.db.query(q).getresult()
730                values = dict(value[:2] for value in values)
731                break
732            if isinstance(values, dict):
733                params[param] = key
734            else:
735                params.append(param)
736        else:
737            for param in params:
738                q = 'SHOW %s' % (param,)
739                value = self.db.query(q).getresult()[0][0]
740                if values is None:
741                    values = value
742                elif isinstance(values, list):
743                    values.append(value)
744                else:
745                    values[params[param]] = value
746        return values
747
748    def set_parameter(self, parameter, value=None, local=False):
749        """Set the value of a run-time parameter.
750
751        If the parameter and the value are strings, the run-time parameter
752        will be set to that value.  If no value or None is passed as a value,
753        then the run-time parameter will be restored to its default value.
754
755        You can set several parameters at once by passing a list of parameter
756        names, together with a single value that all parameters should be
757        set to or with a corresponding list of values.  You can also pass
758        the parameters as a set if you only provide a single value.
759        Finally, you can pass a dict with parameter names as keys.  In this
760        case, you should not pass a value, since the values for the parameters
761        will be taken from the dict.
762
763        By passing the special name 'all' as the parameter, you can reset
764        all existing settable run-time parameters to their default values.
765
766        If you set local to True, then the command takes effect for only the
767        current transaction.  After commit() or rollback(), the session-level
768        setting takes effect again.  Setting local to True will appear to
769        have no effect if it is executed outside a transaction, since the
770        transaction will end immediately.
771        """
772        if isinstance(parameter, basestring):
773            parameter = {parameter: value}
774        elif isinstance(parameter, (list, tuple)):
775            if isinstance(value, (list, tuple)):
776                parameter = dict(zip(parameter, value))
777            else:
778                parameter = dict.fromkeys(parameter, value)
779        elif isinstance(parameter, (set, frozenset)):
780            if isinstance(value, (list, tuple, set, frozenset)):
781                value = set(value)
782                if len(value) == 1:
783                    value = value.pop()
784            if not(value is None or isinstance(value, basestring)):
785                raise ValueError('A single value must be specified'
786                    ' when parameter is a set')
787            parameter = dict.fromkeys(parameter, value)
788        elif isinstance(parameter, dict):
789            if value is not None:
790                raise ValueError('A value must not be specified'
791                    ' when parameter is a dictionary')
792        else:
793            raise TypeError(
794                'The parameter must be a string, list, set or dict')
795        if not parameter:
796            raise TypeError('No parameter has been specified')
797        params = {}
798        for key, value in parameter.items():
799            param = key.strip().lower() if isinstance(
800                key, basestring) else None
801            if not param:
802                raise TypeError('Invalid parameter')
803            if param == 'all':
804                if value is not None:
805                    raise ValueError('A value must ot be specified'
806                        " when parameter is 'all'")
807                params = {'all': None}
808                break
809            params[param] = value
810        local = ' LOCAL' if local else ''
811        for param, value in params.items():
812            if value is None:
813                q = 'RESET%s %s' % (local, param)
814            else:
815                q = 'SET%s %s TO %s' % (local, param, value)
816            self._do_debug(q)
817            self.db.query(q)
818
819    def query(self, command, *args):
820        """Execute a SQL command string.
821
822        This method simply sends a SQL query to the database.  If the query is
823        an insert statement that inserted exactly one row into a table that
824        has OIDs, the return value is the OID of the newly inserted row.
825        If the query is an update or delete statement, or an insert statement
826        that did not insert exactly one row in a table with OIDs, then the
827        number of rows affected is returned as a string.  If it is a statement
828        that returns rows as a result (usually a select statement, but maybe
829        also an "insert/update ... returning" statement), this method returns
830        a Query object that can be accessed via getresult() or dictresult()
831        or simply printed.  Otherwise, it returns `None`.
832
833        The query can contain numbered parameters of the form $1 in place
834        of any data constant.  Arguments given after the query string will
835        be substituted for the corresponding numbered parameter.  Parameter
836        values can also be given as a single list or tuple argument.
837        """
838        # Wraps shared library function for debugging.
839        if not self.db:
840            raise _int_error('Connection is not valid')
841        if args:
842            self._do_debug(command, args)
843            return self.db.query(command, args)
844        self._do_debug(command)
845        return self.db.query(command)
846
847    def pkey(self, table, composite=False, flush=False):
848        """Get or set the primary key of a table.
849
850        Single primary keys are returned as strings unless you
851        set the composite flag.  Composite primary keys are always
852        represented as tuples.  Note that this raises a KeyError
853        if the table does not have a primary key.
854
855        If flush is set then the internal cache for primary keys will
856        be flushed.  This may be necessary after the database schema or
857        the search path has been changed.
858        """
859        pkeys = self._pkeys
860        if flush:
861            pkeys.clear()
862            self._do_debug('The pkey cache has been flushed')
863        try:  # cache lookup
864            pkey = pkeys[table]
865        except KeyError:  # cache miss, check the database
866            q = ("SELECT a.attname, a.attnum, i.indkey FROM pg_index i"
867                " JOIN pg_attribute a ON a.attrelid = i.indrelid"
868                " AND a.attnum = ANY(i.indkey)"
869                " AND NOT a.attisdropped"
870                " WHERE i.indrelid=%s::regclass"
871                " AND i.indisprimary ORDER BY a.attnum") % (
872                    self._prepare_qualified_param(table, 1),)
873            pkey = self.db.query(q, (table,)).getresult()
874            if not pkey:
875                raise KeyError('Table %s has no primary key' % table)
876            # we want to use the order defined in the primary key index here,
877            # not the order as defined by the columns in the table
878            if len(pkey) > 1:
879                indkey = [int(k) for k in pkey[0][2].split()]
880                pkey = sorted(pkey, key=lambda row: indkey.index(row[1]))
881                pkey = tuple(row[0] for row in pkey)
882            else:
883                pkey = pkey[0][0]
884            pkeys[table] = pkey  # cache it
885        if composite and not isinstance(pkey, tuple):
886            pkey = (pkey,)
887        return pkey
888
889    def get_databases(self):
890        """Get list of databases in the system."""
891        return [s[0] for s in
892            self.db.query('SELECT datname FROM pg_database').getresult()]
893
894    def get_relations(self, kinds=None):
895        """Get list of relations in connected database of specified kinds.
896
897        If kinds is None or empty, all kinds of relations are returned.
898        Otherwise kinds can be a string or sequence of type letters
899        specifying which kind of relations you want to list.
900        """
901        where = " AND r.relkind IN (%s)" % ','.join(
902            ["'%s'" % k for k in kinds]) if kinds else ''
903        q = ("SELECT quote_ident(s.nspname)||'.'||quote_ident(r.relname)"
904            " FROM pg_class r"
905            " JOIN pg_namespace s ON s.oid = r.relnamespace"
906            " WHERE s.nspname NOT SIMILAR"
907            " TO 'pg/_%%|information/_schema' ESCAPE '/' %s"
908            " ORDER BY s.nspname, r.relname") % where
909        return [r[0] for r in self.db.query(q).getresult()]
910
911    def get_tables(self):
912        """Return list of tables in connected database."""
913        return self.get_relations('r')
914
915    def get_attnames(self, table, flush=False):
916        """Given the name of a table, dig out the set of attribute names.
917
918        Returns a read-only dictionary of attribute names (the names are
919        the keys, the values are the names of the attributes' types)
920        with the column names in the proper order if you iterate over it.
921
922        If flush is set, then the internal cache for attribute names will
923        be flushed. This may be necessary after the database schema or
924        the search path has been changed.
925
926        By default, only a limited number of simple types will be returned.
927        You can get the regular types after calling use_regtypes(True).
928        """
929        attnames = self._attnames
930        if flush:
931            attnames.clear()
932            self._do_debug('The attnames cache has been flushed')
933        try:  # cache lookup
934            names = attnames[table]
935        except KeyError:  # cache miss, check the database
936            q = ("SELECT a.attname, t.typname%s"
937                " FROM pg_attribute a"
938                " JOIN pg_type t ON t.oid = a.atttypid"
939                " WHERE a.attrelid = %s::regclass"
940                " AND (a.attnum > 0 OR a.attname = 'oid')"
941                " AND NOT a.attisdropped ORDER BY a.attnum") % (
942                    '::regtype' if self._regtypes else '',
943                    self._prepare_qualified_param(table, 1))
944            names = self.db.query(q, (table,)).getresult()
945            if not self._regtypes:
946                names = ((name, _simpletype[typ]) for name, typ in names)
947            names = AttrDict(names)
948            attnames[table] = names  # cache it
949        return names
950
951    def use_regtypes(self, regtypes=None):
952        """Use regular type names instead of simplified type names."""
953        if regtypes is None:
954            return self._regtypes
955        else:
956            regtypes = bool(regtypes)
957            if regtypes != self._regtypes:
958                self._regtypes = regtypes
959                self._attnames.clear()
960            return regtypes
961
962    def has_table_privilege(self, table, privilege='select'):
963        """Check whether current user has specified table privilege."""
964        privilege = privilege.lower()
965        try:  # ask cache
966            return self._privileges[(table, privilege)]
967        except KeyError:  # cache miss, ask the database
968            q = "SELECT has_table_privilege(%s, $2)" % (
969                self._prepare_qualified_param(table, 1),)
970            q = self.db.query(q, (table, privilege))
971            ret = q.getresult()[0][0] == self._make_bool(True)
972            self._privileges[(table, privilege)] = ret  # cache it
973            return ret
974
975    def get(self, table, row, keyname=None):
976        """Get a row from a database table or view.
977
978        This method is the basic mechanism to get a single row.  It assumes
979        that the keyname specifies a unique row.  It must be the name of a
980        single column or a tuple of column names.  If the keyname is not
981        specified, then the primary key for the table is used.
982
983        If row is a dictionary, then the value for the key is taken from it.
984        Otherwise, the row must be a single value or a tuple of values
985        corresponding to the passed keyname or primary key.  The fetched row
986        from the table will be returned as a new dictionary or used to replace
987        the existing values when row was passed as aa dictionary.
988
989        The OID is also put into the dictionary if the table has one, but
990        in order to allow the caller to work with multiple tables, it is
991        munged as "oid(table)" using the actual name of the table.
992        """
993        if table.endswith('*'):  # hint for descendant tables can be ignored
994            table = table[:-1].rstrip()
995        attnames = self.get_attnames(table)
996        qoid = _oid_key(table) if 'oid' in attnames else None
997        if keyname and isinstance(keyname, basestring):
998            keyname = (keyname,)
999        if qoid and isinstance(row, dict) and qoid in row and 'oid' not in row:
1000            row['oid'] = row[qoid]
1001        if not keyname:
1002            try:  # if keyname is not specified, try using the primary key
1003                keyname = self.pkey(table, True)
1004            except KeyError:  # the table has no primary key
1005                # try using the oid instead
1006                if qoid and isinstance(row, dict) and 'oid' in row:
1007                    keyname = ('oid',)
1008                else:
1009                    raise _prg_error('Table %s has no primary key' % table)
1010            else:  # the table has a primary key
1011                # check whether all key columns have values
1012                if isinstance(row, dict) and not set(keyname).issubset(row):
1013                    # try using the oid instead
1014                    if qoid and 'oid' in row:
1015                        keyname = ('oid',)
1016                    else:
1017                        raise KeyError(
1018                            'Missing value in row for specified keyname')
1019        if not isinstance(row, dict):
1020            if not isinstance(row, (tuple, list)):
1021                row = [row]
1022            if len(keyname) != len(row):
1023                raise KeyError(
1024                    'Differing number of items in keyname and row')
1025            row = dict(zip(keyname, row))
1026        params = []
1027        param = partial(self._prepare_param, params=params)
1028        col = self.escape_identifier
1029        what = 'oid, *' if qoid else '*'
1030        where = ' AND '.join('%s = %s' % (
1031            col(k), param(row[k], attnames[k])) for k in keyname)
1032        if 'oid' in row:
1033            if qoid:
1034                row[qoid] = row['oid']
1035            del row['oid']
1036        q = 'SELECT %s FROM %s WHERE %s LIMIT 1' % (
1037            what, self._escape_qualified_name(table), where)
1038        self._do_debug(q, params)
1039        q = self.db.query(q, params)
1040        res = q.dictresult()
1041        if not res:
1042            raise _db_error('No such record in %s\nwhere %s\nwith %s' % (
1043                table, where, self._list_params(params)))
1044        for n, value in res[0].items():
1045            if qoid and n == 'oid':
1046                n = qoid
1047            row[n] = value
1048        return row
1049
1050    def insert(self, table, row=None, **kw):
1051        """Insert a row into a database table.
1052
1053        This method inserts a row into a table.  The name of the table must
1054        be passed as the first parameter.  The other parameters are used for
1055        providing the data of the row that shall be inserted into the table.
1056        If a dictionary is supplied as the second parameter, it starts with
1057        that.  Otherwise it uses a blank dictionary. Either way the dictionary
1058        is updated from the keywords.
1059
1060        The dictionary is then reloaded with the values actually inserted in
1061        order to pick up values modified by rules, triggers, etc.
1062        """
1063        if table.endswith('*'):  # hint for descendant tables can be ignored
1064            table = table[:-1].rstrip()
1065        if row is None:
1066            row = {}
1067        row.update(kw)
1068        if 'oid' in row:
1069            del row['oid']  # do not insert oid
1070        attnames = self.get_attnames(table)
1071        qoid = _oid_key(table) if 'oid' in attnames else None
1072        params = []
1073        param = partial(self._prepare_param, params=params)
1074        col = self.escape_identifier
1075        names, values = [], []
1076        for n in attnames:
1077            if n in row:
1078                names.append(col(n))
1079                values.append(param(row[n], attnames[n]))
1080        if not names:
1081            raise _prg_error('No column found that can be inserted')
1082        names, values = ', '.join(names), ', '.join(values)
1083        ret = 'oid, *' if qoid else '*'
1084        q = 'INSERT INTO %s (%s) VALUES (%s) RETURNING %s' % (
1085            self._escape_qualified_name(table), names, values, ret)
1086        self._do_debug(q, params)
1087        q = self.db.query(q, params)
1088        res = q.dictresult()
1089        if res:  # this should always be true
1090            for n, value in res[0].items():
1091                if qoid and n == 'oid':
1092                    n = qoid
1093                row[n] = value
1094        return row
1095
1096    def update(self, table, row=None, **kw):
1097        """Update an existing row in a database table.
1098
1099        Similar to insert but updates an existing row.  The update is based
1100        on the primary key of the table or the OID value as munged by get
1101        or passed as keyword.
1102
1103        The dictionary is then modified to reflect any changes caused by the
1104        update due to triggers, rules, default values, etc.
1105        """
1106        if table.endswith('*'):
1107            table = table[:-1].rstrip()  # need parent table name
1108        attnames = self.get_attnames(table)
1109        qoid = _oid_key(table) if 'oid' in attnames else None
1110        if row is None:
1111            row = {}
1112        elif 'oid' in row:
1113            del row['oid']  # only accept oid key from named args for safety
1114        row.update(kw)
1115        if qoid and qoid in row and 'oid' not in row:
1116            row['oid'] = row[qoid]
1117        try:  # try using the primary key
1118            keyname = self.pkey(table, True)
1119        except KeyError:  # the table has no primary key
1120            # try using the oid instead
1121            if qoid and 'oid' in row:
1122                keyname = ('oid',)
1123            else:
1124                raise _prg_error('Table %s has no primary key' % table)
1125        else:  # the table has a primary key
1126            # check whether all key columns have values
1127            if not set(keyname).issubset(row):
1128                # try using the oid instead
1129                if qoid and 'oid' in row:
1130                    keyname = ('oid',)
1131                else:
1132                    raise KeyError('Missing primary key in row')
1133        params = []
1134        param = partial(self._prepare_param, params=params)
1135        col = self.escape_identifier
1136        where = ' AND '.join('%s = %s' % (
1137            col(k), param(row[k], attnames[k])) for k in keyname)
1138        if 'oid' in row:
1139            if qoid:
1140                row[qoid] = row['oid']
1141            del row['oid']
1142        values = []
1143        keyname = set(keyname)
1144        for n in attnames:
1145            if n in row and n not in keyname:
1146                values.append('%s = %s' % (col(n), param(row[n], attnames[n])))
1147        if not values:
1148            return row
1149        values = ', '.join(values)
1150        ret = 'oid, *' if qoid else '*'
1151        q = 'UPDATE %s SET %s WHERE %s RETURNING %s' % (
1152            self._escape_qualified_name(table), values, where, ret)
1153        self._do_debug(q, params)
1154        q = self.db.query(q, params)
1155        res = q.dictresult()
1156        if res:  # may be empty when row does not exist
1157            for n, value in res[0].items():
1158                if qoid and n == 'oid':
1159                    n = qoid
1160                row[n] = value
1161        return row
1162
1163    def upsert(self, table, row=None, **kw):
1164        """Insert a row into a database table with conflict resolution
1165
1166        This method inserts a row into a table, but instead of raising a
1167        ProgrammingError exception in case a row with the same primary key
1168        already exists, an update will be executed instead.  This will be
1169        performed as a single atomic operation on the database, so race
1170        conditions can be avoided.
1171
1172        Like the insert method, the first parameter is the name of the
1173        table and the second parameter can be used to pass the values to
1174        be inserted as a dictionary.
1175
1176        Unlike the insert und update statement, keyword parameters are not
1177        used to modify the dictionary, but to specify which columns shall
1178        be updated in case of a conflict, and in which way:
1179
1180        A value of False or None means the column shall not be updated,
1181        a value of True means the column shall be updated with the value
1182        that has been proposed for insertion, i.e. has been passed as value
1183        in the dictionary.  Columns that are not specified by keywords but
1184        appear as keys in the dictionary are also updated like in the case
1185        keywords had been passed with the value True.
1186
1187        So if in the case of a conflict you want to update every column that
1188        has been passed in the dictionary row , you would call upsert(table, row).
1189        If you don't want to do anything in case of a conflict, i.e. leave
1190        the existing row as it is, call upsert(table, row, **dict.fromkeys(row)).
1191
1192        If you need more fine-grained control of what gets updated, you can
1193        also pass strings in the keyword parameters.  These strings will
1194        be used as SQL expressions for the update columns.  In these
1195        expressions you can refer to the value that already exists in
1196        the table by prefixing the column name with "included.", and to
1197        the value that has been proposed for insertion by prefixing the
1198        column name with the "excluded."
1199
1200        The dictionary is modified in any case to reflect the values in
1201        the database after the operation has completed.
1202
1203        Note: The method uses the PostgreSQL "upsert" feature which is
1204        only available since PostgreSQL 9.5.
1205        """
1206        if table.endswith('*'):  # hint for descendant tables can be ignored
1207            table = table[:-1].rstrip()
1208        if row is None:
1209            row = {}
1210        if 'oid' in row:
1211            del row['oid']  # do not insert oid
1212        if 'oid' in kw:
1213            del kw['oid']  # do not update oid
1214        attnames = self.get_attnames(table)
1215        qoid = _oid_key(table) if 'oid' in attnames else None
1216        params = []
1217        param = partial(self._prepare_param,params=params)
1218        col = self.escape_identifier
1219        names, values, updates = [], [], []
1220        for n in attnames:
1221            if n in row:
1222                names.append(col(n))
1223                values.append(param(row[n], attnames[n]))
1224        names, values = ', '.join(names), ', '.join(values)
1225        try:
1226            keyname = self.pkey(table, True)
1227        except KeyError:
1228            raise _prg_error('Table %s has no primary key' % table)
1229        target = ', '.join(col(k) for k in keyname)
1230        update = []
1231        keyname = set(keyname)
1232        keyname.add('oid')
1233        for n in attnames:
1234            if n not in keyname:
1235                value = kw.get(n, True)
1236                if value:
1237                    if not isinstance(value, basestring):
1238                        value = 'excluded.%s' % col(n)
1239                    update.append('%s = %s' % (col(n), value))
1240        if not values:
1241            return row
1242        do = 'update set %s' % ', '.join(update) if update else 'nothing'
1243        ret = 'oid, *' if qoid else '*'
1244        q = ('INSERT INTO %s AS included (%s) VALUES (%s)'
1245            ' ON CONFLICT (%s) DO %s RETURNING %s') % (
1246                self._escape_qualified_name(table), names, values,
1247                target, do, ret)
1248        self._do_debug(q, params)
1249        try:
1250            q = self.db.query(q, params)
1251        except ProgrammingError:
1252            if self.server_version < 90500:
1253                raise _prg_error(
1254                    'Upsert operation is not supported by PostgreSQL version')
1255            raise  # re-raise original error
1256        res = q.dictresult()
1257        if res:  # may be empty with "do nothing"
1258            for n, value in res[0].items():
1259                if qoid and n == 'oid':
1260                    n = qoid
1261                row[n] = value
1262        else:
1263            self.get(table, row)
1264        return row
1265
1266    def clear(self, table, row=None):
1267        """Clear all the attributes to values determined by the types.
1268
1269        Numeric types are set to 0, Booleans are set to false, and everything
1270        else is set to the empty string.  If the row argument is present,
1271        it is used as the row dictionary and any entries matching attribute
1272        names are cleared with everything else left unchanged.
1273        """
1274        # At some point we will need a way to get defaults from a table.
1275        if row is None:
1276            row = {}  # empty if argument is not present
1277        attnames = self.get_attnames(table)
1278        for n, t in attnames.items():
1279            if n == 'oid':
1280                continue
1281            if t in self._num_types:
1282                row[n] = 0
1283            elif t == 'bool':
1284                row[n] = self._make_bool(False)
1285            else:
1286                row[n] = ''
1287        return row
1288
1289    def delete(self, table, row=None, **kw):
1290        """Delete an existing row in a database table.
1291
1292        This method deletes the row from a table.  It deletes based on the
1293        primary key of the table or the OID value as munged by get() or
1294        passed as keyword.
1295
1296        The return value is the number of deleted rows (i.e. 0 if the row
1297        did not exist and 1 if the row was deleted).
1298
1299        Note that if the row cannot be deleted because e.g. it is still
1300        referenced by another table, this method raises a ProgrammingError.
1301        """
1302        if table.endswith('*'):  # hint for descendant tables can be ignored
1303            table = table[:-1].rstrip()
1304        attnames = self.get_attnames(table)
1305        qoid = _oid_key(table) if 'oid' in attnames else None
1306        if row is None:
1307            row = {}
1308        elif 'oid' in row:
1309            del row['oid']  # only accept oid key from named args for safety
1310        row.update(kw)
1311        if qoid and qoid in row and 'oid' not in row:
1312            row['oid'] = row[qoid]
1313        try:  # try using the primary key
1314            keyname = self.pkey(table, True)
1315        except KeyError:  # the table has no primary key
1316            # try using the oid instead
1317            if qoid and 'oid' in row:
1318                keyname = ('oid',)
1319            else:
1320                raise _prg_error('Table %s has no primary key' % table)
1321        else:  # the table has a primary key
1322            # check whether all key columns have values
1323            if not set(keyname).issubset(row):
1324                # try using the oid instead
1325                if qoid and 'oid' in row:
1326                    keyname = ('oid',)
1327                else:
1328                    raise KeyError('Missing primary key in row')
1329        params = []
1330        param = partial(self._prepare_param, params=params)
1331        col = self.escape_identifier
1332        where = ' AND '.join('%s = %s' % (
1333            col(k), param(row[k], attnames[k])) for k in keyname)
1334        if 'oid' in row:
1335            if qoid:
1336                row[qoid] = row['oid']
1337            del row['oid']
1338        q = 'DELETE FROM %s WHERE %s' % (
1339            self._escape_qualified_name(table), where)
1340        self._do_debug(q, params)
1341        res = self.db.query(q, params)
1342        return int(res)
1343
1344    def truncate(self, table, restart=False, cascade=False, only=False):
1345        """Empty a table or set of tables.
1346
1347        This method quickly removes all rows from the given table or set
1348        of tables.  It has the same effect as an unqualified DELETE on each
1349        table, but since it does not actually scan the tables it is faster.
1350        Furthermore, it reclaims disk space immediately, rather than requiring
1351        a subsequent VACUUM operation. This is most useful on large tables.
1352
1353        If restart is set to True, sequences owned by columns of the truncated
1354        table(s) are automatically restarted.  If cascade is set to True, it
1355        also truncates all tables that have foreign-key references to any of
1356        the named tables.  If the parameter only is not set to True, all the
1357        descendant tables (if any) will also be truncated. Optionally, a '*'
1358        can be specified after the table name to explicitly indicate that
1359        descendant tables are included.
1360        """
1361        if isinstance(table, basestring):
1362            only = {table: only}
1363            table = [table]
1364        elif isinstance(table, (list, tuple)):
1365            if isinstance(only, (list, tuple)):
1366                only = dict(zip(table, only))
1367            else:
1368                only = dict.fromkeys(table, only)
1369        elif isinstance(table, (set, frozenset)):
1370            only = dict.fromkeys(table, only)
1371        else:
1372            raise TypeError('The table must be a string, list or set')
1373        if not (restart is None or isinstance(restart, (bool, int))):
1374            raise TypeError('Invalid type for the restart option')
1375        if not (cascade is None or isinstance(cascade, (bool, int))):
1376            raise TypeError('Invalid type for the cascade option')
1377        tables = []
1378        for t in table:
1379            u = only.get(t)
1380            if not (u is None or isinstance(u, (bool, int))):
1381                raise TypeError('Invalid type for the only option')
1382            if t.endswith('*'):
1383                if u:
1384                    raise ValueError(
1385                        'Contradictory table name and only options')
1386                t = t[:-1].rstrip()
1387            t = self._escape_qualified_name(t)
1388            if u:
1389                t = 'ONLY %s' % t
1390            tables.append(t)
1391        q = ['TRUNCATE', ', '.join(tables)]
1392        if restart:
1393            q.append('RESTART IDENTITY')
1394        if cascade:
1395            q.append('CASCADE')
1396        q = ' '.join(q)
1397        self._do_debug(q)
1398        return self.db.query(q)
1399
1400    def get_as_list(self, table, what=None, where=None,
1401            order=None, limit=None, offset=None, scalar=False):
1402        """Get a table as a list.
1403
1404        This gets a convenient representation of the table as a list
1405        of named tuples in Python.  You only need to pass the name of
1406        the table (or any other SQL expression returning rows).  Note that
1407        by default this will return the full content of the table which
1408        can be huge and overflow your memory.  However, you can control
1409        the amount of data returned using the other optional parameters.
1410
1411        The parameter 'what' can restrict the query to only return a
1412        subset of the table columns.  It can be a string, list or a tuple.
1413        The parameter 'where' can restrict the query to only return a
1414        subset of the table rows.  It can be a string, list or a tuple
1415        of SQL expressions that all need to be fulfilled.  The parameter
1416        'order' specifies the ordering of the rows.  It can also be a
1417        other string, list or a tuple.  If no ordering is specified,
1418        the result will be ordered by the primary key(s) or all columns
1419        if no primary key exists.  You can set 'order' to False if you
1420        don't care about the ordering.  The parameters 'limit' and 'offset'
1421        can be integers specifying the maximum number of rows returned
1422        and a number of rows skipped over.
1423
1424        If you set the 'scalar' option to True, then instead of the
1425        named tuples you will get the first items of these tuples.
1426        This is useful if the result has only one column anyway.
1427        """
1428        if not table:
1429            raise TypeError('The table name is missing')
1430        if what:
1431            if isinstance(what, (list, tuple)):
1432                what = ', '.join(map(str, what))
1433            if order is None:
1434                order = what
1435        else:
1436            what = '*'
1437        q = ['SELECT', what, 'FROM', table]
1438        if where:
1439            if isinstance(where, (list, tuple)):
1440                where = ' AND '.join(map(str, where))
1441            q.extend(['WHERE', where])
1442        if order is None:
1443            try:
1444                order = self.pkey(table, True)
1445            except (KeyError, ProgrammingError):
1446                try:
1447                    order = list(self.get_attnames(table))
1448                except (KeyError, ProgrammingError):
1449                    pass
1450        if order:
1451            if isinstance(order, (list, tuple)):
1452                order = ', '.join(map(str, order))
1453            q.extend(['ORDER BY', order])
1454        if limit:
1455            q.append('LIMIT %d' % limit)
1456        if offset:
1457            q.append('OFFSET %d' % offset)
1458        q = ' '.join(q)
1459        self._do_debug(q)
1460        q = self.db.query(q)
1461        res = q.namedresult()
1462        if res and scalar:
1463            res = [row[0] for row in res]
1464        return res
1465
1466    def get_as_dict(self, table, keyname=None, what=None, where=None,
1467            order=None, limit=None, offset=None, scalar=False):
1468        """Get a table as a dictionary.
1469
1470        This method is similar to get_as_list(), but returns the table
1471        as a Python dict instead of a Python list, which can be even
1472        more convenient. The primary key column(s) of the table will
1473        be used as the keys of the dictionary, while the other column(s)
1474        will be the corresponding values.  The keys will be named tuples
1475        if the table has a composite primary key.  The rows will be also
1476        named tuples unless the 'scalar' option has been set to True.
1477        With the optional parameter 'keyname' you can specify an alternative
1478        set of columns to be used as the keys of the dictionary.  It must
1479        be set as a string, list or a tuple.
1480
1481        If the Python version supports it, the dictionary will be an
1482        OrderedDict using the order specified with the 'order' parameter
1483        or the key column(s) if not specified.  You can set 'order' to False
1484        if you don't care about the ordering.  In this case the returned
1485        dictionary will be an ordinary one.
1486        """
1487        if not table:
1488            raise TypeError('The table name is missing')
1489        if not keyname:
1490            try:
1491                keyname = self.pkey(table, True)
1492            except (KeyError, ProgrammingError):
1493                raise _prg_error('Table %s has no primary key' % table)
1494        if isinstance(keyname, basestring):
1495            keyname = [keyname]
1496        elif not isinstance(keyname, (list, tuple)):
1497            raise KeyError('The keyname must be a string, list or tuple')
1498        if what:
1499            if isinstance(what, (list, tuple)):
1500                what = ', '.join(map(str, what))
1501            if order is None:
1502                order = what
1503        else:
1504            what = '*'
1505        q = ['SELECT', what, 'FROM', table]
1506        if where:
1507            if isinstance(where, (list, tuple)):
1508                where = ' AND '.join(map(str, where))
1509            q.extend(['WHERE', where])
1510        if order is None:
1511            order = keyname
1512        if order:
1513            if isinstance(order, (list, tuple)):
1514                order = ', '.join(map(str, order))
1515            q.extend(['ORDER BY', order])
1516        if limit:
1517            q.append('LIMIT %d' % limit)
1518        if offset:
1519            q.append('OFFSET %d' % offset)
1520        q = ' '.join(q)
1521        self._do_debug(q)
1522        q = self.db.query(q)
1523        res = q.getresult()
1524        cls = OrderedDict if order else dict
1525        if not res:
1526            return cls()
1527        keyset = set(keyname)
1528        fields = q.listfields()
1529        if not keyset.issubset(fields):
1530            raise KeyError('Missing keyname in row')
1531        keyind, rowind = [], []
1532        for i, f in enumerate(fields):
1533            (keyind if f in keyset else rowind).append(i)
1534        keytuple = len(keyind) > 1
1535        getkey = itemgetter(*keyind)
1536        keys = map(getkey, res)
1537        if scalar:
1538            rowind = rowind[:1]
1539            rowtuple = False
1540        else:
1541            rowtuple = len(rowind) > 1
1542        if scalar or rowtuple:
1543            getrow = itemgetter(*rowind)
1544        else:
1545            rowind = rowind[0]
1546            getrow = lambda row: (row[rowind],)
1547            rowtuple = True
1548        rows = map(getrow, res)
1549        if keytuple or rowtuple:
1550            namedresult = get_namedresult()
1551            if namedresult:
1552                if keytuple:
1553                    keys = namedresult(_MemoryQuery(keys, keyname))
1554                if rowtuple:
1555                    fields = [f for f in fields if f not in keyset]
1556                    rows = namedresult(_MemoryQuery(rows, fields))
1557        return cls(zip(keys, rows))
1558
1559    def notification_handler(self,
1560            event, callback, arg_dict=None, timeout=None, stop_event=None):
1561        """Get notification handler that will run the given callback."""
1562        return NotificationHandler(self,
1563            event, callback, arg_dict, timeout, stop_event)
1564
1565
1566# if run as script, print some information
1567
1568if __name__ == '__main__':
1569    print('PyGreSQL version' + version)
1570    print('')
1571    print(__doc__)
Note: See TracBrowser for help on using the repository browser.