source: trunk/pg.py @ 781

Last change on this file since 781 was 781, checked in by cito, 4 years ago

Add full support for PostgreSQL array types

At the core of this patch is a fast parser for the peculiar syntax of
literal array expressions in PostgreSQL that was added to the C module.
This is not trivial, because PostgreSQL arrays can be multidimensional
and the syntax is different from Python and SQL expressions.

The Python pg and pgdb modules make use of this parser so that they can
return database columns containing PostgreSQL arrays to Python as lists.
Also added quoting methods that allow passing PostgreSQL arrays as lists
to insert()/update() and execute/executemany(). These methods are simpler
and were implemented in Python but needed support from the regex module.

The patch also adds makes getresult() in pg automatically return bytea
values in unescaped form as bytes strings. Before, it was necessary to
call unescape_bytea manually. The pgdb module did this already.

The patch includes some more refactorings and simplifications regarding
the quoting and casting in pg and pgdb.

Some references to antique PostgreSQL types that are not used any more
in the supported PostgreSQL versions have been removed.

Also added documentation and tests for the new features.

  • Property svn:keywords set to Id
File size: 59.7 KB
Line 
1#! /usr/bin/python
2#
3# pg.py
4#
5# $Id: pg.py 781 2016-01-25 20:44:52Z cito $
6#
7
8"""PyGreSQL classic interface.
9
10This pg module implements some basic database management stuff.
11It includes the _pg module and builds on it, providing the higher
12level wrapper class named DB with additional functionality.
13This is known as the "classic" ("old style") PyGreSQL interface.
14For a DB-API 2 compliant interface use the newer pgdb module.
15"""
16
17# Copyright (c) 1997-2016 by D'Arcy J.M. Cain.
18#
19# Contributions made by Ch. Zwerschke and others.
20#
21# The notification handler is based on pgnotify which is
22# Copyright (c) 2001 Ng Pheng Siong. All rights reserved.
23#
24# Permission to use, copy, modify, and distribute this software and its
25# documentation for any purpose and without fee is hereby granted,
26# provided that the above copyright notice appear in all copies and that
27# both that copyright notice and this permission notice appear in
28# supporting documentation.
29
30from __future__ import print_function
31
32from _pg import *
33
34import select
35import warnings
36
37from decimal import Decimal
38from collections import namedtuple
39from functools import partial
40from operator import itemgetter
41from re import compile as regex
42from json import loads as jsondecode, dumps as jsonencode
43
44try:
45    basestring
46except NameError:  # Python >= 3.0
47    basestring = (str, bytes)
48
49try:
50    from collections import OrderedDict
51except ImportError:  # Python 2.6 or 3.0
52    OrderedDict = dict
53
54
55    class AttrDict(dict):
56        """Simple read-only ordered dictionary for storing attribute names."""
57
58        def __init__(self, *args, **kw):
59            if len(args) > 1 or kw:
60                raise TypeError
61            items = args[0] if args else []
62            if isinstance(items, dict):
63                raise TypeError
64            items = list(items)
65            self._keys = [item[0] for item in items]
66            dict.__init__(self, items)
67            self._read_only = True
68            error = self._read_only_error
69            self.clear = self.update = error
70            self.pop = self.setdefault = self.popitem = error
71
72        def __setitem__(self, key, value):
73            if self._read_only:
74                self._read_only_error()
75            dict.__setitem__(self, key, value)
76
77        def __delitem__(self, key):
78            if self._read_only:
79                self._read_only_error()
80            dict.__delitem__(self, key)
81
82        def __iter__(self):
83            return iter(self._keys)
84
85        def keys(self):
86            return list(self._keys)
87
88        def values(self):
89            return [self[key] for key in self]
90
91        def items(self):
92            return [(key, self[key]) for key in self]
93
94        def iterkeys(self):
95            return self.__iter__()
96
97        def itervalues(self):
98            return iter(self.values())
99
100        def iteritems(self):
101            return iter(self.items())
102
103        @staticmethod
104        def _read_only_error(*args, **kw):
105            raise TypeError('This object is read-only')
106
107else:
108
109     class AttrDict(OrderedDict):
110        """Simple read-only ordered dictionary for storing attribute names."""
111
112        def __init__(self, *args, **kw):
113            self._read_only = False
114            OrderedDict.__init__(self, *args, **kw)
115            self._read_only = True
116            error = self._read_only_error
117            self.clear = self.update = error
118            self.pop = self.setdefault = self.popitem = error
119
120        def __setitem__(self, key, value):
121            if self._read_only:
122                self._read_only_error()
123            OrderedDict.__setitem__(self, key, value)
124
125        def __delitem__(self, key):
126            if self._read_only:
127                self._read_only_error()
128            OrderedDict.__delitem__(self, key)
129
130        @staticmethod
131        def _read_only_error(*args, **kw):
132            raise TypeError('This object is read-only')
133
134
135# Auxiliary classes and functions that are independent from a DB connection:
136
137def _oid_key(table):
138    """Build oid key from a table name."""
139    return 'oid(%s)' % table
140
141
142class _SimpleType(dict):
143    """Dictionary mapping pg_type names to simple type names."""
144
145    _types = {'bool': 'bool',
146        'bytea': 'bytea',
147        'date': 'date interval time timetz timestamp timestamptz'
148            ' abstime reltime',  # these are very old
149        'float': 'float4 float8',
150        'int': 'cid int2 int4 int8 oid xid',
151        'json': 'json jsonb',
152        'num': 'numeric',
153        'money': 'money',
154        'text': 'bpchar char name text varchar'}
155
156    def __init__(self):
157        for typ, keys in self._types.items():
158            for key in keys.split():
159                self[key] = typ
160                self['_%s' % key] = '%s[]' % typ
161
162    @staticmethod
163    def __missing__(key):
164        return 'text'
165
166_simpletype = _SimpleType()
167
168
169class _Literal(str):
170    """Wrapper class for literal SQL."""
171
172
173def _namedresult(q):
174    """Get query result as named tuples."""
175    row = namedtuple('Row', q.listfields())
176    return [row(*r) for r in q.getresult()]
177
178
179class _MemoryQuery:
180    """Class that embodies a given query result."""
181
182    def __init__(self, result, fields):
183        """Create query from given result rows and field names."""
184        self.result = result
185        self.fields = fields
186
187    def listfields(self):
188        """Return the stored field names of this query."""
189        return self.fields
190
191    def getresult(self):
192        """Return the stored result of this query."""
193        return self.result
194
195
196def _db_error(msg, cls=DatabaseError):
197    """Return DatabaseError with empty sqlstate attribute."""
198    error = cls(msg)
199    error.sqlstate = None
200    return error
201
202
203def _int_error(msg):
204    """Return InternalError."""
205    return _db_error(msg, InternalError)
206
207
208def _prg_error(msg):
209    """Return ProgrammingError."""
210    return _db_error(msg, ProgrammingError)
211
212
213# Initialize the C module
214
215set_namedresult(_namedresult)
216set_decimal(Decimal)
217set_jsondecode(jsondecode)
218
219
220# The notification handler
221
222class NotificationHandler(object):
223    """A PostgreSQL client-side asynchronous notification handler."""
224
225    def __init__(self, db, event, callback=None,
226            arg_dict=None, timeout=None, stop_event=None):
227        """Initialize the notification handler.
228
229        You must pass a PyGreSQL database connection, the name of an
230        event (notification channel) to listen for and a callback function.
231
232        You can also specify a dictionary arg_dict that will be passed as
233        the single argument to the callback function, and a timeout value
234        in seconds (a floating point number denotes fractions of seconds).
235        If it is absent or None, the callers will never time out.  If the
236        timeout is reached, the callback function will be called with a
237        single argument that is None.  If you set the timeout to zero,
238        the handler will poll notifications synchronously and return.
239
240        You can specify the name of the event that will be used to signal
241        the handler to stop listening as stop_event. By default, it will
242        be the event name prefixed with 'stop_'.
243        """
244        self.db = db
245        self.event = event
246        self.stop_event = stop_event or 'stop_%s' % event
247        self.listening = False
248        self.callback = callback
249        if arg_dict is None:
250            arg_dict = {}
251        self.arg_dict = arg_dict
252        self.timeout = timeout
253
254    def __del__(self):
255        self.unlisten()
256
257    def close(self):
258        """Stop listening and close the connection."""
259        if self.db:
260            self.unlisten()
261            self.db.close()
262            self.db = None
263
264    def listen(self):
265        """Start listening for the event and the stop event."""
266        if not self.listening:
267            self.db.query('listen "%s"' % self.event)
268            self.db.query('listen "%s"' % self.stop_event)
269            self.listening = True
270
271    def unlisten(self):
272        """Stop listening for the event and the stop event."""
273        if self.listening:
274            self.db.query('unlisten "%s"' % self.event)
275            self.db.query('unlisten "%s"' % self.stop_event)
276            self.listening = False
277
278    def notify(self, db=None, stop=False, payload=None):
279        """Generate a notification.
280
281        Optionally, you can pass a payload with the notification.
282
283        If you set the stop flag, a stop notification will be sent that
284        will cause the handler to stop listening.
285
286        Note: If the notification handler is running in another thread, you
287        must pass a different database connection since PyGreSQL database
288        connections are not thread-safe.
289        """
290        if self.listening:
291            if not db:
292                db = self.db
293            q = 'notify "%s"' % (self.stop_event if stop else self.event)
294            if payload:
295                q += ", '%s'" % payload
296            return db.query(q)
297
298    def __call__(self):
299        """Invoke the notification handler.
300
301        The handler is a loop that listens for notifications on the event
302        and stop event channels.  When either of these notifications are
303        received, its associated 'pid', 'event' and 'extra' (the payload
304        passed with the notification) are inserted into its arg_dict
305        dictionary and the callback is invoked with this dictionary as
306        a single argument.  When the handler receives a stop event, it
307        stops listening to both events and return.
308
309        In the special case that the timeout of the handler has been set
310        to zero, the handler will poll all events synchronously and return.
311        If will keep listening until it receives a stop event.
312
313        Note: If you run this loop in another thread, don't use the same
314        database connection for database operations in the main thread.
315        """
316        self.listen()
317        poll = self.timeout == 0
318        if not poll:
319            rlist = [self.db.fileno()]
320        while self.listening:
321            if poll or select.select(rlist, [], [], self.timeout)[0]:
322                while self.listening:
323                    notice = self.db.getnotify()
324                    if not notice:  # no more messages
325                        break
326                    event, pid, extra = notice
327                    if event not in (self.event, self.stop_event):
328                        self.unlisten()
329                        raise _db_error(
330                            'Listening for "%s" and "%s", but notified of "%s"'
331                            % (self.event, self.stop_event, event))
332                    if event == self.stop_event:
333                        self.unlisten()
334                    self.arg_dict.update(pid=pid, event=event, extra=extra)
335                    self.callback(self.arg_dict)
336                if poll:
337                    break
338            else:   # we timed out
339                self.unlisten()
340                self.callback(None)
341
342
343def pgnotify(*args, **kw):
344    """Same as NotificationHandler, under the traditional name."""
345    warnings.warn("pgnotify is deprecated, use NotificationHandler instead",
346        DeprecationWarning, stacklevel=2)
347    return NotificationHandler(*args, **kw)
348
349
350# The actual PostGreSQL database connection interface:
351
352class DB(object):
353    """Wrapper class for the _pg connection type."""
354
355    def __init__(self, *args, **kw):
356        """Create a new connection
357
358        You can pass either the connection parameters or an existing
359        _pg or pgdb connection. This allows you to use the methods
360        of the classic pg interface with a DB-API 2 pgdb connection.
361        """
362        if not args and len(kw) == 1:
363            db = kw.get('db')
364        elif not kw and len(args) == 1:
365            db = args[0]
366        else:
367            db = None
368        if db:
369            if isinstance(db, DB):
370                db = db.db
371            else:
372                try:
373                    db = db._cnx
374                except AttributeError:
375                    pass
376        if not db or not hasattr(db, 'db') or not hasattr(db, 'query'):
377            db = connect(*args, **kw)
378            self._closeable = True
379        else:
380            self._closeable = False
381        self.db = db
382        self.dbname = db.db
383        self._regtypes = False
384        self._attnames = {}
385        self._pkeys = {}
386        self._privileges = {}
387        self._args = args, kw
388        self.debug = None  # For debugging scripts, this can be set
389            # * to a string format specification (e.g. in CGI set to "%s<BR>"),
390            # * to a file object to write debug statements or
391            # * to a callable object which takes a string argument
392            # * to any other true value to just print debug statements
393
394    def __getattr__(self, name):
395        # All undefined members are same as in underlying connection:
396        if self.db:
397            return getattr(self.db, name)
398        else:
399            raise _int_error('Connection is not valid')
400
401    def __dir__(self):
402        # Custom dir function including the attributes of the connection:
403        attrs = set(self.__class__.__dict__)
404        attrs.update(self.__dict__)
405        attrs.update(dir(self.db))
406        return sorted(attrs)
407
408    # Context manager methods
409
410    def __enter__(self):
411        """Enter the runtime context. This will start a transactio."""
412        self.begin()
413        return self
414
415    def __exit__(self, et, ev, tb):
416        """Exit the runtime context. This will end the transaction."""
417        if et is None and ev is None and tb is None:
418            self.commit()
419        else:
420            self.rollback()
421
422    # Auxiliary methods
423
424    def _do_debug(self, *args):
425        """Print a debug message"""
426        if self.debug:
427            s = '\n'.join(str(arg) for arg in args)
428            if isinstance(self.debug, basestring):
429                print(self.debug % s)
430            elif hasattr(self.debug, 'write'):
431                self.debug.write(s + '\n')
432            elif callable(self.debug):
433                self.debug(s)
434            else:
435                print(s)
436
437    def _escape_qualified_name(self, s):
438        """Escape a qualified name.
439
440        Escapes the name for use as an SQL identifier, unless the
441        name contains a dot, in which case the name is ambiguous
442        (could be a qualified name or just a name with a dot in it)
443        and must be quoted manually by the caller.
444        """
445        if '.' not in s:
446            s = self.escape_identifier(s)
447        return s
448
449    @staticmethod
450    def _make_bool(d):
451        """Get boolean value corresponding to d."""
452        return bool(d) if get_bool() else ('t' if d else 'f')
453
454    _bool_true_values = frozenset('t true 1 y yes on'.split())
455
456    def _prepare_bool(self, d):
457        """Prepare a boolean parameter."""
458        if isinstance(d, basestring):
459            if not d:
460                return None
461            d = d.lower() in self._bool_true_values
462        return 't' if d else 'f'
463
464    _date_literals = frozenset('current_date current_time'
465        ' current_timestamp localtime localtimestamp'.split())
466
467    def _prepare_date(self, d):
468        """Prepare a date parameter."""
469        if not d:
470            return None
471        if isinstance(d, basestring) and d.lower() in self._date_literals:
472            return _Literal(d)
473        return d
474
475    _num_types = frozenset('int float num money'
476        ' int2 int4 int8 float4 float8 numeric money'.split())
477
478    @staticmethod
479    def _prepare_num(d):
480        """Prepare a numeric parameter."""
481        if not d and d != 0:
482            return None
483        return d
484
485    _prepare_int = _prepare_float = _prepare_money = _prepare_num
486
487    def _prepare_bytea(self, d):
488        """Prepare a bytea parameter."""
489        return self.escape_bytea(d)
490
491    def _prepare_json(self, d):
492        """Prepare a json parameter."""
493        if not d:
494            return None
495        if isinstance(d, basestring):
496            return d
497        return self.encode_json(d)
498
499    _re_array_escape = regex(r'(["\\])')
500    _re_array_quote = regex(r'[{},"\\\s]|^[Nn][Uu][Ll][Ll]$')
501
502    def _prepare_bool_array(self, d):
503        """Prepare a bool array parameter."""
504        if isinstance(d, list):
505            return '{%s}' % ','.join(self._prepare_bool_array(v) for v in d)
506        if d is None:
507            return 'null'
508        if isinstance(d, basestring):
509            if not d:
510                return 'null'
511            d = d.lower() in self._bool_true_values
512        return 't' if d else 'f'
513
514    def _prepare_num_array(self, d):
515        """Prepare a numeric array parameter."""
516        if isinstance(d, list):
517            return '{%s}' % ','.join(self._prepare_num_array(v) for v in d)
518        if not d and d != 0:
519            return 'null'
520        return str(d)
521
522    _prepare_int_array = _prepare_float_array = _prepare_money_array = \
523            _prepare_num_array
524
525    def _prepare_text_array(self, d):
526        """Prepare a text array parameter."""
527        if isinstance(d, list):
528            return '{%s}' % ','.join(self._prepare_text_array(v) for v in d)
529        if d is None:
530            return 'null'
531        if not d:
532            return '""'
533        d = str(d)
534        if self._re_array_quote.search(d):
535            d = '"%s"' % self._re_array_escape.sub(r'\\\1', d)
536        return d
537
538    def _prepare_bytea_array(self, d):
539        """Prepare a bytea array parameter."""
540        if isinstance(d, list):
541            return '{%s}' % ','.join(self._prepare_bytea_array(v) for v in d)
542        if d is None:
543            return 'null'
544        return self.escape_bytea(d).replace('\\', '\\\\')
545
546    def _prepare_json_array(self, d):
547        """Prepare a json array parameter."""
548        if isinstance(d, list):
549            return '{%s}' % ','.join(self._prepare_json_array(v) for v in d)
550        if not d:
551            return 'null'
552        if not isinstance(d, basestring):
553            d = self.encode_json(d)
554        if self._re_array_quote.search(d):
555            d = '"%s"' % self._re_array_escape.sub(r'\\\1', d)
556        return d
557
558    def _prepare_param(self, value, typ, params):
559        """Prepare and add a parameter to the list."""
560        if isinstance(value, _Literal):
561            return value
562        if value is not None and typ != 'text':
563            if typ.endswith('[]'):
564                if isinstance(value, list):
565                    prepare = getattr(self, '_prepare_%s_array' % typ[:-2])
566                    value = prepare(value)
567                elif isinstance(value, basestring):
568                    value = value.strip()
569                    if not value.startswith('{') or not value.endswith('}'):
570                        if value[:5].lower() == 'array':
571                            value = value[5:].lstrip()
572                        if value.startswith('[') and value.endswith(']'):
573                            value = _Literal('ARRAY%s' % value)
574                        else:
575                            raise ValueError(
576                                'Invalid array expression: %s' % value)
577                else:
578                    raise ValueError('Invalid array parameter: %s' % value)
579            else:
580                prepare = getattr(self, '_prepare_%s' % typ)
581                value = prepare(value)
582            if isinstance(value, _Literal):
583                return value
584        params.append(value)
585        return '$%d' % len(params)
586
587    def _list_params(self, params):
588        """Create a human readable parameter list."""
589        return ', '.join('$%d=%r' % (n, v) for n, v in enumerate(params, 1))
590
591    @staticmethod
592    def _prepare_qualified_param(name, param):
593        """Quote parameter representing a qualified name.
594
595        Escapes the name for use as an SQL parameter, unless the
596        name contains a dot, in which case the name is ambiguous
597        (could be a qualified name or just a name with a dot in it)
598        and must be quoted manually by the caller.
599
600        """
601        if isinstance(param, int):
602            param = "$%d" % param
603        if '.' not in name:
604            param = 'quote_ident(%s)' % (param,)
605        return param
606
607    # Public methods
608
609    # escape_string and escape_bytea exist as methods,
610    # so we define unescape_bytea as a method as well
611    unescape_bytea = staticmethod(unescape_bytea)
612
613    def decode_json(self, s):
614        """Decode a JSON string coming from the database."""
615        return (get_jsondecode() or jsondecode)(s)
616
617    def encode_json(self, d):
618        """Encode a JSON string for use within SQL."""
619        return jsonencode(d)
620
621    def close(self):
622        """Close the database connection."""
623        # Wraps shared library function so we can track state.
624        if self._closeable:
625            if self.db:
626                self.db.close()
627                self.db = None
628            else:
629                raise _int_error('Connection already closed')
630
631    def reset(self):
632        """Reset connection with current parameters.
633
634        All derived queries and large objects derived from this connection
635        will not be usable after this call.
636
637        """
638        if self.db:
639            self.db.reset()
640        else:
641            raise _int_error('Connection already closed')
642
643    def reopen(self):
644        """Reopen connection to the database.
645
646        Used in case we need another connection to the same database.
647        Note that we can still reopen a database that we have closed.
648
649        """
650        # There is no such shared library function.
651        if self._closeable:
652            db = connect(*self._args[0], **self._args[1])
653            if self.db:
654                self.db.close()
655            self.db = db
656
657    def begin(self, mode=None):
658        """Begin a transaction."""
659        qstr = 'BEGIN'
660        if mode:
661            qstr += ' ' + mode
662        return self.query(qstr)
663
664    start = begin
665
666    def commit(self):
667        """Commit the current transaction."""
668        return self.query('COMMIT')
669
670    end = commit
671
672    def rollback(self, name=None):
673        """Roll back the current transaction."""
674        qstr = 'ROLLBACK'
675        if name:
676            qstr += ' TO ' + name
677        return self.query(qstr)
678
679    abort = rollback
680
681    def savepoint(self, name):
682        """Define a new savepoint within the current transaction."""
683        return self.query('SAVEPOINT ' + name)
684
685    def release(self, name):
686        """Destroy a previously defined savepoint."""
687        return self.query('RELEASE ' + name)
688
689    def get_parameter(self, parameter):
690        """Get the value of a run-time parameter.
691
692        If the parameter is a string, the return value will also be a string
693        that is the current setting of the run-time parameter with that name.
694
695        You can get several parameters at once by passing a list, set or dict.
696        When passing a list of parameter names, the return value will be a
697        corresponding list of parameter settings.  When passing a set of
698        parameter names, a new dict will be returned, mapping these parameter
699        names to their settings.  Finally, if you pass a dict as parameter,
700        its values will be set to the current parameter settings corresponding
701        to its keys.
702
703        By passing the special name 'all' as the parameter, you can get a dict
704        of all existing configuration parameters.
705        """
706        if isinstance(parameter, basestring):
707            parameter = [parameter]
708            values = None
709        elif isinstance(parameter, (list, tuple)):
710            values = []
711        elif isinstance(parameter, (set, frozenset)):
712            values = {}
713        elif isinstance(parameter, dict):
714            values = parameter
715        else:
716            raise TypeError(
717                'The parameter must be a string, list, set or dict')
718        if not parameter:
719            raise TypeError('No parameter has been specified')
720        params = {} if isinstance(values, dict) else []
721        for key in parameter:
722            param = key.strip().lower() if isinstance(
723                key, basestring) else None
724            if not param:
725                raise TypeError('Invalid parameter')
726            if param == 'all':
727                q = 'SHOW ALL'
728                values = self.db.query(q).getresult()
729                values = dict(value[:2] for value in values)
730                break
731            if isinstance(values, dict):
732                params[param] = key
733            else:
734                params.append(param)
735        else:
736            for param in params:
737                q = 'SHOW %s' % (param,)
738                value = self.db.query(q).getresult()[0][0]
739                if values is None:
740                    values = value
741                elif isinstance(values, list):
742                    values.append(value)
743                else:
744                    values[params[param]] = value
745        return values
746
747    def set_parameter(self, parameter, value=None, local=False):
748        """Set the value of a run-time parameter.
749
750        If the parameter and the value are strings, the run-time parameter
751        will be set to that value.  If no value or None is passed as a value,
752        then the run-time parameter will be restored to its default value.
753
754        You can set several parameters at once by passing a list of parameter
755        names, together with a single value that all parameters should be
756        set to or with a corresponding list of values.  You can also pass
757        the parameters as a set if you only provide a single value.
758        Finally, you can pass a dict with parameter names as keys.  In this
759        case, you should not pass a value, since the values for the parameters
760        will be taken from the dict.
761
762        By passing the special name 'all' as the parameter, you can reset
763        all existing settable run-time parameters to their default values.
764
765        If you set local to True, then the command takes effect for only the
766        current transaction.  After commit() or rollback(), the session-level
767        setting takes effect again.  Setting local to True will appear to
768        have no effect if it is executed outside a transaction, since the
769        transaction will end immediately.
770        """
771        if isinstance(parameter, basestring):
772            parameter = {parameter: value}
773        elif isinstance(parameter, (list, tuple)):
774            if isinstance(value, (list, tuple)):
775                parameter = dict(zip(parameter, value))
776            else:
777                parameter = dict.fromkeys(parameter, value)
778        elif isinstance(parameter, (set, frozenset)):
779            if isinstance(value, (list, tuple, set, frozenset)):
780                value = set(value)
781                if len(value) == 1:
782                    value = value.pop()
783            if not(value is None or isinstance(value, basestring)):
784                raise ValueError('A single value must be specified'
785                    ' when parameter is a set')
786            parameter = dict.fromkeys(parameter, value)
787        elif isinstance(parameter, dict):
788            if value is not None:
789                raise ValueError('A value must not be specified'
790                    ' when parameter is a dictionary')
791        else:
792            raise TypeError(
793                'The parameter must be a string, list, set or dict')
794        if not parameter:
795            raise TypeError('No parameter has been specified')
796        params = {}
797        for key, value in parameter.items():
798            param = key.strip().lower() if isinstance(
799                key, basestring) else None
800            if not param:
801                raise TypeError('Invalid parameter')
802            if param == 'all':
803                if value is not None:
804                    raise ValueError('A value must ot be specified'
805                        " when parameter is 'all'")
806                params = {'all': None}
807                break
808            params[param] = value
809        local = ' LOCAL' if local else ''
810        for param, value in params.items():
811            if value is None:
812                q = 'RESET%s %s' % (local, param)
813            else:
814                q = 'SET%s %s TO %s' % (local, param, value)
815            self._do_debug(q)
816            self.db.query(q)
817
818    def query(self, command, *args):
819        """Execute a SQL command string.
820
821        This method simply sends a SQL query to the database.  If the query is
822        an insert statement that inserted exactly one row into a table that
823        has OIDs, the return value is the OID of the newly inserted row.
824        If the query is an update or delete statement, or an insert statement
825        that did not insert exactly one row in a table with OIDs, then the
826        number of rows affected is returned as a string.  If it is a statement
827        that returns rows as a result (usually a select statement, but maybe
828        also an "insert/update ... returning" statement), this method returns
829        a Query object that can be accessed via getresult() or dictresult()
830        or simply printed.  Otherwise, it returns `None`.
831
832        The query can contain numbered parameters of the form $1 in place
833        of any data constant.  Arguments given after the query string will
834        be substituted for the corresponding numbered parameter.  Parameter
835        values can also be given as a single list or tuple argument.
836        """
837        # Wraps shared library function for debugging.
838        if not self.db:
839            raise _int_error('Connection is not valid')
840        if args:
841            self._do_debug(command, args)
842            return self.db.query(command, args)
843        self._do_debug(command)
844        return self.db.query(command)
845
846    def pkey(self, table, composite=False, flush=False):
847        """Get or set the primary key of a table.
848
849        Single primary keys are returned as strings unless you
850        set the composite flag.  Composite primary keys are always
851        represented as tuples.  Note that this raises a KeyError
852        if the table does not have a primary key.
853
854        If flush is set then the internal cache for primary keys will
855        be flushed.  This may be necessary after the database schema or
856        the search path has been changed.
857        """
858        pkeys = self._pkeys
859        if flush:
860            pkeys.clear()
861            self._do_debug('The pkey cache has been flushed')
862        try:  # cache lookup
863            pkey = pkeys[table]
864        except KeyError:  # cache miss, check the database
865            q = ("SELECT a.attname, a.attnum, i.indkey FROM pg_index i"
866                " JOIN pg_attribute a ON a.attrelid = i.indrelid"
867                " AND a.attnum = ANY(i.indkey)"
868                " AND NOT a.attisdropped"
869                " WHERE i.indrelid=%s::regclass"
870                " AND i.indisprimary ORDER BY a.attnum") % (
871                    self._prepare_qualified_param(table, 1),)
872            pkey = self.db.query(q, (table,)).getresult()
873            if not pkey:
874                raise KeyError('Table %s has no primary key' % table)
875            # we want to use the order defined in the primary key index here,
876            # not the order as defined by the columns in the table
877            if len(pkey) > 1:
878                indkey = [int(k) for k in pkey[0][2].split()]
879                pkey = sorted(pkey, key=lambda row: indkey.index(row[1]))
880                pkey = tuple(row[0] for row in pkey)
881            else:
882                pkey = pkey[0][0]
883            pkeys[table] = pkey  # cache it
884        if composite and not isinstance(pkey, tuple):
885            pkey = (pkey,)
886        return pkey
887
888    def get_databases(self):
889        """Get list of databases in the system."""
890        return [s[0] for s in
891            self.db.query('SELECT datname FROM pg_database').getresult()]
892
893    def get_relations(self, kinds=None):
894        """Get list of relations in connected database of specified kinds.
895
896        If kinds is None or empty, all kinds of relations are returned.
897        Otherwise kinds can be a string or sequence of type letters
898        specifying which kind of relations you want to list.
899        """
900        where = " AND r.relkind IN (%s)" % ','.join(
901            ["'%s'" % k for k in kinds]) if kinds else ''
902        q = ("SELECT quote_ident(s.nspname)||'.'||quote_ident(r.relname)"
903            " FROM pg_class r"
904            " JOIN pg_namespace s ON s.oid = r.relnamespace"
905            " WHERE s.nspname NOT SIMILAR"
906            " TO 'pg/_%%|information/_schema' ESCAPE '/' %s"
907            " ORDER BY s.nspname, r.relname") % where
908        return [r[0] for r in self.db.query(q).getresult()]
909
910    def get_tables(self):
911        """Return list of tables in connected database."""
912        return self.get_relations('r')
913
914    def get_attnames(self, table, flush=False):
915        """Given the name of a table, dig out the set of attribute names.
916
917        Returns a read-only dictionary of attribute names (the names are
918        the keys, the values are the names of the attributes' types)
919        with the column names in the proper order if you iterate over it.
920
921        If flush is set, then the internal cache for attribute names will
922        be flushed. This may be necessary after the database schema or
923        the search path has been changed.
924
925        By default, only a limited number of simple types will be returned.
926        You can get the regular types after calling use_regtypes(True).
927        """
928        attnames = self._attnames
929        if flush:
930            attnames.clear()
931            self._do_debug('The attnames cache has been flushed')
932        try:  # cache lookup
933            names = attnames[table]
934        except KeyError:  # cache miss, check the database
935            q = ("SELECT a.attname, t.typname%s"
936                " FROM pg_attribute a"
937                " JOIN pg_type t ON t.oid = a.atttypid"
938                " WHERE a.attrelid = %s::regclass"
939                " AND (a.attnum > 0 OR a.attname = 'oid')"
940                " AND NOT a.attisdropped ORDER BY a.attnum") % (
941                    '::regtype' if self._regtypes else '',
942                    self._prepare_qualified_param(table, 1))
943            names = self.db.query(q, (table,)).getresult()
944            if not self._regtypes:
945                names = ((name, _simpletype[typ]) for name, typ in names)
946            names = AttrDict(names)
947            attnames[table] = names  # cache it
948        return names
949
950    def use_regtypes(self, regtypes=None):
951        """Use regular type names instead of simplified type names."""
952        if regtypes is None:
953            return self._regtypes
954        else:
955            regtypes = bool(regtypes)
956            if regtypes != self._regtypes:
957                self._regtypes = regtypes
958                self._attnames.clear()
959            return regtypes
960
961    def has_table_privilege(self, table, privilege='select'):
962        """Check whether current user has specified table privilege."""
963        privilege = privilege.lower()
964        try:  # ask cache
965            return self._privileges[(table, privilege)]
966        except KeyError:  # cache miss, ask the database
967            q = "SELECT has_table_privilege(%s, $2)" % (
968                self._prepare_qualified_param(table, 1),)
969            q = self.db.query(q, (table, privilege))
970            ret = q.getresult()[0][0] == self._make_bool(True)
971            self._privileges[(table, privilege)] = ret  # cache it
972            return ret
973
974    def get(self, table, row, keyname=None):
975        """Get a row from a database table or view.
976
977        This method is the basic mechanism to get a single row.  It assumes
978        that the keyname specifies a unique row.  It must be the name of a
979        single column or a tuple of column names.  If the keyname is not
980        specified, then the primary key for the table is used.
981
982        If row is a dictionary, then the value for the key is taken from it.
983        Otherwise, the row must be a single value or a tuple of values
984        corresponding to the passed keyname or primary key.  The fetched row
985        from the table will be returned as a new dictionary or used to replace
986        the existing values when row was passed as aa dictionary.
987
988        The OID is also put into the dictionary if the table has one, but
989        in order to allow the caller to work with multiple tables, it is
990        munged as "oid(table)" using the actual name of the table.
991        """
992        if table.endswith('*'):  # hint for descendant tables can be ignored
993            table = table[:-1].rstrip()
994        attnames = self.get_attnames(table)
995        qoid = _oid_key(table) if 'oid' in attnames else None
996        if keyname and isinstance(keyname, basestring):
997            keyname = (keyname,)
998        if qoid and isinstance(row, dict) and qoid in row and 'oid' not in row:
999            row['oid'] = row[qoid]
1000        if not keyname:
1001            try:  # if keyname is not specified, try using the primary key
1002                keyname = self.pkey(table, True)
1003            except KeyError:  # the table has no primary key
1004                # try using the oid instead
1005                if qoid and isinstance(row, dict) and 'oid' in row:
1006                    keyname = ('oid',)
1007                else:
1008                    raise _prg_error('Table %s has no primary key' % table)
1009            else:  # the table has a primary key
1010                # check whether all key columns have values
1011                if isinstance(row, dict) and not set(keyname).issubset(row):
1012                    # try using the oid instead
1013                    if qoid and 'oid' in row:
1014                        keyname = ('oid',)
1015                    else:
1016                        raise KeyError(
1017                            'Missing value in row for specified keyname')
1018        if not isinstance(row, dict):
1019            if not isinstance(row, (tuple, list)):
1020                row = [row]
1021            if len(keyname) != len(row):
1022                raise KeyError(
1023                    'Differing number of items in keyname and row')
1024            row = dict(zip(keyname, row))
1025        params = []
1026        param = partial(self._prepare_param, params=params)
1027        col = self.escape_identifier
1028        what = 'oid, *' if qoid else '*'
1029        where = ' AND '.join('%s = %s' % (
1030            col(k), param(row[k], attnames[k])) for k in keyname)
1031        if 'oid' in row:
1032            if qoid:
1033                row[qoid] = row['oid']
1034            del row['oid']
1035        q = 'SELECT %s FROM %s WHERE %s LIMIT 1' % (
1036            what, self._escape_qualified_name(table), where)
1037        self._do_debug(q, params)
1038        q = self.db.query(q, params)
1039        res = q.dictresult()
1040        if not res:
1041            raise _db_error('No such record in %s\nwhere %s\nwith %s' % (
1042                table, where, self._list_params(params)))
1043        for n, value in res[0].items():
1044            if qoid and n == 'oid':
1045                n = qoid
1046            row[n] = value
1047        return row
1048
1049    def insert(self, table, row=None, **kw):
1050        """Insert a row into a database table.
1051
1052        This method inserts a row into a table.  The name of the table must
1053        be passed as the first parameter.  The other parameters are used for
1054        providing the data of the row that shall be inserted into the table.
1055        If a dictionary is supplied as the second parameter, it starts with
1056        that.  Otherwise it uses a blank dictionary. Either way the dictionary
1057        is updated from the keywords.
1058
1059        The dictionary is then reloaded with the values actually inserted in
1060        order to pick up values modified by rules, triggers, etc.
1061        """
1062        if table.endswith('*'):  # hint for descendant tables can be ignored
1063            table = table[:-1].rstrip()
1064        if row is None:
1065            row = {}
1066        row.update(kw)
1067        if 'oid' in row:
1068            del row['oid']  # do not insert oid
1069        attnames = self.get_attnames(table)
1070        qoid = _oid_key(table) if 'oid' in attnames else None
1071        params = []
1072        param = partial(self._prepare_param, params=params)
1073        col = self.escape_identifier
1074        names, values = [], []
1075        for n in attnames:
1076            if n in row:
1077                names.append(col(n))
1078                values.append(param(row[n], attnames[n]))
1079        if not names:
1080            raise _prg_error('No column found that can be inserted')
1081        names, values = ', '.join(names), ', '.join(values)
1082        ret = 'oid, *' if qoid else '*'
1083        q = 'INSERT INTO %s (%s) VALUES (%s) RETURNING %s' % (
1084            self._escape_qualified_name(table), names, values, ret)
1085        self._do_debug(q, params)
1086        q = self.db.query(q, params)
1087        res = q.dictresult()
1088        if res:  # this should always be true
1089            for n, value in res[0].items():
1090                if qoid and n == 'oid':
1091                    n = qoid
1092                row[n] = value
1093        return row
1094
1095    def update(self, table, row=None, **kw):
1096        """Update an existing row in a database table.
1097
1098        Similar to insert but updates an existing row.  The update is based
1099        on the primary key of the table or the OID value as munged by get
1100        or passed as keyword.
1101
1102        The dictionary is then modified to reflect any changes caused by the
1103        update due to triggers, rules, default values, etc.
1104        """
1105        if table.endswith('*'):
1106            table = table[:-1].rstrip()  # need parent table name
1107        attnames = self.get_attnames(table)
1108        qoid = _oid_key(table) if 'oid' in attnames else None
1109        if row is None:
1110            row = {}
1111        elif 'oid' in row:
1112            del row['oid']  # only accept oid key from named args for safety
1113        row.update(kw)
1114        if qoid and qoid in row and 'oid' not in row:
1115            row['oid'] = row[qoid]
1116        try:  # try using the primary key
1117            keyname = self.pkey(table, True)
1118        except KeyError:  # the table has no primary key
1119            # try using the oid instead
1120            if qoid and 'oid' in row:
1121                keyname = ('oid',)
1122            else:
1123                raise _prg_error('Table %s has no primary key' % table)
1124        else:  # the table has a primary key
1125            # check whether all key columns have values
1126            if not set(keyname).issubset(row):
1127                # try using the oid instead
1128                if qoid and 'oid' in row:
1129                    keyname = ('oid',)
1130                else:
1131                    raise KeyError('Missing primary key in row')
1132        params = []
1133        param = partial(self._prepare_param, params=params)
1134        col = self.escape_identifier
1135        where = ' AND '.join('%s = %s' % (
1136            col(k), param(row[k], attnames[k])) for k in keyname)
1137        if 'oid' in row:
1138            if qoid:
1139                row[qoid] = row['oid']
1140            del row['oid']
1141        values = []
1142        keyname = set(keyname)
1143        for n in attnames:
1144            if n in row and n not in keyname:
1145                values.append('%s = %s' % (col(n), param(row[n], attnames[n])))
1146        if not values:
1147            return row
1148        values = ', '.join(values)
1149        ret = 'oid, *' if qoid else '*'
1150        q = 'UPDATE %s SET %s WHERE %s RETURNING %s' % (
1151            self._escape_qualified_name(table), values, where, ret)
1152        self._do_debug(q, params)
1153        q = self.db.query(q, params)
1154        res = q.dictresult()
1155        if res:  # may be empty when row does not exist
1156            for n, value in res[0].items():
1157                if qoid and n == 'oid':
1158                    n = qoid
1159                row[n] = value
1160        return row
1161
1162    def upsert(self, table, row=None, **kw):
1163        """Insert a row into a database table with conflict resolution
1164
1165        This method inserts a row into a table, but instead of raising a
1166        ProgrammingError exception in case a row with the same primary key
1167        already exists, an update will be executed instead.  This will be
1168        performed as a single atomic operation on the database, so race
1169        conditions can be avoided.
1170
1171        Like the insert method, the first parameter is the name of the
1172        table and the second parameter can be used to pass the values to
1173        be inserted as a dictionary.
1174
1175        Unlike the insert und update statement, keyword parameters are not
1176        used to modify the dictionary, but to specify which columns shall
1177        be updated in case of a conflict, and in which way:
1178
1179        A value of False or None means the column shall not be updated,
1180        a value of True means the column shall be updated with the value
1181        that has been proposed for insertion, i.e. has been passed as value
1182        in the dictionary.  Columns that are not specified by keywords but
1183        appear as keys in the dictionary are also updated like in the case
1184        keywords had been passed with the value True.
1185
1186        So if in the case of a conflict you want to update every column that
1187        has been passed in the dictionary row , you would call upsert(table, row).
1188        If you don't want to do anything in case of a conflict, i.e. leave
1189        the existing row as it is, call upsert(table, row, **dict.fromkeys(row)).
1190
1191        If you need more fine-grained control of what gets updated, you can
1192        also pass strings in the keyword parameters.  These strings will
1193        be used as SQL expressions for the update columns.  In these
1194        expressions you can refer to the value that already exists in
1195        the table by prefixing the column name with "included.", and to
1196        the value that has been proposed for insertion by prefixing the
1197        column name with the "excluded."
1198
1199        The dictionary is modified in any case to reflect the values in
1200        the database after the operation has completed.
1201
1202        Note: The method uses the PostgreSQL "upsert" feature which is
1203        only available since PostgreSQL 9.5.
1204        """
1205        if table.endswith('*'):  # hint for descendant tables can be ignored
1206            table = table[:-1].rstrip()
1207        if row is None:
1208            row = {}
1209        if 'oid' in row:
1210            del row['oid']  # do not insert oid
1211        if 'oid' in kw:
1212            del kw['oid']  # do not update oid
1213        attnames = self.get_attnames(table)
1214        qoid = _oid_key(table) if 'oid' in attnames else None
1215        params = []
1216        param = partial(self._prepare_param,params=params)
1217        col = self.escape_identifier
1218        names, values, updates = [], [], []
1219        for n in attnames:
1220            if n in row:
1221                names.append(col(n))
1222                values.append(param(row[n], attnames[n]))
1223        names, values = ', '.join(names), ', '.join(values)
1224        try:
1225            keyname = self.pkey(table, True)
1226        except KeyError:
1227            raise _prg_error('Table %s has no primary key' % table)
1228        target = ', '.join(col(k) for k in keyname)
1229        update = []
1230        keyname = set(keyname)
1231        keyname.add('oid')
1232        for n in attnames:
1233            if n not in keyname:
1234                value = kw.get(n, True)
1235                if value:
1236                    if not isinstance(value, basestring):
1237                        value = 'excluded.%s' % col(n)
1238                    update.append('%s = %s' % (col(n), value))
1239        if not values:
1240            return row
1241        do = 'update set %s' % ', '.join(update) if update else 'nothing'
1242        ret = 'oid, *' if qoid else '*'
1243        q = ('INSERT INTO %s AS included (%s) VALUES (%s)'
1244            ' ON CONFLICT (%s) DO %s RETURNING %s') % (
1245                self._escape_qualified_name(table), names, values,
1246                target, do, ret)
1247        self._do_debug(q, params)
1248        try:
1249            q = self.db.query(q, params)
1250        except ProgrammingError:
1251            if self.server_version < 90500:
1252                raise _prg_error(
1253                    'Upsert operation is not supported by PostgreSQL version')
1254            raise  # re-raise original error
1255        res = q.dictresult()
1256        if res:  # may be empty with "do nothing"
1257            for n, value in res[0].items():
1258                if qoid and n == 'oid':
1259                    n = qoid
1260                row[n] = value
1261        else:
1262            self.get(table, row)
1263        return row
1264
1265    def clear(self, table, row=None):
1266        """Clear all the attributes to values determined by the types.
1267
1268        Numeric types are set to 0, Booleans are set to false, and everything
1269        else is set to the empty string.  If the row argument is present,
1270        it is used as the row dictionary and any entries matching attribute
1271        names are cleared with everything else left unchanged.
1272        """
1273        # At some point we will need a way to get defaults from a table.
1274        if row is None:
1275            row = {}  # empty if argument is not present
1276        attnames = self.get_attnames(table)
1277        for n, t in attnames.items():
1278            if n == 'oid':
1279                continue
1280            if t in self._num_types:
1281                row[n] = 0
1282            elif t == 'bool':
1283                row[n] = self._make_bool(False)
1284            else:
1285                row[n] = ''
1286        return row
1287
1288    def delete(self, table, row=None, **kw):
1289        """Delete an existing row in a database table.
1290
1291        This method deletes the row from a table.  It deletes based on the
1292        primary key of the table or the OID value as munged by get() or
1293        passed as keyword.
1294
1295        The return value is the number of deleted rows (i.e. 0 if the row
1296        did not exist and 1 if the row was deleted).
1297
1298        Note that if the row cannot be deleted because e.g. it is still
1299        referenced by another table, this method raises a ProgrammingError.
1300        """
1301        if table.endswith('*'):  # hint for descendant tables can be ignored
1302            table = table[:-1].rstrip()
1303        attnames = self.get_attnames(table)
1304        qoid = _oid_key(table) if 'oid' in attnames else None
1305        if row is None:
1306            row = {}
1307        elif 'oid' in row:
1308            del row['oid']  # only accept oid key from named args for safety
1309        row.update(kw)
1310        if qoid and qoid in row and 'oid' not in row:
1311            row['oid'] = row[qoid]
1312        try:  # try using the primary key
1313            keyname = self.pkey(table, True)
1314        except KeyError:  # the table has no primary key
1315            # try using the oid instead
1316            if qoid and 'oid' in row:
1317                keyname = ('oid',)
1318            else:
1319                raise _prg_error('Table %s has no primary key' % table)
1320        else:  # the table has a primary key
1321            # check whether all key columns have values
1322            if not set(keyname).issubset(row):
1323                # try using the oid instead
1324                if qoid and 'oid' in row:
1325                    keyname = ('oid',)
1326                else:
1327                    raise KeyError('Missing primary key in row')
1328        params = []
1329        param = partial(self._prepare_param, params=params)
1330        col = self.escape_identifier
1331        where = ' AND '.join('%s = %s' % (
1332            col(k), param(row[k], attnames[k])) for k in keyname)
1333        if 'oid' in row:
1334            if qoid:
1335                row[qoid] = row['oid']
1336            del row['oid']
1337        q = 'DELETE FROM %s WHERE %s' % (
1338            self._escape_qualified_name(table), where)
1339        self._do_debug(q, params)
1340        res = self.db.query(q, params)
1341        return int(res)
1342
1343    def truncate(self, table, restart=False, cascade=False, only=False):
1344        """Empty a table or set of tables.
1345
1346        This method quickly removes all rows from the given table or set
1347        of tables.  It has the same effect as an unqualified DELETE on each
1348        table, but since it does not actually scan the tables it is faster.
1349        Furthermore, it reclaims disk space immediately, rather than requiring
1350        a subsequent VACUUM operation. This is most useful on large tables.
1351
1352        If restart is set to True, sequences owned by columns of the truncated
1353        table(s) are automatically restarted.  If cascade is set to True, it
1354        also truncates all tables that have foreign-key references to any of
1355        the named tables.  If the parameter only is not set to True, all the
1356        descendant tables (if any) will also be truncated. Optionally, a '*'
1357        can be specified after the table name to explicitly indicate that
1358        descendant tables are included.
1359        """
1360        if isinstance(table, basestring):
1361            only = {table: only}
1362            table = [table]
1363        elif isinstance(table, (list, tuple)):
1364            if isinstance(only, (list, tuple)):
1365                only = dict(zip(table, only))
1366            else:
1367                only = dict.fromkeys(table, only)
1368        elif isinstance(table, (set, frozenset)):
1369            only = dict.fromkeys(table, only)
1370        else:
1371            raise TypeError('The table must be a string, list or set')
1372        if not (restart is None or isinstance(restart, (bool, int))):
1373            raise TypeError('Invalid type for the restart option')
1374        if not (cascade is None or isinstance(cascade, (bool, int))):
1375            raise TypeError('Invalid type for the cascade option')
1376        tables = []
1377        for t in table:
1378            u = only.get(t)
1379            if not (u is None or isinstance(u, (bool, int))):
1380                raise TypeError('Invalid type for the only option')
1381            if t.endswith('*'):
1382                if u:
1383                    raise ValueError(
1384                        'Contradictory table name and only options')
1385                t = t[:-1].rstrip()
1386            t = self._escape_qualified_name(t)
1387            if u:
1388                t = 'ONLY %s' % t
1389            tables.append(t)
1390        q = ['TRUNCATE', ', '.join(tables)]
1391        if restart:
1392            q.append('RESTART IDENTITY')
1393        if cascade:
1394            q.append('CASCADE')
1395        q = ' '.join(q)
1396        self._do_debug(q)
1397        return self.db.query(q)
1398
1399    def get_as_list(self, table, what=None, where=None,
1400            order=None, limit=None, offset=None, scalar=False):
1401        """Get a table as a list.
1402
1403        This gets a convenient representation of the table as a list
1404        of named tuples in Python.  You only need to pass the name of
1405        the table (or any other SQL expression returning rows).  Note that
1406        by default this will return the full content of the table which
1407        can be huge and overflow your memory.  However, you can control
1408        the amount of data returned using the other optional parameters.
1409
1410        The parameter 'what' can restrict the query to only return a
1411        subset of the table columns.  It can be a string, list or a tuple.
1412        The parameter 'where' can restrict the query to only return a
1413        subset of the table rows.  It can be a string, list or a tuple
1414        of SQL expressions that all need to be fulfilled.  The parameter
1415        'order' specifies the ordering of the rows.  It can also be a
1416        other string, list or a tuple.  If no ordering is specified,
1417        the result will be ordered by the primary key(s) or all columns
1418        if no primary key exists.  You can set 'order' to False if you
1419        don't care about the ordering.  The parameters 'limit' and 'offset'
1420        can be integers specifying the maximum number of rows returned
1421        and a number of rows skipped over.
1422
1423        If you set the 'scalar' option to True, then instead of the
1424        named tuples you will get the first items of these tuples.
1425        This is useful if the result has only one column anyway.
1426        """
1427        if not table:
1428            raise TypeError('The table name is missing')
1429        if what:
1430            if isinstance(what, (list, tuple)):
1431                what = ', '.join(map(str, what))
1432            if order is None:
1433                order = what
1434        else:
1435            what = '*'
1436        q = ['SELECT', what, 'FROM', table]
1437        if where:
1438            if isinstance(where, (list, tuple)):
1439                where = ' AND '.join(map(str, where))
1440            q.extend(['WHERE', where])
1441        if order is None:
1442            try:
1443                order = self.pkey(table, True)
1444            except (KeyError, ProgrammingError):
1445                try:
1446                    order = list(self.get_attnames(table))
1447                except (KeyError, ProgrammingError):
1448                    pass
1449        if order:
1450            if isinstance(order, (list, tuple)):
1451                order = ', '.join(map(str, order))
1452            q.extend(['ORDER BY', order])
1453        if limit:
1454            q.append('LIMIT %d' % limit)
1455        if offset:
1456            q.append('OFFSET %d' % offset)
1457        q = ' '.join(q)
1458        self._do_debug(q)
1459        q = self.db.query(q)
1460        res = q.namedresult()
1461        if res and scalar:
1462            res = [row[0] for row in res]
1463        return res
1464
1465    def get_as_dict(self, table, keyname=None, what=None, where=None,
1466            order=None, limit=None, offset=None, scalar=False):
1467        """Get a table as a dictionary.
1468
1469        This method is similar to get_as_list(), but returns the table
1470        as a Python dict instead of a Python list, which can be even
1471        more convenient. The primary key column(s) of the table will
1472        be used as the keys of the dictionary, while the other column(s)
1473        will be the corresponding values.  The keys will be named tuples
1474        if the table has a composite primary key.  The rows will be also
1475        named tuples unless the 'scalar' option has been set to True.
1476        With the optional parameter 'keyname' you can specify an alternative
1477        set of columns to be used as the keys of the dictionary.  It must
1478        be set as a string, list or a tuple.
1479
1480        If the Python version supports it, the dictionary will be an
1481        OrderedDict using the order specified with the 'order' parameter
1482        or the key column(s) if not specified.  You can set 'order' to False
1483        if you don't care about the ordering.  In this case the returned
1484        dictionary will be an ordinary one.
1485        """
1486        if not table:
1487            raise TypeError('The table name is missing')
1488        if not keyname:
1489            try:
1490                keyname = self.pkey(table, True)
1491            except (KeyError, ProgrammingError):
1492                raise _prg_error('Table %s has no primary key' % table)
1493        if isinstance(keyname, basestring):
1494            keyname = [keyname]
1495        elif not isinstance(keyname, (list, tuple)):
1496            raise KeyError('The keyname must be a string, list or tuple')
1497        if what:
1498            if isinstance(what, (list, tuple)):
1499                what = ', '.join(map(str, what))
1500            if order is None:
1501                order = what
1502        else:
1503            what = '*'
1504        q = ['SELECT', what, 'FROM', table]
1505        if where:
1506            if isinstance(where, (list, tuple)):
1507                where = ' AND '.join(map(str, where))
1508            q.extend(['WHERE', where])
1509        if order is None:
1510            order = keyname
1511        if order:
1512            if isinstance(order, (list, tuple)):
1513                order = ', '.join(map(str, order))
1514            q.extend(['ORDER BY', order])
1515        if limit:
1516            q.append('LIMIT %d' % limit)
1517        if offset:
1518            q.append('OFFSET %d' % offset)
1519        q = ' '.join(q)
1520        self._do_debug(q)
1521        q = self.db.query(q)
1522        res = q.getresult()
1523        cls = OrderedDict if order else dict
1524        if not res:
1525            return cls()
1526        keyset = set(keyname)
1527        fields = q.listfields()
1528        if not keyset.issubset(fields):
1529            raise KeyError('Missing keyname in row')
1530        keyind, rowind = [], []
1531        for i, f in enumerate(fields):
1532            (keyind if f in keyset else rowind).append(i)
1533        keytuple = len(keyind) > 1
1534        getkey = itemgetter(*keyind)
1535        keys = map(getkey, res)
1536        if scalar:
1537            rowind = rowind[:1]
1538            rowtuple = False
1539        else:
1540            rowtuple = len(rowind) > 1
1541        if scalar or rowtuple:
1542            getrow = itemgetter(*rowind)
1543        else:
1544            rowind = rowind[0]
1545            getrow = lambda row: (row[rowind],)
1546            rowtuple = True
1547        rows = map(getrow, res)
1548        if keytuple or rowtuple:
1549            namedresult = get_namedresult()
1550            if namedresult:
1551                if keytuple:
1552                    keys = namedresult(_MemoryQuery(keys, keyname))
1553                if rowtuple:
1554                    fields = [f for f in fields if f not in keyset]
1555                    rows = namedresult(_MemoryQuery(rows, fields))
1556        return cls(zip(keys, rows))
1557
1558    def notification_handler(self,
1559            event, callback, arg_dict=None, timeout=None, stop_event=None):
1560        """Get notification handler that will run the given callback."""
1561        return NotificationHandler(self,
1562            event, callback, arg_dict, timeout, stop_event)
1563
1564
1565# if run as script, print some information
1566
1567if __name__ == '__main__':
1568    print('PyGreSQL version' + version)
1569    print('')
1570    print(__doc__)
Note: See TracBrowser for help on using the repository browser.