source: trunk/pg.py

Last change on this file was 964, checked in by cito, 6 weeks ago

Graceful exit of DB destructor on closed connection

Also, in the 5.1 branch, the DB wrapper can now be closed
(without closing the underlying connection) and reopened
(reusing the same connection).

This fixed GitHub? issue #11

  • Property svn:keywords set to Id
File size: 96.7 KB
Line 
1#! /usr/bin/python
2#
3# pg.py
4#
5# $Id: pg.py 964 2019-01-05 13:51:23Z cito $
6#
7
8"""PyGreSQL classic interface.
9
10This pg module implements some basic database management stuff.
11It includes the _pg module and builds on it, providing the higher
12level wrapper class named DB with additional functionality.
13This is known as the "classic" ("old style") PyGreSQL interface.
14For a DB-API 2 compliant interface use the newer pgdb module.
15"""
16
17# Copyright (c) 1997-2019 by D'Arcy J.M. Cain.
18#
19# Contributions made by Ch. Zwerschke and others.
20#
21# The notification handler is based on pgnotify which is
22# Copyright (c) 2001 Ng Pheng Siong. All rights reserved.
23#
24# Permission to use, copy, modify, and distribute this software and its
25# documentation for any purpose and without fee is hereby granted,
26# provided that the above copyright notice appear in all copies and that
27# both that copyright notice and this permission notice appear in
28# supporting documentation.
29
30from __future__ import print_function, division
31
32from _pg import *
33
34__version__ = version
35
36import select
37import warnings
38import weakref
39
40from datetime import date, time, datetime, timedelta, tzinfo
41from decimal import Decimal
42from math import isnan, isinf
43from collections import namedtuple
44from keyword import iskeyword
45from operator import itemgetter
46from functools import partial
47from re import compile as regex
48from json import loads as jsondecode, dumps as jsonencode
49from uuid import UUID
50
51try:
52    long
53except NameError:  # Python >= 3.0
54    long = int
55
56try:
57    basestring
58except NameError:  # Python >= 3.0
59    basestring = (str, bytes)
60
61try:
62    from functools import lru_cache
63except ImportError:  # Python < 3.2
64    from functools import update_wrapper
65    try:
66        from _thread import RLock
67    except ImportError:
68        class RLock:  # for builds without threads
69            def __enter__(self): pass
70
71            def __exit__(self, exctype, excinst, exctb): pass
72
73    def lru_cache(maxsize=128):
74        """Simplified functools.lru_cache decorator for one argument."""
75
76        def decorator(function):
77            sentinel = object()
78            cache = {}
79            get = cache.get
80            lock = RLock()
81            root = []
82            root_full = [root, False]
83            root[:] = [root, root, None, None]
84
85            if maxsize == 0:
86
87                def wrapper(arg):
88                    res = function(arg)
89                    return res
90
91            elif maxsize is None:
92
93                def wrapper(arg):
94                    res = get(arg, sentinel)
95                    if res is not sentinel:
96                        return res
97                    res = function(arg)
98                    cache[arg] = res
99                    return res
100
101            else:
102
103                def wrapper(arg):
104                    with lock:
105                        link = get(arg)
106                        if link is not None:
107                            root = root_full[0]
108                            prev, next, _arg, res = link
109                            prev[1] = next
110                            next[0] = prev
111                            last = root[0]
112                            last[1] = root[0] = link
113                            link[0] = last
114                            link[1] = root
115                            return res
116                    res = function(arg)
117                    with lock:
118                        root, full = root_full
119                        if arg in cache:
120                            pass
121                        elif full:
122                            oldroot = root
123                            oldroot[2] = arg
124                            oldroot[3] = res
125                            root = root_full[0] = oldroot[1]
126                            oldarg = root[2]
127                            oldres = root[3]  # keep reference
128                            root[2] = root[3] = None
129                            del cache[oldarg]
130                            cache[arg] = oldroot
131                        else:
132                            last = root[0]
133                            link = [last, root, arg, res]
134                            last[1] = root[0] = cache[arg] = link
135                            if len(cache) >= maxsize:
136                                root_full[1] = True
137                    return res
138
139            wrapper.__wrapped__ = function
140            return update_wrapper(wrapper, function)
141
142        return decorator
143
144
145# Auxiliary classes and functions that are independent from a DB connection:
146
147try:
148    from collections import OrderedDict
149except ImportError:  # Python 2.6 or 3.0
150    OrderedDict = dict
151
152
153    class AttrDict(dict):
154        """Simple read-only ordered dictionary for storing attribute names."""
155
156        def __init__(self, *args, **kw):
157            if len(args) > 1 or kw:
158                raise TypeError
159            items = args[0] if args else []
160            if isinstance(items, dict):
161                raise TypeError
162            items = list(items)
163            self._keys = [item[0] for item in items]
164            dict.__init__(self, items)
165            self._read_only = True
166            error = self._read_only_error
167            self.clear = self.update = error
168            self.pop = self.setdefault = self.popitem = error
169
170        def __setitem__(self, key, value):
171            if self._read_only:
172                self._read_only_error()
173            dict.__setitem__(self, key, value)
174
175        def __delitem__(self, key):
176            if self._read_only:
177                self._read_only_error()
178            dict.__delitem__(self, key)
179
180        def __iter__(self):
181            return iter(self._keys)
182
183        def keys(self):
184            return list(self._keys)
185
186        def values(self):
187            return [self[key] for key in self]
188
189        def items(self):
190            return [(key, self[key]) for key in self]
191
192        def iterkeys(self):
193            return self.__iter__()
194
195        def itervalues(self):
196            return iter(self.values())
197
198        def iteritems(self):
199            return iter(self.items())
200
201        @staticmethod
202        def _read_only_error(*args, **kw):
203            raise TypeError('This object is read-only')
204
205else:
206
207     class AttrDict(OrderedDict):
208        """Simple read-only ordered dictionary for storing attribute names."""
209
210        def __init__(self, *args, **kw):
211            self._read_only = False
212            OrderedDict.__init__(self, *args, **kw)
213            self._read_only = True
214            error = self._read_only_error
215            self.clear = self.update = error
216            self.pop = self.setdefault = self.popitem = error
217
218        def __setitem__(self, key, value):
219            if self._read_only:
220                self._read_only_error()
221            OrderedDict.__setitem__(self, key, value)
222
223        def __delitem__(self, key):
224            if self._read_only:
225                self._read_only_error()
226            OrderedDict.__delitem__(self, key)
227
228        @staticmethod
229        def _read_only_error(*args, **kw):
230            raise TypeError('This object is read-only')
231
232try:
233    from inspect import signature
234except ImportError:  # Python < 3.3
235    from inspect import getargspec
236
237    def get_args(func):
238        return getargspec(func).args
239else:
240
241    def get_args(func):
242        return list(signature(func).parameters)
243
244try:
245    from datetime import timezone
246except ImportError:  # Python < 3.2
247
248    class timezone(tzinfo):
249        """Simple timezone implementation."""
250
251        def __init__(self, offset, name=None):
252            self.offset = offset
253            if not name:
254                minutes = self.offset.days * 1440 + self.offset.seconds // 60
255                if minutes < 0:
256                    hours, minutes = divmod(-minutes, 60)
257                    hours = -hours
258                else:
259                    hours, minutes = divmod(minutes, 60)
260                name = 'UTC%+03d:%02d' % (hours, minutes)
261            self.name = name
262
263        def utcoffset(self, dt):
264            return self.offset
265
266        def tzname(self, dt):
267            return self.name
268
269        def dst(self, dt):
270            return None
271
272    timezone.utc = timezone(timedelta(0), 'UTC')
273
274    _has_timezone = False
275else:
276    _has_timezone = True
277
278# time zones used in Postgres timestamptz output
279_timezones = dict(CET='+0100', EET='+0200', EST='-0500',
280    GMT='+0000', HST='-1000', MET='+0100', MST='-0700',
281    UCT='+0000', UTC='+0000', WET='+0000')
282
283
284def _timezone_as_offset(tz):
285    if tz.startswith(('+', '-')):
286        if len(tz) < 5:
287            return tz + '00'
288        return tz.replace(':', '')
289    return _timezones.get(tz, '+0000')
290
291
292def _get_timezone(tz):
293    tz = _timezone_as_offset(tz)
294    minutes = 60 * int(tz[1:3]) + int(tz[3:5])
295    if tz[0] == '-':
296        minutes = -minutes
297    return timezone(timedelta(minutes=minutes), tz)
298
299
300def _oid_key(table):
301    """Build oid key from a table name."""
302    return 'oid(%s)' % table
303
304
305class _SimpleTypes(dict):
306    """Dictionary mapping pg_type names to simple type names."""
307
308    _types = {'bool': 'bool',
309        'bytea': 'bytea',
310        'date': 'date interval time timetz timestamp timestamptz'
311            ' abstime reltime',  # these are very old
312        'float': 'float4 float8',
313        'int': 'cid int2 int4 int8 oid xid',
314        'hstore': 'hstore', 'json': 'json jsonb', 'uuid': 'uuid',
315        'num': 'numeric', 'money': 'money',
316        'text': 'bpchar char name text varchar'}
317
318    def __init__(self):
319        for typ, keys in self._types.items():
320            for key in keys.split():
321                self[key] = typ
322                self['_%s' % key] = '%s[]' % typ
323
324    # this could be a static method in Python > 2.6
325    def __missing__(self, key):
326        return 'text'
327
328_simpletypes = _SimpleTypes()
329
330
331def _quote_if_unqualified(param, name):
332    """Quote parameter representing a qualified name.
333
334    Puts a quote_ident() call around the give parameter unless
335    the name contains a dot, in which case the name is ambiguous
336    (could be a qualified name or just a name with a dot in it)
337    and must be quoted manually by the caller.
338    """
339    if isinstance(name, basestring) and '.' not in name:
340        return 'quote_ident(%s)' % (param,)
341    return param
342
343
344class _ParameterList(list):
345    """Helper class for building typed parameter lists."""
346
347    def add(self, value, typ=None):
348        """Typecast value with known database type and build parameter list.
349
350        If this is a literal value, it will be returned as is.  Otherwise, a
351        placeholder will be returned and the parameter list will be augmented.
352        """
353        value = self.adapt(value, typ)
354        if isinstance(value, Literal):
355            return value
356        self.append(value)
357        return '$%d' % len(self)
358
359
360class Bytea(bytes):
361    """Wrapper class for marking Bytea values."""
362
363
364class Hstore(dict):
365    """Wrapper class for marking hstore values."""
366
367    _re_quote = regex('^[Nn][Uu][Ll][Ll]$|[ ,=>]')
368
369    @classmethod
370    def _quote(cls, s):
371        if s is None:
372            return 'NULL'
373        if not s:
374            return '""'
375        s = s.replace('"', '\\"')
376        if cls._re_quote.search(s):
377            s = '"%s"' % s
378        return s
379
380    def __str__(self):
381        q = self._quote
382        return ','.join('%s=>%s' % (q(k), q(v)) for k, v in self.items())
383
384
385class Json:
386    """Wrapper class for marking Json values."""
387
388    def __init__(self, obj):
389        self.obj = obj
390
391
392class Literal(str):
393    """Wrapper class for marking literal SQL values."""
394
395
396class Adapter:
397    """Class providing methods for adapting parameters to the database."""
398
399    _bool_true_values = frozenset('t true 1 y yes on'.split())
400
401    _date_literals = frozenset('current_date current_time'
402        ' current_timestamp localtime localtimestamp'.split())
403
404    _re_array_quote = regex(r'[{},"\\\s]|^[Nn][Uu][Ll][Ll]$')
405    _re_record_quote = regex(r'[(,"\\]')
406    _re_array_escape = _re_record_escape = regex(r'(["\\])')
407
408    def __init__(self, db):
409        self.db = weakref.proxy(db)
410
411    @classmethod
412    def _adapt_bool(cls, v):
413        """Adapt a boolean parameter."""
414        if isinstance(v, basestring):
415            if not v:
416                return None
417            v = v.lower() in cls._bool_true_values
418        return 't' if v else 'f'
419
420    @classmethod
421    def _adapt_date(cls, v):
422        """Adapt a date parameter."""
423        if not v:
424            return None
425        if isinstance(v, basestring) and v.lower() in cls._date_literals:
426            return Literal(v)
427        return v
428
429    @staticmethod
430    def _adapt_num(v):
431        """Adapt a numeric parameter."""
432        if not v and v != 0:
433            return None
434        return v
435
436    _adapt_int = _adapt_float = _adapt_money = _adapt_num
437
438    def _adapt_bytea(self, v):
439        """Adapt a bytea parameter."""
440        return self.db.escape_bytea(v)
441
442    def _adapt_json(self, v):
443        """Adapt a json parameter."""
444        if not v:
445            return None
446        if isinstance(v, basestring):
447            return v
448        return self.db.encode_json(v)
449
450    @classmethod
451    def _adapt_text_array(cls, v):
452        """Adapt a text type array parameter."""
453        if isinstance(v, list):
454            adapt = cls._adapt_text_array
455            return '{%s}' % ','.join(adapt(v) for v in v)
456        if v is None:
457            return 'null'
458        if not v:
459            return '""'
460        v = str(v)
461        if cls._re_array_quote.search(v):
462            v = '"%s"' % cls._re_array_escape.sub(r'\\\1', v)
463        return v
464
465    _adapt_date_array = _adapt_text_array
466
467    @classmethod
468    def _adapt_bool_array(cls, v):
469        """Adapt a boolean array parameter."""
470        if isinstance(v, list):
471            adapt = cls._adapt_bool_array
472            return '{%s}' % ','.join(adapt(v) for v in v)
473        if v is None:
474            return 'null'
475        if isinstance(v, basestring):
476            if not v:
477                return 'null'
478            v = v.lower() in cls._bool_true_values
479        return 't' if v else 'f'
480
481    @classmethod
482    def _adapt_num_array(cls, v):
483        """Adapt a numeric array parameter."""
484        if isinstance(v, list):
485            adapt = cls._adapt_num_array
486            return '{%s}' % ','.join(adapt(v) for v in v)
487        if not v and v != 0:
488            return 'null'
489        return str(v)
490
491    _adapt_int_array = _adapt_float_array = _adapt_money_array = \
492            _adapt_num_array
493
494    def _adapt_bytea_array(self, v):
495        """Adapt a bytea array parameter."""
496        if isinstance(v, list):
497            return b'{' + b','.join(
498                self._adapt_bytea_array(v) for v in v) + b'}'
499        if v is None:
500            return b'null'
501        return self.db.escape_bytea(v).replace(b'\\', b'\\\\')
502
503    def _adapt_json_array(self, v):
504        """Adapt a json array parameter."""
505        if isinstance(v, list):
506            adapt = self._adapt_json_array
507            return '{%s}' % ','.join(adapt(v) for v in v)
508        if not v:
509            return 'null'
510        if not isinstance(v, basestring):
511            v = self.db.encode_json(v)
512        if self._re_array_quote.search(v):
513            v = '"%s"' % self._re_array_escape.sub(r'\\\1', v)
514        return v
515
516    def _adapt_record(self, v, typ):
517        """Adapt a record parameter with given type."""
518        typ = self.get_attnames(typ).values()
519        if len(typ) != len(v):
520            raise TypeError('Record parameter %s has wrong size' % v)
521        adapt = self.adapt
522        value = []
523        for v, t in zip(v, typ):
524            v = adapt(v, t)
525            if v is None:
526                v = ''
527            elif not v:
528                v = '""'
529            else:
530                if isinstance(v, bytes):
531                    if str is not bytes:
532                        v = v.decode('ascii')
533                else:
534                    v = str(v)
535                if self._re_record_quote.search(v):
536                    v = '"%s"' % self._re_record_escape.sub(r'\\\1', v)
537            value.append(v)
538        return '(%s)' % ','.join(value)
539
540    def adapt(self, value, typ=None):
541        """Adapt a value with known database type."""
542        if value is not None and not isinstance(value, Literal):
543            if typ:
544                simple = self.get_simple_name(typ)
545            else:
546                typ = simple = self.guess_simple_type(value) or 'text'
547            try:
548                value = value.__pg_str__(typ)
549            except AttributeError:
550                pass
551            if simple == 'text':
552                pass
553            elif simple == 'record':
554                if isinstance(value, tuple):
555                    value = self._adapt_record(value, typ)
556            elif simple.endswith('[]'):
557                if isinstance(value, list):
558                    adapt = getattr(self, '_adapt_%s_array' % simple[:-2])
559                    value = adapt(value)
560            else:
561                adapt = getattr(self, '_adapt_%s' % simple)
562                value = adapt(value)
563        return value
564
565    @staticmethod
566    def simple_type(name):
567        """Create a simple database type with given attribute names."""
568        typ = DbType(name)
569        typ.simple = name
570        return typ
571
572    @staticmethod
573    def get_simple_name(typ):
574        """Get the simple name of a database type."""
575        if isinstance(typ, DbType):
576            return typ.simple
577        return _simpletypes[typ]
578
579    @staticmethod
580    def get_attnames(typ):
581        """Get the attribute names of a composite database type."""
582        if isinstance(typ, DbType):
583            return typ.attnames
584        return {}
585
586    @classmethod
587    def guess_simple_type(cls, value):
588        """Try to guess which database type the given value has."""
589        if isinstance(value, Bytea):
590            return 'bytea'
591        if isinstance(value, basestring):
592            return 'text'
593        if isinstance(value, bool):
594            return 'bool'
595        if isinstance(value, (int, long)):
596            return 'int'
597        if isinstance(value, float):
598            return 'float'
599        if isinstance(value, Decimal):
600            return 'num'
601        if isinstance(value, (date, time, datetime, timedelta)):
602            return 'date'
603        if isinstance(value, list):
604            return '%s[]' % (cls.guess_simple_base_type(value) or 'text',)
605        if isinstance(value, tuple):
606            simple_type = cls.simple_type
607            typ = simple_type('record')
608            guess = cls.guess_simple_type
609            def get_attnames(self):
610                return AttrDict((str(n + 1), simple_type(guess(v)))
611                    for n, v in enumerate(value))
612            typ._get_attnames = get_attnames
613            return typ
614
615    @classmethod
616    def guess_simple_base_type(cls, value):
617        """Try to guess the base type of a given array."""
618        for v in value:
619            if isinstance(v, list):
620                typ = cls.guess_simple_base_type(v)
621            else:
622                typ = cls.guess_simple_type(v)
623            if typ:
624                return typ
625
626    def adapt_inline(self, value, nested=False):
627        """Adapt a value that is put into the SQL and needs to be quoted."""
628        if value is None:
629            return 'NULL'
630        if isinstance(value, Literal):
631            return value
632        if isinstance(value, Bytea):
633            value = self.db.escape_bytea(value)
634            if bytes is not str:  # Python >= 3.0
635                value = value.decode('ascii')
636        elif isinstance(value, Json):
637            if value.encode:
638                return value.encode()
639            value = self.db.encode_json(value)
640        elif isinstance(value, (datetime, date, time, timedelta)):
641            value = str(value)
642        if isinstance(value, basestring):
643            value = self.db.escape_string(value)
644            return "'%s'" % value
645        if isinstance(value, bool):
646            return 'true' if value else 'false'
647        if isinstance(value, float):
648            if isinf(value):
649                return "'-Infinity'" if value < 0 else "'Infinity'"
650            if isnan(value):
651                return "'NaN'"
652            return value
653        if isinstance(value, (int, long, Decimal)):
654            return value
655        if isinstance(value, list):
656            q = self.adapt_inline
657            s = '[%s]' if nested else 'ARRAY[%s]'
658            return s % ','.join(str(q(v, nested=True)) for v in value)
659        if isinstance(value, tuple):
660            q = self.adapt_inline
661            return '(%s)' % ','.join(str(q(v)) for v in value)
662        try:
663            value = value.__pg_repr__()
664        except AttributeError:
665            raise InterfaceError(
666                'Do not know how to adapt type %s' % type(value))
667        if isinstance(value, (tuple, list)):
668            value = self.adapt_inline(value)
669        return value
670
671    def parameter_list(self):
672        """Return a parameter list for parameters with known database types.
673
674        The list has an add(value, typ) method that will build up the
675        list and return either the literal value or a placeholder.
676        """
677        params = _ParameterList()
678        params.adapt = self.adapt
679        return params
680
681    def format_query(self, command, values=None, types=None, inline=False):
682        """Format a database query using the given values and types."""
683        if not values:
684            return command, []
685        if inline and types:
686            raise ValueError('Typed parameters must be sent separately')
687        params = self.parameter_list()
688        if isinstance(values, (list, tuple)):
689            if inline:
690                adapt = self.adapt_inline
691                literals = [adapt(value) for value in values]
692            else:
693                add = params.add
694                literals = []
695                append = literals.append
696                if types:
697                    if (not isinstance(types, (list, tuple)) or
698                            len(types) != len(values)):
699                        raise TypeError('The values and types do not match')
700                    for value, typ in zip(values, types):
701                        append(add(value, typ))
702                else:
703                    for value in values:
704                        append(add(value))
705            command %= tuple(literals)
706        elif isinstance(values, dict):
707            # we want to allow extra keys in the dictionary,
708            # so we first must find the values actually used in the command
709            used_values = {}
710            literals = dict.fromkeys(values, '')
711            for key in values:
712                del literals[key]
713                try:
714                    command % literals
715                except KeyError:
716                    used_values[key] = values[key]
717                literals[key] = ''
718            values = used_values
719            if inline:
720                adapt = self.adapt_inline
721                literals = dict((key, adapt(value))
722                    for key, value in values.items())
723            else:
724                add = params.add
725                literals = {}
726                if types:
727                    if not isinstance(types, dict):
728                        raise TypeError('The values and types do not match')
729                    for key in sorted(values):
730                        literals[key] = add(values[key], types.get(key))
731                else:
732                    for key in sorted(values):
733                        literals[key] = add(values[key])
734            command %= literals
735        else:
736            raise TypeError('The values must be passed as tuple, list or dict')
737        return command, params
738
739
740def cast_bool(value):
741    """Cast a boolean value."""
742    if not get_bool():
743        return value
744    return value[0] == 't'
745
746
747def cast_json(value):
748    """Cast a JSON value."""
749    cast = get_jsondecode()
750    if not cast:
751        return value
752    return cast(value)
753
754
755def cast_num(value):
756    """Cast a numeric value."""
757    return (get_decimal() or float)(value)
758
759
760def cast_money(value):
761    """Cast a money value."""
762    point = get_decimal_point()
763    if not point:
764        return value
765    if point != '.':
766        value = value.replace(point, '.')
767    value = value.replace('(', '-')
768    value = ''.join(c for c in value if c.isdigit() or c in '.-')
769    return (get_decimal() or float)(value)
770
771
772def cast_int2vector(value):
773    """Cast an int2vector value."""
774    return [int(v) for v in value.split()]
775
776
777def cast_date(value, connection):
778    """Cast a date value."""
779    # The output format depends on the server setting DateStyle.  The default
780    # setting ISO and the setting for German are actually unambiguous.  The
781    # order of days and months in the other two settings is however ambiguous,
782    # so at least here we need to consult the setting to properly parse values.
783    if value == '-infinity':
784        return date.min
785    if value == 'infinity':
786        return date.max
787    value = value.split()
788    if value[-1] == 'BC':
789        return date.min
790    value = value[0]
791    if len(value) > 10:
792        return date.max
793    fmt = connection.date_format()
794    return datetime.strptime(value, fmt).date()
795
796
797def cast_time(value):
798    """Cast a time value."""
799    fmt = '%H:%M:%S.%f' if len(value) > 8 else '%H:%M:%S'
800    return datetime.strptime(value, fmt).time()
801
802
803_re_timezone = regex('(.*)([+-].*)')
804
805
806def cast_timetz(value):
807    """Cast a timetz value."""
808    tz = _re_timezone.match(value)
809    if tz:
810        value, tz = tz.groups()
811    else:
812        tz = '+0000'
813    fmt = '%H:%M:%S.%f' if len(value) > 8 else '%H:%M:%S'
814    if _has_timezone:
815        value += _timezone_as_offset(tz)
816        fmt += '%z'
817        return datetime.strptime(value, fmt).timetz()
818    return datetime.strptime(value, fmt).timetz().replace(
819        tzinfo=_get_timezone(tz))
820
821
822def cast_timestamp(value, connection):
823    """Cast a timestamp value."""
824    if value == '-infinity':
825        return datetime.min
826    if value == 'infinity':
827        return datetime.max
828    value = value.split()
829    if value[-1] == 'BC':
830        return datetime.min
831    fmt = connection.date_format()
832    if fmt.endswith('-%Y') and len(value) > 2:
833        value = value[1:5]
834        if len(value[3]) > 4:
835            return datetime.max
836        fmt = ['%d %b' if fmt.startswith('%d') else '%b %d',
837            '%H:%M:%S.%f' if len(value[2]) > 8 else '%H:%M:%S', '%Y']
838    else:
839        if len(value[0]) > 10:
840            return datetime.max
841        fmt = [fmt, '%H:%M:%S.%f' if len(value[1]) > 8 else '%H:%M:%S']
842    return datetime.strptime(' '.join(value), ' '.join(fmt))
843
844
845def cast_timestamptz(value, connection):
846    """Cast a timestamptz value."""
847    if value == '-infinity':
848        return datetime.min
849    if value == 'infinity':
850        return datetime.max
851    value = value.split()
852    if value[-1] == 'BC':
853        return datetime.min
854    fmt = connection.date_format()
855    if fmt.endswith('-%Y') and len(value) > 2:
856        value = value[1:]
857        if len(value[3]) > 4:
858            return datetime.max
859        fmt = ['%d %b' if fmt.startswith('%d') else '%b %d',
860            '%H:%M:%S.%f' if len(value[2]) > 8 else '%H:%M:%S', '%Y']
861        value, tz = value[:-1], value[-1]
862    else:
863        if fmt.startswith('%Y-'):
864            tz = _re_timezone.match(value[1])
865            if tz:
866                value[1], tz = tz.groups()
867            else:
868                tz = '+0000'
869        else:
870            value, tz = value[:-1], value[-1]
871        if len(value[0]) > 10:
872            return datetime.max
873        fmt = [fmt, '%H:%M:%S.%f' if len(value[1]) > 8 else '%H:%M:%S']
874    if _has_timezone:
875        value.append(_timezone_as_offset(tz))
876        fmt.append('%z')
877        return datetime.strptime(' '.join(value), ' '.join(fmt))
878    return datetime.strptime(' '.join(value), ' '.join(fmt)).replace(
879        tzinfo=_get_timezone(tz))
880
881
882_re_interval_sql_standard = regex(
883    '(?:([+-])?([0-9]+)-([0-9]+) ?)?'
884    '(?:([+-]?[0-9]+)(?!:) ?)?'
885    '(?:([+-])?([0-9]+):([0-9]+):([0-9]+)(?:\\.([0-9]+))?)?')
886
887_re_interval_postgres = regex(
888    '(?:([+-]?[0-9]+) ?years? ?)?'
889    '(?:([+-]?[0-9]+) ?mons? ?)?'
890    '(?:([+-]?[0-9]+) ?days? ?)?'
891    '(?:([+-])?([0-9]+):([0-9]+):([0-9]+)(?:\\.([0-9]+))?)?')
892
893_re_interval_postgres_verbose = regex(
894    '@ ?(?:([+-]?[0-9]+) ?years? ?)?'
895    '(?:([+-]?[0-9]+) ?mons? ?)?'
896    '(?:([+-]?[0-9]+) ?days? ?)?'
897    '(?:([+-]?[0-9]+) ?hours? ?)?'
898    '(?:([+-]?[0-9]+) ?mins? ?)?'
899    '(?:([+-])?([0-9]+)(?:\\.([0-9]+))? ?secs?)? ?(ago)?')
900
901_re_interval_iso_8601 = regex(
902    'P(?:([+-]?[0-9]+)Y)?'
903    '(?:([+-]?[0-9]+)M)?'
904    '(?:([+-]?[0-9]+)D)?'
905    '(?:T(?:([+-]?[0-9]+)H)?'
906    '(?:([+-]?[0-9]+)M)?'
907    '(?:([+-])?([0-9]+)(?:\\.([0-9]+))?S)?)?')
908
909
910def cast_interval(value):
911    """Cast an interval value."""
912    # The output format depends on the server setting IntervalStyle, but it's
913    # not necessary to consult this setting to parse it.  It's faster to just
914    # check all possible formats, and there is no ambiguity here.
915    m = _re_interval_iso_8601.match(value)
916    if m:
917        m = [d or '0' for d in m.groups()]
918        secs_ago = m.pop(5) == '-'
919        m = [int(d) for d in m]
920        years, mons, days, hours, mins, secs, usecs = m
921        if secs_ago:
922            secs = -secs
923            usecs = -usecs
924    else:
925        m = _re_interval_postgres_verbose.match(value)
926        if m:
927            m, ago = [d or '0' for d in m.groups()[:8]], m.group(9)
928            secs_ago = m.pop(5) == '-'
929            m = [-int(d) for d in m] if ago else [int(d) for d in m]
930            years, mons, days, hours, mins, secs, usecs = m
931            if secs_ago:
932                secs = - secs
933                usecs = -usecs
934        else:
935            m = _re_interval_postgres.match(value)
936            if m and any(m.groups()):
937                m = [d or '0' for d in m.groups()]
938                hours_ago = m.pop(3) == '-'
939                m = [int(d) for d in m]
940                years, mons, days, hours, mins, secs, usecs = m
941                if hours_ago:
942                    hours = -hours
943                    mins = -mins
944                    secs = -secs
945                    usecs = -usecs
946            else:
947                m = _re_interval_sql_standard.match(value)
948                if m and any(m.groups()):
949                    m = [d or '0' for d in m.groups()]
950                    years_ago = m.pop(0) == '-'
951                    hours_ago = m.pop(3) == '-'
952                    m = [int(d) for d in m]
953                    years, mons, days, hours, mins, secs, usecs = m
954                    if years_ago:
955                        years = -years
956                        mons = -mons
957                    if hours_ago:
958                        hours = -hours
959                        mins = -mins
960                        secs = -secs
961                        usecs = -usecs
962                else:
963                    raise ValueError('Cannot parse interval: %s' % value)
964    days += 365 * years + 30 * mons
965    return timedelta(days=days, hours=hours, minutes=mins,
966        seconds=secs, microseconds=usecs)
967
968
969class Typecasts(dict):
970    """Dictionary mapping database types to typecast functions.
971
972    The cast functions get passed the string representation of a value in
973    the database which they need to convert to a Python object.  The
974    passed string will never be None since NULL values are already
975    handled before the cast function is called.
976
977    Note that the basic types are already handled by the C extension.
978    They only need to be handled here as record or array components.
979    """
980
981    # the default cast functions
982    # (str functions are ignored but have been added for faster access)
983    defaults = {'char': str, 'bpchar': str, 'name': str,
984        'text': str, 'varchar': str,
985        'bool': cast_bool, 'bytea': unescape_bytea,
986        'int2': int, 'int4': int, 'serial': int, 'int8': long, 'oid': int,
987        'hstore': cast_hstore, 'json': cast_json, 'jsonb': cast_json,
988        'float4': float, 'float8': float,
989        'numeric': cast_num, 'money': cast_money,
990        'date': cast_date, 'interval': cast_interval,
991        'time': cast_time, 'timetz': cast_timetz,
992        'timestamp': cast_timestamp, 'timestamptz': cast_timestamptz,
993        'int2vector': cast_int2vector, 'uuid': UUID,
994        'anyarray': cast_array, 'record': cast_record}
995
996    connection = None  # will be set in a connection specific instance
997
998    def __missing__(self, typ):
999        """Create a cast function if it is not cached.
1000
1001        Note that this class never raises a KeyError,
1002        but returns None when no special cast function exists.
1003        """
1004        if not isinstance(typ, str):
1005            raise TypeError('Invalid type: %s' % typ)
1006        cast = self.defaults.get(typ)
1007        if cast:
1008            # store default for faster access
1009            cast = self._add_connection(cast)
1010            self[typ] = cast
1011        elif typ.startswith('_'):
1012            base_cast = self[typ[1:]]
1013            cast = self.create_array_cast(base_cast)
1014            if base_cast:
1015                self[typ] = cast
1016        else:
1017            attnames = self.get_attnames(typ)
1018            if attnames:
1019                casts = [self[v.pgtype] for v in attnames.values()]
1020                cast = self.create_record_cast(typ, attnames, casts)
1021                self[typ] = cast
1022        return cast
1023
1024    @staticmethod
1025    def _needs_connection(func):
1026        """Check if a typecast function needs a connection argument."""
1027        try:
1028            args = get_args(func)
1029        except (TypeError, ValueError):
1030            return False
1031        else:
1032            return 'connection' in args[1:]
1033
1034    def _add_connection(self, cast):
1035        """Add a connection argument to the typecast function if necessary."""
1036        if not self.connection or not self._needs_connection(cast):
1037            return cast
1038        return partial(cast, connection=self.connection)
1039
1040    def get(self, typ, default=None):
1041        """Get the typecast function for the given database type."""
1042        return self[typ] or default
1043
1044    def set(self, typ, cast):
1045        """Set a typecast function for the specified database type(s)."""
1046        if isinstance(typ, basestring):
1047            typ = [typ]
1048        if cast is None:
1049            for t in typ:
1050                self.pop(t, None)
1051                self.pop('_%s' % t, None)
1052        else:
1053            if not callable(cast):
1054                raise TypeError("Cast parameter must be callable")
1055            for t in typ:
1056                self[t] = self._add_connection(cast)
1057                self.pop('_%s' % t, None)
1058
1059    def reset(self, typ=None):
1060        """Reset the typecasts for the specified type(s) to their defaults.
1061
1062        When no type is specified, all typecasts will be reset.
1063        """
1064        if typ is None:
1065            self.clear()
1066        else:
1067            if isinstance(typ, basestring):
1068                typ = [typ]
1069            for t in typ:
1070                self.pop(t, None)
1071
1072    @classmethod
1073    def get_default(cls, typ):
1074        """Get the default typecast function for the given database type."""
1075        return cls.defaults.get(typ)
1076
1077    @classmethod
1078    def set_default(cls, typ, cast):
1079        """Set a default typecast function for the given database type(s)."""
1080        if isinstance(typ, basestring):
1081            typ = [typ]
1082        defaults = cls.defaults
1083        if cast is None:
1084            for t in typ:
1085                defaults.pop(t, None)
1086                defaults.pop('_%s' % t, None)
1087        else:
1088            if not callable(cast):
1089                raise TypeError("Cast parameter must be callable")
1090            for t in typ:
1091                defaults[t] = cast
1092                defaults.pop('_%s' % t, None)
1093
1094    def get_attnames(self, typ):
1095        """Return the fields for the given record type.
1096
1097        This method will be replaced with the get_attnames() method of DbTypes.
1098        """
1099        return {}
1100
1101    def dateformat(self):
1102        """Return the current date format.
1103
1104        This method will be replaced with the dateformat() method of DbTypes.
1105        """
1106        return '%Y-%m-%d'
1107
1108    def create_array_cast(self, basecast):
1109        """Create an array typecast for the given base cast."""
1110        cast_array = self['anyarray']
1111        def cast(v):
1112            return cast_array(v, basecast)
1113        return cast
1114
1115    def create_record_cast(self, name, fields, casts):
1116        """Create a named record typecast for the given fields and casts."""
1117        cast_record = self['record']
1118        record = namedtuple(name, fields)
1119        def cast(v):
1120            return record(*cast_record(v, casts))
1121        return cast
1122
1123
1124def get_typecast(typ):
1125    """Get the global typecast function for the given database type(s)."""
1126    return Typecasts.get_default(typ)
1127
1128
1129def set_typecast(typ, cast):
1130    """Set a global typecast function for the given database type(s).
1131
1132    Note that connections cache cast functions. To be sure a global change
1133    is picked up by a running connection, call db.db_types.reset_typecast().
1134    """
1135    Typecasts.set_default(typ, cast)
1136
1137
1138class DbType(str):
1139    """Class augmenting the simple type name with additional info.
1140
1141    The following additional information is provided:
1142
1143        oid: the PostgreSQL type OID
1144        pgtype: the PostgreSQL type name
1145        regtype: the regular type name
1146        simple: the simple PyGreSQL type name
1147        typtype: b = base type, c = composite type etc.
1148        category: A = Array, b = Boolean, C = Composite etc.
1149        delim: delimiter for array types
1150        relid: corresponding table for composite types
1151        attnames: attributes for composite types
1152    """
1153
1154    @property
1155    def attnames(self):
1156        """Get names and types of the fields of a composite type."""
1157        return self._get_attnames(self)
1158
1159
1160class DbTypes(dict):
1161    """Cache for PostgreSQL data types.
1162
1163    This cache maps type OIDs and names to DbType objects containing
1164    information on the associated database type.
1165    """
1166
1167    _num_types = frozenset('int float num money'
1168        ' int2 int4 int8 float4 float8 numeric money'.split())
1169
1170    def __init__(self, db):
1171        """Initialize type cache for connection."""
1172        super(DbTypes, self).__init__()
1173        self._db = weakref.proxy(db)
1174        self._regtypes = False
1175        self._typecasts = Typecasts()
1176        self._typecasts.get_attnames = self.get_attnames
1177        self._typecasts.connection = self._db
1178        if db.server_version < 80400:
1179            # older remote databases (not officially supported)
1180            self._query_pg_type = (
1181                "SELECT oid, typname, typname::text::regtype,"
1182                " typtype, null as typcategory, typdelim, typrelid"
1183                " FROM pg_type WHERE oid=%s::regtype")
1184        else:
1185            self._query_pg_type = (
1186                "SELECT oid, typname, typname::regtype,"
1187                " typtype, typcategory, typdelim, typrelid"
1188                " FROM pg_type WHERE oid=%s::regtype")
1189
1190    def add(self, oid, pgtype, regtype,
1191               typtype, category, delim, relid):
1192        """Create a PostgreSQL type name with additional info."""
1193        if oid in self:
1194            return self[oid]
1195        simple = 'record' if relid else _simpletypes[pgtype]
1196        typ = DbType(regtype if self._regtypes else simple)
1197        typ.oid = oid
1198        typ.simple = simple
1199        typ.pgtype = pgtype
1200        typ.regtype = regtype
1201        typ.typtype = typtype
1202        typ.category = category
1203        typ.delim = delim
1204        typ.relid = relid
1205        typ._get_attnames = self.get_attnames
1206        return typ
1207
1208    def __missing__(self, key):
1209        """Get the type info from the database if it is not cached."""
1210        try:
1211            q = self._query_pg_type % (_quote_if_unqualified('$1', key),)
1212            res = self._db.query(q, (key,)).getresult()
1213        except ProgrammingError:
1214            res = None
1215        if not res:
1216            raise KeyError('Type %s could not be found' % key)
1217        res = res[0]
1218        typ = self.add(*res)
1219        self[typ.oid] = self[typ.pgtype] = typ
1220        return typ
1221
1222    def get(self, key, default=None):
1223        """Get the type even if it is not cached."""
1224        try:
1225            return self[key]
1226        except KeyError:
1227            return default
1228
1229    def get_attnames(self, typ):
1230        """Get names and types of the fields of a composite type."""
1231        if not isinstance(typ, DbType):
1232            typ = self.get(typ)
1233            if not typ:
1234                return None
1235        if not typ.relid:
1236            return None
1237        return self._db.get_attnames(typ.relid, with_oid=False)
1238
1239    def get_typecast(self, typ):
1240        """Get the typecast function for the given database type."""
1241        return self._typecasts.get(typ)
1242
1243    def set_typecast(self, typ, cast):
1244        """Set a typecast function for the specified database type(s)."""
1245        self._typecasts.set(typ, cast)
1246
1247    def reset_typecast(self, typ=None):
1248        """Reset the typecast function for the specified database type(s)."""
1249        self._typecasts.reset(typ)
1250
1251    def typecast(self, value, typ):
1252        """Cast the given value according to the given database type."""
1253        if value is None:
1254            # for NULL values, no typecast is necessary
1255            return None
1256        if not isinstance(typ, DbType):
1257            typ = self.get(typ)
1258            if typ:
1259                typ = typ.pgtype
1260        cast = self.get_typecast(typ) if typ else None
1261        if not cast or cast is str:
1262            # no typecast is necessary
1263            return value
1264        return cast(value)
1265
1266
1267_re_fieldname = regex('^[A-Za-z][_a-zA-Z0-9]*$')
1268
1269# The result rows for database operations are returned as named tuples
1270# by default. Since creating namedtuple classes is a somewhat expensive
1271# operation, we cache up to 1024 of these classes by default.
1272
1273@lru_cache(maxsize=1024)
1274def _row_factory(names):
1275    """Get a namedtuple factory for row results with the given names."""
1276    try:
1277        try:
1278            return namedtuple('Row', names, rename=True)._make
1279        except TypeError:  # Python 2.6 and 3.0 do not support rename
1280            names = [v if _re_fieldname.match(v) and not iskeyword(v)
1281                        else 'column_%d' % (n,)
1282                     for n, v in enumerate(names)]
1283            return namedtuple('Row', names)._make
1284    except ValueError:  # there is still a problem with the field names
1285        names = ['column_%d' % (n,) for n in range(len(names))]
1286        return namedtuple('Row', names)._make
1287
1288
1289def set_row_factory_size(maxsize):
1290    """Change the size of the namedtuple factory cache.
1291
1292    If maxsize is set to None, the cache can grow without bound.
1293    """
1294    global _row_factory
1295    _row_factory = lru_cache(maxsize)(_row_factory.__wrapped__)
1296
1297
1298def _namedresult(q):
1299    """Get query result as named tuples."""
1300    row = _row_factory(q.listfields())
1301    return [row(r) for r in q.getresult()]
1302
1303
1304class _MemoryQuery:
1305    """Class that embodies a given query result."""
1306
1307    def __init__(self, result, fields):
1308        """Create query from given result rows and field names."""
1309        self.result = result
1310        self.fields = tuple(fields)
1311
1312    def listfields(self):
1313        """Return the stored field names of this query."""
1314        return self.fields
1315
1316    def getresult(self):
1317        """Return the stored result of this query."""
1318        return self.result
1319
1320
1321def _db_error(msg, cls=DatabaseError):
1322    """Return DatabaseError with empty sqlstate attribute."""
1323    error = cls(msg)
1324    error.sqlstate = None
1325    return error
1326
1327
1328def _int_error(msg):
1329    """Return InternalError."""
1330    return _db_error(msg, InternalError)
1331
1332
1333def _prg_error(msg):
1334    """Return ProgrammingError."""
1335    return _db_error(msg, ProgrammingError)
1336
1337
1338# Initialize the C module
1339
1340set_namedresult(_namedresult)
1341set_decimal(Decimal)
1342set_jsondecode(jsondecode)
1343
1344
1345# The notification handler
1346
1347class NotificationHandler(object):
1348    """A PostgreSQL client-side asynchronous notification handler."""
1349
1350    def __init__(self, db, event, callback=None,
1351            arg_dict=None, timeout=None, stop_event=None):
1352        """Initialize the notification handler.
1353
1354        You must pass a PyGreSQL database connection, the name of an
1355        event (notification channel) to listen for and a callback function.
1356
1357        You can also specify a dictionary arg_dict that will be passed as
1358        the single argument to the callback function, and a timeout value
1359        in seconds (a floating point number denotes fractions of seconds).
1360        If it is absent or None, the callers will never time out.  If the
1361        timeout is reached, the callback function will be called with a
1362        single argument that is None.  If you set the timeout to zero,
1363        the handler will poll notifications synchronously and return.
1364
1365        You can specify the name of the event that will be used to signal
1366        the handler to stop listening as stop_event. By default, it will
1367        be the event name prefixed with 'stop_'.
1368        """
1369        self.db = db
1370        self.event = event
1371        self.stop_event = stop_event or 'stop_%s' % event
1372        self.listening = False
1373        self.callback = callback
1374        if arg_dict is None:
1375            arg_dict = {}
1376        self.arg_dict = arg_dict
1377        self.timeout = timeout
1378
1379    def __del__(self):
1380        self.unlisten()
1381
1382    def close(self):
1383        """Stop listening and close the connection."""
1384        if self.db:
1385            self.unlisten()
1386            self.db.close()
1387            self.db = None
1388
1389    def listen(self):
1390        """Start listening for the event and the stop event."""
1391        if not self.listening:
1392            self.db.query('listen "%s"' % self.event)
1393            self.db.query('listen "%s"' % self.stop_event)
1394            self.listening = True
1395
1396    def unlisten(self):
1397        """Stop listening for the event and the stop event."""
1398        if self.listening:
1399            self.db.query('unlisten "%s"' % self.event)
1400            self.db.query('unlisten "%s"' % self.stop_event)
1401            self.listening = False
1402
1403    def notify(self, db=None, stop=False, payload=None):
1404        """Generate a notification.
1405
1406        Optionally, you can pass a payload with the notification.
1407
1408        If you set the stop flag, a stop notification will be sent that
1409        will cause the handler to stop listening.
1410
1411        Note: If the notification handler is running in another thread, you
1412        must pass a different database connection since PyGreSQL database
1413        connections are not thread-safe.
1414        """
1415        if self.listening:
1416            if not db:
1417                db = self.db
1418            q = 'notify "%s"' % (self.stop_event if stop else self.event)
1419            if payload:
1420                q += ", '%s'" % payload
1421            return db.query(q)
1422
1423    def __call__(self):
1424        """Invoke the notification handler.
1425
1426        The handler is a loop that listens for notifications on the event
1427        and stop event channels.  When either of these notifications are
1428        received, its associated 'pid', 'event' and 'extra' (the payload
1429        passed with the notification) are inserted into its arg_dict
1430        dictionary and the callback is invoked with this dictionary as
1431        a single argument.  When the handler receives a stop event, it
1432        stops listening to both events and return.
1433
1434        In the special case that the timeout of the handler has been set
1435        to zero, the handler will poll all events synchronously and return.
1436        If will keep listening until it receives a stop event.
1437
1438        Note: If you run this loop in another thread, don't use the same
1439        database connection for database operations in the main thread.
1440        """
1441        self.listen()
1442        poll = self.timeout == 0
1443        if not poll:
1444            rlist = [self.db.fileno()]
1445        while self.listening:
1446            if poll or select.select(rlist, [], [], self.timeout)[0]:
1447                while self.listening:
1448                    notice = self.db.getnotify()
1449                    if not notice:  # no more messages
1450                        break
1451                    event, pid, extra = notice
1452                    if event not in (self.event, self.stop_event):
1453                        self.unlisten()
1454                        raise _db_error(
1455                            'Listening for "%s" and "%s", but notified of "%s"'
1456                            % (self.event, self.stop_event, event))
1457                    if event == self.stop_event:
1458                        self.unlisten()
1459                    self.arg_dict.update(pid=pid, event=event, extra=extra)
1460                    self.callback(self.arg_dict)
1461                if poll:
1462                    break
1463            else:   # we timed out
1464                self.unlisten()
1465                self.callback(None)
1466
1467
1468def pgnotify(*args, **kw):
1469    """Same as NotificationHandler, under the traditional name."""
1470    warnings.warn("pgnotify is deprecated, use NotificationHandler instead",
1471        DeprecationWarning, stacklevel=2)
1472    return NotificationHandler(*args, **kw)
1473
1474
1475# The actual PostgreSQL database connection interface:
1476
1477class DB:
1478    """Wrapper class for the _pg connection type."""
1479
1480    db = None  # invalid fallback for underlying connection
1481
1482    def __init__(self, *args, **kw):
1483        """Create a new connection
1484
1485        You can pass either the connection parameters or an existing
1486        _pg or pgdb connection. This allows you to use the methods
1487        of the classic pg interface with a DB-API 2 pgdb connection.
1488        """
1489        if not args and len(kw) == 1:
1490            db = kw.get('db')
1491        elif not kw and len(args) == 1:
1492            db = args[0]
1493        else:
1494            db = None
1495        if db:
1496            if isinstance(db, DB):
1497                db = db.db
1498            else:
1499                try:
1500                    db = db._cnx
1501                except AttributeError:
1502                    pass
1503        if not db or not hasattr(db, 'db') or not hasattr(db, 'query'):
1504            db = connect(*args, **kw)
1505            self._db_args = args, kw
1506            self._closeable = True
1507        else:
1508            self._db_args = db
1509            self._closeable = False
1510        self.db = db
1511        self.dbname = db.db
1512        self._regtypes = False
1513        self._attnames = {}
1514        self._pkeys = {}
1515        self._privileges = {}
1516        self.adapter = Adapter(self)
1517        self.dbtypes = DbTypes(self)
1518        if db.server_version < 80400:
1519            # support older remote data bases
1520            self._query_attnames = (
1521                "SELECT a.attname, t.oid, t.typname, t.typname::text::regtype,"
1522                " t.typtype, null as typcategory, t.typdelim, t.typrelid"
1523                " FROM pg_attribute a"
1524                " JOIN pg_type t ON t.oid = a.atttypid"
1525                " WHERE a.attrelid = %s::regclass AND %s"
1526                " AND NOT a.attisdropped ORDER BY a.attnum")
1527        else:
1528            self._query_attnames = (
1529                "SELECT a.attname, t.oid, t.typname, t.typname::regtype,"
1530                " t.typtype, t.typcategory, t.typdelim, t.typrelid"
1531                " FROM pg_attribute a"
1532                " JOIN pg_type t ON t.oid = a.atttypid"
1533                " WHERE a.attrelid = %s::regclass AND %s"
1534                " AND NOT a.attisdropped ORDER BY a.attnum")
1535        db.set_cast_hook(self.dbtypes.typecast)
1536        self.debug = None  # For debugging scripts, this can be set
1537            # * to a string format specification (e.g. in CGI set to "%s<BR>"),
1538            # * to a file object to write debug statements or
1539            # * to a callable object which takes a string argument
1540            # * to any other true value to just print debug statements
1541
1542    def __getattr__(self, name):
1543        # All undefined members are same as in underlying connection:
1544        if self.db:
1545            return getattr(self.db, name)
1546        else:
1547            raise _int_error('Connection is not valid')
1548
1549    def __dir__(self):
1550        # Custom dir function including the attributes of the connection:
1551        attrs = set(self.__class__.__dict__)
1552        attrs.update(self.__dict__)
1553        attrs.update(dir(self.db))
1554        return sorted(attrs)
1555
1556    # Context manager methods
1557
1558    def __enter__(self):
1559        """Enter the runtime context. This will start a transaction."""
1560        self.begin()
1561        return self
1562
1563    def __exit__(self, et, ev, tb):
1564        """Exit the runtime context. This will end the transaction."""
1565        if et is None and ev is None and tb is None:
1566            self.commit()
1567        else:
1568            self.rollback()
1569
1570    def __del__(self):
1571        try:
1572            db = self.db
1573        except AttributeError:
1574            db = None
1575        if db:
1576            try:
1577                db.set_cast_hook(None)
1578            except TypeError:
1579                pass  # probably already closed
1580            if self._closeable:
1581                try:
1582                    db.close()
1583                except InternalError:
1584                    pass  # probably already closed
1585
1586    # Auxiliary methods
1587
1588    def _do_debug(self, *args):
1589        """Print a debug message"""
1590        if self.debug:
1591            s = '\n'.join(str(arg) for arg in args)
1592            if isinstance(self.debug, basestring):
1593                print(self.debug % s)
1594            elif hasattr(self.debug, 'write'):
1595                self.debug.write(s + '\n')
1596            elif callable(self.debug):
1597                self.debug(s)
1598            else:
1599                print(s)
1600
1601    def _escape_qualified_name(self, s):
1602        """Escape a qualified name.
1603
1604        Escapes the name for use as an SQL identifier, unless the
1605        name contains a dot, in which case the name is ambiguous
1606        (could be a qualified name or just a name with a dot in it)
1607        and must be quoted manually by the caller.
1608        """
1609        if '.' not in s:
1610            s = self.escape_identifier(s)
1611        return s
1612
1613    @staticmethod
1614    def _make_bool(d):
1615        """Get boolean value corresponding to d."""
1616        return bool(d) if get_bool() else ('t' if d else 'f')
1617
1618    def _list_params(self, params):
1619        """Create a human readable parameter list."""
1620        return ', '.join('$%d=%r' % (n, v) for n, v in enumerate(params, 1))
1621
1622    # Public methods
1623
1624    # escape_string and escape_bytea exist as methods,
1625    # so we define unescape_bytea as a method as well
1626    unescape_bytea = staticmethod(unescape_bytea)
1627
1628    def decode_json(self, s):
1629        """Decode a JSON string coming from the database."""
1630        return (get_jsondecode() or jsondecode)(s)
1631
1632    def encode_json(self, d):
1633        """Encode a JSON string for use within SQL."""
1634        return jsonencode(d)
1635
1636    def close(self):
1637        """Close the database connection."""
1638        # Wraps shared library function so we can track state.
1639        db = self.db
1640        if db:
1641            try:
1642                db.set_cast_hook(None)
1643            except TypeError:
1644                pass  # probably already closed
1645            if self._closeable:
1646                db.close()
1647            self.db = None
1648        else:
1649            raise _int_error('Connection already closed')
1650
1651    def reset(self):
1652        """Reset connection with current parameters.
1653
1654        All derived queries and large objects derived from this connection
1655        will not be usable after this call.
1656
1657        """
1658        if self.db:
1659            self.db.reset()
1660        else:
1661            raise _int_error('Connection already closed')
1662
1663    def reopen(self):
1664        """Reopen connection to the database.
1665
1666        Used in case we need another connection to the same database.
1667        Note that we can still reopen a database that we have closed.
1668
1669        """
1670        # There is no such shared library function.
1671        if self._closeable:
1672            db = connect(*self._db_args[0], **self._db_args[1])
1673            if self.db:
1674                self.db.set_cast_hook(None)
1675                self.db.close()
1676            db.set_cast_hook(self.dbtypes.typecast)
1677            self.db = db
1678        else:
1679            self.db = self._db_args
1680
1681    def begin(self, mode=None):
1682        """Begin a transaction."""
1683        qstr = 'BEGIN'
1684        if mode:
1685            qstr += ' ' + mode
1686        return self.query(qstr)
1687
1688    start = begin
1689
1690    def commit(self):
1691        """Commit the current transaction."""
1692        return self.query('COMMIT')
1693
1694    end = commit
1695
1696    def rollback(self, name=None):
1697        """Roll back the current transaction."""
1698        qstr = 'ROLLBACK'
1699        if name:
1700            qstr += ' TO ' + name
1701        return self.query(qstr)
1702
1703    abort = rollback
1704
1705    def savepoint(self, name):
1706        """Define a new savepoint within the current transaction."""
1707        return self.query('SAVEPOINT ' + name)
1708
1709    def release(self, name):
1710        """Destroy a previously defined savepoint."""
1711        return self.query('RELEASE ' + name)
1712
1713    def get_parameter(self, parameter):
1714        """Get the value of a run-time parameter.
1715
1716        If the parameter is a string, the return value will also be a string
1717        that is the current setting of the run-time parameter with that name.
1718
1719        You can get several parameters at once by passing a list, set or dict.
1720        When passing a list of parameter names, the return value will be a
1721        corresponding list of parameter settings.  When passing a set of
1722        parameter names, a new dict will be returned, mapping these parameter
1723        names to their settings.  Finally, if you pass a dict as parameter,
1724        its values will be set to the current parameter settings corresponding
1725        to its keys.
1726
1727        By passing the special name 'all' as the parameter, you can get a dict
1728        of all existing configuration parameters.
1729        """
1730        if isinstance(parameter, basestring):
1731            parameter = [parameter]
1732            values = None
1733        elif isinstance(parameter, (list, tuple)):
1734            values = []
1735        elif isinstance(parameter, (set, frozenset)):
1736            values = {}
1737        elif isinstance(parameter, dict):
1738            values = parameter
1739        else:
1740            raise TypeError(
1741                'The parameter must be a string, list, set or dict')
1742        if not parameter:
1743            raise TypeError('No parameter has been specified')
1744        params = {} if isinstance(values, dict) else []
1745        for key in parameter:
1746            param = key.strip().lower() if isinstance(
1747                key, basestring) else None
1748            if not param:
1749                raise TypeError('Invalid parameter')
1750            if param == 'all':
1751                q = 'SHOW ALL'
1752                values = self.db.query(q).getresult()
1753                values = dict(value[:2] for value in values)
1754                break
1755            if isinstance(values, dict):
1756                params[param] = key
1757            else:
1758                params.append(param)
1759        else:
1760            for param in params:
1761                q = 'SHOW %s' % (param,)
1762                value = self.db.query(q).getresult()[0][0]
1763                if values is None:
1764                    values = value
1765                elif isinstance(values, list):
1766                    values.append(value)
1767                else:
1768                    values[params[param]] = value
1769        return values
1770
1771    def set_parameter(self, parameter, value=None, local=False):
1772        """Set the value of a run-time parameter.
1773
1774        If the parameter and the value are strings, the run-time parameter
1775        will be set to that value.  If no value or None is passed as a value,
1776        then the run-time parameter will be restored to its default value.
1777
1778        You can set several parameters at once by passing a list of parameter
1779        names, together with a single value that all parameters should be
1780        set to or with a corresponding list of values.  You can also pass
1781        the parameters as a set if you only provide a single value.
1782        Finally, you can pass a dict with parameter names as keys.  In this
1783        case, you should not pass a value, since the values for the parameters
1784        will be taken from the dict.
1785
1786        By passing the special name 'all' as the parameter, you can reset
1787        all existing settable run-time parameters to their default values.
1788
1789        If you set local to True, then the command takes effect for only the
1790        current transaction.  After commit() or rollback(), the session-level
1791        setting takes effect again.  Setting local to True will appear to
1792        have no effect if it is executed outside a transaction, since the
1793        transaction will end immediately.
1794        """
1795        if isinstance(parameter, basestring):
1796            parameter = {parameter: value}
1797        elif isinstance(parameter, (list, tuple)):
1798            if isinstance(value, (list, tuple)):
1799                parameter = dict(zip(parameter, value))
1800            else:
1801                parameter = dict.fromkeys(parameter, value)
1802        elif isinstance(parameter, (set, frozenset)):
1803            if isinstance(value, (list, tuple, set, frozenset)):
1804                value = set(value)
1805                if len(value) == 1:
1806                    value = value.pop()
1807            if not(value is None or isinstance(value, basestring)):
1808                raise ValueError('A single value must be specified'
1809                    ' when parameter is a set')
1810            parameter = dict.fromkeys(parameter, value)
1811        elif isinstance(parameter, dict):
1812            if value is not None:
1813                raise ValueError('A value must not be specified'
1814                    ' when parameter is a dictionary')
1815        else:
1816            raise TypeError(
1817                'The parameter must be a string, list, set or dict')
1818        if not parameter:
1819            raise TypeError('No parameter has been specified')
1820        params = {}
1821        for key, value in parameter.items():
1822            param = key.strip().lower() if isinstance(
1823                key, basestring) else None
1824            if not param:
1825                raise TypeError('Invalid parameter')
1826            if param == 'all':
1827                if value is not None:
1828                    raise ValueError('A value must ot be specified'
1829                        " when parameter is 'all'")
1830                params = {'all': None}
1831                break
1832            params[param] = value
1833        local = ' LOCAL' if local else ''
1834        for param, value in params.items():
1835            if value is None:
1836                q = 'RESET%s %s' % (local, param)
1837            else:
1838                q = 'SET%s %s TO %s' % (local, param, value)
1839            self._do_debug(q)
1840            self.db.query(q)
1841
1842    def query(self, command, *args):
1843        """Execute a SQL command string.
1844
1845        This method simply sends a SQL query to the database.  If the query is
1846        an insert statement that inserted exactly one row into a table that
1847        has OIDs, the return value is the OID of the newly inserted row.
1848        If the query is an update or delete statement, or an insert statement
1849        that did not insert exactly one row in a table with OIDs, then the
1850        number of rows affected is returned as a string.  If it is a statement
1851        that returns rows as a result (usually a select statement, but maybe
1852        also an "insert/update ... returning" statement), this method returns
1853        a Query object that can be accessed via getresult() or dictresult()
1854        or simply printed.  Otherwise, it returns `None`.
1855
1856        The query can contain numbered parameters of the form $1 in place
1857        of any data constant.  Arguments given after the query string will
1858        be substituted for the corresponding numbered parameter.  Parameter
1859        values can also be given as a single list or tuple argument.
1860        """
1861        # Wraps shared library function for debugging.
1862        if not self.db:
1863            raise _int_error('Connection is not valid')
1864        if args:
1865            self._do_debug(command, args)
1866            return self.db.query(command, args)
1867        self._do_debug(command)
1868        return self.db.query(command)
1869
1870    def query_formatted(self, command,
1871            parameters=None, types=None, inline=False):
1872        """Execute a formatted SQL command string.
1873
1874        Similar to query, but using Python format placeholders of the form
1875        %s or %(names)s instead of PostgreSQL placeholders of the form $1.
1876        The parameters must be passed as a tuple, list or dict.  You can
1877        also pass a corresponding tuple, list or dict of database types in
1878        order to format the parameters properly in case there is ambiguity.
1879
1880        If you set inline to True, the parameters will be sent to the database
1881        embedded in the SQL command, otherwise they will be sent separately.
1882        """
1883        return self.query(*self.adapter.format_query(
1884            command, parameters, types, inline))
1885
1886    def query_prepared(self, name, *args):
1887        """Execute a prepared SQL statement.
1888
1889        This works like the query() method, except that instead of passing
1890        the SQL command, you pass the name of a prepared statement.  If you
1891        pass an empty name, the unnamed statement will be executed.
1892        """
1893        if not self.db:
1894            raise _int_error('Connection is not valid')
1895        if name is None:
1896            name = ''
1897        if args:
1898            self._do_debug('EXECUTE', name, args)
1899            return self.db.query_prepared(name, args)
1900        self._do_debug('EXECUTE', name)
1901        return self.db.query_prepared(name)
1902
1903    def prepare(self, name, command):
1904        """Create a prepared SQL statement.
1905
1906        This creates a prepared statement for the given command with the
1907        the given name for later execution with the query_prepared() method.
1908
1909        The name can be empty to create an unnamed statement, in which case
1910        any pre-existing unnamed statement is automatically replaced;
1911        otherwise it is an error if the statement name is already
1912        defined in the current database session. We recommend always using
1913        named queries, since unnamed queries have a limited lifetime and
1914        can be automatically replaced or destroyed by various operations.
1915        """
1916        if not self.db:
1917            raise _int_error('Connection is not valid')
1918        if name is None:
1919            name = ''
1920        self._do_debug('prepare', name, command)
1921        return self.db.prepare(name, command)
1922
1923    def describe_prepared(self, name=None):
1924        """Describe a prepared SQL statement.
1925
1926        This method returns a Query object describing the result columns of
1927        the prepared statement with the given name. If you omit the name,
1928        the unnamed statement will be described if you created one before.
1929        """
1930        if name is None:
1931            name = ''
1932        return self.db.describe_prepared(name)
1933
1934    def delete_prepared(self, name=None):
1935        """Delete a prepared SQL statement
1936
1937        This deallocates a previously prepared SQL statement with the given
1938        name, or deallocates all prepared statements if you do not specify a
1939        name. Note that prepared statements are also deallocated automatically
1940        when the current session ends.
1941        """
1942        q = "DEALLOCATE %s" % (name or 'ALL',)
1943        self._do_debug(q)
1944        return self.db.query(q)
1945
1946    def pkey(self, table, composite=False, flush=False):
1947        """Get or set the primary key of a table.
1948
1949        Single primary keys are returned as strings unless you
1950        set the composite flag.  Composite primary keys are always
1951        represented as tuples.  Note that this raises a KeyError
1952        if the table does not have a primary key.
1953
1954        If flush is set then the internal cache for primary keys will
1955        be flushed.  This may be necessary after the database schema or
1956        the search path has been changed.
1957        """
1958        pkeys = self._pkeys
1959        if flush:
1960            pkeys.clear()
1961            self._do_debug('The pkey cache has been flushed')
1962        try:  # cache lookup
1963            pkey = pkeys[table]
1964        except KeyError:  # cache miss, check the database
1965            q = ("SELECT a.attname, a.attnum, i.indkey FROM pg_index i"
1966                " JOIN pg_attribute a ON a.attrelid = i.indrelid"
1967                " AND a.attnum = ANY(i.indkey)"
1968                " AND NOT a.attisdropped"
1969                " WHERE i.indrelid=%s::regclass"
1970                " AND i.indisprimary ORDER BY a.attnum") % (
1971                    _quote_if_unqualified('$1', table),)
1972            pkey = self.db.query(q, (table,)).getresult()
1973            if not pkey:
1974                raise KeyError('Table %s has no primary key' % table)
1975            # we want to use the order defined in the primary key index here,
1976            # not the order as defined by the columns in the table
1977            if len(pkey) > 1:
1978                indkey = pkey[0][2]
1979                pkey = sorted(pkey, key=lambda row: indkey.index(row[1]))
1980                pkey = tuple(row[0] for row in pkey)
1981            else:
1982                pkey = pkey[0][0]
1983            pkeys[table] = pkey  # cache it
1984        if composite and not isinstance(pkey, tuple):
1985            pkey = (pkey,)
1986        return pkey
1987
1988    def get_databases(self):
1989        """Get list of databases in the system."""
1990        return [s[0] for s in
1991            self.db.query('SELECT datname FROM pg_database').getresult()]
1992
1993    def get_relations(self, kinds=None, system=False):
1994        """Get list of relations in connected database of specified kinds.
1995
1996        If kinds is None or empty, all kinds of relations are returned.
1997        Otherwise kinds can be a string or sequence of type letters
1998        specifying which kind of relations you want to list.
1999
2000        Set the system flag if you want to get the system relations as well.
2001        """
2002        where = []
2003        if kinds:
2004            where.append("r.relkind IN (%s)" %
2005                ','.join("'%s'" % k for k in kinds))
2006        if not system:
2007            where.append("s.nspname NOT SIMILAR"
2008                " TO 'pg/_%|information/_schema' ESCAPE '/'")
2009        where = " WHERE %s" % ' AND '.join(where) if where else ''
2010        q = ("SELECT quote_ident(s.nspname)||'.'||quote_ident(r.relname)"
2011            " FROM pg_class r"
2012            " JOIN pg_namespace s ON s.oid = r.relnamespace%s"
2013            " ORDER BY s.nspname, r.relname") % where
2014        return [r[0] for r in self.db.query(q).getresult()]
2015
2016    def get_tables(self, system=False):
2017        """Return list of tables in connected database.
2018
2019        Set the system flag if you want to get the system tables as well.
2020        """
2021        return self.get_relations('r', system)
2022
2023    def get_attnames(self, table, with_oid=True, flush=False):
2024        """Given the name of a table, dig out the set of attribute names.
2025
2026        Returns a read-only dictionary of attribute names (the names are
2027        the keys, the values are the names of the attributes' types)
2028        with the column names in the proper order if you iterate over it.
2029
2030        If flush is set, then the internal cache for attribute names will
2031        be flushed. This may be necessary after the database schema or
2032        the search path has been changed.
2033
2034        By default, only a limited number of simple types will be returned.
2035        You can get the regular types after calling use_regtypes(True).
2036        """
2037        attnames = self._attnames
2038        if flush:
2039            attnames.clear()
2040            self._do_debug('The attnames cache has been flushed')
2041        try:  # cache lookup
2042            names = attnames[table]
2043        except KeyError:  # cache miss, check the database
2044            q = "a.attnum > 0"
2045            if with_oid:
2046                q = "(%s OR a.attname = 'oid')" % q
2047            q = self._query_attnames % (_quote_if_unqualified('$1', table), q)
2048            names = self.db.query(q, (table,)).getresult()
2049            types = self.dbtypes
2050            names = ((name[0], types.add(*name[1:])) for name in names)
2051            names = AttrDict(names)
2052            attnames[table] = names  # cache it
2053        return names
2054
2055    def use_regtypes(self, regtypes=None):
2056        """Use regular type names instead of simplified type names."""
2057        if regtypes is None:
2058            return self.dbtypes._regtypes
2059        else:
2060            regtypes = bool(regtypes)
2061            if regtypes != self.dbtypes._regtypes:
2062                self.dbtypes._regtypes = regtypes
2063                self._attnames.clear()
2064                self.dbtypes.clear()
2065            return regtypes
2066
2067    def has_table_privilege(self, table, privilege='select', flush=False):
2068        """Check whether current user has specified table privilege.
2069
2070        If flush is set, then the internal cache for table privileges will
2071        be flushed. This may be necessary after privileges have been changed.
2072        """
2073        privileges = self._privileges
2074        if flush:
2075            privileges.clear()
2076            self._do_debug('The privileges cache has been flushed')
2077        privilege = privilege.lower()
2078        try:  # ask cache
2079            ret = privileges[table, privilege]
2080        except KeyError:  # cache miss, ask the database
2081            q = "SELECT has_table_privilege(%s, $2)" % (
2082                _quote_if_unqualified('$1', table),)
2083            q = self.db.query(q, (table, privilege))
2084            ret = q.getresult()[0][0] == self._make_bool(True)
2085            privileges[table, privilege] = ret  # cache it
2086        return ret
2087
2088    def get(self, table, row, keyname=None):
2089        """Get a row from a database table or view.
2090
2091        This method is the basic mechanism to get a single row.  It assumes
2092        that the keyname specifies a unique row.  It must be the name of a
2093        single column or a tuple of column names.  If the keyname is not
2094        specified, then the primary key for the table is used.
2095
2096        If row is a dictionary, then the value for the key is taken from it.
2097        Otherwise, the row must be a single value or a tuple of values
2098        corresponding to the passed keyname or primary key.  The fetched row
2099        from the table will be returned as a new dictionary or used to replace
2100        the existing values when row was passed as a dictionary.
2101
2102        The OID is also put into the dictionary if the table has one, but
2103        in order to allow the caller to work with multiple tables, it is
2104        munged as "oid(table)" using the actual name of the table.
2105        """
2106        if table.endswith('*'):  # hint for descendant tables can be ignored
2107            table = table[:-1].rstrip()
2108        attnames = self.get_attnames(table)
2109        qoid = _oid_key(table) if 'oid' in attnames else None
2110        if keyname and isinstance(keyname, basestring):
2111            keyname = (keyname,)
2112        if qoid and isinstance(row, dict) and qoid in row and 'oid' not in row:
2113            row['oid'] = row[qoid]
2114        if not keyname:
2115            try:  # if keyname is not specified, try using the primary key
2116                keyname = self.pkey(table, True)
2117            except KeyError:  # the table has no primary key
2118                # try using the oid instead
2119                if qoid and isinstance(row, dict) and 'oid' in row:
2120                    keyname = ('oid',)
2121                else:
2122                    raise _prg_error('Table %s has no primary key' % table)
2123            else:  # the table has a primary key
2124                # check whether all key columns have values
2125                if isinstance(row, dict) and not set(keyname).issubset(row):
2126                    # try using the oid instead
2127                    if qoid and 'oid' in row:
2128                        keyname = ('oid',)
2129                    else:
2130                        raise KeyError(
2131                            'Missing value in row for specified keyname')
2132        if not isinstance(row, dict):
2133            if not isinstance(row, (tuple, list)):
2134                row = [row]
2135            if len(keyname) != len(row):
2136                raise KeyError(
2137                    'Differing number of items in keyname and row')
2138            row = dict(zip(keyname, row))
2139        params = self.adapter.parameter_list()
2140        adapt = params.add
2141        col = self.escape_identifier
2142        what = 'oid, *' if qoid else '*'
2143        where = ' AND '.join('%s = %s' % (
2144            col(k), adapt(row[k], attnames[k])) for k in keyname)
2145        if 'oid' in row:
2146            if qoid:
2147                row[qoid] = row['oid']
2148            del row['oid']
2149        q = 'SELECT %s FROM %s WHERE %s LIMIT 1' % (
2150            what, self._escape_qualified_name(table), where)
2151        self._do_debug(q, params)
2152        q = self.db.query(q, params)
2153        res = q.dictresult()
2154        if not res:
2155            raise _db_error('No such record in %s\nwhere %s\nwith %s' % (
2156                table, where, self._list_params(params)))
2157        for n, value in res[0].items():
2158            if qoid and n == 'oid':
2159                n = qoid
2160            row[n] = value
2161        return row
2162
2163    def insert(self, table, row=None, **kw):
2164        """Insert a row into a database table.
2165
2166        This method inserts a row into a table.  The name of the table must
2167        be passed as the first parameter.  The other parameters are used for
2168        providing the data of the row that shall be inserted into the table.
2169        If a dictionary is supplied as the second parameter, it starts with
2170        that.  Otherwise it uses a blank dictionary. Either way the dictionary
2171        is updated from the keywords.
2172
2173        The dictionary is then reloaded with the values actually inserted in
2174        order to pick up values modified by rules, triggers, etc.
2175        """
2176        if table.endswith('*'):  # hint for descendant tables can be ignored
2177            table = table[:-1].rstrip()
2178        if row is None:
2179            row = {}
2180        row.update(kw)
2181        if 'oid' in row:
2182            del row['oid']  # do not insert oid
2183        attnames = self.get_attnames(table)
2184        qoid = _oid_key(table) if 'oid' in attnames else None
2185        params = self.adapter.parameter_list()
2186        adapt = params.add
2187        col = self.escape_identifier
2188        names, values = [], []
2189        for n in attnames:
2190            if n in row:
2191                names.append(col(n))
2192                values.append(adapt(row[n], attnames[n]))
2193        if not names:
2194            raise _prg_error('No column found that can be inserted')
2195        names, values = ', '.join(names), ', '.join(values)
2196        ret = 'oid, *' if qoid else '*'
2197        q = 'INSERT INTO %s (%s) VALUES (%s) RETURNING %s' % (
2198            self._escape_qualified_name(table), names, values, ret)
2199        self._do_debug(q, params)
2200        q = self.db.query(q, params)
2201        res = q.dictresult()
2202        if res:  # this should always be true
2203            for n, value in res[0].items():
2204                if qoid and n == 'oid':
2205                    n = qoid
2206                row[n] = value
2207        return row
2208
2209    def update(self, table, row=None, **kw):
2210        """Update an existing row in a database table.
2211
2212        Similar to insert, but updates an existing row.  The update is based
2213        on the primary key of the table or the OID value as munged by get()
2214        or passed as keyword.  The OID will take precedence if provided, so
2215        that it is possible to update the primary key itself.
2216
2217        The dictionary is then modified to reflect any changes caused by the
2218        update due to triggers, rules, default values, etc.
2219        """
2220        if table.endswith('*'):
2221            table = table[:-1].rstrip()  # need parent table name
2222        attnames = self.get_attnames(table)
2223        qoid = _oid_key(table) if 'oid' in attnames else None
2224        if row is None:
2225            row = {}
2226        elif 'oid' in row:
2227            del row['oid']  # only accept oid key from named args for safety
2228        row.update(kw)
2229        if qoid and qoid in row and 'oid' not in row:
2230            row['oid'] = row[qoid]
2231        if qoid and 'oid' in row:  # try using the oid
2232            keyname = ('oid',)
2233        else:  # try using the primary key
2234            try:
2235                keyname = self.pkey(table, True)
2236            except KeyError:  # the table has no primary key
2237                raise _prg_error('Table %s has no primary key' % table)
2238            # check whether all key columns have values
2239            if not set(keyname).issubset(row):
2240                raise KeyError('Missing value for primary key in row')
2241        params = self.adapter.parameter_list()
2242        adapt = params.add
2243        col = self.escape_identifier
2244        where = ' AND '.join('%s = %s' % (
2245            col(k), adapt(row[k], attnames[k])) for k in keyname)
2246        if 'oid' in row:
2247            if qoid:
2248                row[qoid] = row['oid']
2249            del row['oid']
2250        values = []
2251        keyname = set(keyname)
2252        for n in attnames:
2253            if n in row and n not in keyname:
2254                values.append('%s = %s' % (col(n), adapt(row[n], attnames[n])))
2255        if not values:
2256            return row
2257        values = ', '.join(values)
2258        ret = 'oid, *' if qoid else '*'
2259        q = 'UPDATE %s SET %s WHERE %s RETURNING %s' % (
2260            self._escape_qualified_name(table), values, where, ret)
2261        self._do_debug(q, params)
2262        q = self.db.query(q, params)
2263        res = q.dictresult()
2264        if res:  # may be empty when row does not exist
2265            for n, value in res[0].items():
2266                if qoid and n == 'oid':
2267                    n = qoid
2268                row[n] = value
2269        return row
2270
2271    def upsert(self, table, row=None, **kw):
2272        """Insert a row into a database table with conflict resolution
2273
2274        This method inserts a row into a table, but instead of raising a
2275        ProgrammingError exception in case a row with the same primary key
2276        already exists, an update will be executed instead.  This will be
2277        performed as a single atomic operation on the database, so race
2278        conditions can be avoided.
2279
2280        Like the insert method, the first parameter is the name of the
2281        table and the second parameter can be used to pass the values to
2282        be inserted as a dictionary.
2283
2284        Unlike the insert und update statement, keyword parameters are not
2285        used to modify the dictionary, but to specify which columns shall
2286        be updated in case of a conflict, and in which way:
2287
2288        A value of False or None means the column shall not be updated,
2289        a value of True means the column shall be updated with the value
2290        that has been proposed for insertion, i.e. has been passed as value
2291        in the dictionary.  Columns that are not specified by keywords but
2292        appear as keys in the dictionary are also updated like in the case
2293        keywords had been passed with the value True.
2294
2295        So if in the case of a conflict you want to update every column that
2296        has been passed in the dictionary row, you would call upsert(table, row).
2297        If you don't want to do anything in case of a conflict, i.e. leave
2298        the existing row as it is, call upsert(table, row, **dict.fromkeys(row)).
2299
2300        If you need more fine-grained control of what gets updated, you can
2301        also pass strings in the keyword parameters.  These strings will
2302        be used as SQL expressions for the update columns.  In these
2303        expressions you can refer to the value that already exists in
2304        the table by prefixing the column name with "included.", and to
2305        the value that has been proposed for insertion by prefixing the
2306        column name with the "excluded."
2307
2308        The dictionary is modified in any case to reflect the values in
2309        the database after the operation has completed.
2310
2311        Note: The method uses the PostgreSQL "upsert" feature which is
2312        only available since PostgreSQL 9.5.
2313        """
2314        if table.endswith('*'):  # hint for descendant tables can be ignored
2315            table = table[:-1].rstrip()
2316        if row is None:
2317            row = {}
2318        if 'oid' in row:
2319            del row['oid']  # do not insert oid
2320        if 'oid' in kw:
2321            del kw['oid']  # do not update oid
2322        attnames = self.get_attnames(table)
2323        qoid = _oid_key(table) if 'oid' in attnames else None
2324        params = self.adapter.parameter_list()
2325        adapt = params.add
2326        col = self.escape_identifier
2327        names, values, updates = [], [], []
2328        for n in attnames:
2329            if n in row:
2330                names.append(col(n))
2331                values.append(adapt(row[n], attnames[n]))
2332        names, values = ', '.join(names), ', '.join(values)
2333        try:
2334            keyname = self.pkey(table, True)
2335        except KeyError:
2336            raise _prg_error('Table %s has no primary key' % table)
2337        target = ', '.join(col(k) for k in keyname)
2338        update = []
2339        keyname = set(keyname)
2340        keyname.add('oid')
2341        for n in attnames:
2342            if n not in keyname:
2343                value = kw.get(n, True)
2344                if value:
2345                    if not isinstance(value, basestring):
2346                        value = 'excluded.%s' % col(n)
2347                    update.append('%s = %s' % (col(n), value))
2348        if not values:
2349            return row
2350        do = 'update set %s' % ', '.join(update) if update else 'nothing'
2351        ret = 'oid, *' if qoid else '*'
2352        q = ('INSERT INTO %s AS included (%s) VALUES (%s)'
2353            ' ON CONFLICT (%s) DO %s RETURNING %s') % (
2354                self._escape_qualified_name(table), names, values,
2355                target, do, ret)
2356        self._do_debug(q, params)
2357        try:
2358            q = self.db.query(q, params)
2359        except ProgrammingError:
2360            if self.server_version < 90500:
2361                raise _prg_error(
2362                    'Upsert operation is not supported by PostgreSQL version')
2363            raise  # re-raise original error
2364        res = q.dictresult()
2365        if res:  # may be empty with "do nothing"
2366            for n, value in res[0].items():
2367                if qoid and n == 'oid':
2368                    n = qoid
2369                row[n] = value
2370        else:
2371            self.get(table, row)
2372        return row
2373
2374    def clear(self, table, row=None):
2375        """Clear all the attributes to values determined by the types.
2376
2377        Numeric types are set to 0, Booleans are set to false, and everything
2378        else is set to the empty string.  If the row argument is present,
2379        it is used as the row dictionary and any entries matching attribute
2380        names are cleared with everything else left unchanged.
2381        """
2382        # At some point we will need a way to get defaults from a table.
2383        if row is None:
2384            row = {}  # empty if argument is not present
2385        attnames = self.get_attnames(table)
2386        for n, t in attnames.items():
2387            if n == 'oid':
2388                continue
2389            t = t.simple
2390            if t in DbTypes._num_types:
2391                row[n] = 0
2392            elif t == 'bool':
2393                row[n] = self._make_bool(False)
2394            else:
2395                row[n] = ''
2396        return row
2397
2398    def delete(self, table, row=None, **kw):
2399        """Delete an existing row in a database table.
2400
2401        This method deletes the row from a table.  It deletes based on the
2402        primary key of the table or the OID value as munged by get() or
2403        passed as keyword.  The OID will take precedence if provided.
2404
2405        The return value is the number of deleted rows (i.e. 0 if the row
2406        did not exist and 1 if the row was deleted).
2407
2408        Note that if the row cannot be deleted because e.g. it is still
2409        referenced by another table, this method raises a ProgrammingError.
2410        """
2411        if table.endswith('*'):  # hint for descendant tables can be ignored
2412            table = table[:-1].rstrip()
2413        attnames = self.get_attnames(table)
2414        qoid = _oid_key(table) if 'oid' in attnames else None
2415        if row is None:
2416            row = {}
2417        elif 'oid' in row:
2418            del row['oid']  # only accept oid key from named args for safety
2419        row.update(kw)
2420        if qoid and qoid in row and 'oid' not in row:
2421            row['oid'] = row[qoid]
2422        if qoid and 'oid' in row:  # try using the oid
2423            keyname = ('oid',)
2424        else:  # try using the primary key
2425            try:
2426                keyname = self.pkey(table, True)
2427            except KeyError:  # the table has no primary key
2428                raise _prg_error('Table %s has no primary key' % table)
2429            # check whether all key columns have values
2430            if not set(keyname).issubset(row):
2431                raise KeyError('Missing value for primary key in row')
2432        params = self.adapter.parameter_list()
2433        adapt = params.add
2434        col = self.escape_identifier
2435        where = ' AND '.join('%s = %s' % (
2436            col(k), adapt(row[k], attnames[k])) for k in keyname)
2437        if 'oid' in row:
2438            if qoid:
2439                row[qoid] = row['oid']
2440            del row['oid']
2441        q = 'DELETE FROM %s WHERE %s' % (
2442            self._escape_qualified_name(table), where)
2443        self._do_debug(q, params)
2444        res = self.db.query(q, params)
2445        return int(res)
2446
2447    def truncate(self, table, restart=False, cascade=False, only=False):
2448        """Empty a table or set of tables.
2449
2450        This method quickly removes all rows from the given table or set
2451        of tables.  It has the same effect as an unqualified DELETE on each
2452        table, but since it does not actually scan the tables it is faster.
2453        Furthermore, it reclaims disk space immediately, rather than requiring
2454        a subsequent VACUUM operation. This is most useful on large tables.
2455
2456        If restart is set to True, sequences owned by columns of the truncated
2457        table(s) are automatically restarted.  If cascade is set to True, it
2458        also truncates all tables that have foreign-key references to any of
2459        the named tables.  If the parameter only is not set to True, all the
2460        descendant tables (if any) will also be truncated. Optionally, a '*'
2461        can be specified after the table name to explicitly indicate that
2462        descendant tables are included.
2463        """
2464        if isinstance(table, basestring):
2465            only = {table: only}
2466            table = [table]
2467        elif isinstance(table, (list, tuple)):
2468            if isinstance(only, (list, tuple)):
2469                only = dict(zip(table, only))
2470            else:
2471                only = dict.fromkeys(table, only)
2472        elif isinstance(table, (set, frozenset)):
2473            only = dict.fromkeys(table, only)
2474        else:
2475            raise TypeError('The table must be a string, list or set')
2476        if not (restart is None or isinstance(restart, (bool, int))):
2477            raise TypeError('Invalid type for the restart option')
2478        if not (cascade is None or isinstance(cascade, (bool, int))):
2479            raise TypeError('Invalid type for the cascade option')
2480        tables = []
2481        for t in table:
2482            u = only.get(t)
2483            if not (u is None or isinstance(u, (bool, int))):
2484                raise TypeError('Invalid type for the only option')
2485            if t.endswith('*'):
2486                if u:
2487                    raise ValueError(
2488                        'Contradictory table name and only options')
2489                t = t[:-1].rstrip()
2490            t = self._escape_qualified_name(t)
2491            if u:
2492                t = 'ONLY %s' % t
2493            tables.append(t)
2494        q = ['TRUNCATE', ', '.join(tables)]
2495        if restart:
2496            q.append('RESTART IDENTITY')
2497        if cascade:
2498            q.append('CASCADE')
2499        q = ' '.join(q)
2500        self._do_debug(q)
2501        return self.db.query(q)
2502
2503    def get_as_list(self, table, what=None, where=None,
2504            order=None, limit=None, offset=None, scalar=False):
2505        """Get a table as a list.
2506
2507        This gets a convenient representation of the table as a list
2508        of named tuples in Python.  You only need to pass the name of
2509        the table (or any other SQL expression returning rows).  Note that
2510        by default this will return the full content of the table which
2511        can be huge and overflow your memory.  However, you can control
2512        the amount of data returned using the other optional parameters.
2513
2514        The parameter 'what' can restrict the query to only return a
2515        subset of the table columns.  It can be a string, list or a tuple.
2516        The parameter 'where' can restrict the query to only return a
2517        subset of the table rows.  It can be a string, list or a tuple
2518        of SQL expressions that all need to be fulfilled.  The parameter
2519        'order' specifies the ordering of the rows.  It can also be a
2520        other string, list or a tuple.  If no ordering is specified,
2521        the result will be ordered by the primary key(s) or all columns
2522        if no primary key exists.  You can set 'order' to False if you
2523        don't care about the ordering.  The parameters 'limit' and 'offset'
2524        can be integers specifying the maximum number of rows returned
2525        and a number of rows skipped over.
2526
2527        If you set the 'scalar' option to True, then instead of the
2528        named tuples you will get the first items of these tuples.
2529        This is useful if the result has only one column anyway.
2530        """
2531        if not table:
2532            raise TypeError('The table name is missing')
2533        if what:
2534            if isinstance(what, (list, tuple)):
2535                what = ', '.join(map(str, what))
2536            if order is None:
2537                order = what
2538        else:
2539            what = '*'
2540        q = ['SELECT', what, 'FROM', table]
2541        if where:
2542            if isinstance(where, (list, tuple)):
2543                where = ' AND '.join(map(str, where))
2544            q.extend(['WHERE', where])
2545        if order is None:
2546            try:
2547                order = self.pkey(table, True)
2548            except (KeyError, ProgrammingError):
2549                try:
2550                    order = list(self.get_attnames(table))
2551                except (KeyError, ProgrammingError):
2552                    pass
2553        if order:
2554            if isinstance(order, (list, tuple)):
2555                order = ', '.join(map(str, order))
2556            q.extend(['ORDER BY', order])
2557        if limit:
2558            q.append('LIMIT %d' % limit)
2559        if offset:
2560            q.append('OFFSET %d' % offset)
2561        q = ' '.join(q)
2562        self._do_debug(q)
2563        q = self.db.query(q)
2564        res = q.namedresult()
2565        if res and scalar:
2566            res = [row[0] for row in res]
2567        return res
2568
2569    def get_as_dict(self, table, keyname=None, what=None, where=None,
2570            order=None, limit=None, offset=None, scalar=False):
2571        """Get a table as a dictionary.
2572
2573        This method is similar to get_as_list(), but returns the table
2574        as a Python dict instead of a Python list, which can be even
2575        more convenient. The primary key column(s) of the table will
2576        be used as the keys of the dictionary, while the other column(s)
2577        will be the corresponding values.  The keys will be named tuples
2578        if the table has a composite primary key.  The rows will be also
2579        named tuples unless the 'scalar' option has been set to True.
2580        With the optional parameter 'keyname' you can specify an alternative
2581        set of columns to be used as the keys of the dictionary.  It must
2582        be set as a string, list or a tuple.
2583
2584        If the Python version supports it, the dictionary will be an
2585        OrderedDict using the order specified with the 'order' parameter
2586        or the key column(s) if not specified.  You can set 'order' to False
2587        if you don't care about the ordering.  In this case the returned
2588        dictionary will be an ordinary one.
2589        """
2590        if not table:
2591            raise TypeError('The table name is missing')
2592        if not keyname:
2593            try:
2594                keyname = self.pkey(table, True)
2595            except (KeyError, ProgrammingError):
2596                raise _prg_error('Table %s has no primary key' % table)
2597        if isinstance(keyname, basestring):
2598            keyname = [keyname]
2599        elif not isinstance(keyname, (list, tuple)):
2600            raise KeyError('The keyname must be a string, list or tuple')
2601        if what:
2602            if isinstance(what, (list, tuple)):
2603                what = ', '.join(map(str, what))
2604            if order is None:
2605                order = what
2606        else:
2607            what = '*'
2608        q = ['SELECT', what, 'FROM', table]
2609        if where:
2610            if isinstance(where, (list, tuple)):
2611                where = ' AND '.join(map(str, where))
2612            q.extend(['WHERE', where])
2613        if order is None:
2614            order = keyname
2615        if order:
2616            if isinstance(order, (list, tuple)):
2617                order = ', '.join(map(str, order))
2618            q.extend(['ORDER BY', order])
2619        if limit:
2620            q.append('LIMIT %d' % limit)
2621        if offset:
2622            q.append('OFFSET %d' % offset)
2623        q = ' '.join(q)
2624        self._do_debug(q)
2625        q = self.db.query(q)
2626        res = q.getresult()
2627        cls = OrderedDict if order else dict
2628        if not res:
2629            return cls()
2630        keyset = set(keyname)
2631        fields = q.listfields()
2632        if not keyset.issubset(fields):
2633            raise KeyError('Missing keyname in row')
2634        keyind, rowind = [], []
2635        for i, f in enumerate(fields):
2636            (keyind if f in keyset else rowind).append(i)
2637        keytuple = len(keyind) > 1
2638        getkey = itemgetter(*keyind)
2639        keys = map(getkey, res)
2640        if scalar:
2641            rowind = rowind[:1]
2642            rowtuple = False
2643        else:
2644            rowtuple = len(rowind) > 1
2645        if scalar or rowtuple:
2646            getrow = itemgetter(*rowind)
2647        else:
2648            rowind = rowind[0]
2649            getrow = lambda row: (row[rowind],)
2650            rowtuple = True
2651        rows = map(getrow, res)
2652        if keytuple or rowtuple:
2653            namedresult = get_namedresult()
2654            if namedresult:
2655                if keytuple:
2656                    keys = namedresult(_MemoryQuery(keys, keyname))
2657                if rowtuple:
2658                    fields = [f for f in fields if f not in keyset]
2659                    rows = namedresult(_MemoryQuery(rows, fields))
2660        return cls(zip(keys, rows))
2661
2662    def notification_handler(self,
2663            event, callback, arg_dict=None, timeout=None, stop_event=None):
2664        """Get notification handler that will run the given callback."""
2665        return NotificationHandler(self,
2666            event, callback, arg_dict, timeout, stop_event)
2667
2668
2669# if run as script, print some information
2670
2671if __name__ == '__main__':
2672    print('PyGreSQL version' + version)
2673    print('')
2674    print(__doc__)
Note: See TracBrowser for help on using the repository browser.