source: trunk/pg.py @ 894

Last change on this file since 894 was 894, checked in by cito, 3 years ago

Cache the namedtuple classes used for query result rows

  • Property svn:keywords set to Id
File size: 93.6 KB
Line 
1#! /usr/bin/python
2#
3# pg.py
4#
5# $Id: pg.py 894 2016-09-23 14:04:06Z cito $
6#
7
8"""PyGreSQL classic interface.
9
10This pg module implements some basic database management stuff.
11It includes the _pg module and builds on it, providing the higher
12level wrapper class named DB with additional functionality.
13This is known as the "classic" ("old style") PyGreSQL interface.
14For a DB-API 2 compliant interface use the newer pgdb module.
15"""
16
17# Copyright (c) 1997-2016 by D'Arcy J.M. Cain.
18#
19# Contributions made by Ch. Zwerschke and others.
20#
21# The notification handler is based on pgnotify which is
22# Copyright (c) 2001 Ng Pheng Siong. All rights reserved.
23#
24# Permission to use, copy, modify, and distribute this software and its
25# documentation for any purpose and without fee is hereby granted,
26# provided that the above copyright notice appear in all copies and that
27# both that copyright notice and this permission notice appear in
28# supporting documentation.
29
30from __future__ import print_function, division
31
32from _pg import *
33
34__version__ = version
35
36import select
37import warnings
38import weakref
39
40from datetime import date, time, datetime, timedelta, tzinfo
41from decimal import Decimal
42from math import isnan, isinf
43from collections import namedtuple
44from operator import itemgetter
45from functools import partial
46from re import compile as regex
47from json import loads as jsondecode, dumps as jsonencode
48from uuid import UUID
49
50try:
51    long
52except NameError:  # Python >= 3.0
53    long = int
54
55try:
56    basestring
57except NameError:  # Python >= 3.0
58    basestring = (str, bytes)
59
60try:
61    from functools import lru_cache
62except ImportError:  # Python < 3.2
63    from functools import update_wrapper
64    try:
65        from _thread import RLock
66    except ImportError:
67        class RLock:  # for builds without threads
68            def __enter__(self): pass
69
70            def __exit__(self, exctype, excinst, exctb): pass
71
72    def lru_cache(maxsize=128):
73        """Simplified functools.lru_cache decorator for one argument."""
74
75        def decorator(function):
76            sentinel = object()
77            cache = {}
78            get = cache.get
79            lock = RLock()
80            root = []
81            root_full = [root, False]
82            root[:] = [root, root, None, None]
83
84            if maxsize == 0:
85
86                def wrapper(arg):
87                    res = function(arg)
88                    return res
89
90            elif maxsize is None:
91
92                def wrapper(arg):
93                    res = get(arg, sentinel)
94                    if res is not sentinel:
95                        return res
96                    res = function(arg)
97                    cache[arg] = res
98                    return res
99
100            else:
101
102                def wrapper(arg):
103                    with lock:
104                        link = get(arg)
105                        if link is not None:
106                            root = root_full[0]
107                            prev, next, _arg, res = link
108                            prev[1] = next
109                            next[0] = prev
110                            last = root[0]
111                            last[1] = root[0] = link
112                            link[0] = last
113                            link[1] = root
114                            return res
115                    res = function(arg)
116                    with lock:
117                        root, full = root_full
118                        if arg in cache:
119                            pass
120                        elif full:
121                            oldroot = root
122                            oldroot[2] = arg
123                            oldroot[3] = res
124                            root = root_full[0] = oldroot[1]
125                            oldarg = root[2]
126                            oldres = root[3]  # keep reference
127                            root[2] = root[3] = None
128                            del cache[oldarg]
129                            cache[arg] = oldroot
130                        else:
131                            last = root[0]
132                            link = [last, root, arg, res]
133                            last[1] = root[0] = cache[arg] = link
134                            if len(cache) >= maxsize:
135                                root_full[1] = True
136                    return res
137
138            wrapper.__wrapped__ = function
139            return update_wrapper(wrapper, function)
140
141        return decorator
142
143
144# Auxiliary classes and functions that are independent from a DB connection:
145
146try:
147    from collections import OrderedDict
148except ImportError:  # Python 2.6 or 3.0
149    OrderedDict = dict
150
151
152    class AttrDict(dict):
153        """Simple read-only ordered dictionary for storing attribute names."""
154
155        def __init__(self, *args, **kw):
156            if len(args) > 1 or kw:
157                raise TypeError
158            items = args[0] if args else []
159            if isinstance(items, dict):
160                raise TypeError
161            items = list(items)
162            self._keys = [item[0] for item in items]
163            dict.__init__(self, items)
164            self._read_only = True
165            error = self._read_only_error
166            self.clear = self.update = error
167            self.pop = self.setdefault = self.popitem = error
168
169        def __setitem__(self, key, value):
170            if self._read_only:
171                self._read_only_error()
172            dict.__setitem__(self, key, value)
173
174        def __delitem__(self, key):
175            if self._read_only:
176                self._read_only_error()
177            dict.__delitem__(self, key)
178
179        def __iter__(self):
180            return iter(self._keys)
181
182        def keys(self):
183            return list(self._keys)
184
185        def values(self):
186            return [self[key] for key in self]
187
188        def items(self):
189            return [(key, self[key]) for key in self]
190
191        def iterkeys(self):
192            return self.__iter__()
193
194        def itervalues(self):
195            return iter(self.values())
196
197        def iteritems(self):
198            return iter(self.items())
199
200        @staticmethod
201        def _read_only_error(*args, **kw):
202            raise TypeError('This object is read-only')
203
204else:
205
206     class AttrDict(OrderedDict):
207        """Simple read-only ordered dictionary for storing attribute names."""
208
209        def __init__(self, *args, **kw):
210            self._read_only = False
211            OrderedDict.__init__(self, *args, **kw)
212            self._read_only = True
213            error = self._read_only_error
214            self.clear = self.update = error
215            self.pop = self.setdefault = self.popitem = error
216
217        def __setitem__(self, key, value):
218            if self._read_only:
219                self._read_only_error()
220            OrderedDict.__setitem__(self, key, value)
221
222        def __delitem__(self, key):
223            if self._read_only:
224                self._read_only_error()
225            OrderedDict.__delitem__(self, key)
226
227        @staticmethod
228        def _read_only_error(*args, **kw):
229            raise TypeError('This object is read-only')
230
231try:
232    from inspect import signature
233except ImportError:  # Python < 3.3
234    from inspect import getargspec
235
236    def get_args(func):
237        return getargspec(func).args
238else:
239
240    def get_args(func):
241        return list(signature(func).parameters)
242
243try:
244    from datetime import timezone
245except ImportError:  # Python < 3.2
246
247    class timezone(tzinfo):
248        """Simple timezone implementation."""
249
250        def __init__(self, offset, name=None):
251            self.offset = offset
252            if not name:
253                minutes = self.offset.days * 1440 + self.offset.seconds // 60
254                if minutes < 0:
255                    hours, minutes = divmod(-minutes, 60)
256                    hours = -hours
257                else:
258                    hours, minutes = divmod(minutes, 60)
259                name = 'UTC%+03d:%02d' % (hours, minutes)
260            self.name = name
261
262        def utcoffset(self, dt):
263            return self.offset
264
265        def tzname(self, dt):
266            return self.name
267
268        def dst(self, dt):
269            return None
270
271    timezone.utc = timezone(timedelta(0), 'UTC')
272
273    _has_timezone = False
274else:
275    _has_timezone = True
276
277# time zones used in Postgres timestamptz output
278_timezones = dict(CET='+0100', EET='+0200', EST='-0500',
279    GMT='+0000', HST='-1000', MET='+0100', MST='-0700',
280    UCT='+0000', UTC='+0000', WET='+0000')
281
282
283def _timezone_as_offset(tz):
284    if tz.startswith(('+', '-')):
285        if len(tz) < 5:
286            return tz + '00'
287        return tz.replace(':', '')
288    return _timezones.get(tz, '+0000')
289
290
291def _get_timezone(tz):
292    tz = _timezone_as_offset(tz)
293    minutes = 60 * int(tz[1:3]) + int(tz[3:5])
294    if tz[0] == '-':
295        minutes = -minutes
296    return timezone(timedelta(minutes=minutes), tz)
297
298
299def _oid_key(table):
300    """Build oid key from a table name."""
301    return 'oid(%s)' % table
302
303
304class _SimpleTypes(dict):
305    """Dictionary mapping pg_type names to simple type names."""
306
307    _types = {'bool': 'bool',
308        'bytea': 'bytea',
309        'date': 'date interval time timetz timestamp timestamptz'
310            ' abstime reltime',  # these are very old
311        'float': 'float4 float8',
312        'int': 'cid int2 int4 int8 oid xid',
313        'hstore': 'hstore', 'json': 'json jsonb', 'uuid': 'uuid',
314        'num': 'numeric', 'money': 'money',
315        'text': 'bpchar char name text varchar'}
316
317    def __init__(self):
318        for typ, keys in self._types.items():
319            for key in keys.split():
320                self[key] = typ
321                self['_%s' % key] = '%s[]' % typ
322
323    # this could be a static method in Python > 2.6
324    def __missing__(self, key):
325        return 'text'
326
327_simpletypes = _SimpleTypes()
328
329
330def _quote_if_unqualified(param, name):
331    """Quote parameter representing a qualified name.
332
333    Puts a quote_ident() call around the give parameter unless
334    the name contains a dot, in which case the name is ambiguous
335    (could be a qualified name or just a name with a dot in it)
336    and must be quoted manually by the caller.
337    """
338    if isinstance(name, basestring) and '.' not in name:
339        return 'quote_ident(%s)' % (param,)
340    return param
341
342
343class _ParameterList(list):
344    """Helper class for building typed parameter lists."""
345
346    def add(self, value, typ=None):
347        """Typecast value with known database type and build parameter list.
348
349        If this is a literal value, it will be returned as is.  Otherwise, a
350        placeholder will be returned and the parameter list will be augmented.
351        """
352        value = self.adapt(value, typ)
353        if isinstance(value, Literal):
354            return value
355        self.append(value)
356        return '$%d' % len(self)
357
358
359class Bytea(bytes):
360    """Wrapper class for marking Bytea values."""
361
362
363class Hstore(dict):
364    """Wrapper class for marking hstore values."""
365
366    _re_quote = regex('^[Nn][Uu][Ll][Ll]$|[ ,=>]')
367
368    @classmethod
369    def _quote(cls, s):
370        if s is None:
371            return 'NULL'
372        if not s:
373            return '""'
374        s = s.replace('"', '\\"')
375        if cls._re_quote.search(s):
376            s = '"%s"' % s
377        return s
378
379    def __str__(self):
380        q = self._quote
381        return ','.join('%s=>%s' % (q(k), q(v)) for k, v in self.items())
382
383
384class Json:
385    """Wrapper class for marking Json values."""
386
387    def __init__(self, obj):
388        self.obj = obj
389
390
391class Literal(str):
392    """Wrapper class for marking literal SQL values."""
393
394
395class Adapter:
396    """Class providing methods for adapting parameters to the database."""
397
398    _bool_true_values = frozenset('t true 1 y yes on'.split())
399
400    _date_literals = frozenset('current_date current_time'
401        ' current_timestamp localtime localtimestamp'.split())
402
403    _re_array_quote = regex(r'[{},"\\\s]|^[Nn][Uu][Ll][Ll]$')
404    _re_record_quote = regex(r'[(,"\\]')
405    _re_array_escape = _re_record_escape = regex(r'(["\\])')
406
407    def __init__(self, db):
408        self.db = weakref.proxy(db)
409
410    @classmethod
411    def _adapt_bool(cls, v):
412        """Adapt a boolean parameter."""
413        if isinstance(v, basestring):
414            if not v:
415                return None
416            v = v.lower() in cls._bool_true_values
417        return 't' if v else 'f'
418
419    @classmethod
420    def _adapt_date(cls, v):
421        """Adapt a date parameter."""
422        if not v:
423            return None
424        if isinstance(v, basestring) and v.lower() in cls._date_literals:
425            return Literal(v)
426        return v
427
428    @staticmethod
429    def _adapt_num(v):
430        """Adapt a numeric parameter."""
431        if not v and v != 0:
432            return None
433        return v
434
435    _adapt_int = _adapt_float = _adapt_money = _adapt_num
436
437    def _adapt_bytea(self, v):
438        """Adapt a bytea parameter."""
439        return self.db.escape_bytea(v)
440
441    def _adapt_json(self, v):
442        """Adapt a json parameter."""
443        if not v:
444            return None
445        if isinstance(v, basestring):
446            return v
447        return self.db.encode_json(v)
448
449    @classmethod
450    def _adapt_text_array(cls, v):
451        """Adapt a text type array parameter."""
452        if isinstance(v, list):
453            adapt = cls._adapt_text_array
454            return '{%s}' % ','.join(adapt(v) for v in v)
455        if v is None:
456            return 'null'
457        if not v:
458            return '""'
459        v = str(v)
460        if cls._re_array_quote.search(v):
461            v = '"%s"' % cls._re_array_escape.sub(r'\\\1', v)
462        return v
463
464    _adapt_date_array = _adapt_text_array
465
466    @classmethod
467    def _adapt_bool_array(cls, v):
468        """Adapt a boolean array parameter."""
469        if isinstance(v, list):
470            adapt = cls._adapt_bool_array
471            return '{%s}' % ','.join(adapt(v) for v in v)
472        if v is None:
473            return 'null'
474        if isinstance(v, basestring):
475            if not v:
476                return 'null'
477            v = v.lower() in cls._bool_true_values
478        return 't' if v else 'f'
479
480    @classmethod
481    def _adapt_num_array(cls, v):
482        """Adapt a numeric array parameter."""
483        if isinstance(v, list):
484            adapt = cls._adapt_num_array
485            return '{%s}' % ','.join(adapt(v) for v in v)
486        if not v and v != 0:
487            return 'null'
488        return str(v)
489
490    _adapt_int_array = _adapt_float_array = _adapt_money_array = \
491            _adapt_num_array
492
493    def _adapt_bytea_array(self, v):
494        """Adapt a bytea array parameter."""
495        if isinstance(v, list):
496            return b'{' + b','.join(
497                self._adapt_bytea_array(v) for v in v) + b'}'
498        if v is None:
499            return b'null'
500        return self.db.escape_bytea(v).replace(b'\\', b'\\\\')
501
502    def _adapt_json_array(self, v):
503        """Adapt a json array parameter."""
504        if isinstance(v, list):
505            adapt = self._adapt_json_array
506            return '{%s}' % ','.join(adapt(v) for v in v)
507        if not v:
508            return 'null'
509        if not isinstance(v, basestring):
510            v = self.db.encode_json(v)
511        if self._re_array_quote.search(v):
512            v = '"%s"' % self._re_array_escape.sub(r'\\\1', v)
513        return v
514
515    def _adapt_record(self, v, typ):
516        """Adapt a record parameter with given type."""
517        typ = self.get_attnames(typ).values()
518        if len(typ) != len(v):
519            raise TypeError('Record parameter %s has wrong size' % v)
520        adapt = self.adapt
521        value = []
522        for v, t in zip(v, typ):
523            v = adapt(v, t)
524            if v is None:
525                v = ''
526            elif not v:
527                v = '""'
528            else:
529                if isinstance(v, bytes):
530                    if str is not bytes:
531                        v = v.decode('ascii')
532                else:
533                    v = str(v)
534                if self._re_record_quote.search(v):
535                    v = '"%s"' % self._re_record_escape.sub(r'\\\1', v)
536            value.append(v)
537        return '(%s)' % ','.join(value)
538
539    def adapt(self, value, typ=None):
540        """Adapt a value with known database type."""
541        if value is not None and not isinstance(value, Literal):
542            if typ:
543                simple = self.get_simple_name(typ)
544            else:
545                typ = simple = self.guess_simple_type(value) or 'text'
546            try:
547                value = value.__pg_str__(typ)
548            except AttributeError:
549                pass
550            if simple == 'text':
551                pass
552            elif simple == 'record':
553                if isinstance(value, tuple):
554                    value = self._adapt_record(value, typ)
555            elif simple.endswith('[]'):
556                if isinstance(value, list):
557                    adapt = getattr(self, '_adapt_%s_array' % simple[:-2])
558                    value = adapt(value)
559            else:
560                adapt = getattr(self, '_adapt_%s' % simple)
561                value = adapt(value)
562        return value
563
564    @staticmethod
565    def simple_type(name):
566        """Create a simple database type with given attribute names."""
567        typ = DbType(name)
568        typ.simple = name
569        return typ
570
571    @staticmethod
572    def get_simple_name(typ):
573        """Get the simple name of a database type."""
574        if isinstance(typ, DbType):
575            return typ.simple
576        return _simpletypes[typ]
577
578    @staticmethod
579    def get_attnames(typ):
580        """Get the attribute names of a composite database type."""
581        if isinstance(typ, DbType):
582            return typ.attnames
583        return {}
584
585    @classmethod
586    def guess_simple_type(cls, value):
587        """Try to guess which database type the given value has."""
588        if isinstance(value, Bytea):
589            return 'bytea'
590        if isinstance(value, basestring):
591            return 'text'
592        if isinstance(value, bool):
593            return 'bool'
594        if isinstance(value, (int, long)):
595            return 'int'
596        if isinstance(value, float):
597            return 'float'
598        if isinstance(value, Decimal):
599            return 'num'
600        if isinstance(value, (date, time, datetime, timedelta)):
601            return 'date'
602        if isinstance(value, list):
603            return '%s[]' % (cls.guess_simple_base_type(value) or 'text',)
604        if isinstance(value, tuple):
605            simple_type = cls.simple_type
606            typ = simple_type('record')
607            guess = cls.guess_simple_type
608            def get_attnames(self):
609                return AttrDict((str(n + 1), simple_type(guess(v)))
610                    for n, v in enumerate(value))
611            typ._get_attnames = get_attnames
612            return typ
613
614    @classmethod
615    def guess_simple_base_type(cls, value):
616        """Try to guess the base type of a given array."""
617        for v in value:
618            if isinstance(v, list):
619                typ = cls.guess_simple_base_type(v)
620            else:
621                typ = cls.guess_simple_type(v)
622            if typ:
623                return typ
624
625    def adapt_inline(self, value, nested=False):
626        """Adapt a value that is put into the SQL and needs to be quoted."""
627        if value is None:
628            return 'NULL'
629        if isinstance(value, Literal):
630            return value
631        if isinstance(value, Bytea):
632            value = self.db.escape_bytea(value)
633            if bytes is not str:  # Python >= 3.0
634                value = value.decode('ascii')
635        elif isinstance(value, Json):
636            if value.encode:
637                return value.encode()
638            value = self.db.encode_json(value)
639        elif isinstance(value, (datetime, date, time, timedelta)):
640            value = str(value)
641        if isinstance(value, basestring):
642            value = self.db.escape_string(value)
643            return "'%s'" % value
644        if isinstance(value, bool):
645            return 'true' if value else 'false'
646        if isinstance(value, float):
647            if isinf(value):
648                return "'-Infinity'" if value < 0 else "'Infinity'"
649            if isnan(value):
650                return "'NaN'"
651            return value
652        if isinstance(value, (int, long, Decimal)):
653            return value
654        if isinstance(value, list):
655            q = self.adapt_inline
656            s = '[%s]' if nested else 'ARRAY[%s]'
657            return s % ','.join(str(q(v, nested=True)) for v in value)
658        if isinstance(value, tuple):
659            q = self.adapt_inline
660            return '(%s)' % ','.join(str(q(v)) for v in value)
661        try:
662            value = value.__pg_repr__()
663        except AttributeError:
664            raise InterfaceError(
665                'Do not know how to adapt type %s' % type(value))
666        if isinstance(value, (tuple, list)):
667            value = self.adapt_inline(value)
668        return value
669
670    def parameter_list(self):
671        """Return a parameter list for parameters with known database types.
672
673        The list has an add(value, typ) method that will build up the
674        list and return either the literal value or a placeholder.
675        """
676        params = _ParameterList()
677        params.adapt = self.adapt
678        return params
679
680    def format_query(self, command, values, types=None, inline=False):
681        """Format a database query using the given values and types."""
682        if inline and types:
683            raise ValueError('Typed parameters must be sent separately')
684        params = self.parameter_list()
685        if isinstance(values, (list, tuple)):
686            if inline:
687                adapt = self.adapt_inline
688                literals = [adapt(value) for value in values]
689            else:
690                add = params.add
691                literals = []
692                append = literals.append
693                if types:
694                    if (not isinstance(types, (list, tuple)) or
695                            len(types) != len(values)):
696                        raise TypeError('The values and types do not match')
697                    for value, typ in zip(values, types):
698                        append(add(value, typ))
699                else:
700                    for value in values:
701                        append(add(value))
702            command %= tuple(literals)
703        elif isinstance(values, dict):
704            # we want to allow extra keys in the dictionary,
705            # so we first must find the values actually used in the command
706            used_values = {}
707            literals = dict.fromkeys(values, '')
708            for key in literals:
709                del literals[key]
710                try:
711                    command % literals
712                except KeyError:
713                    used_values[key] = values[key]
714                literals[key] = ''
715            values = used_values
716            if inline:
717                adapt = self.adapt_inline
718                literals = dict((key, adapt(value))
719                    for key, value in values.items())
720            else:
721                add = params.add
722                literals = {}
723                if types:
724                    if not isinstance(types, dict):
725                        raise TypeError('The values and types do not match')
726                    for key in sorted(values):
727                        literals[key] = add(values[key], types.get(key))
728                else:
729                    for key in sorted(values):
730                        literals[key] = add(values[key])
731            command %= literals
732        else:
733            raise TypeError('The values must be passed as tuple, list or dict')
734        return command, params
735
736
737def cast_bool(value):
738    """Cast a boolean value."""
739    if not get_bool():
740        return value
741    return value[0] == 't'
742
743
744def cast_json(value):
745    """Cast a JSON value."""
746    cast = get_jsondecode()
747    if not cast:
748        return value
749    return cast(value)
750
751
752def cast_num(value):
753    """Cast a numeric value."""
754    return (get_decimal() or float)(value)
755
756
757def cast_money(value):
758    """Cast a money value."""
759    point = get_decimal_point()
760    if not point:
761        return value
762    if point != '.':
763        value = value.replace(point, '.')
764    value = value.replace('(', '-')
765    value = ''.join(c for c in value if c.isdigit() or c in '.-')
766    return (get_decimal() or float)(value)
767
768
769def cast_int2vector(value):
770    """Cast an int2vector value."""
771    return [int(v) for v in value.split()]
772
773
774def cast_date(value, connection):
775    """Cast a date value."""
776    # The output format depends on the server setting DateStyle.  The default
777    # setting ISO and the setting for German are actually unambiguous.  The
778    # order of days and months in the other two settings is however ambiguous,
779    # so at least here we need to consult the setting to properly parse values.
780    if value == '-infinity':
781        return date.min
782    if value == 'infinity':
783        return date.max
784    value = value.split()
785    if value[-1] == 'BC':
786        return date.min
787    value = value[0]
788    if len(value) > 10:
789        return date.max
790    fmt = connection.date_format()
791    return datetime.strptime(value, fmt).date()
792
793
794def cast_time(value):
795    """Cast a time value."""
796    fmt = '%H:%M:%S.%f' if len(value) > 8 else '%H:%M:%S'
797    return datetime.strptime(value, fmt).time()
798
799
800_re_timezone = regex('(.*)([+-].*)')
801
802
803def cast_timetz(value):
804    """Cast a timetz value."""
805    tz = _re_timezone.match(value)
806    if tz:
807        value, tz = tz.groups()
808    else:
809        tz = '+0000'
810    fmt = '%H:%M:%S.%f' if len(value) > 8 else '%H:%M:%S'
811    if _has_timezone:
812        value += _timezone_as_offset(tz)
813        fmt += '%z'
814        return datetime.strptime(value, fmt).timetz()
815    return datetime.strptime(value, fmt).timetz().replace(
816        tzinfo=_get_timezone(tz))
817
818
819def cast_timestamp(value, connection):
820    """Cast a timestamp value."""
821    if value == '-infinity':
822        return datetime.min
823    if value == 'infinity':
824        return datetime.max
825    value = value.split()
826    if value[-1] == 'BC':
827        return datetime.min
828    fmt = connection.date_format()
829    if fmt.endswith('-%Y') and len(value) > 2:
830        value = value[1:5]
831        if len(value[3]) > 4:
832            return datetime.max
833        fmt = ['%d %b' if fmt.startswith('%d') else '%b %d',
834            '%H:%M:%S.%f' if len(value[2]) > 8 else '%H:%M:%S', '%Y']
835    else:
836        if len(value[0]) > 10:
837            return datetime.max
838        fmt = [fmt, '%H:%M:%S.%f' if len(value[1]) > 8 else '%H:%M:%S']
839    return datetime.strptime(' '.join(value), ' '.join(fmt))
840
841
842def cast_timestamptz(value, connection):
843    """Cast a timestamptz value."""
844    if value == '-infinity':
845        return datetime.min
846    if value == 'infinity':
847        return datetime.max
848    value = value.split()
849    if value[-1] == 'BC':
850        return datetime.min
851    fmt = connection.date_format()
852    if fmt.endswith('-%Y') and len(value) > 2:
853        value = value[1:]
854        if len(value[3]) > 4:
855            return datetime.max
856        fmt = ['%d %b' if fmt.startswith('%d') else '%b %d',
857            '%H:%M:%S.%f' if len(value[2]) > 8 else '%H:%M:%S', '%Y']
858        value, tz = value[:-1], value[-1]
859    else:
860        if fmt.startswith('%Y-'):
861            tz = _re_timezone.match(value[1])
862            if tz:
863                value[1], tz = tz.groups()
864            else:
865                tz = '+0000'
866        else:
867            value, tz = value[:-1], value[-1]
868        if len(value[0]) > 10:
869            return datetime.max
870        fmt = [fmt, '%H:%M:%S.%f' if len(value[1]) > 8 else '%H:%M:%S']
871    if _has_timezone:
872        value.append(_timezone_as_offset(tz))
873        fmt.append('%z')
874        return datetime.strptime(' '.join(value), ' '.join(fmt))
875    return datetime.strptime(' '.join(value), ' '.join(fmt)).replace(
876        tzinfo=_get_timezone(tz))
877
878
879_re_interval_sql_standard = regex(
880    '(?:([+-])?([0-9]+)-([0-9]+) ?)?'
881    '(?:([+-]?[0-9]+)(?!:) ?)?'
882    '(?:([+-])?([0-9]+):([0-9]+):([0-9]+)(?:\\.([0-9]+))?)?')
883
884_re_interval_postgres = regex(
885    '(?:([+-]?[0-9]+) ?years? ?)?'
886    '(?:([+-]?[0-9]+) ?mons? ?)?'
887    '(?:([+-]?[0-9]+) ?days? ?)?'
888    '(?:([+-])?([0-9]+):([0-9]+):([0-9]+)(?:\\.([0-9]+))?)?')
889
890_re_interval_postgres_verbose = regex(
891    '@ ?(?:([+-]?[0-9]+) ?years? ?)?'
892    '(?:([+-]?[0-9]+) ?mons? ?)?'
893    '(?:([+-]?[0-9]+) ?days? ?)?'
894    '(?:([+-]?[0-9]+) ?hours? ?)?'
895    '(?:([+-]?[0-9]+) ?mins? ?)?'
896    '(?:([+-])?([0-9]+)(?:\\.([0-9]+))? ?secs?)? ?(ago)?')
897
898_re_interval_iso_8601 = regex(
899    'P(?:([+-]?[0-9]+)Y)?'
900    '(?:([+-]?[0-9]+)M)?'
901    '(?:([+-]?[0-9]+)D)?'
902    '(?:T(?:([+-]?[0-9]+)H)?'
903    '(?:([+-]?[0-9]+)M)?'
904    '(?:([+-])?([0-9]+)(?:\\.([0-9]+))?S)?)?')
905
906
907def cast_interval(value):
908    """Cast an interval value."""
909    # The output format depends on the server setting IntervalStyle, but it's
910    # not necessary to consult this setting to parse it.  It's faster to just
911    # check all possible formats, and there is no ambiguity here.
912    m = _re_interval_iso_8601.match(value)
913    if m:
914        m = [d or '0' for d in m.groups()]
915        secs_ago = m.pop(5) == '-'
916        m = [int(d) for d in m]
917        years, mons, days, hours, mins, secs, usecs = m
918        if secs_ago:
919            secs = -secs
920            usecs = -usecs
921    else:
922        m = _re_interval_postgres_verbose.match(value)
923        if m:
924            m, ago = [d or '0' for d in m.groups()[:8]], m.group(9)
925            secs_ago = m.pop(5) == '-'
926            m = [-int(d) for d in m] if ago else [int(d) for d in m]
927            years, mons, days, hours, mins, secs, usecs = m
928            if secs_ago:
929                secs = - secs
930                usecs = -usecs
931        else:
932            m = _re_interval_postgres.match(value)
933            if m and any(m.groups()):
934                m = [d or '0' for d in m.groups()]
935                hours_ago = m.pop(3) == '-'
936                m = [int(d) for d in m]
937                years, mons, days, hours, mins, secs, usecs = m
938                if hours_ago:
939                    hours = -hours
940                    mins = -mins
941                    secs = -secs
942                    usecs = -usecs
943            else:
944                m = _re_interval_sql_standard.match(value)
945                if m and any(m.groups()):
946                    m = [d or '0' for d in m.groups()]
947                    years_ago = m.pop(0) == '-'
948                    hours_ago = m.pop(3) == '-'
949                    m = [int(d) for d in m]
950                    years, mons, days, hours, mins, secs, usecs = m
951                    if years_ago:
952                        years = -years
953                        mons = -mons
954                    if hours_ago:
955                        hours = -hours
956                        mins = -mins
957                        secs = -secs
958                        usecs = -usecs
959                else:
960                    raise ValueError('Cannot parse interval: %s' % value)
961    days += 365 * years + 30 * mons
962    return timedelta(days=days, hours=hours, minutes=mins,
963        seconds=secs, microseconds=usecs)
964
965
966class Typecasts(dict):
967    """Dictionary mapping database types to typecast functions.
968
969    The cast functions get passed the string representation of a value in
970    the database which they need to convert to a Python object.  The
971    passed string will never be None since NULL values are already be
972    handled before the cast function is called.
973
974    Note that the basic types are already handled by the C extension.
975    They only need to be handled here as record or array components.
976    """
977
978    # the default cast functions
979    # (str functions are ignored but have been added for faster access)
980    defaults = {'char': str, 'bpchar': str, 'name': str,
981        'text': str, 'varchar': str,
982        'bool': cast_bool, 'bytea': unescape_bytea,
983        'int2': int, 'int4': int, 'serial': int, 'int8': long, 'oid': int,
984        'hstore': cast_hstore, 'json': cast_json, 'jsonb': cast_json,
985        'float4': float, 'float8': float,
986        'numeric': cast_num, 'money': cast_money,
987        'date': cast_date, 'interval': cast_interval,
988        'time': cast_time, 'timetz': cast_timetz,
989        'timestamp': cast_timestamp, 'timestamptz': cast_timestamptz,
990        'int2vector': cast_int2vector, 'uuid': UUID,
991        'anyarray': cast_array, 'record': cast_record}
992
993    connection = None  # will be set in a connection specific instance
994
995    def __missing__(self, typ):
996        """Create a cast function if it is not cached.
997
998        Note that this class never raises a KeyError,
999        but returns None when no special cast function exists.
1000        """
1001        if not isinstance(typ, str):
1002            raise TypeError('Invalid type: %s' % typ)
1003        cast = self.defaults.get(typ)
1004        if cast:
1005            # store default for faster access
1006            cast = self._add_connection(cast)
1007            self[typ] = cast
1008        elif typ.startswith('_'):
1009            base_cast = self[typ[1:]]
1010            cast = self.create_array_cast(base_cast)
1011            if base_cast:
1012                self[typ] = cast
1013        else:
1014            attnames = self.get_attnames(typ)
1015            if attnames:
1016                casts = [self[v.pgtype] for v in attnames.values()]
1017                cast = self.create_record_cast(typ, attnames, casts)
1018                self[typ] = cast
1019        return cast
1020
1021    @staticmethod
1022    def _needs_connection(func):
1023        """Check if a typecast function needs a connection argument."""
1024        try:
1025            args = get_args(func)
1026        except (TypeError, ValueError):
1027            return False
1028        else:
1029            return 'connection' in args[1:]
1030
1031    def _add_connection(self, cast):
1032        """Add a connection argument to the typecast function if necessary."""
1033        if not self.connection or not self._needs_connection(cast):
1034            return cast
1035        return partial(cast, connection=self.connection)
1036
1037    def get(self, typ, default=None):
1038        """Get the typecast function for the given database type."""
1039        return self[typ] or default
1040
1041    def set(self, typ, cast):
1042        """Set a typecast function for the specified database type(s)."""
1043        if isinstance(typ, basestring):
1044            typ = [typ]
1045        if cast is None:
1046            for t in typ:
1047                self.pop(t, None)
1048                self.pop('_%s' % t, None)
1049        else:
1050            if not callable(cast):
1051                raise TypeError("Cast parameter must be callable")
1052            for t in typ:
1053                self[t] = self._add_connection(cast)
1054                self.pop('_%s' % t, None)
1055
1056    def reset(self, typ=None):
1057        """Reset the typecasts for the specified type(s) to their defaults.
1058
1059        When no type is specified, all typecasts will be reset.
1060        """
1061        if typ is None:
1062            self.clear()
1063        else:
1064            if isinstance(typ, basestring):
1065                typ = [typ]
1066            for t in typ:
1067                self.pop(t, None)
1068
1069    @classmethod
1070    def get_default(cls, typ):
1071        """Get the default typecast function for the given database type."""
1072        return cls.defaults.get(typ)
1073
1074    @classmethod
1075    def set_default(cls, typ, cast):
1076        """Set a default typecast function for the given database type(s)."""
1077        if isinstance(typ, basestring):
1078            typ = [typ]
1079        defaults = cls.defaults
1080        if cast is None:
1081            for t in typ:
1082                defaults.pop(t, None)
1083                defaults.pop('_%s' % t, None)
1084        else:
1085            if not callable(cast):
1086                raise TypeError("Cast parameter must be callable")
1087            for t in typ:
1088                defaults[t] = cast
1089                defaults.pop('_%s' % t, None)
1090
1091    def get_attnames(self, typ):
1092        """Return the fields for the given record type.
1093
1094        This method will be replaced with the get_attnames() method of DbTypes.
1095        """
1096        return {}
1097
1098    def dateformat(self):
1099        """Return the current date format.
1100
1101        This method will be replaced with the dateformat() method of DbTypes.
1102        """
1103        return '%Y-%m-%d'
1104
1105    def create_array_cast(self, basecast):
1106        """Create an array typecast for the given base cast."""
1107        cast_array = self['anyarray']
1108        def cast(v):
1109            return cast_array(v, basecast)
1110        return cast
1111
1112    def create_record_cast(self, name, fields, casts):
1113        """Create a named record typecast for the given fields and casts."""
1114        cast_record = self['record']
1115        record = namedtuple(name, fields)
1116        def cast(v):
1117            return record(*cast_record(v, casts))
1118        return cast
1119
1120
1121def get_typecast(typ):
1122    """Get the global typecast function for the given database type(s)."""
1123    return Typecasts.get_default(typ)
1124
1125
1126def set_typecast(typ, cast):
1127    """Set a global typecast function for the given database type(s).
1128
1129    Note that connections cache cast functions. To be sure a global change
1130    is picked up by a running connection, call db.db_types.reset_typecast().
1131    """
1132    Typecasts.set_default(typ, cast)
1133
1134
1135class DbType(str):
1136    """Class augmenting the simple type name with additional info.
1137
1138    The following additional information is provided:
1139
1140        oid: the PostgreSQL type OID
1141        pgtype: the PostgreSQL type name
1142        regtype: the regular type name
1143        simple: the simple PyGreSQL type name
1144        typtype: b = base type, c = composite type etc.
1145        category: A = Array, b = Boolean, C = Composite etc.
1146        delim: delimiter for array types
1147        relid: corresponding table for composite types
1148        attnames: attributes for composite types
1149    """
1150
1151    @property
1152    def attnames(self):
1153        """Get names and types of the fields of a composite type."""
1154        return self._get_attnames(self)
1155
1156
1157class DbTypes(dict):
1158    """Cache for PostgreSQL data types.
1159
1160    This cache maps type OIDs and names to DbType objects containing
1161    information on the associated database type.
1162    """
1163
1164    _num_types = frozenset('int float num money'
1165        ' int2 int4 int8 float4 float8 numeric money'.split())
1166
1167    def __init__(self, db):
1168        """Initialize type cache for connection."""
1169        super(DbTypes, self).__init__()
1170        self._db = weakref.proxy(db)
1171        self._regtypes = False
1172        self._typecasts = Typecasts()
1173        self._typecasts.get_attnames = self.get_attnames
1174        self._typecasts.connection = self._db
1175        if db.server_version < 80400:
1176            # older remote databases (not officially supported)
1177            self._query_pg_type = (
1178                "SELECT oid, typname, typname::text::regtype,"
1179                " typtype, null as typcategory, typdelim, typrelid"
1180                " FROM pg_type WHERE oid=%s::regtype")
1181        else:
1182            self._query_pg_type = (
1183                "SELECT oid, typname, typname::regtype,"
1184                " typtype, typcategory, typdelim, typrelid"
1185                " FROM pg_type WHERE oid=%s::regtype")
1186
1187    def add(self, oid, pgtype, regtype,
1188               typtype, category, delim, relid):
1189        """Create a PostgreSQL type name with additional info."""
1190        if oid in self:
1191            return self[oid]
1192        simple = 'record' if relid else _simpletypes[pgtype]
1193        typ = DbType(regtype if self._regtypes else simple)
1194        typ.oid = oid
1195        typ.simple = simple
1196        typ.pgtype = pgtype
1197        typ.regtype = regtype
1198        typ.typtype = typtype
1199        typ.category = category
1200        typ.delim = delim
1201        typ.relid = relid
1202        typ._get_attnames = self.get_attnames
1203        return typ
1204
1205    def __missing__(self, key):
1206        """Get the type info from the database if it is not cached."""
1207        try:
1208            q = self._query_pg_type % (_quote_if_unqualified('$1', key),)
1209            res = self._db.query(q, (key,)).getresult()
1210        except ProgrammingError:
1211            res = None
1212        if not res:
1213            raise KeyError('Type %s could not be found' % key)
1214        res = res[0]
1215        typ = self.add(*res)
1216        self[typ.oid] = self[typ.pgtype] = typ
1217        return typ
1218
1219    def get(self, key, default=None):
1220        """Get the type even if it is not cached."""
1221        try:
1222            return self[key]
1223        except KeyError:
1224            return default
1225
1226    def get_attnames(self, typ):
1227        """Get names and types of the fields of a composite type."""
1228        if not isinstance(typ, DbType):
1229            typ = self.get(typ)
1230            if not typ:
1231                return None
1232        if not typ.relid:
1233            return None
1234        return self._db.get_attnames(typ.relid, with_oid=False)
1235
1236    def get_typecast(self, typ):
1237        """Get the typecast function for the given database type."""
1238        return self._typecasts.get(typ)
1239
1240    def set_typecast(self, typ, cast):
1241        """Set a typecast function for the specified database type(s)."""
1242        self._typecasts.set(typ, cast)
1243
1244    def reset_typecast(self, typ=None):
1245        """Reset the typecast function for the specified database type(s)."""
1246        self._typecasts.reset(typ)
1247
1248    def typecast(self, value, typ):
1249        """Cast the given value according to the given database type."""
1250        if value is None:
1251            # for NULL values, no typecast is necessary
1252            return None
1253        if not isinstance(typ, DbType):
1254            typ = self.get(typ)
1255            if typ:
1256                typ = typ.pgtype
1257        cast = self.get_typecast(typ) if typ else None
1258        if not cast or cast is str:
1259            # no typecast is necessary
1260            return value
1261        return cast(value)
1262
1263
1264# The result rows for database operations are returned as named tuples
1265# by default. Since creating namedtuple classes is a somewhat expensive
1266# operation, we cache up to 1024 of these classes by default.
1267
1268@lru_cache(maxsize=1024)
1269def _row_factory(names):
1270    """Get a namedtuple factory for row results with the given names."""
1271    try:
1272        try:
1273            return namedtuple('Row', names, rename=True)._make
1274        except TypeError:  # Python 2.6 and 3.0 do not support rename
1275            names = [v if v.isalnum() else 'column_%d' % (n,)
1276                     for n, v in enumerate(names)]
1277            return namedtuple('Row', names)._make
1278    except ValueError:  # there is still a problem with the field names
1279        names = ['column_%d' % (n,) for n in range(len(names))]
1280        return namedtuple('Row', names)._make
1281
1282
1283def set_row_factory_size(maxsize):
1284    """Change the size of the namedtuple factory cache.
1285
1286    If maxsize is set to None, the cache can grow without bound.
1287    """
1288    global _row_factory
1289    _row_factory = lru_cache(maxsize)(_row_factory.__wrapped__)
1290
1291
1292def _namedresult(q):
1293    """Get query result as named tuples."""
1294    row = _row_factory(q.listfields())
1295    return [row(r) for r in q.getresult()]
1296
1297
1298class _MemoryQuery:
1299    """Class that embodies a given query result."""
1300
1301    def __init__(self, result, fields):
1302        """Create query from given result rows and field names."""
1303        self.result = result
1304        self.fields = tuple(fields)
1305
1306    def listfields(self):
1307        """Return the stored field names of this query."""
1308        return self.fields
1309
1310    def getresult(self):
1311        """Return the stored result of this query."""
1312        return self.result
1313
1314
1315def _db_error(msg, cls=DatabaseError):
1316    """Return DatabaseError with empty sqlstate attribute."""
1317    error = cls(msg)
1318    error.sqlstate = None
1319    return error
1320
1321
1322def _int_error(msg):
1323    """Return InternalError."""
1324    return _db_error(msg, InternalError)
1325
1326
1327def _prg_error(msg):
1328    """Return ProgrammingError."""
1329    return _db_error(msg, ProgrammingError)
1330
1331
1332# Initialize the C module
1333
1334set_namedresult(_namedresult)
1335set_decimal(Decimal)
1336set_jsondecode(jsondecode)
1337
1338
1339# The notification handler
1340
1341class NotificationHandler(object):
1342    """A PostgreSQL client-side asynchronous notification handler."""
1343
1344    def __init__(self, db, event, callback=None,
1345            arg_dict=None, timeout=None, stop_event=None):
1346        """Initialize the notification handler.
1347
1348        You must pass a PyGreSQL database connection, the name of an
1349        event (notification channel) to listen for and a callback function.
1350
1351        You can also specify a dictionary arg_dict that will be passed as
1352        the single argument to the callback function, and a timeout value
1353        in seconds (a floating point number denotes fractions of seconds).
1354        If it is absent or None, the callers will never time out.  If the
1355        timeout is reached, the callback function will be called with a
1356        single argument that is None.  If you set the timeout to zero,
1357        the handler will poll notifications synchronously and return.
1358
1359        You can specify the name of the event that will be used to signal
1360        the handler to stop listening as stop_event. By default, it will
1361        be the event name prefixed with 'stop_'.
1362        """
1363        self.db = db
1364        self.event = event
1365        self.stop_event = stop_event or 'stop_%s' % event
1366        self.listening = False
1367        self.callback = callback
1368        if arg_dict is None:
1369            arg_dict = {}
1370        self.arg_dict = arg_dict
1371        self.timeout = timeout
1372
1373    def __del__(self):
1374        self.unlisten()
1375
1376    def close(self):
1377        """Stop listening and close the connection."""
1378        if self.db:
1379            self.unlisten()
1380            self.db.close()
1381            self.db = None
1382
1383    def listen(self):
1384        """Start listening for the event and the stop event."""
1385        if not self.listening:
1386            self.db.query('listen "%s"' % self.event)
1387            self.db.query('listen "%s"' % self.stop_event)
1388            self.listening = True
1389
1390    def unlisten(self):
1391        """Stop listening for the event and the stop event."""
1392        if self.listening:
1393            self.db.query('unlisten "%s"' % self.event)
1394            self.db.query('unlisten "%s"' % self.stop_event)
1395            self.listening = False
1396
1397    def notify(self, db=None, stop=False, payload=None):
1398        """Generate a notification.
1399
1400        Optionally, you can pass a payload with the notification.
1401
1402        If you set the stop flag, a stop notification will be sent that
1403        will cause the handler to stop listening.
1404
1405        Note: If the notification handler is running in another thread, you
1406        must pass a different database connection since PyGreSQL database
1407        connections are not thread-safe.
1408        """
1409        if self.listening:
1410            if not db:
1411                db = self.db
1412            q = 'notify "%s"' % (self.stop_event if stop else self.event)
1413            if payload:
1414                q += ", '%s'" % payload
1415            return db.query(q)
1416
1417    def __call__(self):
1418        """Invoke the notification handler.
1419
1420        The handler is a loop that listens for notifications on the event
1421        and stop event channels.  When either of these notifications are
1422        received, its associated 'pid', 'event' and 'extra' (the payload
1423        passed with the notification) are inserted into its arg_dict
1424        dictionary and the callback is invoked with this dictionary as
1425        a single argument.  When the handler receives a stop event, it
1426        stops listening to both events and return.
1427
1428        In the special case that the timeout of the handler has been set
1429        to zero, the handler will poll all events synchronously and return.
1430        If will keep listening until it receives a stop event.
1431
1432        Note: If you run this loop in another thread, don't use the same
1433        database connection for database operations in the main thread.
1434        """
1435        self.listen()
1436        poll = self.timeout == 0
1437        if not poll:
1438            rlist = [self.db.fileno()]
1439        while self.listening:
1440            if poll or select.select(rlist, [], [], self.timeout)[0]:
1441                while self.listening:
1442                    notice = self.db.getnotify()
1443                    if not notice:  # no more messages
1444                        break
1445                    event, pid, extra = notice
1446                    if event not in (self.event, self.stop_event):
1447                        self.unlisten()
1448                        raise _db_error(
1449                            'Listening for "%s" and "%s", but notified of "%s"'
1450                            % (self.event, self.stop_event, event))
1451                    if event == self.stop_event:
1452                        self.unlisten()
1453                    self.arg_dict.update(pid=pid, event=event, extra=extra)
1454                    self.callback(self.arg_dict)
1455                if poll:
1456                    break
1457            else:   # we timed out
1458                self.unlisten()
1459                self.callback(None)
1460
1461
1462def pgnotify(*args, **kw):
1463    """Same as NotificationHandler, under the traditional name."""
1464    warnings.warn("pgnotify is deprecated, use NotificationHandler instead",
1465        DeprecationWarning, stacklevel=2)
1466    return NotificationHandler(*args, **kw)
1467
1468
1469# The actual PostGreSQL database connection interface:
1470
1471class DB:
1472    """Wrapper class for the _pg connection type."""
1473
1474    db = None  # invalid fallback for underlying connection
1475
1476    def __init__(self, *args, **kw):
1477        """Create a new connection
1478
1479        You can pass either the connection parameters or an existing
1480        _pg or pgdb connection. This allows you to use the methods
1481        of the classic pg interface with a DB-API 2 pgdb connection.
1482        """
1483        if not args and len(kw) == 1:
1484            db = kw.get('db')
1485        elif not kw and len(args) == 1:
1486            db = args[0]
1487        else:
1488            db = None
1489        if db:
1490            if isinstance(db, DB):
1491                db = db.db
1492            else:
1493                try:
1494                    db = db._cnx
1495                except AttributeError:
1496                    pass
1497        if not db or not hasattr(db, 'db') or not hasattr(db, 'query'):
1498            db = connect(*args, **kw)
1499            self._closeable = True
1500        else:
1501            self._closeable = False
1502        self.db = db
1503        self.dbname = db.db
1504        self._regtypes = False
1505        self._attnames = {}
1506        self._pkeys = {}
1507        self._privileges = {}
1508        self._args = args, kw
1509        self.adapter = Adapter(self)
1510        self.dbtypes = DbTypes(self)
1511        if db.server_version < 80400:
1512            # support older remote data bases
1513            self._query_attnames = (
1514                "SELECT a.attname, t.oid, t.typname, t.typname::text::regtype,"
1515                " t.typtype, null as typcategory, t.typdelim, t.typrelid"
1516                " FROM pg_attribute a"
1517                " JOIN pg_type t ON t.oid = a.atttypid"
1518                " WHERE a.attrelid = %s::regclass AND %s"
1519                " AND NOT a.attisdropped ORDER BY a.attnum")
1520        else:
1521            self._query_attnames = (
1522                "SELECT a.attname, t.oid, t.typname, t.typname::regtype,"
1523                " t.typtype, t.typcategory, t.typdelim, t.typrelid"
1524                " FROM pg_attribute a"
1525                " JOIN pg_type t ON t.oid = a.atttypid"
1526                " WHERE a.attrelid = %s::regclass AND %s"
1527                " AND NOT a.attisdropped ORDER BY a.attnum")
1528        db.set_cast_hook(self.dbtypes.typecast)
1529        self.debug = None  # For debugging scripts, this can be set
1530            # * to a string format specification (e.g. in CGI set to "%s<BR>"),
1531            # * to a file object to write debug statements or
1532            # * to a callable object which takes a string argument
1533            # * to any other true value to just print debug statements
1534
1535    def __getattr__(self, name):
1536        # All undefined members are same as in underlying connection:
1537        if self.db:
1538            return getattr(self.db, name)
1539        else:
1540            raise _int_error('Connection is not valid')
1541
1542    def __dir__(self):
1543        # Custom dir function including the attributes of the connection:
1544        attrs = set(self.__class__.__dict__)
1545        attrs.update(self.__dict__)
1546        attrs.update(dir(self.db))
1547        return sorted(attrs)
1548
1549    # Context manager methods
1550
1551    def __enter__(self):
1552        """Enter the runtime context. This will start a transaction."""
1553        self.begin()
1554        return self
1555
1556    def __exit__(self, et, ev, tb):
1557        """Exit the runtime context. This will end the transaction."""
1558        if et is None and ev is None and tb is None:
1559            self.commit()
1560        else:
1561            self.rollback()
1562
1563    def __del__(self):
1564        try:
1565            db = self.db
1566        except AttributeError:
1567            db = None
1568        if db:
1569            db.set_cast_hook(None)
1570            if self._closeable:
1571                db.close()
1572
1573    # Auxiliary methods
1574
1575    def _do_debug(self, *args):
1576        """Print a debug message"""
1577        if self.debug:
1578            s = '\n'.join(str(arg) for arg in args)
1579            if isinstance(self.debug, basestring):
1580                print(self.debug % s)
1581            elif hasattr(self.debug, 'write'):
1582                self.debug.write(s + '\n')
1583            elif callable(self.debug):
1584                self.debug(s)
1585            else:
1586                print(s)
1587
1588    def _escape_qualified_name(self, s):
1589        """Escape a qualified name.
1590
1591        Escapes the name for use as an SQL identifier, unless the
1592        name contains a dot, in which case the name is ambiguous
1593        (could be a qualified name or just a name with a dot in it)
1594        and must be quoted manually by the caller.
1595        """
1596        if '.' not in s:
1597            s = self.escape_identifier(s)
1598        return s
1599
1600    @staticmethod
1601    def _make_bool(d):
1602        """Get boolean value corresponding to d."""
1603        return bool(d) if get_bool() else ('t' if d else 'f')
1604
1605    def _list_params(self, params):
1606        """Create a human readable parameter list."""
1607        return ', '.join('$%d=%r' % (n, v) for n, v in enumerate(params, 1))
1608
1609    # Public methods
1610
1611    # escape_string and escape_bytea exist as methods,
1612    # so we define unescape_bytea as a method as well
1613    unescape_bytea = staticmethod(unescape_bytea)
1614
1615    def decode_json(self, s):
1616        """Decode a JSON string coming from the database."""
1617        return (get_jsondecode() or jsondecode)(s)
1618
1619    def encode_json(self, d):
1620        """Encode a JSON string for use within SQL."""
1621        return jsonencode(d)
1622
1623    def close(self):
1624        """Close the database connection."""
1625        # Wraps shared library function so we can track state.
1626        if self._closeable:
1627            if self.db:
1628                self.db.set_cast_hook(None)
1629                self.db.close()
1630                self.db = None
1631            else:
1632                raise _int_error('Connection already closed')
1633
1634    def reset(self):
1635        """Reset connection with current parameters.
1636
1637        All derived queries and large objects derived from this connection
1638        will not be usable after this call.
1639
1640        """
1641        if self.db:
1642            self.db.reset()
1643        else:
1644            raise _int_error('Connection already closed')
1645
1646    def reopen(self):
1647        """Reopen connection to the database.
1648
1649        Used in case we need another connection to the same database.
1650        Note that we can still reopen a database that we have closed.
1651
1652        """
1653        # There is no such shared library function.
1654        if self._closeable:
1655            db = connect(*self._args[0], **self._args[1])
1656            if self.db:
1657                self.db.set_cast_hook(None)
1658                self.db.close()
1659            self.db = db
1660
1661    def begin(self, mode=None):
1662        """Begin a transaction."""
1663        qstr = 'BEGIN'
1664        if mode:
1665            qstr += ' ' + mode
1666        return self.query(qstr)
1667
1668    start = begin
1669
1670    def commit(self):
1671        """Commit the current transaction."""
1672        return self.query('COMMIT')
1673
1674    end = commit
1675
1676    def rollback(self, name=None):
1677        """Roll back the current transaction."""
1678        qstr = 'ROLLBACK'
1679        if name:
1680            qstr += ' TO ' + name
1681        return self.query(qstr)
1682
1683    abort = rollback
1684
1685    def savepoint(self, name):
1686        """Define a new savepoint within the current transaction."""
1687        return self.query('SAVEPOINT ' + name)
1688
1689    def release(self, name):
1690        """Destroy a previously defined savepoint."""
1691        return self.query('RELEASE ' + name)
1692
1693    def get_parameter(self, parameter):
1694        """Get the value of a run-time parameter.
1695
1696        If the parameter is a string, the return value will also be a string
1697        that is the current setting of the run-time parameter with that name.
1698
1699        You can get several parameters at once by passing a list, set or dict.
1700        When passing a list of parameter names, the return value will be a
1701        corresponding list of parameter settings.  When passing a set of
1702        parameter names, a new dict will be returned, mapping these parameter
1703        names to their settings.  Finally, if you pass a dict as parameter,
1704        its values will be set to the current parameter settings corresponding
1705        to its keys.
1706
1707        By passing the special name 'all' as the parameter, you can get a dict
1708        of all existing configuration parameters.
1709        """
1710        if isinstance(parameter, basestring):
1711            parameter = [parameter]
1712            values = None
1713        elif isinstance(parameter, (list, tuple)):
1714            values = []
1715        elif isinstance(parameter, (set, frozenset)):
1716            values = {}
1717        elif isinstance(parameter, dict):
1718            values = parameter
1719        else:
1720            raise TypeError(
1721                'The parameter must be a string, list, set or dict')
1722        if not parameter:
1723            raise TypeError('No parameter has been specified')
1724        params = {} if isinstance(values, dict) else []
1725        for key in parameter:
1726            param = key.strip().lower() if isinstance(
1727                key, basestring) else None
1728            if not param:
1729                raise TypeError('Invalid parameter')
1730            if param == 'all':
1731                q = 'SHOW ALL'
1732                values = self.db.query(q).getresult()
1733                values = dict(value[:2] for value in values)
1734                break
1735            if isinstance(values, dict):
1736                params[param] = key
1737            else:
1738                params.append(param)
1739        else:
1740            for param in params:
1741                q = 'SHOW %s' % (param,)
1742                value = self.db.query(q).getresult()[0][0]
1743                if values is None:
1744                    values = value
1745                elif isinstance(values, list):
1746                    values.append(value)
1747                else:
1748                    values[params[param]] = value
1749        return values
1750
1751    def set_parameter(self, parameter, value=None, local=False):
1752        """Set the value of a run-time parameter.
1753
1754        If the parameter and the value are strings, the run-time parameter
1755        will be set to that value.  If no value or None is passed as a value,
1756        then the run-time parameter will be restored to its default value.
1757
1758        You can set several parameters at once by passing a list of parameter
1759        names, together with a single value that all parameters should be
1760        set to or with a corresponding list of values.  You can also pass
1761        the parameters as a set if you only provide a single value.
1762        Finally, you can pass a dict with parameter names as keys.  In this
1763        case, you should not pass a value, since the values for the parameters
1764        will be taken from the dict.
1765
1766        By passing the special name 'all' as the parameter, you can reset
1767        all existing settable run-time parameters to their default values.
1768
1769        If you set local to True, then the command takes effect for only the
1770        current transaction.  After commit() or rollback(), the session-level
1771        setting takes effect again.  Setting local to True will appear to
1772        have no effect if it is executed outside a transaction, since the
1773        transaction will end immediately.
1774        """
1775        if isinstance(parameter, basestring):
1776            parameter = {parameter: value}
1777        elif isinstance(parameter, (list, tuple)):
1778            if isinstance(value, (list, tuple)):
1779                parameter = dict(zip(parameter, value))
1780            else:
1781                parameter = dict.fromkeys(parameter, value)
1782        elif isinstance(parameter, (set, frozenset)):
1783            if isinstance(value, (list, tuple, set, frozenset)):
1784                value = set(value)
1785                if len(value) == 1:
1786                    value = value.pop()
1787            if not(value is None or isinstance(value, basestring)):
1788                raise ValueError('A single value must be specified'
1789                    ' when parameter is a set')
1790            parameter = dict.fromkeys(parameter, value)
1791        elif isinstance(parameter, dict):
1792            if value is not None:
1793                raise ValueError('A value must not be specified'
1794                    ' when parameter is a dictionary')
1795        else:
1796            raise TypeError(
1797                'The parameter must be a string, list, set or dict')
1798        if not parameter:
1799            raise TypeError('No parameter has been specified')
1800        params = {}
1801        for key, value in parameter.items():
1802            param = key.strip().lower() if isinstance(
1803                key, basestring) else None
1804            if not param:
1805                raise TypeError('Invalid parameter')
1806            if param == 'all':
1807                if value is not None:
1808                    raise ValueError('A value must ot be specified'
1809                        " when parameter is 'all'")
1810                params = {'all': None}
1811                break
1812            params[param] = value
1813        local = ' LOCAL' if local else ''
1814        for param, value in params.items():
1815            if value is None:
1816                q = 'RESET%s %s' % (local, param)
1817            else:
1818                q = 'SET%s %s TO %s' % (local, param, value)
1819            self._do_debug(q)
1820            self.db.query(q)
1821
1822    def query(self, command, *args):
1823        """Execute a SQL command string.
1824
1825        This method simply sends a SQL query to the database.  If the query is
1826        an insert statement that inserted exactly one row into a table that
1827        has OIDs, the return value is the OID of the newly inserted row.
1828        If the query is an update or delete statement, or an insert statement
1829        that did not insert exactly one row in a table with OIDs, then the
1830        number of rows affected is returned as a string.  If it is a statement
1831        that returns rows as a result (usually a select statement, but maybe
1832        also an "insert/update ... returning" statement), this method returns
1833        a Query object that can be accessed via getresult() or dictresult()
1834        or simply printed.  Otherwise, it returns `None`.
1835
1836        The query can contain numbered parameters of the form $1 in place
1837        of any data constant.  Arguments given after the query string will
1838        be substituted for the corresponding numbered parameter.  Parameter
1839        values can also be given as a single list or tuple argument.
1840        """
1841        # Wraps shared library function for debugging.
1842        if not self.db:
1843            raise _int_error('Connection is not valid')
1844        if args:
1845            self._do_debug(command, args)
1846            return self.db.query(command, args)
1847        self._do_debug(command)
1848        return self.db.query(command)
1849
1850    def query_formatted(self, command, parameters, types=None, inline=False):
1851        """Execute a formatted SQL command string.
1852
1853        Similar to query, but using Python format placeholders of the form
1854        %s or %(names)s instead of PostgreSQL placeholders of the form $1.
1855        The parameters must be passed as a tuple, list or dict.  You can
1856        also pass a corresponding tuple, list or dict of database types in
1857        order to format the parameters properly in case there is ambiguity.
1858
1859        If you set inline to True, the parameters will be sent to the database
1860        embedded in the SQL command, otherwise they will be sent separately.
1861        """
1862        return self.query(*self.adapter.format_query(
1863            command, parameters, types, inline))
1864
1865    def pkey(self, table, composite=False, flush=False):
1866        """Get or set the primary key of a table.
1867
1868        Single primary keys are returned as strings unless you
1869        set the composite flag.  Composite primary keys are always
1870        represented as tuples.  Note that this raises a KeyError
1871        if the table does not have a primary key.
1872
1873        If flush is set then the internal cache for primary keys will
1874        be flushed.  This may be necessary after the database schema or
1875        the search path has been changed.
1876        """
1877        pkeys = self._pkeys
1878        if flush:
1879            pkeys.clear()
1880            self._do_debug('The pkey cache has been flushed')
1881        try:  # cache lookup
1882            pkey = pkeys[table]
1883        except KeyError:  # cache miss, check the database
1884            q = ("SELECT a.attname, a.attnum, i.indkey FROM pg_index i"
1885                " JOIN pg_attribute a ON a.attrelid = i.indrelid"
1886                " AND a.attnum = ANY(i.indkey)"
1887                " AND NOT a.attisdropped"
1888                " WHERE i.indrelid=%s::regclass"
1889                " AND i.indisprimary ORDER BY a.attnum") % (
1890                    _quote_if_unqualified('$1', table),)
1891            pkey = self.db.query(q, (table,)).getresult()
1892            if not pkey:
1893                raise KeyError('Table %s has no primary key' % table)
1894            # we want to use the order defined in the primary key index here,
1895            # not the order as defined by the columns in the table
1896            if len(pkey) > 1:
1897                indkey = pkey[0][2]
1898                pkey = sorted(pkey, key=lambda row: indkey.index(row[1]))
1899                pkey = tuple(row[0] for row in pkey)
1900            else:
1901                pkey = pkey[0][0]
1902            pkeys[table] = pkey  # cache it
1903        if composite and not isinstance(pkey, tuple):
1904            pkey = (pkey,)
1905        return pkey
1906
1907    def get_databases(self):
1908        """Get list of databases in the system."""
1909        return [s[0] for s in
1910            self.db.query('SELECT datname FROM pg_database').getresult()]
1911
1912    def get_relations(self, kinds=None, system=False):
1913        """Get list of relations in connected database of specified kinds.
1914
1915        If kinds is None or empty, all kinds of relations are returned.
1916        Otherwise kinds can be a string or sequence of type letters
1917        specifying which kind of relations you want to list.
1918
1919        Set the system flag if you want to get the system relations as well.
1920        """
1921        where = []
1922        if kinds:
1923            where.append("r.relkind IN (%s)" %
1924                ','.join("'%s'" % k for k in kinds))
1925        if not system:
1926            where.append("s.nspname NOT SIMILAR"
1927                " TO 'pg/_%|information/_schema' ESCAPE '/'")
1928        where = " WHERE %s" % ' AND '.join(where) if where else ''
1929        q = ("SELECT quote_ident(s.nspname)||'.'||quote_ident(r.relname)"
1930            " FROM pg_class r"
1931            " JOIN pg_namespace s ON s.oid = r.relnamespace%s"
1932            " ORDER BY s.nspname, r.relname") % where
1933        return [r[0] for r in self.db.query(q).getresult()]
1934
1935    def get_tables(self, system=False):
1936        """Return list of tables in connected database.
1937
1938        Set the system flag if you want to get the system tables as well.
1939        """
1940        return self.get_relations('r', system)
1941
1942    def get_attnames(self, table, with_oid=True, flush=False):
1943        """Given the name of a table, dig out the set of attribute names.
1944
1945        Returns a read-only dictionary of attribute names (the names are
1946        the keys, the values are the names of the attributes' types)
1947        with the column names in the proper order if you iterate over it.
1948
1949        If flush is set, then the internal cache for attribute names will
1950        be flushed. This may be necessary after the database schema or
1951        the search path has been changed.
1952
1953        By default, only a limited number of simple types will be returned.
1954        You can get the regular types after calling use_regtypes(True).
1955        """
1956        attnames = self._attnames
1957        if flush:
1958            attnames.clear()
1959            self._do_debug('The attnames cache has been flushed')
1960        try:  # cache lookup
1961            names = attnames[table]
1962        except KeyError:  # cache miss, check the database
1963            q = "a.attnum > 0"
1964            if with_oid:
1965                q = "(%s OR a.attname = 'oid')" % q
1966            q = self._query_attnames % (_quote_if_unqualified('$1', table), q)
1967            names = self.db.query(q, (table,)).getresult()
1968            types = self.dbtypes
1969            names = ((name[0], types.add(*name[1:])) for name in names)
1970            names = AttrDict(names)
1971            attnames[table] = names  # cache it
1972        return names
1973
1974    def use_regtypes(self, regtypes=None):
1975        """Use regular type names instead of simplified type names."""
1976        if regtypes is None:
1977            return self.dbtypes._regtypes
1978        else:
1979            regtypes = bool(regtypes)
1980            if regtypes != self.dbtypes._regtypes:
1981                self.dbtypes._regtypes = regtypes
1982                self._attnames.clear()
1983                self.dbtypes.clear()
1984            return regtypes
1985
1986    def has_table_privilege(self, table, privilege='select', flush=False):
1987        """Check whether current user has specified table privilege.
1988
1989        If flush is set, then the internal cache for table privileges will
1990        be flushed. This may be necessary after privileges have been changed.
1991        """
1992        privileges = self._privileges
1993        if flush:
1994            privileges.clear()
1995            self._do_debug('The privileges cache has been flushed')
1996        privilege = privilege.lower()
1997        try:  # ask cache
1998            ret = privileges[table, privilege]
1999        except KeyError:  # cache miss, ask the database
2000            q = "SELECT has_table_privilege(%s, $2)" % (
2001                _quote_if_unqualified('$1', table),)
2002            q = self.db.query(q, (table, privilege))
2003            ret = q.getresult()[0][0] == self._make_bool(True)
2004            privileges[table, privilege] = ret  # cache it
2005        return ret
2006
2007    def get(self, table, row, keyname=None):
2008        """Get a row from a database table or view.
2009
2010        This method is the basic mechanism to get a single row.  It assumes
2011        that the keyname specifies a unique row.  It must be the name of a
2012        single column or a tuple of column names.  If the keyname is not
2013        specified, then the primary key for the table is used.
2014
2015        If row is a dictionary, then the value for the key is taken from it.
2016        Otherwise, the row must be a single value or a tuple of values
2017        corresponding to the passed keyname or primary key.  The fetched row
2018        from the table will be returned as a new dictionary or used to replace
2019        the existing values when row was passed as a dictionary.
2020
2021        The OID is also put into the dictionary if the table has one, but
2022        in order to allow the caller to work with multiple tables, it is
2023        munged as "oid(table)" using the actual name of the table.
2024        """
2025        if table.endswith('*'):  # hint for descendant tables can be ignored
2026            table = table[:-1].rstrip()
2027        attnames = self.get_attnames(table)
2028        qoid = _oid_key(table) if 'oid' in attnames else None
2029        if keyname and isinstance(keyname, basestring):
2030            keyname = (keyname,)
2031        if qoid and isinstance(row, dict) and qoid in row and 'oid' not in row:
2032            row['oid'] = row[qoid]
2033        if not keyname:
2034            try:  # if keyname is not specified, try using the primary key
2035                keyname = self.pkey(table, True)
2036            except KeyError:  # the table has no primary key
2037                # try using the oid instead
2038                if qoid and isinstance(row, dict) and 'oid' in row:
2039                    keyname = ('oid',)
2040                else:
2041                    raise _prg_error('Table %s has no primary key' % table)
2042            else:  # the table has a primary key
2043                # check whether all key columns have values
2044                if isinstance(row, dict) and not set(keyname).issubset(row):
2045                    # try using the oid instead
2046                    if qoid and 'oid' in row:
2047                        keyname = ('oid',)
2048                    else:
2049                        raise KeyError(
2050                            'Missing value in row for specified keyname')
2051        if not isinstance(row, dict):
2052            if not isinstance(row, (tuple, list)):
2053                row = [row]
2054            if len(keyname) != len(row):
2055                raise KeyError(
2056                    'Differing number of items in keyname and row')
2057            row = dict(zip(keyname, row))
2058        params = self.adapter.parameter_list()
2059        adapt = params.add
2060        col = self.escape_identifier
2061        what = 'oid, *' if qoid else '*'
2062        where = ' AND '.join('%s = %s' % (
2063            col(k), adapt(row[k], attnames[k])) for k in keyname)
2064        if 'oid' in row:
2065            if qoid:
2066                row[qoid] = row['oid']
2067            del row['oid']
2068        q = 'SELECT %s FROM %s WHERE %s LIMIT 1' % (
2069            what, self._escape_qualified_name(table), where)
2070        self._do_debug(q, params)
2071        q = self.db.query(q, params)
2072        res = q.dictresult()
2073        if not res:
2074            raise _db_error('No such record in %s\nwhere %s\nwith %s' % (
2075                table, where, self._list_params(params)))
2076        for n, value in res[0].items():
2077            if qoid and n == 'oid':
2078                n = qoid
2079            row[n] = value
2080        return row
2081
2082    def insert(self, table, row=None, **kw):
2083        """Insert a row into a database table.
2084
2085        This method inserts a row into a table.  The name of the table must
2086        be passed as the first parameter.  The other parameters are used for
2087        providing the data of the row that shall be inserted into the table.
2088        If a dictionary is supplied as the second parameter, it starts with
2089        that.  Otherwise it uses a blank dictionary. Either way the dictionary
2090        is updated from the keywords.
2091
2092        The dictionary is then reloaded with the values actually inserted in
2093        order to pick up values modified by rules, triggers, etc.
2094        """
2095        if table.endswith('*'):  # hint for descendant tables can be ignored
2096            table = table[:-1].rstrip()
2097        if row is None:
2098            row = {}
2099        row.update(kw)
2100        if 'oid' in row:
2101            del row['oid']  # do not insert oid
2102        attnames = self.get_attnames(table)
2103        qoid = _oid_key(table) if 'oid' in attnames else None
2104        params = self.adapter.parameter_list()
2105        adapt = params.add
2106        col = self.escape_identifier
2107        names, values = [], []
2108        for n in attnames:
2109            if n in row:
2110                names.append(col(n))
2111                values.append(adapt(row[n], attnames[n]))
2112        if not names:
2113            raise _prg_error('No column found that can be inserted')
2114        names, values = ', '.join(names), ', '.join(values)
2115        ret = 'oid, *' if qoid else '*'
2116        q = 'INSERT INTO %s (%s) VALUES (%s) RETURNING %s' % (
2117            self._escape_qualified_name(table), names, values, ret)
2118        self._do_debug(q, params)
2119        q = self.db.query(q, params)
2120        res = q.dictresult()
2121        if res:  # this should always be true
2122            for n, value in res[0].items():
2123                if qoid and n == 'oid':
2124                    n = qoid
2125                row[n] = value
2126        return row
2127
2128    def update(self, table, row=None, **kw):
2129        """Update an existing row in a database table.
2130
2131        Similar to insert, but updates an existing row.  The update is based
2132        on the primary key of the table or the OID value as munged by get()
2133        or passed as keyword.  The OID will take precedence if provided, so
2134        that it is possible to update the primary key itself.
2135
2136        The dictionary is then modified to reflect any changes caused by the
2137        update due to triggers, rules, default values, etc.
2138        """
2139        if table.endswith('*'):
2140            table = table[:-1].rstrip()  # need parent table name
2141        attnames = self.get_attnames(table)
2142        qoid = _oid_key(table) if 'oid' in attnames else None
2143        if row is None:
2144            row = {}
2145        elif 'oid' in row:
2146            del row['oid']  # only accept oid key from named args for safety
2147        row.update(kw)
2148        if qoid and qoid in row and 'oid' not in row:
2149            row['oid'] = row[qoid]
2150        if qoid and 'oid' in row:  # try using the oid
2151            keyname = ('oid',)
2152        else:  # try using the primary key
2153            try:
2154                keyname = self.pkey(table, True)
2155            except KeyError:  # the table has no primary key
2156                raise _prg_error('Table %s has no primary key' % table)
2157            # check whether all key columns have values
2158            if not set(keyname).issubset(row):
2159                raise KeyError('Missing value for primary key in row')
2160        params = self.adapter.parameter_list()
2161        adapt = params.add
2162        col = self.escape_identifier
2163        where = ' AND '.join('%s = %s' % (
2164            col(k), adapt(row[k], attnames[k])) for k in keyname)
2165        if 'oid' in row:
2166            if qoid:
2167                row[qoid] = row['oid']
2168            del row['oid']
2169        values = []
2170        keyname = set(keyname)
2171        for n in attnames:
2172            if n in row and n not in keyname:
2173                values.append('%s = %s' % (col(n), adapt(row[n], attnames[n])))
2174        if not values:
2175            return row
2176        values = ', '.join(values)
2177        ret = 'oid, *' if qoid else '*'
2178        q = 'UPDATE %s SET %s WHERE %s RETURNING %s' % (
2179            self._escape_qualified_name(table), values, where, ret)
2180        self._do_debug(q, params)
2181        q = self.db.query(q, params)
2182        res = q.dictresult()
2183        if res:  # may be empty when row does not exist
2184            for n, value in res[0].items():
2185                if qoid and n == 'oid':
2186                    n = qoid
2187                row[n] = value
2188        return row
2189
2190    def upsert(self, table, row=None, **kw):
2191        """Insert a row into a database table with conflict resolution
2192
2193        This method inserts a row into a table, but instead of raising a
2194        ProgrammingError exception in case a row with the same primary key
2195        already exists, an update will be executed instead.  This will be
2196        performed as a single atomic operation on the database, so race
2197        conditions can be avoided.
2198
2199        Like the insert method, the first parameter is the name of the
2200        table and the second parameter can be used to pass the values to
2201        be inserted as a dictionary.
2202
2203        Unlike the insert und update statement, keyword parameters are not
2204        used to modify the dictionary, but to specify which columns shall
2205        be updated in case of a conflict, and in which way:
2206
2207        A value of False or None means the column shall not be updated,
2208        a value of True means the column shall be updated with the value
2209        that has been proposed for insertion, i.e. has been passed as value
2210        in the dictionary.  Columns that are not specified by keywords but
2211        appear as keys in the dictionary are also updated like in the case
2212        keywords had been passed with the value True.
2213
2214        So if in the case of a conflict you want to update every column that
2215        has been passed in the dictionary row, you would call upsert(table, row).
2216        If you don't want to do anything in case of a conflict, i.e. leave
2217        the existing row as it is, call upsert(table, row, **dict.fromkeys(row)).
2218
2219        If you need more fine-grained control of what gets updated, you can
2220        also pass strings in the keyword parameters.  These strings will
2221        be used as SQL expressions for the update columns.  In these
2222        expressions you can refer to the value that already exists in
2223        the table by prefixing the column name with "included.", and to
2224        the value that has been proposed for insertion by prefixing the
2225        column name with the "excluded."
2226
2227        The dictionary is modified in any case to reflect the values in
2228        the database after the operation has completed.
2229
2230        Note: The method uses the PostgreSQL "upsert" feature which is
2231        only available since PostgreSQL 9.5.
2232        """
2233        if table.endswith('*'):  # hint for descendant tables can be ignored
2234            table = table[:-1].rstrip()
2235        if row is None:
2236            row = {}
2237        if 'oid' in row:
2238            del row['oid']  # do not insert oid
2239        if 'oid' in kw:
2240            del kw['oid']  # do not update oid
2241        attnames = self.get_attnames(table)
2242        qoid = _oid_key(table) if 'oid' in attnames else None
2243        params = self.adapter.parameter_list()
2244        adapt = params.add
2245        col = self.escape_identifier
2246        names, values, updates = [], [], []
2247        for n in attnames:
2248            if n in row:
2249                names.append(col(n))
2250                values.append(adapt(row[n], attnames[n]))
2251        names, values = ', '.join(names), ', '.join(values)
2252        try:
2253            keyname = self.pkey(table, True)
2254        except KeyError:
2255            raise _prg_error('Table %s has no primary key' % table)
2256        target = ', '.join(col(k) for k in keyname)
2257        update = []
2258        keyname = set(keyname)
2259        keyname.add('oid')
2260        for n in attnames:
2261            if n not in keyname:
2262                value = kw.get(n, True)
2263                if value:
2264                    if not isinstance(value, basestring):
2265                        value = 'excluded.%s' % col(n)
2266                    update.append('%s = %s' % (col(n), value))
2267        if not values:
2268            return row
2269        do = 'update set %s' % ', '.join(update) if update else 'nothing'
2270        ret = 'oid, *' if qoid else '*'
2271        q = ('INSERT INTO %s AS included (%s) VALUES (%s)'
2272            ' ON CONFLICT (%s) DO %s RETURNING %s') % (
2273                self._escape_qualified_name(table), names, values,
2274                target, do, ret)
2275        self._do_debug(q, params)
2276        try:
2277            q = self.db.query(q, params)
2278        except ProgrammingError:
2279            if self.server_version < 90500:
2280                raise _prg_error(
2281                    'Upsert operation is not supported by PostgreSQL version')
2282            raise  # re-raise original error
2283        res = q.dictresult()
2284        if res:  # may be empty with "do nothing"
2285            for n, value in res[0].items():
2286                if qoid and n == 'oid':
2287                    n = qoid
2288                row[n] = value
2289        else:
2290            self.get(table, row)
2291        return row
2292
2293    def clear(self, table, row=None):
2294        """Clear all the attributes to values determined by the types.
2295
2296        Numeric types are set to 0, Booleans are set to false, and everything
2297        else is set to the empty string.  If the row argument is present,
2298        it is used as the row dictionary and any entries matching attribute
2299        names are cleared with everything else left unchanged.
2300        """
2301        # At some point we will need a way to get defaults from a table.
2302        if row is None:
2303            row = {}  # empty if argument is not present
2304        attnames = self.get_attnames(table)
2305        for n, t in attnames.items():
2306            if n == 'oid':
2307                continue
2308            t = t.simple
2309            if t in DbTypes._num_types:
2310                row[n] = 0
2311            elif t == 'bool':
2312                row[n] = self._make_bool(False)
2313            else:
2314                row[n] = ''
2315        return row
2316
2317    def delete(self, table, row=None, **kw):
2318        """Delete an existing row in a database table.
2319
2320        This method deletes the row from a table.  It deletes based on the
2321        primary key of the table or the OID value as munged by get() or
2322        passed as keyword.  The OID will take precedence if provided.
2323
2324        The return value is the number of deleted rows (i.e. 0 if the row
2325        did not exist and 1 if the row was deleted).
2326
2327        Note that if the row cannot be deleted because e.g. it is still
2328        referenced by another table, this method raises a ProgrammingError.
2329        """
2330        if table.endswith('*'):  # hint for descendant tables can be ignored
2331            table = table[:-1].rstrip()
2332        attnames = self.get_attnames(table)
2333        qoid = _oid_key(table) if 'oid' in attnames else None
2334        if row is None:
2335            row = {}
2336        elif 'oid' in row:
2337            del row['oid']  # only accept oid key from named args for safety
2338        row.update(kw)
2339        if qoid and qoid in row and 'oid' not in row:
2340            row['oid'] = row[qoid]
2341        if qoid and 'oid' in row:  # try using the oid
2342            keyname = ('oid',)
2343        else:  # try using the primary key
2344            try:
2345                keyname = self.pkey(table, True)
2346            except KeyError:  # the table has no primary key
2347                raise _prg_error('Table %s has no primary key' % table)
2348            # check whether all key columns have values
2349            if not set(keyname).issubset(row):
2350                raise KeyError('Missing value for primary key in row')
2351        params = self.adapter.parameter_list()
2352        adapt = params.add
2353        col = self.escape_identifier
2354        where = ' AND '.join('%s = %s' % (
2355            col(k), adapt(row[k], attnames[k])) for k in keyname)
2356        if 'oid' in row:
2357            if qoid:
2358                row[qoid] = row['oid']
2359            del row['oid']
2360        q = 'DELETE FROM %s WHERE %s' % (
2361            self._escape_qualified_name(table), where)
2362        self._do_debug(q, params)
2363        res = self.db.query(q, params)
2364        return int(res)
2365
2366    def truncate(self, table, restart=False, cascade=False, only=False):
2367        """Empty a table or set of tables.
2368
2369        This method quickly removes all rows from the given table or set
2370        of tables.  It has the same effect as an unqualified DELETE on each
2371        table, but since it does not actually scan the tables it is faster.
2372        Furthermore, it reclaims disk space immediately, rather than requiring
2373        a subsequent VACUUM operation. This is most useful on large tables.
2374
2375        If restart is set to True, sequences owned by columns of the truncated
2376        table(s) are automatically restarted.  If cascade is set to True, it
2377        also truncates all tables that have foreign-key references to any of
2378        the named tables.  If the parameter only is not set to True, all the
2379        descendant tables (if any) will also be truncated. Optionally, a '*'
2380        can be specified after the table name to explicitly indicate that
2381        descendant tables are included.
2382        """
2383        if isinstance(table, basestring):
2384            only = {table: only}
2385            table = [table]
2386        elif isinstance(table, (list, tuple)):
2387            if isinstance(only, (list, tuple)):
2388                only = dict(zip(table, only))
2389            else:
2390                only = dict.fromkeys(table, only)
2391        elif isinstance(table, (set, frozenset)):
2392            only = dict.fromkeys(table, only)
2393        else:
2394            raise TypeError('The table must be a string, list or set')
2395        if not (restart is None or isinstance(restart, (bool, int))):
2396            raise TypeError('Invalid type for the restart option')
2397        if not (cascade is None or isinstance(cascade, (bool, int))):
2398            raise TypeError('Invalid type for the cascade option')
2399        tables = []
2400        for t in table:
2401            u = only.get(t)
2402            if not (u is None or isinstance(u, (bool, int))):
2403                raise TypeError('Invalid type for the only option')
2404            if t.endswith('*'):
2405                if u:
2406                    raise ValueError(
2407                        'Contradictory table name and only options')
2408                t = t[:-1].rstrip()
2409            t = self._escape_qualified_name(t)
2410            if u:
2411                t = 'ONLY %s' % t
2412            tables.append(t)
2413        q = ['TRUNCATE', ', '.join(tables)]
2414        if restart:
2415            q.append('RESTART IDENTITY')
2416        if cascade:
2417            q.append('CASCADE')
2418        q = ' '.join(q)
2419        self._do_debug(q)
2420        return self.db.query(q)
2421
2422    def get_as_list(self, table, what=None, where=None,
2423            order=None, limit=None, offset=None, scalar=False):
2424        """Get a table as a list.
2425
2426        This gets a convenient representation of the table as a list
2427        of named tuples in Python.  You only need to pass the name of
2428        the table (or any other SQL expression returning rows).  Note that
2429        by default this will return the full content of the table which
2430        can be huge and overflow your memory.  However, you can control
2431        the amount of data returned using the other optional parameters.
2432
2433        The parameter 'what' can restrict the query to only return a
2434        subset of the table columns.  It can be a string, list or a tuple.
2435        The parameter 'where' can restrict the query to only return a
2436        subset of the table rows.  It can be a string, list or a tuple
2437        of SQL expressions that all need to be fulfilled.  The parameter
2438        'order' specifies the ordering of the rows.  It can also be a
2439        other string, list or a tuple.  If no ordering is specified,
2440        the result will be ordered by the primary key(s) or all columns
2441        if no primary key exists.  You can set 'order' to False if you
2442        don't care about the ordering.  The parameters 'limit' and 'offset'
2443        can be integers specifying the maximum number of rows returned
2444        and a number of rows skipped over.
2445
2446        If you set the 'scalar' option to True, then instead of the
2447        named tuples you will get the first items of these tuples.
2448        This is useful if the result has only one column anyway.
2449        """
2450        if not table:
2451            raise TypeError('The table name is missing')
2452        if what:
2453            if isinstance(what, (list, tuple)):
2454                what = ', '.join(map(str, what))
2455            if order is None:
2456                order = what
2457        else:
2458            what = '*'
2459        q = ['SELECT', what, 'FROM', table]
2460        if where:
2461            if isinstance(where, (list, tuple)):
2462                where = ' AND '.join(map(str, where))
2463            q.extend(['WHERE', where])
2464        if order is None:
2465            try:
2466                order = self.pkey(table, True)
2467            except (KeyError, ProgrammingError):
2468                try:
2469                    order = list(self.get_attnames(table))
2470                except (KeyError, ProgrammingError):
2471                    pass
2472        if order:
2473            if isinstance(order, (list, tuple)):
2474                order = ', '.join(map(str, order))
2475            q.extend(['ORDER BY', order])
2476        if limit:
2477            q.append('LIMIT %d' % limit)
2478        if offset:
2479            q.append('OFFSET %d' % offset)
2480        q = ' '.join(q)
2481        self._do_debug(q)
2482        q = self.db.query(q)
2483        res = q.namedresult()
2484        if res and scalar:
2485            res = [row[0] for row in res]
2486        return res
2487
2488    def get_as_dict(self, table, keyname=None, what=None, where=None,
2489            order=None, limit=None, offset=None, scalar=False):
2490        """Get a table as a dictionary.
2491
2492        This method is similar to get_as_list(), but returns the table
2493        as a Python dict instead of a Python list, which can be even
2494        more convenient. The primary key column(s) of the table will
2495        be used as the keys of the dictionary, while the other column(s)
2496        will be the corresponding values.  The keys will be named tuples
2497        if the table has a composite primary key.  The rows will be also
2498        named tuples unless the 'scalar' option has been set to True.
2499        With the optional parameter 'keyname' you can specify an alternative
2500        set of columns to be used as the keys of the dictionary.  It must
2501        be set as a string, list or a tuple.
2502
2503        If the Python version supports it, the dictionary will be an
2504        OrderedDict using the order specified with the 'order' parameter
2505        or the key column(s) if not specified.  You can set 'order' to False
2506        if you don't care about the ordering.  In this case the returned
2507        dictionary will be an ordinary one.
2508        """
2509        if not table:
2510            raise TypeError('The table name is missing')
2511        if not keyname:
2512            try:
2513                keyname = self.pkey(table, True)
2514            except (KeyError, ProgrammingError):
2515                raise _prg_error('Table %s has no primary key' % table)
2516        if isinstance(keyname, basestring):
2517            keyname = [keyname]
2518        elif not isinstance(keyname, (list, tuple)):
2519            raise KeyError('The keyname must be a string, list or tuple')
2520        if what:
2521            if isinstance(what, (list, tuple)):
2522                what = ', '.join(map(str, what))
2523            if order is None:
2524                order = what
2525        else:
2526            what = '*'
2527        q = ['SELECT', what, 'FROM', table]
2528        if where:
2529            if isinstance(where, (list, tuple)):
2530                where = ' AND '.join(map(str, where))
2531            q.extend(['WHERE', where])
2532        if order is None:
2533            order = keyname
2534        if order:
2535            if isinstance(order, (list, tuple)):
2536                order = ', '.join(map(str, order))
2537            q.extend(['ORDER BY', order])
2538        if limit:
2539            q.append('LIMIT %d' % limit)
2540        if offset:
2541            q.append('OFFSET %d' % offset)
2542        q = ' '.join(q)
2543        self._do_debug(q)
2544        q = self.db.query(q)
2545        res = q.getresult()
2546        cls = OrderedDict if order else dict
2547        if not res:
2548            return cls()
2549        keyset = set(keyname)
2550        fields = q.listfields()
2551        if not keyset.issubset(fields):
2552            raise KeyError('Missing keyname in row')
2553        keyind, rowind = [], []
2554        for i, f in enumerate(fields):
2555            (keyind if f in keyset else rowind).append(i)
2556        keytuple = len(keyind) > 1
2557        getkey = itemgetter(*keyind)
2558        keys = map(getkey, res)
2559        if scalar:
2560            rowind = rowind[:1]
2561            rowtuple = False
2562        else:
2563            rowtuple = len(rowind) > 1
2564        if scalar or rowtuple:
2565            getrow = itemgetter(*rowind)
2566        else:
2567            rowind = rowind[0]
2568            getrow = lambda row: (row[rowind],)
2569            rowtuple = True
2570        rows = map(getrow, res)
2571        if keytuple or rowtuple:
2572            namedresult = get_namedresult()
2573            if namedresult:
2574                if keytuple:
2575                    keys = namedresult(_MemoryQuery(keys, keyname))
2576                if rowtuple:
2577                    fields = [f for f in fields if f not in keyset]
2578                    rows = namedresult(_MemoryQuery(rows, fields))
2579        return cls(zip(keys, rows))
2580
2581    def notification_handler(self,
2582            event, callback, arg_dict=None, timeout=None, stop_event=None):
2583        """Get notification handler that will run the given callback."""
2584        return NotificationHandler(self,
2585            event, callback, arg_dict, timeout, stop_event)
2586
2587
2588# if run as script, print some information
2589
2590if __name__ == '__main__':
2591    print('PyGreSQL version' + version)
2592    print('')
2593    print(__doc__)
Note: See TracBrowser for help on using the repository browser.