source: trunk/pg.py @ 900

Last change on this file since 900 was 900, checked in by cito, 3 years ago

Allow query_formatted() to be used without parameters

  • Property svn:keywords set to Id
File size: 93.7 KB
Line 
1#! /usr/bin/python
2#
3# pg.py
4#
5# $Id: pg.py 900 2016-12-31 08:40:24Z cito $
6#
7
8"""PyGreSQL classic interface.
9
10This pg module implements some basic database management stuff.
11It includes the _pg module and builds on it, providing the higher
12level wrapper class named DB with additional functionality.
13This is known as the "classic" ("old style") PyGreSQL interface.
14For a DB-API 2 compliant interface use the newer pgdb module.
15"""
16
17# Copyright (c) 1997-2016 by D'Arcy J.M. Cain.
18#
19# Contributions made by Ch. Zwerschke and others.
20#
21# The notification handler is based on pgnotify which is
22# Copyright (c) 2001 Ng Pheng Siong. All rights reserved.
23#
24# Permission to use, copy, modify, and distribute this software and its
25# documentation for any purpose and without fee is hereby granted,
26# provided that the above copyright notice appear in all copies and that
27# both that copyright notice and this permission notice appear in
28# supporting documentation.
29
30from __future__ import print_function, division
31
32from _pg import *
33
34__version__ = version
35
36import select
37import warnings
38import weakref
39
40from datetime import date, time, datetime, timedelta, tzinfo
41from decimal import Decimal
42from math import isnan, isinf
43from collections import namedtuple
44from operator import itemgetter
45from functools import partial
46from re import compile as regex
47from json import loads as jsondecode, dumps as jsonencode
48from uuid import UUID
49
50try:
51    long
52except NameError:  # Python >= 3.0
53    long = int
54
55try:
56    basestring
57except NameError:  # Python >= 3.0
58    basestring = (str, bytes)
59
60try:
61    from functools import lru_cache
62except ImportError:  # Python < 3.2
63    from functools import update_wrapper
64    try:
65        from _thread import RLock
66    except ImportError:
67        class RLock:  # for builds without threads
68            def __enter__(self): pass
69
70            def __exit__(self, exctype, excinst, exctb): pass
71
72    def lru_cache(maxsize=128):
73        """Simplified functools.lru_cache decorator for one argument."""
74
75        def decorator(function):
76            sentinel = object()
77            cache = {}
78            get = cache.get
79            lock = RLock()
80            root = []
81            root_full = [root, False]
82            root[:] = [root, root, None, None]
83
84            if maxsize == 0:
85
86                def wrapper(arg):
87                    res = function(arg)
88                    return res
89
90            elif maxsize is None:
91
92                def wrapper(arg):
93                    res = get(arg, sentinel)
94                    if res is not sentinel:
95                        return res
96                    res = function(arg)
97                    cache[arg] = res
98                    return res
99
100            else:
101
102                def wrapper(arg):
103                    with lock:
104                        link = get(arg)
105                        if link is not None:
106                            root = root_full[0]
107                            prev, next, _arg, res = link
108                            prev[1] = next
109                            next[0] = prev
110                            last = root[0]
111                            last[1] = root[0] = link
112                            link[0] = last
113                            link[1] = root
114                            return res
115                    res = function(arg)
116                    with lock:
117                        root, full = root_full
118                        if arg in cache:
119                            pass
120                        elif full:
121                            oldroot = root
122                            oldroot[2] = arg
123                            oldroot[3] = res
124                            root = root_full[0] = oldroot[1]
125                            oldarg = root[2]
126                            oldres = root[3]  # keep reference
127                            root[2] = root[3] = None
128                            del cache[oldarg]
129                            cache[arg] = oldroot
130                        else:
131                            last = root[0]
132                            link = [last, root, arg, res]
133                            last[1] = root[0] = cache[arg] = link
134                            if len(cache) >= maxsize:
135                                root_full[1] = True
136                    return res
137
138            wrapper.__wrapped__ = function
139            return update_wrapper(wrapper, function)
140
141        return decorator
142
143
144# Auxiliary classes and functions that are independent from a DB connection:
145
146try:
147    from collections import OrderedDict
148except ImportError:  # Python 2.6 or 3.0
149    OrderedDict = dict
150
151
152    class AttrDict(dict):
153        """Simple read-only ordered dictionary for storing attribute names."""
154
155        def __init__(self, *args, **kw):
156            if len(args) > 1 or kw:
157                raise TypeError
158            items = args[0] if args else []
159            if isinstance(items, dict):
160                raise TypeError
161            items = list(items)
162            self._keys = [item[0] for item in items]
163            dict.__init__(self, items)
164            self._read_only = True
165            error = self._read_only_error
166            self.clear = self.update = error
167            self.pop = self.setdefault = self.popitem = error
168
169        def __setitem__(self, key, value):
170            if self._read_only:
171                self._read_only_error()
172            dict.__setitem__(self, key, value)
173
174        def __delitem__(self, key):
175            if self._read_only:
176                self._read_only_error()
177            dict.__delitem__(self, key)
178
179        def __iter__(self):
180            return iter(self._keys)
181
182        def keys(self):
183            return list(self._keys)
184
185        def values(self):
186            return [self[key] for key in self]
187
188        def items(self):
189            return [(key, self[key]) for key in self]
190
191        def iterkeys(self):
192            return self.__iter__()
193
194        def itervalues(self):
195            return iter(self.values())
196
197        def iteritems(self):
198            return iter(self.items())
199
200        @staticmethod
201        def _read_only_error(*args, **kw):
202            raise TypeError('This object is read-only')
203
204else:
205
206     class AttrDict(OrderedDict):
207        """Simple read-only ordered dictionary for storing attribute names."""
208
209        def __init__(self, *args, **kw):
210            self._read_only = False
211            OrderedDict.__init__(self, *args, **kw)
212            self._read_only = True
213            error = self._read_only_error
214            self.clear = self.update = error
215            self.pop = self.setdefault = self.popitem = error
216
217        def __setitem__(self, key, value):
218            if self._read_only:
219                self._read_only_error()
220            OrderedDict.__setitem__(self, key, value)
221
222        def __delitem__(self, key):
223            if self._read_only:
224                self._read_only_error()
225            OrderedDict.__delitem__(self, key)
226
227        @staticmethod
228        def _read_only_error(*args, **kw):
229            raise TypeError('This object is read-only')
230
231try:
232    from inspect import signature
233except ImportError:  # Python < 3.3
234    from inspect import getargspec
235
236    def get_args(func):
237        return getargspec(func).args
238else:
239
240    def get_args(func):
241        return list(signature(func).parameters)
242
243try:
244    from datetime import timezone
245except ImportError:  # Python < 3.2
246
247    class timezone(tzinfo):
248        """Simple timezone implementation."""
249
250        def __init__(self, offset, name=None):
251            self.offset = offset
252            if not name:
253                minutes = self.offset.days * 1440 + self.offset.seconds // 60
254                if minutes < 0:
255                    hours, minutes = divmod(-minutes, 60)
256                    hours = -hours
257                else:
258                    hours, minutes = divmod(minutes, 60)
259                name = 'UTC%+03d:%02d' % (hours, minutes)
260            self.name = name
261
262        def utcoffset(self, dt):
263            return self.offset
264
265        def tzname(self, dt):
266            return self.name
267
268        def dst(self, dt):
269            return None
270
271    timezone.utc = timezone(timedelta(0), 'UTC')
272
273    _has_timezone = False
274else:
275    _has_timezone = True
276
277# time zones used in Postgres timestamptz output
278_timezones = dict(CET='+0100', EET='+0200', EST='-0500',
279    GMT='+0000', HST='-1000', MET='+0100', MST='-0700',
280    UCT='+0000', UTC='+0000', WET='+0000')
281
282
283def _timezone_as_offset(tz):
284    if tz.startswith(('+', '-')):
285        if len(tz) < 5:
286            return tz + '00'
287        return tz.replace(':', '')
288    return _timezones.get(tz, '+0000')
289
290
291def _get_timezone(tz):
292    tz = _timezone_as_offset(tz)
293    minutes = 60 * int(tz[1:3]) + int(tz[3:5])
294    if tz[0] == '-':
295        minutes = -minutes
296    return timezone(timedelta(minutes=minutes), tz)
297
298
299def _oid_key(table):
300    """Build oid key from a table name."""
301    return 'oid(%s)' % table
302
303
304class _SimpleTypes(dict):
305    """Dictionary mapping pg_type names to simple type names."""
306
307    _types = {'bool': 'bool',
308        'bytea': 'bytea',
309        'date': 'date interval time timetz timestamp timestamptz'
310            ' abstime reltime',  # these are very old
311        'float': 'float4 float8',
312        'int': 'cid int2 int4 int8 oid xid',
313        'hstore': 'hstore', 'json': 'json jsonb', 'uuid': 'uuid',
314        'num': 'numeric', 'money': 'money',
315        'text': 'bpchar char name text varchar'}
316
317    def __init__(self):
318        for typ, keys in self._types.items():
319            for key in keys.split():
320                self[key] = typ
321                self['_%s' % key] = '%s[]' % typ
322
323    # this could be a static method in Python > 2.6
324    def __missing__(self, key):
325        return 'text'
326
327_simpletypes = _SimpleTypes()
328
329
330def _quote_if_unqualified(param, name):
331    """Quote parameter representing a qualified name.
332
333    Puts a quote_ident() call around the give parameter unless
334    the name contains a dot, in which case the name is ambiguous
335    (could be a qualified name or just a name with a dot in it)
336    and must be quoted manually by the caller.
337    """
338    if isinstance(name, basestring) and '.' not in name:
339        return 'quote_ident(%s)' % (param,)
340    return param
341
342
343class _ParameterList(list):
344    """Helper class for building typed parameter lists."""
345
346    def add(self, value, typ=None):
347        """Typecast value with known database type and build parameter list.
348
349        If this is a literal value, it will be returned as is.  Otherwise, a
350        placeholder will be returned and the parameter list will be augmented.
351        """
352        value = self.adapt(value, typ)
353        if isinstance(value, Literal):
354            return value
355        self.append(value)
356        return '$%d' % len(self)
357
358
359class Bytea(bytes):
360    """Wrapper class for marking Bytea values."""
361
362
363class Hstore(dict):
364    """Wrapper class for marking hstore values."""
365
366    _re_quote = regex('^[Nn][Uu][Ll][Ll]$|[ ,=>]')
367
368    @classmethod
369    def _quote(cls, s):
370        if s is None:
371            return 'NULL'
372        if not s:
373            return '""'
374        s = s.replace('"', '\\"')
375        if cls._re_quote.search(s):
376            s = '"%s"' % s
377        return s
378
379    def __str__(self):
380        q = self._quote
381        return ','.join('%s=>%s' % (q(k), q(v)) for k, v in self.items())
382
383
384class Json:
385    """Wrapper class for marking Json values."""
386
387    def __init__(self, obj):
388        self.obj = obj
389
390
391class Literal(str):
392    """Wrapper class for marking literal SQL values."""
393
394
395class Adapter:
396    """Class providing methods for adapting parameters to the database."""
397
398    _bool_true_values = frozenset('t true 1 y yes on'.split())
399
400    _date_literals = frozenset('current_date current_time'
401        ' current_timestamp localtime localtimestamp'.split())
402
403    _re_array_quote = regex(r'[{},"\\\s]|^[Nn][Uu][Ll][Ll]$')
404    _re_record_quote = regex(r'[(,"\\]')
405    _re_array_escape = _re_record_escape = regex(r'(["\\])')
406
407    def __init__(self, db):
408        self.db = weakref.proxy(db)
409
410    @classmethod
411    def _adapt_bool(cls, v):
412        """Adapt a boolean parameter."""
413        if isinstance(v, basestring):
414            if not v:
415                return None
416            v = v.lower() in cls._bool_true_values
417        return 't' if v else 'f'
418
419    @classmethod
420    def _adapt_date(cls, v):
421        """Adapt a date parameter."""
422        if not v:
423            return None
424        if isinstance(v, basestring) and v.lower() in cls._date_literals:
425            return Literal(v)
426        return v
427
428    @staticmethod
429    def _adapt_num(v):
430        """Adapt a numeric parameter."""
431        if not v and v != 0:
432            return None
433        return v
434
435    _adapt_int = _adapt_float = _adapt_money = _adapt_num
436
437    def _adapt_bytea(self, v):
438        """Adapt a bytea parameter."""
439        return self.db.escape_bytea(v)
440
441    def _adapt_json(self, v):
442        """Adapt a json parameter."""
443        if not v:
444            return None
445        if isinstance(v, basestring):
446            return v
447        return self.db.encode_json(v)
448
449    @classmethod
450    def _adapt_text_array(cls, v):
451        """Adapt a text type array parameter."""
452        if isinstance(v, list):
453            adapt = cls._adapt_text_array
454            return '{%s}' % ','.join(adapt(v) for v in v)
455        if v is None:
456            return 'null'
457        if not v:
458            return '""'
459        v = str(v)
460        if cls._re_array_quote.search(v):
461            v = '"%s"' % cls._re_array_escape.sub(r'\\\1', v)
462        return v
463
464    _adapt_date_array = _adapt_text_array
465
466    @classmethod
467    def _adapt_bool_array(cls, v):
468        """Adapt a boolean array parameter."""
469        if isinstance(v, list):
470            adapt = cls._adapt_bool_array
471            return '{%s}' % ','.join(adapt(v) for v in v)
472        if v is None:
473            return 'null'
474        if isinstance(v, basestring):
475            if not v:
476                return 'null'
477            v = v.lower() in cls._bool_true_values
478        return 't' if v else 'f'
479
480    @classmethod
481    def _adapt_num_array(cls, v):
482        """Adapt a numeric array parameter."""
483        if isinstance(v, list):
484            adapt = cls._adapt_num_array
485            return '{%s}' % ','.join(adapt(v) for v in v)
486        if not v and v != 0:
487            return 'null'
488        return str(v)
489
490    _adapt_int_array = _adapt_float_array = _adapt_money_array = \
491            _adapt_num_array
492
493    def _adapt_bytea_array(self, v):
494        """Adapt a bytea array parameter."""
495        if isinstance(v, list):
496            return b'{' + b','.join(
497                self._adapt_bytea_array(v) for v in v) + b'}'
498        if v is None:
499            return b'null'
500        return self.db.escape_bytea(v).replace(b'\\', b'\\\\')
501
502    def _adapt_json_array(self, v):
503        """Adapt a json array parameter."""
504        if isinstance(v, list):
505            adapt = self._adapt_json_array
506            return '{%s}' % ','.join(adapt(v) for v in v)
507        if not v:
508            return 'null'
509        if not isinstance(v, basestring):
510            v = self.db.encode_json(v)
511        if self._re_array_quote.search(v):
512            v = '"%s"' % self._re_array_escape.sub(r'\\\1', v)
513        return v
514
515    def _adapt_record(self, v, typ):
516        """Adapt a record parameter with given type."""
517        typ = self.get_attnames(typ).values()
518        if len(typ) != len(v):
519            raise TypeError('Record parameter %s has wrong size' % v)
520        adapt = self.adapt
521        value = []
522        for v, t in zip(v, typ):
523            v = adapt(v, t)
524            if v is None:
525                v = ''
526            elif not v:
527                v = '""'
528            else:
529                if isinstance(v, bytes):
530                    if str is not bytes:
531                        v = v.decode('ascii')
532                else:
533                    v = str(v)
534                if self._re_record_quote.search(v):
535                    v = '"%s"' % self._re_record_escape.sub(r'\\\1', v)
536            value.append(v)
537        return '(%s)' % ','.join(value)
538
539    def adapt(self, value, typ=None):
540        """Adapt a value with known database type."""
541        if value is not None and not isinstance(value, Literal):
542            if typ:
543                simple = self.get_simple_name(typ)
544            else:
545                typ = simple = self.guess_simple_type(value) or 'text'
546            try:
547                value = value.__pg_str__(typ)
548            except AttributeError:
549                pass
550            if simple == 'text':
551                pass
552            elif simple == 'record':
553                if isinstance(value, tuple):
554                    value = self._adapt_record(value, typ)
555            elif simple.endswith('[]'):
556                if isinstance(value, list):
557                    adapt = getattr(self, '_adapt_%s_array' % simple[:-2])
558                    value = adapt(value)
559            else:
560                adapt = getattr(self, '_adapt_%s' % simple)
561                value = adapt(value)
562        return value
563
564    @staticmethod
565    def simple_type(name):
566        """Create a simple database type with given attribute names."""
567        typ = DbType(name)
568        typ.simple = name
569        return typ
570
571    @staticmethod
572    def get_simple_name(typ):
573        """Get the simple name of a database type."""
574        if isinstance(typ, DbType):
575            return typ.simple
576        return _simpletypes[typ]
577
578    @staticmethod
579    def get_attnames(typ):
580        """Get the attribute names of a composite database type."""
581        if isinstance(typ, DbType):
582            return typ.attnames
583        return {}
584
585    @classmethod
586    def guess_simple_type(cls, value):
587        """Try to guess which database type the given value has."""
588        if isinstance(value, Bytea):
589            return 'bytea'
590        if isinstance(value, basestring):
591            return 'text'
592        if isinstance(value, bool):
593            return 'bool'
594        if isinstance(value, (int, long)):
595            return 'int'
596        if isinstance(value, float):
597            return 'float'
598        if isinstance(value, Decimal):
599            return 'num'
600        if isinstance(value, (date, time, datetime, timedelta)):
601            return 'date'
602        if isinstance(value, list):
603            return '%s[]' % (cls.guess_simple_base_type(value) or 'text',)
604        if isinstance(value, tuple):
605            simple_type = cls.simple_type
606            typ = simple_type('record')
607            guess = cls.guess_simple_type
608            def get_attnames(self):
609                return AttrDict((str(n + 1), simple_type(guess(v)))
610                    for n, v in enumerate(value))
611            typ._get_attnames = get_attnames
612            return typ
613
614    @classmethod
615    def guess_simple_base_type(cls, value):
616        """Try to guess the base type of a given array."""
617        for v in value:
618            if isinstance(v, list):
619                typ = cls.guess_simple_base_type(v)
620            else:
621                typ = cls.guess_simple_type(v)
622            if typ:
623                return typ
624
625    def adapt_inline(self, value, nested=False):
626        """Adapt a value that is put into the SQL and needs to be quoted."""
627        if value is None:
628            return 'NULL'
629        if isinstance(value, Literal):
630            return value
631        if isinstance(value, Bytea):
632            value = self.db.escape_bytea(value)
633            if bytes is not str:  # Python >= 3.0
634                value = value.decode('ascii')
635        elif isinstance(value, Json):
636            if value.encode:
637                return value.encode()
638            value = self.db.encode_json(value)
639        elif isinstance(value, (datetime, date, time, timedelta)):
640            value = str(value)
641        if isinstance(value, basestring):
642            value = self.db.escape_string(value)
643            return "'%s'" % value
644        if isinstance(value, bool):
645            return 'true' if value else 'false'
646        if isinstance(value, float):
647            if isinf(value):
648                return "'-Infinity'" if value < 0 else "'Infinity'"
649            if isnan(value):
650                return "'NaN'"
651            return value
652        if isinstance(value, (int, long, Decimal)):
653            return value
654        if isinstance(value, list):
655            q = self.adapt_inline
656            s = '[%s]' if nested else 'ARRAY[%s]'
657            return s % ','.join(str(q(v, nested=True)) for v in value)
658        if isinstance(value, tuple):
659            q = self.adapt_inline
660            return '(%s)' % ','.join(str(q(v)) for v in value)
661        try:
662            value = value.__pg_repr__()
663        except AttributeError:
664            raise InterfaceError(
665                'Do not know how to adapt type %s' % type(value))
666        if isinstance(value, (tuple, list)):
667            value = self.adapt_inline(value)
668        return value
669
670    def parameter_list(self):
671        """Return a parameter list for parameters with known database types.
672
673        The list has an add(value, typ) method that will build up the
674        list and return either the literal value or a placeholder.
675        """
676        params = _ParameterList()
677        params.adapt = self.adapt
678        return params
679
680    def format_query(self, command, values=None, types=None, inline=False):
681        """Format a database query using the given values and types."""
682        if not values:
683            return command, []
684        if inline and types:
685            raise ValueError('Typed parameters must be sent separately')
686        params = self.parameter_list()
687        if isinstance(values, (list, tuple)):
688            if inline:
689                adapt = self.adapt_inline
690                literals = [adapt(value) for value in values]
691            else:
692                add = params.add
693                literals = []
694                append = literals.append
695                if types:
696                    if (not isinstance(types, (list, tuple)) or
697                            len(types) != len(values)):
698                        raise TypeError('The values and types do not match')
699                    for value, typ in zip(values, types):
700                        append(add(value, typ))
701                else:
702                    for value in values:
703                        append(add(value))
704            command %= tuple(literals)
705        elif isinstance(values, dict):
706            # we want to allow extra keys in the dictionary,
707            # so we first must find the values actually used in the command
708            used_values = {}
709            literals = dict.fromkeys(values, '')
710            for key in literals:
711                del literals[key]
712                try:
713                    command % literals
714                except KeyError:
715                    used_values[key] = values[key]
716                literals[key] = ''
717            values = used_values
718            if inline:
719                adapt = self.adapt_inline
720                literals = dict((key, adapt(value))
721                    for key, value in values.items())
722            else:
723                add = params.add
724                literals = {}
725                if types:
726                    if not isinstance(types, dict):
727                        raise TypeError('The values and types do not match')
728                    for key in sorted(values):
729                        literals[key] = add(values[key], types.get(key))
730                else:
731                    for key in sorted(values):
732                        literals[key] = add(values[key])
733            command %= literals
734        else:
735            raise TypeError('The values must be passed as tuple, list or dict')
736        return command, params
737
738
739def cast_bool(value):
740    """Cast a boolean value."""
741    if not get_bool():
742        return value
743    return value[0] == 't'
744
745
746def cast_json(value):
747    """Cast a JSON value."""
748    cast = get_jsondecode()
749    if not cast:
750        return value
751    return cast(value)
752
753
754def cast_num(value):
755    """Cast a numeric value."""
756    return (get_decimal() or float)(value)
757
758
759def cast_money(value):
760    """Cast a money value."""
761    point = get_decimal_point()
762    if not point:
763        return value
764    if point != '.':
765        value = value.replace(point, '.')
766    value = value.replace('(', '-')
767    value = ''.join(c for c in value if c.isdigit() or c in '.-')
768    return (get_decimal() or float)(value)
769
770
771def cast_int2vector(value):
772    """Cast an int2vector value."""
773    return [int(v) for v in value.split()]
774
775
776def cast_date(value, connection):
777    """Cast a date value."""
778    # The output format depends on the server setting DateStyle.  The default
779    # setting ISO and the setting for German are actually unambiguous.  The
780    # order of days and months in the other two settings is however ambiguous,
781    # so at least here we need to consult the setting to properly parse values.
782    if value == '-infinity':
783        return date.min
784    if value == 'infinity':
785        return date.max
786    value = value.split()
787    if value[-1] == 'BC':
788        return date.min
789    value = value[0]
790    if len(value) > 10:
791        return date.max
792    fmt = connection.date_format()
793    return datetime.strptime(value, fmt).date()
794
795
796def cast_time(value):
797    """Cast a time value."""
798    fmt = '%H:%M:%S.%f' if len(value) > 8 else '%H:%M:%S'
799    return datetime.strptime(value, fmt).time()
800
801
802_re_timezone = regex('(.*)([+-].*)')
803
804
805def cast_timetz(value):
806    """Cast a timetz value."""
807    tz = _re_timezone.match(value)
808    if tz:
809        value, tz = tz.groups()
810    else:
811        tz = '+0000'
812    fmt = '%H:%M:%S.%f' if len(value) > 8 else '%H:%M:%S'
813    if _has_timezone:
814        value += _timezone_as_offset(tz)
815        fmt += '%z'
816        return datetime.strptime(value, fmt).timetz()
817    return datetime.strptime(value, fmt).timetz().replace(
818        tzinfo=_get_timezone(tz))
819
820
821def cast_timestamp(value, connection):
822    """Cast a timestamp value."""
823    if value == '-infinity':
824        return datetime.min
825    if value == 'infinity':
826        return datetime.max
827    value = value.split()
828    if value[-1] == 'BC':
829        return datetime.min
830    fmt = connection.date_format()
831    if fmt.endswith('-%Y') and len(value) > 2:
832        value = value[1:5]
833        if len(value[3]) > 4:
834            return datetime.max
835        fmt = ['%d %b' if fmt.startswith('%d') else '%b %d',
836            '%H:%M:%S.%f' if len(value[2]) > 8 else '%H:%M:%S', '%Y']
837    else:
838        if len(value[0]) > 10:
839            return datetime.max
840        fmt = [fmt, '%H:%M:%S.%f' if len(value[1]) > 8 else '%H:%M:%S']
841    return datetime.strptime(' '.join(value), ' '.join(fmt))
842
843
844def cast_timestamptz(value, connection):
845    """Cast a timestamptz value."""
846    if value == '-infinity':
847        return datetime.min
848    if value == 'infinity':
849        return datetime.max
850    value = value.split()
851    if value[-1] == 'BC':
852        return datetime.min
853    fmt = connection.date_format()
854    if fmt.endswith('-%Y') and len(value) > 2:
855        value = value[1:]
856        if len(value[3]) > 4:
857            return datetime.max
858        fmt = ['%d %b' if fmt.startswith('%d') else '%b %d',
859            '%H:%M:%S.%f' if len(value[2]) > 8 else '%H:%M:%S', '%Y']
860        value, tz = value[:-1], value[-1]
861    else:
862        if fmt.startswith('%Y-'):
863            tz = _re_timezone.match(value[1])
864            if tz:
865                value[1], tz = tz.groups()
866            else:
867                tz = '+0000'
868        else:
869            value, tz = value[:-1], value[-1]
870        if len(value[0]) > 10:
871            return datetime.max
872        fmt = [fmt, '%H:%M:%S.%f' if len(value[1]) > 8 else '%H:%M:%S']
873    if _has_timezone:
874        value.append(_timezone_as_offset(tz))
875        fmt.append('%z')
876        return datetime.strptime(' '.join(value), ' '.join(fmt))
877    return datetime.strptime(' '.join(value), ' '.join(fmt)).replace(
878        tzinfo=_get_timezone(tz))
879
880
881_re_interval_sql_standard = regex(
882    '(?:([+-])?([0-9]+)-([0-9]+) ?)?'
883    '(?:([+-]?[0-9]+)(?!:) ?)?'
884    '(?:([+-])?([0-9]+):([0-9]+):([0-9]+)(?:\\.([0-9]+))?)?')
885
886_re_interval_postgres = regex(
887    '(?:([+-]?[0-9]+) ?years? ?)?'
888    '(?:([+-]?[0-9]+) ?mons? ?)?'
889    '(?:([+-]?[0-9]+) ?days? ?)?'
890    '(?:([+-])?([0-9]+):([0-9]+):([0-9]+)(?:\\.([0-9]+))?)?')
891
892_re_interval_postgres_verbose = regex(
893    '@ ?(?:([+-]?[0-9]+) ?years? ?)?'
894    '(?:([+-]?[0-9]+) ?mons? ?)?'
895    '(?:([+-]?[0-9]+) ?days? ?)?'
896    '(?:([+-]?[0-9]+) ?hours? ?)?'
897    '(?:([+-]?[0-9]+) ?mins? ?)?'
898    '(?:([+-])?([0-9]+)(?:\\.([0-9]+))? ?secs?)? ?(ago)?')
899
900_re_interval_iso_8601 = regex(
901    'P(?:([+-]?[0-9]+)Y)?'
902    '(?:([+-]?[0-9]+)M)?'
903    '(?:([+-]?[0-9]+)D)?'
904    '(?:T(?:([+-]?[0-9]+)H)?'
905    '(?:([+-]?[0-9]+)M)?'
906    '(?:([+-])?([0-9]+)(?:\\.([0-9]+))?S)?)?')
907
908
909def cast_interval(value):
910    """Cast an interval value."""
911    # The output format depends on the server setting IntervalStyle, but it's
912    # not necessary to consult this setting to parse it.  It's faster to just
913    # check all possible formats, and there is no ambiguity here.
914    m = _re_interval_iso_8601.match(value)
915    if m:
916        m = [d or '0' for d in m.groups()]
917        secs_ago = m.pop(5) == '-'
918        m = [int(d) for d in m]
919        years, mons, days, hours, mins, secs, usecs = m
920        if secs_ago:
921            secs = -secs
922            usecs = -usecs
923    else:
924        m = _re_interval_postgres_verbose.match(value)
925        if m:
926            m, ago = [d or '0' for d in m.groups()[:8]], m.group(9)
927            secs_ago = m.pop(5) == '-'
928            m = [-int(d) for d in m] if ago else [int(d) for d in m]
929            years, mons, days, hours, mins, secs, usecs = m
930            if secs_ago:
931                secs = - secs
932                usecs = -usecs
933        else:
934            m = _re_interval_postgres.match(value)
935            if m and any(m.groups()):
936                m = [d or '0' for d in m.groups()]
937                hours_ago = m.pop(3) == '-'
938                m = [int(d) for d in m]
939                years, mons, days, hours, mins, secs, usecs = m
940                if hours_ago:
941                    hours = -hours
942                    mins = -mins
943                    secs = -secs
944                    usecs = -usecs
945            else:
946                m = _re_interval_sql_standard.match(value)
947                if m and any(m.groups()):
948                    m = [d or '0' for d in m.groups()]
949                    years_ago = m.pop(0) == '-'
950                    hours_ago = m.pop(3) == '-'
951                    m = [int(d) for d in m]
952                    years, mons, days, hours, mins, secs, usecs = m
953                    if years_ago:
954                        years = -years
955                        mons = -mons
956                    if hours_ago:
957                        hours = -hours
958                        mins = -mins
959                        secs = -secs
960                        usecs = -usecs
961                else:
962                    raise ValueError('Cannot parse interval: %s' % value)
963    days += 365 * years + 30 * mons
964    return timedelta(days=days, hours=hours, minutes=mins,
965        seconds=secs, microseconds=usecs)
966
967
968class Typecasts(dict):
969    """Dictionary mapping database types to typecast functions.
970
971    The cast functions get passed the string representation of a value in
972    the database which they need to convert to a Python object.  The
973    passed string will never be None since NULL values are already be
974    handled before the cast function is called.
975
976    Note that the basic types are already handled by the C extension.
977    They only need to be handled here as record or array components.
978    """
979
980    # the default cast functions
981    # (str functions are ignored but have been added for faster access)
982    defaults = {'char': str, 'bpchar': str, 'name': str,
983        'text': str, 'varchar': str,
984        'bool': cast_bool, 'bytea': unescape_bytea,
985        'int2': int, 'int4': int, 'serial': int, 'int8': long, 'oid': int,
986        'hstore': cast_hstore, 'json': cast_json, 'jsonb': cast_json,
987        'float4': float, 'float8': float,
988        'numeric': cast_num, 'money': cast_money,
989        'date': cast_date, 'interval': cast_interval,
990        'time': cast_time, 'timetz': cast_timetz,
991        'timestamp': cast_timestamp, 'timestamptz': cast_timestamptz,
992        'int2vector': cast_int2vector, 'uuid': UUID,
993        'anyarray': cast_array, 'record': cast_record}
994
995    connection = None  # will be set in a connection specific instance
996
997    def __missing__(self, typ):
998        """Create a cast function if it is not cached.
999
1000        Note that this class never raises a KeyError,
1001        but returns None when no special cast function exists.
1002        """
1003        if not isinstance(typ, str):
1004            raise TypeError('Invalid type: %s' % typ)
1005        cast = self.defaults.get(typ)
1006        if cast:
1007            # store default for faster access
1008            cast = self._add_connection(cast)
1009            self[typ] = cast
1010        elif typ.startswith('_'):
1011            base_cast = self[typ[1:]]
1012            cast = self.create_array_cast(base_cast)
1013            if base_cast:
1014                self[typ] = cast
1015        else:
1016            attnames = self.get_attnames(typ)
1017            if attnames:
1018                casts = [self[v.pgtype] for v in attnames.values()]
1019                cast = self.create_record_cast(typ, attnames, casts)
1020                self[typ] = cast
1021        return cast
1022
1023    @staticmethod
1024    def _needs_connection(func):
1025        """Check if a typecast function needs a connection argument."""
1026        try:
1027            args = get_args(func)
1028        except (TypeError, ValueError):
1029            return False
1030        else:
1031            return 'connection' in args[1:]
1032
1033    def _add_connection(self, cast):
1034        """Add a connection argument to the typecast function if necessary."""
1035        if not self.connection or not self._needs_connection(cast):
1036            return cast
1037        return partial(cast, connection=self.connection)
1038
1039    def get(self, typ, default=None):
1040        """Get the typecast function for the given database type."""
1041        return self[typ] or default
1042
1043    def set(self, typ, cast):
1044        """Set a typecast function for the specified database type(s)."""
1045        if isinstance(typ, basestring):
1046            typ = [typ]
1047        if cast is None:
1048            for t in typ:
1049                self.pop(t, None)
1050                self.pop('_%s' % t, None)
1051        else:
1052            if not callable(cast):
1053                raise TypeError("Cast parameter must be callable")
1054            for t in typ:
1055                self[t] = self._add_connection(cast)
1056                self.pop('_%s' % t, None)
1057
1058    def reset(self, typ=None):
1059        """Reset the typecasts for the specified type(s) to their defaults.
1060
1061        When no type is specified, all typecasts will be reset.
1062        """
1063        if typ is None:
1064            self.clear()
1065        else:
1066            if isinstance(typ, basestring):
1067                typ = [typ]
1068            for t in typ:
1069                self.pop(t, None)
1070
1071    @classmethod
1072    def get_default(cls, typ):
1073        """Get the default typecast function for the given database type."""
1074        return cls.defaults.get(typ)
1075
1076    @classmethod
1077    def set_default(cls, typ, cast):
1078        """Set a default typecast function for the given database type(s)."""
1079        if isinstance(typ, basestring):
1080            typ = [typ]
1081        defaults = cls.defaults
1082        if cast is None:
1083            for t in typ:
1084                defaults.pop(t, None)
1085                defaults.pop('_%s' % t, None)
1086        else:
1087            if not callable(cast):
1088                raise TypeError("Cast parameter must be callable")
1089            for t in typ:
1090                defaults[t] = cast
1091                defaults.pop('_%s' % t, None)
1092
1093    def get_attnames(self, typ):
1094        """Return the fields for the given record type.
1095
1096        This method will be replaced with the get_attnames() method of DbTypes.
1097        """
1098        return {}
1099
1100    def dateformat(self):
1101        """Return the current date format.
1102
1103        This method will be replaced with the dateformat() method of DbTypes.
1104        """
1105        return '%Y-%m-%d'
1106
1107    def create_array_cast(self, basecast):
1108        """Create an array typecast for the given base cast."""
1109        cast_array = self['anyarray']
1110        def cast(v):
1111            return cast_array(v, basecast)
1112        return cast
1113
1114    def create_record_cast(self, name, fields, casts):
1115        """Create a named record typecast for the given fields and casts."""
1116        cast_record = self['record']
1117        record = namedtuple(name, fields)
1118        def cast(v):
1119            return record(*cast_record(v, casts))
1120        return cast
1121
1122
1123def get_typecast(typ):
1124    """Get the global typecast function for the given database type(s)."""
1125    return Typecasts.get_default(typ)
1126
1127
1128def set_typecast(typ, cast):
1129    """Set a global typecast function for the given database type(s).
1130
1131    Note that connections cache cast functions. To be sure a global change
1132    is picked up by a running connection, call db.db_types.reset_typecast().
1133    """
1134    Typecasts.set_default(typ, cast)
1135
1136
1137class DbType(str):
1138    """Class augmenting the simple type name with additional info.
1139
1140    The following additional information is provided:
1141
1142        oid: the PostgreSQL type OID
1143        pgtype: the PostgreSQL type name
1144        regtype: the regular type name
1145        simple: the simple PyGreSQL type name
1146        typtype: b = base type, c = composite type etc.
1147        category: A = Array, b = Boolean, C = Composite etc.
1148        delim: delimiter for array types
1149        relid: corresponding table for composite types
1150        attnames: attributes for composite types
1151    """
1152
1153    @property
1154    def attnames(self):
1155        """Get names and types of the fields of a composite type."""
1156        return self._get_attnames(self)
1157
1158
1159class DbTypes(dict):
1160    """Cache for PostgreSQL data types.
1161
1162    This cache maps type OIDs and names to DbType objects containing
1163    information on the associated database type.
1164    """
1165
1166    _num_types = frozenset('int float num money'
1167        ' int2 int4 int8 float4 float8 numeric money'.split())
1168
1169    def __init__(self, db):
1170        """Initialize type cache for connection."""
1171        super(DbTypes, self).__init__()
1172        self._db = weakref.proxy(db)
1173        self._regtypes = False
1174        self._typecasts = Typecasts()
1175        self._typecasts.get_attnames = self.get_attnames
1176        self._typecasts.connection = self._db
1177        if db.server_version < 80400:
1178            # older remote databases (not officially supported)
1179            self._query_pg_type = (
1180                "SELECT oid, typname, typname::text::regtype,"
1181                " typtype, null as typcategory, typdelim, typrelid"
1182                " FROM pg_type WHERE oid=%s::regtype")
1183        else:
1184            self._query_pg_type = (
1185                "SELECT oid, typname, typname::regtype,"
1186                " typtype, typcategory, typdelim, typrelid"
1187                " FROM pg_type WHERE oid=%s::regtype")
1188
1189    def add(self, oid, pgtype, regtype,
1190               typtype, category, delim, relid):
1191        """Create a PostgreSQL type name with additional info."""
1192        if oid in self:
1193            return self[oid]
1194        simple = 'record' if relid else _simpletypes[pgtype]
1195        typ = DbType(regtype if self._regtypes else simple)
1196        typ.oid = oid
1197        typ.simple = simple
1198        typ.pgtype = pgtype
1199        typ.regtype = regtype
1200        typ.typtype = typtype
1201        typ.category = category
1202        typ.delim = delim
1203        typ.relid = relid
1204        typ._get_attnames = self.get_attnames
1205        return typ
1206
1207    def __missing__(self, key):
1208        """Get the type info from the database if it is not cached."""
1209        try:
1210            q = self._query_pg_type % (_quote_if_unqualified('$1', key),)
1211            res = self._db.query(q, (key,)).getresult()
1212        except ProgrammingError:
1213            res = None
1214        if not res:
1215            raise KeyError('Type %s could not be found' % key)
1216        res = res[0]
1217        typ = self.add(*res)
1218        self[typ.oid] = self[typ.pgtype] = typ
1219        return typ
1220
1221    def get(self, key, default=None):
1222        """Get the type even if it is not cached."""
1223        try:
1224            return self[key]
1225        except KeyError:
1226            return default
1227
1228    def get_attnames(self, typ):
1229        """Get names and types of the fields of a composite type."""
1230        if not isinstance(typ, DbType):
1231            typ = self.get(typ)
1232            if not typ:
1233                return None
1234        if not typ.relid:
1235            return None
1236        return self._db.get_attnames(typ.relid, with_oid=False)
1237
1238    def get_typecast(self, typ):
1239        """Get the typecast function for the given database type."""
1240        return self._typecasts.get(typ)
1241
1242    def set_typecast(self, typ, cast):
1243        """Set a typecast function for the specified database type(s)."""
1244        self._typecasts.set(typ, cast)
1245
1246    def reset_typecast(self, typ=None):
1247        """Reset the typecast function for the specified database type(s)."""
1248        self._typecasts.reset(typ)
1249
1250    def typecast(self, value, typ):
1251        """Cast the given value according to the given database type."""
1252        if value is None:
1253            # for NULL values, no typecast is necessary
1254            return None
1255        if not isinstance(typ, DbType):
1256            typ = self.get(typ)
1257            if typ:
1258                typ = typ.pgtype
1259        cast = self.get_typecast(typ) if typ else None
1260        if not cast or cast is str:
1261            # no typecast is necessary
1262            return value
1263        return cast(value)
1264
1265
1266# The result rows for database operations are returned as named tuples
1267# by default. Since creating namedtuple classes is a somewhat expensive
1268# operation, we cache up to 1024 of these classes by default.
1269
1270@lru_cache(maxsize=1024)
1271def _row_factory(names):
1272    """Get a namedtuple factory for row results with the given names."""
1273    try:
1274        try:
1275            return namedtuple('Row', names, rename=True)._make
1276        except TypeError:  # Python 2.6 and 3.0 do not support rename
1277            names = [v if v.isalnum() else 'column_%d' % (n,)
1278                     for n, v in enumerate(names)]
1279            return namedtuple('Row', names)._make
1280    except ValueError:  # there is still a problem with the field names
1281        names = ['column_%d' % (n,) for n in range(len(names))]
1282        return namedtuple('Row', names)._make
1283
1284
1285def set_row_factory_size(maxsize):
1286    """Change the size of the namedtuple factory cache.
1287
1288    If maxsize is set to None, the cache can grow without bound.
1289    """
1290    global _row_factory
1291    _row_factory = lru_cache(maxsize)(_row_factory.__wrapped__)
1292
1293
1294def _namedresult(q):
1295    """Get query result as named tuples."""
1296    row = _row_factory(q.listfields())
1297    return [row(r) for r in q.getresult()]
1298
1299
1300class _MemoryQuery:
1301    """Class that embodies a given query result."""
1302
1303    def __init__(self, result, fields):
1304        """Create query from given result rows and field names."""
1305        self.result = result
1306        self.fields = tuple(fields)
1307
1308    def listfields(self):
1309        """Return the stored field names of this query."""
1310        return self.fields
1311
1312    def getresult(self):
1313        """Return the stored result of this query."""
1314        return self.result
1315
1316
1317def _db_error(msg, cls=DatabaseError):
1318    """Return DatabaseError with empty sqlstate attribute."""
1319    error = cls(msg)
1320    error.sqlstate = None
1321    return error
1322
1323
1324def _int_error(msg):
1325    """Return InternalError."""
1326    return _db_error(msg, InternalError)
1327
1328
1329def _prg_error(msg):
1330    """Return ProgrammingError."""
1331    return _db_error(msg, ProgrammingError)
1332
1333
1334# Initialize the C module
1335
1336set_namedresult(_namedresult)
1337set_decimal(Decimal)
1338set_jsondecode(jsondecode)
1339
1340
1341# The notification handler
1342
1343class NotificationHandler(object):
1344    """A PostgreSQL client-side asynchronous notification handler."""
1345
1346    def __init__(self, db, event, callback=None,
1347            arg_dict=None, timeout=None, stop_event=None):
1348        """Initialize the notification handler.
1349
1350        You must pass a PyGreSQL database connection, the name of an
1351        event (notification channel) to listen for and a callback function.
1352
1353        You can also specify a dictionary arg_dict that will be passed as
1354        the single argument to the callback function, and a timeout value
1355        in seconds (a floating point number denotes fractions of seconds).
1356        If it is absent or None, the callers will never time out.  If the
1357        timeout is reached, the callback function will be called with a
1358        single argument that is None.  If you set the timeout to zero,
1359        the handler will poll notifications synchronously and return.
1360
1361        You can specify the name of the event that will be used to signal
1362        the handler to stop listening as stop_event. By default, it will
1363        be the event name prefixed with 'stop_'.
1364        """
1365        self.db = db
1366        self.event = event
1367        self.stop_event = stop_event or 'stop_%s' % event
1368        self.listening = False
1369        self.callback = callback
1370        if arg_dict is None:
1371            arg_dict = {}
1372        self.arg_dict = arg_dict
1373        self.timeout = timeout
1374
1375    def __del__(self):
1376        self.unlisten()
1377
1378    def close(self):
1379        """Stop listening and close the connection."""
1380        if self.db:
1381            self.unlisten()
1382            self.db.close()
1383            self.db = None
1384
1385    def listen(self):
1386        """Start listening for the event and the stop event."""
1387        if not self.listening:
1388            self.db.query('listen "%s"' % self.event)
1389            self.db.query('listen "%s"' % self.stop_event)
1390            self.listening = True
1391
1392    def unlisten(self):
1393        """Stop listening for the event and the stop event."""
1394        if self.listening:
1395            self.db.query('unlisten "%s"' % self.event)
1396            self.db.query('unlisten "%s"' % self.stop_event)
1397            self.listening = False
1398
1399    def notify(self, db=None, stop=False, payload=None):
1400        """Generate a notification.
1401
1402        Optionally, you can pass a payload with the notification.
1403
1404        If you set the stop flag, a stop notification will be sent that
1405        will cause the handler to stop listening.
1406
1407        Note: If the notification handler is running in another thread, you
1408        must pass a different database connection since PyGreSQL database
1409        connections are not thread-safe.
1410        """
1411        if self.listening:
1412            if not db:
1413                db = self.db
1414            q = 'notify "%s"' % (self.stop_event if stop else self.event)
1415            if payload:
1416                q += ", '%s'" % payload
1417            return db.query(q)
1418
1419    def __call__(self):
1420        """Invoke the notification handler.
1421
1422        The handler is a loop that listens for notifications on the event
1423        and stop event channels.  When either of these notifications are
1424        received, its associated 'pid', 'event' and 'extra' (the payload
1425        passed with the notification) are inserted into its arg_dict
1426        dictionary and the callback is invoked with this dictionary as
1427        a single argument.  When the handler receives a stop event, it
1428        stops listening to both events and return.
1429
1430        In the special case that the timeout of the handler has been set
1431        to zero, the handler will poll all events synchronously and return.
1432        If will keep listening until it receives a stop event.
1433
1434        Note: If you run this loop in another thread, don't use the same
1435        database connection for database operations in the main thread.
1436        """
1437        self.listen()
1438        poll = self.timeout == 0
1439        if not poll:
1440            rlist = [self.db.fileno()]
1441        while self.listening:
1442            if poll or select.select(rlist, [], [], self.timeout)[0]:
1443                while self.listening:
1444                    notice = self.db.getnotify()
1445                    if not notice:  # no more messages
1446                        break
1447                    event, pid, extra = notice
1448                    if event not in (self.event, self.stop_event):
1449                        self.unlisten()
1450                        raise _db_error(
1451                            'Listening for "%s" and "%s", but notified of "%s"'
1452                            % (self.event, self.stop_event, event))
1453                    if event == self.stop_event:
1454                        self.unlisten()
1455                    self.arg_dict.update(pid=pid, event=event, extra=extra)
1456                    self.callback(self.arg_dict)
1457                if poll:
1458                    break
1459            else:   # we timed out
1460                self.unlisten()
1461                self.callback(None)
1462
1463
1464def pgnotify(*args, **kw):
1465    """Same as NotificationHandler, under the traditional name."""
1466    warnings.warn("pgnotify is deprecated, use NotificationHandler instead",
1467        DeprecationWarning, stacklevel=2)
1468    return NotificationHandler(*args, **kw)
1469
1470
1471# The actual PostGreSQL database connection interface:
1472
1473class DB:
1474    """Wrapper class for the _pg connection type."""
1475
1476    db = None  # invalid fallback for underlying connection
1477
1478    def __init__(self, *args, **kw):
1479        """Create a new connection
1480
1481        You can pass either the connection parameters or an existing
1482        _pg or pgdb connection. This allows you to use the methods
1483        of the classic pg interface with a DB-API 2 pgdb connection.
1484        """
1485        if not args and len(kw) == 1:
1486            db = kw.get('db')
1487        elif not kw and len(args) == 1:
1488            db = args[0]
1489        else:
1490            db = None
1491        if db:
1492            if isinstance(db, DB):
1493                db = db.db
1494            else:
1495                try:
1496                    db = db._cnx
1497                except AttributeError:
1498                    pass
1499        if not db or not hasattr(db, 'db') or not hasattr(db, 'query'):
1500            db = connect(*args, **kw)
1501            self._closeable = True
1502        else:
1503            self._closeable = False
1504        self.db = db
1505        self.dbname = db.db
1506        self._regtypes = False
1507        self._attnames = {}
1508        self._pkeys = {}
1509        self._privileges = {}
1510        self._args = args, kw
1511        self.adapter = Adapter(self)
1512        self.dbtypes = DbTypes(self)
1513        if db.server_version < 80400:
1514            # support older remote data bases
1515            self._query_attnames = (
1516                "SELECT a.attname, t.oid, t.typname, t.typname::text::regtype,"
1517                " t.typtype, null as typcategory, t.typdelim, t.typrelid"
1518                " FROM pg_attribute a"
1519                " JOIN pg_type t ON t.oid = a.atttypid"
1520                " WHERE a.attrelid = %s::regclass AND %s"
1521                " AND NOT a.attisdropped ORDER BY a.attnum")
1522        else:
1523            self._query_attnames = (
1524                "SELECT a.attname, t.oid, t.typname, t.typname::regtype,"
1525                " t.typtype, t.typcategory, t.typdelim, t.typrelid"
1526                " FROM pg_attribute a"
1527                " JOIN pg_type t ON t.oid = a.atttypid"
1528                " WHERE a.attrelid = %s::regclass AND %s"
1529                " AND NOT a.attisdropped ORDER BY a.attnum")
1530        db.set_cast_hook(self.dbtypes.typecast)
1531        self.debug = None  # For debugging scripts, this can be set
1532            # * to a string format specification (e.g. in CGI set to "%s<BR>"),
1533            # * to a file object to write debug statements or
1534            # * to a callable object which takes a string argument
1535            # * to any other true value to just print debug statements
1536
1537    def __getattr__(self, name):
1538        # All undefined members are same as in underlying connection:
1539        if self.db:
1540            return getattr(self.db, name)
1541        else:
1542            raise _int_error('Connection is not valid')
1543
1544    def __dir__(self):
1545        # Custom dir function including the attributes of the connection:
1546        attrs = set(self.__class__.__dict__)
1547        attrs.update(self.__dict__)
1548        attrs.update(dir(self.db))
1549        return sorted(attrs)
1550
1551    # Context manager methods
1552
1553    def __enter__(self):
1554        """Enter the runtime context. This will start a transaction."""
1555        self.begin()
1556        return self
1557
1558    def __exit__(self, et, ev, tb):
1559        """Exit the runtime context. This will end the transaction."""
1560        if et is None and ev is None and tb is None:
1561            self.commit()
1562        else:
1563            self.rollback()
1564
1565    def __del__(self):
1566        try:
1567            db = self.db
1568        except AttributeError:
1569            db = None
1570        if db:
1571            db.set_cast_hook(None)
1572            if self._closeable:
1573                db.close()
1574
1575    # Auxiliary methods
1576
1577    def _do_debug(self, *args):
1578        """Print a debug message"""
1579        if self.debug:
1580            s = '\n'.join(str(arg) for arg in args)
1581            if isinstance(self.debug, basestring):
1582                print(self.debug % s)
1583            elif hasattr(self.debug, 'write'):
1584                self.debug.write(s + '\n')
1585            elif callable(self.debug):
1586                self.debug(s)
1587            else:
1588                print(s)
1589
1590    def _escape_qualified_name(self, s):
1591        """Escape a qualified name.
1592
1593        Escapes the name for use as an SQL identifier, unless the
1594        name contains a dot, in which case the name is ambiguous
1595        (could be a qualified name or just a name with a dot in it)
1596        and must be quoted manually by the caller.
1597        """
1598        if '.' not in s:
1599            s = self.escape_identifier(s)
1600        return s
1601
1602    @staticmethod
1603    def _make_bool(d):
1604        """Get boolean value corresponding to d."""
1605        return bool(d) if get_bool() else ('t' if d else 'f')
1606
1607    def _list_params(self, params):
1608        """Create a human readable parameter list."""
1609        return ', '.join('$%d=%r' % (n, v) for n, v in enumerate(params, 1))
1610
1611    # Public methods
1612
1613    # escape_string and escape_bytea exist as methods,
1614    # so we define unescape_bytea as a method as well
1615    unescape_bytea = staticmethod(unescape_bytea)
1616
1617    def decode_json(self, s):
1618        """Decode a JSON string coming from the database."""
1619        return (get_jsondecode() or jsondecode)(s)
1620
1621    def encode_json(self, d):
1622        """Encode a JSON string for use within SQL."""
1623        return jsonencode(d)
1624
1625    def close(self):
1626        """Close the database connection."""
1627        # Wraps shared library function so we can track state.
1628        if self._closeable:
1629            if self.db:
1630                self.db.set_cast_hook(None)
1631                self.db.close()
1632                self.db = None
1633            else:
1634                raise _int_error('Connection already closed')
1635
1636    def reset(self):
1637        """Reset connection with current parameters.
1638
1639        All derived queries and large objects derived from this connection
1640        will not be usable after this call.
1641
1642        """
1643        if self.db:
1644            self.db.reset()
1645        else:
1646            raise _int_error('Connection already closed')
1647
1648    def reopen(self):
1649        """Reopen connection to the database.
1650
1651        Used in case we need another connection to the same database.
1652        Note that we can still reopen a database that we have closed.
1653
1654        """
1655        # There is no such shared library function.
1656        if self._closeable:
1657            db = connect(*self._args[0], **self._args[1])
1658            if self.db:
1659                self.db.set_cast_hook(None)
1660                self.db.close()
1661            self.db = db
1662
1663    def begin(self, mode=None):
1664        """Begin a transaction."""
1665        qstr = 'BEGIN'
1666        if mode:
1667            qstr += ' ' + mode
1668        return self.query(qstr)
1669
1670    start = begin
1671
1672    def commit(self):
1673        """Commit the current transaction."""
1674        return self.query('COMMIT')
1675
1676    end = commit
1677
1678    def rollback(self, name=None):
1679        """Roll back the current transaction."""
1680        qstr = 'ROLLBACK'
1681        if name:
1682            qstr += ' TO ' + name
1683        return self.query(qstr)
1684
1685    abort = rollback
1686
1687    def savepoint(self, name):
1688        """Define a new savepoint within the current transaction."""
1689        return self.query('SAVEPOINT ' + name)
1690
1691    def release(self, name):
1692        """Destroy a previously defined savepoint."""
1693        return self.query('RELEASE ' + name)
1694
1695    def get_parameter(self, parameter):
1696        """Get the value of a run-time parameter.
1697
1698        If the parameter is a string, the return value will also be a string
1699        that is the current setting of the run-time parameter with that name.
1700
1701        You can get several parameters at once by passing a list, set or dict.
1702        When passing a list of parameter names, the return value will be a
1703        corresponding list of parameter settings.  When passing a set of
1704        parameter names, a new dict will be returned, mapping these parameter
1705        names to their settings.  Finally, if you pass a dict as parameter,
1706        its values will be set to the current parameter settings corresponding
1707        to its keys.
1708
1709        By passing the special name 'all' as the parameter, you can get a dict
1710        of all existing configuration parameters.
1711        """
1712        if isinstance(parameter, basestring):
1713            parameter = [parameter]
1714            values = None
1715        elif isinstance(parameter, (list, tuple)):
1716            values = []
1717        elif isinstance(parameter, (set, frozenset)):
1718            values = {}
1719        elif isinstance(parameter, dict):
1720            values = parameter
1721        else:
1722            raise TypeError(
1723                'The parameter must be a string, list, set or dict')
1724        if not parameter:
1725            raise TypeError('No parameter has been specified')
1726        params = {} if isinstance(values, dict) else []
1727        for key in parameter:
1728            param = key.strip().lower() if isinstance(
1729                key, basestring) else None
1730            if not param:
1731                raise TypeError('Invalid parameter')
1732            if param == 'all':
1733                q = 'SHOW ALL'
1734                values = self.db.query(q).getresult()
1735                values = dict(value[:2] for value in values)
1736                break
1737            if isinstance(values, dict):
1738                params[param] = key
1739            else:
1740                params.append(param)
1741        else:
1742            for param in params:
1743                q = 'SHOW %s' % (param,)
1744                value = self.db.query(q).getresult()[0][0]
1745                if values is None:
1746                    values = value
1747                elif isinstance(values, list):
1748                    values.append(value)
1749                else:
1750                    values[params[param]] = value
1751        return values
1752
1753    def set_parameter(self, parameter, value=None, local=False):
1754        """Set the value of a run-time parameter.
1755
1756        If the parameter and the value are strings, the run-time parameter
1757        will be set to that value.  If no value or None is passed as a value,
1758        then the run-time parameter will be restored to its default value.
1759
1760        You can set several parameters at once by passing a list of parameter
1761        names, together with a single value that all parameters should be
1762        set to or with a corresponding list of values.  You can also pass
1763        the parameters as a set if you only provide a single value.
1764        Finally, you can pass a dict with parameter names as keys.  In this
1765        case, you should not pass a value, since the values for the parameters
1766        will be taken from the dict.
1767
1768        By passing the special name 'all' as the parameter, you can reset
1769        all existing settable run-time parameters to their default values.
1770
1771        If you set local to True, then the command takes effect for only the
1772        current transaction.  After commit() or rollback(), the session-level
1773        setting takes effect again.  Setting local to True will appear to
1774        have no effect if it is executed outside a transaction, since the
1775        transaction will end immediately.
1776        """
1777        if isinstance(parameter, basestring):
1778            parameter = {parameter: value}
1779        elif isinstance(parameter, (list, tuple)):
1780            if isinstance(value, (list, tuple)):
1781                parameter = dict(zip(parameter, value))
1782            else:
1783                parameter = dict.fromkeys(parameter, value)
1784        elif isinstance(parameter, (set, frozenset)):
1785            if isinstance(value, (list, tuple, set, frozenset)):
1786                value = set(value)
1787                if len(value) == 1:
1788                    value = value.pop()
1789            if not(value is None or isinstance(value, basestring)):
1790                raise ValueError('A single value must be specified'
1791                    ' when parameter is a set')
1792            parameter = dict.fromkeys(parameter, value)
1793        elif isinstance(parameter, dict):
1794            if value is not None:
1795                raise ValueError('A value must not be specified'
1796                    ' when parameter is a dictionary')
1797        else:
1798            raise TypeError(
1799                'The parameter must be a string, list, set or dict')
1800        if not parameter:
1801            raise TypeError('No parameter has been specified')
1802        params = {}
1803        for key, value in parameter.items():
1804            param = key.strip().lower() if isinstance(
1805                key, basestring) else None
1806            if not param:
1807                raise TypeError('Invalid parameter')
1808            if param == 'all':
1809                if value is not None:
1810                    raise ValueError('A value must ot be specified'
1811                        " when parameter is 'all'")
1812                params = {'all': None}
1813                break
1814            params[param] = value
1815        local = ' LOCAL' if local else ''
1816        for param, value in params.items():
1817            if value is None:
1818                q = 'RESET%s %s' % (local, param)
1819            else:
1820                q = 'SET%s %s TO %s' % (local, param, value)
1821            self._do_debug(q)
1822            self.db.query(q)
1823
1824    def query(self, command, *args):
1825        """Execute a SQL command string.
1826
1827        This method simply sends a SQL query to the database.  If the query is
1828        an insert statement that inserted exactly one row into a table that
1829        has OIDs, the return value is the OID of the newly inserted row.
1830        If the query is an update or delete statement, or an insert statement
1831        that did not insert exactly one row in a table with OIDs, then the
1832        number of rows affected is returned as a string.  If it is a statement
1833        that returns rows as a result (usually a select statement, but maybe
1834        also an "insert/update ... returning" statement), this method returns
1835        a Query object that can be accessed via getresult() or dictresult()
1836        or simply printed.  Otherwise, it returns `None`.
1837
1838        The query can contain numbered parameters of the form $1 in place
1839        of any data constant.  Arguments given after the query string will
1840        be substituted for the corresponding numbered parameter.  Parameter
1841        values can also be given as a single list or tuple argument.
1842        """
1843        # Wraps shared library function for debugging.
1844        if not self.db:
1845            raise _int_error('Connection is not valid')
1846        if args:
1847            self._do_debug(command, args)
1848            return self.db.query(command, args)
1849        self._do_debug(command)
1850        return self.db.query(command)
1851
1852    def query_formatted(self, command,
1853            parameters=None, types=None, inline=False):
1854        """Execute a formatted SQL command string.
1855
1856        Similar to query, but using Python format placeholders of the form
1857        %s or %(names)s instead of PostgreSQL placeholders of the form $1.
1858        The parameters must be passed as a tuple, list or dict.  You can
1859        also pass a corresponding tuple, list or dict of database types in
1860        order to format the parameters properly in case there is ambiguity.
1861
1862        If you set inline to True, the parameters will be sent to the database
1863        embedded in the SQL command, otherwise they will be sent separately.
1864        """
1865        return self.query(*self.adapter.format_query(
1866            command, parameters, types, inline))
1867
1868    def pkey(self, table, composite=False, flush=False):
1869        """Get or set the primary key of a table.
1870
1871        Single primary keys are returned as strings unless you
1872        set the composite flag.  Composite primary keys are always
1873        represented as tuples.  Note that this raises a KeyError
1874        if the table does not have a primary key.
1875
1876        If flush is set then the internal cache for primary keys will
1877        be flushed.  This may be necessary after the database schema or
1878        the search path has been changed.
1879        """
1880        pkeys = self._pkeys
1881        if flush:
1882            pkeys.clear()
1883            self._do_debug('The pkey cache has been flushed')
1884        try:  # cache lookup
1885            pkey = pkeys[table]
1886        except KeyError:  # cache miss, check the database
1887            q = ("SELECT a.attname, a.attnum, i.indkey FROM pg_index i"
1888                " JOIN pg_attribute a ON a.attrelid = i.indrelid"
1889                " AND a.attnum = ANY(i.indkey)"
1890                " AND NOT a.attisdropped"
1891                " WHERE i.indrelid=%s::regclass"
1892                " AND i.indisprimary ORDER BY a.attnum") % (
1893                    _quote_if_unqualified('$1', table),)
1894            pkey = self.db.query(q, (table,)).getresult()
1895            if not pkey:
1896                raise KeyError('Table %s has no primary key' % table)
1897            # we want to use the order defined in the primary key index here,
1898            # not the order as defined by the columns in the table
1899            if len(pkey) > 1:
1900                indkey = pkey[0][2]
1901                pkey = sorted(pkey, key=lambda row: indkey.index(row[1]))
1902                pkey = tuple(row[0] for row in pkey)
1903            else:
1904                pkey = pkey[0][0]
1905            pkeys[table] = pkey  # cache it
1906        if composite and not isinstance(pkey, tuple):
1907            pkey = (pkey,)
1908        return pkey
1909
1910    def get_databases(self):
1911        """Get list of databases in the system."""
1912        return [s[0] for s in
1913            self.db.query('SELECT datname FROM pg_database').getresult()]
1914
1915    def get_relations(self, kinds=None, system=False):
1916        """Get list of relations in connected database of specified kinds.
1917
1918        If kinds is None or empty, all kinds of relations are returned.
1919        Otherwise kinds can be a string or sequence of type letters
1920        specifying which kind of relations you want to list.
1921
1922        Set the system flag if you want to get the system relations as well.
1923        """
1924        where = []
1925        if kinds:
1926            where.append("r.relkind IN (%s)" %
1927                ','.join("'%s'" % k for k in kinds))
1928        if not system:
1929            where.append("s.nspname NOT SIMILAR"
1930                " TO 'pg/_%|information/_schema' ESCAPE '/'")
1931        where = " WHERE %s" % ' AND '.join(where) if where else ''
1932        q = ("SELECT quote_ident(s.nspname)||'.'||quote_ident(r.relname)"
1933            " FROM pg_class r"
1934            " JOIN pg_namespace s ON s.oid = r.relnamespace%s"
1935            " ORDER BY s.nspname, r.relname") % where
1936        return [r[0] for r in self.db.query(q).getresult()]
1937
1938    def get_tables(self, system=False):
1939        """Return list of tables in connected database.
1940
1941        Set the system flag if you want to get the system tables as well.
1942        """
1943        return self.get_relations('r', system)
1944
1945    def get_attnames(self, table, with_oid=True, flush=False):
1946        """Given the name of a table, dig out the set of attribute names.
1947
1948        Returns a read-only dictionary of attribute names (the names are
1949        the keys, the values are the names of the attributes' types)
1950        with the column names in the proper order if you iterate over it.
1951
1952        If flush is set, then the internal cache for attribute names will
1953        be flushed. This may be necessary after the database schema or
1954        the search path has been changed.
1955
1956        By default, only a limited number of simple types will be returned.
1957        You can get the regular types after calling use_regtypes(True).
1958        """
1959        attnames = self._attnames
1960        if flush:
1961            attnames.clear()
1962            self._do_debug('The attnames cache has been flushed')
1963        try:  # cache lookup
1964            names = attnames[table]
1965        except KeyError:  # cache miss, check the database
1966            q = "a.attnum > 0"
1967            if with_oid:
1968                q = "(%s OR a.attname = 'oid')" % q
1969            q = self._query_attnames % (_quote_if_unqualified('$1', table), q)
1970            names = self.db.query(q, (table,)).getresult()
1971            types = self.dbtypes
1972            names = ((name[0], types.add(*name[1:])) for name in names)
1973            names = AttrDict(names)
1974            attnames[table] = names  # cache it
1975        return names
1976
1977    def use_regtypes(self, regtypes=None):
1978        """Use regular type names instead of simplified type names."""
1979        if regtypes is None:
1980            return self.dbtypes._regtypes
1981        else:
1982            regtypes = bool(regtypes)
1983            if regtypes != self.dbtypes._regtypes:
1984                self.dbtypes._regtypes = regtypes
1985                self._attnames.clear()
1986                self.dbtypes.clear()
1987            return regtypes
1988
1989    def has_table_privilege(self, table, privilege='select', flush=False):
1990        """Check whether current user has specified table privilege.
1991
1992        If flush is set, then the internal cache for table privileges will
1993        be flushed. This may be necessary after privileges have been changed.
1994        """
1995        privileges = self._privileges
1996        if flush:
1997            privileges.clear()
1998            self._do_debug('The privileges cache has been flushed')
1999        privilege = privilege.lower()
2000        try:  # ask cache
2001            ret = privileges[table, privilege]
2002        except KeyError:  # cache miss, ask the database
2003            q = "SELECT has_table_privilege(%s, $2)" % (
2004                _quote_if_unqualified('$1', table),)
2005            q = self.db.query(q, (table, privilege))
2006            ret = q.getresult()[0][0] == self._make_bool(True)
2007            privileges[table, privilege] = ret  # cache it
2008        return ret
2009
2010    def get(self, table, row, keyname=None):
2011        """Get a row from a database table or view.
2012
2013        This method is the basic mechanism to get a single row.  It assumes
2014        that the keyname specifies a unique row.  It must be the name of a
2015        single column or a tuple of column names.  If the keyname is not
2016        specified, then the primary key for the table is used.
2017
2018        If row is a dictionary, then the value for the key is taken from it.
2019        Otherwise, the row must be a single value or a tuple of values
2020        corresponding to the passed keyname or primary key.  The fetched row
2021        from the table will be returned as a new dictionary or used to replace
2022        the existing values when row was passed as a dictionary.
2023
2024        The OID is also put into the dictionary if the table has one, but
2025        in order to allow the caller to work with multiple tables, it is
2026        munged as "oid(table)" using the actual name of the table.
2027        """
2028        if table.endswith('*'):  # hint for descendant tables can be ignored
2029            table = table[:-1].rstrip()
2030        attnames = self.get_attnames(table)
2031        qoid = _oid_key(table) if 'oid' in attnames else None
2032        if keyname and isinstance(keyname, basestring):
2033            keyname = (keyname,)
2034        if qoid and isinstance(row, dict) and qoid in row and 'oid' not in row:
2035            row['oid'] = row[qoid]
2036        if not keyname:
2037            try:  # if keyname is not specified, try using the primary key
2038                keyname = self.pkey(table, True)
2039            except KeyError:  # the table has no primary key
2040                # try using the oid instead
2041                if qoid and isinstance(row, dict) and 'oid' in row:
2042                    keyname = ('oid',)
2043                else:
2044                    raise _prg_error('Table %s has no primary key' % table)
2045            else:  # the table has a primary key
2046                # check whether all key columns have values
2047                if isinstance(row, dict) and not set(keyname).issubset(row):
2048                    # try using the oid instead
2049                    if qoid and 'oid' in row:
2050                        keyname = ('oid',)
2051                    else:
2052                        raise KeyError(
2053                            'Missing value in row for specified keyname')
2054        if not isinstance(row, dict):
2055            if not isinstance(row, (tuple, list)):
2056                row = [row]
2057            if len(keyname) != len(row):
2058                raise KeyError(
2059                    'Differing number of items in keyname and row')
2060            row = dict(zip(keyname, row))
2061        params = self.adapter.parameter_list()
2062        adapt = params.add
2063        col = self.escape_identifier
2064        what = 'oid, *' if qoid else '*'
2065        where = ' AND '.join('%s = %s' % (
2066            col(k), adapt(row[k], attnames[k])) for k in keyname)
2067        if 'oid' in row:
2068            if qoid:
2069                row[qoid] = row['oid']
2070            del row['oid']
2071        q = 'SELECT %s FROM %s WHERE %s LIMIT 1' % (
2072            what, self._escape_qualified_name(table), where)
2073        self._do_debug(q, params)
2074        q = self.db.query(q, params)
2075        res = q.dictresult()
2076        if not res:
2077            raise _db_error('No such record in %s\nwhere %s\nwith %s' % (
2078                table, where, self._list_params(params)))
2079        for n, value in res[0].items():
2080            if qoid and n == 'oid':
2081                n = qoid
2082            row[n] = value
2083        return row
2084
2085    def insert(self, table, row=None, **kw):
2086        """Insert a row into a database table.
2087
2088        This method inserts a row into a table.  The name of the table must
2089        be passed as the first parameter.  The other parameters are used for
2090        providing the data of the row that shall be inserted into the table.
2091        If a dictionary is supplied as the second parameter, it starts with
2092        that.  Otherwise it uses a blank dictionary. Either way the dictionary
2093        is updated from the keywords.
2094
2095        The dictionary is then reloaded with the values actually inserted in
2096        order to pick up values modified by rules, triggers, etc.
2097        """
2098        if table.endswith('*'):  # hint for descendant tables can be ignored
2099            table = table[:-1].rstrip()
2100        if row is None:
2101            row = {}
2102        row.update(kw)
2103        if 'oid' in row:
2104            del row['oid']  # do not insert oid
2105        attnames = self.get_attnames(table)
2106        qoid = _oid_key(table) if 'oid' in attnames else None
2107        params = self.adapter.parameter_list()
2108        adapt = params.add
2109        col = self.escape_identifier
2110        names, values = [], []
2111        for n in attnames:
2112            if n in row:
2113                names.append(col(n))
2114                values.append(adapt(row[n], attnames[n]))
2115        if not names:
2116            raise _prg_error('No column found that can be inserted')
2117        names, values = ', '.join(names), ', '.join(values)
2118        ret = 'oid, *' if qoid else '*'
2119        q = 'INSERT INTO %s (%s) VALUES (%s) RETURNING %s' % (
2120            self._escape_qualified_name(table), names, values, ret)
2121        self._do_debug(q, params)
2122        q = self.db.query(q, params)
2123        res = q.dictresult()
2124        if res:  # this should always be true
2125            for n, value in res[0].items():
2126                if qoid and n == 'oid':
2127                    n = qoid
2128                row[n] = value
2129        return row
2130
2131    def update(self, table, row=None, **kw):
2132        """Update an existing row in a database table.
2133
2134        Similar to insert, but updates an existing row.  The update is based
2135        on the primary key of the table or the OID value as munged by get()
2136        or passed as keyword.  The OID will take precedence if provided, so
2137        that it is possible to update the primary key itself.
2138
2139        The dictionary is then modified to reflect any changes caused by the
2140        update due to triggers, rules, default values, etc.
2141        """
2142        if table.endswith('*'):
2143            table = table[:-1].rstrip()  # need parent table name
2144        attnames = self.get_attnames(table)
2145        qoid = _oid_key(table) if 'oid' in attnames else None
2146        if row is None:
2147            row = {}
2148        elif 'oid' in row:
2149            del row['oid']  # only accept oid key from named args for safety
2150        row.update(kw)
2151        if qoid and qoid in row and 'oid' not in row:
2152            row['oid'] = row[qoid]
2153        if qoid and 'oid' in row:  # try using the oid
2154            keyname = ('oid',)
2155        else:  # try using the primary key
2156            try:
2157                keyname = self.pkey(table, True)
2158            except KeyError:  # the table has no primary key
2159                raise _prg_error('Table %s has no primary key' % table)
2160            # check whether all key columns have values
2161            if not set(keyname).issubset(row):
2162                raise KeyError('Missing value for primary key in row')
2163        params = self.adapter.parameter_list()
2164        adapt = params.add
2165        col = self.escape_identifier
2166        where = ' AND '.join('%s = %s' % (
2167            col(k), adapt(row[k], attnames[k])) for k in keyname)
2168        if 'oid' in row:
2169            if qoid:
2170                row[qoid] = row['oid']
2171            del row['oid']
2172        values = []
2173        keyname = set(keyname)
2174        for n in attnames:
2175            if n in row and n not in keyname:
2176                values.append('%s = %s' % (col(n), adapt(row[n], attnames[n])))
2177        if not values:
2178            return row
2179        values = ', '.join(values)
2180        ret = 'oid, *' if qoid else '*'
2181        q = 'UPDATE %s SET %s WHERE %s RETURNING %s' % (
2182            self._escape_qualified_name(table), values, where, ret)
2183        self._do_debug(q, params)
2184        q = self.db.query(q, params)
2185        res = q.dictresult()
2186        if res:  # may be empty when row does not exist
2187            for n, value in res[0].items():
2188                if qoid and n == 'oid':
2189                    n = qoid
2190                row[n] = value
2191        return row
2192
2193    def upsert(self, table, row=None, **kw):
2194        """Insert a row into a database table with conflict resolution
2195
2196        This method inserts a row into a table, but instead of raising a
2197        ProgrammingError exception in case a row with the same primary key
2198        already exists, an update will be executed instead.  This will be
2199        performed as a single atomic operation on the database, so race
2200        conditions can be avoided.
2201
2202        Like the insert method, the first parameter is the name of the
2203        table and the second parameter can be used to pass the values to
2204        be inserted as a dictionary.
2205
2206        Unlike the insert und update statement, keyword parameters are not
2207        used to modify the dictionary, but to specify which columns shall
2208        be updated in case of a conflict, and in which way:
2209
2210        A value of False or None means the column shall not be updated,
2211        a value of True means the column shall be updated with the value
2212        that has been proposed for insertion, i.e. has been passed as value
2213        in the dictionary.  Columns that are not specified by keywords but
2214        appear as keys in the dictionary are also updated like in the case
2215        keywords had been passed with the value True.
2216
2217        So if in the case of a conflict you want to update every column that
2218        has been passed in the dictionary row, you would call upsert(table, row).
2219        If you don't want to do anything in case of a conflict, i.e. leave
2220        the existing row as it is, call upsert(table, row, **dict.fromkeys(row)).
2221
2222        If you need more fine-grained control of what gets updated, you can
2223        also pass strings in the keyword parameters.  These strings will
2224        be used as SQL expressions for the update columns.  In these
2225        expressions you can refer to the value that already exists in
2226        the table by prefixing the column name with "included.", and to
2227        the value that has been proposed for insertion by prefixing the
2228        column name with the "excluded."
2229
2230        The dictionary is modified in any case to reflect the values in
2231        the database after the operation has completed.
2232
2233        Note: The method uses the PostgreSQL "upsert" feature which is
2234        only available since PostgreSQL 9.5.
2235        """
2236        if table.endswith('*'):  # hint for descendant tables can be ignored
2237            table = table[:-1].rstrip()
2238        if row is None:
2239            row = {}
2240        if 'oid' in row:
2241            del row['oid']  # do not insert oid
2242        if 'oid' in kw:
2243            del kw['oid']  # do not update oid
2244        attnames = self.get_attnames(table)
2245        qoid = _oid_key(table) if 'oid' in attnames else None
2246        params = self.adapter.parameter_list()
2247        adapt = params.add
2248        col = self.escape_identifier
2249        names, values, updates = [], [], []
2250        for n in attnames:
2251            if n in row:
2252                names.append(col(n))
2253                values.append(adapt(row[n], attnames[n]))
2254        names, values = ', '.join(names), ', '.join(values)
2255        try:
2256            keyname = self.pkey(table, True)
2257        except KeyError:
2258            raise _prg_error('Table %s has no primary key' % table)
2259        target = ', '.join(col(k) for k in keyname)
2260        update = []
2261        keyname = set(keyname)
2262        keyname.add('oid')
2263        for n in attnames:
2264            if n not in keyname:
2265                value = kw.get(n, True)
2266                if value:
2267                    if not isinstance(value, basestring):
2268                        value = 'excluded.%s' % col(n)
2269                    update.append('%s = %s' % (col(n), value))
2270        if not values:
2271            return row
2272        do = 'update set %s' % ', '.join(update) if update else 'nothing'
2273        ret = 'oid, *' if qoid else '*'
2274        q = ('INSERT INTO %s AS included (%s) VALUES (%s)'
2275            ' ON CONFLICT (%s) DO %s RETURNING %s') % (
2276                self._escape_qualified_name(table), names, values,
2277                target, do, ret)
2278        self._do_debug(q, params)
2279        try:
2280            q = self.db.query(q, params)
2281        except ProgrammingError:
2282            if self.server_version < 90500:
2283                raise _prg_error(
2284                    'Upsert operation is not supported by PostgreSQL version')
2285            raise  # re-raise original error
2286        res = q.dictresult()
2287        if res:  # may be empty with "do nothing"
2288            for n, value in res[0].items():
2289                if qoid and n == 'oid':
2290                    n = qoid
2291                row[n] = value
2292        else:
2293            self.get(table, row)
2294        return row
2295
2296    def clear(self, table, row=None):
2297        """Clear all the attributes to values determined by the types.
2298
2299        Numeric types are set to 0, Booleans are set to false, and everything
2300        else is set to the empty string.  If the row argument is present,
2301        it is used as the row dictionary and any entries matching attribute
2302        names are cleared with everything else left unchanged.
2303        """
2304        # At some point we will need a way to get defaults from a table.
2305        if row is None:
2306            row = {}  # empty if argument is not present
2307        attnames = self.get_attnames(table)
2308        for n, t in attnames.items():
2309            if n == 'oid':
2310                continue
2311            t = t.simple
2312            if t in DbTypes._num_types:
2313                row[n] = 0
2314            elif t == 'bool':
2315                row[n] = self._make_bool(False)
2316            else:
2317                row[n] = ''
2318        return row
2319
2320    def delete(self, table, row=None, **kw):
2321        """Delete an existing row in a database table.
2322
2323        This method deletes the row from a table.  It deletes based on the
2324        primary key of the table or the OID value as munged by get() or
2325        passed as keyword.  The OID will take precedence if provided.
2326
2327        The return value is the number of deleted rows (i.e. 0 if the row
2328        did not exist and 1 if the row was deleted).
2329
2330        Note that if the row cannot be deleted because e.g. it is still
2331        referenced by another table, this method raises a ProgrammingError.
2332        """
2333        if table.endswith('*'):  # hint for descendant tables can be ignored
2334            table = table[:-1].rstrip()
2335        attnames = self.get_attnames(table)
2336        qoid = _oid_key(table) if 'oid' in attnames else None
2337        if row is None:
2338            row = {}
2339        elif 'oid' in row:
2340            del row['oid']  # only accept oid key from named args for safety
2341        row.update(kw)
2342        if qoid and qoid in row and 'oid' not in row:
2343            row['oid'] = row[qoid]
2344        if qoid and 'oid' in row:  # try using the oid
2345            keyname = ('oid',)
2346        else:  # try using the primary key
2347            try:
2348                keyname = self.pkey(table, True)
2349            except KeyError:  # the table has no primary key
2350                raise _prg_error('Table %s has no primary key' % table)
2351            # check whether all key columns have values
2352            if not set(keyname).issubset(row):
2353                raise KeyError('Missing value for primary key in row')
2354        params = self.adapter.parameter_list()
2355        adapt = params.add
2356        col = self.escape_identifier
2357        where = ' AND '.join('%s = %s' % (
2358            col(k), adapt(row[k], attnames[k])) for k in keyname)
2359        if 'oid' in row:
2360            if qoid:
2361                row[qoid] = row['oid']
2362            del row['oid']
2363        q = 'DELETE FROM %s WHERE %s' % (
2364            self._escape_qualified_name(table), where)
2365        self._do_debug(q, params)
2366        res = self.db.query(q, params)
2367        return int(res)
2368
2369    def truncate(self, table, restart=False, cascade=False, only=False):
2370        """Empty a table or set of tables.
2371
2372        This method quickly removes all rows from the given table or set
2373        of tables.  It has the same effect as an unqualified DELETE on each
2374        table, but since it does not actually scan the tables it is faster.
2375        Furthermore, it reclaims disk space immediately, rather than requiring
2376        a subsequent VACUUM operation. This is most useful on large tables.
2377
2378        If restart is set to True, sequences owned by columns of the truncated
2379        table(s) are automatically restarted.  If cascade is set to True, it
2380        also truncates all tables that have foreign-key references to any of
2381        the named tables.  If the parameter only is not set to True, all the
2382        descendant tables (if any) will also be truncated. Optionally, a '*'
2383        can be specified after the table name to explicitly indicate that
2384        descendant tables are included.
2385        """
2386        if isinstance(table, basestring):
2387            only = {table: only}
2388            table = [table]
2389        elif isinstance(table, (list, tuple)):
2390            if isinstance(only, (list, tuple)):
2391                only = dict(zip(table, only))
2392            else:
2393                only = dict.fromkeys(table, only)
2394        elif isinstance(table, (set, frozenset)):
2395            only = dict.fromkeys(table, only)
2396        else:
2397            raise TypeError('The table must be a string, list or set')
2398        if not (restart is None or isinstance(restart, (bool, int))):
2399            raise TypeError('Invalid type for the restart option')
2400        if not (cascade is None or isinstance(cascade, (bool, int))):
2401            raise TypeError('Invalid type for the cascade option')
2402        tables = []
2403        for t in table:
2404            u = only.get(t)
2405            if not (u is None or isinstance(u, (bool, int))):
2406                raise TypeError('Invalid type for the only option')
2407            if t.endswith('*'):
2408                if u:
2409                    raise ValueError(
2410                        'Contradictory table name and only options')
2411                t = t[:-1].rstrip()
2412            t = self._escape_qualified_name(t)
2413            if u:
2414                t = 'ONLY %s' % t
2415            tables.append(t)
2416        q = ['TRUNCATE', ', '.join(tables)]
2417        if restart:
2418            q.append('RESTART IDENTITY')
2419        if cascade:
2420            q.append('CASCADE')
2421        q = ' '.join(q)
2422        self._do_debug(q)
2423        return self.db.query(q)
2424
2425    def get_as_list(self, table, what=None, where=None,
2426            order=None, limit=None, offset=None, scalar=False):
2427        """Get a table as a list.
2428
2429        This gets a convenient representation of the table as a list
2430        of named tuples in Python.  You only need to pass the name of
2431        the table (or any other SQL expression returning rows).  Note that
2432        by default this will return the full content of the table which
2433        can be huge and overflow your memory.  However, you can control
2434        the amount of data returned using the other optional parameters.
2435
2436        The parameter 'what' can restrict the query to only return a
2437        subset of the table columns.  It can be a string, list or a tuple.
2438        The parameter 'where' can restrict the query to only return a
2439        subset of the table rows.  It can be a string, list or a tuple
2440        of SQL expressions that all need to be fulfilled.  The parameter
2441        'order' specifies the ordering of the rows.  It can also be a
2442        other string, list or a tuple.  If no ordering is specified,
2443        the result will be ordered by the primary key(s) or all columns
2444        if no primary key exists.  You can set 'order' to False if you
2445        don't care about the ordering.  The parameters 'limit' and 'offset'
2446        can be integers specifying the maximum number of rows returned
2447        and a number of rows skipped over.
2448
2449        If you set the 'scalar' option to True, then instead of the
2450        named tuples you will get the first items of these tuples.
2451        This is useful if the result has only one column anyway.
2452        """
2453        if not table:
2454            raise TypeError('The table name is missing')
2455        if what:
2456            if isinstance(what, (list, tuple)):
2457                what = ', '.join(map(str, what))
2458            if order is None:
2459                order = what
2460        else:
2461            what = '*'
2462        q = ['SELECT', what, 'FROM', table]
2463        if where:
2464            if isinstance(where, (list, tuple)):
2465                where = ' AND '.join(map(str, where))
2466            q.extend(['WHERE', where])
2467        if order is None:
2468            try:
2469                order = self.pkey(table, True)
2470            except (KeyError, ProgrammingError):
2471                try:
2472                    order = list(self.get_attnames(table))
2473                except (KeyError, ProgrammingError):
2474                    pass
2475        if order:
2476            if isinstance(order, (list, tuple)):
2477                order = ', '.join(map(str, order))
2478            q.extend(['ORDER BY', order])
2479        if limit:
2480            q.append('LIMIT %d' % limit)
2481        if offset:
2482            q.append('OFFSET %d' % offset)
2483        q = ' '.join(q)
2484        self._do_debug(q)
2485        q = self.db.query(q)
2486        res = q.namedresult()
2487        if res and scalar:
2488            res = [row[0] for row in res]
2489        return res
2490
2491    def get_as_dict(self, table, keyname=None, what=None, where=None,
2492            order=None, limit=None, offset=None, scalar=False):
2493        """Get a table as a dictionary.
2494
2495        This method is similar to get_as_list(), but returns the table
2496        as a Python dict instead of a Python list, which can be even
2497        more convenient. The primary key column(s) of the table will
2498        be used as the keys of the dictionary, while the other column(s)
2499        will be the corresponding values.  The keys will be named tuples
2500        if the table has a composite primary key.  The rows will be also
2501        named tuples unless the 'scalar' option has been set to True.
2502        With the optional parameter 'keyname' you can specify an alternative
2503        set of columns to be used as the keys of the dictionary.  It must
2504        be set as a string, list or a tuple.
2505
2506        If the Python version supports it, the dictionary will be an
2507        OrderedDict using the order specified with the 'order' parameter
2508        or the key column(s) if not specified.  You can set 'order' to False
2509        if you don't care about the ordering.  In this case the returned
2510        dictionary will be an ordinary one.
2511        """
2512        if not table:
2513            raise TypeError('The table name is missing')
2514        if not keyname:
2515            try:
2516                keyname = self.pkey(table, True)
2517            except (KeyError, ProgrammingError):
2518                raise _prg_error('Table %s has no primary key' % table)
2519        if isinstance(keyname, basestring):
2520            keyname = [keyname]
2521        elif not isinstance(keyname, (list, tuple)):
2522            raise KeyError('The keyname must be a string, list or tuple')
2523        if what:
2524            if isinstance(what, (list, tuple)):
2525                what = ', '.join(map(str, what))
2526            if order is None:
2527                order = what
2528        else:
2529            what = '*'
2530        q = ['SELECT', what, 'FROM', table]
2531        if where:
2532            if isinstance(where, (list, tuple)):
2533                where = ' AND '.join(map(str, where))
2534            q.extend(['WHERE', where])
2535        if order is None:
2536            order = keyname
2537        if order:
2538            if isinstance(order, (list, tuple)):
2539                order = ', '.join(map(str, order))
2540            q.extend(['ORDER BY', order])
2541        if limit:
2542            q.append('LIMIT %d' % limit)
2543        if offset:
2544            q.append('OFFSET %d' % offset)
2545        q = ' '.join(q)
2546        self._do_debug(q)
2547        q = self.db.query(q)
2548        res = q.getresult()
2549        cls = OrderedDict if order else dict
2550        if not res:
2551            return cls()
2552        keyset = set(keyname)
2553        fields = q.listfields()
2554        if not keyset.issubset(fields):
2555            raise KeyError('Missing keyname in row')
2556        keyind, rowind = [], []
2557        for i, f in enumerate(fields):
2558            (keyind if f in keyset else rowind).append(i)
2559        keytuple = len(keyind) > 1
2560        getkey = itemgetter(*keyind)
2561        keys = map(getkey, res)
2562        if scalar:
2563            rowind = rowind[:1]
2564            rowtuple = False
2565        else:
2566            rowtuple = len(rowind) > 1
2567        if scalar or rowtuple:
2568            getrow = itemgetter(*rowind)
2569        else:
2570            rowind = rowind[0]
2571            getrow = lambda row: (row[rowind],)
2572            rowtuple = True
2573        rows = map(getrow, res)
2574        if keytuple or rowtuple:
2575            namedresult = get_namedresult()
2576            if namedresult:
2577                if keytuple:
2578                    keys = namedresult(_MemoryQuery(keys, keyname))
2579                if rowtuple:
2580                    fields = [f for f in fields if f not in keyset]
2581                    rows = namedresult(_MemoryQuery(rows, fields))
2582        return cls(zip(keys, rows))
2583
2584    def notification_handler(self,
2585            event, callback, arg_dict=None, timeout=None, stop_event=None):
2586        """Get notification handler that will run the given callback."""
2587        return NotificationHandler(self,
2588            event, callback, arg_dict, timeout, stop_event)
2589
2590
2591# if run as script, print some information
2592
2593if __name__ == '__main__':
2594    print('PyGreSQL version' + version)
2595    print('')
2596    print(__doc__)
Note: See TracBrowser for help on using the repository browser.