source: trunk/pg.py @ 883

Last change on this file since 883 was 883, checked in by cito, 3 years ago

Fix issue with adaptation of empty arrays in pg

  • Property svn:keywords set to Id
File size: 89.7 KB
Line 
1#! /usr/bin/python
2#
3# pg.py
4#
5# $Id: pg.py 883 2016-08-10 13:57:50Z cito $
6#
7
8"""PyGreSQL classic interface.
9
10This pg module implements some basic database management stuff.
11It includes the _pg module and builds on it, providing the higher
12level wrapper class named DB with additional functionality.
13This is known as the "classic" ("old style") PyGreSQL interface.
14For a DB-API 2 compliant interface use the newer pgdb module.
15"""
16
17# Copyright (c) 1997-2016 by D'Arcy J.M. Cain.
18#
19# Contributions made by Ch. Zwerschke and others.
20#
21# The notification handler is based on pgnotify which is
22# Copyright (c) 2001 Ng Pheng Siong. All rights reserved.
23#
24# Permission to use, copy, modify, and distribute this software and its
25# documentation for any purpose and without fee is hereby granted,
26# provided that the above copyright notice appear in all copies and that
27# both that copyright notice and this permission notice appear in
28# supporting documentation.
29
30from __future__ import print_function, division
31
32from _pg import *
33
34__version__ = version
35
36import select
37import warnings
38import weakref
39
40from datetime import date, time, datetime, timedelta, tzinfo
41from decimal import Decimal
42from math import isnan, isinf
43from collections import namedtuple
44from operator import itemgetter
45from functools import partial
46from re import compile as regex
47from json import loads as jsondecode, dumps as jsonencode
48from uuid import UUID
49
50try:
51    long
52except NameError:  # Python >= 3.0
53    long = int
54
55try:
56    basestring
57except NameError:  # Python >= 3.0
58    basestring = (str, bytes)
59
60
61# Auxiliary classes and functions that are independent from a DB connection:
62
63try:
64    from collections import OrderedDict
65except ImportError:  # Python 2.6 or 3.0
66    OrderedDict = dict
67
68
69    class AttrDict(dict):
70        """Simple read-only ordered dictionary for storing attribute names."""
71
72        def __init__(self, *args, **kw):
73            if len(args) > 1 or kw:
74                raise TypeError
75            items = args[0] if args else []
76            if isinstance(items, dict):
77                raise TypeError
78            items = list(items)
79            self._keys = [item[0] for item in items]
80            dict.__init__(self, items)
81            self._read_only = True
82            error = self._read_only_error
83            self.clear = self.update = error
84            self.pop = self.setdefault = self.popitem = error
85
86        def __setitem__(self, key, value):
87            if self._read_only:
88                self._read_only_error()
89            dict.__setitem__(self, key, value)
90
91        def __delitem__(self, key):
92            if self._read_only:
93                self._read_only_error()
94            dict.__delitem__(self, key)
95
96        def __iter__(self):
97            return iter(self._keys)
98
99        def keys(self):
100            return list(self._keys)
101
102        def values(self):
103            return [self[key] for key in self]
104
105        def items(self):
106            return [(key, self[key]) for key in self]
107
108        def iterkeys(self):
109            return self.__iter__()
110
111        def itervalues(self):
112            return iter(self.values())
113
114        def iteritems(self):
115            return iter(self.items())
116
117        @staticmethod
118        def _read_only_error(*args, **kw):
119            raise TypeError('This object is read-only')
120
121else:
122
123     class AttrDict(OrderedDict):
124        """Simple read-only ordered dictionary for storing attribute names."""
125
126        def __init__(self, *args, **kw):
127            self._read_only = False
128            OrderedDict.__init__(self, *args, **kw)
129            self._read_only = True
130            error = self._read_only_error
131            self.clear = self.update = error
132            self.pop = self.setdefault = self.popitem = error
133
134        def __setitem__(self, key, value):
135            if self._read_only:
136                self._read_only_error()
137            OrderedDict.__setitem__(self, key, value)
138
139        def __delitem__(self, key):
140            if self._read_only:
141                self._read_only_error()
142            OrderedDict.__delitem__(self, key)
143
144        @staticmethod
145        def _read_only_error(*args, **kw):
146            raise TypeError('This object is read-only')
147
148try:
149    from inspect import signature
150except ImportError:  # Python < 3.3
151    from inspect import getargspec
152
153    def get_args(func):
154        return getargspec(func).args
155else:
156
157    def get_args(func):
158        return list(signature(func).parameters)
159
160try:
161    from datetime import timezone
162except ImportError:  # Python < 3.2
163
164    class timezone(tzinfo):
165        """Simple timezone implementation."""
166
167        def __init__(self, offset, name=None):
168            self.offset = offset
169            if not name:
170                minutes = self.offset.days * 1440 + self.offset.seconds // 60
171                if minutes < 0:
172                    hours, minutes = divmod(-minutes, 60)
173                    hours = -hours
174                else:
175                    hours, minutes = divmod(minutes, 60)
176                name = 'UTC%+03d:%02d' % (hours, minutes)
177            self.name = name
178
179        def utcoffset(self, dt):
180            return self.offset
181
182        def tzname(self, dt):
183            return self.name
184
185        def dst(self, dt):
186            return None
187
188    timezone.utc = timezone(timedelta(0), 'UTC')
189
190    _has_timezone = False
191else:
192    _has_timezone = True
193
194# time zones used in Postgres timestamptz output
195_timezones = dict(CET='+0100', EET='+0200', EST='-0500',
196    GMT='+0000', HST='-1000', MET='+0100', MST='-0700',
197    UCT='+0000', UTC='+0000', WET='+0000')
198
199
200def _timezone_as_offset(tz):
201    if tz.startswith(('+', '-')):
202        if len(tz) < 5:
203            return tz + '00'
204        return tz.replace(':', '')
205    return _timezones.get(tz, '+0000')
206
207
208def _get_timezone(tz):
209    tz = _timezone_as_offset(tz)
210    minutes = 60 * int(tz[1:3]) + int(tz[3:5])
211    if tz[0] == '-':
212        minutes = -minutes
213    return timezone(timedelta(minutes=minutes), tz)
214
215
216def _oid_key(table):
217    """Build oid key from a table name."""
218    return 'oid(%s)' % table
219
220
221class _SimpleTypes(dict):
222    """Dictionary mapping pg_type names to simple type names."""
223
224    _types = {'bool': 'bool',
225        'bytea': 'bytea',
226        'date': 'date interval time timetz timestamp timestamptz'
227            ' abstime reltime',  # these are very old
228        'float': 'float4 float8',
229        'int': 'cid int2 int4 int8 oid xid',
230        'hstore': 'hstore', 'json': 'json jsonb', 'uuid': 'uuid',
231        'num': 'numeric', 'money': 'money',
232        'text': 'bpchar char name text varchar'}
233
234    def __init__(self):
235        for typ, keys in self._types.items():
236            for key in keys.split():
237                self[key] = typ
238                self['_%s' % key] = '%s[]' % typ
239
240    # this could be a static method in Python > 2.6
241    def __missing__(self, key):
242        return 'text'
243
244_simpletypes = _SimpleTypes()
245
246
247def _quote_if_unqualified(param, name):
248    """Quote parameter representing a qualified name.
249
250    Puts a quote_ident() call around the give parameter unless
251    the name contains a dot, in which case the name is ambiguous
252    (could be a qualified name or just a name with a dot in it)
253    and must be quoted manually by the caller.
254    """
255    if isinstance(name, basestring) and '.' not in name:
256        return 'quote_ident(%s)' % (param,)
257    return param
258
259
260class _ParameterList(list):
261    """Helper class for building typed parameter lists."""
262
263    def add(self, value, typ=None):
264        """Typecast value with known database type and build parameter list.
265
266        If this is a literal value, it will be returned as is.  Otherwise, a
267        placeholder will be returned and the parameter list will be augmented.
268        """
269        value = self.adapt(value, typ)
270        if isinstance(value, Literal):
271            return value
272        self.append(value)
273        return '$%d' % len(self)
274
275
276class Bytea(bytes):
277    """Wrapper class for marking Bytea values."""
278
279
280class Hstore(dict):
281    """Wrapper class for marking hstore values."""
282
283    _re_quote = regex('^[Nn][Uu][Ll][Ll]$|[ ,=>]')
284
285    @classmethod
286    def _quote(cls, s):
287        if s is None:
288            return 'NULL'
289        if not s:
290            return '""'
291        s = s.replace('"', '\\"')
292        if cls._re_quote.search(s):
293            s = '"%s"' % s
294        return s
295
296    def __str__(self):
297        q = self._quote
298        return ','.join('%s=>%s' % (q(k), q(v)) for k, v in self.items())
299
300
301class Json:
302    """Wrapper class for marking Json values."""
303
304    def __init__(self, obj):
305        self.obj = obj
306
307
308class Literal(str):
309    """Wrapper class for marking literal SQL values."""
310
311
312class Adapter:
313    """Class providing methods for adapting parameters to the database."""
314
315    _bool_true_values = frozenset('t true 1 y yes on'.split())
316
317    _date_literals = frozenset('current_date current_time'
318        ' current_timestamp localtime localtimestamp'.split())
319
320    _re_array_quote = regex(r'[{},"\\\s]|^[Nn][Uu][Ll][Ll]$')
321    _re_record_quote = regex(r'[(,"\\]')
322    _re_array_escape = _re_record_escape = regex(r'(["\\])')
323
324    def __init__(self, db):
325        self.db = weakref.proxy(db)
326
327    @classmethod
328    def _adapt_bool(cls, v):
329        """Adapt a boolean parameter."""
330        if isinstance(v, basestring):
331            if not v:
332                return None
333            v = v.lower() in cls._bool_true_values
334        return 't' if v else 'f'
335
336    @classmethod
337    def _adapt_date(cls, v):
338        """Adapt a date parameter."""
339        if not v:
340            return None
341        if isinstance(v, basestring) and v.lower() in cls._date_literals:
342            return Literal(v)
343        return v
344
345    @staticmethod
346    def _adapt_num(v):
347        """Adapt a numeric parameter."""
348        if not v and v != 0:
349            return None
350        return v
351
352    _adapt_int = _adapt_float = _adapt_money = _adapt_num
353
354    def _adapt_bytea(self, v):
355        """Adapt a bytea parameter."""
356        return self.db.escape_bytea(v)
357
358    def _adapt_json(self, v):
359        """Adapt a json parameter."""
360        if not v:
361            return None
362        if isinstance(v, basestring):
363            return v
364        return self.db.encode_json(v)
365
366    @classmethod
367    def _adapt_text_array(cls, v):
368        """Adapt a text type array parameter."""
369        if isinstance(v, list):
370            adapt = cls._adapt_text_array
371            return '{%s}' % ','.join(adapt(v) for v in v)
372        if v is None:
373            return 'null'
374        if not v:
375            return '""'
376        v = str(v)
377        if cls._re_array_quote.search(v):
378            v = '"%s"' % cls._re_array_escape.sub(r'\\\1', v)
379        return v
380
381    _adapt_date_array = _adapt_text_array
382
383    @classmethod
384    def _adapt_bool_array(cls, v):
385        """Adapt a boolean array parameter."""
386        if isinstance(v, list):
387            adapt = cls._adapt_bool_array
388            return '{%s}' % ','.join(adapt(v) for v in v)
389        if v is None:
390            return 'null'
391        if isinstance(v, basestring):
392            if not v:
393                return 'null'
394            v = v.lower() in cls._bool_true_values
395        return 't' if v else 'f'
396
397    @classmethod
398    def _adapt_num_array(cls, v):
399        """Adapt a numeric array parameter."""
400        if isinstance(v, list):
401            adapt = cls._adapt_num_array
402            return '{%s}' % ','.join(adapt(v) for v in v)
403        if not v and v != 0:
404            return 'null'
405        return str(v)
406
407    _adapt_int_array = _adapt_float_array = _adapt_money_array = \
408            _adapt_num_array
409
410    def _adapt_bytea_array(self, v):
411        """Adapt a bytea array parameter."""
412        if isinstance(v, list):
413            return b'{' + b','.join(
414                self._adapt_bytea_array(v) for v in v) + b'}'
415        if v is None:
416            return b'null'
417        return self.db.escape_bytea(v).replace(b'\\', b'\\\\')
418
419    def _adapt_json_array(self, v):
420        """Adapt a json array parameter."""
421        if isinstance(v, list):
422            adapt = self._adapt_json_array
423            return '{%s}' % ','.join(adapt(v) for v in v)
424        if not v:
425            return 'null'
426        if not isinstance(v, basestring):
427            v = self.db.encode_json(v)
428        if self._re_array_quote.search(v):
429            v = '"%s"' % self._re_array_escape.sub(r'\\\1', v)
430        return v
431
432    def _adapt_record(self, v, typ):
433        """Adapt a record parameter with given type."""
434        typ = self.get_attnames(typ).values()
435        if len(typ) != len(v):
436            raise TypeError('Record parameter %s has wrong size' % v)
437        adapt = self.adapt
438        value = []
439        for v, t in zip(v, typ):
440            v = adapt(v, t)
441            if v is None:
442                v = ''
443            elif not v:
444                v = '""'
445            else:
446                if isinstance(v, bytes):
447                    if str is not bytes:
448                        v = v.decode('ascii')
449                else:
450                    v = str(v)
451                if self._re_record_quote.search(v):
452                    v = '"%s"' % self._re_record_escape.sub(r'\\\1', v)
453            value.append(v)
454        return '(%s)' % ','.join(value)
455
456    def adapt(self, value, typ=None):
457        """Adapt a value with known database type."""
458        if value is not None and not isinstance(value, Literal):
459            if typ:
460                simple = self.get_simple_name(typ)
461            else:
462                typ = simple = self.guess_simple_type(value) or 'text'
463            try:
464                value = value.__pg_str__(typ)
465            except AttributeError:
466                pass
467            if simple == 'text':
468                pass
469            elif simple == 'record':
470                if isinstance(value, tuple):
471                    value = self._adapt_record(value, typ)
472            elif simple.endswith('[]'):
473                if isinstance(value, list):
474                    adapt = getattr(self, '_adapt_%s_array' % simple[:-2])
475                    value = adapt(value)
476            else:
477                adapt = getattr(self, '_adapt_%s' % simple)
478                value = adapt(value)
479        return value
480
481    @staticmethod
482    def simple_type(name):
483        """Create a simple database type with given attribute names."""
484        typ = DbType(name)
485        typ.simple = name
486        return typ
487
488    @staticmethod
489    def get_simple_name(typ):
490        """Get the simple name of a database type."""
491        if isinstance(typ, DbType):
492            return typ.simple
493        return _simpletypes[typ]
494
495    @staticmethod
496    def get_attnames(typ):
497        """Get the attribute names of a composite database type."""
498        if isinstance(typ, DbType):
499            return typ.attnames
500        return {}
501
502    @classmethod
503    def guess_simple_type(cls, value):
504        """Try to guess which database type the given value has."""
505        if isinstance(value, Bytea):
506            return 'bytea'
507        if isinstance(value, basestring):
508            return 'text'
509        if isinstance(value, bool):
510            return 'bool'
511        if isinstance(value, (int, long)):
512            return 'int'
513        if isinstance(value, float):
514            return 'float'
515        if isinstance(value, Decimal):
516            return 'num'
517        if isinstance(value, (date, time, datetime, timedelta)):
518            return 'date'
519        if isinstance(value, list):
520            return '%s[]' % (cls.guess_simple_base_type(value) or 'text',)
521        if isinstance(value, tuple):
522            simple_type = cls.simple_type
523            typ = simple_type('record')
524            guess = cls.guess_simple_type
525            def get_attnames(self):
526                return AttrDict((str(n + 1), simple_type(guess(v)))
527                    for n, v in enumerate(value))
528            typ._get_attnames = get_attnames
529            return typ
530
531    @classmethod
532    def guess_simple_base_type(cls, value):
533        """Try to guess the base type of a given array."""
534        for v in value:
535            if isinstance(v, list):
536                typ = cls.guess_simple_base_type(v)
537            else:
538                typ = cls.guess_simple_type(v)
539            if typ:
540                return typ
541
542    def adapt_inline(self, value, nested=False):
543        """Adapt a value that is put into the SQL and needs to be quoted."""
544        if value is None:
545            return 'NULL'
546        if isinstance(value, Literal):
547            return value
548        if isinstance(value, Bytea):
549            value = self.db.escape_bytea(value)
550            if bytes is not str:  # Python >= 3.0
551                value = value.decode('ascii')
552        elif isinstance(value, Json):
553            if value.encode:
554                return value.encode()
555            value = self.db.encode_json(value)
556        elif isinstance(value, (datetime, date, time, timedelta)):
557            value = str(value)
558        if isinstance(value, basestring):
559            value = self.db.escape_string(value)
560            return "'%s'" % value
561        if isinstance(value, bool):
562            return 'true' if value else 'false'
563        if isinstance(value, float):
564            if isinf(value):
565                return "'-Infinity'" if value < 0 else "'Infinity'"
566            if isnan(value):
567                return "'NaN'"
568            return value
569        if isinstance(value, (int, long, Decimal)):
570            return value
571        if isinstance(value, list):
572            q = self.adapt_inline
573            s = '[%s]' if nested else 'ARRAY[%s]'
574            return s % ','.join(str(q(v, nested=True)) for v in value)
575        if isinstance(value, tuple):
576            q = self.adapt_inline
577            return '(%s)' % ','.join(str(q(v)) for v in value)
578        try:
579            value = value.__pg_repr__()
580        except AttributeError:
581            raise InterfaceError(
582                'Do not know how to adapt type %s' % type(value))
583        if isinstance(value, (tuple, list)):
584            value = self.adapt_inline(value)
585        return value
586
587    def parameter_list(self):
588        """Return a parameter list for parameters with known database types.
589
590        The list has an add(value, typ) method that will build up the
591        list and return either the literal value or a placeholder.
592        """
593        params = _ParameterList()
594        params.adapt = self.adapt
595        return params
596
597    def format_query(self, command, values, types=None, inline=False):
598        """Format a database query using the given values and types."""
599        if inline and types:
600            raise ValueError('Typed parameters must be sent separately')
601        params = self.parameter_list()
602        if isinstance(values, (list, tuple)):
603            if inline:
604                adapt = self.adapt_inline
605                literals = [adapt(value) for value in values]
606            else:
607                add = params.add
608                literals = []
609                append = literals.append
610                if types:
611                    if (not isinstance(types, (list, tuple)) or
612                            len(types) != len(values)):
613                        raise TypeError('The values and types do not match')
614                    for value, typ in zip(values, types):
615                        append(add(value, typ))
616                else:
617                    for value in values:
618                        append(add(value))
619            command %= tuple(literals)
620        elif isinstance(values, dict):
621            # we want to allow extra keys in the dictionary,
622            # so we first must find the values actually used in the command
623            used_values = {}
624            literals = dict.fromkeys(values, '')
625            for key in literals:
626                del literals[key]
627                try:
628                    command % literals
629                except KeyError:
630                    used_values[key] = values[key]
631                literals[key] = ''
632            values = used_values
633            if inline:
634                adapt = self.adapt_inline
635                literals = dict((key, adapt(value))
636                    for key, value in values.items())
637            else:
638                add = params.add
639                literals = {}
640                if types:
641                    if not isinstance(types, dict):
642                        raise TypeError('The values and types do not match')
643                    for key in sorted(values):
644                        literals[key] = add(values[key], types.get(key))
645                else:
646                    for key in sorted(values):
647                        literals[key] = add(values[key])
648            command %= literals
649        else:
650            raise TypeError('The values must be passed as tuple, list or dict')
651        return command, params
652
653
654def cast_bool(value):
655    """Cast a boolean value."""
656    if not get_bool():
657        return value
658    return value[0] == 't'
659
660
661def cast_json(value):
662    """Cast a JSON value."""
663    cast = get_jsondecode()
664    if not cast:
665        return value
666    return cast(value)
667
668
669def cast_num(value):
670    """Cast a numeric value."""
671    return (get_decimal() or float)(value)
672
673
674def cast_money(value):
675    """Cast a money value."""
676    point = get_decimal_point()
677    if not point:
678        return value
679    if point != '.':
680        value = value.replace(point, '.')
681    value = value.replace('(', '-')
682    value = ''.join(c for c in value if c.isdigit() or c in '.-')
683    return (get_decimal() or float)(value)
684
685
686def cast_int2vector(value):
687    """Cast an int2vector value."""
688    return [int(v) for v in value.split()]
689
690
691def cast_date(value, connection):
692    """Cast a date value."""
693    # The output format depends on the server setting DateStyle.  The default
694    # setting ISO and the setting for German are actually unambiguous.  The
695    # order of days and months in the other two settings is however ambiguous,
696    # so at least here we need to consult the setting to properly parse values.
697    if value == '-infinity':
698        return date.min
699    if value == 'infinity':
700        return date.max
701    value = value.split()
702    if value[-1] == 'BC':
703        return date.min
704    value = value[0]
705    if len(value) > 10:
706        return date.max
707    fmt = connection.date_format()
708    return datetime.strptime(value, fmt).date()
709
710
711def cast_time(value):
712    """Cast a time value."""
713    fmt = '%H:%M:%S.%f' if len(value) > 8 else '%H:%M:%S'
714    return datetime.strptime(value, fmt).time()
715
716
717_re_timezone = regex('(.*)([+-].*)')
718
719
720def cast_timetz(value):
721    """Cast a timetz value."""
722    tz = _re_timezone.match(value)
723    if tz:
724        value, tz = tz.groups()
725    else:
726        tz = '+0000'
727    fmt = '%H:%M:%S.%f' if len(value) > 8 else '%H:%M:%S'
728    if _has_timezone:
729        value += _timezone_as_offset(tz)
730        fmt += '%z'
731        return datetime.strptime(value, fmt).timetz()
732    return datetime.strptime(value, fmt).timetz().replace(
733        tzinfo=_get_timezone(tz))
734
735
736def cast_timestamp(value, connection):
737    """Cast a timestamp value."""
738    if value == '-infinity':
739        return datetime.min
740    if value == 'infinity':
741        return datetime.max
742    value = value.split()
743    if value[-1] == 'BC':
744        return datetime.min
745    fmt = connection.date_format()
746    if fmt.endswith('-%Y') and len(value) > 2:
747        value = value[1:5]
748        if len(value[3]) > 4:
749            return datetime.max
750        fmt = ['%d %b' if fmt.startswith('%d') else '%b %d',
751            '%H:%M:%S.%f' if len(value[2]) > 8 else '%H:%M:%S', '%Y']
752    else:
753        if len(value[0]) > 10:
754            return datetime.max
755        fmt = [fmt, '%H:%M:%S.%f' if len(value[1]) > 8 else '%H:%M:%S']
756    return datetime.strptime(' '.join(value), ' '.join(fmt))
757
758
759def cast_timestamptz(value, connection):
760    """Cast a timestamptz value."""
761    if value == '-infinity':
762        return datetime.min
763    if value == 'infinity':
764        return datetime.max
765    value = value.split()
766    if value[-1] == 'BC':
767        return datetime.min
768    fmt = connection.date_format()
769    if fmt.endswith('-%Y') and len(value) > 2:
770        value = value[1:]
771        if len(value[3]) > 4:
772            return datetime.max
773        fmt = ['%d %b' if fmt.startswith('%d') else '%b %d',
774            '%H:%M:%S.%f' if len(value[2]) > 8 else '%H:%M:%S', '%Y']
775        value, tz = value[:-1], value[-1]
776    else:
777        if fmt.startswith('%Y-'):
778            tz = _re_timezone.match(value[1])
779            if tz:
780                value[1], tz = tz.groups()
781            else:
782                tz = '+0000'
783        else:
784            value, tz = value[:-1], value[-1]
785        if len(value[0]) > 10:
786            return datetime.max
787        fmt = [fmt, '%H:%M:%S.%f' if len(value[1]) > 8 else '%H:%M:%S']
788    if _has_timezone:
789        value.append(_timezone_as_offset(tz))
790        fmt.append('%z')
791        return datetime.strptime(' '.join(value), ' '.join(fmt))
792    return datetime.strptime(' '.join(value), ' '.join(fmt)).replace(
793        tzinfo=_get_timezone(tz))
794
795
796_re_interval_sql_standard = regex(
797    '(?:([+-])?([0-9]+)-([0-9]+) ?)?'
798    '(?:([+-]?[0-9]+)(?!:) ?)?'
799    '(?:([+-])?([0-9]+):([0-9]+):([0-9]+)(?:\\.([0-9]+))?)?')
800
801_re_interval_postgres = regex(
802    '(?:([+-]?[0-9]+) ?years? ?)?'
803    '(?:([+-]?[0-9]+) ?mons? ?)?'
804    '(?:([+-]?[0-9]+) ?days? ?)?'
805    '(?:([+-])?([0-9]+):([0-9]+):([0-9]+)(?:\\.([0-9]+))?)?')
806
807_re_interval_postgres_verbose = regex(
808    '@ ?(?:([+-]?[0-9]+) ?years? ?)?'
809    '(?:([+-]?[0-9]+) ?mons? ?)?'
810    '(?:([+-]?[0-9]+) ?days? ?)?'
811    '(?:([+-]?[0-9]+) ?hours? ?)?'
812    '(?:([+-]?[0-9]+) ?mins? ?)?'
813    '(?:([+-])?([0-9]+)(?:\\.([0-9]+))? ?secs?)? ?(ago)?')
814
815_re_interval_iso_8601 = regex(
816    'P(?:([+-]?[0-9]+)Y)?'
817    '(?:([+-]?[0-9]+)M)?'
818    '(?:([+-]?[0-9]+)D)?'
819    '(?:T(?:([+-]?[0-9]+)H)?'
820    '(?:([+-]?[0-9]+)M)?'
821    '(?:([+-])?([0-9]+)(?:\\.([0-9]+))?S)?)?')
822
823
824def cast_interval(value):
825    """Cast an interval value."""
826    # The output format depends on the server setting IntervalStyle, but it's
827    # not necessary to consult this setting to parse it.  It's faster to just
828    # check all possible formats, and there is no ambiguity here.
829    m = _re_interval_iso_8601.match(value)
830    if m:
831        m = [d or '0' for d in m.groups()]
832        secs_ago = m.pop(5) == '-'
833        m = [int(d) for d in m]
834        years, mons, days, hours, mins, secs, usecs = m
835        if secs_ago:
836            secs = -secs
837            usecs = -usecs
838    else:
839        m = _re_interval_postgres_verbose.match(value)
840        if m:
841            m, ago = [d or '0' for d in m.groups()[:8]], m.group(9)
842            secs_ago = m.pop(5) == '-'
843            m = [-int(d) for d in m] if ago else [int(d) for d in m]
844            years, mons, days, hours, mins, secs, usecs = m
845            if secs_ago:
846                secs = - secs
847                usecs = -usecs
848        else:
849            m = _re_interval_postgres.match(value)
850            if m and any(m.groups()):
851                m = [d or '0' for d in m.groups()]
852                hours_ago = m.pop(3) == '-'
853                m = [int(d) for d in m]
854                years, mons, days, hours, mins, secs, usecs = m
855                if hours_ago:
856                    hours = -hours
857                    mins = -mins
858                    secs = -secs
859                    usecs = -usecs
860            else:
861                m = _re_interval_sql_standard.match(value)
862                if m and any(m.groups()):
863                    m = [d or '0' for d in m.groups()]
864                    years_ago = m.pop(0) == '-'
865                    hours_ago = m.pop(3) == '-'
866                    m = [int(d) for d in m]
867                    years, mons, days, hours, mins, secs, usecs = m
868                    if years_ago:
869                        years = -years
870                        mons = -mons
871                    if hours_ago:
872                        hours = -hours
873                        mins = -mins
874                        secs = -secs
875                        usecs = -usecs
876                else:
877                    raise ValueError('Cannot parse interval: %s' % value)
878    days += 365 * years + 30 * mons
879    return timedelta(days=days, hours=hours, minutes=mins,
880        seconds=secs, microseconds=usecs)
881
882
883class Typecasts(dict):
884    """Dictionary mapping database types to typecast functions.
885
886    The cast functions get passed the string representation of a value in
887    the database which they need to convert to a Python object.  The
888    passed string will never be None since NULL values are already be
889    handled before the cast function is called.
890
891    Note that the basic types are already handled by the C extension.
892    They only need to be handled here as record or array components.
893    """
894
895    # the default cast functions
896    # (str functions are ignored but have been added for faster access)
897    defaults = {'char': str, 'bpchar': str, 'name': str,
898        'text': str, 'varchar': str,
899        'bool': cast_bool, 'bytea': unescape_bytea,
900        'int2': int, 'int4': int, 'serial': int, 'int8': long, 'oid': int,
901        'hstore': cast_hstore, 'json': cast_json, 'jsonb': cast_json,
902        'float4': float, 'float8': float,
903        'numeric': cast_num, 'money': cast_money,
904        'date': cast_date, 'interval': cast_interval,
905        'time': cast_time, 'timetz': cast_timetz,
906        'timestamp': cast_timestamp, 'timestamptz': cast_timestamptz,
907        'int2vector': cast_int2vector, 'uuid': UUID,
908        'anyarray': cast_array, 'record': cast_record}
909
910    connection = None  # will be set in a connection specific instance
911
912    def __missing__(self, typ):
913        """Create a cast function if it is not cached.
914
915        Note that this class never raises a KeyError,
916        but returns None when no special cast function exists.
917        """
918        if not isinstance(typ, str):
919            raise TypeError('Invalid type: %s' % typ)
920        cast = self.defaults.get(typ)
921        if cast:
922            # store default for faster access
923            cast = self._add_connection(cast)
924            self[typ] = cast
925        elif typ.startswith('_'):
926            base_cast = self[typ[1:]]
927            cast = self.create_array_cast(base_cast)
928            if base_cast:
929                self[typ] = cast
930        else:
931            attnames = self.get_attnames(typ)
932            if attnames:
933                casts = [self[v.pgtype] for v in attnames.values()]
934                cast = self.create_record_cast(typ, attnames, casts)
935                self[typ] = cast
936        return cast
937
938    @staticmethod
939    def _needs_connection(func):
940        """Check if a typecast function needs a connection argument."""
941        try:
942            args = get_args(func)
943        except (TypeError, ValueError):
944            return False
945        else:
946            return 'connection' in args[1:]
947
948    def _add_connection(self, cast):
949        """Add a connection argument to the typecast function if necessary."""
950        if not self.connection or not self._needs_connection(cast):
951            return cast
952        return partial(cast, connection=self.connection)
953
954    def get(self, typ, default=None):
955        """Get the typecast function for the given database type."""
956        return self[typ] or default
957
958    def set(self, typ, cast):
959        """Set a typecast function for the specified database type(s)."""
960        if isinstance(typ, basestring):
961            typ = [typ]
962        if cast is None:
963            for t in typ:
964                self.pop(t, None)
965                self.pop('_%s' % t, None)
966        else:
967            if not callable(cast):
968                raise TypeError("Cast parameter must be callable")
969            for t in typ:
970                self[t] = self._add_connection(cast)
971                self.pop('_%s' % t, None)
972
973    def reset(self, typ=None):
974        """Reset the typecasts for the specified type(s) to their defaults.
975
976        When no type is specified, all typecasts will be reset.
977        """
978        if typ is None:
979            self.clear()
980        else:
981            if isinstance(typ, basestring):
982                typ = [typ]
983            for t in typ:
984                self.pop(t, None)
985
986    @classmethod
987    def get_default(cls, typ):
988        """Get the default typecast function for the given database type."""
989        return cls.defaults.get(typ)
990
991    @classmethod
992    def set_default(cls, typ, cast):
993        """Set a default typecast function for the given database type(s)."""
994        if isinstance(typ, basestring):
995            typ = [typ]
996        defaults = cls.defaults
997        if cast is None:
998            for t in typ:
999                defaults.pop(t, None)
1000                defaults.pop('_%s' % t, None)
1001        else:
1002            if not callable(cast):
1003                raise TypeError("Cast parameter must be callable")
1004            for t in typ:
1005                defaults[t] = cast
1006                defaults.pop('_%s' % t, None)
1007
1008    def get_attnames(self, typ):
1009        """Return the fields for the given record type.
1010
1011        This method will be replaced with the get_attnames() method of DbTypes.
1012        """
1013        return {}
1014
1015    def dateformat(self):
1016        """Return the current date format.
1017
1018        This method will be replaced with the dateformat() method of DbTypes.
1019        """
1020        return '%Y-%m-%d'
1021
1022    def create_array_cast(self, basecast):
1023        """Create an array typecast for the given base cast."""
1024        def cast(v):
1025            return cast_array(v, basecast)
1026        return cast
1027
1028    def create_record_cast(self, name, fields, casts):
1029        """Create a named record typecast for the given fields and casts."""
1030        record = namedtuple(name, fields)
1031        def cast(v):
1032            return record(*cast_record(v, casts))
1033        return cast
1034
1035
1036def get_typecast(typ):
1037    """Get the global typecast function for the given database type(s)."""
1038    return Typecasts.get_default(typ)
1039
1040
1041def set_typecast(typ, cast):
1042    """Set a global typecast function for the given database type(s).
1043
1044    Note that connections cache cast functions. To be sure a global change
1045    is picked up by a running connection, call db.db_types.reset_typecast().
1046    """
1047    Typecasts.set_default(typ, cast)
1048
1049
1050class DbType(str):
1051    """Class augmenting the simple type name with additional info.
1052
1053    The following additional information is provided:
1054
1055        oid: the PostgreSQL type OID
1056        pgtype: the PostgreSQL type name
1057        regtype: the regular type name
1058        simple: the simple PyGreSQL type name
1059        typtype: b = base type, c = composite type etc.
1060        category: A = Array, b = Boolean, C = Composite etc.
1061        delim: delimiter for array types
1062        relid: corresponding table for composite types
1063        attnames: attributes for composite types
1064    """
1065
1066    @property
1067    def attnames(self):
1068        """Get names and types of the fields of a composite type."""
1069        return self._get_attnames(self)
1070
1071
1072class DbTypes(dict):
1073    """Cache for PostgreSQL data types.
1074
1075    This cache maps type OIDs and names to DbType objects containing
1076    information on the associated database type.
1077    """
1078
1079    _num_types = frozenset('int float num money'
1080        ' int2 int4 int8 float4 float8 numeric money'.split())
1081
1082    def __init__(self, db):
1083        """Initialize type cache for connection."""
1084        super(DbTypes, self).__init__()
1085        self._db = weakref.proxy(db)
1086        self._regtypes = False
1087        self._typecasts = Typecasts()
1088        self._typecasts.get_attnames = self.get_attnames
1089        self._typecasts.connection = self._db
1090        if db.server_version < 80400:
1091            # older remote databases (not officially supported)
1092            self._query_pg_type = (
1093                "SELECT oid, typname, typname::text::regtype,"
1094                " typtype, null as typcategory, typdelim, typrelid"
1095                " FROM pg_type WHERE oid=%s::regtype")
1096        else:
1097            self._query_pg_type = (
1098                "SELECT oid, typname, typname::regtype,"
1099                " typtype, typcategory, typdelim, typrelid"
1100                " FROM pg_type WHERE oid=%s::regtype")
1101
1102    def add(self, oid, pgtype, regtype,
1103               typtype, category, delim, relid):
1104        """Create a PostgreSQL type name with additional info."""
1105        if oid in self:
1106            return self[oid]
1107        simple = 'record' if relid else _simpletypes[pgtype]
1108        typ = DbType(regtype if self._regtypes else simple)
1109        typ.oid = oid
1110        typ.simple = simple
1111        typ.pgtype = pgtype
1112        typ.regtype = regtype
1113        typ.typtype = typtype
1114        typ.category = category
1115        typ.delim = delim
1116        typ.relid = relid
1117        typ._get_attnames = self.get_attnames
1118        return typ
1119
1120    def __missing__(self, key):
1121        """Get the type info from the database if it is not cached."""
1122        try:
1123            q = self._query_pg_type % (_quote_if_unqualified('$1', key),)
1124            res = self._db.query(q, (key,)).getresult()
1125        except ProgrammingError:
1126            res = None
1127        if not res:
1128            raise KeyError('Type %s could not be found' % key)
1129        res = res[0]
1130        typ = self.add(*res)
1131        self[typ.oid] = self[typ.pgtype] = typ
1132        return typ
1133
1134    def get(self, key, default=None):
1135        """Get the type even if it is not cached."""
1136        try:
1137            return self[key]
1138        except KeyError:
1139            return default
1140
1141    def get_attnames(self, typ):
1142        """Get names and types of the fields of a composite type."""
1143        if not isinstance(typ, DbType):
1144            typ = self.get(typ)
1145            if not typ:
1146                return None
1147        if not typ.relid:
1148            return None
1149        return self._db.get_attnames(typ.relid, with_oid=False)
1150
1151    def get_typecast(self, typ):
1152        """Get the typecast function for the given database type."""
1153        return self._typecasts.get(typ)
1154
1155    def set_typecast(self, typ, cast):
1156        """Set a typecast function for the specified database type(s)."""
1157        self._typecasts.set(typ, cast)
1158
1159    def reset_typecast(self, typ=None):
1160        """Reset the typecast function for the specified database type(s)."""
1161        self._typecasts.reset(typ)
1162
1163    def typecast(self, value, typ):
1164        """Cast the given value according to the given database type."""
1165        if value is None:
1166            # for NULL values, no typecast is necessary
1167            return None
1168        if not isinstance(typ, DbType):
1169            typ = self.get(typ)
1170            if typ:
1171                typ = typ.pgtype
1172        cast = self.get_typecast(typ) if typ else None
1173        if not cast or cast is str:
1174            # no typecast is necessary
1175            return value
1176        return cast(value)
1177
1178
1179def _namedresult(q):
1180    """Get query result as named tuples."""
1181    row = namedtuple('Row', q.listfields())
1182    return [row(*r) for r in q.getresult()]
1183
1184
1185class _MemoryQuery:
1186    """Class that embodies a given query result."""
1187
1188    def __init__(self, result, fields):
1189        """Create query from given result rows and field names."""
1190        self.result = result
1191        self.fields = fields
1192
1193    def listfields(self):
1194        """Return the stored field names of this query."""
1195        return self.fields
1196
1197    def getresult(self):
1198        """Return the stored result of this query."""
1199        return self.result
1200
1201
1202def _db_error(msg, cls=DatabaseError):
1203    """Return DatabaseError with empty sqlstate attribute."""
1204    error = cls(msg)
1205    error.sqlstate = None
1206    return error
1207
1208
1209def _int_error(msg):
1210    """Return InternalError."""
1211    return _db_error(msg, InternalError)
1212
1213
1214def _prg_error(msg):
1215    """Return ProgrammingError."""
1216    return _db_error(msg, ProgrammingError)
1217
1218
1219# Initialize the C module
1220
1221set_namedresult(_namedresult)
1222set_decimal(Decimal)
1223set_jsondecode(jsondecode)
1224
1225
1226# The notification handler
1227
1228class NotificationHandler(object):
1229    """A PostgreSQL client-side asynchronous notification handler."""
1230
1231    def __init__(self, db, event, callback=None,
1232            arg_dict=None, timeout=None, stop_event=None):
1233        """Initialize the notification handler.
1234
1235        You must pass a PyGreSQL database connection, the name of an
1236        event (notification channel) to listen for and a callback function.
1237
1238        You can also specify a dictionary arg_dict that will be passed as
1239        the single argument to the callback function, and a timeout value
1240        in seconds (a floating point number denotes fractions of seconds).
1241        If it is absent or None, the callers will never time out.  If the
1242        timeout is reached, the callback function will be called with a
1243        single argument that is None.  If you set the timeout to zero,
1244        the handler will poll notifications synchronously and return.
1245
1246        You can specify the name of the event that will be used to signal
1247        the handler to stop listening as stop_event. By default, it will
1248        be the event name prefixed with 'stop_'.
1249        """
1250        self.db = db
1251        self.event = event
1252        self.stop_event = stop_event or 'stop_%s' % event
1253        self.listening = False
1254        self.callback = callback
1255        if arg_dict is None:
1256            arg_dict = {}
1257        self.arg_dict = arg_dict
1258        self.timeout = timeout
1259
1260    def __del__(self):
1261        self.unlisten()
1262
1263    def close(self):
1264        """Stop listening and close the connection."""
1265        if self.db:
1266            self.unlisten()
1267            self.db.close()
1268            self.db = None
1269
1270    def listen(self):
1271        """Start listening for the event and the stop event."""
1272        if not self.listening:
1273            self.db.query('listen "%s"' % self.event)
1274            self.db.query('listen "%s"' % self.stop_event)
1275            self.listening = True
1276
1277    def unlisten(self):
1278        """Stop listening for the event and the stop event."""
1279        if self.listening:
1280            self.db.query('unlisten "%s"' % self.event)
1281            self.db.query('unlisten "%s"' % self.stop_event)
1282            self.listening = False
1283
1284    def notify(self, db=None, stop=False, payload=None):
1285        """Generate a notification.
1286
1287        Optionally, you can pass a payload with the notification.
1288
1289        If you set the stop flag, a stop notification will be sent that
1290        will cause the handler to stop listening.
1291
1292        Note: If the notification handler is running in another thread, you
1293        must pass a different database connection since PyGreSQL database
1294        connections are not thread-safe.
1295        """
1296        if self.listening:
1297            if not db:
1298                db = self.db
1299            q = 'notify "%s"' % (self.stop_event if stop else self.event)
1300            if payload:
1301                q += ", '%s'" % payload
1302            return db.query(q)
1303
1304    def __call__(self):
1305        """Invoke the notification handler.
1306
1307        The handler is a loop that listens for notifications on the event
1308        and stop event channels.  When either of these notifications are
1309        received, its associated 'pid', 'event' and 'extra' (the payload
1310        passed with the notification) are inserted into its arg_dict
1311        dictionary and the callback is invoked with this dictionary as
1312        a single argument.  When the handler receives a stop event, it
1313        stops listening to both events and return.
1314
1315        In the special case that the timeout of the handler has been set
1316        to zero, the handler will poll all events synchronously and return.
1317        If will keep listening until it receives a stop event.
1318
1319        Note: If you run this loop in another thread, don't use the same
1320        database connection for database operations in the main thread.
1321        """
1322        self.listen()
1323        poll = self.timeout == 0
1324        if not poll:
1325            rlist = [self.db.fileno()]
1326        while self.listening:
1327            if poll or select.select(rlist, [], [], self.timeout)[0]:
1328                while self.listening:
1329                    notice = self.db.getnotify()
1330                    if not notice:  # no more messages
1331                        break
1332                    event, pid, extra = notice
1333                    if event not in (self.event, self.stop_event):
1334                        self.unlisten()
1335                        raise _db_error(
1336                            'Listening for "%s" and "%s", but notified of "%s"'
1337                            % (self.event, self.stop_event, event))
1338                    if event == self.stop_event:
1339                        self.unlisten()
1340                    self.arg_dict.update(pid=pid, event=event, extra=extra)
1341                    self.callback(self.arg_dict)
1342                if poll:
1343                    break
1344            else:   # we timed out
1345                self.unlisten()
1346                self.callback(None)
1347
1348
1349def pgnotify(*args, **kw):
1350    """Same as NotificationHandler, under the traditional name."""
1351    warnings.warn("pgnotify is deprecated, use NotificationHandler instead",
1352        DeprecationWarning, stacklevel=2)
1353    return NotificationHandler(*args, **kw)
1354
1355
1356# The actual PostGreSQL database connection interface:
1357
1358class DB:
1359    """Wrapper class for the _pg connection type."""
1360
1361    def __init__(self, *args, **kw):
1362        """Create a new connection
1363
1364        You can pass either the connection parameters or an existing
1365        _pg or pgdb connection. This allows you to use the methods
1366        of the classic pg interface with a DB-API 2 pgdb connection.
1367        """
1368        if not args and len(kw) == 1:
1369            db = kw.get('db')
1370        elif not kw and len(args) == 1:
1371            db = args[0]
1372        else:
1373            db = None
1374        if db:
1375            if isinstance(db, DB):
1376                db = db.db
1377            else:
1378                try:
1379                    db = db._cnx
1380                except AttributeError:
1381                    pass
1382        if not db or not hasattr(db, 'db') or not hasattr(db, 'query'):
1383            db = connect(*args, **kw)
1384            self._closeable = True
1385        else:
1386            self._closeable = False
1387        self.db = db
1388        self.dbname = db.db
1389        self._regtypes = False
1390        self._attnames = {}
1391        self._pkeys = {}
1392        self._privileges = {}
1393        self._args = args, kw
1394        self.adapter = Adapter(self)
1395        self.dbtypes = DbTypes(self)
1396        if db.server_version < 80400:
1397            # support older remote data bases
1398            self._query_attnames = (
1399                "SELECT a.attname, t.oid, t.typname, t.typname::text::regtype,"
1400                " t.typtype, null as typcategory, t.typdelim, t.typrelid"
1401                " FROM pg_attribute a"
1402                " JOIN pg_type t ON t.oid = a.atttypid"
1403                " WHERE a.attrelid = %s::regclass AND %s"
1404                " AND NOT a.attisdropped ORDER BY a.attnum")
1405        else:
1406            self._query_attnames = (
1407                "SELECT a.attname, t.oid, t.typname, t.typname::regtype,"
1408                " t.typtype, t.typcategory, t.typdelim, t.typrelid"
1409                " FROM pg_attribute a"
1410                " JOIN pg_type t ON t.oid = a.atttypid"
1411                " WHERE a.attrelid = %s::regclass AND %s"
1412                " AND NOT a.attisdropped ORDER BY a.attnum")
1413        db.set_cast_hook(self.dbtypes.typecast)
1414        self.debug = None  # For debugging scripts, this can be set
1415            # * to a string format specification (e.g. in CGI set to "%s<BR>"),
1416            # * to a file object to write debug statements or
1417            # * to a callable object which takes a string argument
1418            # * to any other true value to just print debug statements
1419
1420    def __getattr__(self, name):
1421        # All undefined members are same as in underlying connection:
1422        if self.db:
1423            return getattr(self.db, name)
1424        else:
1425            raise _int_error('Connection is not valid')
1426
1427    def __dir__(self):
1428        # Custom dir function including the attributes of the connection:
1429        attrs = set(self.__class__.__dict__)
1430        attrs.update(self.__dict__)
1431        attrs.update(dir(self.db))
1432        return sorted(attrs)
1433
1434    # Context manager methods
1435
1436    def __enter__(self):
1437        """Enter the runtime context. This will start a transaction."""
1438        self.begin()
1439        return self
1440
1441    def __exit__(self, et, ev, tb):
1442        """Exit the runtime context. This will end the transaction."""
1443        if et is None and ev is None and tb is None:
1444            self.commit()
1445        else:
1446            self.rollback()
1447
1448    def __del__(self):
1449        try:
1450            db = self.db
1451        except AttributeError:
1452            db = None
1453        if db:
1454            db.set_cast_hook(None)
1455            if self._closeable:
1456                db.close()
1457
1458    # Auxiliary methods
1459
1460    def _do_debug(self, *args):
1461        """Print a debug message"""
1462        if self.debug:
1463            s = '\n'.join(str(arg) for arg in args)
1464            if isinstance(self.debug, basestring):
1465                print(self.debug % s)
1466            elif hasattr(self.debug, 'write'):
1467                self.debug.write(s + '\n')
1468            elif callable(self.debug):
1469                self.debug(s)
1470            else:
1471                print(s)
1472
1473    def _escape_qualified_name(self, s):
1474        """Escape a qualified name.
1475
1476        Escapes the name for use as an SQL identifier, unless the
1477        name contains a dot, in which case the name is ambiguous
1478        (could be a qualified name or just a name with a dot in it)
1479        and must be quoted manually by the caller.
1480        """
1481        if '.' not in s:
1482            s = self.escape_identifier(s)
1483        return s
1484
1485    @staticmethod
1486    def _make_bool(d):
1487        """Get boolean value corresponding to d."""
1488        return bool(d) if get_bool() else ('t' if d else 'f')
1489
1490    def _list_params(self, params):
1491        """Create a human readable parameter list."""
1492        return ', '.join('$%d=%r' % (n, v) for n, v in enumerate(params, 1))
1493
1494    # Public methods
1495
1496    # escape_string and escape_bytea exist as methods,
1497    # so we define unescape_bytea as a method as well
1498    unescape_bytea = staticmethod(unescape_bytea)
1499
1500    def decode_json(self, s):
1501        """Decode a JSON string coming from the database."""
1502        return (get_jsondecode() or jsondecode)(s)
1503
1504    def encode_json(self, d):
1505        """Encode a JSON string for use within SQL."""
1506        return jsonencode(d)
1507
1508    def close(self):
1509        """Close the database connection."""
1510        # Wraps shared library function so we can track state.
1511        if self._closeable:
1512            if self.db:
1513                self.db.set_cast_hook(None)
1514                self.db.close()
1515                self.db = None
1516            else:
1517                raise _int_error('Connection already closed')
1518
1519    def reset(self):
1520        """Reset connection with current parameters.
1521
1522        All derived queries and large objects derived from this connection
1523        will not be usable after this call.
1524
1525        """
1526        if self.db:
1527            self.db.reset()
1528        else:
1529            raise _int_error('Connection already closed')
1530
1531    def reopen(self):
1532        """Reopen connection to the database.
1533
1534        Used in case we need another connection to the same database.
1535        Note that we can still reopen a database that we have closed.
1536
1537        """
1538        # There is no such shared library function.
1539        if self._closeable:
1540            db = connect(*self._args[0], **self._args[1])
1541            if self.db:
1542                self.db.set_cast_hook(None)
1543                self.db.close()
1544            self.db = db
1545
1546    def begin(self, mode=None):
1547        """Begin a transaction."""
1548        qstr = 'BEGIN'
1549        if mode:
1550            qstr += ' ' + mode
1551        return self.query(qstr)
1552
1553    start = begin
1554
1555    def commit(self):
1556        """Commit the current transaction."""
1557        return self.query('COMMIT')
1558
1559    end = commit
1560
1561    def rollback(self, name=None):
1562        """Roll back the current transaction."""
1563        qstr = 'ROLLBACK'
1564        if name:
1565            qstr += ' TO ' + name
1566        return self.query(qstr)
1567
1568    abort = rollback
1569
1570    def savepoint(self, name):
1571        """Define a new savepoint within the current transaction."""
1572        return self.query('SAVEPOINT ' + name)
1573
1574    def release(self, name):
1575        """Destroy a previously defined savepoint."""
1576        return self.query('RELEASE ' + name)
1577
1578    def get_parameter(self, parameter):
1579        """Get the value of a run-time parameter.
1580
1581        If the parameter is a string, the return value will also be a string
1582        that is the current setting of the run-time parameter with that name.
1583
1584        You can get several parameters at once by passing a list, set or dict.
1585        When passing a list of parameter names, the return value will be a
1586        corresponding list of parameter settings.  When passing a set of
1587        parameter names, a new dict will be returned, mapping these parameter
1588        names to their settings.  Finally, if you pass a dict as parameter,
1589        its values will be set to the current parameter settings corresponding
1590        to its keys.
1591
1592        By passing the special name 'all' as the parameter, you can get a dict
1593        of all existing configuration parameters.
1594        """
1595        if isinstance(parameter, basestring):
1596            parameter = [parameter]
1597            values = None
1598        elif isinstance(parameter, (list, tuple)):
1599            values = []
1600        elif isinstance(parameter, (set, frozenset)):
1601            values = {}
1602        elif isinstance(parameter, dict):
1603            values = parameter
1604        else:
1605            raise TypeError(
1606                'The parameter must be a string, list, set or dict')
1607        if not parameter:
1608            raise TypeError('No parameter has been specified')
1609        params = {} if isinstance(values, dict) else []
1610        for key in parameter:
1611            param = key.strip().lower() if isinstance(
1612                key, basestring) else None
1613            if not param:
1614                raise TypeError('Invalid parameter')
1615            if param == 'all':
1616                q = 'SHOW ALL'
1617                values = self.db.query(q).getresult()
1618                values = dict(value[:2] for value in values)
1619                break
1620            if isinstance(values, dict):
1621                params[param] = key
1622            else:
1623                params.append(param)
1624        else:
1625            for param in params:
1626                q = 'SHOW %s' % (param,)
1627                value = self.db.query(q).getresult()[0][0]
1628                if values is None:
1629                    values = value
1630                elif isinstance(values, list):
1631                    values.append(value)
1632                else:
1633                    values[params[param]] = value
1634        return values
1635
1636    def set_parameter(self, parameter, value=None, local=False):
1637        """Set the value of a run-time parameter.
1638
1639        If the parameter and the value are strings, the run-time parameter
1640        will be set to that value.  If no value or None is passed as a value,
1641        then the run-time parameter will be restored to its default value.
1642
1643        You can set several parameters at once by passing a list of parameter
1644        names, together with a single value that all parameters should be
1645        set to or with a corresponding list of values.  You can also pass
1646        the parameters as a set if you only provide a single value.
1647        Finally, you can pass a dict with parameter names as keys.  In this
1648        case, you should not pass a value, since the values for the parameters
1649        will be taken from the dict.
1650
1651        By passing the special name 'all' as the parameter, you can reset
1652        all existing settable run-time parameters to their default values.
1653
1654        If you set local to True, then the command takes effect for only the
1655        current transaction.  After commit() or rollback(), the session-level
1656        setting takes effect again.  Setting local to True will appear to
1657        have no effect if it is executed outside a transaction, since the
1658        transaction will end immediately.
1659        """
1660        if isinstance(parameter, basestring):
1661            parameter = {parameter: value}
1662        elif isinstance(parameter, (list, tuple)):
1663            if isinstance(value, (list, tuple)):
1664                parameter = dict(zip(parameter, value))
1665            else:
1666                parameter = dict.fromkeys(parameter, value)
1667        elif isinstance(parameter, (set, frozenset)):
1668            if isinstance(value, (list, tuple, set, frozenset)):
1669                value = set(value)
1670                if len(value) == 1:
1671                    value = value.pop()
1672            if not(value is None or isinstance(value, basestring)):
1673                raise ValueError('A single value must be specified'
1674                    ' when parameter is a set')
1675            parameter = dict.fromkeys(parameter, value)
1676        elif isinstance(parameter, dict):
1677            if value is not None:
1678                raise ValueError('A value must not be specified'
1679                    ' when parameter is a dictionary')
1680        else:
1681            raise TypeError(
1682                'The parameter must be a string, list, set or dict')
1683        if not parameter:
1684            raise TypeError('No parameter has been specified')
1685        params = {}
1686        for key, value in parameter.items():
1687            param = key.strip().lower() if isinstance(
1688                key, basestring) else None
1689            if not param:
1690                raise TypeError('Invalid parameter')
1691            if param == 'all':
1692                if value is not None:
1693                    raise ValueError('A value must ot be specified'
1694                        " when parameter is 'all'")
1695                params = {'all': None}
1696                break
1697            params[param] = value
1698        local = ' LOCAL' if local else ''
1699        for param, value in params.items():
1700            if value is None:
1701                q = 'RESET%s %s' % (local, param)
1702            else:
1703                q = 'SET%s %s TO %s' % (local, param, value)
1704            self._do_debug(q)
1705            self.db.query(q)
1706
1707    def query(self, command, *args):
1708        """Execute a SQL command string.
1709
1710        This method simply sends a SQL query to the database.  If the query is
1711        an insert statement that inserted exactly one row into a table that
1712        has OIDs, the return value is the OID of the newly inserted row.
1713        If the query is an update or delete statement, or an insert statement
1714        that did not insert exactly one row in a table with OIDs, then the
1715        number of rows affected is returned as a string.  If it is a statement
1716        that returns rows as a result (usually a select statement, but maybe
1717        also an "insert/update ... returning" statement), this method returns
1718        a Query object that can be accessed via getresult() or dictresult()
1719        or simply printed.  Otherwise, it returns `None`.
1720
1721        The query can contain numbered parameters of the form $1 in place
1722        of any data constant.  Arguments given after the query string will
1723        be substituted for the corresponding numbered parameter.  Parameter
1724        values can also be given as a single list or tuple argument.
1725        """
1726        # Wraps shared library function for debugging.
1727        if not self.db:
1728            raise _int_error('Connection is not valid')
1729        if args:
1730            self._do_debug(command, args)
1731            return self.db.query(command, args)
1732        self._do_debug(command)
1733        return self.db.query(command)
1734
1735    def query_formatted(self, command, parameters, types=None, inline=False):
1736        """Execute a formatted SQL command string.
1737
1738        Similar to query, but using Python format placeholders of the form
1739        %s or %(names)s instead of PostgreSQL placeholders of the form $1.
1740        The parameters must be passed as a tuple, list or dict.  You can
1741        also pass a corresponding tuple, list or dict of database types in
1742        order to format the parameters properly in case there is ambiguity.
1743
1744        If you set inline to True, the parameters will be sent to the database
1745        embedded in the SQL command, otherwise they will be sent separately.
1746        """
1747        return self.query(*self.adapter.format_query(
1748            command, parameters, types, inline))
1749
1750    def pkey(self, table, composite=False, flush=False):
1751        """Get or set the primary key of a table.
1752
1753        Single primary keys are returned as strings unless you
1754        set the composite flag.  Composite primary keys are always
1755        represented as tuples.  Note that this raises a KeyError
1756        if the table does not have a primary key.
1757
1758        If flush is set then the internal cache for primary keys will
1759        be flushed.  This may be necessary after the database schema or
1760        the search path has been changed.
1761        """
1762        pkeys = self._pkeys
1763        if flush:
1764            pkeys.clear()
1765            self._do_debug('The pkey cache has been flushed')
1766        try:  # cache lookup
1767            pkey = pkeys[table]
1768        except KeyError:  # cache miss, check the database
1769            q = ("SELECT a.attname, a.attnum, i.indkey FROM pg_index i"
1770                " JOIN pg_attribute a ON a.attrelid = i.indrelid"
1771                " AND a.attnum = ANY(i.indkey)"
1772                " AND NOT a.attisdropped"
1773                " WHERE i.indrelid=%s::regclass"
1774                " AND i.indisprimary ORDER BY a.attnum") % (
1775                    _quote_if_unqualified('$1', table),)
1776            pkey = self.db.query(q, (table,)).getresult()
1777            if not pkey:
1778                raise KeyError('Table %s has no primary key' % table)
1779            # we want to use the order defined in the primary key index here,
1780            # not the order as defined by the columns in the table
1781            if len(pkey) > 1:
1782                indkey = pkey[0][2]
1783                pkey = sorted(pkey, key=lambda row: indkey.index(row[1]))
1784                pkey = tuple(row[0] for row in pkey)
1785            else:
1786                pkey = pkey[0][0]
1787            pkeys[table] = pkey  # cache it
1788        if composite and not isinstance(pkey, tuple):
1789            pkey = (pkey,)
1790        return pkey
1791
1792    def get_databases(self):
1793        """Get list of databases in the system."""
1794        return [s[0] for s in
1795            self.db.query('SELECT datname FROM pg_database').getresult()]
1796
1797    def get_relations(self, kinds=None, system=False):
1798        """Get list of relations in connected database of specified kinds.
1799
1800        If kinds is None or empty, all kinds of relations are returned.
1801        Otherwise kinds can be a string or sequence of type letters
1802        specifying which kind of relations you want to list.
1803
1804        Set the system flag if you want to get the system relations as well.
1805        """
1806        where = []
1807        if kinds:
1808            where.append("r.relkind IN (%s)" %
1809                ','.join("'%s'" % k for k in kinds))
1810        if not system:
1811            where.append("s.nspname NOT SIMILAR"
1812                " TO 'pg/_%|information/_schema' ESCAPE '/'")
1813        where = " WHERE %s" % ' AND '.join(where) if where else ''
1814        q = ("SELECT quote_ident(s.nspname)||'.'||quote_ident(r.relname)"
1815            " FROM pg_class r"
1816            " JOIN pg_namespace s ON s.oid = r.relnamespace%s"
1817            " ORDER BY s.nspname, r.relname") % where
1818        return [r[0] for r in self.db.query(q).getresult()]
1819
1820    def get_tables(self, system=False):
1821        """Return list of tables in connected database.
1822
1823        Set the system flag if you want to get the system tables as well.
1824        """
1825        return self.get_relations('r', system)
1826
1827    def get_attnames(self, table, with_oid=True, flush=False):
1828        """Given the name of a table, dig out the set of attribute names.
1829
1830        Returns a read-only dictionary of attribute names (the names are
1831        the keys, the values are the names of the attributes' types)
1832        with the column names in the proper order if you iterate over it.
1833
1834        If flush is set, then the internal cache for attribute names will
1835        be flushed. This may be necessary after the database schema or
1836        the search path has been changed.
1837
1838        By default, only a limited number of simple types will be returned.
1839        You can get the regular types after calling use_regtypes(True).
1840        """
1841        attnames = self._attnames
1842        if flush:
1843            attnames.clear()
1844            self._do_debug('The attnames cache has been flushed')
1845        try:  # cache lookup
1846            names = attnames[table]
1847        except KeyError:  # cache miss, check the database
1848            q = "a.attnum > 0"
1849            if with_oid:
1850                q = "(%s OR a.attname = 'oid')" % q
1851            q = self._query_attnames % (_quote_if_unqualified('$1', table), q)
1852            names = self.db.query(q, (table,)).getresult()
1853            types = self.dbtypes
1854            names = ((name[0], types.add(*name[1:])) for name in names)
1855            names = AttrDict(names)
1856            attnames[table] = names  # cache it
1857        return names
1858
1859    def use_regtypes(self, regtypes=None):
1860        """Use regular type names instead of simplified type names."""
1861        if regtypes is None:
1862            return self.dbtypes._regtypes
1863        else:
1864            regtypes = bool(regtypes)
1865            if regtypes != self.dbtypes._regtypes:
1866                self.dbtypes._regtypes = regtypes
1867                self._attnames.clear()
1868                self.dbtypes.clear()
1869            return regtypes
1870
1871    def has_table_privilege(self, table, privilege='select', flush=False):
1872        """Check whether current user has specified table privilege.
1873
1874        If flush is set, then the internal cache for table privileges will
1875        be flushed. This may be necessary after privileges have been changed.
1876        """
1877        privileges = self._privileges
1878        if flush:
1879            privileges.clear()
1880            self._do_debug('The privileges cache has been flushed')
1881        privilege = privilege.lower()
1882        try:  # ask cache
1883            ret = privileges[table, privilege]
1884        except KeyError:  # cache miss, ask the database
1885            q = "SELECT has_table_privilege(%s, $2)" % (
1886                _quote_if_unqualified('$1', table),)
1887            q = self.db.query(q, (table, privilege))
1888            ret = q.getresult()[0][0] == self._make_bool(True)
1889            privileges[table, privilege] = ret  # cache it
1890        return ret
1891
1892    def get(self, table, row, keyname=None):
1893        """Get a row from a database table or view.
1894
1895        This method is the basic mechanism to get a single row.  It assumes
1896        that the keyname specifies a unique row.  It must be the name of a
1897        single column or a tuple of column names.  If the keyname is not
1898        specified, then the primary key for the table is used.
1899
1900        If row is a dictionary, then the value for the key is taken from it.
1901        Otherwise, the row must be a single value or a tuple of values
1902        corresponding to the passed keyname or primary key.  The fetched row
1903        from the table will be returned as a new dictionary or used to replace
1904        the existing values when row was passed as a dictionary.
1905
1906        The OID is also put into the dictionary if the table has one, but
1907        in order to allow the caller to work with multiple tables, it is
1908        munged as "oid(table)" using the actual name of the table.
1909        """
1910        if table.endswith('*'):  # hint for descendant tables can be ignored
1911            table = table[:-1].rstrip()
1912        attnames = self.get_attnames(table)
1913        qoid = _oid_key(table) if 'oid' in attnames else None
1914        if keyname and isinstance(keyname, basestring):
1915            keyname = (keyname,)
1916        if qoid and isinstance(row, dict) and qoid in row and 'oid' not in row:
1917            row['oid'] = row[qoid]
1918        if not keyname:
1919            try:  # if keyname is not specified, try using the primary key
1920                keyname = self.pkey(table, True)
1921            except KeyError:  # the table has no primary key
1922                # try using the oid instead
1923                if qoid and isinstance(row, dict) and 'oid' in row:
1924                    keyname = ('oid',)
1925                else:
1926                    raise _prg_error('Table %s has no primary key' % table)
1927            else:  # the table has a primary key
1928                # check whether all key columns have values
1929                if isinstance(row, dict) and not set(keyname).issubset(row):
1930                    # try using the oid instead
1931                    if qoid and 'oid' in row:
1932                        keyname = ('oid',)
1933                    else:
1934                        raise KeyError(
1935                            'Missing value in row for specified keyname')
1936        if not isinstance(row, dict):
1937            if not isinstance(row, (tuple, list)):
1938                row = [row]
1939            if len(keyname) != len(row):
1940                raise KeyError(
1941                    'Differing number of items in keyname and row')
1942            row = dict(zip(keyname, row))
1943        params = self.adapter.parameter_list()
1944        adapt = params.add
1945        col = self.escape_identifier
1946        what = 'oid, *' if qoid else '*'
1947        where = ' AND '.join('%s = %s' % (
1948            col(k), adapt(row[k], attnames[k])) for k in keyname)
1949        if 'oid' in row:
1950            if qoid:
1951                row[qoid] = row['oid']
1952            del row['oid']
1953        q = 'SELECT %s FROM %s WHERE %s LIMIT 1' % (
1954            what, self._escape_qualified_name(table), where)
1955        self._do_debug(q, params)
1956        q = self.db.query(q, params)
1957        res = q.dictresult()
1958        if not res:
1959            raise _db_error('No such record in %s\nwhere %s\nwith %s' % (
1960                table, where, self._list_params(params)))
1961        for n, value in res[0].items():
1962            if qoid and n == 'oid':
1963                n = qoid
1964            row[n] = value
1965        return row
1966
1967    def insert(self, table, row=None, **kw):
1968        """Insert a row into a database table.
1969
1970        This method inserts a row into a table.  The name of the table must
1971        be passed as the first parameter.  The other parameters are used for
1972        providing the data of the row that shall be inserted into the table.
1973        If a dictionary is supplied as the second parameter, it starts with
1974        that.  Otherwise it uses a blank dictionary. Either way the dictionary
1975        is updated from the keywords.
1976
1977        The dictionary is then reloaded with the values actually inserted in
1978        order to pick up values modified by rules, triggers, etc.
1979        """
1980        if table.endswith('*'):  # hint for descendant tables can be ignored
1981            table = table[:-1].rstrip()
1982        if row is None:
1983            row = {}
1984        row.update(kw)
1985        if 'oid' in row:
1986            del row['oid']  # do not insert oid
1987        attnames = self.get_attnames(table)
1988        qoid = _oid_key(table) if 'oid' in attnames else None
1989        params = self.adapter.parameter_list()
1990        adapt = params.add
1991        col = self.escape_identifier
1992        names, values = [], []
1993        for n in attnames:
1994            if n in row:
1995                names.append(col(n))
1996                values.append(adapt(row[n], attnames[n]))
1997        if not names:
1998            raise _prg_error('No column found that can be inserted')
1999        names, values = ', '.join(names), ', '.join(values)
2000        ret = 'oid, *' if qoid else '*'
2001        q = 'INSERT INTO %s (%s) VALUES (%s) RETURNING %s' % (
2002            self._escape_qualified_name(table), names, values, ret)
2003        self._do_debug(q, params)
2004        q = self.db.query(q, params)
2005        res = q.dictresult()
2006        if res:  # this should always be true
2007            for n, value in res[0].items():
2008                if qoid and n == 'oid':
2009                    n = qoid
2010                row[n] = value
2011        return row
2012
2013    def update(self, table, row=None, **kw):
2014        """Update an existing row in a database table.
2015
2016        Similar to insert, but updates an existing row.  The update is based
2017        on the primary key of the table or the OID value as munged by get()
2018        or passed as keyword.  The OID will take precedence if provided, so
2019        that it is possible to update the primary key itself.
2020
2021        The dictionary is then modified to reflect any changes caused by the
2022        update due to triggers, rules, default values, etc.
2023        """
2024        if table.endswith('*'):
2025            table = table[:-1].rstrip()  # need parent table name
2026        attnames = self.get_attnames(table)
2027        qoid = _oid_key(table) if 'oid' in attnames else None
2028        if row is None:
2029            row = {}
2030        elif 'oid' in row:
2031            del row['oid']  # only accept oid key from named args for safety
2032        row.update(kw)
2033        if qoid and qoid in row and 'oid' not in row:
2034            row['oid'] = row[qoid]
2035        if qoid and 'oid' in row:  # try using the oid
2036            keyname = ('oid',)
2037        else:  # try using the primary key
2038            try:
2039                keyname = self.pkey(table, True)
2040            except KeyError:  # the table has no primary key
2041                raise _prg_error('Table %s has no primary key' % table)
2042            # check whether all key columns have values
2043            if not set(keyname).issubset(row):
2044                raise KeyError('Missing value for primary key in row')
2045        params = self.adapter.parameter_list()
2046        adapt = params.add
2047        col = self.escape_identifier
2048        where = ' AND '.join('%s = %s' % (
2049            col(k), adapt(row[k], attnames[k])) for k in keyname)
2050        if 'oid' in row:
2051            if qoid:
2052                row[qoid] = row['oid']
2053            del row['oid']
2054        values = []
2055        keyname = set(keyname)
2056        for n in attnames:
2057            if n in row and n not in keyname:
2058                values.append('%s = %s' % (col(n), adapt(row[n], attnames[n])))
2059        if not values:
2060            return row
2061        values = ', '.join(values)
2062        ret = 'oid, *' if qoid else '*'
2063        q = 'UPDATE %s SET %s WHERE %s RETURNING %s' % (
2064            self._escape_qualified_name(table), values, where, ret)
2065        self._do_debug(q, params)
2066        q = self.db.query(q, params)
2067        res = q.dictresult()
2068        if res:  # may be empty when row does not exist
2069            for n, value in res[0].items():
2070                if qoid and n == 'oid':
2071                    n = qoid
2072                row[n] = value
2073        return row
2074
2075    def upsert(self, table, row=None, **kw):
2076        """Insert a row into a database table with conflict resolution
2077
2078        This method inserts a row into a table, but instead of raising a
2079        ProgrammingError exception in case a row with the same primary key
2080        already exists, an update will be executed instead.  This will be
2081        performed as a single atomic operation on the database, so race
2082        conditions can be avoided.
2083
2084        Like the insert method, the first parameter is the name of the
2085        table and the second parameter can be used to pass the values to
2086        be inserted as a dictionary.
2087
2088        Unlike the insert und update statement, keyword parameters are not
2089        used to modify the dictionary, but to specify which columns shall
2090        be updated in case of a conflict, and in which way:
2091
2092        A value of False or None means the column shall not be updated,
2093        a value of True means the column shall be updated with the value
2094        that has been proposed for insertion, i.e. has been passed as value
2095        in the dictionary.  Columns that are not specified by keywords but
2096        appear as keys in the dictionary are also updated like in the case
2097        keywords had been passed with the value True.
2098
2099        So if in the case of a conflict you want to update every column that
2100        has been passed in the dictionary row, you would call upsert(table, row).
2101        If you don't want to do anything in case of a conflict, i.e. leave
2102        the existing row as it is, call upsert(table, row, **dict.fromkeys(row)).
2103
2104        If you need more fine-grained control of what gets updated, you can
2105        also pass strings in the keyword parameters.  These strings will
2106        be used as SQL expressions for the update columns.  In these
2107        expressions you can refer to the value that already exists in
2108        the table by prefixing the column name with "included.", and to
2109        the value that has been proposed for insertion by prefixing the
2110        column name with the "excluded."
2111
2112        The dictionary is modified in any case to reflect the values in
2113        the database after the operation has completed.
2114
2115        Note: The method uses the PostgreSQL "upsert" feature which is
2116        only available since PostgreSQL 9.5.
2117        """
2118        if table.endswith('*'):  # hint for descendant tables can be ignored
2119            table = table[:-1].rstrip()
2120        if row is None:
2121            row = {}
2122        if 'oid' in row:
2123            del row['oid']  # do not insert oid
2124        if 'oid' in kw:
2125            del kw['oid']  # do not update oid
2126        attnames = self.get_attnames(table)
2127        qoid = _oid_key(table) if 'oid' in attnames else None
2128        params = self.adapter.parameter_list()
2129        adapt = params.add
2130        col = self.escape_identifier
2131        names, values, updates = [], [], []
2132        for n in attnames:
2133            if n in row:
2134                names.append(col(n))
2135                values.append(adapt(row[n], attnames[n]))
2136        names, values = ', '.join(names), ', '.join(values)
2137        try:
2138            keyname = self.pkey(table, True)
2139        except KeyError:
2140            raise _prg_error('Table %s has no primary key' % table)
2141        target = ', '.join(col(k) for k in keyname)
2142        update = []
2143        keyname = set(keyname)
2144        keyname.add('oid')
2145        for n in attnames:
2146            if n not in keyname:
2147                value = kw.get(n, True)
2148                if value:
2149                    if not isinstance(value, basestring):
2150                        value = 'excluded.%s' % col(n)
2151                    update.append('%s = %s' % (col(n), value))
2152        if not values:
2153            return row
2154        do = 'update set %s' % ', '.join(update) if update else 'nothing'
2155        ret = 'oid, *' if qoid else '*'
2156        q = ('INSERT INTO %s AS included (%s) VALUES (%s)'
2157            ' ON CONFLICT (%s) DO %s RETURNING %s') % (
2158                self._escape_qualified_name(table), names, values,
2159                target, do, ret)
2160        self._do_debug(q, params)
2161        try:
2162            q = self.db.query(q, params)
2163        except ProgrammingError:
2164            if self.server_version < 90500:
2165                raise _prg_error(
2166                    'Upsert operation is not supported by PostgreSQL version')
2167            raise  # re-raise original error
2168        res = q.dictresult()
2169        if res:  # may be empty with "do nothing"
2170            for n, value in res[0].items():
2171                if qoid and n == 'oid':
2172                    n = qoid
2173                row[n] = value
2174        else:
2175            self.get(table, row)
2176        return row
2177
2178    def clear(self, table, row=None):
2179        """Clear all the attributes to values determined by the types.
2180
2181        Numeric types are set to 0, Booleans are set to false, and everything
2182        else is set to the empty string.  If the row argument is present,
2183        it is used as the row dictionary and any entries matching attribute
2184        names are cleared with everything else left unchanged.
2185        """
2186        # At some point we will need a way to get defaults from a table.
2187        if row is None:
2188            row = {}  # empty if argument is not present
2189        attnames = self.get_attnames(table)
2190        for n, t in attnames.items():
2191            if n == 'oid':
2192                continue
2193            t = t.simple
2194            if t in DbTypes._num_types:
2195                row[n] = 0
2196            elif t == 'bool':
2197                row[n] = self._make_bool(False)
2198            else:
2199                row[n] = ''
2200        return row
2201
2202    def delete(self, table, row=None, **kw):
2203        """Delete an existing row in a database table.
2204
2205        This method deletes the row from a table.  It deletes based on the
2206        primary key of the table or the OID value as munged by get() or
2207        passed as keyword.  The OID will take precedence if provided.
2208
2209        The return value is the number of deleted rows (i.e. 0 if the row
2210        did not exist and 1 if the row was deleted).
2211
2212        Note that if the row cannot be deleted because e.g. it is still
2213        referenced by another table, this method raises a ProgrammingError.
2214        """
2215        if table.endswith('*'):  # hint for descendant tables can be ignored
2216            table = table[:-1].rstrip()
2217        attnames = self.get_attnames(table)
2218        qoid = _oid_key(table) if 'oid' in attnames else None
2219        if row is None:
2220            row = {}
2221        elif 'oid' in row:
2222            del row['oid']  # only accept oid key from named args for safety
2223        row.update(kw)
2224        if qoid and qoid in row and 'oid' not in row:
2225            row['oid'] = row[qoid]
2226        if qoid and 'oid' in row:  # try using the oid
2227            keyname = ('oid',)
2228        else:  # try using the primary key
2229            try:
2230                keyname = self.pkey(table, True)
2231            except KeyError:  # the table has no primary key
2232                raise _prg_error('Table %s has no primary key' % table)
2233            # check whether all key columns have values
2234            if not set(keyname).issubset(row):
2235                raise KeyError('Missing value for primary key in row')
2236        params = self.adapter.parameter_list()
2237        adapt = params.add
2238        col = self.escape_identifier
2239        where = ' AND '.join('%s = %s' % (
2240            col(k), adapt(row[k], attnames[k])) for k in keyname)
2241        if 'oid' in row:
2242            if qoid:
2243                row[qoid] = row['oid']
2244            del row['oid']
2245        q = 'DELETE FROM %s WHERE %s' % (
2246            self._escape_qualified_name(table), where)
2247        self._do_debug(q, params)
2248        res = self.db.query(q, params)
2249        return int(res)
2250
2251    def truncate(self, table, restart=False, cascade=False, only=False):
2252        """Empty a table or set of tables.
2253
2254        This method quickly removes all rows from the given table or set
2255        of tables.  It has the same effect as an unqualified DELETE on each
2256        table, but since it does not actually scan the tables it is faster.
2257        Furthermore, it reclaims disk space immediately, rather than requiring
2258        a subsequent VACUUM operation. This is most useful on large tables.
2259
2260        If restart is set to True, sequences owned by columns of the truncated
2261        table(s) are automatically restarted.  If cascade is set to True, it
2262        also truncates all tables that have foreign-key references to any of
2263        the named tables.  If the parameter only is not set to True, all the
2264        descendant tables (if any) will also be truncated. Optionally, a '*'
2265        can be specified after the table name to explicitly indicate that
2266        descendant tables are included.
2267        """
2268        if isinstance(table, basestring):
2269            only = {table: only}
2270            table = [table]
2271        elif isinstance(table, (list, tuple)):
2272            if isinstance(only, (list, tuple)):
2273                only = dict(zip(table, only))
2274            else:
2275                only = dict.fromkeys(table, only)
2276        elif isinstance(table, (set, frozenset)):
2277            only = dict.fromkeys(table, only)
2278        else:
2279            raise TypeError('The table must be a string, list or set')
2280        if not (restart is None or isinstance(restart, (bool, int))):
2281            raise TypeError('Invalid type for the restart option')
2282        if not (cascade is None or isinstance(cascade, (bool, int))):
2283            raise TypeError('Invalid type for the cascade option')
2284        tables = []
2285        for t in table:
2286            u = only.get(t)
2287            if not (u is None or isinstance(u, (bool, int))):
2288                raise TypeError('Invalid type for the only option')
2289            if t.endswith('*'):
2290                if u:
2291                    raise ValueError(
2292                        'Contradictory table name and only options')
2293                t = t[:-1].rstrip()
2294            t = self._escape_qualified_name(t)
2295            if u:
2296                t = 'ONLY %s' % t
2297            tables.append(t)
2298        q = ['TRUNCATE', ', '.join(tables)]
2299        if restart:
2300            q.append('RESTART IDENTITY')
2301        if cascade:
2302            q.append('CASCADE')
2303        q = ' '.join(q)
2304        self._do_debug(q)
2305        return self.db.query(q)
2306
2307    def get_as_list(self, table, what=None, where=None,
2308            order=None, limit=None, offset=None, scalar=False):
2309        """Get a table as a list.
2310
2311        This gets a convenient representation of the table as a list
2312        of named tuples in Python.  You only need to pass the name of
2313        the table (or any other SQL expression returning rows).  Note that
2314        by default this will return the full content of the table which
2315        can be huge and overflow your memory.  However, you can control
2316        the amount of data returned using the other optional parameters.
2317
2318        The parameter 'what' can restrict the query to only return a
2319        subset of the table columns.  It can be a string, list or a tuple.
2320        The parameter 'where' can restrict the query to only return a
2321        subset of the table rows.  It can be a string, list or a tuple
2322        of SQL expressions that all need to be fulfilled.  The parameter
2323        'order' specifies the ordering of the rows.  It can also be a
2324        other string, list or a tuple.  If no ordering is specified,
2325        the result will be ordered by the primary key(s) or all columns
2326        if no primary key exists.  You can set 'order' to False if you
2327        don't care about the ordering.  The parameters 'limit' and 'offset'
2328        can be integers specifying the maximum number of rows returned
2329        and a number of rows skipped over.
2330
2331        If you set the 'scalar' option to True, then instead of the
2332        named tuples you will get the first items of these tuples.
2333        This is useful if the result has only one column anyway.
2334        """
2335        if not table:
2336            raise TypeError('The table name is missing')
2337        if what:
2338            if isinstance(what, (list, tuple)):
2339                what = ', '.join(map(str, what))
2340            if order is None:
2341                order = what
2342        else:
2343            what = '*'
2344        q = ['SELECT', what, 'FROM', table]
2345        if where:
2346            if isinstance(where, (list, tuple)):
2347                where = ' AND '.join(map(str, where))
2348            q.extend(['WHERE', where])
2349        if order is None:
2350            try:
2351                order = self.pkey(table, True)
2352            except (KeyError, ProgrammingError):
2353                try:
2354                    order = list(self.get_attnames(table))
2355                except (KeyError, ProgrammingError):
2356                    pass
2357        if order:
2358            if isinstance(order, (list, tuple)):
2359                order = ', '.join(map(str, order))
2360            q.extend(['ORDER BY', order])
2361        if limit:
2362            q.append('LIMIT %d' % limit)
2363        if offset:
2364            q.append('OFFSET %d' % offset)
2365        q = ' '.join(q)
2366        self._do_debug(q)
2367        q = self.db.query(q)
2368        res = q.namedresult()
2369        if res and scalar:
2370            res = [row[0] for row in res]
2371        return res
2372
2373    def get_as_dict(self, table, keyname=None, what=None, where=None,
2374            order=None, limit=None, offset=None, scalar=False):
2375        """Get a table as a dictionary.
2376
2377        This method is similar to get_as_list(), but returns the table
2378        as a Python dict instead of a Python list, which can be even
2379        more convenient. The primary key column(s) of the table will
2380        be used as the keys of the dictionary, while the other column(s)
2381        will be the corresponding values.  The keys will be named tuples
2382        if the table has a composite primary key.  The rows will be also
2383        named tuples unless the 'scalar' option has been set to True.
2384        With the optional parameter 'keyname' you can specify an alternative
2385        set of columns to be used as the keys of the dictionary.  It must
2386        be set as a string, list or a tuple.
2387
2388        If the Python version supports it, the dictionary will be an
2389        OrderedDict using the order specified with the 'order' parameter
2390        or the key column(s) if not specified.  You can set 'order' to False
2391        if you don't care about the ordering.  In this case the returned
2392        dictionary will be an ordinary one.
2393        """
2394        if not table:
2395            raise TypeError('The table name is missing')
2396        if not keyname:
2397            try:
2398                keyname = self.pkey(table, True)
2399            except (KeyError, ProgrammingError):
2400                raise _prg_error('Table %s has no primary key' % table)
2401        if isinstance(keyname, basestring):
2402            keyname = [keyname]
2403        elif not isinstance(keyname, (list, tuple)):
2404            raise KeyError('The keyname must be a string, list or tuple')
2405        if what:
2406            if isinstance(what, (list, tuple)):
2407                what = ', '.join(map(str, what))
2408            if order is None:
2409                order = what
2410        else:
2411            what = '*'
2412        q = ['SELECT', what, 'FROM', table]
2413        if where:
2414            if isinstance(where, (list, tuple)):
2415                where = ' AND '.join(map(str, where))
2416            q.extend(['WHERE', where])
2417        if order is None:
2418            order = keyname
2419        if order:
2420            if isinstance(order, (list, tuple)):
2421                order = ', '.join(map(str, order))
2422            q.extend(['ORDER BY', order])
2423        if limit:
2424            q.append('LIMIT %d' % limit)
2425        if offset:
2426            q.append('OFFSET %d' % offset)
2427        q = ' '.join(q)
2428        self._do_debug(q)
2429        q = self.db.query(q)
2430        res = q.getresult()
2431        cls = OrderedDict if order else dict
2432        if not res:
2433            return cls()
2434        keyset = set(keyname)
2435        fields = q.listfields()
2436        if not keyset.issubset(fields):
2437            raise KeyError('Missing keyname in row')
2438        keyind, rowind = [], []
2439        for i, f in enumerate(fields):
2440            (keyind if f in keyset else rowind).append(i)
2441        keytuple = len(keyind) > 1
2442        getkey = itemgetter(*keyind)
2443        keys = map(getkey, res)
2444        if scalar:
2445            rowind = rowind[:1]
2446            rowtuple = False
2447        else:
2448            rowtuple = len(rowind) > 1
2449        if scalar or rowtuple:
2450            getrow = itemgetter(*rowind)
2451        else:
2452            rowind = rowind[0]
2453            getrow = lambda row: (row[rowind],)
2454            rowtuple = True
2455        rows = map(getrow, res)
2456        if keytuple or rowtuple:
2457            namedresult = get_namedresult()
2458            if namedresult:
2459                if keytuple:
2460                    keys = namedresult(_MemoryQuery(keys, keyname))
2461                if rowtuple:
2462                    fields = [f for f in fields if f not in keyset]
2463                    rows = namedresult(_MemoryQuery(rows, fields))
2464        return cls(zip(keys, rows))
2465
2466    def notification_handler(self,
2467            event, callback, arg_dict=None, timeout=None, stop_event=None):
2468        """Get notification handler that will run the given callback."""
2469        return NotificationHandler(self,
2470            event, callback, arg_dict, timeout, stop_event)
2471
2472
2473# if run as script, print some information
2474
2475if __name__ == '__main__':
2476    print('PyGreSQL version' + version)
2477    print('')
2478    print(__doc__)
Note: See TracBrowser for help on using the repository browser.