source: trunk/pg.py @ 793

Last change on this file since 793 was 793, checked in by cito, 3 years ago

Improve quoting and typecasting in the pg module

Larger refactoring of the code for adapting and typecasting in the pg module.
Things are now a lot cleaner and clearer.

The _Adapt class is responsible for all adapting of Python objects to their
PostgreSQL equivalents when sending data to the database. The typecasting
from PostgreSQL on output happens in the C module, except for the typecasting
of records which is new and provided by the _CastRecord class.

The classic module also did not work properly when regular type names were
switched on with use_regtypes(True), since the adapting of types relied on
the PyGreSQL type names. This has been solved by adding a new _PgType class
that is essentially the old type name, but augmented with all the necessary
information necessary to adapt types, particularly record types.

All tests in test_classic_dbwrapper now run twice, using opposite settings
for the various configuration settings like use_bool() or use_regtypes(),
in order to make sure that no internal functions rely on default settings.

  • Property svn:keywords set to Id
File size: 64.5 KB
Line 
1#! /usr/bin/python
2#
3# pg.py
4#
5# $Id: pg.py 793 2016-01-28 20:07:41Z cito $
6#
7
8"""PyGreSQL classic interface.
9
10This pg module implements some basic database management stuff.
11It includes the _pg module and builds on it, providing the higher
12level wrapper class named DB with additional functionality.
13This is known as the "classic" ("old style") PyGreSQL interface.
14For a DB-API 2 compliant interface use the newer pgdb module.
15"""
16
17# Copyright (c) 1997-2016 by D'Arcy J.M. Cain.
18#
19# Contributions made by Ch. Zwerschke and others.
20#
21# The notification handler is based on pgnotify which is
22# Copyright (c) 2001 Ng Pheng Siong. All rights reserved.
23#
24# Permission to use, copy, modify, and distribute this software and its
25# documentation for any purpose and without fee is hereby granted,
26# provided that the above copyright notice appear in all copies and that
27# both that copyright notice and this permission notice appear in
28# supporting documentation.
29
30from __future__ import print_function
31
32from _pg import *
33
34import select
35import warnings
36
37from decimal import Decimal
38from collections import namedtuple
39from functools import partial
40from operator import itemgetter
41from re import compile as regex
42from json import loads as jsondecode, dumps as jsonencode
43
44try:
45    basestring
46except NameError:  # Python >= 3.0
47    basestring = (str, bytes)
48
49try:
50    from collections import OrderedDict
51except ImportError:  # Python 2.6 or 3.0
52    OrderedDict = dict
53
54
55    class AttrDict(dict):
56        """Simple read-only ordered dictionary for storing attribute names."""
57
58        def __init__(self, *args, **kw):
59            if len(args) > 1 or kw:
60                raise TypeError
61            items = args[0] if args else []
62            if isinstance(items, dict):
63                raise TypeError
64            items = list(items)
65            self._keys = [item[0] for item in items]
66            dict.__init__(self, items)
67            self._read_only = True
68            error = self._read_only_error
69            self.clear = self.update = error
70            self.pop = self.setdefault = self.popitem = error
71
72        def __setitem__(self, key, value):
73            if self._read_only:
74                self._read_only_error()
75            dict.__setitem__(self, key, value)
76
77        def __delitem__(self, key):
78            if self._read_only:
79                self._read_only_error()
80            dict.__delitem__(self, key)
81
82        def __iter__(self):
83            return iter(self._keys)
84
85        def keys(self):
86            return list(self._keys)
87
88        def values(self):
89            return [self[key] for key in self]
90
91        def items(self):
92            return [(key, self[key]) for key in self]
93
94        def iterkeys(self):
95            return self.__iter__()
96
97        def itervalues(self):
98            return iter(self.values())
99
100        def iteritems(self):
101            return iter(self.items())
102
103        @staticmethod
104        def _read_only_error(*args, **kw):
105            raise TypeError('This object is read-only')
106
107else:
108
109     class AttrDict(OrderedDict):
110        """Simple read-only ordered dictionary for storing attribute names."""
111
112        def __init__(self, *args, **kw):
113            self._read_only = False
114            OrderedDict.__init__(self, *args, **kw)
115            self._read_only = True
116            error = self._read_only_error
117            self.clear = self.update = error
118            self.pop = self.setdefault = self.popitem = error
119
120        def __setitem__(self, key, value):
121            if self._read_only:
122                self._read_only_error()
123            OrderedDict.__setitem__(self, key, value)
124
125        def __delitem__(self, key):
126            if self._read_only:
127                self._read_only_error()
128            OrderedDict.__delitem__(self, key)
129
130        @staticmethod
131        def _read_only_error(*args, **kw):
132            raise TypeError('This object is read-only')
133
134
135# Auxiliary classes and functions that are independent from a DB connection:
136
137def _oid_key(table):
138    """Build oid key from a table name."""
139    return 'oid(%s)' % table
140
141
142class _SimpleType(dict):
143    """Dictionary mapping pg_type names to simple type names."""
144
145    _types = {'bool': 'bool',
146        'bytea': 'bytea',
147        'date': 'date interval time timetz timestamp timestamptz'
148            ' abstime reltime',  # these are very old
149        'float': 'float4 float8',
150        'int': 'cid int2 int4 int8 oid xid',
151        'json': 'json jsonb',
152        'num': 'numeric',
153        'money': 'money',
154        'text': 'bpchar char name text varchar'}
155
156    def __init__(self):
157        for typ, keys in self._types.items():
158            for key in keys.split():
159                self[key] = typ
160                self['_%s' % key] = '%s[]' % typ
161
162    @staticmethod
163    def __missing__(key):
164        return 'text'
165
166_simpletype = _SimpleType()
167
168
169class _Adapt:
170    """Mixin providing methods for adapting records and record elements.
171
172    This is used when passing values from one of the higher level DB
173    methods as parameters for a query.
174
175    This class must be mixed in to a connection class, because it needs
176    connection specific methods such as escape_bytea().
177    """
178
179    _bool_true_values = frozenset('t true 1 y yes on'.split())
180
181    _date_literals = frozenset('current_date current_time'
182        ' current_timestamp localtime localtimestamp'.split())
183
184    _re_array_quote = regex(r'[{},"\\\s]|^[Nn][Uu][Ll][Ll]$')
185    _re_record_quote = regex(r'[(,"\\]')
186    _re_array_escape = _re_record_escape = regex(r'(["\\])')
187
188    @classmethod
189    def _adapt_bool(cls, v):
190        """Adapt a boolean parameter."""
191        if isinstance(v, basestring):
192            if not v:
193                return None
194            v = v.lower() in cls._bool_true_values
195        return 't' if v else 'f'
196
197    @classmethod
198    def _adapt_date(cls, v):
199        """Adapt a date parameter."""
200        if not v:
201            return None
202        if isinstance(v, basestring) and v.lower() in cls._date_literals:
203            return _Literal(v)
204        return v
205
206    @staticmethod
207    def _adapt_num(v):
208        """Adapt a numeric parameter."""
209        if not v and v != 0:
210            return None
211        return v
212
213    _adapt_int = _adapt_float = _adapt_money = _adapt_num
214
215    def _adapt_bytea(self, v):
216        """Adapt a bytea parameter."""
217        return self.escape_bytea(v)
218
219    def _adapt_json(self, v):
220        """Adapt a json parameter."""
221        if not v:
222            return None
223        if isinstance(v, basestring):
224            return v
225        return self.encode_json(v)
226
227    @classmethod
228    def _adapt_text_array(cls, v):
229        """Adapt a text type array parameter."""
230        if isinstance(v, list):
231            adapt = cls._adapt_text_array
232            return '{%s}' % ','.join(adapt(v) for v in v)
233        if v is None:
234            return 'null'
235        if not v:
236            return '""'
237        v = str(v)
238        if cls._re_array_quote.search(v):
239            v = '"%s"' % cls._re_array_escape.sub(r'\\\1', v)
240        return v
241
242    _adapt_date_array = _adapt_text_array
243
244    @classmethod
245    def _adapt_bool_array(cls, v):
246        """Adapt a boolean array parameter."""
247        if isinstance(v, list):
248            adapt = cls._adapt_bool_array
249            return '{%s}' % ','.join(adapt(v) for v in v)
250        if v is None:
251            return 'null'
252        if isinstance(v, basestring):
253            if not v:
254                return 'null'
255            v = v.lower() in cls._bool_true_values
256        return 't' if v else 'f'
257
258    @classmethod
259    def _adapt_num_array(cls, v):
260        """Adapt a numeric array parameter."""
261        if isinstance(v, list):
262            adapt = cls._adapt_num_array
263            return '{%s}' % ','.join(adapt(v) for v in v)
264        if not v and v != 0:
265            return 'null'
266        return str(v)
267
268    _adapt_int_array = _adapt_float_array = _adapt_money_array = \
269            _adapt_num_array
270
271    def _adapt_bytea_array(self, v):
272        """Adapt a bytea array parameter."""
273        if isinstance(v, list):
274            return b'{' + b','.join(
275                self._adapt_bytea_array(v) for v in v) + b'}'
276        if v is None:
277            return b'null'
278        return self.escape_bytea(v).replace(b'\\', b'\\\\')
279
280    def _adapt_json_array(self, v):
281        """Adapt a json array parameter."""
282        if isinstance(v, list):
283            adapt = self._adapt_json_array
284            return '{%s}' % ','.join(adapt(v) for v in v)
285        if not v:
286            return 'null'
287        if not isinstance(v, basestring):
288            v = self.encode_json(v)
289        if self._re_array_quote.search(v):
290            v = '"%s"' % self._re_array_escape.sub(r'\\\1', v)
291        return v
292
293    def _adapt_record(self, v, typ):
294        """Adapt a record parameter with given type."""
295        typ = typ.attnames.values()
296        if len(typ) != len(v):
297            raise TypeError('Record parameter %s has wrong size' % v)
298        return '(%s)' % ','.join(getattr(self,
299            '_adapt_record_%s' % t.simple)(v) for v, t in zip(v, typ))
300
301    @classmethod
302    def _adapt_record_text(cls, v):
303        """Adapt a text type record component."""
304        if v is None:
305            return ''
306        if not v:
307            return '""'
308        v = str(v)
309        if cls._re_record_quote.search(v):
310            v = '"%s"' % cls._re_record_escape.sub(r'\\\1', v)
311        return v
312
313    _adapt_record_date = _adapt_record_text
314
315    @classmethod
316    def _adapt_record_bool(cls, v):
317        """Adapt a boolean record component."""
318        if v is None:
319            return ''
320        if isinstance(v, basestring):
321            if not v:
322                return ''
323            v = v.lower() in cls._bool_true_values
324        return 't' if v else 'f'
325
326    @staticmethod
327    def _adapt_record_num(v):
328        """Adapt a numeric record component."""
329        if not v and v != 0:
330            return ''
331        return str(v)
332
333    _adapt_record_int = _adapt_record_float = _adapt_record_money = \
334        _adapt_record_num
335
336    def _adapt_record_bytea(self, v):
337        if v is None:
338            return ''
339        v = self.escape_bytea(v)
340        if bytes is not str and isinstance(v, bytes):
341            v = v.decode('ascii')
342        return v.replace('\\', '\\\\')
343
344    def _adapt_record_json(self, v):
345        """Adapt a bytea record component."""
346        if not v:
347            return ''
348        if not isinstance(v, basestring):
349            v = self.encode_json(v)
350        if self._re_array_quote.search(v):
351            v = '"%s"' % self._re_array_escape.sub(r'\\\1', v)
352        return v
353
354    def _adapt_param(self, value, typ, params):
355        """Adapt and add a parameter to the list."""
356        if isinstance(value, _Literal):
357            return value
358        if value is not None:
359            simple = typ.simple
360            if simple == 'text':
361                pass
362            elif simple == 'record':
363                if isinstance(value, tuple):
364                    value = self._adapt_record(value, typ)
365            elif simple.endswith('[]'):
366                if isinstance(value, list):
367                    adapt = getattr(self, '_adapt_%s_array' % simple[:-2])
368                    value = adapt(value)
369            else:
370                adapt = getattr(self, '_adapt_%s' % simple)
371                value = adapt(value)
372                if isinstance(value, _Literal):
373                    return value
374        params.append(value)
375        return '$%d' % len(params)
376
377
378class _CastRecord:
379    """Class providing methods for casting records and record elements.
380
381    This is needed when getting result values from one of the higher level DB
382    methods, since the lower level query method only casts the other types.
383    """
384
385    @staticmethod
386    def cast_bool(v):
387        if not get_bool():
388            return v
389        return v[0] == 't'
390
391    @staticmethod
392    def cast_bytea(v):
393        return unescape_bytea(v)
394
395    @staticmethod
396    def cast_float(v):
397        return float(v)
398
399    @staticmethod
400    def cast_int(v):
401        return int(v)
402
403    @staticmethod
404    def cast_json(v):
405        cast = get_jsondecode()
406        if not cast:
407            return v
408        return cast(v)
409
410    @staticmethod
411    def cast_num(v):
412        return (get_decimal() or float)(v)
413
414    @staticmethod
415    def cast_money(v):
416        point = get_decimal_point()
417        if not point:
418            return v
419        if point != '.':
420            v = v.replace(point, '.')
421        v = v.replace('(', '-')
422        v = ''.join(c for c in v if c.isdigit() or c in '.-')
423        return (get_decimal() or float)(v)
424
425    @classmethod
426    def cast(cls, v, typ):
427        types = typ.attnames.values()
428        cast = [getattr(cls, 'cast_%s' % t.simple, None) for t in types]
429        v = cast_record(v, cast)
430        return typ.namedtuple(*v)
431
432
433class _PgType(str):
434    """Class augmenting the simple type name with additional info."""
435
436    _num_types = frozenset('int float num money'
437        ' int2 int4 int8 float4 float8 numeric money'.split())
438
439    @classmethod
440    def create(cls, db, pgtype, regtype, typrelid):
441        """Create a PostgreSQL type name with additional info."""
442        simple = 'record' if typrelid else _simpletype[pgtype]
443        self = cls(regtype if db._regtypes else simple)
444        self.db = db
445        self.simple = simple
446        self.pgtype = pgtype
447        self.regtype = regtype
448        self.typrelid = typrelid
449        self._attnames = self._namedtuple = None
450        return self
451
452    @property
453    def attnames(self):
454        """Get names and types of the fields of a composite type."""
455        if not self.typrelid:
456            return None
457        if not self._attnames:
458            self._attnames = self.db.get_attnames(self.typrelid)
459        return self._attnames
460
461    @property
462    def namedtuple(self):
463        """Return named tuple class representing a composite type."""
464        if not self._namedtuple:
465            self._namedtuple = namedtuple(self, self.attnames)
466        return self._namedtuple
467
468    def cast(self, value):
469        if value is not None and self.typrelid:
470            value = _CastRecord.cast(value, self)
471        return value
472
473
474class _Literal(str):
475    """Wrapper class for literal SQL."""
476
477
478def _namedresult(q):
479    """Get query result as named tuples."""
480    row = namedtuple('Row', q.listfields())
481    return [row(*r) for r in q.getresult()]
482
483
484class _MemoryQuery:
485    """Class that embodies a given query result."""
486
487    def __init__(self, result, fields):
488        """Create query from given result rows and field names."""
489        self.result = result
490        self.fields = fields
491
492    def listfields(self):
493        """Return the stored field names of this query."""
494        return self.fields
495
496    def getresult(self):
497        """Return the stored result of this query."""
498        return self.result
499
500
501def _db_error(msg, cls=DatabaseError):
502    """Return DatabaseError with empty sqlstate attribute."""
503    error = cls(msg)
504    error.sqlstate = None
505    return error
506
507
508def _int_error(msg):
509    """Return InternalError."""
510    return _db_error(msg, InternalError)
511
512
513def _prg_error(msg):
514    """Return ProgrammingError."""
515    return _db_error(msg, ProgrammingError)
516
517
518# Initialize the C module
519
520set_namedresult(_namedresult)
521set_decimal(Decimal)
522set_jsondecode(jsondecode)
523
524
525# The notification handler
526
527class NotificationHandler(object):
528    """A PostgreSQL client-side asynchronous notification handler."""
529
530    def __init__(self, db, event, callback=None,
531            arg_dict=None, timeout=None, stop_event=None):
532        """Initialize the notification handler.
533
534        You must pass a PyGreSQL database connection, the name of an
535        event (notification channel) to listen for and a callback function.
536
537        You can also specify a dictionary arg_dict that will be passed as
538        the single argument to the callback function, and a timeout value
539        in seconds (a floating point number denotes fractions of seconds).
540        If it is absent or None, the callers will never time out.  If the
541        timeout is reached, the callback function will be called with a
542        single argument that is None.  If you set the timeout to zero,
543        the handler will poll notifications synchronously and return.
544
545        You can specify the name of the event that will be used to signal
546        the handler to stop listening as stop_event. By default, it will
547        be the event name prefixed with 'stop_'.
548        """
549        self.db = db
550        self.event = event
551        self.stop_event = stop_event or 'stop_%s' % event
552        self.listening = False
553        self.callback = callback
554        if arg_dict is None:
555            arg_dict = {}
556        self.arg_dict = arg_dict
557        self.timeout = timeout
558
559    def __del__(self):
560        self.unlisten()
561
562    def close(self):
563        """Stop listening and close the connection."""
564        if self.db:
565            self.unlisten()
566            self.db.close()
567            self.db = None
568
569    def listen(self):
570        """Start listening for the event and the stop event."""
571        if not self.listening:
572            self.db.query('listen "%s"' % self.event)
573            self.db.query('listen "%s"' % self.stop_event)
574            self.listening = True
575
576    def unlisten(self):
577        """Stop listening for the event and the stop event."""
578        if self.listening:
579            self.db.query('unlisten "%s"' % self.event)
580            self.db.query('unlisten "%s"' % self.stop_event)
581            self.listening = False
582
583    def notify(self, db=None, stop=False, payload=None):
584        """Generate a notification.
585
586        Optionally, you can pass a payload with the notification.
587
588        If you set the stop flag, a stop notification will be sent that
589        will cause the handler to stop listening.
590
591        Note: If the notification handler is running in another thread, you
592        must pass a different database connection since PyGreSQL database
593        connections are not thread-safe.
594        """
595        if self.listening:
596            if not db:
597                db = self.db
598            q = 'notify "%s"' % (self.stop_event if stop else self.event)
599            if payload:
600                q += ", '%s'" % payload
601            return db.query(q)
602
603    def __call__(self):
604        """Invoke the notification handler.
605
606        The handler is a loop that listens for notifications on the event
607        and stop event channels.  When either of these notifications are
608        received, its associated 'pid', 'event' and 'extra' (the payload
609        passed with the notification) are inserted into its arg_dict
610        dictionary and the callback is invoked with this dictionary as
611        a single argument.  When the handler receives a stop event, it
612        stops listening to both events and return.
613
614        In the special case that the timeout of the handler has been set
615        to zero, the handler will poll all events synchronously and return.
616        If will keep listening until it receives a stop event.
617
618        Note: If you run this loop in another thread, don't use the same
619        database connection for database operations in the main thread.
620        """
621        self.listen()
622        poll = self.timeout == 0
623        if not poll:
624            rlist = [self.db.fileno()]
625        while self.listening:
626            if poll or select.select(rlist, [], [], self.timeout)[0]:
627                while self.listening:
628                    notice = self.db.getnotify()
629                    if not notice:  # no more messages
630                        break
631                    event, pid, extra = notice
632                    if event not in (self.event, self.stop_event):
633                        self.unlisten()
634                        raise _db_error(
635                            'Listening for "%s" and "%s", but notified of "%s"'
636                            % (self.event, self.stop_event, event))
637                    if event == self.stop_event:
638                        self.unlisten()
639                    self.arg_dict.update(pid=pid, event=event, extra=extra)
640                    self.callback(self.arg_dict)
641                if poll:
642                    break
643            else:   # we timed out
644                self.unlisten()
645                self.callback(None)
646
647
648def pgnotify(*args, **kw):
649    """Same as NotificationHandler, under the traditional name."""
650    warnings.warn("pgnotify is deprecated, use NotificationHandler instead",
651        DeprecationWarning, stacklevel=2)
652    return NotificationHandler(*args, **kw)
653
654
655# The actual PostGreSQL database connection interface:
656
657class DB(_Adapt):
658    """Wrapper class for the _pg connection type."""
659
660    def __init__(self, *args, **kw):
661        """Create a new connection
662
663        You can pass either the connection parameters or an existing
664        _pg or pgdb connection. This allows you to use the methods
665        of the classic pg interface with a DB-API 2 pgdb connection.
666        """
667        if not args and len(kw) == 1:
668            db = kw.get('db')
669        elif not kw and len(args) == 1:
670            db = args[0]
671        else:
672            db = None
673        if db:
674            if isinstance(db, DB):
675                db = db.db
676            else:
677                try:
678                    db = db._cnx
679                except AttributeError:
680                    pass
681        if not db or not hasattr(db, 'db') or not hasattr(db, 'query'):
682            db = connect(*args, **kw)
683            self._closeable = True
684        else:
685            self._closeable = False
686        self.db = db
687        self.dbname = db.db
688        self._regtypes = False
689        self._attnames = {}
690        self._pkeys = {}
691        self._privileges = {}
692        self._args = args, kw
693        self.debug = None  # For debugging scripts, this can be set
694            # * to a string format specification (e.g. in CGI set to "%s<BR>"),
695            # * to a file object to write debug statements or
696            # * to a callable object which takes a string argument
697            # * to any other true value to just print debug statements
698
699    def __getattr__(self, name):
700        # All undefined members are same as in underlying connection:
701        if self.db:
702            return getattr(self.db, name)
703        else:
704            raise _int_error('Connection is not valid')
705
706    def __dir__(self):
707        # Custom dir function including the attributes of the connection:
708        attrs = set(self.__class__.__dict__)
709        attrs.update(self.__dict__)
710        attrs.update(dir(self.db))
711        return sorted(attrs)
712
713    # Context manager methods
714
715    def __enter__(self):
716        """Enter the runtime context. This will start a transactio."""
717        self.begin()
718        return self
719
720    def __exit__(self, et, ev, tb):
721        """Exit the runtime context. This will end the transaction."""
722        if et is None and ev is None and tb is None:
723            self.commit()
724        else:
725            self.rollback()
726
727    # Auxiliary methods
728
729    def _do_debug(self, *args):
730        """Print a debug message"""
731        if self.debug:
732            s = '\n'.join(str(arg) for arg in args)
733            if isinstance(self.debug, basestring):
734                print(self.debug % s)
735            elif hasattr(self.debug, 'write'):
736                self.debug.write(s + '\n')
737            elif callable(self.debug):
738                self.debug(s)
739            else:
740                print(s)
741
742    def _escape_qualified_name(self, s):
743        """Escape a qualified name.
744
745        Escapes the name for use as an SQL identifier, unless the
746        name contains a dot, in which case the name is ambiguous
747        (could be a qualified name or just a name with a dot in it)
748        and must be quoted manually by the caller.
749        """
750        if '.' not in s:
751            s = self.escape_identifier(s)
752        return s
753
754    @staticmethod
755    def _make_bool(d):
756        """Get boolean value corresponding to d."""
757        return bool(d) if get_bool() else ('t' if d else 'f')
758
759    def _list_params(self, params):
760        """Create a human readable parameter list."""
761        return ', '.join('$%d=%r' % (n, v) for n, v in enumerate(params, 1))
762
763    @staticmethod
764    def _adapt_qualified_param(name, param):
765        """Quote parameter representing a qualified name.
766
767        Escapes the name for use as an SQL parameter, unless the
768        name contains a dot, in which case the name is ambiguous
769        (could be a qualified name or just a name with a dot in it)
770        and must be quoted manually by the caller.
771
772        """
773        if isinstance(param, int):
774            param = "$%d" % param
775        if isinstance(name, basestring) and '.' not in name:
776            param = 'quote_ident(%s)' % (param,)
777        return param
778
779    # Public methods
780
781    # escape_string and escape_bytea exist as methods,
782    # so we define unescape_bytea as a method as well
783    unescape_bytea = staticmethod(unescape_bytea)
784
785    def decode_json(self, s):
786        """Decode a JSON string coming from the database."""
787        return (get_jsondecode() or jsondecode)(s)
788
789    def encode_json(self, d):
790        """Encode a JSON string for use within SQL."""
791        return jsonencode(d)
792
793    def close(self):
794        """Close the database connection."""
795        # Wraps shared library function so we can track state.
796        if self._closeable:
797            if self.db:
798                self.db.close()
799                self.db = None
800            else:
801                raise _int_error('Connection already closed')
802
803    def reset(self):
804        """Reset connection with current parameters.
805
806        All derived queries and large objects derived from this connection
807        will not be usable after this call.
808
809        """
810        if self.db:
811            self.db.reset()
812        else:
813            raise _int_error('Connection already closed')
814
815    def reopen(self):
816        """Reopen connection to the database.
817
818        Used in case we need another connection to the same database.
819        Note that we can still reopen a database that we have closed.
820
821        """
822        # There is no such shared library function.
823        if self._closeable:
824            db = connect(*self._args[0], **self._args[1])
825            if self.db:
826                self.db.close()
827            self.db = db
828
829    def begin(self, mode=None):
830        """Begin a transaction."""
831        qstr = 'BEGIN'
832        if mode:
833            qstr += ' ' + mode
834        return self.query(qstr)
835
836    start = begin
837
838    def commit(self):
839        """Commit the current transaction."""
840        return self.query('COMMIT')
841
842    end = commit
843
844    def rollback(self, name=None):
845        """Roll back the current transaction."""
846        qstr = 'ROLLBACK'
847        if name:
848            qstr += ' TO ' + name
849        return self.query(qstr)
850
851    abort = rollback
852
853    def savepoint(self, name):
854        """Define a new savepoint within the current transaction."""
855        return self.query('SAVEPOINT ' + name)
856
857    def release(self, name):
858        """Destroy a previously defined savepoint."""
859        return self.query('RELEASE ' + name)
860
861    def get_parameter(self, parameter):
862        """Get the value of a run-time parameter.
863
864        If the parameter is a string, the return value will also be a string
865        that is the current setting of the run-time parameter with that name.
866
867        You can get several parameters at once by passing a list, set or dict.
868        When passing a list of parameter names, the return value will be a
869        corresponding list of parameter settings.  When passing a set of
870        parameter names, a new dict will be returned, mapping these parameter
871        names to their settings.  Finally, if you pass a dict as parameter,
872        its values will be set to the current parameter settings corresponding
873        to its keys.
874
875        By passing the special name 'all' as the parameter, you can get a dict
876        of all existing configuration parameters.
877        """
878        if isinstance(parameter, basestring):
879            parameter = [parameter]
880            values = None
881        elif isinstance(parameter, (list, tuple)):
882            values = []
883        elif isinstance(parameter, (set, frozenset)):
884            values = {}
885        elif isinstance(parameter, dict):
886            values = parameter
887        else:
888            raise TypeError(
889                'The parameter must be a string, list, set or dict')
890        if not parameter:
891            raise TypeError('No parameter has been specified')
892        params = {} if isinstance(values, dict) else []
893        for key in parameter:
894            param = key.strip().lower() if isinstance(
895                key, basestring) else None
896            if not param:
897                raise TypeError('Invalid parameter')
898            if param == 'all':
899                q = 'SHOW ALL'
900                values = self.db.query(q).getresult()
901                values = dict(value[:2] for value in values)
902                break
903            if isinstance(values, dict):
904                params[param] = key
905            else:
906                params.append(param)
907        else:
908            for param in params:
909                q = 'SHOW %s' % (param,)
910                value = self.db.query(q).getresult()[0][0]
911                if values is None:
912                    values = value
913                elif isinstance(values, list):
914                    values.append(value)
915                else:
916                    values[params[param]] = value
917        return values
918
919    def set_parameter(self, parameter, value=None, local=False):
920        """Set the value of a run-time parameter.
921
922        If the parameter and the value are strings, the run-time parameter
923        will be set to that value.  If no value or None is passed as a value,
924        then the run-time parameter will be restored to its default value.
925
926        You can set several parameters at once by passing a list of parameter
927        names, together with a single value that all parameters should be
928        set to or with a corresponding list of values.  You can also pass
929        the parameters as a set if you only provide a single value.
930        Finally, you can pass a dict with parameter names as keys.  In this
931        case, you should not pass a value, since the values for the parameters
932        will be taken from the dict.
933
934        By passing the special name 'all' as the parameter, you can reset
935        all existing settable run-time parameters to their default values.
936
937        If you set local to True, then the command takes effect for only the
938        current transaction.  After commit() or rollback(), the session-level
939        setting takes effect again.  Setting local to True will appear to
940        have no effect if it is executed outside a transaction, since the
941        transaction will end immediately.
942        """
943        if isinstance(parameter, basestring):
944            parameter = {parameter: value}
945        elif isinstance(parameter, (list, tuple)):
946            if isinstance(value, (list, tuple)):
947                parameter = dict(zip(parameter, value))
948            else:
949                parameter = dict.fromkeys(parameter, value)
950        elif isinstance(parameter, (set, frozenset)):
951            if isinstance(value, (list, tuple, set, frozenset)):
952                value = set(value)
953                if len(value) == 1:
954                    value = value.pop()
955            if not(value is None or isinstance(value, basestring)):
956                raise ValueError('A single value must be specified'
957                    ' when parameter is a set')
958            parameter = dict.fromkeys(parameter, value)
959        elif isinstance(parameter, dict):
960            if value is not None:
961                raise ValueError('A value must not be specified'
962                    ' when parameter is a dictionary')
963        else:
964            raise TypeError(
965                'The parameter must be a string, list, set or dict')
966        if not parameter:
967            raise TypeError('No parameter has been specified')
968        params = {}
969        for key, value in parameter.items():
970            param = key.strip().lower() if isinstance(
971                key, basestring) else None
972            if not param:
973                raise TypeError('Invalid parameter')
974            if param == 'all':
975                if value is not None:
976                    raise ValueError('A value must ot be specified'
977                        " when parameter is 'all'")
978                params = {'all': None}
979                break
980            params[param] = value
981        local = ' LOCAL' if local else ''
982        for param, value in params.items():
983            if value is None:
984                q = 'RESET%s %s' % (local, param)
985            else:
986                q = 'SET%s %s TO %s' % (local, param, value)
987            self._do_debug(q)
988            self.db.query(q)
989
990    def query(self, command, *args):
991        """Execute a SQL command string.
992
993        This method simply sends a SQL query to the database.  If the query is
994        an insert statement that inserted exactly one row into a table that
995        has OIDs, the return value is the OID of the newly inserted row.
996        If the query is an update or delete statement, or an insert statement
997        that did not insert exactly one row in a table with OIDs, then the
998        number of rows affected is returned as a string.  If it is a statement
999        that returns rows as a result (usually a select statement, but maybe
1000        also an "insert/update ... returning" statement), this method returns
1001        a Query object that can be accessed via getresult() or dictresult()
1002        or simply printed.  Otherwise, it returns `None`.
1003
1004        The query can contain numbered parameters of the form $1 in place
1005        of any data constant.  Arguments given after the query string will
1006        be substituted for the corresponding numbered parameter.  Parameter
1007        values can also be given as a single list or tuple argument.
1008        """
1009        # Wraps shared library function for debugging.
1010        if not self.db:
1011            raise _int_error('Connection is not valid')
1012        if args:
1013            self._do_debug(command, args)
1014            return self.db.query(command, args)
1015        self._do_debug(command)
1016        return self.db.query(command)
1017
1018    def pkey(self, table, composite=False, flush=False):
1019        """Get or set the primary key of a table.
1020
1021        Single primary keys are returned as strings unless you
1022        set the composite flag.  Composite primary keys are always
1023        represented as tuples.  Note that this raises a KeyError
1024        if the table does not have a primary key.
1025
1026        If flush is set then the internal cache for primary keys will
1027        be flushed.  This may be necessary after the database schema or
1028        the search path has been changed.
1029        """
1030        pkeys = self._pkeys
1031        if flush:
1032            pkeys.clear()
1033            self._do_debug('The pkey cache has been flushed')
1034        try:  # cache lookup
1035            pkey = pkeys[table]
1036        except KeyError:  # cache miss, check the database
1037            q = ("SELECT a.attname, a.attnum, i.indkey FROM pg_index i"
1038                " JOIN pg_attribute a ON a.attrelid = i.indrelid"
1039                " AND a.attnum = ANY(i.indkey)"
1040                " AND NOT a.attisdropped"
1041                " WHERE i.indrelid=%s::regclass"
1042                " AND i.indisprimary ORDER BY a.attnum") % (
1043                    self._adapt_qualified_param(table, 1),)
1044            pkey = self.db.query(q, (table,)).getresult()
1045            if not pkey:
1046                raise KeyError('Table %s has no primary key' % table)
1047            # we want to use the order defined in the primary key index here,
1048            # not the order as defined by the columns in the table
1049            if len(pkey) > 1:
1050                indkey = [int(k) for k in pkey[0][2].split()]
1051                pkey = sorted(pkey, key=lambda row: indkey.index(row[1]))
1052                pkey = tuple(row[0] for row in pkey)
1053            else:
1054                pkey = pkey[0][0]
1055            pkeys[table] = pkey  # cache it
1056        if composite and not isinstance(pkey, tuple):
1057            pkey = (pkey,)
1058        return pkey
1059
1060    def get_databases(self):
1061        """Get list of databases in the system."""
1062        return [s[0] for s in
1063            self.db.query('SELECT datname FROM pg_database').getresult()]
1064
1065    def get_relations(self, kinds=None):
1066        """Get list of relations in connected database of specified kinds.
1067
1068        If kinds is None or empty, all kinds of relations are returned.
1069        Otherwise kinds can be a string or sequence of type letters
1070        specifying which kind of relations you want to list.
1071        """
1072        where = " AND r.relkind IN (%s)" % ','.join(
1073            ["'%s'" % k for k in kinds]) if kinds else ''
1074        q = ("SELECT quote_ident(s.nspname)||'.'||quote_ident(r.relname)"
1075            " FROM pg_class r"
1076            " JOIN pg_namespace s ON s.oid = r.relnamespace"
1077            " WHERE s.nspname NOT SIMILAR"
1078            " TO 'pg/_%%|information/_schema' ESCAPE '/' %s"
1079            " ORDER BY s.nspname, r.relname") % where
1080        return [r[0] for r in self.db.query(q).getresult()]
1081
1082    def get_tables(self):
1083        """Return list of tables in connected database."""
1084        return self.get_relations('r')
1085
1086    def get_attnames(self, table, flush=False):
1087        """Given the name of a table, dig out the set of attribute names.
1088
1089        Returns a read-only dictionary of attribute names (the names are
1090        the keys, the values are the names of the attributes' types)
1091        with the column names in the proper order if you iterate over it.
1092
1093        If flush is set, then the internal cache for attribute names will
1094        be flushed. This may be necessary after the database schema or
1095        the search path has been changed.
1096
1097        By default, only a limited number of simple types will be returned.
1098        You can get the regular types after calling use_regtypes(True).
1099        """
1100        attnames = self._attnames
1101        if flush:
1102            attnames.clear()
1103            self._do_debug('The attnames cache has been flushed')
1104        try:  # cache lookup
1105            names = attnames[table]
1106        except KeyError:  # cache miss, check the database
1107            q = ("SELECT a.attname, t.typname, t.typname::regtype, t.typrelid"
1108                " FROM pg_attribute a"
1109                " JOIN pg_type t ON t.oid = a.atttypid"
1110                " WHERE a.attrelid = %s::regclass"
1111                " AND (a.attnum > 0 OR a.attname = 'oid')"
1112                " AND NOT a.attisdropped ORDER BY a.attnum") % (
1113                    self._adapt_qualified_param(table, 1))
1114            names = self.db.query(q, (table,)).getresult()
1115            names = ((name, _PgType.create(self, pgtype, regtype, typrelid))
1116                for name, pgtype, regtype, typrelid in names)
1117            names = AttrDict(names)
1118            attnames[table] = names  # cache it
1119        return names
1120
1121    def use_regtypes(self, regtypes=None):
1122        """Use regular type names instead of simplified type names."""
1123        if regtypes is None:
1124            return self._regtypes
1125        else:
1126            regtypes = bool(regtypes)
1127            if regtypes != self._regtypes:
1128                self._regtypes = regtypes
1129                self._attnames.clear()
1130            return regtypes
1131
1132    def has_table_privilege(self, table, privilege='select'):
1133        """Check whether current user has specified table privilege."""
1134        privilege = privilege.lower()
1135        try:  # ask cache
1136            return self._privileges[(table, privilege)]
1137        except KeyError:  # cache miss, ask the database
1138            q = "SELECT has_table_privilege(%s, $2)" % (
1139                self._adapt_qualified_param(table, 1),)
1140            q = self.db.query(q, (table, privilege))
1141            ret = q.getresult()[0][0] == self._make_bool(True)
1142            self._privileges[(table, privilege)] = ret  # cache it
1143            return ret
1144
1145    def get(self, table, row, keyname=None):
1146        """Get a row from a database table or view.
1147
1148        This method is the basic mechanism to get a single row.  It assumes
1149        that the keyname specifies a unique row.  It must be the name of a
1150        single column or a tuple of column names.  If the keyname is not
1151        specified, then the primary key for the table is used.
1152
1153        If row is a dictionary, then the value for the key is taken from it.
1154        Otherwise, the row must be a single value or a tuple of values
1155        corresponding to the passed keyname or primary key.  The fetched row
1156        from the table will be returned as a new dictionary or used to replace
1157        the existing values when row was passed as aa dictionary.
1158
1159        The OID is also put into the dictionary if the table has one, but
1160        in order to allow the caller to work with multiple tables, it is
1161        munged as "oid(table)" using the actual name of the table.
1162        """
1163        if table.endswith('*'):  # hint for descendant tables can be ignored
1164            table = table[:-1].rstrip()
1165        attnames = self.get_attnames(table)
1166        qoid = _oid_key(table) if 'oid' in attnames else None
1167        if keyname and isinstance(keyname, basestring):
1168            keyname = (keyname,)
1169        if qoid and isinstance(row, dict) and qoid in row and 'oid' not in row:
1170            row['oid'] = row[qoid]
1171        if not keyname:
1172            try:  # if keyname is not specified, try using the primary key
1173                keyname = self.pkey(table, True)
1174            except KeyError:  # the table has no primary key
1175                # try using the oid instead
1176                if qoid and isinstance(row, dict) and 'oid' in row:
1177                    keyname = ('oid',)
1178                else:
1179                    raise _prg_error('Table %s has no primary key' % table)
1180            else:  # the table has a primary key
1181                # check whether all key columns have values
1182                if isinstance(row, dict) and not set(keyname).issubset(row):
1183                    # try using the oid instead
1184                    if qoid and 'oid' in row:
1185                        keyname = ('oid',)
1186                    else:
1187                        raise KeyError(
1188                            'Missing value in row for specified keyname')
1189        if not isinstance(row, dict):
1190            if not isinstance(row, (tuple, list)):
1191                row = [row]
1192            if len(keyname) != len(row):
1193                raise KeyError(
1194                    'Differing number of items in keyname and row')
1195            row = dict(zip(keyname, row))
1196        params = []
1197        param = partial(self._adapt_param, params=params)
1198        col = self.escape_identifier
1199        what = 'oid, *' if qoid else '*'
1200        where = ' AND '.join('%s = %s' % (
1201            col(k), param(row[k], attnames[k])) for k in keyname)
1202        if 'oid' in row:
1203            if qoid:
1204                row[qoid] = row['oid']
1205            del row['oid']
1206        q = 'SELECT %s FROM %s WHERE %s LIMIT 1' % (
1207            what, self._escape_qualified_name(table), where)
1208        self._do_debug(q, params)
1209        q = self.db.query(q, params)
1210        res = q.dictresult()
1211        if not res:
1212            raise _db_error('No such record in %s\nwhere %s\nwith %s' % (
1213                table, where, self._list_params(params)))
1214        for n, value in res[0].items():
1215            if qoid and n == 'oid':
1216                n = qoid
1217            else:
1218                value = attnames[n].cast(value)
1219            row[n] = value
1220        return row
1221
1222    def insert(self, table, row=None, **kw):
1223        """Insert a row into a database table.
1224
1225        This method inserts a row into a table.  The name of the table must
1226        be passed as the first parameter.  The other parameters are used for
1227        providing the data of the row that shall be inserted into the table.
1228        If a dictionary is supplied as the second parameter, it starts with
1229        that.  Otherwise it uses a blank dictionary. Either way the dictionary
1230        is updated from the keywords.
1231
1232        The dictionary is then reloaded with the values actually inserted in
1233        order to pick up values modified by rules, triggers, etc.
1234        """
1235        if table.endswith('*'):  # hint for descendant tables can be ignored
1236            table = table[:-1].rstrip()
1237        if row is None:
1238            row = {}
1239        row.update(kw)
1240        if 'oid' in row:
1241            del row['oid']  # do not insert oid
1242        attnames = self.get_attnames(table)
1243        qoid = _oid_key(table) if 'oid' in attnames else None
1244        params = []
1245        param = partial(self._adapt_param, params=params)
1246        col = self.escape_identifier
1247        names, values = [], []
1248        for n in attnames:
1249            if n in row:
1250                names.append(col(n))
1251                values.append(param(row[n], attnames[n]))
1252        if not names:
1253            raise _prg_error('No column found that can be inserted')
1254        names, values = ', '.join(names), ', '.join(values)
1255        ret = 'oid, *' if qoid else '*'
1256        q = 'INSERT INTO %s (%s) VALUES (%s) RETURNING %s' % (
1257            self._escape_qualified_name(table), names, values, ret)
1258        self._do_debug(q, params)
1259        q = self.db.query(q, params)
1260        res = q.dictresult()
1261        if res:  # this should always be true
1262            for n, value in res[0].items():
1263                if qoid and n == 'oid':
1264                    n = qoid
1265                else:
1266                    value = attnames[n].cast(value)
1267                row[n] = value
1268        return row
1269
1270    def update(self, table, row=None, **kw):
1271        """Update an existing row in a database table.
1272
1273        Similar to insert but updates an existing row.  The update is based
1274        on the primary key of the table or the OID value as munged by get
1275        or passed as keyword.
1276
1277        The dictionary is then modified to reflect any changes caused by the
1278        update due to triggers, rules, default values, etc.
1279        """
1280        if table.endswith('*'):
1281            table = table[:-1].rstrip()  # need parent table name
1282        attnames = self.get_attnames(table)
1283        qoid = _oid_key(table) if 'oid' in attnames else None
1284        if row is None:
1285            row = {}
1286        elif 'oid' in row:
1287            del row['oid']  # only accept oid key from named args for safety
1288        row.update(kw)
1289        if qoid and qoid in row and 'oid' not in row:
1290            row['oid'] = row[qoid]
1291        try:  # try using the primary key
1292            keyname = self.pkey(table, True)
1293        except KeyError:  # the table has no primary key
1294            # try using the oid instead
1295            if qoid and 'oid' in row:
1296                keyname = ('oid',)
1297            else:
1298                raise _prg_error('Table %s has no primary key' % table)
1299        else:  # the table has a primary key
1300            # check whether all key columns have values
1301            if not set(keyname).issubset(row):
1302                # try using the oid instead
1303                if qoid and 'oid' in row:
1304                    keyname = ('oid',)
1305                else:
1306                    raise KeyError('Missing primary key in row')
1307        params = []
1308        param = partial(self._adapt_param, params=params)
1309        col = self.escape_identifier
1310        where = ' AND '.join('%s = %s' % (
1311            col(k), param(row[k], attnames[k])) for k in keyname)
1312        if 'oid' in row:
1313            if qoid:
1314                row[qoid] = row['oid']
1315            del row['oid']
1316        values = []
1317        keyname = set(keyname)
1318        for n in attnames:
1319            if n in row and n not in keyname:
1320                values.append('%s = %s' % (col(n), param(row[n], attnames[n])))
1321        if not values:
1322            return row
1323        values = ', '.join(values)
1324        ret = 'oid, *' if qoid else '*'
1325        q = 'UPDATE %s SET %s WHERE %s RETURNING %s' % (
1326            self._escape_qualified_name(table), values, where, ret)
1327        self._do_debug(q, params)
1328        q = self.db.query(q, params)
1329        res = q.dictresult()
1330        if res:  # may be empty when row does not exist
1331            for n, value in res[0].items():
1332                if qoid and n == 'oid':
1333                    n = qoid
1334                else:
1335                    value = attnames[n].cast(value)
1336                row[n] = value
1337        return row
1338
1339    def upsert(self, table, row=None, **kw):
1340        """Insert a row into a database table with conflict resolution
1341
1342        This method inserts a row into a table, but instead of raising a
1343        ProgrammingError exception in case a row with the same primary key
1344        already exists, an update will be executed instead.  This will be
1345        performed as a single atomic operation on the database, so race
1346        conditions can be avoided.
1347
1348        Like the insert method, the first parameter is the name of the
1349        table and the second parameter can be used to pass the values to
1350        be inserted as a dictionary.
1351
1352        Unlike the insert und update statement, keyword parameters are not
1353        used to modify the dictionary, but to specify which columns shall
1354        be updated in case of a conflict, and in which way:
1355
1356        A value of False or None means the column shall not be updated,
1357        a value of True means the column shall be updated with the value
1358        that has been proposed for insertion, i.e. has been passed as value
1359        in the dictionary.  Columns that are not specified by keywords but
1360        appear as keys in the dictionary are also updated like in the case
1361        keywords had been passed with the value True.
1362
1363        So if in the case of a conflict you want to update every column that
1364        has been passed in the dictionary row , you would call upsert(table, row).
1365        If you don't want to do anything in case of a conflict, i.e. leave
1366        the existing row as it is, call upsert(table, row, **dict.fromkeys(row)).
1367
1368        If you need more fine-grained control of what gets updated, you can
1369        also pass strings in the keyword parameters.  These strings will
1370        be used as SQL expressions for the update columns.  In these
1371        expressions you can refer to the value that already exists in
1372        the table by prefixing the column name with "included.", and to
1373        the value that has been proposed for insertion by prefixing the
1374        column name with the "excluded."
1375
1376        The dictionary is modified in any case to reflect the values in
1377        the database after the operation has completed.
1378
1379        Note: The method uses the PostgreSQL "upsert" feature which is
1380        only available since PostgreSQL 9.5.
1381        """
1382        if table.endswith('*'):  # hint for descendant tables can be ignored
1383            table = table[:-1].rstrip()
1384        if row is None:
1385            row = {}
1386        if 'oid' in row:
1387            del row['oid']  # do not insert oid
1388        if 'oid' in kw:
1389            del kw['oid']  # do not update oid
1390        attnames = self.get_attnames(table)
1391        qoid = _oid_key(table) if 'oid' in attnames else None
1392        params = []
1393        param = partial(self._adapt_param,params=params)
1394        col = self.escape_identifier
1395        names, values, updates = [], [], []
1396        for n in attnames:
1397            if n in row:
1398                names.append(col(n))
1399                values.append(param(row[n], attnames[n]))
1400        names, values = ', '.join(names), ', '.join(values)
1401        try:
1402            keyname = self.pkey(table, True)
1403        except KeyError:
1404            raise _prg_error('Table %s has no primary key' % table)
1405        target = ', '.join(col(k) for k in keyname)
1406        update = []
1407        keyname = set(keyname)
1408        keyname.add('oid')
1409        for n in attnames:
1410            if n not in keyname:
1411                value = kw.get(n, True)
1412                if value:
1413                    if not isinstance(value, basestring):
1414                        value = 'excluded.%s' % col(n)
1415                    update.append('%s = %s' % (col(n), value))
1416        if not values:
1417            return row
1418        do = 'update set %s' % ', '.join(update) if update else 'nothing'
1419        ret = 'oid, *' if qoid else '*'
1420        q = ('INSERT INTO %s AS included (%s) VALUES (%s)'
1421            ' ON CONFLICT (%s) DO %s RETURNING %s') % (
1422                self._escape_qualified_name(table), names, values,
1423                target, do, ret)
1424        self._do_debug(q, params)
1425        try:
1426            q = self.db.query(q, params)
1427        except ProgrammingError:
1428            if self.server_version < 90500:
1429                raise _prg_error(
1430                    'Upsert operation is not supported by PostgreSQL version')
1431            raise  # re-raise original error
1432        res = q.dictresult()
1433        if res:  # may be empty with "do nothing"
1434            for n, value in res[0].items():
1435                if qoid and n == 'oid':
1436                    n = qoid
1437                else:
1438                    value = attnames[n].cast(value)
1439                row[n] = value
1440        else:
1441            self.get(table, row)
1442        return row
1443
1444    def clear(self, table, row=None):
1445        """Clear all the attributes to values determined by the types.
1446
1447        Numeric types are set to 0, Booleans are set to false, and everything
1448        else is set to the empty string.  If the row argument is present,
1449        it is used as the row dictionary and any entries matching attribute
1450        names are cleared with everything else left unchanged.
1451        """
1452        # At some point we will need a way to get defaults from a table.
1453        if row is None:
1454            row = {}  # empty if argument is not present
1455        attnames = self.get_attnames(table)
1456        for n, t in attnames.items():
1457            if n == 'oid':
1458                continue
1459            t = t.simple
1460            if t in _PgType._num_types:
1461                row[n] = 0
1462            elif t == 'bool':
1463                row[n] = self._make_bool(False)
1464            else:
1465                row[n] = ''
1466        return row
1467
1468    def delete(self, table, row=None, **kw):
1469        """Delete an existing row in a database table.
1470
1471        This method deletes the row from a table.  It deletes based on the
1472        primary key of the table or the OID value as munged by get() or
1473        passed as keyword.
1474
1475        The return value is the number of deleted rows (i.e. 0 if the row
1476        did not exist and 1 if the row was deleted).
1477
1478        Note that if the row cannot be deleted because e.g. it is still
1479        referenced by another table, this method raises a ProgrammingError.
1480        """
1481        if table.endswith('*'):  # hint for descendant tables can be ignored
1482            table = table[:-1].rstrip()
1483        attnames = self.get_attnames(table)
1484        qoid = _oid_key(table) if 'oid' in attnames else None
1485        if row is None:
1486            row = {}
1487        elif 'oid' in row:
1488            del row['oid']  # only accept oid key from named args for safety
1489        row.update(kw)
1490        if qoid and qoid in row and 'oid' not in row:
1491            row['oid'] = row[qoid]
1492        try:  # try using the primary key
1493            keyname = self.pkey(table, True)
1494        except KeyError:  # the table has no primary key
1495            # try using the oid instead
1496            if qoid and 'oid' in row:
1497                keyname = ('oid',)
1498            else:
1499                raise _prg_error('Table %s has no primary key' % table)
1500        else:  # the table has a primary key
1501            # check whether all key columns have values
1502            if not set(keyname).issubset(row):
1503                # try using the oid instead
1504                if qoid and 'oid' in row:
1505                    keyname = ('oid',)
1506                else:
1507                    raise KeyError('Missing primary key in row')
1508        params = []
1509        param = partial(self._adapt_param, params=params)
1510        col = self.escape_identifier
1511        where = ' AND '.join('%s = %s' % (
1512            col(k), param(row[k], attnames[k])) for k in keyname)
1513        if 'oid' in row:
1514            if qoid:
1515                row[qoid] = row['oid']
1516            del row['oid']
1517        q = 'DELETE FROM %s WHERE %s' % (
1518            self._escape_qualified_name(table), where)
1519        self._do_debug(q, params)
1520        res = self.db.query(q, params)
1521        return int(res)
1522
1523    def truncate(self, table, restart=False, cascade=False, only=False):
1524        """Empty a table or set of tables.
1525
1526        This method quickly removes all rows from the given table or set
1527        of tables.  It has the same effect as an unqualified DELETE on each
1528        table, but since it does not actually scan the tables it is faster.
1529        Furthermore, it reclaims disk space immediately, rather than requiring
1530        a subsequent VACUUM operation. This is most useful on large tables.
1531
1532        If restart is set to True, sequences owned by columns of the truncated
1533        table(s) are automatically restarted.  If cascade is set to True, it
1534        also truncates all tables that have foreign-key references to any of
1535        the named tables.  If the parameter only is not set to True, all the
1536        descendant tables (if any) will also be truncated. Optionally, a '*'
1537        can be specified after the table name to explicitly indicate that
1538        descendant tables are included.
1539        """
1540        if isinstance(table, basestring):
1541            only = {table: only}
1542            table = [table]
1543        elif isinstance(table, (list, tuple)):
1544            if isinstance(only, (list, tuple)):
1545                only = dict(zip(table, only))
1546            else:
1547                only = dict.fromkeys(table, only)
1548        elif isinstance(table, (set, frozenset)):
1549            only = dict.fromkeys(table, only)
1550        else:
1551            raise TypeError('The table must be a string, list or set')
1552        if not (restart is None or isinstance(restart, (bool, int))):
1553            raise TypeError('Invalid type for the restart option')
1554        if not (cascade is None or isinstance(cascade, (bool, int))):
1555            raise TypeError('Invalid type for the cascade option')
1556        tables = []
1557        for t in table:
1558            u = only.get(t)
1559            if not (u is None or isinstance(u, (bool, int))):
1560                raise TypeError('Invalid type for the only option')
1561            if t.endswith('*'):
1562                if u:
1563                    raise ValueError(
1564                        'Contradictory table name and only options')
1565                t = t[:-1].rstrip()
1566            t = self._escape_qualified_name(t)
1567            if u:
1568                t = 'ONLY %s' % t
1569            tables.append(t)
1570        q = ['TRUNCATE', ', '.join(tables)]
1571        if restart:
1572            q.append('RESTART IDENTITY')
1573        if cascade:
1574            q.append('CASCADE')
1575        q = ' '.join(q)
1576        self._do_debug(q)
1577        return self.db.query(q)
1578
1579    def get_as_list(self, table, what=None, where=None,
1580            order=None, limit=None, offset=None, scalar=False):
1581        """Get a table as a list.
1582
1583        This gets a convenient representation of the table as a list
1584        of named tuples in Python.  You only need to pass the name of
1585        the table (or any other SQL expression returning rows).  Note that
1586        by default this will return the full content of the table which
1587        can be huge and overflow your memory.  However, you can control
1588        the amount of data returned using the other optional parameters.
1589
1590        The parameter 'what' can restrict the query to only return a
1591        subset of the table columns.  It can be a string, list or a tuple.
1592        The parameter 'where' can restrict the query to only return a
1593        subset of the table rows.  It can be a string, list or a tuple
1594        of SQL expressions that all need to be fulfilled.  The parameter
1595        'order' specifies the ordering of the rows.  It can also be a
1596        other string, list or a tuple.  If no ordering is specified,
1597        the result will be ordered by the primary key(s) or all columns
1598        if no primary key exists.  You can set 'order' to False if you
1599        don't care about the ordering.  The parameters 'limit' and 'offset'
1600        can be integers specifying the maximum number of rows returned
1601        and a number of rows skipped over.
1602
1603        If you set the 'scalar' option to True, then instead of the
1604        named tuples you will get the first items of these tuples.
1605        This is useful if the result has only one column anyway.
1606        """
1607        if not table:
1608            raise TypeError('The table name is missing')
1609        if what:
1610            if isinstance(what, (list, tuple)):
1611                what = ', '.join(map(str, what))
1612            if order is None:
1613                order = what
1614        else:
1615            what = '*'
1616        q = ['SELECT', what, 'FROM', table]
1617        if where:
1618            if isinstance(where, (list, tuple)):
1619                where = ' AND '.join(map(str, where))
1620            q.extend(['WHERE', where])
1621        if order is None:
1622            try:
1623                order = self.pkey(table, True)
1624            except (KeyError, ProgrammingError):
1625                try:
1626                    order = list(self.get_attnames(table))
1627                except (KeyError, ProgrammingError):
1628                    pass
1629        if order:
1630            if isinstance(order, (list, tuple)):
1631                order = ', '.join(map(str, order))
1632            q.extend(['ORDER BY', order])
1633        if limit:
1634            q.append('LIMIT %d' % limit)
1635        if offset:
1636            q.append('OFFSET %d' % offset)
1637        q = ' '.join(q)
1638        self._do_debug(q)
1639        q = self.db.query(q)
1640        res = q.namedresult()
1641        if res and scalar:
1642            res = [row[0] for row in res]
1643        return res
1644
1645    def get_as_dict(self, table, keyname=None, what=None, where=None,
1646            order=None, limit=None, offset=None, scalar=False):
1647        """Get a table as a dictionary.
1648
1649        This method is similar to get_as_list(), but returns the table
1650        as a Python dict instead of a Python list, which can be even
1651        more convenient. The primary key column(s) of the table will
1652        be used as the keys of the dictionary, while the other column(s)
1653        will be the corresponding values.  The keys will be named tuples
1654        if the table has a composite primary key.  The rows will be also
1655        named tuples unless the 'scalar' option has been set to True.
1656        With the optional parameter 'keyname' you can specify an alternative
1657        set of columns to be used as the keys of the dictionary.  It must
1658        be set as a string, list or a tuple.
1659
1660        If the Python version supports it, the dictionary will be an
1661        OrderedDict using the order specified with the 'order' parameter
1662        or the key column(s) if not specified.  You can set 'order' to False
1663        if you don't care about the ordering.  In this case the returned
1664        dictionary will be an ordinary one.
1665        """
1666        if not table:
1667            raise TypeError('The table name is missing')
1668        if not keyname:
1669            try:
1670                keyname = self.pkey(table, True)
1671            except (KeyError, ProgrammingError):
1672                raise _prg_error('Table %s has no primary key' % table)
1673        if isinstance(keyname, basestring):
1674            keyname = [keyname]
1675        elif not isinstance(keyname, (list, tuple)):
1676            raise KeyError('The keyname must be a string, list or tuple')
1677        if what:
1678            if isinstance(what, (list, tuple)):
1679                what = ', '.join(map(str, what))
1680            if order is None:
1681                order = what
1682        else:
1683            what = '*'
1684        q = ['SELECT', what, 'FROM', table]
1685        if where:
1686            if isinstance(where, (list, tuple)):
1687                where = ' AND '.join(map(str, where))
1688            q.extend(['WHERE', where])
1689        if order is None:
1690            order = keyname
1691        if order:
1692            if isinstance(order, (list, tuple)):
1693                order = ', '.join(map(str, order))
1694            q.extend(['ORDER BY', order])
1695        if limit:
1696            q.append('LIMIT %d' % limit)
1697        if offset:
1698            q.append('OFFSET %d' % offset)
1699        q = ' '.join(q)
1700        self._do_debug(q)
1701        q = self.db.query(q)
1702        res = q.getresult()
1703        cls = OrderedDict if order else dict
1704        if not res:
1705            return cls()
1706        keyset = set(keyname)
1707        fields = q.listfields()
1708        if not keyset.issubset(fields):
1709            raise KeyError('Missing keyname in row')
1710        keyind, rowind = [], []
1711        for i, f in enumerate(fields):
1712            (keyind if f in keyset else rowind).append(i)
1713        keytuple = len(keyind) > 1
1714        getkey = itemgetter(*keyind)
1715        keys = map(getkey, res)
1716        if scalar:
1717            rowind = rowind[:1]
1718            rowtuple = False
1719        else:
1720            rowtuple = len(rowind) > 1
1721        if scalar or rowtuple:
1722            getrow = itemgetter(*rowind)
1723        else:
1724            rowind = rowind[0]
1725            getrow = lambda row: (row[rowind],)
1726            rowtuple = True
1727        rows = map(getrow, res)
1728        if keytuple or rowtuple:
1729            namedresult = get_namedresult()
1730            if namedresult:
1731                if keytuple:
1732                    keys = namedresult(_MemoryQuery(keys, keyname))
1733                if rowtuple:
1734                    fields = [f for f in fields if f not in keyset]
1735                    rows = namedresult(_MemoryQuery(rows, fields))
1736        return cls(zip(keys, rows))
1737
1738    def notification_handler(self,
1739            event, callback, arg_dict=None, timeout=None, stop_event=None):
1740        """Get notification handler that will run the given callback."""
1741        return NotificationHandler(self,
1742            event, callback, arg_dict, timeout, stop_event)
1743
1744
1745# if run as script, print some information
1746
1747if __name__ == '__main__':
1748    print('PyGreSQL version' + version)
1749    print('')
1750    print(__doc__)
Note: See TracBrowser for help on using the repository browser.