source: trunk/pg.py @ 886

Last change on this file since 886 was 886, checked in by cito, 3 years ago

Fix infinite recursion issue

The getattr method assumes that the "db" attribute is always set
(either to None or to the underlying connection). We ensure this
by setting the class attribute "db" to None.

  • Property svn:keywords set to Id
File size: 89.7 KB
Line 
1#! /usr/bin/python
2#
3# pg.py
4#
5# $Id: pg.py 886 2016-09-08 16:04:53Z cito $
6#
7
8"""PyGreSQL classic interface.
9
10This pg module implements some basic database management stuff.
11It includes the _pg module and builds on it, providing the higher
12level wrapper class named DB with additional functionality.
13This is known as the "classic" ("old style") PyGreSQL interface.
14For a DB-API 2 compliant interface use the newer pgdb module.
15"""
16
17# Copyright (c) 1997-2016 by D'Arcy J.M. Cain.
18#
19# Contributions made by Ch. Zwerschke and others.
20#
21# The notification handler is based on pgnotify which is
22# Copyright (c) 2001 Ng Pheng Siong. All rights reserved.
23#
24# Permission to use, copy, modify, and distribute this software and its
25# documentation for any purpose and without fee is hereby granted,
26# provided that the above copyright notice appear in all copies and that
27# both that copyright notice and this permission notice appear in
28# supporting documentation.
29
30from __future__ import print_function, division
31
32from _pg import *
33
34__version__ = version
35
36import select
37import warnings
38import weakref
39
40from datetime import date, time, datetime, timedelta, tzinfo
41from decimal import Decimal
42from math import isnan, isinf
43from collections import namedtuple
44from operator import itemgetter
45from functools import partial
46from re import compile as regex
47from json import loads as jsondecode, dumps as jsonencode
48from uuid import UUID
49
50try:
51    long
52except NameError:  # Python >= 3.0
53    long = int
54
55try:
56    basestring
57except NameError:  # Python >= 3.0
58    basestring = (str, bytes)
59
60
61# Auxiliary classes and functions that are independent from a DB connection:
62
63try:
64    from collections import OrderedDict
65except ImportError:  # Python 2.6 or 3.0
66    OrderedDict = dict
67
68
69    class AttrDict(dict):
70        """Simple read-only ordered dictionary for storing attribute names."""
71
72        def __init__(self, *args, **kw):
73            if len(args) > 1 or kw:
74                raise TypeError
75            items = args[0] if args else []
76            if isinstance(items, dict):
77                raise TypeError
78            items = list(items)
79            self._keys = [item[0] for item in items]
80            dict.__init__(self, items)
81            self._read_only = True
82            error = self._read_only_error
83            self.clear = self.update = error
84            self.pop = self.setdefault = self.popitem = error
85
86        def __setitem__(self, key, value):
87            if self._read_only:
88                self._read_only_error()
89            dict.__setitem__(self, key, value)
90
91        def __delitem__(self, key):
92            if self._read_only:
93                self._read_only_error()
94            dict.__delitem__(self, key)
95
96        def __iter__(self):
97            return iter(self._keys)
98
99        def keys(self):
100            return list(self._keys)
101
102        def values(self):
103            return [self[key] for key in self]
104
105        def items(self):
106            return [(key, self[key]) for key in self]
107
108        def iterkeys(self):
109            return self.__iter__()
110
111        def itervalues(self):
112            return iter(self.values())
113
114        def iteritems(self):
115            return iter(self.items())
116
117        @staticmethod
118        def _read_only_error(*args, **kw):
119            raise TypeError('This object is read-only')
120
121else:
122
123     class AttrDict(OrderedDict):
124        """Simple read-only ordered dictionary for storing attribute names."""
125
126        def __init__(self, *args, **kw):
127            self._read_only = False
128            OrderedDict.__init__(self, *args, **kw)
129            self._read_only = True
130            error = self._read_only_error
131            self.clear = self.update = error
132            self.pop = self.setdefault = self.popitem = error
133
134        def __setitem__(self, key, value):
135            if self._read_only:
136                self._read_only_error()
137            OrderedDict.__setitem__(self, key, value)
138
139        def __delitem__(self, key):
140            if self._read_only:
141                self._read_only_error()
142            OrderedDict.__delitem__(self, key)
143
144        @staticmethod
145        def _read_only_error(*args, **kw):
146            raise TypeError('This object is read-only')
147
148try:
149    from inspect import signature
150except ImportError:  # Python < 3.3
151    from inspect import getargspec
152
153    def get_args(func):
154        return getargspec(func).args
155else:
156
157    def get_args(func):
158        return list(signature(func).parameters)
159
160try:
161    from datetime import timezone
162except ImportError:  # Python < 3.2
163
164    class timezone(tzinfo):
165        """Simple timezone implementation."""
166
167        def __init__(self, offset, name=None):
168            self.offset = offset
169            if not name:
170                minutes = self.offset.days * 1440 + self.offset.seconds // 60
171                if minutes < 0:
172                    hours, minutes = divmod(-minutes, 60)
173                    hours = -hours
174                else:
175                    hours, minutes = divmod(minutes, 60)
176                name = 'UTC%+03d:%02d' % (hours, minutes)
177            self.name = name
178
179        def utcoffset(self, dt):
180            return self.offset
181
182        def tzname(self, dt):
183            return self.name
184
185        def dst(self, dt):
186            return None
187
188    timezone.utc = timezone(timedelta(0), 'UTC')
189
190    _has_timezone = False
191else:
192    _has_timezone = True
193
194# time zones used in Postgres timestamptz output
195_timezones = dict(CET='+0100', EET='+0200', EST='-0500',
196    GMT='+0000', HST='-1000', MET='+0100', MST='-0700',
197    UCT='+0000', UTC='+0000', WET='+0000')
198
199
200def _timezone_as_offset(tz):
201    if tz.startswith(('+', '-')):
202        if len(tz) < 5:
203            return tz + '00'
204        return tz.replace(':', '')
205    return _timezones.get(tz, '+0000')
206
207
208def _get_timezone(tz):
209    tz = _timezone_as_offset(tz)
210    minutes = 60 * int(tz[1:3]) + int(tz[3:5])
211    if tz[0] == '-':
212        minutes = -minutes
213    return timezone(timedelta(minutes=minutes), tz)
214
215
216def _oid_key(table):
217    """Build oid key from a table name."""
218    return 'oid(%s)' % table
219
220
221class _SimpleTypes(dict):
222    """Dictionary mapping pg_type names to simple type names."""
223
224    _types = {'bool': 'bool',
225        'bytea': 'bytea',
226        'date': 'date interval time timetz timestamp timestamptz'
227            ' abstime reltime',  # these are very old
228        'float': 'float4 float8',
229        'int': 'cid int2 int4 int8 oid xid',
230        'hstore': 'hstore', 'json': 'json jsonb', 'uuid': 'uuid',
231        'num': 'numeric', 'money': 'money',
232        'text': 'bpchar char name text varchar'}
233
234    def __init__(self):
235        for typ, keys in self._types.items():
236            for key in keys.split():
237                self[key] = typ
238                self['_%s' % key] = '%s[]' % typ
239
240    # this could be a static method in Python > 2.6
241    def __missing__(self, key):
242        return 'text'
243
244_simpletypes = _SimpleTypes()
245
246
247def _quote_if_unqualified(param, name):
248    """Quote parameter representing a qualified name.
249
250    Puts a quote_ident() call around the give parameter unless
251    the name contains a dot, in which case the name is ambiguous
252    (could be a qualified name or just a name with a dot in it)
253    and must be quoted manually by the caller.
254    """
255    if isinstance(name, basestring) and '.' not in name:
256        return 'quote_ident(%s)' % (param,)
257    return param
258
259
260class _ParameterList(list):
261    """Helper class for building typed parameter lists."""
262
263    def add(self, value, typ=None):
264        """Typecast value with known database type and build parameter list.
265
266        If this is a literal value, it will be returned as is.  Otherwise, a
267        placeholder will be returned and the parameter list will be augmented.
268        """
269        value = self.adapt(value, typ)
270        if isinstance(value, Literal):
271            return value
272        self.append(value)
273        return '$%d' % len(self)
274
275
276class Bytea(bytes):
277    """Wrapper class for marking Bytea values."""
278
279
280class Hstore(dict):
281    """Wrapper class for marking hstore values."""
282
283    _re_quote = regex('^[Nn][Uu][Ll][Ll]$|[ ,=>]')
284
285    @classmethod
286    def _quote(cls, s):
287        if s is None:
288            return 'NULL'
289        if not s:
290            return '""'
291        s = s.replace('"', '\\"')
292        if cls._re_quote.search(s):
293            s = '"%s"' % s
294        return s
295
296    def __str__(self):
297        q = self._quote
298        return ','.join('%s=>%s' % (q(k), q(v)) for k, v in self.items())
299
300
301class Json:
302    """Wrapper class for marking Json values."""
303
304    def __init__(self, obj):
305        self.obj = obj
306
307
308class Literal(str):
309    """Wrapper class for marking literal SQL values."""
310
311
312class Adapter:
313    """Class providing methods for adapting parameters to the database."""
314
315    _bool_true_values = frozenset('t true 1 y yes on'.split())
316
317    _date_literals = frozenset('current_date current_time'
318        ' current_timestamp localtime localtimestamp'.split())
319
320    _re_array_quote = regex(r'[{},"\\\s]|^[Nn][Uu][Ll][Ll]$')
321    _re_record_quote = regex(r'[(,"\\]')
322    _re_array_escape = _re_record_escape = regex(r'(["\\])')
323
324    def __init__(self, db):
325        self.db = weakref.proxy(db)
326
327    @classmethod
328    def _adapt_bool(cls, v):
329        """Adapt a boolean parameter."""
330        if isinstance(v, basestring):
331            if not v:
332                return None
333            v = v.lower() in cls._bool_true_values
334        return 't' if v else 'f'
335
336    @classmethod
337    def _adapt_date(cls, v):
338        """Adapt a date parameter."""
339        if not v:
340            return None
341        if isinstance(v, basestring) and v.lower() in cls._date_literals:
342            return Literal(v)
343        return v
344
345    @staticmethod
346    def _adapt_num(v):
347        """Adapt a numeric parameter."""
348        if not v and v != 0:
349            return None
350        return v
351
352    _adapt_int = _adapt_float = _adapt_money = _adapt_num
353
354    def _adapt_bytea(self, v):
355        """Adapt a bytea parameter."""
356        return self.db.escape_bytea(v)
357
358    def _adapt_json(self, v):
359        """Adapt a json parameter."""
360        if not v:
361            return None
362        if isinstance(v, basestring):
363            return v
364        return self.db.encode_json(v)
365
366    @classmethod
367    def _adapt_text_array(cls, v):
368        """Adapt a text type array parameter."""
369        if isinstance(v, list):
370            adapt = cls._adapt_text_array
371            return '{%s}' % ','.join(adapt(v) for v in v)
372        if v is None:
373            return 'null'
374        if not v:
375            return '""'
376        v = str(v)
377        if cls._re_array_quote.search(v):
378            v = '"%s"' % cls._re_array_escape.sub(r'\\\1', v)
379        return v
380
381    _adapt_date_array = _adapt_text_array
382
383    @classmethod
384    def _adapt_bool_array(cls, v):
385        """Adapt a boolean array parameter."""
386        if isinstance(v, list):
387            adapt = cls._adapt_bool_array
388            return '{%s}' % ','.join(adapt(v) for v in v)
389        if v is None:
390            return 'null'
391        if isinstance(v, basestring):
392            if not v:
393                return 'null'
394            v = v.lower() in cls._bool_true_values
395        return 't' if v else 'f'
396
397    @classmethod
398    def _adapt_num_array(cls, v):
399        """Adapt a numeric array parameter."""
400        if isinstance(v, list):
401            adapt = cls._adapt_num_array
402            return '{%s}' % ','.join(adapt(v) for v in v)
403        if not v and v != 0:
404            return 'null'
405        return str(v)
406
407    _adapt_int_array = _adapt_float_array = _adapt_money_array = \
408            _adapt_num_array
409
410    def _adapt_bytea_array(self, v):
411        """Adapt a bytea array parameter."""
412        if isinstance(v, list):
413            return b'{' + b','.join(
414                self._adapt_bytea_array(v) for v in v) + b'}'
415        if v is None:
416            return b'null'
417        return self.db.escape_bytea(v).replace(b'\\', b'\\\\')
418
419    def _adapt_json_array(self, v):
420        """Adapt a json array parameter."""
421        if isinstance(v, list):
422            adapt = self._adapt_json_array
423            return '{%s}' % ','.join(adapt(v) for v in v)
424        if not v:
425            return 'null'
426        if not isinstance(v, basestring):
427            v = self.db.encode_json(v)
428        if self._re_array_quote.search(v):
429            v = '"%s"' % self._re_array_escape.sub(r'\\\1', v)
430        return v
431
432    def _adapt_record(self, v, typ):
433        """Adapt a record parameter with given type."""
434        typ = self.get_attnames(typ).values()
435        if len(typ) != len(v):
436            raise TypeError('Record parameter %s has wrong size' % v)
437        adapt = self.adapt
438        value = []
439        for v, t in zip(v, typ):
440            v = adapt(v, t)
441            if v is None:
442                v = ''
443            elif not v:
444                v = '""'
445            else:
446                if isinstance(v, bytes):
447                    if str is not bytes:
448                        v = v.decode('ascii')
449                else:
450                    v = str(v)
451                if self._re_record_quote.search(v):
452                    v = '"%s"' % self._re_record_escape.sub(r'\\\1', v)
453            value.append(v)
454        return '(%s)' % ','.join(value)
455
456    def adapt(self, value, typ=None):
457        """Adapt a value with known database type."""
458        if value is not None and not isinstance(value, Literal):
459            if typ:
460                simple = self.get_simple_name(typ)
461            else:
462                typ = simple = self.guess_simple_type(value) or 'text'
463            try:
464                value = value.__pg_str__(typ)
465            except AttributeError:
466                pass
467            if simple == 'text':
468                pass
469            elif simple == 'record':
470                if isinstance(value, tuple):
471                    value = self._adapt_record(value, typ)
472            elif simple.endswith('[]'):
473                if isinstance(value, list):
474                    adapt = getattr(self, '_adapt_%s_array' % simple[:-2])
475                    value = adapt(value)
476            else:
477                adapt = getattr(self, '_adapt_%s' % simple)
478                value = adapt(value)
479        return value
480
481    @staticmethod
482    def simple_type(name):
483        """Create a simple database type with given attribute names."""
484        typ = DbType(name)
485        typ.simple = name
486        return typ
487
488    @staticmethod
489    def get_simple_name(typ):
490        """Get the simple name of a database type."""
491        if isinstance(typ, DbType):
492            return typ.simple
493        return _simpletypes[typ]
494
495    @staticmethod
496    def get_attnames(typ):
497        """Get the attribute names of a composite database type."""
498        if isinstance(typ, DbType):
499            return typ.attnames
500        return {}
501
502    @classmethod
503    def guess_simple_type(cls, value):
504        """Try to guess which database type the given value has."""
505        if isinstance(value, Bytea):
506            return 'bytea'
507        if isinstance(value, basestring):
508            return 'text'
509        if isinstance(value, bool):
510            return 'bool'
511        if isinstance(value, (int, long)):
512            return 'int'
513        if isinstance(value, float):
514            return 'float'
515        if isinstance(value, Decimal):
516            return 'num'
517        if isinstance(value, (date, time, datetime, timedelta)):
518            return 'date'
519        if isinstance(value, list):
520            return '%s[]' % (cls.guess_simple_base_type(value) or 'text',)
521        if isinstance(value, tuple):
522            simple_type = cls.simple_type
523            typ = simple_type('record')
524            guess = cls.guess_simple_type
525            def get_attnames(self):
526                return AttrDict((str(n + 1), simple_type(guess(v)))
527                    for n, v in enumerate(value))
528            typ._get_attnames = get_attnames
529            return typ
530
531    @classmethod
532    def guess_simple_base_type(cls, value):
533        """Try to guess the base type of a given array."""
534        for v in value:
535            if isinstance(v, list):
536                typ = cls.guess_simple_base_type(v)
537            else:
538                typ = cls.guess_simple_type(v)
539            if typ:
540                return typ
541
542    def adapt_inline(self, value, nested=False):
543        """Adapt a value that is put into the SQL and needs to be quoted."""
544        if value is None:
545            return 'NULL'
546        if isinstance(value, Literal):
547            return value
548        if isinstance(value, Bytea):
549            value = self.db.escape_bytea(value)
550            if bytes is not str:  # Python >= 3.0
551                value = value.decode('ascii')
552        elif isinstance(value, Json):
553            if value.encode:
554                return value.encode()
555            value = self.db.encode_json(value)
556        elif isinstance(value, (datetime, date, time, timedelta)):
557            value = str(value)
558        if isinstance(value, basestring):
559            value = self.db.escape_string(value)
560            return "'%s'" % value
561        if isinstance(value, bool):
562            return 'true' if value else 'false'
563        if isinstance(value, float):
564            if isinf(value):
565                return "'-Infinity'" if value < 0 else "'Infinity'"
566            if isnan(value):
567                return "'NaN'"
568            return value
569        if isinstance(value, (int, long, Decimal)):
570            return value
571        if isinstance(value, list):
572            q = self.adapt_inline
573            s = '[%s]' if nested else 'ARRAY[%s]'
574            return s % ','.join(str(q(v, nested=True)) for v in value)
575        if isinstance(value, tuple):
576            q = self.adapt_inline
577            return '(%s)' % ','.join(str(q(v)) for v in value)
578        try:
579            value = value.__pg_repr__()
580        except AttributeError:
581            raise InterfaceError(
582                'Do not know how to adapt type %s' % type(value))
583        if isinstance(value, (tuple, list)):
584            value = self.adapt_inline(value)
585        return value
586
587    def parameter_list(self):
588        """Return a parameter list for parameters with known database types.
589
590        The list has an add(value, typ) method that will build up the
591        list and return either the literal value or a placeholder.
592        """
593        params = _ParameterList()
594        params.adapt = self.adapt
595        return params
596
597    def format_query(self, command, values, types=None, inline=False):
598        """Format a database query using the given values and types."""
599        if inline and types:
600            raise ValueError('Typed parameters must be sent separately')
601        params = self.parameter_list()
602        if isinstance(values, (list, tuple)):
603            if inline:
604                adapt = self.adapt_inline
605                literals = [adapt(value) for value in values]
606            else:
607                add = params.add
608                literals = []
609                append = literals.append
610                if types:
611                    if (not isinstance(types, (list, tuple)) or
612                            len(types) != len(values)):
613                        raise TypeError('The values and types do not match')
614                    for value, typ in zip(values, types):
615                        append(add(value, typ))
616                else:
617                    for value in values:
618                        append(add(value))
619            command %= tuple(literals)
620        elif isinstance(values, dict):
621            # we want to allow extra keys in the dictionary,
622            # so we first must find the values actually used in the command
623            used_values = {}
624            literals = dict.fromkeys(values, '')
625            for key in literals:
626                del literals[key]
627                try:
628                    command % literals
629                except KeyError:
630                    used_values[key] = values[key]
631                literals[key] = ''
632            values = used_values
633            if inline:
634                adapt = self.adapt_inline
635                literals = dict((key, adapt(value))
636                    for key, value in values.items())
637            else:
638                add = params.add
639                literals = {}
640                if types:
641                    if not isinstance(types, dict):
642                        raise TypeError('The values and types do not match')
643                    for key in sorted(values):
644                        literals[key] = add(values[key], types.get(key))
645                else:
646                    for key in sorted(values):
647                        literals[key] = add(values[key])
648            command %= literals
649        else:
650            raise TypeError('The values must be passed as tuple, list or dict')
651        return command, params
652
653
654def cast_bool(value):
655    """Cast a boolean value."""
656    if not get_bool():
657        return value
658    return value[0] == 't'
659
660
661def cast_json(value):
662    """Cast a JSON value."""
663    cast = get_jsondecode()
664    if not cast:
665        return value
666    return cast(value)
667
668
669def cast_num(value):
670    """Cast a numeric value."""
671    return (get_decimal() or float)(value)
672
673
674def cast_money(value):
675    """Cast a money value."""
676    point = get_decimal_point()
677    if not point:
678        return value
679    if point != '.':
680        value = value.replace(point, '.')
681    value = value.replace('(', '-')
682    value = ''.join(c for c in value if c.isdigit() or c in '.-')
683    return (get_decimal() or float)(value)
684
685
686def cast_int2vector(value):
687    """Cast an int2vector value."""
688    return [int(v) for v in value.split()]
689
690
691def cast_date(value, connection):
692    """Cast a date value."""
693    # The output format depends on the server setting DateStyle.  The default
694    # setting ISO and the setting for German are actually unambiguous.  The
695    # order of days and months in the other two settings is however ambiguous,
696    # so at least here we need to consult the setting to properly parse values.
697    if value == '-infinity':
698        return date.min
699    if value == 'infinity':
700        return date.max
701    value = value.split()
702    if value[-1] == 'BC':
703        return date.min
704    value = value[0]
705    if len(value) > 10:
706        return date.max
707    fmt = connection.date_format()
708    return datetime.strptime(value, fmt).date()
709
710
711def cast_time(value):
712    """Cast a time value."""
713    fmt = '%H:%M:%S.%f' if len(value) > 8 else '%H:%M:%S'
714    return datetime.strptime(value, fmt).time()
715
716
717_re_timezone = regex('(.*)([+-].*)')
718
719
720def cast_timetz(value):
721    """Cast a timetz value."""
722    tz = _re_timezone.match(value)
723    if tz:
724        value, tz = tz.groups()
725    else:
726        tz = '+0000'
727    fmt = '%H:%M:%S.%f' if len(value) > 8 else '%H:%M:%S'
728    if _has_timezone:
729        value += _timezone_as_offset(tz)
730        fmt += '%z'
731        return datetime.strptime(value, fmt).timetz()
732    return datetime.strptime(value, fmt).timetz().replace(
733        tzinfo=_get_timezone(tz))
734
735
736def cast_timestamp(value, connection):
737    """Cast a timestamp value."""
738    if value == '-infinity':
739        return datetime.min
740    if value == 'infinity':
741        return datetime.max
742    value = value.split()
743    if value[-1] == 'BC':
744        return datetime.min
745    fmt = connection.date_format()
746    if fmt.endswith('-%Y') and len(value) > 2:
747        value = value[1:5]
748        if len(value[3]) > 4:
749            return datetime.max
750        fmt = ['%d %b' if fmt.startswith('%d') else '%b %d',
751            '%H:%M:%S.%f' if len(value[2]) > 8 else '%H:%M:%S', '%Y']
752    else:
753        if len(value[0]) > 10:
754            return datetime.max
755        fmt = [fmt, '%H:%M:%S.%f' if len(value[1]) > 8 else '%H:%M:%S']
756    return datetime.strptime(' '.join(value), ' '.join(fmt))
757
758
759def cast_timestamptz(value, connection):
760    """Cast a timestamptz value."""
761    if value == '-infinity':
762        return datetime.min
763    if value == 'infinity':
764        return datetime.max
765    value = value.split()
766    if value[-1] == 'BC':
767        return datetime.min
768    fmt = connection.date_format()
769    if fmt.endswith('-%Y') and len(value) > 2:
770        value = value[1:]
771        if len(value[3]) > 4:
772            return datetime.max
773        fmt = ['%d %b' if fmt.startswith('%d') else '%b %d',
774            '%H:%M:%S.%f' if len(value[2]) > 8 else '%H:%M:%S', '%Y']
775        value, tz = value[:-1], value[-1]
776    else:
777        if fmt.startswith('%Y-'):
778            tz = _re_timezone.match(value[1])
779            if tz:
780                value[1], tz = tz.groups()
781            else:
782                tz = '+0000'
783        else:
784            value, tz = value[:-1], value[-1]
785        if len(value[0]) > 10:
786            return datetime.max
787        fmt = [fmt, '%H:%M:%S.%f' if len(value[1]) > 8 else '%H:%M:%S']
788    if _has_timezone:
789        value.append(_timezone_as_offset(tz))
790        fmt.append('%z')
791        return datetime.strptime(' '.join(value), ' '.join(fmt))
792    return datetime.strptime(' '.join(value), ' '.join(fmt)).replace(
793        tzinfo=_get_timezone(tz))
794
795
796_re_interval_sql_standard = regex(
797    '(?:([+-])?([0-9]+)-([0-9]+) ?)?'
798    '(?:([+-]?[0-9]+)(?!:) ?)?'
799    '(?:([+-])?([0-9]+):([0-9]+):([0-9]+)(?:\\.([0-9]+))?)?')
800
801_re_interval_postgres = regex(
802    '(?:([+-]?[0-9]+) ?years? ?)?'
803    '(?:([+-]?[0-9]+) ?mons? ?)?'
804    '(?:([+-]?[0-9]+) ?days? ?)?'
805    '(?:([+-])?([0-9]+):([0-9]+):([0-9]+)(?:\\.([0-9]+))?)?')
806
807_re_interval_postgres_verbose = regex(
808    '@ ?(?:([+-]?[0-9]+) ?years? ?)?'
809    '(?:([+-]?[0-9]+) ?mons? ?)?'
810    '(?:([+-]?[0-9]+) ?days? ?)?'
811    '(?:([+-]?[0-9]+) ?hours? ?)?'
812    '(?:([+-]?[0-9]+) ?mins? ?)?'
813    '(?:([+-])?([0-9]+)(?:\\.([0-9]+))? ?secs?)? ?(ago)?')
814
815_re_interval_iso_8601 = regex(
816    'P(?:([+-]?[0-9]+)Y)?'
817    '(?:([+-]?[0-9]+)M)?'
818    '(?:([+-]?[0-9]+)D)?'
819    '(?:T(?:([+-]?[0-9]+)H)?'
820    '(?:([+-]?[0-9]+)M)?'
821    '(?:([+-])?([0-9]+)(?:\\.([0-9]+))?S)?)?')
822
823
824def cast_interval(value):
825    """Cast an interval value."""
826    # The output format depends on the server setting IntervalStyle, but it's
827    # not necessary to consult this setting to parse it.  It's faster to just
828    # check all possible formats, and there is no ambiguity here.
829    m = _re_interval_iso_8601.match(value)
830    if m:
831        m = [d or '0' for d in m.groups()]
832        secs_ago = m.pop(5) == '-'
833        m = [int(d) for d in m]
834        years, mons, days, hours, mins, secs, usecs = m
835        if secs_ago:
836            secs = -secs
837            usecs = -usecs
838    else:
839        m = _re_interval_postgres_verbose.match(value)
840        if m:
841            m, ago = [d or '0' for d in m.groups()[:8]], m.group(9)
842            secs_ago = m.pop(5) == '-'
843            m = [-int(d) for d in m] if ago else [int(d) for d in m]
844            years, mons, days, hours, mins, secs, usecs = m
845            if secs_ago:
846                secs = - secs
847                usecs = -usecs
848        else:
849            m = _re_interval_postgres.match(value)
850            if m and any(m.groups()):
851                m = [d or '0' for d in m.groups()]
852                hours_ago = m.pop(3) == '-'
853                m = [int(d) for d in m]
854                years, mons, days, hours, mins, secs, usecs = m
855                if hours_ago:
856                    hours = -hours
857                    mins = -mins
858                    secs = -secs
859                    usecs = -usecs
860            else:
861                m = _re_interval_sql_standard.match(value)
862                if m and any(m.groups()):
863                    m = [d or '0' for d in m.groups()]
864                    years_ago = m.pop(0) == '-'
865                    hours_ago = m.pop(3) == '-'
866                    m = [int(d) for d in m]
867                    years, mons, days, hours, mins, secs, usecs = m
868                    if years_ago:
869                        years = -years
870                        mons = -mons
871                    if hours_ago:
872                        hours = -hours
873                        mins = -mins
874                        secs = -secs
875                        usecs = -usecs
876                else:
877                    raise ValueError('Cannot parse interval: %s' % value)
878    days += 365 * years + 30 * mons
879    return timedelta(days=days, hours=hours, minutes=mins,
880        seconds=secs, microseconds=usecs)
881
882
883class Typecasts(dict):
884    """Dictionary mapping database types to typecast functions.
885
886    The cast functions get passed the string representation of a value in
887    the database which they need to convert to a Python object.  The
888    passed string will never be None since NULL values are already be
889    handled before the cast function is called.
890
891    Note that the basic types are already handled by the C extension.
892    They only need to be handled here as record or array components.
893    """
894
895    # the default cast functions
896    # (str functions are ignored but have been added for faster access)
897    defaults = {'char': str, 'bpchar': str, 'name': str,
898        'text': str, 'varchar': str,
899        'bool': cast_bool, 'bytea': unescape_bytea,
900        'int2': int, 'int4': int, 'serial': int, 'int8': long, 'oid': int,
901        'hstore': cast_hstore, 'json': cast_json, 'jsonb': cast_json,
902        'float4': float, 'float8': float,
903        'numeric': cast_num, 'money': cast_money,
904        'date': cast_date, 'interval': cast_interval,
905        'time': cast_time, 'timetz': cast_timetz,
906        'timestamp': cast_timestamp, 'timestamptz': cast_timestamptz,
907        'int2vector': cast_int2vector, 'uuid': UUID,
908        'anyarray': cast_array, 'record': cast_record}
909
910    connection = None  # will be set in a connection specific instance
911
912    def __missing__(self, typ):
913        """Create a cast function if it is not cached.
914
915        Note that this class never raises a KeyError,
916        but returns None when no special cast function exists.
917        """
918        if not isinstance(typ, str):
919            raise TypeError('Invalid type: %s' % typ)
920        cast = self.defaults.get(typ)
921        if cast:
922            # store default for faster access
923            cast = self._add_connection(cast)
924            self[typ] = cast
925        elif typ.startswith('_'):
926            base_cast = self[typ[1:]]
927            cast = self.create_array_cast(base_cast)
928            if base_cast:
929                self[typ] = cast
930        else:
931            attnames = self.get_attnames(typ)
932            if attnames:
933                casts = [self[v.pgtype] for v in attnames.values()]
934                cast = self.create_record_cast(typ, attnames, casts)
935                self[typ] = cast
936        return cast
937
938    @staticmethod
939    def _needs_connection(func):
940        """Check if a typecast function needs a connection argument."""
941        try:
942            args = get_args(func)
943        except (TypeError, ValueError):
944            return False
945        else:
946            return 'connection' in args[1:]
947
948    def _add_connection(self, cast):
949        """Add a connection argument to the typecast function if necessary."""
950        if not self.connection or not self._needs_connection(cast):
951            return cast
952        return partial(cast, connection=self.connection)
953
954    def get(self, typ, default=None):
955        """Get the typecast function for the given database type."""
956        return self[typ] or default
957
958    def set(self, typ, cast):
959        """Set a typecast function for the specified database type(s)."""
960        if isinstance(typ, basestring):
961            typ = [typ]
962        if cast is None:
963            for t in typ:
964                self.pop(t, None)
965                self.pop('_%s' % t, None)
966        else:
967            if not callable(cast):
968                raise TypeError("Cast parameter must be callable")
969            for t in typ:
970                self[t] = self._add_connection(cast)
971                self.pop('_%s' % t, None)
972
973    def reset(self, typ=None):
974        """Reset the typecasts for the specified type(s) to their defaults.
975
976        When no type is specified, all typecasts will be reset.
977        """
978        if typ is None:
979            self.clear()
980        else:
981            if isinstance(typ, basestring):
982                typ = [typ]
983            for t in typ:
984                self.pop(t, None)
985
986    @classmethod
987    def get_default(cls, typ):
988        """Get the default typecast function for the given database type."""
989        return cls.defaults.get(typ)
990
991    @classmethod
992    def set_default(cls, typ, cast):
993        """Set a default typecast function for the given database type(s)."""
994        if isinstance(typ, basestring):
995            typ = [typ]
996        defaults = cls.defaults
997        if cast is None:
998            for t in typ:
999                defaults.pop(t, None)
1000                defaults.pop('_%s' % t, None)
1001        else:
1002            if not callable(cast):
1003                raise TypeError("Cast parameter must be callable")
1004            for t in typ:
1005                defaults[t] = cast
1006                defaults.pop('_%s' % t, None)
1007
1008    def get_attnames(self, typ):
1009        """Return the fields for the given record type.
1010
1011        This method will be replaced with the get_attnames() method of DbTypes.
1012        """
1013        return {}
1014
1015    def dateformat(self):
1016        """Return the current date format.
1017
1018        This method will be replaced with the dateformat() method of DbTypes.
1019        """
1020        return '%Y-%m-%d'
1021
1022    def create_array_cast(self, basecast):
1023        """Create an array typecast for the given base cast."""
1024        def cast(v):
1025            return cast_array(v, basecast)
1026        return cast
1027
1028    def create_record_cast(self, name, fields, casts):
1029        """Create a named record typecast for the given fields and casts."""
1030        record = namedtuple(name, fields)
1031        def cast(v):
1032            return record(*cast_record(v, casts))
1033        return cast
1034
1035
1036def get_typecast(typ):
1037    """Get the global typecast function for the given database type(s)."""
1038    return Typecasts.get_default(typ)
1039
1040
1041def set_typecast(typ, cast):
1042    """Set a global typecast function for the given database type(s).
1043
1044    Note that connections cache cast functions. To be sure a global change
1045    is picked up by a running connection, call db.db_types.reset_typecast().
1046    """
1047    Typecasts.set_default(typ, cast)
1048
1049
1050class DbType(str):
1051    """Class augmenting the simple type name with additional info.
1052
1053    The following additional information is provided:
1054
1055        oid: the PostgreSQL type OID
1056        pgtype: the PostgreSQL type name
1057        regtype: the regular type name
1058        simple: the simple PyGreSQL type name
1059        typtype: b = base type, c = composite type etc.
1060        category: A = Array, b = Boolean, C = Composite etc.
1061        delim: delimiter for array types
1062        relid: corresponding table for composite types
1063        attnames: attributes for composite types
1064    """
1065
1066    @property
1067    def attnames(self):
1068        """Get names and types of the fields of a composite type."""
1069        return self._get_attnames(self)
1070
1071
1072class DbTypes(dict):
1073    """Cache for PostgreSQL data types.
1074
1075    This cache maps type OIDs and names to DbType objects containing
1076    information on the associated database type.
1077    """
1078
1079    _num_types = frozenset('int float num money'
1080        ' int2 int4 int8 float4 float8 numeric money'.split())
1081
1082    def __init__(self, db):
1083        """Initialize type cache for connection."""
1084        super(DbTypes, self).__init__()
1085        self._db = weakref.proxy(db)
1086        self._regtypes = False
1087        self._typecasts = Typecasts()
1088        self._typecasts.get_attnames = self.get_attnames
1089        self._typecasts.connection = self._db
1090        if db.server_version < 80400:
1091            # older remote databases (not officially supported)
1092            self._query_pg_type = (
1093                "SELECT oid, typname, typname::text::regtype,"
1094                " typtype, null as typcategory, typdelim, typrelid"
1095                " FROM pg_type WHERE oid=%s::regtype")
1096        else:
1097            self._query_pg_type = (
1098                "SELECT oid, typname, typname::regtype,"
1099                " typtype, typcategory, typdelim, typrelid"
1100                " FROM pg_type WHERE oid=%s::regtype")
1101
1102    def add(self, oid, pgtype, regtype,
1103               typtype, category, delim, relid):
1104        """Create a PostgreSQL type name with additional info."""
1105        if oid in self:
1106            return self[oid]
1107        simple = 'record' if relid else _simpletypes[pgtype]
1108        typ = DbType(regtype if self._regtypes else simple)
1109        typ.oid = oid
1110        typ.simple = simple
1111        typ.pgtype = pgtype
1112        typ.regtype = regtype
1113        typ.typtype = typtype
1114        typ.category = category
1115        typ.delim = delim
1116        typ.relid = relid
1117        typ._get_attnames = self.get_attnames
1118        return typ
1119
1120    def __missing__(self, key):
1121        """Get the type info from the database if it is not cached."""
1122        try:
1123            q = self._query_pg_type % (_quote_if_unqualified('$1', key),)
1124            res = self._db.query(q, (key,)).getresult()
1125        except ProgrammingError:
1126            res = None
1127        if not res:
1128            raise KeyError('Type %s could not be found' % key)
1129        res = res[0]
1130        typ = self.add(*res)
1131        self[typ.oid] = self[typ.pgtype] = typ
1132        return typ
1133
1134    def get(self, key, default=None):
1135        """Get the type even if it is not cached."""
1136        try:
1137            return self[key]
1138        except KeyError:
1139            return default
1140
1141    def get_attnames(self, typ):
1142        """Get names and types of the fields of a composite type."""
1143        if not isinstance(typ, DbType):
1144            typ = self.get(typ)
1145            if not typ:
1146                return None
1147        if not typ.relid:
1148            return None
1149        return self._db.get_attnames(typ.relid, with_oid=False)
1150
1151    def get_typecast(self, typ):
1152        """Get the typecast function for the given database type."""
1153        return self._typecasts.get(typ)
1154
1155    def set_typecast(self, typ, cast):
1156        """Set a typecast function for the specified database type(s)."""
1157        self._typecasts.set(typ, cast)
1158
1159    def reset_typecast(self, typ=None):
1160        """Reset the typecast function for the specified database type(s)."""
1161        self._typecasts.reset(typ)
1162
1163    def typecast(self, value, typ):
1164        """Cast the given value according to the given database type."""
1165        if value is None:
1166            # for NULL values, no typecast is necessary
1167            return None
1168        if not isinstance(typ, DbType):
1169            typ = self.get(typ)
1170            if typ:
1171                typ = typ.pgtype
1172        cast = self.get_typecast(typ) if typ else None
1173        if not cast or cast is str:
1174            # no typecast is necessary
1175            return value
1176        return cast(value)
1177
1178
1179def _namedresult(q):
1180    """Get query result as named tuples."""
1181    row = namedtuple('Row', q.listfields())
1182    return [row(*r) for r in q.getresult()]
1183
1184
1185class _MemoryQuery:
1186    """Class that embodies a given query result."""
1187
1188    def __init__(self, result, fields):
1189        """Create query from given result rows and field names."""
1190        self.result = result
1191        self.fields = fields
1192
1193    def listfields(self):
1194        """Return the stored field names of this query."""
1195        return self.fields
1196
1197    def getresult(self):
1198        """Return the stored result of this query."""
1199        return self.result
1200
1201
1202def _db_error(msg, cls=DatabaseError):
1203    """Return DatabaseError with empty sqlstate attribute."""
1204    error = cls(msg)
1205    error.sqlstate = None
1206    return error
1207
1208
1209def _int_error(msg):
1210    """Return InternalError."""
1211    return _db_error(msg, InternalError)
1212
1213
1214def _prg_error(msg):
1215    """Return ProgrammingError."""
1216    return _db_error(msg, ProgrammingError)
1217
1218
1219# Initialize the C module
1220
1221set_namedresult(_namedresult)
1222set_decimal(Decimal)
1223set_jsondecode(jsondecode)
1224
1225
1226# The notification handler
1227
1228class NotificationHandler(object):
1229    """A PostgreSQL client-side asynchronous notification handler."""
1230
1231    def __init__(self, db, event, callback=None,
1232            arg_dict=None, timeout=None, stop_event=None):
1233        """Initialize the notification handler.
1234
1235        You must pass a PyGreSQL database connection, the name of an
1236        event (notification channel) to listen for and a callback function.
1237
1238        You can also specify a dictionary arg_dict that will be passed as
1239        the single argument to the callback function, and a timeout value
1240        in seconds (a floating point number denotes fractions of seconds).
1241        If it is absent or None, the callers will never time out.  If the
1242        timeout is reached, the callback function will be called with a
1243        single argument that is None.  If you set the timeout to zero,
1244        the handler will poll notifications synchronously and return.
1245
1246        You can specify the name of the event that will be used to signal
1247        the handler to stop listening as stop_event. By default, it will
1248        be the event name prefixed with 'stop_'.
1249        """
1250        self.db = db
1251        self.event = event
1252        self.stop_event = stop_event or 'stop_%s' % event
1253        self.listening = False
1254        self.callback = callback
1255        if arg_dict is None:
1256            arg_dict = {}
1257        self.arg_dict = arg_dict
1258        self.timeout = timeout
1259
1260    def __del__(self):
1261        self.unlisten()
1262
1263    def close(self):
1264        """Stop listening and close the connection."""
1265        if self.db:
1266            self.unlisten()
1267            self.db.close()
1268            self.db = None
1269
1270    def listen(self):
1271        """Start listening for the event and the stop event."""
1272        if not self.listening:
1273            self.db.query('listen "%s"' % self.event)
1274            self.db.query('listen "%s"' % self.stop_event)
1275            self.listening = True
1276
1277    def unlisten(self):
1278        """Stop listening for the event and the stop event."""
1279        if self.listening:
1280            self.db.query('unlisten "%s"' % self.event)
1281            self.db.query('unlisten "%s"' % self.stop_event)
1282            self.listening = False
1283
1284    def notify(self, db=None, stop=False, payload=None):
1285        """Generate a notification.
1286
1287        Optionally, you can pass a payload with the notification.
1288
1289        If you set the stop flag, a stop notification will be sent that
1290        will cause the handler to stop listening.
1291
1292        Note: If the notification handler is running in another thread, you
1293        must pass a different database connection since PyGreSQL database
1294        connections are not thread-safe.
1295        """
1296        if self.listening:
1297            if not db:
1298                db = self.db
1299            q = 'notify "%s"' % (self.stop_event if stop else self.event)
1300            if payload:
1301                q += ", '%s'" % payload
1302            return db.query(q)
1303
1304    def __call__(self):
1305        """Invoke the notification handler.
1306
1307        The handler is a loop that listens for notifications on the event
1308        and stop event channels.  When either of these notifications are
1309        received, its associated 'pid', 'event' and 'extra' (the payload
1310        passed with the notification) are inserted into its arg_dict
1311        dictionary and the callback is invoked with this dictionary as
1312        a single argument.  When the handler receives a stop event, it
1313        stops listening to both events and return.
1314
1315        In the special case that the timeout of the handler has been set
1316        to zero, the handler will poll all events synchronously and return.
1317        If will keep listening until it receives a stop event.
1318
1319        Note: If you run this loop in another thread, don't use the same
1320        database connection for database operations in the main thread.
1321        """
1322        self.listen()
1323        poll = self.timeout == 0
1324        if not poll:
1325            rlist = [self.db.fileno()]
1326        while self.listening:
1327            if poll or select.select(rlist, [], [], self.timeout)[0]:
1328                while self.listening:
1329                    notice = self.db.getnotify()
1330                    if not notice:  # no more messages
1331                        break
1332                    event, pid, extra = notice
1333                    if event not in (self.event, self.stop_event):
1334                        self.unlisten()
1335                        raise _db_error(
1336                            'Listening for "%s" and "%s", but notified of "%s"'
1337                            % (self.event, self.stop_event, event))
1338                    if event == self.stop_event:
1339                        self.unlisten()
1340                    self.arg_dict.update(pid=pid, event=event, extra=extra)
1341                    self.callback(self.arg_dict)
1342                if poll:
1343                    break
1344            else:   # we timed out
1345                self.unlisten()
1346                self.callback(None)
1347
1348
1349def pgnotify(*args, **kw):
1350    """Same as NotificationHandler, under the traditional name."""
1351    warnings.warn("pgnotify is deprecated, use NotificationHandler instead",
1352        DeprecationWarning, stacklevel=2)
1353    return NotificationHandler(*args, **kw)
1354
1355
1356# The actual PostGreSQL database connection interface:
1357
1358class DB:
1359    """Wrapper class for the _pg connection type."""
1360
1361    db = None  # invalid fallback for underlying connection
1362
1363    def __init__(self, *args, **kw):
1364        """Create a new connection
1365
1366        You can pass either the connection parameters or an existing
1367        _pg or pgdb connection. This allows you to use the methods
1368        of the classic pg interface with a DB-API 2 pgdb connection.
1369        """
1370        if not args and len(kw) == 1:
1371            db = kw.get('db')
1372        elif not kw and len(args) == 1:
1373            db = args[0]
1374        else:
1375            db = None
1376        if db:
1377            if isinstance(db, DB):
1378                db = db.db
1379            else:
1380                try:
1381                    db = db._cnx
1382                except AttributeError:
1383                    pass
1384        if not db or not hasattr(db, 'db') or not hasattr(db, 'query'):
1385            db = connect(*args, **kw)
1386            self._closeable = True
1387        else:
1388            self._closeable = False
1389        self.db = db
1390        self.dbname = db.db
1391        self._regtypes = False
1392        self._attnames = {}
1393        self._pkeys = {}
1394        self._privileges = {}
1395        self._args = args, kw
1396        self.adapter = Adapter(self)
1397        self.dbtypes = DbTypes(self)
1398        if db.server_version < 80400:
1399            # support older remote data bases
1400            self._query_attnames = (
1401                "SELECT a.attname, t.oid, t.typname, t.typname::text::regtype,"
1402                " t.typtype, null as typcategory, t.typdelim, t.typrelid"
1403                " FROM pg_attribute a"
1404                " JOIN pg_type t ON t.oid = a.atttypid"
1405                " WHERE a.attrelid = %s::regclass AND %s"
1406                " AND NOT a.attisdropped ORDER BY a.attnum")
1407        else:
1408            self._query_attnames = (
1409                "SELECT a.attname, t.oid, t.typname, t.typname::regtype,"
1410                " t.typtype, t.typcategory, t.typdelim, t.typrelid"
1411                " FROM pg_attribute a"
1412                " JOIN pg_type t ON t.oid = a.atttypid"
1413                " WHERE a.attrelid = %s::regclass AND %s"
1414                " AND NOT a.attisdropped ORDER BY a.attnum")
1415        db.set_cast_hook(self.dbtypes.typecast)
1416        self.debug = None  # For debugging scripts, this can be set
1417            # * to a string format specification (e.g. in CGI set to "%s<BR>"),
1418            # * to a file object to write debug statements or
1419            # * to a callable object which takes a string argument
1420            # * to any other true value to just print debug statements
1421
1422    def __getattr__(self, name):
1423        # All undefined members are same as in underlying connection:
1424        if self.db:
1425            return getattr(self.db, name)
1426        else:
1427            raise _int_error('Connection is not valid')
1428
1429    def __dir__(self):
1430        # Custom dir function including the attributes of the connection:
1431        attrs = set(self.__class__.__dict__)
1432        attrs.update(self.__dict__)
1433        attrs.update(dir(self.db))
1434        return sorted(attrs)
1435
1436    # Context manager methods
1437
1438    def __enter__(self):
1439        """Enter the runtime context. This will start a transaction."""
1440        self.begin()
1441        return self
1442
1443    def __exit__(self, et, ev, tb):
1444        """Exit the runtime context. This will end the transaction."""
1445        if et is None and ev is None and tb is None:
1446            self.commit()
1447        else:
1448            self.rollback()
1449
1450    def __del__(self):
1451        try:
1452            db = self.db
1453        except AttributeError:
1454            db = None
1455        if db:
1456            db.set_cast_hook(None)
1457            if self._closeable:
1458                db.close()
1459
1460    # Auxiliary methods
1461
1462    def _do_debug(self, *args):
1463        """Print a debug message"""
1464        if self.debug:
1465            s = '\n'.join(str(arg) for arg in args)
1466            if isinstance(self.debug, basestring):
1467                print(self.debug % s)
1468            elif hasattr(self.debug, 'write'):
1469                self.debug.write(s + '\n')
1470            elif callable(self.debug):
1471                self.debug(s)
1472            else:
1473                print(s)
1474
1475    def _escape_qualified_name(self, s):
1476        """Escape a qualified name.
1477
1478        Escapes the name for use as an SQL identifier, unless the
1479        name contains a dot, in which case the name is ambiguous
1480        (could be a qualified name or just a name with a dot in it)
1481        and must be quoted manually by the caller.
1482        """
1483        if '.' not in s:
1484            s = self.escape_identifier(s)
1485        return s
1486
1487    @staticmethod
1488    def _make_bool(d):
1489        """Get boolean value corresponding to d."""
1490        return bool(d) if get_bool() else ('t' if d else 'f')
1491
1492    def _list_params(self, params):
1493        """Create a human readable parameter list."""
1494        return ', '.join('$%d=%r' % (n, v) for n, v in enumerate(params, 1))
1495
1496    # Public methods
1497
1498    # escape_string and escape_bytea exist as methods,
1499    # so we define unescape_bytea as a method as well
1500    unescape_bytea = staticmethod(unescape_bytea)
1501
1502    def decode_json(self, s):
1503        """Decode a JSON string coming from the database."""
1504        return (get_jsondecode() or jsondecode)(s)
1505
1506    def encode_json(self, d):
1507        """Encode a JSON string for use within SQL."""
1508        return jsonencode(d)
1509
1510    def close(self):
1511        """Close the database connection."""
1512        # Wraps shared library function so we can track state.
1513        if self._closeable:
1514            if self.db:
1515                self.db.set_cast_hook(None)
1516                self.db.close()
1517                self.db = None
1518            else:
1519                raise _int_error('Connection already closed')
1520
1521    def reset(self):
1522        """Reset connection with current parameters.
1523
1524        All derived queries and large objects derived from this connection
1525        will not be usable after this call.
1526
1527        """
1528        if self.db:
1529            self.db.reset()
1530        else:
1531            raise _int_error('Connection already closed')
1532
1533    def reopen(self):
1534        """Reopen connection to the database.
1535
1536        Used in case we need another connection to the same database.
1537        Note that we can still reopen a database that we have closed.
1538
1539        """
1540        # There is no such shared library function.
1541        if self._closeable:
1542            db = connect(*self._args[0], **self._args[1])
1543            if self.db:
1544                self.db.set_cast_hook(None)
1545                self.db.close()
1546            self.db = db
1547
1548    def begin(self, mode=None):
1549        """Begin a transaction."""
1550        qstr = 'BEGIN'
1551        if mode:
1552            qstr += ' ' + mode
1553        return self.query(qstr)
1554
1555    start = begin
1556
1557    def commit(self):
1558        """Commit the current transaction."""
1559        return self.query('COMMIT')
1560
1561    end = commit
1562
1563    def rollback(self, name=None):
1564        """Roll back the current transaction."""
1565        qstr = 'ROLLBACK'
1566        if name:
1567            qstr += ' TO ' + name
1568        return self.query(qstr)
1569
1570    abort = rollback
1571
1572    def savepoint(self, name):
1573        """Define a new savepoint within the current transaction."""
1574        return self.query('SAVEPOINT ' + name)
1575
1576    def release(self, name):
1577        """Destroy a previously defined savepoint."""
1578        return self.query('RELEASE ' + name)
1579
1580    def get_parameter(self, parameter):
1581        """Get the value of a run-time parameter.
1582
1583        If the parameter is a string, the return value will also be a string
1584        that is the current setting of the run-time parameter with that name.
1585
1586        You can get several parameters at once by passing a list, set or dict.
1587        When passing a list of parameter names, the return value will be a
1588        corresponding list of parameter settings.  When passing a set of
1589        parameter names, a new dict will be returned, mapping these parameter
1590        names to their settings.  Finally, if you pass a dict as parameter,
1591        its values will be set to the current parameter settings corresponding
1592        to its keys.
1593
1594        By passing the special name 'all' as the parameter, you can get a dict
1595        of all existing configuration parameters.
1596        """
1597        if isinstance(parameter, basestring):
1598            parameter = [parameter]
1599            values = None
1600        elif isinstance(parameter, (list, tuple)):
1601            values = []
1602        elif isinstance(parameter, (set, frozenset)):
1603            values = {}
1604        elif isinstance(parameter, dict):
1605            values = parameter
1606        else:
1607            raise TypeError(
1608                'The parameter must be a string, list, set or dict')
1609        if not parameter:
1610            raise TypeError('No parameter has been specified')
1611        params = {} if isinstance(values, dict) else []
1612        for key in parameter:
1613            param = key.strip().lower() if isinstance(
1614                key, basestring) else None
1615            if not param:
1616                raise TypeError('Invalid parameter')
1617            if param == 'all':
1618                q = 'SHOW ALL'
1619                values = self.db.query(q).getresult()
1620                values = dict(value[:2] for value in values)
1621                break
1622            if isinstance(values, dict):
1623                params[param] = key
1624            else:
1625                params.append(param)
1626        else:
1627            for param in params:
1628                q = 'SHOW %s' % (param,)
1629                value = self.db.query(q).getresult()[0][0]
1630                if values is None:
1631                    values = value
1632                elif isinstance(values, list):
1633                    values.append(value)
1634                else:
1635                    values[params[param]] = value
1636        return values
1637
1638    def set_parameter(self, parameter, value=None, local=False):
1639        """Set the value of a run-time parameter.
1640
1641        If the parameter and the value are strings, the run-time parameter
1642        will be set to that value.  If no value or None is passed as a value,
1643        then the run-time parameter will be restored to its default value.
1644
1645        You can set several parameters at once by passing a list of parameter
1646        names, together with a single value that all parameters should be
1647        set to or with a corresponding list of values.  You can also pass
1648        the parameters as a set if you only provide a single value.
1649        Finally, you can pass a dict with parameter names as keys.  In this
1650        case, you should not pass a value, since the values for the parameters
1651        will be taken from the dict.
1652
1653        By passing the special name 'all' as the parameter, you can reset
1654        all existing settable run-time parameters to their default values.
1655
1656        If you set local to True, then the command takes effect for only the
1657        current transaction.  After commit() or rollback(), the session-level
1658        setting takes effect again.  Setting local to True will appear to
1659        have no effect if it is executed outside a transaction, since the
1660        transaction will end immediately.
1661        """
1662        if isinstance(parameter, basestring):
1663            parameter = {parameter: value}
1664        elif isinstance(parameter, (list, tuple)):
1665            if isinstance(value, (list, tuple)):
1666                parameter = dict(zip(parameter, value))
1667            else:
1668                parameter = dict.fromkeys(parameter, value)
1669        elif isinstance(parameter, (set, frozenset)):
1670            if isinstance(value, (list, tuple, set, frozenset)):
1671                value = set(value)
1672                if len(value) == 1:
1673                    value = value.pop()
1674            if not(value is None or isinstance(value, basestring)):
1675                raise ValueError('A single value must be specified'
1676                    ' when parameter is a set')
1677            parameter = dict.fromkeys(parameter, value)
1678        elif isinstance(parameter, dict):
1679            if value is not None:
1680                raise ValueError('A value must not be specified'
1681                    ' when parameter is a dictionary')
1682        else:
1683            raise TypeError(
1684                'The parameter must be a string, list, set or dict')
1685        if not parameter:
1686            raise TypeError('No parameter has been specified')
1687        params = {}
1688        for key, value in parameter.items():
1689            param = key.strip().lower() if isinstance(
1690                key, basestring) else None
1691            if not param:
1692                raise TypeError('Invalid parameter')
1693            if param == 'all':
1694                if value is not None:
1695                    raise ValueError('A value must ot be specified'
1696                        " when parameter is 'all'")
1697                params = {'all': None}
1698                break
1699            params[param] = value
1700        local = ' LOCAL' if local else ''
1701        for param, value in params.items():
1702            if value is None:
1703                q = 'RESET%s %s' % (local, param)
1704            else:
1705                q = 'SET%s %s TO %s' % (local, param, value)
1706            self._do_debug(q)
1707            self.db.query(q)
1708
1709    def query(self, command, *args):
1710        """Execute a SQL command string.
1711
1712        This method simply sends a SQL query to the database.  If the query is
1713        an insert statement that inserted exactly one row into a table that
1714        has OIDs, the return value is the OID of the newly inserted row.
1715        If the query is an update or delete statement, or an insert statement
1716        that did not insert exactly one row in a table with OIDs, then the
1717        number of rows affected is returned as a string.  If it is a statement
1718        that returns rows as a result (usually a select statement, but maybe
1719        also an "insert/update ... returning" statement), this method returns
1720        a Query object that can be accessed via getresult() or dictresult()
1721        or simply printed.  Otherwise, it returns `None`.
1722
1723        The query can contain numbered parameters of the form $1 in place
1724        of any data constant.  Arguments given after the query string will
1725        be substituted for the corresponding numbered parameter.  Parameter
1726        values can also be given as a single list or tuple argument.
1727        """
1728        # Wraps shared library function for debugging.
1729        if not self.db:
1730            raise _int_error('Connection is not valid')
1731        if args:
1732            self._do_debug(command, args)
1733            return self.db.query(command, args)
1734        self._do_debug(command)
1735        return self.db.query(command)
1736
1737    def query_formatted(self, command, parameters, types=None, inline=False):
1738        """Execute a formatted SQL command string.
1739
1740        Similar to query, but using Python format placeholders of the form
1741        %s or %(names)s instead of PostgreSQL placeholders of the form $1.
1742        The parameters must be passed as a tuple, list or dict.  You can
1743        also pass a corresponding tuple, list or dict of database types in
1744        order to format the parameters properly in case there is ambiguity.
1745
1746        If you set inline to True, the parameters will be sent to the database
1747        embedded in the SQL command, otherwise they will be sent separately.
1748        """
1749        return self.query(*self.adapter.format_query(
1750            command, parameters, types, inline))
1751
1752    def pkey(self, table, composite=False, flush=False):
1753        """Get or set the primary key of a table.
1754
1755        Single primary keys are returned as strings unless you
1756        set the composite flag.  Composite primary keys are always
1757        represented as tuples.  Note that this raises a KeyError
1758        if the table does not have a primary key.
1759
1760        If flush is set then the internal cache for primary keys will
1761        be flushed.  This may be necessary after the database schema or
1762        the search path has been changed.
1763        """
1764        pkeys = self._pkeys
1765        if flush:
1766            pkeys.clear()
1767            self._do_debug('The pkey cache has been flushed')
1768        try:  # cache lookup
1769            pkey = pkeys[table]
1770        except KeyError:  # cache miss, check the database
1771            q = ("SELECT a.attname, a.attnum, i.indkey FROM pg_index i"
1772                " JOIN pg_attribute a ON a.attrelid = i.indrelid"
1773                " AND a.attnum = ANY(i.indkey)"
1774                " AND NOT a.attisdropped"
1775                " WHERE i.indrelid=%s::regclass"
1776                " AND i.indisprimary ORDER BY a.attnum") % (
1777                    _quote_if_unqualified('$1', table),)
1778            pkey = self.db.query(q, (table,)).getresult()
1779            if not pkey:
1780                raise KeyError('Table %s has no primary key' % table)
1781            # we want to use the order defined in the primary key index here,
1782            # not the order as defined by the columns in the table
1783            if len(pkey) > 1:
1784                indkey = pkey[0][2]
1785                pkey = sorted(pkey, key=lambda row: indkey.index(row[1]))
1786                pkey = tuple(row[0] for row in pkey)
1787            else:
1788                pkey = pkey[0][0]
1789            pkeys[table] = pkey  # cache it
1790        if composite and not isinstance(pkey, tuple):
1791            pkey = (pkey,)
1792        return pkey
1793
1794    def get_databases(self):
1795        """Get list of databases in the system."""
1796        return [s[0] for s in
1797            self.db.query('SELECT datname FROM pg_database').getresult()]
1798
1799    def get_relations(self, kinds=None, system=False):
1800        """Get list of relations in connected database of specified kinds.
1801
1802        If kinds is None or empty, all kinds of relations are returned.
1803        Otherwise kinds can be a string or sequence of type letters
1804        specifying which kind of relations you want to list.
1805
1806        Set the system flag if you want to get the system relations as well.
1807        """
1808        where = []
1809        if kinds:
1810            where.append("r.relkind IN (%s)" %
1811                ','.join("'%s'" % k for k in kinds))
1812        if not system:
1813            where.append("s.nspname NOT SIMILAR"
1814                " TO 'pg/_%|information/_schema' ESCAPE '/'")
1815        where = " WHERE %s" % ' AND '.join(where) if where else ''
1816        q = ("SELECT quote_ident(s.nspname)||'.'||quote_ident(r.relname)"
1817            " FROM pg_class r"
1818            " JOIN pg_namespace s ON s.oid = r.relnamespace%s"
1819            " ORDER BY s.nspname, r.relname") % where
1820        return [r[0] for r in self.db.query(q).getresult()]
1821
1822    def get_tables(self, system=False):
1823        """Return list of tables in connected database.
1824
1825        Set the system flag if you want to get the system tables as well.
1826        """
1827        return self.get_relations('r', system)
1828
1829    def get_attnames(self, table, with_oid=True, flush=False):
1830        """Given the name of a table, dig out the set of attribute names.
1831
1832        Returns a read-only dictionary of attribute names (the names are
1833        the keys, the values are the names of the attributes' types)
1834        with the column names in the proper order if you iterate over it.
1835
1836        If flush is set, then the internal cache for attribute names will
1837        be flushed. This may be necessary after the database schema or
1838        the search path has been changed.
1839
1840        By default, only a limited number of simple types will be returned.
1841        You can get the regular types after calling use_regtypes(True).
1842        """
1843        attnames = self._attnames
1844        if flush:
1845            attnames.clear()
1846            self._do_debug('The attnames cache has been flushed')
1847        try:  # cache lookup
1848            names = attnames[table]
1849        except KeyError:  # cache miss, check the database
1850            q = "a.attnum > 0"
1851            if with_oid:
1852                q = "(%s OR a.attname = 'oid')" % q
1853            q = self._query_attnames % (_quote_if_unqualified('$1', table), q)
1854            names = self.db.query(q, (table,)).getresult()
1855            types = self.dbtypes
1856            names = ((name[0], types.add(*name[1:])) for name in names)
1857            names = AttrDict(names)
1858            attnames[table] = names  # cache it
1859        return names
1860
1861    def use_regtypes(self, regtypes=None):
1862        """Use regular type names instead of simplified type names."""
1863        if regtypes is None:
1864            return self.dbtypes._regtypes
1865        else:
1866            regtypes = bool(regtypes)
1867            if regtypes != self.dbtypes._regtypes:
1868                self.dbtypes._regtypes = regtypes
1869                self._attnames.clear()
1870                self.dbtypes.clear()
1871            return regtypes
1872
1873    def has_table_privilege(self, table, privilege='select', flush=False):
1874        """Check whether current user has specified table privilege.
1875
1876        If flush is set, then the internal cache for table privileges will
1877        be flushed. This may be necessary after privileges have been changed.
1878        """
1879        privileges = self._privileges
1880        if flush:
1881            privileges.clear()
1882            self._do_debug('The privileges cache has been flushed')
1883        privilege = privilege.lower()
1884        try:  # ask cache
1885            ret = privileges[table, privilege]
1886        except KeyError:  # cache miss, ask the database
1887            q = "SELECT has_table_privilege(%s, $2)" % (
1888                _quote_if_unqualified('$1', table),)
1889            q = self.db.query(q, (table, privilege))
1890            ret = q.getresult()[0][0] == self._make_bool(True)
1891            privileges[table, privilege] = ret  # cache it
1892        return ret
1893
1894    def get(self, table, row, keyname=None):
1895        """Get a row from a database table or view.
1896
1897        This method is the basic mechanism to get a single row.  It assumes
1898        that the keyname specifies a unique row.  It must be the name of a
1899        single column or a tuple of column names.  If the keyname is not
1900        specified, then the primary key for the table is used.
1901
1902        If row is a dictionary, then the value for the key is taken from it.
1903        Otherwise, the row must be a single value or a tuple of values
1904        corresponding to the passed keyname or primary key.  The fetched row
1905        from the table will be returned as a new dictionary or used to replace
1906        the existing values when row was passed as a dictionary.
1907
1908        The OID is also put into the dictionary if the table has one, but
1909        in order to allow the caller to work with multiple tables, it is
1910        munged as "oid(table)" using the actual name of the table.
1911        """
1912        if table.endswith('*'):  # hint for descendant tables can be ignored
1913            table = table[:-1].rstrip()
1914        attnames = self.get_attnames(table)
1915        qoid = _oid_key(table) if 'oid' in attnames else None
1916        if keyname and isinstance(keyname, basestring):
1917            keyname = (keyname,)
1918        if qoid and isinstance(row, dict) and qoid in row and 'oid' not in row:
1919            row['oid'] = row[qoid]
1920        if not keyname:
1921            try:  # if keyname is not specified, try using the primary key
1922                keyname = self.pkey(table, True)
1923            except KeyError:  # the table has no primary key
1924                # try using the oid instead
1925                if qoid and isinstance(row, dict) and 'oid' in row:
1926                    keyname = ('oid',)
1927                else:
1928                    raise _prg_error('Table %s has no primary key' % table)
1929            else:  # the table has a primary key
1930                # check whether all key columns have values
1931                if isinstance(row, dict) and not set(keyname).issubset(row):
1932                    # try using the oid instead
1933                    if qoid and 'oid' in row:
1934                        keyname = ('oid',)
1935                    else:
1936                        raise KeyError(
1937                            'Missing value in row for specified keyname')
1938        if not isinstance(row, dict):
1939            if not isinstance(row, (tuple, list)):
1940                row = [row]
1941            if len(keyname) != len(row):
1942                raise KeyError(
1943                    'Differing number of items in keyname and row')
1944            row = dict(zip(keyname, row))
1945        params = self.adapter.parameter_list()
1946        adapt = params.add
1947        col = self.escape_identifier
1948        what = 'oid, *' if qoid else '*'
1949        where = ' AND '.join('%s = %s' % (
1950            col(k), adapt(row[k], attnames[k])) for k in keyname)
1951        if 'oid' in row:
1952            if qoid:
1953                row[qoid] = row['oid']
1954            del row['oid']
1955        q = 'SELECT %s FROM %s WHERE %s LIMIT 1' % (
1956            what, self._escape_qualified_name(table), where)
1957        self._do_debug(q, params)
1958        q = self.db.query(q, params)
1959        res = q.dictresult()
1960        if not res:
1961            raise _db_error('No such record in %s\nwhere %s\nwith %s' % (
1962                table, where, self._list_params(params)))
1963        for n, value in res[0].items():
1964            if qoid and n == 'oid':
1965                n = qoid
1966            row[n] = value
1967        return row
1968
1969    def insert(self, table, row=None, **kw):
1970        """Insert a row into a database table.
1971
1972        This method inserts a row into a table.  The name of the table must
1973        be passed as the first parameter.  The other parameters are used for
1974        providing the data of the row that shall be inserted into the table.
1975        If a dictionary is supplied as the second parameter, it starts with
1976        that.  Otherwise it uses a blank dictionary. Either way the dictionary
1977        is updated from the keywords.
1978
1979        The dictionary is then reloaded with the values actually inserted in
1980        order to pick up values modified by rules, triggers, etc.
1981        """
1982        if table.endswith('*'):  # hint for descendant tables can be ignored
1983            table = table[:-1].rstrip()
1984        if row is None:
1985            row = {}
1986        row.update(kw)
1987        if 'oid' in row:
1988            del row['oid']  # do not insert oid
1989        attnames = self.get_attnames(table)
1990        qoid = _oid_key(table) if 'oid' in attnames else None
1991        params = self.adapter.parameter_list()
1992        adapt = params.add
1993        col = self.escape_identifier
1994        names, values = [], []
1995        for n in attnames:
1996            if n in row:
1997                names.append(col(n))
1998                values.append(adapt(row[n], attnames[n]))
1999        if not names:
2000            raise _prg_error('No column found that can be inserted')
2001        names, values = ', '.join(names), ', '.join(values)
2002        ret = 'oid, *' if qoid else '*'
2003        q = 'INSERT INTO %s (%s) VALUES (%s) RETURNING %s' % (
2004            self._escape_qualified_name(table), names, values, ret)
2005        self._do_debug(q, params)
2006        q = self.db.query(q, params)
2007        res = q.dictresult()
2008        if res:  # this should always be true
2009            for n, value in res[0].items():
2010                if qoid and n == 'oid':
2011                    n = qoid
2012                row[n] = value
2013        return row
2014
2015    def update(self, table, row=None, **kw):
2016        """Update an existing row in a database table.
2017
2018        Similar to insert, but updates an existing row.  The update is based
2019        on the primary key of the table or the OID value as munged by get()
2020        or passed as keyword.  The OID will take precedence if provided, so
2021        that it is possible to update the primary key itself.
2022
2023        The dictionary is then modified to reflect any changes caused by the
2024        update due to triggers, rules, default values, etc.
2025        """
2026        if table.endswith('*'):
2027            table = table[:-1].rstrip()  # need parent table name
2028        attnames = self.get_attnames(table)
2029        qoid = _oid_key(table) if 'oid' in attnames else None
2030        if row is None:
2031            row = {}
2032        elif 'oid' in row:
2033            del row['oid']  # only accept oid key from named args for safety
2034        row.update(kw)
2035        if qoid and qoid in row and 'oid' not in row:
2036            row['oid'] = row[qoid]
2037        if qoid and 'oid' in row:  # try using the oid
2038            keyname = ('oid',)
2039        else:  # try using the primary key
2040            try:
2041                keyname = self.pkey(table, True)
2042            except KeyError:  # the table has no primary key
2043                raise _prg_error('Table %s has no primary key' % table)
2044            # check whether all key columns have values
2045            if not set(keyname).issubset(row):
2046                raise KeyError('Missing value for primary key in row')
2047        params = self.adapter.parameter_list()
2048        adapt = params.add
2049        col = self.escape_identifier
2050        where = ' AND '.join('%s = %s' % (
2051            col(k), adapt(row[k], attnames[k])) for k in keyname)
2052        if 'oid' in row:
2053            if qoid:
2054                row[qoid] = row['oid']
2055            del row['oid']
2056        values = []
2057        keyname = set(keyname)
2058        for n in attnames:
2059            if n in row and n not in keyname:
2060                values.append('%s = %s' % (col(n), adapt(row[n], attnames[n])))
2061        if not values:
2062            return row
2063        values = ', '.join(values)
2064        ret = 'oid, *' if qoid else '*'
2065        q = 'UPDATE %s SET %s WHERE %s RETURNING %s' % (
2066            self._escape_qualified_name(table), values, where, ret)
2067        self._do_debug(q, params)
2068        q = self.db.query(q, params)
2069        res = q.dictresult()
2070        if res:  # may be empty when row does not exist
2071            for n, value in res[0].items():
2072                if qoid and n == 'oid':
2073                    n = qoid
2074                row[n] = value
2075        return row
2076
2077    def upsert(self, table, row=None, **kw):
2078        """Insert a row into a database table with conflict resolution
2079
2080        This method inserts a row into a table, but instead of raising a
2081        ProgrammingError exception in case a row with the same primary key
2082        already exists, an update will be executed instead.  This will be
2083        performed as a single atomic operation on the database, so race
2084        conditions can be avoided.
2085
2086        Like the insert method, the first parameter is the name of the
2087        table and the second parameter can be used to pass the values to
2088        be inserted as a dictionary.
2089
2090        Unlike the insert und update statement, keyword parameters are not
2091        used to modify the dictionary, but to specify which columns shall
2092        be updated in case of a conflict, and in which way:
2093
2094        A value of False or None means the column shall not be updated,
2095        a value of True means the column shall be updated with the value
2096        that has been proposed for insertion, i.e. has been passed as value
2097        in the dictionary.  Columns that are not specified by keywords but
2098        appear as keys in the dictionary are also updated like in the case
2099        keywords had been passed with the value True.
2100
2101        So if in the case of a conflict you want to update every column that
2102        has been passed in the dictionary row, you would call upsert(table, row).
2103        If you don't want to do anything in case of a conflict, i.e. leave
2104        the existing row as it is, call upsert(table, row, **dict.fromkeys(row)).
2105
2106        If you need more fine-grained control of what gets updated, you can
2107        also pass strings in the keyword parameters.  These strings will
2108        be used as SQL expressions for the update columns.  In these
2109        expressions you can refer to the value that already exists in
2110        the table by prefixing the column name with "included.", and to
2111        the value that has been proposed for insertion by prefixing the
2112        column name with the "excluded."
2113
2114        The dictionary is modified in any case to reflect the values in
2115        the database after the operation has completed.
2116
2117        Note: The method uses the PostgreSQL "upsert" feature which is
2118        only available since PostgreSQL 9.5.
2119        """
2120        if table.endswith('*'):  # hint for descendant tables can be ignored
2121            table = table[:-1].rstrip()
2122        if row is None:
2123            row = {}
2124        if 'oid' in row:
2125            del row['oid']  # do not insert oid
2126        if 'oid' in kw:
2127            del kw['oid']  # do not update oid
2128        attnames = self.get_attnames(table)
2129        qoid = _oid_key(table) if 'oid' in attnames else None
2130        params = self.adapter.parameter_list()
2131        adapt = params.add
2132        col = self.escape_identifier
2133        names, values, updates = [], [], []
2134        for n in attnames:
2135            if n in row:
2136                names.append(col(n))
2137                values.append(adapt(row[n], attnames[n]))
2138        names, values = ', '.join(names), ', '.join(values)
2139        try:
2140            keyname = self.pkey(table, True)
2141        except KeyError:
2142            raise _prg_error('Table %s has no primary key' % table)
2143        target = ', '.join(col(k) for k in keyname)
2144        update = []
2145        keyname = set(keyname)
2146        keyname.add('oid')
2147        for n in attnames:
2148            if n not in keyname:
2149                value = kw.get(n, True)
2150                if value:
2151                    if not isinstance(value, basestring):
2152                        value = 'excluded.%s' % col(n)
2153                    update.append('%s = %s' % (col(n), value))
2154        if not values:
2155            return row
2156        do = 'update set %s' % ', '.join(update) if update else 'nothing'
2157        ret = 'oid, *' if qoid else '*'
2158        q = ('INSERT INTO %s AS included (%s) VALUES (%s)'
2159            ' ON CONFLICT (%s) DO %s RETURNING %s') % (
2160                self._escape_qualified_name(table), names, values,
2161                target, do, ret)
2162        self._do_debug(q, params)
2163        try:
2164            q = self.db.query(q, params)
2165        except ProgrammingError:
2166            if self.server_version < 90500:
2167                raise _prg_error(
2168                    'Upsert operation is not supported by PostgreSQL version')
2169            raise  # re-raise original error
2170        res = q.dictresult()
2171        if res:  # may be empty with "do nothing"
2172            for n, value in res[0].items():
2173                if qoid and n == 'oid':
2174                    n = qoid
2175                row[n] = value
2176        else:
2177            self.get(table, row)
2178        return row
2179
2180    def clear(self, table, row=None):
2181        """Clear all the attributes to values determined by the types.
2182
2183        Numeric types are set to 0, Booleans are set to false, and everything
2184        else is set to the empty string.  If the row argument is present,
2185        it is used as the row dictionary and any entries matching attribute
2186        names are cleared with everything else left unchanged.
2187        """
2188        # At some point we will need a way to get defaults from a table.
2189        if row is None:
2190            row = {}  # empty if argument is not present
2191        attnames = self.get_attnames(table)
2192        for n, t in attnames.items():
2193            if n == 'oid':
2194                continue
2195            t = t.simple
2196            if t in DbTypes._num_types:
2197                row[n] = 0
2198            elif t == 'bool':
2199                row[n] = self._make_bool(False)
2200            else:
2201                row[n] = ''
2202        return row
2203
2204    def delete(self, table, row=None, **kw):
2205        """Delete an existing row in a database table.
2206
2207        This method deletes the row from a table.  It deletes based on the
2208        primary key of the table or the OID value as munged by get() or
2209        passed as keyword.  The OID will take precedence if provided.
2210
2211        The return value is the number of deleted rows (i.e. 0 if the row
2212        did not exist and 1 if the row was deleted).
2213
2214        Note that if the row cannot be deleted because e.g. it is still
2215        referenced by another table, this method raises a ProgrammingError.
2216        """
2217        if table.endswith('*'):  # hint for descendant tables can be ignored
2218            table = table[:-1].rstrip()
2219        attnames = self.get_attnames(table)
2220        qoid = _oid_key(table) if 'oid' in attnames else None
2221        if row is None:
2222            row = {}
2223        elif 'oid' in row:
2224            del row['oid']  # only accept oid key from named args for safety
2225        row.update(kw)
2226        if qoid and qoid in row and 'oid' not in row:
2227            row['oid'] = row[qoid]
2228        if qoid and 'oid' in row:  # try using the oid
2229            keyname = ('oid',)
2230        else:  # try using the primary key
2231            try:
2232                keyname = self.pkey(table, True)
2233            except KeyError:  # the table has no primary key
2234                raise _prg_error('Table %s has no primary key' % table)
2235            # check whether all key columns have values
2236            if not set(keyname).issubset(row):
2237                raise KeyError('Missing value for primary key in row')
2238        params = self.adapter.parameter_list()
2239        adapt = params.add
2240        col = self.escape_identifier
2241        where = ' AND '.join('%s = %s' % (
2242            col(k), adapt(row[k], attnames[k])) for k in keyname)
2243        if 'oid' in row:
2244            if qoid:
2245                row[qoid] = row['oid']
2246            del row['oid']
2247        q = 'DELETE FROM %s WHERE %s' % (
2248            self._escape_qualified_name(table), where)
2249        self._do_debug(q, params)
2250        res = self.db.query(q, params)
2251        return int(res)
2252
2253    def truncate(self, table, restart=False, cascade=False, only=False):
2254        """Empty a table or set of tables.
2255
2256        This method quickly removes all rows from the given table or set
2257        of tables.  It has the same effect as an unqualified DELETE on each
2258        table, but since it does not actually scan the tables it is faster.
2259        Furthermore, it reclaims disk space immediately, rather than requiring
2260        a subsequent VACUUM operation. This is most useful on large tables.
2261
2262        If restart is set to True, sequences owned by columns of the truncated
2263        table(s) are automatically restarted.  If cascade is set to True, it
2264        also truncates all tables that have foreign-key references to any of
2265        the named tables.  If the parameter only is not set to True, all the
2266        descendant tables (if any) will also be truncated. Optionally, a '*'
2267        can be specified after the table name to explicitly indicate that
2268        descendant tables are included.
2269        """
2270        if isinstance(table, basestring):
2271            only = {table: only}
2272            table = [table]
2273        elif isinstance(table, (list, tuple)):
2274            if isinstance(only, (list, tuple)):
2275                only = dict(zip(table, only))
2276            else:
2277                only = dict.fromkeys(table, only)
2278        elif isinstance(table, (set, frozenset)):
2279            only = dict.fromkeys(table, only)
2280        else:
2281            raise TypeError('The table must be a string, list or set')
2282        if not (restart is None or isinstance(restart, (bool, int))):
2283            raise TypeError('Invalid type for the restart option')
2284        if not (cascade is None or isinstance(cascade, (bool, int))):
2285            raise TypeError('Invalid type for the cascade option')
2286        tables = []
2287        for t in table:
2288            u = only.get(t)
2289            if not (u is None or isinstance(u, (bool, int))):
2290                raise TypeError('Invalid type for the only option')
2291            if t.endswith('*'):
2292                if u:
2293                    raise ValueError(
2294                        'Contradictory table name and only options')
2295                t = t[:-1].rstrip()
2296            t = self._escape_qualified_name(t)
2297            if u:
2298                t = 'ONLY %s' % t
2299            tables.append(t)
2300        q = ['TRUNCATE', ', '.join(tables)]
2301        if restart:
2302            q.append('RESTART IDENTITY')
2303        if cascade:
2304            q.append('CASCADE')
2305        q = ' '.join(q)
2306        self._do_debug(q)
2307        return self.db.query(q)
2308
2309    def get_as_list(self, table, what=None, where=None,
2310            order=None, limit=None, offset=None, scalar=False):
2311        """Get a table as a list.
2312
2313        This gets a convenient representation of the table as a list
2314        of named tuples in Python.  You only need to pass the name of
2315        the table (or any other SQL expression returning rows).  Note that
2316        by default this will return the full content of the table which
2317        can be huge and overflow your memory.  However, you can control
2318        the amount of data returned using the other optional parameters.
2319
2320        The parameter 'what' can restrict the query to only return a
2321        subset of the table columns.  It can be a string, list or a tuple.
2322        The parameter 'where' can restrict the query to only return a
2323        subset of the table rows.  It can be a string, list or a tuple
2324        of SQL expressions that all need to be fulfilled.  The parameter
2325        'order' specifies the ordering of the rows.  It can also be a
2326        other string, list or a tuple.  If no ordering is specified,
2327        the result will be ordered by the primary key(s) or all columns
2328        if no primary key exists.  You can set 'order' to False if you
2329        don't care about the ordering.  The parameters 'limit' and 'offset'
2330        can be integers specifying the maximum number of rows returned
2331        and a number of rows skipped over.
2332
2333        If you set the 'scalar' option to True, then instead of the
2334        named tuples you will get the first items of these tuples.
2335        This is useful if the result has only one column anyway.
2336        """
2337        if not table:
2338            raise TypeError('The table name is missing')
2339        if what:
2340            if isinstance(what, (list, tuple)):
2341                what = ', '.join(map(str, what))
2342            if order is None:
2343                order = what
2344        else:
2345            what = '*'
2346        q = ['SELECT', what, 'FROM', table]
2347        if where:
2348            if isinstance(where, (list, tuple)):
2349                where = ' AND '.join(map(str, where))
2350            q.extend(['WHERE', where])
2351        if order is None:
2352            try:
2353                order = self.pkey(table, True)
2354            except (KeyError, ProgrammingError):
2355                try:
2356                    order = list(self.get_attnames(table))
2357                except (KeyError, ProgrammingError):
2358                    pass
2359        if order:
2360            if isinstance(order, (list, tuple)):
2361                order = ', '.join(map(str, order))
2362            q.extend(['ORDER BY', order])
2363        if limit:
2364            q.append('LIMIT %d' % limit)
2365        if offset:
2366            q.append('OFFSET %d' % offset)
2367        q = ' '.join(q)
2368        self._do_debug(q)
2369        q = self.db.query(q)
2370        res = q.namedresult()
2371        if res and scalar:
2372            res = [row[0] for row in res]
2373        return res
2374
2375    def get_as_dict(self, table, keyname=None, what=None, where=None,
2376            order=None, limit=None, offset=None, scalar=False):
2377        """Get a table as a dictionary.
2378
2379        This method is similar to get_as_list(), but returns the table
2380        as a Python dict instead of a Python list, which can be even
2381        more convenient. The primary key column(s) of the table will
2382        be used as the keys of the dictionary, while the other column(s)
2383        will be the corresponding values.  The keys will be named tuples
2384        if the table has a composite primary key.  The rows will be also
2385        named tuples unless the 'scalar' option has been set to True.
2386        With the optional parameter 'keyname' you can specify an alternative
2387        set of columns to be used as the keys of the dictionary.  It must
2388        be set as a string, list or a tuple.
2389
2390        If the Python version supports it, the dictionary will be an
2391        OrderedDict using the order specified with the 'order' parameter
2392        or the key column(s) if not specified.  You can set 'order' to False
2393        if you don't care about the ordering.  In this case the returned
2394        dictionary will be an ordinary one.
2395        """
2396        if not table:
2397            raise TypeError('The table name is missing')
2398        if not keyname:
2399            try:
2400                keyname = self.pkey(table, True)
2401            except (KeyError, ProgrammingError):
2402                raise _prg_error('Table %s has no primary key' % table)
2403        if isinstance(keyname, basestring):
2404            keyname = [keyname]
2405        elif not isinstance(keyname, (list, tuple)):
2406            raise KeyError('The keyname must be a string, list or tuple')
2407        if what:
2408            if isinstance(what, (list, tuple)):
2409                what = ', '.join(map(str, what))
2410            if order is None:
2411                order = what
2412        else:
2413            what = '*'
2414        q = ['SELECT', what, 'FROM', table]
2415        if where:
2416            if isinstance(where, (list, tuple)):
2417                where = ' AND '.join(map(str, where))
2418            q.extend(['WHERE', where])
2419        if order is None:
2420            order = keyname
2421        if order:
2422            if isinstance(order, (list, tuple)):
2423                order = ', '.join(map(str, order))
2424            q.extend(['ORDER BY', order])
2425        if limit:
2426            q.append('LIMIT %d' % limit)
2427        if offset:
2428            q.append('OFFSET %d' % offset)
2429        q = ' '.join(q)
2430        self._do_debug(q)
2431        q = self.db.query(q)
2432        res = q.getresult()
2433        cls = OrderedDict if order else dict
2434        if not res:
2435            return cls()
2436        keyset = set(keyname)
2437        fields = q.listfields()
2438        if not keyset.issubset(fields):
2439            raise KeyError('Missing keyname in row')
2440        keyind, rowind = [], []
2441        for i, f in enumerate(fields):
2442            (keyind if f in keyset else rowind).append(i)
2443        keytuple = len(keyind) > 1
2444        getkey = itemgetter(*keyind)
2445        keys = map(getkey, res)
2446        if scalar:
2447            rowind = rowind[:1]
2448            rowtuple = False
2449        else:
2450            rowtuple = len(rowind) > 1
2451        if scalar or rowtuple:
2452            getrow = itemgetter(*rowind)
2453        else:
2454            rowind = rowind[0]
2455            getrow = lambda row: (row[rowind],)
2456            rowtuple = True
2457        rows = map(getrow, res)
2458        if keytuple or rowtuple:
2459            namedresult = get_namedresult()
2460            if namedresult:
2461                if keytuple:
2462                    keys = namedresult(_MemoryQuery(keys, keyname))
2463                if rowtuple:
2464                    fields = [f for f in fields if f not in keyset]
2465                    rows = namedresult(_MemoryQuery(rows, fields))
2466        return cls(zip(keys, rows))
2467
2468    def notification_handler(self,
2469            event, callback, arg_dict=None, timeout=None, stop_event=None):
2470        """Get notification handler that will run the given callback."""
2471        return NotificationHandler(self,
2472            event, callback, arg_dict, timeout, stop_event)
2473
2474
2475# if run as script, print some information
2476
2477if __name__ == '__main__':
2478    print('PyGreSQL version' + version)
2479    print('')
2480    print(__doc__)
Note: See TracBrowser for help on using the repository browser.