source: trunk/pg.py @ 879

Last change on this file since 879 was 879, checked in by cito, 3 years ago

Enable garbage collection after closing DB instance

Note that there is still a problem if the DB instance is not closed.

  • Property svn:keywords set to Id
File size: 89.6 KB
Line 
1#! /usr/bin/python
2#
3# pg.py
4#
5# $Id: pg.py 879 2016-07-23 12:06:19Z cito $
6#
7
8"""PyGreSQL classic interface.
9
10This pg module implements some basic database management stuff.
11It includes the _pg module and builds on it, providing the higher
12level wrapper class named DB with additional functionality.
13This is known as the "classic" ("old style") PyGreSQL interface.
14For a DB-API 2 compliant interface use the newer pgdb module.
15"""
16
17# Copyright (c) 1997-2016 by D'Arcy J.M. Cain.
18#
19# Contributions made by Ch. Zwerschke and others.
20#
21# The notification handler is based on pgnotify which is
22# Copyright (c) 2001 Ng Pheng Siong. All rights reserved.
23#
24# Permission to use, copy, modify, and distribute this software and its
25# documentation for any purpose and without fee is hereby granted,
26# provided that the above copyright notice appear in all copies and that
27# both that copyright notice and this permission notice appear in
28# supporting documentation.
29
30from __future__ import print_function, division
31
32from _pg import *
33
34__version__ = version
35
36import select
37import warnings
38
39from datetime import date, time, datetime, timedelta, tzinfo
40from decimal import Decimal
41from math import isnan, isinf
42from collections import namedtuple
43from operator import itemgetter
44from functools import partial
45from re import compile as regex
46from json import loads as jsondecode, dumps as jsonencode
47from uuid import UUID
48
49try:
50    long
51except NameError:  # Python >= 3.0
52    long = int
53
54try:
55    basestring
56except NameError:  # Python >= 3.0
57    basestring = (str, bytes)
58
59
60# Auxiliary classes and functions that are independent from a DB connection:
61
62try:
63    from collections import OrderedDict
64except ImportError:  # Python 2.6 or 3.0
65    OrderedDict = dict
66
67
68    class AttrDict(dict):
69        """Simple read-only ordered dictionary for storing attribute names."""
70
71        def __init__(self, *args, **kw):
72            if len(args) > 1 or kw:
73                raise TypeError
74            items = args[0] if args else []
75            if isinstance(items, dict):
76                raise TypeError
77            items = list(items)
78            self._keys = [item[0] for item in items]
79            dict.__init__(self, items)
80            self._read_only = True
81            error = self._read_only_error
82            self.clear = self.update = error
83            self.pop = self.setdefault = self.popitem = error
84
85        def __setitem__(self, key, value):
86            if self._read_only:
87                self._read_only_error()
88            dict.__setitem__(self, key, value)
89
90        def __delitem__(self, key):
91            if self._read_only:
92                self._read_only_error()
93            dict.__delitem__(self, key)
94
95        def __iter__(self):
96            return iter(self._keys)
97
98        def keys(self):
99            return list(self._keys)
100
101        def values(self):
102            return [self[key] for key in self]
103
104        def items(self):
105            return [(key, self[key]) for key in self]
106
107        def iterkeys(self):
108            return self.__iter__()
109
110        def itervalues(self):
111            return iter(self.values())
112
113        def iteritems(self):
114            return iter(self.items())
115
116        @staticmethod
117        def _read_only_error(*args, **kw):
118            raise TypeError('This object is read-only')
119
120else:
121
122     class AttrDict(OrderedDict):
123        """Simple read-only ordered dictionary for storing attribute names."""
124
125        def __init__(self, *args, **kw):
126            self._read_only = False
127            OrderedDict.__init__(self, *args, **kw)
128            self._read_only = True
129            error = self._read_only_error
130            self.clear = self.update = error
131            self.pop = self.setdefault = self.popitem = error
132
133        def __setitem__(self, key, value):
134            if self._read_only:
135                self._read_only_error()
136            OrderedDict.__setitem__(self, key, value)
137
138        def __delitem__(self, key):
139            if self._read_only:
140                self._read_only_error()
141            OrderedDict.__delitem__(self, key)
142
143        @staticmethod
144        def _read_only_error(*args, **kw):
145            raise TypeError('This object is read-only')
146
147try:
148    from inspect import signature
149except ImportError:  # Python < 3.3
150    from inspect import getargspec
151
152    def get_args(func):
153        return getargspec(func).args
154else:
155
156    def get_args(func):
157        return list(signature(func).parameters)
158
159try:
160    from datetime import timezone
161except ImportError:  # Python < 3.2
162
163    class timezone(tzinfo):
164        """Simple timezone implementation."""
165
166        def __init__(self, offset, name=None):
167            self.offset = offset
168            if not name:
169                minutes = self.offset.days * 1440 + self.offset.seconds // 60
170                if minutes < 0:
171                    hours, minutes = divmod(-minutes, 60)
172                    hours = -hours
173                else:
174                    hours, minutes = divmod(minutes, 60)
175                name = 'UTC%+03d:%02d' % (hours, minutes)
176            self.name = name
177
178        def utcoffset(self, dt):
179            return self.offset
180
181        def tzname(self, dt):
182            return self.name
183
184        def dst(self, dt):
185            return None
186
187    timezone.utc = timezone(timedelta(0), 'UTC')
188
189    _has_timezone = False
190else:
191    _has_timezone = True
192
193# time zones used in Postgres timestamptz output
194_timezones = dict(CET='+0100', EET='+0200', EST='-0500',
195    GMT='+0000', HST='-1000', MET='+0100', MST='-0700',
196    UCT='+0000', UTC='+0000', WET='+0000')
197
198
199def _timezone_as_offset(tz):
200    if tz.startswith(('+', '-')):
201        if len(tz) < 5:
202            return tz + '00'
203        return tz.replace(':', '')
204    return _timezones.get(tz, '+0000')
205
206
207def _get_timezone(tz):
208    tz = _timezone_as_offset(tz)
209    minutes = 60 * int(tz[1:3]) + int(tz[3:5])
210    if tz[0] == '-':
211        minutes = -minutes
212    return timezone(timedelta(minutes=minutes), tz)
213
214
215def _oid_key(table):
216    """Build oid key from a table name."""
217    return 'oid(%s)' % table
218
219
220class _SimpleTypes(dict):
221    """Dictionary mapping pg_type names to simple type names."""
222
223    _types = {'bool': 'bool',
224        'bytea': 'bytea',
225        'date': 'date interval time timetz timestamp timestamptz'
226            ' abstime reltime',  # these are very old
227        'float': 'float4 float8',
228        'int': 'cid int2 int4 int8 oid xid',
229        'hstore': 'hstore', 'json': 'json jsonb', 'uuid': 'uuid',
230        'num': 'numeric', 'money': 'money',
231        'text': 'bpchar char name text varchar'}
232
233    def __init__(self):
234        for typ, keys in self._types.items():
235            for key in keys.split():
236                self[key] = typ
237                self['_%s' % key] = '%s[]' % typ
238
239    # this could be a static method in Python > 2.6
240    def __missing__(self, key):
241        return 'text'
242
243_simpletypes = _SimpleTypes()
244
245
246def _quote_if_unqualified(param, name):
247    """Quote parameter representing a qualified name.
248
249    Puts a quote_ident() call around the give parameter unless
250    the name contains a dot, in which case the name is ambiguous
251    (could be a qualified name or just a name with a dot in it)
252    and must be quoted manually by the caller.
253    """
254    if isinstance(name, basestring) and '.' not in name:
255        return 'quote_ident(%s)' % (param,)
256    return param
257
258
259class _ParameterList(list):
260    """Helper class for building typed parameter lists."""
261
262    def add(self, value, typ=None):
263        """Typecast value with known database type and build parameter list.
264
265        If this is a literal value, it will be returned as is.  Otherwise, a
266        placeholder will be returned and the parameter list will be augmented.
267        """
268        value = self.adapt(value, typ)
269        if isinstance(value, Literal):
270            return value
271        self.append(value)
272        return '$%d' % len(self)
273
274
275class Bytea(bytes):
276    """Wrapper class for marking Bytea values."""
277
278
279class Hstore(dict):
280    """Wrapper class for marking hstore values."""
281
282    _re_quote = regex('^[Nn][Uu][Ll][Ll]$|[ ,=>]')
283
284    @classmethod
285    def _quote(cls, s):
286        if s is None:
287            return 'NULL'
288        if not s:
289            return '""'
290        s = s.replace('"', '\\"')
291        if cls._re_quote.search(s):
292            s = '"%s"' % s
293        return s
294
295    def __str__(self):
296        q = self._quote
297        return ','.join('%s=>%s' % (q(k), q(v)) for k, v in self.items())
298
299
300class Json:
301    """Wrapper class for marking Json values."""
302
303    def __init__(self, obj):
304        self.obj = obj
305
306
307class Literal(str):
308    """Wrapper class for marking literal SQL values."""
309
310
311class Adapter:
312    """Class providing methods for adapting parameters to the database."""
313
314    _bool_true_values = frozenset('t true 1 y yes on'.split())
315
316    _date_literals = frozenset('current_date current_time'
317        ' current_timestamp localtime localtimestamp'.split())
318
319    _re_array_quote = regex(r'[{},"\\\s]|^[Nn][Uu][Ll][Ll]$')
320    _re_record_quote = regex(r'[(,"\\]')
321    _re_array_escape = _re_record_escape = regex(r'(["\\])')
322
323    def __init__(self, db):
324        self.db = db
325        self.encode_json = db.encode_json
326        db = db.db
327        self.escape_bytea = db.escape_bytea
328        self.escape_string = db.escape_string
329
330    @classmethod
331    def _adapt_bool(cls, v):
332        """Adapt a boolean parameter."""
333        if isinstance(v, basestring):
334            if not v:
335                return None
336            v = v.lower() in cls._bool_true_values
337        return 't' if v else 'f'
338
339    @classmethod
340    def _adapt_date(cls, v):
341        """Adapt a date parameter."""
342        if not v:
343            return None
344        if isinstance(v, basestring) and v.lower() in cls._date_literals:
345            return Literal(v)
346        return v
347
348    @staticmethod
349    def _adapt_num(v):
350        """Adapt a numeric parameter."""
351        if not v and v != 0:
352            return None
353        return v
354
355    _adapt_int = _adapt_float = _adapt_money = _adapt_num
356
357    def _adapt_bytea(self, v):
358        """Adapt a bytea parameter."""
359        return self.escape_bytea(v)
360
361    def _adapt_json(self, v):
362        """Adapt a json parameter."""
363        if not v:
364            return None
365        if isinstance(v, basestring):
366            return v
367        return self.encode_json(v)
368
369    @classmethod
370    def _adapt_text_array(cls, v):
371        """Adapt a text type array parameter."""
372        if isinstance(v, list):
373            adapt = cls._adapt_text_array
374            return '{%s}' % ','.join(adapt(v) for v in v)
375        if v is None:
376            return 'null'
377        if not v:
378            return '""'
379        v = str(v)
380        if cls._re_array_quote.search(v):
381            v = '"%s"' % cls._re_array_escape.sub(r'\\\1', v)
382        return v
383
384    _adapt_date_array = _adapt_text_array
385
386    @classmethod
387    def _adapt_bool_array(cls, v):
388        """Adapt a boolean array parameter."""
389        if isinstance(v, list):
390            adapt = cls._adapt_bool_array
391            return '{%s}' % ','.join(adapt(v) for v in v)
392        if v is None:
393            return 'null'
394        if isinstance(v, basestring):
395            if not v:
396                return 'null'
397            v = v.lower() in cls._bool_true_values
398        return 't' if v else 'f'
399
400    @classmethod
401    def _adapt_num_array(cls, v):
402        """Adapt a numeric array parameter."""
403        if isinstance(v, list):
404            adapt = cls._adapt_num_array
405            return '{%s}' % ','.join(adapt(v) for v in v)
406        if not v and v != 0:
407            return 'null'
408        return str(v)
409
410    _adapt_int_array = _adapt_float_array = _adapt_money_array = \
411            _adapt_num_array
412
413    def _adapt_bytea_array(self, v):
414        """Adapt a bytea array parameter."""
415        if isinstance(v, list):
416            return b'{' + b','.join(
417                self._adapt_bytea_array(v) for v in v) + b'}'
418        if v is None:
419            return b'null'
420        return self.escape_bytea(v).replace(b'\\', b'\\\\')
421
422    def _adapt_json_array(self, v):
423        """Adapt a json array parameter."""
424        if isinstance(v, list):
425            adapt = self._adapt_json_array
426            return '{%s}' % ','.join(adapt(v) for v in v)
427        if not v:
428            return 'null'
429        if not isinstance(v, basestring):
430            v = self.encode_json(v)
431        if self._re_array_quote.search(v):
432            v = '"%s"' % self._re_array_escape.sub(r'\\\1', v)
433        return v
434
435    def _adapt_record(self, v, typ):
436        """Adapt a record parameter with given type."""
437        typ = self.get_attnames(typ).values()
438        if len(typ) != len(v):
439            raise TypeError('Record parameter %s has wrong size' % v)
440        adapt = self.adapt
441        value = []
442        for v, t in zip(v, typ):
443            v = adapt(v, t)
444            if v is None:
445                v = ''
446            elif not v:
447                v = '""'
448            else:
449                if isinstance(v, bytes):
450                    if str is not bytes:
451                        v = v.decode('ascii')
452                else:
453                    v = str(v)
454                if self._re_record_quote.search(v):
455                    v = '"%s"' % self._re_record_escape.sub(r'\\\1', v)
456            value.append(v)
457        return '(%s)' % ','.join(value)
458
459    def adapt(self, value, typ=None):
460        """Adapt a value with known database type."""
461        if value is not None and not isinstance(value, Literal):
462            if typ:
463                simple = self.get_simple_name(typ)
464            else:
465                typ = simple = self.guess_simple_type(value) or 'text'
466            try:
467                value = value.__pg_str__(typ)
468            except AttributeError:
469                pass
470            if simple == 'text':
471                pass
472            elif simple == 'record':
473                if isinstance(value, tuple):
474                    value = self._adapt_record(value, typ)
475            elif simple.endswith('[]'):
476                if isinstance(value, list):
477                    adapt = getattr(self, '_adapt_%s_array' % simple[:-2])
478                    value = adapt(value)
479            else:
480                adapt = getattr(self, '_adapt_%s' % simple)
481                value = adapt(value)
482        return value
483
484    @staticmethod
485    def simple_type(name):
486        """Create a simple database type with given attribute names."""
487        typ = DbType(name)
488        typ.simple = name
489        return typ
490
491    @staticmethod
492    def get_simple_name(typ):
493        """Get the simple name of a database type."""
494        if isinstance(typ, DbType):
495            return typ.simple
496        return _simpletypes[typ]
497
498    @staticmethod
499    def get_attnames(typ):
500        """Get the attribute names of a composite database type."""
501        if isinstance(typ, DbType):
502            return typ.attnames
503        return {}
504
505    @classmethod
506    def guess_simple_type(cls, value):
507        """Try to guess which database type the given value has."""
508        if isinstance(value, Bytea):
509            return 'bytea'
510        if isinstance(value, basestring):
511            return 'text'
512        if isinstance(value, bool):
513            return 'bool'
514        if isinstance(value, (int, long)):
515            return 'int'
516        if isinstance(value, float):
517            return 'float'
518        if isinstance(value, Decimal):
519            return 'num'
520        if isinstance(value, (date, time, datetime, timedelta)):
521            return 'date'
522        if isinstance(value, list):
523            return '%s[]' % cls.guess_simple_base_type(value)
524        if isinstance(value, tuple):
525            simple_type = cls.simple_type
526            typ = simple_type('record')
527            guess = cls.guess_simple_type
528            def get_attnames(self):
529                return AttrDict((str(n + 1), simple_type(guess(v)))
530                    for n, v in enumerate(value))
531            typ._get_attnames = get_attnames
532            return typ
533
534    @classmethod
535    def guess_simple_base_type(cls, value):
536        """Try to guess the base type of a given array."""
537        for v in value:
538            if isinstance(v, list):
539                typ = cls.guess_simple_base_type(v)
540            else:
541                typ = cls.guess_simple_type(v)
542            if typ:
543                return typ
544
545    def adapt_inline(self, value, nested=False):
546        """Adapt a value that is put into the SQL and needs to be quoted."""
547        if value is None:
548            return 'NULL'
549        if isinstance(value, Literal):
550            return value
551        if isinstance(value, Bytea):
552            value = self.escape_bytea(value)
553            if bytes is not str:  # Python >= 3.0
554                value = value.decode('ascii')
555        elif isinstance(value, Json):
556            if value.encode:
557                return value.encode()
558            value = self.encode_json(value)
559        elif isinstance(value, (datetime, date, time, timedelta)):
560            value = str(value)
561        if isinstance(value, basestring):
562            value = self.escape_string(value)
563            return "'%s'" % value
564        if isinstance(value, bool):
565            return 'true' if value else 'false'
566        if isinstance(value, float):
567            if isinf(value):
568                return "'-Infinity'" if value < 0 else "'Infinity'"
569            if isnan(value):
570                return "'NaN'"
571            return value
572        if isinstance(value, (int, long, Decimal)):
573            return value
574        if isinstance(value, list):
575            q = self.adapt_inline
576            s = '[%s]' if nested else 'ARRAY[%s]'
577            return s % ','.join(str(q(v, nested=True)) for v in value)
578        if isinstance(value, tuple):
579            q = self.adapt_inline
580            return '(%s)' % ','.join(str(q(v)) for v in value)
581        try:
582            value = value.__pg_repr__()
583        except AttributeError:
584            raise InterfaceError(
585                'Do not know how to adapt type %s' % type(value))
586        if isinstance(value, (tuple, list)):
587            value = self.adapt_inline(value)
588        return value
589
590    def parameter_list(self):
591        """Return a parameter list for parameters with known database types.
592
593        The list has an add(value, typ) method that will build up the
594        list and return either the literal value or a placeholder.
595        """
596        params = _ParameterList()
597        params.adapt = self.adapt
598        return params
599
600    def format_query(self, command, values, types=None, inline=False):
601        """Format a database query using the given values and types."""
602        if inline and types:
603            raise ValueError('Typed parameters must be sent separately')
604        params = self.parameter_list()
605        if isinstance(values, (list, tuple)):
606            if inline:
607                adapt = self.adapt_inline
608                literals = [adapt(value) for value in values]
609            else:
610                add = params.add
611                literals = []
612                append = literals.append
613                if types:
614                    if (not isinstance(types, (list, tuple)) or
615                            len(types) != len(values)):
616                        raise TypeError('The values and types do not match')
617                    for value, typ in zip(values, types):
618                        append(add(value, typ))
619                else:
620                    for value in values:
621                        append(add(value))
622            command %= tuple(literals)
623        elif isinstance(values, dict):
624            # we want to allow extra keys in the dictionary,
625            # so we first must find the values actually used in the command
626            used_values = {}
627            literals = dict.fromkeys(values, '')
628            for key in literals:
629                del literals[key]
630                try:
631                    command % literals
632                except KeyError:
633                    used_values[key] = values[key]
634                literals[key] = ''
635            values = used_values
636            if inline:
637                adapt = self.adapt_inline
638                literals = dict((key, adapt(value))
639                    for key, value in values.items())
640            else:
641                add = params.add
642                literals = {}
643                if types:
644                    if not isinstance(types, dict):
645                        raise TypeError('The values and types do not match')
646                    for key in sorted(values):
647                        literals[key] = add(values[key], types.get(key))
648                else:
649                    for key in sorted(values):
650                        literals[key] = add(values[key])
651            command %= literals
652        else:
653            raise TypeError('The values must be passed as tuple, list or dict')
654        return command, params
655
656
657def cast_bool(value):
658    """Cast a boolean value."""
659    if not get_bool():
660        return value
661    return value[0] == 't'
662
663
664def cast_json(value):
665    """Cast a JSON value."""
666    cast = get_jsondecode()
667    if not cast:
668        return value
669    return cast(value)
670
671
672def cast_num(value):
673    """Cast a numeric value."""
674    return (get_decimal() or float)(value)
675
676
677def cast_money(value):
678    """Cast a money value."""
679    point = get_decimal_point()
680    if not point:
681        return value
682    if point != '.':
683        value = value.replace(point, '.')
684    value = value.replace('(', '-')
685    value = ''.join(c for c in value if c.isdigit() or c in '.-')
686    return (get_decimal() or float)(value)
687
688
689def cast_int2vector(value):
690    """Cast an int2vector value."""
691    return [int(v) for v in value.split()]
692
693
694def cast_date(value, connection):
695    """Cast a date value."""
696    # The output format depends on the server setting DateStyle.  The default
697    # setting ISO and the setting for German are actually unambiguous.  The
698    # order of days and months in the other two settings is however ambiguous,
699    # so at least here we need to consult the setting to properly parse values.
700    if value == '-infinity':
701        return date.min
702    if value == 'infinity':
703        return date.max
704    value = value.split()
705    if value[-1] == 'BC':
706        return date.min
707    value = value[0]
708    if len(value) > 10:
709        return date.max
710    fmt = connection.date_format()
711    return datetime.strptime(value, fmt).date()
712
713
714def cast_time(value):
715    """Cast a time value."""
716    fmt = '%H:%M:%S.%f' if len(value) > 8 else '%H:%M:%S'
717    return datetime.strptime(value, fmt).time()
718
719
720_re_timezone = regex('(.*)([+-].*)')
721
722
723def cast_timetz(value):
724    """Cast a timetz value."""
725    tz = _re_timezone.match(value)
726    if tz:
727        value, tz = tz.groups()
728    else:
729        tz = '+0000'
730    fmt = '%H:%M:%S.%f' if len(value) > 8 else '%H:%M:%S'
731    if _has_timezone:
732        value += _timezone_as_offset(tz)
733        fmt += '%z'
734        return datetime.strptime(value, fmt).timetz()
735    return datetime.strptime(value, fmt).timetz().replace(
736        tzinfo=_get_timezone(tz))
737
738
739def cast_timestamp(value, connection):
740    """Cast a timestamp value."""
741    if value == '-infinity':
742        return datetime.min
743    if value == 'infinity':
744        return datetime.max
745    value = value.split()
746    if value[-1] == 'BC':
747        return datetime.min
748    fmt = connection.date_format()
749    if fmt.endswith('-%Y') and len(value) > 2:
750        value = value[1:5]
751        if len(value[3]) > 4:
752            return datetime.max
753        fmt = ['%d %b' if fmt.startswith('%d') else '%b %d',
754            '%H:%M:%S.%f' if len(value[2]) > 8 else '%H:%M:%S', '%Y']
755    else:
756        if len(value[0]) > 10:
757            return datetime.max
758        fmt = [fmt, '%H:%M:%S.%f' if len(value[1]) > 8 else '%H:%M:%S']
759    return datetime.strptime(' '.join(value), ' '.join(fmt))
760
761
762def cast_timestamptz(value, connection):
763    """Cast a timestamptz value."""
764    if value == '-infinity':
765        return datetime.min
766    if value == 'infinity':
767        return datetime.max
768    value = value.split()
769    if value[-1] == 'BC':
770        return datetime.min
771    fmt = connection.date_format()
772    if fmt.endswith('-%Y') and len(value) > 2:
773        value = value[1:]
774        if len(value[3]) > 4:
775            return datetime.max
776        fmt = ['%d %b' if fmt.startswith('%d') else '%b %d',
777            '%H:%M:%S.%f' if len(value[2]) > 8 else '%H:%M:%S', '%Y']
778        value, tz = value[:-1], value[-1]
779    else:
780        if fmt.startswith('%Y-'):
781            tz = _re_timezone.match(value[1])
782            if tz:
783                value[1], tz = tz.groups()
784            else:
785                tz = '+0000'
786        else:
787            value, tz = value[:-1], value[-1]
788        if len(value[0]) > 10:
789            return datetime.max
790        fmt = [fmt, '%H:%M:%S.%f' if len(value[1]) > 8 else '%H:%M:%S']
791    if _has_timezone:
792        value.append(_timezone_as_offset(tz))
793        fmt.append('%z')
794        return datetime.strptime(' '.join(value), ' '.join(fmt))
795    return datetime.strptime(' '.join(value), ' '.join(fmt)).replace(
796        tzinfo=_get_timezone(tz))
797
798
799_re_interval_sql_standard = regex(
800    '(?:([+-])?([0-9]+)-([0-9]+) ?)?'
801    '(?:([+-]?[0-9]+)(?!:) ?)?'
802    '(?:([+-])?([0-9]+):([0-9]+):([0-9]+)(?:\\.([0-9]+))?)?')
803
804_re_interval_postgres = regex(
805    '(?:([+-]?[0-9]+) ?years? ?)?'
806    '(?:([+-]?[0-9]+) ?mons? ?)?'
807    '(?:([+-]?[0-9]+) ?days? ?)?'
808    '(?:([+-])?([0-9]+):([0-9]+):([0-9]+)(?:\\.([0-9]+))?)?')
809
810_re_interval_postgres_verbose = regex(
811    '@ ?(?:([+-]?[0-9]+) ?years? ?)?'
812    '(?:([+-]?[0-9]+) ?mons? ?)?'
813    '(?:([+-]?[0-9]+) ?days? ?)?'
814    '(?:([+-]?[0-9]+) ?hours? ?)?'
815    '(?:([+-]?[0-9]+) ?mins? ?)?'
816    '(?:([+-])?([0-9]+)(?:\\.([0-9]+))? ?secs?)? ?(ago)?')
817
818_re_interval_iso_8601 = regex(
819    'P(?:([+-]?[0-9]+)Y)?'
820    '(?:([+-]?[0-9]+)M)?'
821    '(?:([+-]?[0-9]+)D)?'
822    '(?:T(?:([+-]?[0-9]+)H)?'
823    '(?:([+-]?[0-9]+)M)?'
824    '(?:([+-])?([0-9]+)(?:\\.([0-9]+))?S)?)?')
825
826
827def cast_interval(value):
828    """Cast an interval value."""
829    # The output format depends on the server setting IntervalStyle, but it's
830    # not necessary to consult this setting to parse it.  It's faster to just
831    # check all possible formats, and there is no ambiguity here.
832    m = _re_interval_iso_8601.match(value)
833    if m:
834        m = [d or '0' for d in m.groups()]
835        secs_ago = m.pop(5) == '-'
836        m = [int(d) for d in m]
837        years, mons, days, hours, mins, secs, usecs = m
838        if secs_ago:
839            secs = -secs
840            usecs = -usecs
841    else:
842        m = _re_interval_postgres_verbose.match(value)
843        if m:
844            m, ago = [d or '0' for d in m.groups()[:8]], m.group(9)
845            secs_ago = m.pop(5) == '-'
846            m = [-int(d) for d in m] if ago else [int(d) for d in m]
847            years, mons, days, hours, mins, secs, usecs = m
848            if secs_ago:
849                secs = - secs
850                usecs = -usecs
851        else:
852            m = _re_interval_postgres.match(value)
853            if m and any(m.groups()):
854                m = [d or '0' for d in m.groups()]
855                hours_ago = m.pop(3) == '-'
856                m = [int(d) for d in m]
857                years, mons, days, hours, mins, secs, usecs = m
858                if hours_ago:
859                    hours = -hours
860                    mins = -mins
861                    secs = -secs
862                    usecs = -usecs
863            else:
864                m = _re_interval_sql_standard.match(value)
865                if m and any(m.groups()):
866                    m = [d or '0' for d in m.groups()]
867                    years_ago = m.pop(0) == '-'
868                    hours_ago = m.pop(3) == '-'
869                    m = [int(d) for d in m]
870                    years, mons, days, hours, mins, secs, usecs = m
871                    if years_ago:
872                        years = -years
873                        mons = -mons
874                    if hours_ago:
875                        hours = -hours
876                        mins = -mins
877                        secs = -secs
878                        usecs = -usecs
879                else:
880                    raise ValueError('Cannot parse interval: %s' % value)
881    days += 365 * years + 30 * mons
882    return timedelta(days=days, hours=hours, minutes=mins,
883        seconds=secs, microseconds=usecs)
884
885
886class Typecasts(dict):
887    """Dictionary mapping database types to typecast functions.
888
889    The cast functions get passed the string representation of a value in
890    the database which they need to convert to a Python object.  The
891    passed string will never be None since NULL values are already be
892    handled before the cast function is called.
893
894    Note that the basic types are already handled by the C extension.
895    They only need to be handled here as record or array components.
896    """
897
898    # the default cast functions
899    # (str functions are ignored but have been added for faster access)
900    defaults = {'char': str, 'bpchar': str, 'name': str,
901        'text': str, 'varchar': str,
902        'bool': cast_bool, 'bytea': unescape_bytea,
903        'int2': int, 'int4': int, 'serial': int, 'int8': long, 'oid': int,
904        'hstore': cast_hstore, 'json': cast_json, 'jsonb': cast_json,
905        'float4': float, 'float8': float,
906        'numeric': cast_num, 'money': cast_money,
907        'date': cast_date, 'interval': cast_interval,
908        'time': cast_time, 'timetz': cast_timetz,
909        'timestamp': cast_timestamp, 'timestamptz': cast_timestamptz,
910        'int2vector': cast_int2vector, 'uuid': UUID,
911        'anyarray': cast_array, 'record': cast_record}
912
913    connection = None  # will be set in a connection specific instance
914
915    def __missing__(self, typ):
916        """Create a cast function if it is not cached.
917
918        Note that this class never raises a KeyError,
919        but returns None when no special cast function exists.
920        """
921        if not isinstance(typ, str):
922            raise TypeError('Invalid type: %s' % typ)
923        cast = self.defaults.get(typ)
924        if cast:
925            # store default for faster access
926            cast = self._add_connection(cast)
927            self[typ] = cast
928        elif typ.startswith('_'):
929            base_cast = self[typ[1:]]
930            cast = self.create_array_cast(base_cast)
931            if base_cast:
932                self[typ] = cast
933        else:
934            attnames = self.get_attnames(typ)
935            if attnames:
936                casts = [self[v.pgtype] for v in attnames.values()]
937                cast = self.create_record_cast(typ, attnames, casts)
938                self[typ] = cast
939        return cast
940
941    @staticmethod
942    def _needs_connection(func):
943        """Check if a typecast function needs a connection argument."""
944        try:
945            args = get_args(func)
946        except (TypeError, ValueError):
947            return False
948        else:
949            return 'connection' in args[1:]
950
951    def _add_connection(self, cast):
952        """Add a connection argument to the typecast function if necessary."""
953        if not self.connection or not self._needs_connection(cast):
954            return cast
955        return partial(cast, connection=self.connection)
956
957    def get(self, typ, default=None):
958        """Get the typecast function for the given database type."""
959        return self[typ] or default
960
961    def set(self, typ, cast):
962        """Set a typecast function for the specified database type(s)."""
963        if isinstance(typ, basestring):
964            typ = [typ]
965        if cast is None:
966            for t in typ:
967                self.pop(t, None)
968                self.pop('_%s' % t, None)
969        else:
970            if not callable(cast):
971                raise TypeError("Cast parameter must be callable")
972            for t in typ:
973                self[t] = self._add_connection(cast)
974                self.pop('_%s' % t, None)
975
976    def reset(self, typ=None):
977        """Reset the typecasts for the specified type(s) to their defaults.
978
979        When no type is specified, all typecasts will be reset.
980        """
981        if typ is None:
982            self.clear()
983        else:
984            if isinstance(typ, basestring):
985                typ = [typ]
986            for t in typ:
987                self.pop(t, None)
988
989    @classmethod
990    def get_default(cls, typ):
991        """Get the default typecast function for the given database type."""
992        return cls.defaults.get(typ)
993
994    @classmethod
995    def set_default(cls, typ, cast):
996        """Set a default typecast function for the given database type(s)."""
997        if isinstance(typ, basestring):
998            typ = [typ]
999        defaults = cls.defaults
1000        if cast is None:
1001            for t in typ:
1002                defaults.pop(t, None)
1003                defaults.pop('_%s' % t, None)
1004        else:
1005            if not callable(cast):
1006                raise TypeError("Cast parameter must be callable")
1007            for t in typ:
1008                defaults[t] = cast
1009                defaults.pop('_%s' % t, None)
1010
1011    def get_attnames(self, typ):
1012        """Return the fields for the given record type.
1013
1014        This method will be replaced with the get_attnames() method of DbTypes.
1015        """
1016        return {}
1017
1018    def dateformat(self):
1019        """Return the current date format.
1020
1021        This method will be replaced with the dateformat() method of DbTypes.
1022        """
1023        return '%Y-%m-%d'
1024
1025    def create_array_cast(self, basecast):
1026        """Create an array typecast for the given base cast."""
1027        def cast(v):
1028            return cast_array(v, basecast)
1029        return cast
1030
1031    def create_record_cast(self, name, fields, casts):
1032        """Create a named record typecast for the given fields and casts."""
1033        record = namedtuple(name, fields)
1034        def cast(v):
1035            return record(*cast_record(v, casts))
1036        return cast
1037
1038
1039def get_typecast(typ):
1040    """Get the global typecast function for the given database type(s)."""
1041    return Typecasts.get_default(typ)
1042
1043
1044def set_typecast(typ, cast):
1045    """Set a global typecast function for the given database type(s).
1046
1047    Note that connections cache cast functions. To be sure a global change
1048    is picked up by a running connection, call db.db_types.reset_typecast().
1049    """
1050    Typecasts.set_default(typ, cast)
1051
1052
1053class DbType(str):
1054    """Class augmenting the simple type name with additional info.
1055
1056    The following additional information is provided:
1057
1058        oid: the PostgreSQL type OID
1059        pgtype: the PostgreSQL type name
1060        regtype: the regular type name
1061        simple: the simple PyGreSQL type name
1062        typtype: b = base type, c = composite type etc.
1063        category: A = Array, b = Boolean, C = Composite etc.
1064        delim: delimiter for array types
1065        relid: corresponding table for composite types
1066        attnames: attributes for composite types
1067    """
1068
1069    @property
1070    def attnames(self):
1071        """Get names and types of the fields of a composite type."""
1072        return self._get_attnames(self)
1073
1074
1075class DbTypes(dict):
1076    """Cache for PostgreSQL data types.
1077
1078    This cache maps type OIDs and names to DbType objects containing
1079    information on the associated database type.
1080    """
1081
1082    _num_types = frozenset('int float num money'
1083        ' int2 int4 int8 float4 float8 numeric money'.split())
1084
1085    def __init__(self, db):
1086        """Initialize type cache for connection."""
1087        super(DbTypes, self).__init__()
1088        self._regtypes = False
1089        self._get_attnames = db.get_attnames
1090        self._typecasts = Typecasts()
1091        self._typecasts.get_attnames = self.get_attnames
1092        self._typecasts.connection = db
1093        db = db.db
1094        self.query = db.query
1095        self.escape_string = db.escape_string
1096        if db.server_version < 80400:
1097            # older remote databases (not officially supported)
1098            self._query_pg_type = (
1099                "SELECT oid, typname, typname::text::regtype,"
1100                " typtype, null as typcategory, typdelim, typrelid"
1101                " FROM pg_type WHERE oid=%s::regtype")
1102        else:
1103            self._query_pg_type = (
1104                "SELECT oid, typname, typname::regtype,"
1105                " typtype, typcategory, typdelim, typrelid"
1106                " FROM pg_type WHERE oid=%s::regtype")
1107
1108    def add(self, oid, pgtype, regtype,
1109               typtype, category, delim, relid):
1110        """Create a PostgreSQL type name with additional info."""
1111        if oid in self:
1112            return self[oid]
1113        simple = 'record' if relid else _simpletypes[pgtype]
1114        typ = DbType(regtype if self._regtypes else simple)
1115        typ.oid = oid
1116        typ.simple = simple
1117        typ.pgtype = pgtype
1118        typ.regtype = regtype
1119        typ.typtype = typtype
1120        typ.category = category
1121        typ.delim = delim
1122        typ.relid = relid
1123        typ._get_attnames = self.get_attnames
1124        return typ
1125
1126    def __missing__(self, key):
1127        """Get the type info from the database if it is not cached."""
1128        try:
1129            q = self._query_pg_type % (_quote_if_unqualified('$1', key),)
1130            res = self.query(q, (key,)).getresult()
1131        except ProgrammingError:
1132            res = None
1133        if not res:
1134            raise KeyError('Type %s could not be found' % key)
1135        res = res[0]
1136        typ = self.add(*res)
1137        self[typ.oid] = self[typ.pgtype] = typ
1138        return typ
1139
1140    def get(self, key, default=None):
1141        """Get the type even if it is not cached."""
1142        try:
1143            return self[key]
1144        except KeyError:
1145            return default
1146
1147    def get_attnames(self, typ):
1148        """Get names and types of the fields of a composite type."""
1149        if not isinstance(typ, DbType):
1150            typ = self.get(typ)
1151            if not typ:
1152                return None
1153        if not typ.relid:
1154            return None
1155        return self._get_attnames(typ.relid, with_oid=False)
1156
1157    def get_typecast(self, typ):
1158        """Get the typecast function for the given database type."""
1159        return self._typecasts.get(typ)
1160
1161    def set_typecast(self, typ, cast):
1162        """Set a typecast function for the specified database type(s)."""
1163        self._typecasts.set(typ, cast)
1164
1165    def reset_typecast(self, typ=None):
1166        """Reset the typecast function for the specified database type(s)."""
1167        self._typecasts.reset(typ)
1168
1169    def typecast(self, value, typ):
1170        """Cast the given value according to the given database type."""
1171        if value is None:
1172            # for NULL values, no typecast is necessary
1173            return None
1174        if not isinstance(typ, DbType):
1175            typ = self.get(typ)
1176            if typ:
1177                typ = typ.pgtype
1178        cast = self.get_typecast(typ) if typ else None
1179        if not cast or cast is str:
1180            # no typecast is necessary
1181            return value
1182        return cast(value)
1183
1184
1185def _namedresult(q):
1186    """Get query result as named tuples."""
1187    row = namedtuple('Row', q.listfields())
1188    return [row(*r) for r in q.getresult()]
1189
1190
1191class _MemoryQuery:
1192    """Class that embodies a given query result."""
1193
1194    def __init__(self, result, fields):
1195        """Create query from given result rows and field names."""
1196        self.result = result
1197        self.fields = fields
1198
1199    def listfields(self):
1200        """Return the stored field names of this query."""
1201        return self.fields
1202
1203    def getresult(self):
1204        """Return the stored result of this query."""
1205        return self.result
1206
1207
1208def _db_error(msg, cls=DatabaseError):
1209    """Return DatabaseError with empty sqlstate attribute."""
1210    error = cls(msg)
1211    error.sqlstate = None
1212    return error
1213
1214
1215def _int_error(msg):
1216    """Return InternalError."""
1217    return _db_error(msg, InternalError)
1218
1219
1220def _prg_error(msg):
1221    """Return ProgrammingError."""
1222    return _db_error(msg, ProgrammingError)
1223
1224
1225# Initialize the C module
1226
1227set_namedresult(_namedresult)
1228set_decimal(Decimal)
1229set_jsondecode(jsondecode)
1230
1231
1232# The notification handler
1233
1234class NotificationHandler(object):
1235    """A PostgreSQL client-side asynchronous notification handler."""
1236
1237    def __init__(self, db, event, callback=None,
1238            arg_dict=None, timeout=None, stop_event=None):
1239        """Initialize the notification handler.
1240
1241        You must pass a PyGreSQL database connection, the name of an
1242        event (notification channel) to listen for and a callback function.
1243
1244        You can also specify a dictionary arg_dict that will be passed as
1245        the single argument to the callback function, and a timeout value
1246        in seconds (a floating point number denotes fractions of seconds).
1247        If it is absent or None, the callers will never time out.  If the
1248        timeout is reached, the callback function will be called with a
1249        single argument that is None.  If you set the timeout to zero,
1250        the handler will poll notifications synchronously and return.
1251
1252        You can specify the name of the event that will be used to signal
1253        the handler to stop listening as stop_event. By default, it will
1254        be the event name prefixed with 'stop_'.
1255        """
1256        self.db = db
1257        self.event = event
1258        self.stop_event = stop_event or 'stop_%s' % event
1259        self.listening = False
1260        self.callback = callback
1261        if arg_dict is None:
1262            arg_dict = {}
1263        self.arg_dict = arg_dict
1264        self.timeout = timeout
1265
1266    def __del__(self):
1267        self.unlisten()
1268
1269    def close(self):
1270        """Stop listening and close the connection."""
1271        if self.db:
1272            self.unlisten()
1273            self.db.close()
1274            self.db = None
1275
1276    def listen(self):
1277        """Start listening for the event and the stop event."""
1278        if not self.listening:
1279            self.db.query('listen "%s"' % self.event)
1280            self.db.query('listen "%s"' % self.stop_event)
1281            self.listening = True
1282
1283    def unlisten(self):
1284        """Stop listening for the event and the stop event."""
1285        if self.listening:
1286            self.db.query('unlisten "%s"' % self.event)
1287            self.db.query('unlisten "%s"' % self.stop_event)
1288            self.listening = False
1289
1290    def notify(self, db=None, stop=False, payload=None):
1291        """Generate a notification.
1292
1293        Optionally, you can pass a payload with the notification.
1294
1295        If you set the stop flag, a stop notification will be sent that
1296        will cause the handler to stop listening.
1297
1298        Note: If the notification handler is running in another thread, you
1299        must pass a different database connection since PyGreSQL database
1300        connections are not thread-safe.
1301        """
1302        if self.listening:
1303            if not db:
1304                db = self.db
1305            q = 'notify "%s"' % (self.stop_event if stop else self.event)
1306            if payload:
1307                q += ", '%s'" % payload
1308            return db.query(q)
1309
1310    def __call__(self):
1311        """Invoke the notification handler.
1312
1313        The handler is a loop that listens for notifications on the event
1314        and stop event channels.  When either of these notifications are
1315        received, its associated 'pid', 'event' and 'extra' (the payload
1316        passed with the notification) are inserted into its arg_dict
1317        dictionary and the callback is invoked with this dictionary as
1318        a single argument.  When the handler receives a stop event, it
1319        stops listening to both events and return.
1320
1321        In the special case that the timeout of the handler has been set
1322        to zero, the handler will poll all events synchronously and return.
1323        If will keep listening until it receives a stop event.
1324
1325        Note: If you run this loop in another thread, don't use the same
1326        database connection for database operations in the main thread.
1327        """
1328        self.listen()
1329        poll = self.timeout == 0
1330        if not poll:
1331            rlist = [self.db.fileno()]
1332        while self.listening:
1333            if poll or select.select(rlist, [], [], self.timeout)[0]:
1334                while self.listening:
1335                    notice = self.db.getnotify()
1336                    if not notice:  # no more messages
1337                        break
1338                    event, pid, extra = notice
1339                    if event not in (self.event, self.stop_event):
1340                        self.unlisten()
1341                        raise _db_error(
1342                            'Listening for "%s" and "%s", but notified of "%s"'
1343                            % (self.event, self.stop_event, event))
1344                    if event == self.stop_event:
1345                        self.unlisten()
1346                    self.arg_dict.update(pid=pid, event=event, extra=extra)
1347                    self.callback(self.arg_dict)
1348                if poll:
1349                    break
1350            else:   # we timed out
1351                self.unlisten()
1352                self.callback(None)
1353
1354
1355def pgnotify(*args, **kw):
1356    """Same as NotificationHandler, under the traditional name."""
1357    warnings.warn("pgnotify is deprecated, use NotificationHandler instead",
1358        DeprecationWarning, stacklevel=2)
1359    return NotificationHandler(*args, **kw)
1360
1361
1362# The actual PostGreSQL database connection interface:
1363
1364class DB:
1365    """Wrapper class for the _pg connection type."""
1366
1367    def __init__(self, *args, **kw):
1368        """Create a new connection
1369
1370        You can pass either the connection parameters or an existing
1371        _pg or pgdb connection. This allows you to use the methods
1372        of the classic pg interface with a DB-API 2 pgdb connection.
1373        """
1374        if not args and len(kw) == 1:
1375            db = kw.get('db')
1376        elif not kw and len(args) == 1:
1377            db = args[0]
1378        else:
1379            db = None
1380        if db:
1381            if isinstance(db, DB):
1382                db = db.db
1383            else:
1384                try:
1385                    db = db._cnx
1386                except AttributeError:
1387                    pass
1388        if not db or not hasattr(db, 'db') or not hasattr(db, 'query'):
1389            db = connect(*args, **kw)
1390            self._closeable = True
1391        else:
1392            self._closeable = False
1393        self.db = db
1394        self.dbname = db.db
1395        self._regtypes = False
1396        self._attnames = {}
1397        self._pkeys = {}
1398        self._privileges = {}
1399        self._args = args, kw
1400        self.adapter = Adapter(self)
1401        self.dbtypes = DbTypes(self)
1402        if db.server_version < 80400:
1403            # support older remote data bases
1404            self._query_attnames = (
1405                "SELECT a.attname, t.oid, t.typname, t.typname::text::regtype,"
1406                " t.typtype, null as typcategory, t.typdelim, t.typrelid"
1407                " FROM pg_attribute a"
1408                " JOIN pg_type t ON t.oid = a.atttypid"
1409                " WHERE a.attrelid = %s::regclass AND %s"
1410                " AND NOT a.attisdropped ORDER BY a.attnum")
1411        else:
1412            self._query_attnames = (
1413                "SELECT a.attname, t.oid, t.typname, t.typname::regtype,"
1414                " t.typtype, t.typcategory, t.typdelim, t.typrelid"
1415                " FROM pg_attribute a"
1416                " JOIN pg_type t ON t.oid = a.atttypid"
1417                " WHERE a.attrelid = %s::regclass AND %s"
1418                " AND NOT a.attisdropped ORDER BY a.attnum")
1419        db.set_cast_hook(self.dbtypes.typecast)
1420        self.debug = None  # For debugging scripts, this can be set
1421            # * to a string format specification (e.g. in CGI set to "%s<BR>"),
1422            # * to a file object to write debug statements or
1423            # * to a callable object which takes a string argument
1424            # * to any other true value to just print debug statements
1425
1426    def __getattr__(self, name):
1427        # All undefined members are same as in underlying connection:
1428        if self.db:
1429            return getattr(self.db, name)
1430        else:
1431            raise _int_error('Connection is not valid')
1432
1433    def __dir__(self):
1434        # Custom dir function including the attributes of the connection:
1435        attrs = set(self.__class__.__dict__)
1436        attrs.update(self.__dict__)
1437        attrs.update(dir(self.db))
1438        return sorted(attrs)
1439
1440    # Context manager methods
1441
1442    def __enter__(self):
1443        """Enter the runtime context. This will start a transaction."""
1444        self.begin()
1445        return self
1446
1447    def __exit__(self, et, ev, tb):
1448        """Exit the runtime context. This will end the transaction."""
1449        if et is None and ev is None and tb is None:
1450            self.commit()
1451        else:
1452            self.rollback()
1453
1454    # Auxiliary methods
1455
1456    def _do_debug(self, *args):
1457        """Print a debug message"""
1458        if self.debug:
1459            s = '\n'.join(str(arg) for arg in args)
1460            if isinstance(self.debug, basestring):
1461                print(self.debug % s)
1462            elif hasattr(self.debug, 'write'):
1463                self.debug.write(s + '\n')
1464            elif callable(self.debug):
1465                self.debug(s)
1466            else:
1467                print(s)
1468
1469    def _escape_qualified_name(self, s):
1470        """Escape a qualified name.
1471
1472        Escapes the name for use as an SQL identifier, unless the
1473        name contains a dot, in which case the name is ambiguous
1474        (could be a qualified name or just a name with a dot in it)
1475        and must be quoted manually by the caller.
1476        """
1477        if '.' not in s:
1478            s = self.escape_identifier(s)
1479        return s
1480
1481    @staticmethod
1482    def _make_bool(d):
1483        """Get boolean value corresponding to d."""
1484        return bool(d) if get_bool() else ('t' if d else 'f')
1485
1486    def _list_params(self, params):
1487        """Create a human readable parameter list."""
1488        return ', '.join('$%d=%r' % (n, v) for n, v in enumerate(params, 1))
1489
1490    # Public methods
1491
1492    # escape_string and escape_bytea exist as methods,
1493    # so we define unescape_bytea as a method as well
1494    unescape_bytea = staticmethod(unescape_bytea)
1495
1496    def decode_json(self, s):
1497        """Decode a JSON string coming from the database."""
1498        return (get_jsondecode() or jsondecode)(s)
1499
1500    def encode_json(self, d):
1501        """Encode a JSON string for use within SQL."""
1502        return jsonencode(d)
1503
1504    def close(self):
1505        """Close the database connection."""
1506        # Wraps shared library function so we can track state.
1507        if self._closeable:
1508            if self.db:
1509                self.db.set_cast_hook(None)
1510                self.db.close()
1511                self.db = None
1512            else:
1513                raise _int_error('Connection already closed')
1514
1515    def reset(self):
1516        """Reset connection with current parameters.
1517
1518        All derived queries and large objects derived from this connection
1519        will not be usable after this call.
1520
1521        """
1522        if self.db:
1523            self.db.reset()
1524        else:
1525            raise _int_error('Connection already closed')
1526
1527    def reopen(self):
1528        """Reopen connection to the database.
1529
1530        Used in case we need another connection to the same database.
1531        Note that we can still reopen a database that we have closed.
1532
1533        """
1534        # There is no such shared library function.
1535        if self._closeable:
1536            db = connect(*self._args[0], **self._args[1])
1537            if self.db:
1538                self.db.set_cast_hook(None)
1539                self.db.close()
1540            self.db = db
1541
1542    def begin(self, mode=None):
1543        """Begin a transaction."""
1544        qstr = 'BEGIN'
1545        if mode:
1546            qstr += ' ' + mode
1547        return self.query(qstr)
1548
1549    start = begin
1550
1551    def commit(self):
1552        """Commit the current transaction."""
1553        return self.query('COMMIT')
1554
1555    end = commit
1556
1557    def rollback(self, name=None):
1558        """Roll back the current transaction."""
1559        qstr = 'ROLLBACK'
1560        if name:
1561            qstr += ' TO ' + name
1562        return self.query(qstr)
1563
1564    abort = rollback
1565
1566    def savepoint(self, name):
1567        """Define a new savepoint within the current transaction."""
1568        return self.query('SAVEPOINT ' + name)
1569
1570    def release(self, name):
1571        """Destroy a previously defined savepoint."""
1572        return self.query('RELEASE ' + name)
1573
1574    def get_parameter(self, parameter):
1575        """Get the value of a run-time parameter.
1576
1577        If the parameter is a string, the return value will also be a string
1578        that is the current setting of the run-time parameter with that name.
1579
1580        You can get several parameters at once by passing a list, set or dict.
1581        When passing a list of parameter names, the return value will be a
1582        corresponding list of parameter settings.  When passing a set of
1583        parameter names, a new dict will be returned, mapping these parameter
1584        names to their settings.  Finally, if you pass a dict as parameter,
1585        its values will be set to the current parameter settings corresponding
1586        to its keys.
1587
1588        By passing the special name 'all' as the parameter, you can get a dict
1589        of all existing configuration parameters.
1590        """
1591        if isinstance(parameter, basestring):
1592            parameter = [parameter]
1593            values = None
1594        elif isinstance(parameter, (list, tuple)):
1595            values = []
1596        elif isinstance(parameter, (set, frozenset)):
1597            values = {}
1598        elif isinstance(parameter, dict):
1599            values = parameter
1600        else:
1601            raise TypeError(
1602                'The parameter must be a string, list, set or dict')
1603        if not parameter:
1604            raise TypeError('No parameter has been specified')
1605        params = {} if isinstance(values, dict) else []
1606        for key in parameter:
1607            param = key.strip().lower() if isinstance(
1608                key, basestring) else None
1609            if not param:
1610                raise TypeError('Invalid parameter')
1611            if param == 'all':
1612                q = 'SHOW ALL'
1613                values = self.db.query(q).getresult()
1614                values = dict(value[:2] for value in values)
1615                break
1616            if isinstance(values, dict):
1617                params[param] = key
1618            else:
1619                params.append(param)
1620        else:
1621            for param in params:
1622                q = 'SHOW %s' % (param,)
1623                value = self.db.query(q).getresult()[0][0]
1624                if values is None:
1625                    values = value
1626                elif isinstance(values, list):
1627                    values.append(value)
1628                else:
1629                    values[params[param]] = value
1630        return values
1631
1632    def set_parameter(self, parameter, value=None, local=False):
1633        """Set the value of a run-time parameter.
1634
1635        If the parameter and the value are strings, the run-time parameter
1636        will be set to that value.  If no value or None is passed as a value,
1637        then the run-time parameter will be restored to its default value.
1638
1639        You can set several parameters at once by passing a list of parameter
1640        names, together with a single value that all parameters should be
1641        set to or with a corresponding list of values.  You can also pass
1642        the parameters as a set if you only provide a single value.
1643        Finally, you can pass a dict with parameter names as keys.  In this
1644        case, you should not pass a value, since the values for the parameters
1645        will be taken from the dict.
1646
1647        By passing the special name 'all' as the parameter, you can reset
1648        all existing settable run-time parameters to their default values.
1649
1650        If you set local to True, then the command takes effect for only the
1651        current transaction.  After commit() or rollback(), the session-level
1652        setting takes effect again.  Setting local to True will appear to
1653        have no effect if it is executed outside a transaction, since the
1654        transaction will end immediately.
1655        """
1656        if isinstance(parameter, basestring):
1657            parameter = {parameter: value}
1658        elif isinstance(parameter, (list, tuple)):
1659            if isinstance(value, (list, tuple)):
1660                parameter = dict(zip(parameter, value))
1661            else:
1662                parameter = dict.fromkeys(parameter, value)
1663        elif isinstance(parameter, (set, frozenset)):
1664            if isinstance(value, (list, tuple, set, frozenset)):
1665                value = set(value)
1666                if len(value) == 1:
1667                    value = value.pop()
1668            if not(value is None or isinstance(value, basestring)):
1669                raise ValueError('A single value must be specified'
1670                    ' when parameter is a set')
1671            parameter = dict.fromkeys(parameter, value)
1672        elif isinstance(parameter, dict):
1673            if value is not None:
1674                raise ValueError('A value must not be specified'
1675                    ' when parameter is a dictionary')
1676        else:
1677            raise TypeError(
1678                'The parameter must be a string, list, set or dict')
1679        if not parameter:
1680            raise TypeError('No parameter has been specified')
1681        params = {}
1682        for key, value in parameter.items():
1683            param = key.strip().lower() if isinstance(
1684                key, basestring) else None
1685            if not param:
1686                raise TypeError('Invalid parameter')
1687            if param == 'all':
1688                if value is not None:
1689                    raise ValueError('A value must ot be specified'
1690                        " when parameter is 'all'")
1691                params = {'all': None}
1692                break
1693            params[param] = value
1694        local = ' LOCAL' if local else ''
1695        for param, value in params.items():
1696            if value is None:
1697                q = 'RESET%s %s' % (local, param)
1698            else:
1699                q = 'SET%s %s TO %s' % (local, param, value)
1700            self._do_debug(q)
1701            self.db.query(q)
1702
1703    def query(self, command, *args):
1704        """Execute a SQL command string.
1705
1706        This method simply sends a SQL query to the database.  If the query is
1707        an insert statement that inserted exactly one row into a table that
1708        has OIDs, the return value is the OID of the newly inserted row.
1709        If the query is an update or delete statement, or an insert statement
1710        that did not insert exactly one row in a table with OIDs, then the
1711        number of rows affected is returned as a string.  If it is a statement
1712        that returns rows as a result (usually a select statement, but maybe
1713        also an "insert/update ... returning" statement), this method returns
1714        a Query object that can be accessed via getresult() or dictresult()
1715        or simply printed.  Otherwise, it returns `None`.
1716
1717        The query can contain numbered parameters of the form $1 in place
1718        of any data constant.  Arguments given after the query string will
1719        be substituted for the corresponding numbered parameter.  Parameter
1720        values can also be given as a single list or tuple argument.
1721        """
1722        # Wraps shared library function for debugging.
1723        if not self.db:
1724            raise _int_error('Connection is not valid')
1725        if args:
1726            self._do_debug(command, args)
1727            return self.db.query(command, args)
1728        self._do_debug(command)
1729        return self.db.query(command)
1730
1731    def query_formatted(self, command, parameters, types=None, inline=False):
1732        """Execute a formatted SQL command string.
1733
1734        Similar to query, but using Python format placeholders of the form
1735        %s or %(names)s instead of PostgreSQL placeholders of the form $1.
1736        The parameters must be passed as a tuple, list or dict.  You can
1737        also pass a corresponding tuple, list or dict of database types in
1738        order to format the parameters properly in case there is ambiguity.
1739
1740        If you set inline to True, the parameters will be sent to the database
1741        embedded in the SQL command, otherwise they will be sent separately.
1742        """
1743        return self.query(*self.adapter.format_query(
1744            command, parameters, types, inline))
1745
1746    def pkey(self, table, composite=False, flush=False):
1747        """Get or set the primary key of a table.
1748
1749        Single primary keys are returned as strings unless you
1750        set the composite flag.  Composite primary keys are always
1751        represented as tuples.  Note that this raises a KeyError
1752        if the table does not have a primary key.
1753
1754        If flush is set then the internal cache for primary keys will
1755        be flushed.  This may be necessary after the database schema or
1756        the search path has been changed.
1757        """
1758        pkeys = self._pkeys
1759        if flush:
1760            pkeys.clear()
1761            self._do_debug('The pkey cache has been flushed')
1762        try:  # cache lookup
1763            pkey = pkeys[table]
1764        except KeyError:  # cache miss, check the database
1765            q = ("SELECT a.attname, a.attnum, i.indkey FROM pg_index i"
1766                " JOIN pg_attribute a ON a.attrelid = i.indrelid"
1767                " AND a.attnum = ANY(i.indkey)"
1768                " AND NOT a.attisdropped"
1769                " WHERE i.indrelid=%s::regclass"
1770                " AND i.indisprimary ORDER BY a.attnum") % (
1771                    _quote_if_unqualified('$1', table),)
1772            pkey = self.db.query(q, (table,)).getresult()
1773            if not pkey:
1774                raise KeyError('Table %s has no primary key' % table)
1775            # we want to use the order defined in the primary key index here,
1776            # not the order as defined by the columns in the table
1777            if len(pkey) > 1:
1778                indkey = pkey[0][2]
1779                pkey = sorted(pkey, key=lambda row: indkey.index(row[1]))
1780                pkey = tuple(row[0] for row in pkey)
1781            else:
1782                pkey = pkey[0][0]
1783            pkeys[table] = pkey  # cache it
1784        if composite and not isinstance(pkey, tuple):
1785            pkey = (pkey,)
1786        return pkey
1787
1788    def get_databases(self):
1789        """Get list of databases in the system."""
1790        return [s[0] for s in
1791            self.db.query('SELECT datname FROM pg_database').getresult()]
1792
1793    def get_relations(self, kinds=None, system=False):
1794        """Get list of relations in connected database of specified kinds.
1795
1796        If kinds is None or empty, all kinds of relations are returned.
1797        Otherwise kinds can be a string or sequence of type letters
1798        specifying which kind of relations you want to list.
1799
1800        Set the system flag if you want to get the system relations as well.
1801        """
1802        where = []
1803        if kinds:
1804            where.append("r.relkind IN (%s)" %
1805                ','.join("'%s'" % k for k in kinds))
1806        if not system:
1807            where.append("s.nspname NOT SIMILAR"
1808                " TO 'pg/_%|information/_schema' ESCAPE '/'")
1809        where = " WHERE %s" % ' AND '.join(where) if where else ''
1810        q = ("SELECT quote_ident(s.nspname)||'.'||quote_ident(r.relname)"
1811            " FROM pg_class r"
1812            " JOIN pg_namespace s ON s.oid = r.relnamespace%s"
1813            " ORDER BY s.nspname, r.relname") % where
1814        return [r[0] for r in self.db.query(q).getresult()]
1815
1816    def get_tables(self, system=False):
1817        """Return list of tables in connected database.
1818
1819        Set the system flag if you want to get the system tables as well.
1820        """
1821        return self.get_relations('r', system)
1822
1823    def get_attnames(self, table, with_oid=True, flush=False):
1824        """Given the name of a table, dig out the set of attribute names.
1825
1826        Returns a read-only dictionary of attribute names (the names are
1827        the keys, the values are the names of the attributes' types)
1828        with the column names in the proper order if you iterate over it.
1829
1830        If flush is set, then the internal cache for attribute names will
1831        be flushed. This may be necessary after the database schema or
1832        the search path has been changed.
1833
1834        By default, only a limited number of simple types will be returned.
1835        You can get the regular types after calling use_regtypes(True).
1836        """
1837        attnames = self._attnames
1838        if flush:
1839            attnames.clear()
1840            self._do_debug('The attnames cache has been flushed')
1841        try:  # cache lookup
1842            names = attnames[table]
1843        except KeyError:  # cache miss, check the database
1844            q = "a.attnum > 0"
1845            if with_oid:
1846                q = "(%s OR a.attname = 'oid')" % q
1847            q = self._query_attnames % (_quote_if_unqualified('$1', table), q)
1848            names = self.db.query(q, (table,)).getresult()
1849            types = self.dbtypes
1850            names = ((name[0], types.add(*name[1:])) for name in names)
1851            names = AttrDict(names)
1852            attnames[table] = names  # cache it
1853        return names
1854
1855    def use_regtypes(self, regtypes=None):
1856        """Use regular type names instead of simplified type names."""
1857        if regtypes is None:
1858            return self.dbtypes._regtypes
1859        else:
1860            regtypes = bool(regtypes)
1861            if regtypes != self.dbtypes._regtypes:
1862                self.dbtypes._regtypes = regtypes
1863                self._attnames.clear()
1864                self.dbtypes.clear()
1865            return regtypes
1866
1867    def has_table_privilege(self, table, privilege='select', flush=False):
1868        """Check whether current user has specified table privilege.
1869
1870        If flush is set, then the internal cache for table privileges will
1871        be flushed. This may be necessary after privileges have been changed.
1872        """
1873        privileges = self._privileges
1874        if flush:
1875            privileges.clear()
1876            self._do_debug('The privileges cache has been flushed')
1877        privilege = privilege.lower()
1878        try:  # ask cache
1879            ret = privileges[table, privilege]
1880        except KeyError:  # cache miss, ask the database
1881            q = "SELECT has_table_privilege(%s, $2)" % (
1882                _quote_if_unqualified('$1', table),)
1883            q = self.db.query(q, (table, privilege))
1884            ret = q.getresult()[0][0] == self._make_bool(True)
1885            privileges[table, privilege] = ret  # cache it
1886        return ret
1887
1888    def get(self, table, row, keyname=None):
1889        """Get a row from a database table or view.
1890
1891        This method is the basic mechanism to get a single row.  It assumes
1892        that the keyname specifies a unique row.  It must be the name of a
1893        single column or a tuple of column names.  If the keyname is not
1894        specified, then the primary key for the table is used.
1895
1896        If row is a dictionary, then the value for the key is taken from it.
1897        Otherwise, the row must be a single value or a tuple of values
1898        corresponding to the passed keyname or primary key.  The fetched row
1899        from the table will be returned as a new dictionary or used to replace
1900        the existing values when row was passed as a dictionary.
1901
1902        The OID is also put into the dictionary if the table has one, but
1903        in order to allow the caller to work with multiple tables, it is
1904        munged as "oid(table)" using the actual name of the table.
1905        """
1906        if table.endswith('*'):  # hint for descendant tables can be ignored
1907            table = table[:-1].rstrip()
1908        attnames = self.get_attnames(table)
1909        qoid = _oid_key(table) if 'oid' in attnames else None
1910        if keyname and isinstance(keyname, basestring):
1911            keyname = (keyname,)
1912        if qoid and isinstance(row, dict) and qoid in row and 'oid' not in row:
1913            row['oid'] = row[qoid]
1914        if not keyname:
1915            try:  # if keyname is not specified, try using the primary key
1916                keyname = self.pkey(table, True)
1917            except KeyError:  # the table has no primary key
1918                # try using the oid instead
1919                if qoid and isinstance(row, dict) and 'oid' in row:
1920                    keyname = ('oid',)
1921                else:
1922                    raise _prg_error('Table %s has no primary key' % table)
1923            else:  # the table has a primary key
1924                # check whether all key columns have values
1925                if isinstance(row, dict) and not set(keyname).issubset(row):
1926                    # try using the oid instead
1927                    if qoid and 'oid' in row:
1928                        keyname = ('oid',)
1929                    else:
1930                        raise KeyError(
1931                            'Missing value in row for specified keyname')
1932        if not isinstance(row, dict):
1933            if not isinstance(row, (tuple, list)):
1934                row = [row]
1935            if len(keyname) != len(row):
1936                raise KeyError(
1937                    'Differing number of items in keyname and row')
1938            row = dict(zip(keyname, row))
1939        params = self.adapter.parameter_list()
1940        adapt = params.add
1941        col = self.escape_identifier
1942        what = 'oid, *' if qoid else '*'
1943        where = ' AND '.join('%s = %s' % (
1944            col(k), adapt(row[k], attnames[k])) for k in keyname)
1945        if 'oid' in row:
1946            if qoid:
1947                row[qoid] = row['oid']
1948            del row['oid']
1949        q = 'SELECT %s FROM %s WHERE %s LIMIT 1' % (
1950            what, self._escape_qualified_name(table), where)
1951        self._do_debug(q, params)
1952        q = self.db.query(q, params)
1953        res = q.dictresult()
1954        if not res:
1955            raise _db_error('No such record in %s\nwhere %s\nwith %s' % (
1956                table, where, self._list_params(params)))
1957        for n, value in res[0].items():
1958            if qoid and n == 'oid':
1959                n = qoid
1960            row[n] = value
1961        return row
1962
1963    def insert(self, table, row=None, **kw):
1964        """Insert a row into a database table.
1965
1966        This method inserts a row into a table.  The name of the table must
1967        be passed as the first parameter.  The other parameters are used for
1968        providing the data of the row that shall be inserted into the table.
1969        If a dictionary is supplied as the second parameter, it starts with
1970        that.  Otherwise it uses a blank dictionary. Either way the dictionary
1971        is updated from the keywords.
1972
1973        The dictionary is then reloaded with the values actually inserted in
1974        order to pick up values modified by rules, triggers, etc.
1975        """
1976        if table.endswith('*'):  # hint for descendant tables can be ignored
1977            table = table[:-1].rstrip()
1978        if row is None:
1979            row = {}
1980        row.update(kw)
1981        if 'oid' in row:
1982            del row['oid']  # do not insert oid
1983        attnames = self.get_attnames(table)
1984        qoid = _oid_key(table) if 'oid' in attnames else None
1985        params = self.adapter.parameter_list()
1986        adapt = params.add
1987        col = self.escape_identifier
1988        names, values = [], []
1989        for n in attnames:
1990            if n in row:
1991                names.append(col(n))
1992                values.append(adapt(row[n], attnames[n]))
1993        if not names:
1994            raise _prg_error('No column found that can be inserted')
1995        names, values = ', '.join(names), ', '.join(values)
1996        ret = 'oid, *' if qoid else '*'
1997        q = 'INSERT INTO %s (%s) VALUES (%s) RETURNING %s' % (
1998            self._escape_qualified_name(table), names, values, ret)
1999        self._do_debug(q, params)
2000        q = self.db.query(q, params)
2001        res = q.dictresult()
2002        if res:  # this should always be true
2003            for n, value in res[0].items():
2004                if qoid and n == 'oid':
2005                    n = qoid
2006                row[n] = value
2007        return row
2008
2009    def update(self, table, row=None, **kw):
2010        """Update an existing row in a database table.
2011
2012        Similar to insert, but updates an existing row.  The update is based
2013        on the primary key of the table or the OID value as munged by get()
2014        or passed as keyword.  The OID will take precedence if provided, so
2015        that it is possible to update the primary key itself.
2016
2017        The dictionary is then modified to reflect any changes caused by the
2018        update due to triggers, rules, default values, etc.
2019        """
2020        if table.endswith('*'):
2021            table = table[:-1].rstrip()  # need parent table name
2022        attnames = self.get_attnames(table)
2023        qoid = _oid_key(table) if 'oid' in attnames else None
2024        if row is None:
2025            row = {}
2026        elif 'oid' in row:
2027            del row['oid']  # only accept oid key from named args for safety
2028        row.update(kw)
2029        if qoid and qoid in row and 'oid' not in row:
2030            row['oid'] = row[qoid]
2031        if qoid and 'oid' in row:  # try using the oid
2032            keyname = ('oid',)
2033        else:  # try using the primary key
2034            try:
2035                keyname = self.pkey(table, True)
2036            except KeyError:  # the table has no primary key
2037                raise _prg_error('Table %s has no primary key' % table)
2038            # check whether all key columns have values
2039            if not set(keyname).issubset(row):
2040                raise KeyError('Missing value for primary key in row')
2041        params = self.adapter.parameter_list()
2042        adapt = params.add
2043        col = self.escape_identifier
2044        where = ' AND '.join('%s = %s' % (
2045            col(k), adapt(row[k], attnames[k])) for k in keyname)
2046        if 'oid' in row:
2047            if qoid:
2048                row[qoid] = row['oid']
2049            del row['oid']
2050        values = []
2051        keyname = set(keyname)
2052        for n in attnames:
2053            if n in row and n not in keyname:
2054                values.append('%s = %s' % (col(n), adapt(row[n], attnames[n])))
2055        if not values:
2056            return row
2057        values = ', '.join(values)
2058        ret = 'oid, *' if qoid else '*'
2059        q = 'UPDATE %s SET %s WHERE %s RETURNING %s' % (
2060            self._escape_qualified_name(table), values, where, ret)
2061        self._do_debug(q, params)
2062        q = self.db.query(q, params)
2063        res = q.dictresult()
2064        if res:  # may be empty when row does not exist
2065            for n, value in res[0].items():
2066                if qoid and n == 'oid':
2067                    n = qoid
2068                row[n] = value
2069        return row
2070
2071    def upsert(self, table, row=None, **kw):
2072        """Insert a row into a database table with conflict resolution
2073
2074        This method inserts a row into a table, but instead of raising a
2075        ProgrammingError exception in case a row with the same primary key
2076        already exists, an update will be executed instead.  This will be
2077        performed as a single atomic operation on the database, so race
2078        conditions can be avoided.
2079
2080        Like the insert method, the first parameter is the name of the
2081        table and the second parameter can be used to pass the values to
2082        be inserted as a dictionary.
2083
2084        Unlike the insert und update statement, keyword parameters are not
2085        used to modify the dictionary, but to specify which columns shall
2086        be updated in case of a conflict, and in which way:
2087
2088        A value of False or None means the column shall not be updated,
2089        a value of True means the column shall be updated with the value
2090        that has been proposed for insertion, i.e. has been passed as value
2091        in the dictionary.  Columns that are not specified by keywords but
2092        appear as keys in the dictionary are also updated like in the case
2093        keywords had been passed with the value True.
2094
2095        So if in the case of a conflict you want to update every column that
2096        has been passed in the dictionary row, you would call upsert(table, row).
2097        If you don't want to do anything in case of a conflict, i.e. leave
2098        the existing row as it is, call upsert(table, row, **dict.fromkeys(row)).
2099
2100        If you need more fine-grained control of what gets updated, you can
2101        also pass strings in the keyword parameters.  These strings will
2102        be used as SQL expressions for the update columns.  In these
2103        expressions you can refer to the value that already exists in
2104        the table by prefixing the column name with "included.", and to
2105        the value that has been proposed for insertion by prefixing the
2106        column name with the "excluded."
2107
2108        The dictionary is modified in any case to reflect the values in
2109        the database after the operation has completed.
2110
2111        Note: The method uses the PostgreSQL "upsert" feature which is
2112        only available since PostgreSQL 9.5.
2113        """
2114        if table.endswith('*'):  # hint for descendant tables can be ignored
2115            table = table[:-1].rstrip()
2116        if row is None:
2117            row = {}
2118        if 'oid' in row:
2119            del row['oid']  # do not insert oid
2120        if 'oid' in kw:
2121            del kw['oid']  # do not update oid
2122        attnames = self.get_attnames(table)
2123        qoid = _oid_key(table) if 'oid' in attnames else None
2124        params = self.adapter.parameter_list()
2125        adapt = params.add
2126        col = self.escape_identifier
2127        names, values, updates = [], [], []
2128        for n in attnames:
2129            if n in row:
2130                names.append(col(n))
2131                values.append(adapt(row[n], attnames[n]))
2132        names, values = ', '.join(names), ', '.join(values)
2133        try:
2134            keyname = self.pkey(table, True)
2135        except KeyError:
2136            raise _prg_error('Table %s has no primary key' % table)
2137        target = ', '.join(col(k) for k in keyname)
2138        update = []
2139        keyname = set(keyname)
2140        keyname.add('oid')
2141        for n in attnames:
2142            if n not in keyname:
2143                value = kw.get(n, True)
2144                if value:
2145                    if not isinstance(value, basestring):
2146                        value = 'excluded.%s' % col(n)
2147                    update.append('%s = %s' % (col(n), value))
2148        if not values:
2149            return row
2150        do = 'update set %s' % ', '.join(update) if update else 'nothing'
2151        ret = 'oid, *' if qoid else '*'
2152        q = ('INSERT INTO %s AS included (%s) VALUES (%s)'
2153            ' ON CONFLICT (%s) DO %s RETURNING %s') % (
2154                self._escape_qualified_name(table), names, values,
2155                target, do, ret)
2156        self._do_debug(q, params)
2157        try:
2158            q = self.db.query(q, params)
2159        except ProgrammingError:
2160            if self.server_version < 90500:
2161                raise _prg_error(
2162                    'Upsert operation is not supported by PostgreSQL version')
2163            raise  # re-raise original error
2164        res = q.dictresult()
2165        if res:  # may be empty with "do nothing"
2166            for n, value in res[0].items():
2167                if qoid and n == 'oid':
2168                    n = qoid
2169                row[n] = value
2170        else:
2171            self.get(table, row)
2172        return row
2173
2174    def clear(self, table, row=None):
2175        """Clear all the attributes to values determined by the types.
2176
2177        Numeric types are set to 0, Booleans are set to false, and everything
2178        else is set to the empty string.  If the row argument is present,
2179        it is used as the row dictionary and any entries matching attribute
2180        names are cleared with everything else left unchanged.
2181        """
2182        # At some point we will need a way to get defaults from a table.
2183        if row is None:
2184            row = {}  # empty if argument is not present
2185        attnames = self.get_attnames(table)
2186        for n, t in attnames.items():
2187            if n == 'oid':
2188                continue
2189            t = t.simple
2190            if t in DbTypes._num_types:
2191                row[n] = 0
2192            elif t == 'bool':
2193                row[n] = self._make_bool(False)
2194            else:
2195                row[n] = ''
2196        return row
2197
2198    def delete(self, table, row=None, **kw):
2199        """Delete an existing row in a database table.
2200
2201        This method deletes the row from a table.  It deletes based on the
2202        primary key of the table or the OID value as munged by get() or
2203        passed as keyword.  The OID will take precedence if provided.
2204
2205        The return value is the number of deleted rows (i.e. 0 if the row
2206        did not exist and 1 if the row was deleted).
2207
2208        Note that if the row cannot be deleted because e.g. it is still
2209        referenced by another table, this method raises a ProgrammingError.
2210        """
2211        if table.endswith('*'):  # hint for descendant tables can be ignored
2212            table = table[:-1].rstrip()
2213        attnames = self.get_attnames(table)
2214        qoid = _oid_key(table) if 'oid' in attnames else None
2215        if row is None:
2216            row = {}
2217        elif 'oid' in row:
2218            del row['oid']  # only accept oid key from named args for safety
2219        row.update(kw)
2220        if qoid and qoid in row and 'oid' not in row:
2221            row['oid'] = row[qoid]
2222        if qoid and 'oid' in row:  # try using the oid
2223            keyname = ('oid',)
2224        else:  # try using the primary key
2225            try:
2226                keyname = self.pkey(table, True)
2227            except KeyError:  # the table has no primary key
2228                raise _prg_error('Table %s has no primary key' % table)
2229            # check whether all key columns have values
2230            if not set(keyname).issubset(row):
2231                raise KeyError('Missing value for primary key in row')
2232        params = self.adapter.parameter_list()
2233        adapt = params.add
2234        col = self.escape_identifier
2235        where = ' AND '.join('%s = %s' % (
2236            col(k), adapt(row[k], attnames[k])) for k in keyname)
2237        if 'oid' in row:
2238            if qoid:
2239                row[qoid] = row['oid']
2240            del row['oid']
2241        q = 'DELETE FROM %s WHERE %s' % (
2242            self._escape_qualified_name(table), where)
2243        self._do_debug(q, params)
2244        res = self.db.query(q, params)
2245        return int(res)
2246
2247    def truncate(self, table, restart=False, cascade=False, only=False):
2248        """Empty a table or set of tables.
2249
2250        This method quickly removes all rows from the given table or set
2251        of tables.  It has the same effect as an unqualified DELETE on each
2252        table, but since it does not actually scan the tables it is faster.
2253        Furthermore, it reclaims disk space immediately, rather than requiring
2254        a subsequent VACUUM operation. This is most useful on large tables.
2255
2256        If restart is set to True, sequences owned by columns of the truncated
2257        table(s) are automatically restarted.  If cascade is set to True, it
2258        also truncates all tables that have foreign-key references to any of
2259        the named tables.  If the parameter only is not set to True, all the
2260        descendant tables (if any) will also be truncated. Optionally, a '*'
2261        can be specified after the table name to explicitly indicate that
2262        descendant tables are included.
2263        """
2264        if isinstance(table, basestring):
2265            only = {table: only}
2266            table = [table]
2267        elif isinstance(table, (list, tuple)):
2268            if isinstance(only, (list, tuple)):
2269                only = dict(zip(table, only))
2270            else:
2271                only = dict.fromkeys(table, only)
2272        elif isinstance(table, (set, frozenset)):
2273            only = dict.fromkeys(table, only)
2274        else:
2275            raise TypeError('The table must be a string, list or set')
2276        if not (restart is None or isinstance(restart, (bool, int))):
2277            raise TypeError('Invalid type for the restart option')
2278        if not (cascade is None or isinstance(cascade, (bool, int))):
2279            raise TypeError('Invalid type for the cascade option')
2280        tables = []
2281        for t in table:
2282            u = only.get(t)
2283            if not (u is None or isinstance(u, (bool, int))):
2284                raise TypeError('Invalid type for the only option')
2285            if t.endswith('*'):
2286                if u:
2287                    raise ValueError(
2288                        'Contradictory table name and only options')
2289                t = t[:-1].rstrip()
2290            t = self._escape_qualified_name(t)
2291            if u:
2292                t = 'ONLY %s' % t
2293            tables.append(t)
2294        q = ['TRUNCATE', ', '.join(tables)]
2295        if restart:
2296            q.append('RESTART IDENTITY')
2297        if cascade:
2298            q.append('CASCADE')
2299        q = ' '.join(q)
2300        self._do_debug(q)
2301        return self.db.query(q)
2302
2303    def get_as_list(self, table, what=None, where=None,
2304            order=None, limit=None, offset=None, scalar=False):
2305        """Get a table as a list.
2306
2307        This gets a convenient representation of the table as a list
2308        of named tuples in Python.  You only need to pass the name of
2309        the table (or any other SQL expression returning rows).  Note that
2310        by default this will return the full content of the table which
2311        can be huge and overflow your memory.  However, you can control
2312        the amount of data returned using the other optional parameters.
2313
2314        The parameter 'what' can restrict the query to only return a
2315        subset of the table columns.  It can be a string, list or a tuple.
2316        The parameter 'where' can restrict the query to only return a
2317        subset of the table rows.  It can be a string, list or a tuple
2318        of SQL expressions that all need to be fulfilled.  The parameter
2319        'order' specifies the ordering of the rows.  It can also be a
2320        other string, list or a tuple.  If no ordering is specified,
2321        the result will be ordered by the primary key(s) or all columns
2322        if no primary key exists.  You can set 'order' to False if you
2323        don't care about the ordering.  The parameters 'limit' and 'offset'
2324        can be integers specifying the maximum number of rows returned
2325        and a number of rows skipped over.
2326
2327        If you set the 'scalar' option to True, then instead of the
2328        named tuples you will get the first items of these tuples.
2329        This is useful if the result has only one column anyway.
2330        """
2331        if not table:
2332            raise TypeError('The table name is missing')
2333        if what:
2334            if isinstance(what, (list, tuple)):
2335                what = ', '.join(map(str, what))
2336            if order is None:
2337                order = what
2338        else:
2339            what = '*'
2340        q = ['SELECT', what, 'FROM', table]
2341        if where:
2342            if isinstance(where, (list, tuple)):
2343                where = ' AND '.join(map(str, where))
2344            q.extend(['WHERE', where])
2345        if order is None:
2346            try:
2347                order = self.pkey(table, True)
2348            except (KeyError, ProgrammingError):
2349                try:
2350                    order = list(self.get_attnames(table))
2351                except (KeyError, ProgrammingError):
2352                    pass
2353        if order:
2354            if isinstance(order, (list, tuple)):
2355                order = ', '.join(map(str, order))
2356            q.extend(['ORDER BY', order])
2357        if limit:
2358            q.append('LIMIT %d' % limit)
2359        if offset:
2360            q.append('OFFSET %d' % offset)
2361        q = ' '.join(q)
2362        self._do_debug(q)
2363        q = self.db.query(q)
2364        res = q.namedresult()
2365        if res and scalar:
2366            res = [row[0] for row in res]
2367        return res
2368
2369    def get_as_dict(self, table, keyname=None, what=None, where=None,
2370            order=None, limit=None, offset=None, scalar=False):
2371        """Get a table as a dictionary.
2372
2373        This method is similar to get_as_list(), but returns the table
2374        as a Python dict instead of a Python list, which can be even
2375        more convenient. The primary key column(s) of the table will
2376        be used as the keys of the dictionary, while the other column(s)
2377        will be the corresponding values.  The keys will be named tuples
2378        if the table has a composite primary key.  The rows will be also
2379        named tuples unless the 'scalar' option has been set to True.
2380        With the optional parameter 'keyname' you can specify an alternative
2381        set of columns to be used as the keys of the dictionary.  It must
2382        be set as a string, list or a tuple.
2383
2384        If the Python version supports it, the dictionary will be an
2385        OrderedDict using the order specified with the 'order' parameter
2386        or the key column(s) if not specified.  You can set 'order' to False
2387        if you don't care about the ordering.  In this case the returned
2388        dictionary will be an ordinary one.
2389        """
2390        if not table:
2391            raise TypeError('The table name is missing')
2392        if not keyname:
2393            try:
2394                keyname = self.pkey(table, True)
2395            except (KeyError, ProgrammingError):
2396                raise _prg_error('Table %s has no primary key' % table)
2397        if isinstance(keyname, basestring):
2398            keyname = [keyname]
2399        elif not isinstance(keyname, (list, tuple)):
2400            raise KeyError('The keyname must be a string, list or tuple')
2401        if what:
2402            if isinstance(what, (list, tuple)):
2403                what = ', '.join(map(str, what))
2404            if order is None:
2405                order = what
2406        else:
2407            what = '*'
2408        q = ['SELECT', what, 'FROM', table]
2409        if where:
2410            if isinstance(where, (list, tuple)):
2411                where = ' AND '.join(map(str, where))
2412            q.extend(['WHERE', where])
2413        if order is None:
2414            order = keyname
2415        if order:
2416            if isinstance(order, (list, tuple)):
2417                order = ', '.join(map(str, order))
2418            q.extend(['ORDER BY', order])
2419        if limit:
2420            q.append('LIMIT %d' % limit)
2421        if offset:
2422            q.append('OFFSET %d' % offset)
2423        q = ' '.join(q)
2424        self._do_debug(q)
2425        q = self.db.query(q)
2426        res = q.getresult()
2427        cls = OrderedDict if order else dict
2428        if not res:
2429            return cls()
2430        keyset = set(keyname)
2431        fields = q.listfields()
2432        if not keyset.issubset(fields):
2433            raise KeyError('Missing keyname in row')
2434        keyind, rowind = [], []
2435        for i, f in enumerate(fields):
2436            (keyind if f in keyset else rowind).append(i)
2437        keytuple = len(keyind) > 1
2438        getkey = itemgetter(*keyind)
2439        keys = map(getkey, res)
2440        if scalar:
2441            rowind = rowind[:1]
2442            rowtuple = False
2443        else:
2444            rowtuple = len(rowind) > 1
2445        if scalar or rowtuple:
2446            getrow = itemgetter(*rowind)
2447        else:
2448            rowind = rowind[0]
2449            getrow = lambda row: (row[rowind],)
2450            rowtuple = True
2451        rows = map(getrow, res)
2452        if keytuple or rowtuple:
2453            namedresult = get_namedresult()
2454            if namedresult:
2455                if keytuple:
2456                    keys = namedresult(_MemoryQuery(keys, keyname))
2457                if rowtuple:
2458                    fields = [f for f in fields if f not in keyset]
2459                    rows = namedresult(_MemoryQuery(rows, fields))
2460        return cls(zip(keys, rows))
2461
2462    def notification_handler(self,
2463            event, callback, arg_dict=None, timeout=None, stop_event=None):
2464        """Get notification handler that will run the given callback."""
2465        return NotificationHandler(self,
2466            event, callback, arg_dict, timeout, stop_event)
2467
2468
2469# if run as script, print some information
2470
2471if __name__ == '__main__':
2472    print('PyGreSQL version' + version)
2473    print('')
2474    print(__doc__)
Note: See TracBrowser for help on using the repository browser.