source: trunk/pg.py @ 989

Last change on this file since 989 was 989, checked in by cito, 3 months ago

Minor whitespace fixes and IDE hints

  • Property svn:keywords set to Id
File size: 97.4 KB
Line 
1#!/usr/bin/python
2#
3# $Id: pg.py 989 2019-04-24 15:49:20Z cito $
4#
5# PyGreSQL - a Python interface for the PostgreSQL database.
6#
7# This file contains the classic pg module.
8#
9# Copyright (c) 2019 by the PyGreSQL Development Team
10#
11# The notification handler is based on pgnotify which is
12# Copyright (c) 2001 Ng Pheng Siong. All rights reserved.
13#
14# Please see the LICENSE.TXT file for specific restrictions.
15
16"""PyGreSQL classic interface.
17
18This pg module implements some basic database management stuff.
19It includes the _pg module and builds on it, providing the higher
20level wrapper class named DB with additional functionality.
21This is known as the "classic" ("old style") PyGreSQL interface.
22For a DB-API 2 compliant interface use the newer pgdb module.
23"""
24
25from __future__ import print_function, division
26
27from _pg import *
28
29__version__ = version
30
31import select
32import warnings
33import weakref
34
35from datetime import date, time, datetime, timedelta, tzinfo
36from decimal import Decimal
37from math import isnan, isinf
38from collections import namedtuple
39from keyword import iskeyword
40from operator import itemgetter
41from functools import partial
42from re import compile as regex
43from json import loads as jsondecode, dumps as jsonencode
44from uuid import UUID
45
46try:  # noinspection PyUnresolvedReferences
47    long
48except NameError:  # Python >= 3.0
49    long = int
50
51try:  # noinspection PyUnresolvedReferences
52    basestring
53except NameError:  # Python >= 3.0
54    basestring = (str, bytes)
55
56try:
57    from functools import lru_cache
58except ImportError:  # Python < 3.2
59    from functools import update_wrapper
60    try:
61        from _thread import RLock
62    except ImportError:
63        class RLock:  # for builds without threads
64            def __enter__(self): pass
65
66            def __exit__(self, exctype, excinst, exctb): pass
67
68    def lru_cache(maxsize=128):
69        """Simplified functools.lru_cache decorator for one argument."""
70
71        def decorator(function):
72            sentinel = object()
73            cache = {}
74            get = cache.get
75            lock = RLock()
76            root = []
77            root_full = [root, False]
78            root[:] = [root, root, None, None]
79
80            if maxsize == 0:
81
82                def wrapper(arg):
83                    res = function(arg)
84                    return res
85
86            elif maxsize is None:
87
88                def wrapper(arg):
89                    res = get(arg, sentinel)
90                    if res is not sentinel:
91                        return res
92                    res = function(arg)
93                    cache[arg] = res
94                    return res
95
96            else:
97
98                def wrapper(arg):
99                    with lock:
100                        link = get(arg)
101                        if link is not None:
102                            root = root_full[0]
103                            prev, next, _arg, res = link
104                            prev[1] = next
105                            next[0] = prev
106                            last = root[0]
107                            last[1] = root[0] = link
108                            link[0] = last
109                            link[1] = root
110                            return res
111                    res = function(arg)
112                    with lock:
113                        root, full = root_full
114                        if arg in cache:
115                            pass
116                        elif full:
117                            oldroot = root
118                            oldroot[2] = arg
119                            oldroot[3] = res
120                            root = root_full[0] = oldroot[1]
121                            oldarg = root[2]
122                            oldres = root[3]  # keep reference
123                            root[2] = root[3] = None
124                            del cache[oldarg]
125                            cache[arg] = oldroot
126                        else:
127                            last = root[0]
128                            link = [last, root, arg, res]
129                            last[1] = root[0] = cache[arg] = link
130                            if len(cache) >= maxsize:
131                                root_full[1] = True
132                    return res
133
134            wrapper.__wrapped__ = function
135            return update_wrapper(wrapper, function)
136
137        return decorator
138
139
140# Auxiliary classes and functions that are independent from a DB connection:
141
142try:
143    from collections import OrderedDict
144except ImportError:  # Python 2.6 or 3.0
145    OrderedDict = dict
146
147
148    class AttrDict(dict):
149        """Simple read-only ordered dictionary for storing attribute names."""
150
151        def __init__(self, *args, **kw):
152            if len(args) > 1 or kw:
153                raise TypeError
154            items = args[0] if args else []
155            if isinstance(items, dict):
156                raise TypeError
157            items = list(items)
158            self._keys = [item[0] for item in items]
159            dict.__init__(self, items)
160            self._read_only = True
161            error = self._read_only_error
162            self.clear = self.update = error
163            self.pop = self.setdefault = self.popitem = error
164
165        def __setitem__(self, key, value):
166            if self._read_only:
167                self._read_only_error()
168            dict.__setitem__(self, key, value)
169
170        def __delitem__(self, key):
171            if self._read_only:
172                self._read_only_error()
173            dict.__delitem__(self, key)
174
175        def __iter__(self):
176            return iter(self._keys)
177
178        def keys(self):
179            return list(self._keys)
180
181        def values(self):
182            return [self[key] for key in self]
183
184        def items(self):
185            return [(key, self[key]) for key in self]
186
187        def iterkeys(self):
188            return self.__iter__()
189
190        def itervalues(self):
191            return iter(self.values())
192
193        def iteritems(self):
194            return iter(self.items())
195
196        @staticmethod
197        def _read_only_error(*args, **kw):
198            raise TypeError('This object is read-only')
199
200else:
201
202     class AttrDict(OrderedDict):
203        """Simple read-only ordered dictionary for storing attribute names."""
204
205        def __init__(self, *args, **kw):
206            self._read_only = False
207            OrderedDict.__init__(self, *args, **kw)
208            self._read_only = True
209            error = self._read_only_error
210            self.clear = self.update = error
211            self.pop = self.setdefault = self.popitem = error
212
213        def __setitem__(self, key, value):
214            if self._read_only:
215                self._read_only_error()
216            OrderedDict.__setitem__(self, key, value)
217
218        def __delitem__(self, key):
219            if self._read_only:
220                self._read_only_error()
221            OrderedDict.__delitem__(self, key)
222
223        @staticmethod
224        def _read_only_error(*args, **kw):
225            raise TypeError('This object is read-only')
226
227try:
228    from inspect import signature
229except ImportError:  # Python < 3.3
230    from inspect import getargspec
231
232    def get_args(func):
233        return getargspec(func).args
234else:
235
236    def get_args(func):
237        return list(signature(func).parameters)
238
239try:
240    from datetime import timezone
241except ImportError:  # Python < 3.2
242
243    class timezone(tzinfo):
244        """Simple timezone implementation."""
245
246        def __init__(self, offset, name=None):
247            self.offset = offset
248            if not name:
249                minutes = self.offset.days * 1440 + self.offset.seconds // 60
250                if minutes < 0:
251                    hours, minutes = divmod(-minutes, 60)
252                    hours = -hours
253                else:
254                    hours, minutes = divmod(minutes, 60)
255                name = 'UTC%+03d:%02d' % (hours, minutes)
256            self.name = name
257
258        def utcoffset(self, dt):
259            return self.offset
260
261        def tzname(self, dt):
262            return self.name
263
264        def dst(self, dt):
265            return None
266
267    timezone.utc = timezone(timedelta(0), 'UTC')
268
269    _has_timezone = False
270else:
271    _has_timezone = True
272
273# time zones used in Postgres timestamptz output
274_timezones = dict(CET='+0100', EET='+0200', EST='-0500',
275    GMT='+0000', HST='-1000', MET='+0100', MST='-0700',
276    UCT='+0000', UTC='+0000', WET='+0000')
277
278
279def _timezone_as_offset(tz):
280    if tz.startswith(('+', '-')):
281        if len(tz) < 5:
282            return tz + '00'
283        return tz.replace(':', '')
284    return _timezones.get(tz, '+0000')
285
286
287def _get_timezone(tz):
288    tz = _timezone_as_offset(tz)
289    minutes = 60 * int(tz[1:3]) + int(tz[3:5])
290    if tz[0] == '-':
291        minutes = -minutes
292    return timezone(timedelta(minutes=minutes), tz)
293
294
295def _oid_key(table):
296    """Build oid key from a table name."""
297    return 'oid(%s)' % table
298
299
300class _SimpleTypes(dict):
301    """Dictionary mapping pg_type names to simple type names."""
302
303    _types = {'bool': 'bool',
304        'bytea': 'bytea',
305        'date': 'date interval time timetz timestamp timestamptz'
306            ' abstime reltime',  # these are very old
307        'float': 'float4 float8',
308        'int': 'cid int2 int4 int8 oid xid',
309        'hstore': 'hstore', 'json': 'json jsonb', 'uuid': 'uuid',
310        'num': 'numeric', 'money': 'money',
311        'text': 'bpchar char name text varchar'}
312
313    def __init__(self):
314        for typ, keys in self._types.items():
315            for key in keys.split():
316                self[key] = typ
317                self['_%s' % key] = '%s[]' % typ
318
319    # this could be a static method in Python > 2.6
320    def __missing__(self, key):
321        return 'text'
322
323_simpletypes = _SimpleTypes()
324
325
326def _quote_if_unqualified(param, name):
327    """Quote parameter representing a qualified name.
328
329    Puts a quote_ident() call around the give parameter unless
330    the name contains a dot, in which case the name is ambiguous
331    (could be a qualified name or just a name with a dot in it)
332    and must be quoted manually by the caller.
333    """
334    if isinstance(name, basestring) and '.' not in name:
335        return 'quote_ident(%s)' % (param,)
336    return param
337
338
339class _ParameterList(list):
340    """Helper class for building typed parameter lists."""
341
342    def add(self, value, typ=None):
343        """Typecast value with known database type and build parameter list.
344
345        If this is a literal value, it will be returned as is.  Otherwise, a
346        placeholder will be returned and the parameter list will be augmented.
347        """
348        value = self.adapt(value, typ)
349        if isinstance(value, Literal):
350            return value
351        self.append(value)
352        return '$%d' % len(self)
353
354
355class Bytea(bytes):
356    """Wrapper class for marking Bytea values."""
357
358
359class Hstore(dict):
360    """Wrapper class for marking hstore values."""
361
362    _re_quote = regex('^[Nn][Uu][Ll][Ll]$|[ ,=>]')
363
364    @classmethod
365    def _quote(cls, s):
366        if s is None:
367            return 'NULL'
368        if not s:
369            return '""'
370        s = s.replace('"', '\\"')
371        if cls._re_quote.search(s):
372            s = '"%s"' % s
373        return s
374
375    def __str__(self):
376        q = self._quote
377        return ','.join('%s=>%s' % (q(k), q(v)) for k, v in self.items())
378
379
380class Json:
381    """Wrapper class for marking Json values."""
382
383    def __init__(self, obj):
384        self.obj = obj
385
386
387class Literal(str):
388    """Wrapper class for marking literal SQL values."""
389
390
391class Adapter:
392    """Class providing methods for adapting parameters to the database."""
393
394    _bool_true_values = frozenset('t true 1 y yes on'.split())
395
396    _date_literals = frozenset('current_date current_time'
397        ' current_timestamp localtime localtimestamp'.split())
398
399    _re_array_quote = regex(r'[{},"\\\s]|^[Nn][Uu][Ll][Ll]$')
400    _re_record_quote = regex(r'[(,"\\]')
401    _re_array_escape = _re_record_escape = regex(r'(["\\])')
402
403    def __init__(self, db):
404        self.db = weakref.proxy(db)
405
406    @classmethod
407    def _adapt_bool(cls, v):
408        """Adapt a boolean parameter."""
409        if isinstance(v, basestring):
410            if not v:
411                return None
412            v = v.lower() in cls._bool_true_values
413        return 't' if v else 'f'
414
415    @classmethod
416    def _adapt_date(cls, v):
417        """Adapt a date parameter."""
418        if not v:
419            return None
420        if isinstance(v, basestring) and v.lower() in cls._date_literals:
421            return Literal(v)
422        return v
423
424    @staticmethod
425    def _adapt_num(v):
426        """Adapt a numeric parameter."""
427        if not v and v != 0:
428            return None
429        return v
430
431    _adapt_int = _adapt_float = _adapt_money = _adapt_num
432
433    def _adapt_bytea(self, v):
434        """Adapt a bytea parameter."""
435        return self.db.escape_bytea(v)
436
437    def _adapt_json(self, v):
438        """Adapt a json parameter."""
439        if not v:
440            return None
441        if isinstance(v, basestring):
442            return v
443        return self.db.encode_json(v)
444
445    @classmethod
446    def _adapt_text_array(cls, v):
447        """Adapt a text type array parameter."""
448        if isinstance(v, list):
449            adapt = cls._adapt_text_array
450            return '{%s}' % ','.join(adapt(v) for v in v)
451        if v is None:
452            return 'null'
453        if not v:
454            return '""'
455        v = str(v)
456        if cls._re_array_quote.search(v):
457            v = '"%s"' % cls._re_array_escape.sub(r'\\\1', v)
458        return v
459
460    _adapt_date_array = _adapt_text_array
461
462    @classmethod
463    def _adapt_bool_array(cls, v):
464        """Adapt a boolean array parameter."""
465        if isinstance(v, list):
466            adapt = cls._adapt_bool_array
467            return '{%s}' % ','.join(adapt(v) for v in v)
468        if v is None:
469            return 'null'
470        if isinstance(v, basestring):
471            if not v:
472                return 'null'
473            v = v.lower() in cls._bool_true_values
474        return 't' if v else 'f'
475
476    @classmethod
477    def _adapt_num_array(cls, v):
478        """Adapt a numeric array parameter."""
479        if isinstance(v, list):
480            adapt = cls._adapt_num_array
481            return '{%s}' % ','.join(adapt(v) for v in v)
482        if not v and v != 0:
483            return 'null'
484        return str(v)
485
486    _adapt_int_array = _adapt_float_array = _adapt_money_array = \
487            _adapt_num_array
488
489    def _adapt_bytea_array(self, v):
490        """Adapt a bytea array parameter."""
491        if isinstance(v, list):
492            return b'{' + b','.join(
493                self._adapt_bytea_array(v) for v in v) + b'}'
494        if v is None:
495            return b'null'
496        return self.db.escape_bytea(v).replace(b'\\', b'\\\\')
497
498    def _adapt_json_array(self, v):
499        """Adapt a json array parameter."""
500        if isinstance(v, list):
501            adapt = self._adapt_json_array
502            return '{%s}' % ','.join(adapt(v) for v in v)
503        if not v:
504            return 'null'
505        if not isinstance(v, basestring):
506            v = self.db.encode_json(v)
507        if self._re_array_quote.search(v):
508            v = '"%s"' % self._re_array_escape.sub(r'\\\1', v)
509        return v
510
511    def _adapt_record(self, v, typ):
512        """Adapt a record parameter with given type."""
513        typ = self.get_attnames(typ).values()
514        if len(typ) != len(v):
515            raise TypeError('Record parameter %s has wrong size' % v)
516        adapt = self.adapt
517        value = []
518        for v, t in zip(v, typ):
519            v = adapt(v, t)
520            if v is None:
521                v = ''
522            elif not v:
523                v = '""'
524            else:
525                if isinstance(v, bytes):
526                    if str is not bytes:
527                        v = v.decode('ascii')
528                else:
529                    v = str(v)
530                if self._re_record_quote.search(v):
531                    v = '"%s"' % self._re_record_escape.sub(r'\\\1', v)
532            value.append(v)
533        return '(%s)' % ','.join(value)
534
535    def adapt(self, value, typ=None):
536        """Adapt a value with known database type."""
537        if value is not None and not isinstance(value, Literal):
538            if typ:
539                simple = self.get_simple_name(typ)
540            else:
541                typ = simple = self.guess_simple_type(value) or 'text'
542            pg_str = getattr(value, '__pg_str__', None)
543            if pg_str:
544                value = pg_str(typ)
545            if simple == 'text':
546                pass
547            elif simple == 'record':
548                if isinstance(value, tuple):
549                    value = self._adapt_record(value, typ)
550            elif simple.endswith('[]'):
551                if isinstance(value, list):
552                    adapt = getattr(self, '_adapt_%s_array' % simple[:-2])
553                    value = adapt(value)
554            else:
555                adapt = getattr(self, '_adapt_%s' % simple)
556                value = adapt(value)
557        return value
558
559    @staticmethod
560    def simple_type(name):
561        """Create a simple database type with given attribute names."""
562        typ = DbType(name)
563        typ.simple = name
564        return typ
565
566    @staticmethod
567    def get_simple_name(typ):
568        """Get the simple name of a database type."""
569        if isinstance(typ, DbType):
570            return typ.simple
571        return _simpletypes[typ]
572
573    @staticmethod
574    def get_attnames(typ):
575        """Get the attribute names of a composite database type."""
576        if isinstance(typ, DbType):
577            return typ.attnames
578        return {}
579
580    _frequent_simple_types = {
581        Bytea: 'bytea',
582        str: 'text',
583        bytes: 'text',
584        bool: 'bool',
585        int: 'int',
586        long: 'int',
587        float: 'float',
588        Decimal: 'num',
589        date: 'date',
590        time: 'date',
591        datetime: 'date',
592        timedelta: 'date'
593    }
594
595    @classmethod
596    def guess_simple_type(cls, value):
597        """Try to guess which database type the given value has."""
598        # optimize for most frequent types
599        simple_type = cls._frequent_simple_types.get(type(value))
600        if simple_type:
601            return simple_type
602        if isinstance(value, Bytea):
603            return 'bytea'
604        if isinstance(value, basestring):
605            return 'text'
606        if isinstance(value, bool):
607            return 'bool'
608        if isinstance(value, (int, long)):
609            return 'int'
610        if isinstance(value, float):
611            return 'float'
612        if isinstance(value, Decimal):
613            return 'num'
614        if isinstance(value, (date, time, datetime, timedelta)):
615            return 'date'
616        if isinstance(value, list):
617            return '%s[]' % (cls.guess_simple_base_type(value) or 'text',)
618        if isinstance(value, tuple):
619            simple_type = cls.simple_type
620            guess = cls.guess_simple_type
621
622            def get_attnames(self):
623                return AttrDict((str(n + 1), simple_type(guess(v)))
624                    for n, v in enumerate(value))
625
626            typ = simple_type('record')
627            typ._get_attnames = get_attnames
628            return typ
629
630    @classmethod
631    def guess_simple_base_type(cls, value):
632        """Try to guess the base type of a given array."""
633        for v in value:
634            if isinstance(v, list):
635                typ = cls.guess_simple_base_type(v)
636            else:
637                typ = cls.guess_simple_type(v)
638            if typ:
639                return typ
640
641    def adapt_inline(self, value, nested=False):
642        """Adapt a value that is put into the SQL and needs to be quoted."""
643        if value is None:
644            return 'NULL'
645        if isinstance(value, Literal):
646            return value
647        if isinstance(value, Bytea):
648            value = self.db.escape_bytea(value)
649            if bytes is not str:  # Python >= 3.0
650                value = value.decode('ascii')
651        elif isinstance(value, Json):
652            if value.encode:
653                return value.encode()
654            value = self.db.encode_json(value)
655        elif isinstance(value, (datetime, date, time, timedelta)):
656            value = str(value)
657        if isinstance(value, basestring):
658            value = self.db.escape_string(value)
659            return "'%s'" % value
660        if isinstance(value, bool):
661            return 'true' if value else 'false'
662        if isinstance(value, float):
663            if isinf(value):
664                return "'-Infinity'" if value < 0 else "'Infinity'"
665            if isnan(value):
666                return "'NaN'"
667            return value
668        if isinstance(value, (int, long, Decimal)):
669            return value
670        if isinstance(value, list):
671            q = self.adapt_inline
672            s = '[%s]' if nested else 'ARRAY[%s]'
673            return s % ','.join(str(q(v, nested=True)) for v in value)
674        if isinstance(value, tuple):
675            q = self.adapt_inline
676            return '(%s)' % ','.join(str(q(v)) for v in value)
677        pg_repr = getattr(value, '__pg_repr__', None)
678        if not pg_repr:
679            raise InterfaceError(
680                'Do not know how to adapt type %s' % type(value))
681        value = pg_repr()
682        if isinstance(value, (tuple, list)):
683            value = self.adapt_inline(value)
684        return value
685
686    def parameter_list(self):
687        """Return a parameter list for parameters with known database types.
688
689        The list has an add(value, typ) method that will build up the
690        list and return either the literal value or a placeholder.
691        """
692        params = _ParameterList()
693        params.adapt = self.adapt
694        return params
695
696    def format_query(self, command, values=None, types=None, inline=False):
697        """Format a database query using the given values and types."""
698        if not values:
699            return command, []
700        if inline and types:
701            raise ValueError('Typed parameters must be sent separately')
702        params = self.parameter_list()
703        if isinstance(values, (list, tuple)):
704            if inline:
705                adapt = self.adapt_inline
706                literals = [adapt(value) for value in values]
707            else:
708                add = params.add
709                if types:
710                    if (not isinstance(types, (list, tuple)) or
711                            len(types) != len(values)):
712                        raise TypeError('The values and types do not match')
713                    literals = [add(value, typ)
714                        for value, typ in zip(values, types)]
715                else:
716                    literals = [add(value) for value in values]
717            command %= tuple(literals)
718        elif isinstance(values, dict):
719            # we want to allow extra keys in the dictionary,
720            # so we first must find the values actually used in the command
721            used_values = {}
722            literals = dict.fromkeys(values, '')
723            for key in values:
724                del literals[key]
725                try:
726                    command % literals
727                except KeyError:
728                    used_values[key] = values[key]
729                literals[key] = ''
730            values = used_values
731            if inline:
732                adapt = self.adapt_inline
733                literals = dict((key, adapt(value))
734                    for key, value in values.items())
735            else:
736                add = params.add
737                if types:
738                    if not isinstance(types, dict):
739                        raise TypeError('The values and types do not match')
740                    literals = dict((key, add(values[key], types.get(key)))
741                        for key in sorted(values))
742                else:
743                    literals = dict((key, add(values[key]))
744                        for key in sorted(values))
745            command %= literals
746        else:
747            raise TypeError('The values must be passed as tuple, list or dict')
748        return command, params
749
750
751def cast_bool(value):
752    """Cast a boolean value."""
753    if not get_bool():
754        return value
755    return value[0] == 't'
756
757
758def cast_json(value):
759    """Cast a JSON value."""
760    cast = get_jsondecode()
761    if not cast:
762        return value
763    return cast(value)
764
765
766def cast_num(value):
767    """Cast a numeric value."""
768    return (get_decimal() or float)(value)
769
770
771def cast_money(value):
772    """Cast a money value."""
773    point = get_decimal_point()
774    if not point:
775        return value
776    if point != '.':
777        value = value.replace(point, '.')
778    value = value.replace('(', '-')
779    value = ''.join(c for c in value if c.isdigit() or c in '.-')
780    return (get_decimal() or float)(value)
781
782
783def cast_int2vector(value):
784    """Cast an int2vector value."""
785    return [int(v) for v in value.split()]
786
787
788def cast_date(value, connection):
789    """Cast a date value."""
790    # The output format depends on the server setting DateStyle.  The default
791    # setting ISO and the setting for German are actually unambiguous.  The
792    # order of days and months in the other two settings is however ambiguous,
793    # so at least here we need to consult the setting to properly parse values.
794    if value == '-infinity':
795        return date.min
796    if value == 'infinity':
797        return date.max
798    value = value.split()
799    if value[-1] == 'BC':
800        return date.min
801    value = value[0]
802    if len(value) > 10:
803        return date.max
804    fmt = connection.date_format()
805    return datetime.strptime(value, fmt).date()
806
807
808def cast_time(value):
809    """Cast a time value."""
810    fmt = '%H:%M:%S.%f' if len(value) > 8 else '%H:%M:%S'
811    return datetime.strptime(value, fmt).time()
812
813
814_re_timezone = regex('(.*)([+-].*)')
815
816
817def cast_timetz(value):
818    """Cast a timetz value."""
819    tz = _re_timezone.match(value)
820    if tz:
821        value, tz = tz.groups()
822    else:
823        tz = '+0000'
824    fmt = '%H:%M:%S.%f' if len(value) > 8 else '%H:%M:%S'
825    if _has_timezone:
826        value += _timezone_as_offset(tz)
827        fmt += '%z'
828        return datetime.strptime(value, fmt).timetz()
829    return datetime.strptime(value, fmt).timetz().replace(
830        tzinfo=_get_timezone(tz))
831
832
833def cast_timestamp(value, connection):
834    """Cast a timestamp value."""
835    if value == '-infinity':
836        return datetime.min
837    if value == 'infinity':
838        return datetime.max
839    value = value.split()
840    if value[-1] == 'BC':
841        return datetime.min
842    fmt = connection.date_format()
843    if fmt.endswith('-%Y') and len(value) > 2:
844        value = value[1:5]
845        if len(value[3]) > 4:
846            return datetime.max
847        fmt = ['%d %b' if fmt.startswith('%d') else '%b %d',
848            '%H:%M:%S.%f' if len(value[2]) > 8 else '%H:%M:%S', '%Y']
849    else:
850        if len(value[0]) > 10:
851            return datetime.max
852        fmt = [fmt, '%H:%M:%S.%f' if len(value[1]) > 8 else '%H:%M:%S']
853    return datetime.strptime(' '.join(value), ' '.join(fmt))
854
855
856def cast_timestamptz(value, connection):
857    """Cast a timestamptz value."""
858    if value == '-infinity':
859        return datetime.min
860    if value == 'infinity':
861        return datetime.max
862    value = value.split()
863    if value[-1] == 'BC':
864        return datetime.min
865    fmt = connection.date_format()
866    if fmt.endswith('-%Y') and len(value) > 2:
867        value = value[1:]
868        if len(value[3]) > 4:
869            return datetime.max
870        fmt = ['%d %b' if fmt.startswith('%d') else '%b %d',
871            '%H:%M:%S.%f' if len(value[2]) > 8 else '%H:%M:%S', '%Y']
872        value, tz = value[:-1], value[-1]
873    else:
874        if fmt.startswith('%Y-'):
875            tz = _re_timezone.match(value[1])
876            if tz:
877                value[1], tz = tz.groups()
878            else:
879                tz = '+0000'
880        else:
881            value, tz = value[:-1], value[-1]
882        if len(value[0]) > 10:
883            return datetime.max
884        fmt = [fmt, '%H:%M:%S.%f' if len(value[1]) > 8 else '%H:%M:%S']
885    if _has_timezone:
886        value.append(_timezone_as_offset(tz))
887        fmt.append('%z')
888        return datetime.strptime(' '.join(value), ' '.join(fmt))
889    return datetime.strptime(' '.join(value), ' '.join(fmt)).replace(
890        tzinfo=_get_timezone(tz))
891
892
893_re_interval_sql_standard = regex(
894    '(?:([+-])?([0-9]+)-([0-9]+) ?)?'
895    '(?:([+-]?[0-9]+)(?!:) ?)?'
896    '(?:([+-])?([0-9]+):([0-9]+):([0-9]+)(?:\\.([0-9]+))?)?')
897
898_re_interval_postgres = regex(
899    '(?:([+-]?[0-9]+) ?years? ?)?'
900    '(?:([+-]?[0-9]+) ?mons? ?)?'
901    '(?:([+-]?[0-9]+) ?days? ?)?'
902    '(?:([+-])?([0-9]+):([0-9]+):([0-9]+)(?:\\.([0-9]+))?)?')
903
904_re_interval_postgres_verbose = regex(
905    '@ ?(?:([+-]?[0-9]+) ?years? ?)?'
906    '(?:([+-]?[0-9]+) ?mons? ?)?'
907    '(?:([+-]?[0-9]+) ?days? ?)?'
908    '(?:([+-]?[0-9]+) ?hours? ?)?'
909    '(?:([+-]?[0-9]+) ?mins? ?)?'
910    '(?:([+-])?([0-9]+)(?:\\.([0-9]+))? ?secs?)? ?(ago)?')
911
912_re_interval_iso_8601 = regex(
913    'P(?:([+-]?[0-9]+)Y)?'
914    '(?:([+-]?[0-9]+)M)?'
915    '(?:([+-]?[0-9]+)D)?'
916    '(?:T(?:([+-]?[0-9]+)H)?'
917    '(?:([+-]?[0-9]+)M)?'
918    '(?:([+-])?([0-9]+)(?:\\.([0-9]+))?S)?)?')
919
920
921def cast_interval(value):
922    """Cast an interval value."""
923    # The output format depends on the server setting IntervalStyle, but it's
924    # not necessary to consult this setting to parse it.  It's faster to just
925    # check all possible formats, and there is no ambiguity here.
926    m = _re_interval_iso_8601.match(value)
927    if m:
928        m = [d or '0' for d in m.groups()]
929        secs_ago = m.pop(5) == '-'
930        m = [int(d) for d in m]
931        years, mons, days, hours, mins, secs, usecs = m
932        if secs_ago:
933            secs = -secs
934            usecs = -usecs
935    else:
936        m = _re_interval_postgres_verbose.match(value)
937        if m:
938            m, ago = [d or '0' for d in m.groups()[:8]], m.group(9)
939            secs_ago = m.pop(5) == '-'
940            m = [-int(d) for d in m] if ago else [int(d) for d in m]
941            years, mons, days, hours, mins, secs, usecs = m
942            if secs_ago:
943                secs = - secs
944                usecs = -usecs
945        else:
946            m = _re_interval_postgres.match(value)
947            if m and any(m.groups()):
948                m = [d or '0' for d in m.groups()]
949                hours_ago = m.pop(3) == '-'
950                m = [int(d) for d in m]
951                years, mons, days, hours, mins, secs, usecs = m
952                if hours_ago:
953                    hours = -hours
954                    mins = -mins
955                    secs = -secs
956                    usecs = -usecs
957            else:
958                m = _re_interval_sql_standard.match(value)
959                if m and any(m.groups()):
960                    m = [d or '0' for d in m.groups()]
961                    years_ago = m.pop(0) == '-'
962                    hours_ago = m.pop(3) == '-'
963                    m = [int(d) for d in m]
964                    years, mons, days, hours, mins, secs, usecs = m
965                    if years_ago:
966                        years = -years
967                        mons = -mons
968                    if hours_ago:
969                        hours = -hours
970                        mins = -mins
971                        secs = -secs
972                        usecs = -usecs
973                else:
974                    raise ValueError('Cannot parse interval: %s' % value)
975    days += 365 * years + 30 * mons
976    return timedelta(days=days, hours=hours, minutes=mins,
977        seconds=secs, microseconds=usecs)
978
979
980class Typecasts(dict):
981    """Dictionary mapping database types to typecast functions.
982
983    The cast functions get passed the string representation of a value in
984    the database which they need to convert to a Python object.  The
985    passed string will never be None since NULL values are already
986    handled before the cast function is called.
987
988    Note that the basic types are already handled by the C extension.
989    They only need to be handled here as record or array components.
990    """
991
992    # the default cast functions
993    # (str functions are ignored but have been added for faster access)
994    defaults = {'char': str, 'bpchar': str, 'name': str,
995        'text': str, 'varchar': str,
996        'bool': cast_bool, 'bytea': unescape_bytea,
997        'int2': int, 'int4': int, 'serial': int, 'int8': long, 'oid': int,
998        'hstore': cast_hstore, 'json': cast_json, 'jsonb': cast_json,
999        'float4': float, 'float8': float,
1000        'numeric': cast_num, 'money': cast_money,
1001        'date': cast_date, 'interval': cast_interval,
1002        'time': cast_time, 'timetz': cast_timetz,
1003        'timestamp': cast_timestamp, 'timestamptz': cast_timestamptz,
1004        'int2vector': cast_int2vector, 'uuid': UUID,
1005        'anyarray': cast_array, 'record': cast_record}
1006
1007    connection = None  # will be set in a connection specific instance
1008
1009    def __missing__(self, typ):
1010        """Create a cast function if it is not cached.
1011
1012        Note that this class never raises a KeyError,
1013        but returns None when no special cast function exists.
1014        """
1015        if not isinstance(typ, str):
1016            raise TypeError('Invalid type: %s' % typ)
1017        cast = self.defaults.get(typ)
1018        if cast:
1019            # store default for faster access
1020            cast = self._add_connection(cast)
1021            self[typ] = cast
1022        elif typ.startswith('_'):
1023            base_cast = self[typ[1:]]
1024            cast = self.create_array_cast(base_cast)
1025            if base_cast:
1026                self[typ] = cast
1027        else:
1028            attnames = self.get_attnames(typ)
1029            if attnames:
1030                casts = [self[v.pgtype] for v in attnames.values()]
1031                cast = self.create_record_cast(typ, attnames, casts)
1032                self[typ] = cast
1033        return cast
1034
1035    @staticmethod
1036    def _needs_connection(func):
1037        """Check if a typecast function needs a connection argument."""
1038        try:
1039            args = get_args(func)
1040        except (TypeError, ValueError):
1041            return False
1042        else:
1043            return 'connection' in args[1:]
1044
1045    def _add_connection(self, cast):
1046        """Add a connection argument to the typecast function if necessary."""
1047        if not self.connection or not self._needs_connection(cast):
1048            return cast
1049        return partial(cast, connection=self.connection)
1050
1051    def get(self, typ, default=None):
1052        """Get the typecast function for the given database type."""
1053        return self[typ] or default
1054
1055    def set(self, typ, cast):
1056        """Set a typecast function for the specified database type(s)."""
1057        if isinstance(typ, basestring):
1058            typ = [typ]
1059        if cast is None:
1060            for t in typ:
1061                self.pop(t, None)
1062                self.pop('_%s' % t, None)
1063        else:
1064            if not callable(cast):
1065                raise TypeError("Cast parameter must be callable")
1066            for t in typ:
1067                self[t] = self._add_connection(cast)
1068                self.pop('_%s' % t, None)
1069
1070    def reset(self, typ=None):
1071        """Reset the typecasts for the specified type(s) to their defaults.
1072
1073        When no type is specified, all typecasts will be reset.
1074        """
1075        if typ is None:
1076            self.clear()
1077        else:
1078            if isinstance(typ, basestring):
1079                typ = [typ]
1080            for t in typ:
1081                self.pop(t, None)
1082
1083    @classmethod
1084    def get_default(cls, typ):
1085        """Get the default typecast function for the given database type."""
1086        return cls.defaults.get(typ)
1087
1088    @classmethod
1089    def set_default(cls, typ, cast):
1090        """Set a default typecast function for the given database type(s)."""
1091        if isinstance(typ, basestring):
1092            typ = [typ]
1093        defaults = cls.defaults
1094        if cast is None:
1095            for t in typ:
1096                defaults.pop(t, None)
1097                defaults.pop('_%s' % t, None)
1098        else:
1099            if not callable(cast):
1100                raise TypeError("Cast parameter must be callable")
1101            for t in typ:
1102                defaults[t] = cast
1103                defaults.pop('_%s' % t, None)
1104
1105    def get_attnames(self, typ):
1106        """Return the fields for the given record type.
1107
1108        This method will be replaced with the get_attnames() method of DbTypes.
1109        """
1110        return {}
1111
1112    def dateformat(self):
1113        """Return the current date format.
1114
1115        This method will be replaced with the dateformat() method of DbTypes.
1116        """
1117        return '%Y-%m-%d'
1118
1119    def create_array_cast(self, basecast):
1120        """Create an array typecast for the given base cast."""
1121        cast_array = self['anyarray']
1122        def cast(v):
1123            return cast_array(v, basecast)
1124        return cast
1125
1126    def create_record_cast(self, name, fields, casts):
1127        """Create a named record typecast for the given fields and casts."""
1128        cast_record = self['record']
1129        record = namedtuple(name, fields)
1130        def cast(v):
1131            return record(*cast_record(v, casts))
1132        return cast
1133
1134
1135def get_typecast(typ):
1136    """Get the global typecast function for the given database type(s)."""
1137    return Typecasts.get_default(typ)
1138
1139
1140def set_typecast(typ, cast):
1141    """Set a global typecast function for the given database type(s).
1142
1143    Note that connections cache cast functions. To be sure a global change
1144    is picked up by a running connection, call db.db_types.reset_typecast().
1145    """
1146    Typecasts.set_default(typ, cast)
1147
1148
1149class DbType(str):
1150    """Class augmenting the simple type name with additional info.
1151
1152    The following additional information is provided:
1153
1154        oid: the PostgreSQL type OID
1155        pgtype: the internal PostgreSQL data type name
1156        regtype: the registered PostgreSQL data type name
1157        simple: the more coarse-grained PyGreSQL type name
1158        typtype: b = base type, c = composite type etc.
1159        category: A = Array, b = Boolean, C = Composite etc.
1160        delim: delimiter for array types
1161        relid: corresponding table for composite types
1162        attnames: attributes for composite types
1163    """
1164
1165    @property
1166    def attnames(self):
1167        """Get names and types of the fields of a composite type."""
1168        return self._get_attnames(self)
1169
1170
1171class DbTypes(dict):
1172    """Cache for PostgreSQL data types.
1173
1174    This cache maps type OIDs and names to DbType objects containing
1175    information on the associated database type.
1176    """
1177
1178    _num_types = frozenset('int float num money'
1179        ' int2 int4 int8 float4 float8 numeric money'.split())
1180
1181    def __init__(self, db):
1182        """Initialize type cache for connection."""
1183        super(DbTypes, self).__init__()
1184        self._db = weakref.proxy(db)
1185        self._regtypes = False
1186        self._typecasts = Typecasts()
1187        self._typecasts.get_attnames = self.get_attnames
1188        self._typecasts.connection = self._db
1189        if db.server_version < 80400:
1190            # older remote databases (not officially supported)
1191            self._query_pg_type = (
1192                "SELECT oid, typname, typname::text::regtype,"
1193                " typtype, null as typcategory, typdelim, typrelid"
1194                " FROM pg_type WHERE oid=%s::regtype")
1195        else:
1196            self._query_pg_type = (
1197                "SELECT oid, typname, typname::regtype,"
1198                " typtype, typcategory, typdelim, typrelid"
1199                " FROM pg_type WHERE oid=%s::regtype")
1200
1201    def add(self, oid, pgtype, regtype,
1202               typtype, category, delim, relid):
1203        """Create a PostgreSQL type name with additional info."""
1204        if oid in self:
1205            return self[oid]
1206        simple = 'record' if relid else _simpletypes[pgtype]
1207        typ = DbType(regtype if self._regtypes else simple)
1208        typ.oid = oid
1209        typ.simple = simple
1210        typ.pgtype = pgtype
1211        typ.regtype = regtype
1212        typ.typtype = typtype
1213        typ.category = category
1214        typ.delim = delim
1215        typ.relid = relid
1216        typ._get_attnames = self.get_attnames
1217        return typ
1218
1219    def __missing__(self, key):
1220        """Get the type info from the database if it is not cached."""
1221        try:
1222            q = self._query_pg_type % (_quote_if_unqualified('$1', key),)
1223            res = self._db.query(q, (key,)).getresult()
1224        except ProgrammingError:
1225            res = None
1226        if not res:
1227            raise KeyError('Type %s could not be found' % key)
1228        res = res[0]
1229        typ = self.add(*res)
1230        self[typ.oid] = self[typ.pgtype] = typ
1231        return typ
1232
1233    def get(self, key, default=None):
1234        """Get the type even if it is not cached."""
1235        try:
1236            return self[key]
1237        except KeyError:
1238            return default
1239
1240    def get_attnames(self, typ):
1241        """Get names and types of the fields of a composite type."""
1242        if not isinstance(typ, DbType):
1243            typ = self.get(typ)
1244            if not typ:
1245                return None
1246        if not typ.relid:
1247            return None
1248        return self._db.get_attnames(typ.relid, with_oid=False)
1249
1250    def get_typecast(self, typ):
1251        """Get the typecast function for the given database type."""
1252        return self._typecasts.get(typ)
1253
1254    def set_typecast(self, typ, cast):
1255        """Set a typecast function for the specified database type(s)."""
1256        self._typecasts.set(typ, cast)
1257
1258    def reset_typecast(self, typ=None):
1259        """Reset the typecast function for the specified database type(s)."""
1260        self._typecasts.reset(typ)
1261
1262    def typecast(self, value, typ):
1263        """Cast the given value according to the given database type."""
1264        if value is None:
1265            # for NULL values, no typecast is necessary
1266            return None
1267        if not isinstance(typ, DbType):
1268            typ = self.get(typ)
1269            if typ:
1270                typ = typ.pgtype
1271        cast = self.get_typecast(typ) if typ else None
1272        if not cast or cast is str:
1273            # no typecast is necessary
1274            return value
1275        return cast(value)
1276
1277
1278_re_fieldname = regex('^[A-Za-z][_a-zA-Z0-9]*$')
1279
1280# The result rows for database operations are returned as named tuples
1281# by default. Since creating namedtuple classes is a somewhat expensive
1282# operation, we cache up to 1024 of these classes by default.
1283
1284@lru_cache(maxsize=1024)
1285def _row_factory(names):
1286    """Get a namedtuple factory for row results with the given names."""
1287    try:
1288        try:
1289            return namedtuple('Row', names, rename=True)._make
1290        except TypeError:  # Python 2.6 and 3.0 do not support rename
1291            names = [v if _re_fieldname.match(v) and not iskeyword(v)
1292                        else 'column_%d' % (n,)
1293                     for n, v in enumerate(names)]
1294            return namedtuple('Row', names)._make
1295    except ValueError:  # there is still a problem with the field names
1296        names = ['column_%d' % (n,) for n in range(len(names))]
1297        return namedtuple('Row', names)._make
1298
1299
1300def set_row_factory_size(maxsize):
1301    """Change the size of the namedtuple factory cache.
1302
1303    If maxsize is set to None, the cache can grow without bound.
1304    """
1305    global _row_factory
1306    _row_factory = lru_cache(maxsize)(_row_factory.__wrapped__)
1307
1308
1309# Helper functions used by the query object
1310
1311def _dictiter(q):
1312    """Get query result as an iterator of dictionaries."""
1313    fields = q.listfields()
1314    for r in q:
1315        yield dict(zip(fields, r))
1316
1317
1318def _namediter(q):
1319    """Get query result as an iterator of named tuples."""
1320    row = _row_factory(q.listfields())
1321    for r in q:
1322        yield row(r)
1323
1324
1325def _namednext(q):
1326    """Get next row from query result as a named tuple."""
1327    return _row_factory(q.listfields())(next(q))
1328
1329
1330def _scalariter(q):
1331    """Get query result as an iterator of scalar values."""
1332    for r in q:
1333        yield r[0]
1334
1335
1336class _MemoryQuery:
1337    """Class that embodies a given query result."""
1338
1339    def __init__(self, result, fields):
1340        """Create query from given result rows and field names."""
1341        self.result = result
1342        self.fields = tuple(fields)
1343
1344    def listfields(self):
1345        """Return the stored field names of this query."""
1346        return self.fields
1347
1348    def getresult(self):
1349        """Return the stored result of this query."""
1350        return self.result
1351
1352    def __iter__(self):
1353        return iter(self.result)
1354
1355
1356def _db_error(msg, cls=DatabaseError):
1357    """Return DatabaseError with empty sqlstate attribute."""
1358    error = cls(msg)
1359    error.sqlstate = None
1360    return error
1361
1362
1363def _int_error(msg):
1364    """Return InternalError."""
1365    return _db_error(msg, InternalError)
1366
1367
1368def _prg_error(msg):
1369    """Return ProgrammingError."""
1370    return _db_error(msg, ProgrammingError)
1371
1372
1373# Initialize the C module
1374
1375set_decimal(Decimal)
1376set_jsondecode(jsondecode)
1377set_query_helpers(_dictiter, _namediter, _namednext, _scalariter)
1378
1379
1380# The notification handler
1381
1382class NotificationHandler(object):
1383    """A PostgreSQL client-side asynchronous notification handler."""
1384
1385    def __init__(self, db, event, callback=None,
1386            arg_dict=None, timeout=None, stop_event=None):
1387        """Initialize the notification handler.
1388
1389        You must pass a PyGreSQL database connection, the name of an
1390        event (notification channel) to listen for and a callback function.
1391
1392        You can also specify a dictionary arg_dict that will be passed as
1393        the single argument to the callback function, and a timeout value
1394        in seconds (a floating point number denotes fractions of seconds).
1395        If it is absent or None, the callers will never time out.  If the
1396        timeout is reached, the callback function will be called with a
1397        single argument that is None.  If you set the timeout to zero,
1398        the handler will poll notifications synchronously and return.
1399
1400        You can specify the name of the event that will be used to signal
1401        the handler to stop listening as stop_event. By default, it will
1402        be the event name prefixed with 'stop_'.
1403        """
1404        self.db = db
1405        self.event = event
1406        self.stop_event = stop_event or 'stop_%s' % event
1407        self.listening = False
1408        self.callback = callback
1409        if arg_dict is None:
1410            arg_dict = {}
1411        self.arg_dict = arg_dict
1412        self.timeout = timeout
1413
1414    def __del__(self):
1415        self.unlisten()
1416
1417    def close(self):
1418        """Stop listening and close the connection."""
1419        if self.db:
1420            self.unlisten()
1421            self.db.close()
1422            self.db = None
1423
1424    def listen(self):
1425        """Start listening for the event and the stop event."""
1426        if not self.listening:
1427            self.db.query('listen "%s"' % self.event)
1428            self.db.query('listen "%s"' % self.stop_event)
1429            self.listening = True
1430
1431    def unlisten(self):
1432        """Stop listening for the event and the stop event."""
1433        if self.listening:
1434            self.db.query('unlisten "%s"' % self.event)
1435            self.db.query('unlisten "%s"' % self.stop_event)
1436            self.listening = False
1437
1438    def notify(self, db=None, stop=False, payload=None):
1439        """Generate a notification.
1440
1441        Optionally, you can pass a payload with the notification.
1442
1443        If you set the stop flag, a stop notification will be sent that
1444        will cause the handler to stop listening.
1445
1446        Note: If the notification handler is running in another thread, you
1447        must pass a different database connection since PyGreSQL database
1448        connections are not thread-safe.
1449        """
1450        if self.listening:
1451            if not db:
1452                db = self.db
1453            q = 'notify "%s"' % (self.stop_event if stop else self.event)
1454            if payload:
1455                q += ", '%s'" % payload
1456            return db.query(q)
1457
1458    def __call__(self):
1459        """Invoke the notification handler.
1460
1461        The handler is a loop that listens for notifications on the event
1462        and stop event channels.  When either of these notifications are
1463        received, its associated 'pid', 'event' and 'extra' (the payload
1464        passed with the notification) are inserted into its arg_dict
1465        dictionary and the callback is invoked with this dictionary as
1466        a single argument.  When the handler receives a stop event, it
1467        stops listening to both events and return.
1468
1469        In the special case that the timeout of the handler has been set
1470        to zero, the handler will poll all events synchronously and return.
1471        If will keep listening until it receives a stop event.
1472
1473        Note: If you run this loop in another thread, don't use the same
1474        database connection for database operations in the main thread.
1475        """
1476        self.listen()
1477        poll = self.timeout == 0
1478        if not poll:
1479            rlist = [self.db.fileno()]
1480        while self.listening:
1481            if poll or select.select(rlist, [], [], self.timeout)[0]:
1482                while self.listening:
1483                    notice = self.db.getnotify()
1484                    if not notice:  # no more messages
1485                        break
1486                    event, pid, extra = notice
1487                    if event not in (self.event, self.stop_event):
1488                        self.unlisten()
1489                        raise _db_error(
1490                            'Listening for "%s" and "%s", but notified of "%s"'
1491                            % (self.event, self.stop_event, event))
1492                    if event == self.stop_event:
1493                        self.unlisten()
1494                    self.arg_dict.update(pid=pid, event=event, extra=extra)
1495                    self.callback(self.arg_dict)
1496                if poll:
1497                    break
1498            else:   # we timed out
1499                self.unlisten()
1500                self.callback(None)
1501
1502
1503def pgnotify(*args, **kw):
1504    """Same as NotificationHandler, under the traditional name."""
1505    warnings.warn("pgnotify is deprecated, use NotificationHandler instead",
1506        DeprecationWarning, stacklevel=2)
1507    return NotificationHandler(*args, **kw)
1508
1509
1510# The actual PostgreSQL database connection interface:
1511
1512class DB:
1513    """Wrapper class for the _pg connection type."""
1514
1515    db = None  # invalid fallback for underlying connection
1516
1517    def __init__(self, *args, **kw):
1518        """Create a new connection
1519
1520        You can pass either the connection parameters or an existing
1521        _pg or pgdb connection. This allows you to use the methods
1522        of the classic pg interface with a DB-API 2 pgdb connection.
1523        """
1524        if not args and len(kw) == 1:
1525            db = kw.get('db')
1526        elif not kw and len(args) == 1:
1527            db = args[0]
1528        else:
1529            db = None
1530        if db:
1531            if isinstance(db, DB):
1532                db = db.db
1533            else:
1534                try:
1535                    db = db._cnx
1536                except AttributeError:
1537                    pass
1538        if not db or not hasattr(db, 'db') or not hasattr(db, 'query'):
1539            db = connect(*args, **kw)
1540            self._db_args = args, kw
1541            self._closeable = True
1542        else:
1543            self._db_args = db
1544            self._closeable = False
1545        self.db = db
1546        self.dbname = db.db
1547        self._regtypes = False
1548        self._attnames = {}
1549        self._pkeys = {}
1550        self._privileges = {}
1551        self.adapter = Adapter(self)
1552        self.dbtypes = DbTypes(self)
1553        if db.server_version < 80400:
1554            # support older remote data bases
1555            self._query_attnames = (
1556                "SELECT a.attname, t.oid, t.typname, t.typname::text::regtype,"
1557                " t.typtype, null as typcategory, t.typdelim, t.typrelid"
1558                " FROM pg_attribute a"
1559                " JOIN pg_type t ON t.oid = a.atttypid"
1560                " WHERE a.attrelid = %s::regclass AND %s"
1561                " AND NOT a.attisdropped ORDER BY a.attnum")
1562        else:
1563            self._query_attnames = (
1564                "SELECT a.attname, t.oid, t.typname, t.typname::regtype,"
1565                " t.typtype, t.typcategory, t.typdelim, t.typrelid"
1566                " FROM pg_attribute a"
1567                " JOIN pg_type t ON t.oid = a.atttypid"
1568                " WHERE a.attrelid = %s::regclass AND %s"
1569                " AND NOT a.attisdropped ORDER BY a.attnum")
1570        db.set_cast_hook(self.dbtypes.typecast)
1571        self.debug = None  # For debugging scripts, this can be set
1572            # * to a string format specification (e.g. in CGI set to "%s<BR>"),
1573            # * to a file object to write debug statements or
1574            # * to a callable object which takes a string argument
1575            # * to any other true value to just print debug statements
1576
1577    def __getattr__(self, name):
1578        # All undefined members are same as in underlying connection:
1579        if self.db:
1580            return getattr(self.db, name)
1581        else:
1582            raise _int_error('Connection is not valid')
1583
1584    def __dir__(self):
1585        # Custom dir function including the attributes of the connection:
1586        attrs = set(self.__class__.__dict__)
1587        attrs.update(self.__dict__)
1588        attrs.update(dir(self.db))
1589        return sorted(attrs)
1590
1591    # Context manager methods
1592
1593    def __enter__(self):
1594        """Enter the runtime context. This will start a transaction."""
1595        self.begin()
1596        return self
1597
1598    def __exit__(self, et, ev, tb):
1599        """Exit the runtime context. This will end the transaction."""
1600        if et is None and ev is None and tb is None:
1601            self.commit()
1602        else:
1603            self.rollback()
1604
1605    def __del__(self):
1606        try:
1607            db = self.db
1608        except AttributeError:
1609            db = None
1610        if db:
1611            try:
1612                db.set_cast_hook(None)
1613            except TypeError:
1614                pass  # probably already closed
1615            if self._closeable:
1616                try:
1617                    db.close()
1618                except InternalError:
1619                    pass  # probably already closed
1620
1621    # Auxiliary methods
1622
1623    def _do_debug(self, *args):
1624        """Print a debug message"""
1625        if self.debug:
1626            s = '\n'.join(str(arg) for arg in args)
1627            if isinstance(self.debug, basestring):
1628                print(self.debug % s)
1629            elif hasattr(self.debug, 'write'):
1630                self.debug.write(s + '\n')
1631            elif callable(self.debug):
1632                self.debug(s)
1633            else:
1634                print(s)
1635
1636    def _escape_qualified_name(self, s):
1637        """Escape a qualified name.
1638
1639        Escapes the name for use as an SQL identifier, unless the
1640        name contains a dot, in which case the name is ambiguous
1641        (could be a qualified name or just a name with a dot in it)
1642        and must be quoted manually by the caller.
1643        """
1644        if '.' not in s:
1645            s = self.escape_identifier(s)
1646        return s
1647
1648    @staticmethod
1649    def _make_bool(d):
1650        """Get boolean value corresponding to d."""
1651        return bool(d) if get_bool() else ('t' if d else 'f')
1652
1653    def _list_params(self, params):
1654        """Create a human readable parameter list."""
1655        return ', '.join('$%d=%r' % (n, v) for n, v in enumerate(params, 1))
1656
1657    # Public methods
1658
1659    # escape_string and escape_bytea exist as methods,
1660    # so we define unescape_bytea as a method as well
1661    unescape_bytea = staticmethod(unescape_bytea)
1662
1663    def decode_json(self, s):
1664        """Decode a JSON string coming from the database."""
1665        return (get_jsondecode() or jsondecode)(s)
1666
1667    def encode_json(self, d):
1668        """Encode a JSON string for use within SQL."""
1669        return jsonencode(d)
1670
1671    def close(self):
1672        """Close the database connection."""
1673        # Wraps shared library function so we can track state.
1674        db = self.db
1675        if db:
1676            try:
1677                db.set_cast_hook(None)
1678            except TypeError:
1679                pass  # probably already closed
1680            if self._closeable:
1681                db.close()
1682            self.db = None
1683        else:
1684            raise _int_error('Connection already closed')
1685
1686    def reset(self):
1687        """Reset connection with current parameters.
1688
1689        All derived queries and large objects derived from this connection
1690        will not be usable after this call.
1691
1692        """
1693        if self.db:
1694            self.db.reset()
1695        else:
1696            raise _int_error('Connection already closed')
1697
1698    def reopen(self):
1699        """Reopen connection to the database.
1700
1701        Used in case we need another connection to the same database.
1702        Note that we can still reopen a database that we have closed.
1703
1704        """
1705        # There is no such shared library function.
1706        if self._closeable:
1707            db = connect(*self._db_args[0], **self._db_args[1])
1708            if self.db:
1709                self.db.set_cast_hook(None)
1710                self.db.close()
1711            db.set_cast_hook(self.dbtypes.typecast)
1712            self.db = db
1713        else:
1714            self.db = self._db_args
1715
1716    def begin(self, mode=None):
1717        """Begin a transaction."""
1718        qstr = 'BEGIN'
1719        if mode:
1720            qstr += ' ' + mode
1721        return self.query(qstr)
1722
1723    start = begin
1724
1725    def commit(self):
1726        """Commit the current transaction."""
1727        return self.query('COMMIT')
1728
1729    end = commit
1730
1731    def rollback(self, name=None):
1732        """Roll back the current transaction."""
1733        qstr = 'ROLLBACK'
1734        if name:
1735            qstr += ' TO ' + name
1736        return self.query(qstr)
1737
1738    abort = rollback
1739
1740    def savepoint(self, name):
1741        """Define a new savepoint within the current transaction."""
1742        return self.query('SAVEPOINT ' + name)
1743
1744    def release(self, name):
1745        """Destroy a previously defined savepoint."""
1746        return self.query('RELEASE ' + name)
1747
1748    def get_parameter(self, parameter):
1749        """Get the value of a run-time parameter.
1750
1751        If the parameter is a string, the return value will also be a string
1752        that is the current setting of the run-time parameter with that name.
1753
1754        You can get several parameters at once by passing a list, set or dict.
1755        When passing a list of parameter names, the return value will be a
1756        corresponding list of parameter settings.  When passing a set of
1757        parameter names, a new dict will be returned, mapping these parameter
1758        names to their settings.  Finally, if you pass a dict as parameter,
1759        its values will be set to the current parameter settings corresponding
1760        to its keys.
1761
1762        By passing the special name 'all' as the parameter, you can get a dict
1763        of all existing configuration parameters.
1764        """
1765        if isinstance(parameter, basestring):
1766            parameter = [parameter]
1767            values = None
1768        elif isinstance(parameter, (list, tuple)):
1769            values = []
1770        elif isinstance(parameter, (set, frozenset)):
1771            values = {}
1772        elif isinstance(parameter, dict):
1773            values = parameter
1774        else:
1775            raise TypeError(
1776                'The parameter must be a string, list, set or dict')
1777        if not parameter:
1778            raise TypeError('No parameter has been specified')
1779        params = {} if isinstance(values, dict) else []
1780        for key in parameter:
1781            param = key.strip().lower() if isinstance(
1782                key, basestring) else None
1783            if not param:
1784                raise TypeError('Invalid parameter')
1785            if param == 'all':
1786                q = 'SHOW ALL'
1787                values = self.db.query(q).getresult()
1788                values = dict(value[:2] for value in values)
1789                break
1790            if isinstance(values, dict):
1791                params[param] = key
1792            else:
1793                params.append(param)
1794        else:
1795            for param in params:
1796                q = 'SHOW %s' % (param,)
1797                value = self.db.query(q).getresult()[0][0]
1798                if values is None:
1799                    values = value
1800                elif isinstance(values, list):
1801                    values.append(value)
1802                else:
1803                    values[params[param]] = value
1804        return values
1805
1806    def set_parameter(self, parameter, value=None, local=False):
1807        """Set the value of a run-time parameter.
1808
1809        If the parameter and the value are strings, the run-time parameter
1810        will be set to that value.  If no value or None is passed as a value,
1811        then the run-time parameter will be restored to its default value.
1812
1813        You can set several parameters at once by passing a list of parameter
1814        names, together with a single value that all parameters should be
1815        set to or with a corresponding list of values.  You can also pass
1816        the parameters as a set if you only provide a single value.
1817        Finally, you can pass a dict with parameter names as keys.  In this
1818        case, you should not pass a value, since the values for the parameters
1819        will be taken from the dict.
1820
1821        By passing the special name 'all' as the parameter, you can reset
1822        all existing settable run-time parameters to their default values.
1823
1824        If you set local to True, then the command takes effect for only the
1825        current transaction.  After commit() or rollback(), the session-level
1826        setting takes effect again.  Setting local to True will appear to
1827        have no effect if it is executed outside a transaction, since the
1828        transaction will end immediately.
1829        """
1830        if isinstance(parameter, basestring):
1831            parameter = {parameter: value}
1832        elif isinstance(parameter, (list, tuple)):
1833            if isinstance(value, (list, tuple)):
1834                parameter = dict(zip(parameter, value))
1835            else:
1836                parameter = dict.fromkeys(parameter, value)
1837        elif isinstance(parameter, (set, frozenset)):
1838            if isinstance(value, (list, tuple, set, frozenset)):
1839                value = set(value)
1840                if len(value) == 1:
1841                    value = value.pop()
1842            if not(value is None or isinstance(value, basestring)):
1843                raise ValueError('A single value must be specified'
1844                    ' when parameter is a set')
1845            parameter = dict.fromkeys(parameter, value)
1846        elif isinstance(parameter, dict):
1847            if value is not None:
1848                raise ValueError('A value must not be specified'
1849                    ' when parameter is a dictionary')
1850        else:
1851            raise TypeError(
1852                'The parameter must be a string, list, set or dict')
1853        if not parameter:
1854            raise TypeError('No parameter has been specified')
1855        params = {}
1856        for key, value in parameter.items():
1857            param = key.strip().lower() if isinstance(
1858                key, basestring) else None
1859            if not param:
1860                raise TypeError('Invalid parameter')
1861            if param == 'all':
1862                if value is not None:
1863                    raise ValueError('A value must ot be specified'
1864                        " when parameter is 'all'")
1865                params = {'all': None}
1866                break
1867            params[param] = value
1868        local = ' LOCAL' if local else ''
1869        for param, value in params.items():
1870            if value is None:
1871                q = 'RESET%s %s' % (local, param)
1872            else:
1873                q = 'SET%s %s TO %s' % (local, param, value)
1874            self._do_debug(q)
1875            self.db.query(q)
1876
1877    def query(self, command, *args):
1878        """Execute a SQL command string.
1879
1880        This method simply sends a SQL query to the database.  If the query is
1881        an insert statement that inserted exactly one row into a table that
1882        has OIDs, the return value is the OID of the newly inserted row.
1883        If the query is an update or delete statement, or an insert statement
1884        that did not insert exactly one row in a table with OIDs, then the
1885        number of rows affected is returned as a string.  If it is a statement
1886        that returns rows as a result (usually a select statement, but maybe
1887        also an "insert/update ... returning" statement), this method returns
1888        a Query object that can be accessed via getresult() or dictresult()
1889        or simply printed.  Otherwise, it returns `None`.
1890
1891        The query can contain numbered parameters of the form $1 in place
1892        of any data constant.  Arguments given after the query string will
1893        be substituted for the corresponding numbered parameter.  Parameter
1894        values can also be given as a single list or tuple argument.
1895        """
1896        # Wraps shared library function for debugging.
1897        if not self.db:
1898            raise _int_error('Connection is not valid')
1899        if args:
1900            self._do_debug(command, args)
1901            return self.db.query(command, args)
1902        self._do_debug(command)
1903        return self.db.query(command)
1904
1905    def query_formatted(self, command,
1906            parameters=None, types=None, inline=False):
1907        """Execute a formatted SQL command string.
1908
1909        Similar to query, but using Python format placeholders of the form
1910        %s or %(names)s instead of PostgreSQL placeholders of the form $1.
1911        The parameters must be passed as a tuple, list or dict.  You can
1912        also pass a corresponding tuple, list or dict of database types in
1913        order to format the parameters properly in case there is ambiguity.
1914
1915        If you set inline to True, the parameters will be sent to the database
1916        embedded in the SQL command, otherwise they will be sent separately.
1917        """
1918        return self.query(*self.adapter.format_query(
1919            command, parameters, types, inline))
1920
1921    def query_prepared(self, name, *args):
1922        """Execute a prepared SQL statement.
1923
1924        This works like the query() method, except that instead of passing
1925        the SQL command, you pass the name of a prepared statement.  If you
1926        pass an empty name, the unnamed statement will be executed.
1927        """
1928        if not self.db:
1929            raise _int_error('Connection is not valid')
1930        if name is None:
1931            name = ''
1932        if args:
1933            self._do_debug('EXECUTE', name, args)
1934            return self.db.query_prepared(name, args)
1935        self._do_debug('EXECUTE', name)
1936        return self.db.query_prepared(name)
1937
1938    def prepare(self, name, command):
1939        """Create a prepared SQL statement.
1940
1941        This creates a prepared statement for the given command with the
1942        the given name for later execution with the query_prepared() method.
1943
1944        The name can be empty to create an unnamed statement, in which case
1945        any pre-existing unnamed statement is automatically replaced;
1946        otherwise it is an error if the statement name is already
1947        defined in the current database session. We recommend always using
1948        named queries, since unnamed queries have a limited lifetime and
1949        can be automatically replaced or destroyed by various operations.
1950        """
1951        if not self.db:
1952            raise _int_error('Connection is not valid')
1953        if name is None:
1954            name = ''
1955        self._do_debug('prepare', name, command)
1956        return self.db.prepare(name, command)
1957
1958    def describe_prepared(self, name=None):
1959        """Describe a prepared SQL statement.
1960
1961        This method returns a Query object describing the result columns of
1962        the prepared statement with the given name. If you omit the name,
1963        the unnamed statement will be described if you created one before.
1964        """
1965        if name is None:
1966            name = ''
1967        return self.db.describe_prepared(name)
1968
1969    def delete_prepared(self, name=None):
1970        """Delete a prepared SQL statement
1971
1972        This deallocates a previously prepared SQL statement with the given
1973        name, or deallocates all prepared statements if you do not specify a
1974        name. Note that prepared statements are also deallocated automatically
1975        when the current session ends.
1976        """
1977        q = "DEALLOCATE %s" % (name or 'ALL',)
1978        self._do_debug(q)
1979        return self.db.query(q)
1980
1981    def pkey(self, table, composite=False, flush=False):
1982        """Get or set the primary key of a table.
1983
1984        Single primary keys are returned as strings unless you
1985        set the composite flag.  Composite primary keys are always
1986        represented as tuples.  Note that this raises a KeyError
1987        if the table does not have a primary key.
1988
1989        If flush is set then the internal cache for primary keys will
1990        be flushed.  This may be necessary after the database schema or
1991        the search path has been changed.
1992        """
1993        pkeys = self._pkeys
1994        if flush:
1995            pkeys.clear()
1996            self._do_debug('The pkey cache has been flushed')
1997        try:  # cache lookup
1998            pkey = pkeys[table]
1999        except KeyError:  # cache miss, check the database
2000            q = ("SELECT a.attname, a.attnum, i.indkey FROM pg_index i"
2001                " JOIN pg_attribute a ON a.attrelid = i.indrelid"
2002                " AND a.attnum = ANY(i.indkey)"
2003                " AND NOT a.attisdropped"
2004                " WHERE i.indrelid=%s::regclass"
2005                " AND i.indisprimary ORDER BY a.attnum") % (
2006                    _quote_if_unqualified('$1', table),)
2007            pkey = self.db.query(q, (table,)).getresult()
2008            if not pkey:
2009                raise KeyError('Table %s has no primary key' % table)
2010            # we want to use the order defined in the primary key index here,
2011            # not the order as defined by the columns in the table
2012            if len(pkey) > 1:
2013                indkey = pkey[0][2]
2014                pkey = sorted(pkey, key=lambda row: indkey.index(row[1]))
2015                pkey = tuple(row[0] for row in pkey)
2016            else:
2017                pkey = pkey[0][0]
2018            pkeys[table] = pkey  # cache it
2019        if composite and not isinstance(pkey, tuple):
2020            pkey = (pkey,)
2021        return pkey
2022
2023    def get_databases(self):
2024        """Get list of databases in the system."""
2025        return [s[0] for s in
2026            self.db.query('SELECT datname FROM pg_database').getresult()]
2027
2028    def get_relations(self, kinds=None, system=False):
2029        """Get list of relations in connected database of specified kinds.
2030
2031        If kinds is None or empty, all kinds of relations are returned.
2032        Otherwise kinds can be a string or sequence of type letters
2033        specifying which kind of relations you want to list.
2034
2035        Set the system flag if you want to get the system relations as well.
2036        """
2037        where = []
2038        if kinds:
2039            where.append("r.relkind IN (%s)" %
2040                ','.join("'%s'" % k for k in kinds))
2041        if not system:
2042            where.append("s.nspname NOT SIMILAR"
2043                " TO 'pg/_%|information/_schema' ESCAPE '/'")
2044        where = " WHERE %s" % ' AND '.join(where) if where else ''
2045        q = ("SELECT quote_ident(s.nspname)||'.'||quote_ident(r.relname)"
2046            " FROM pg_class r"
2047            " JOIN pg_namespace s ON s.oid = r.relnamespace%s"
2048            " ORDER BY s.nspname, r.relname") % where
2049        return [r[0] for r in self.db.query(q).getresult()]
2050
2051    def get_tables(self, system=False):
2052        """Return list of tables in connected database.
2053
2054        Set the system flag if you want to get the system tables as well.
2055        """
2056        return self.get_relations('r', system)
2057
2058    def get_attnames(self, table, with_oid=True, flush=False):
2059        """Given the name of a table, dig out the set of attribute names.
2060
2061        Returns a read-only dictionary of attribute names (the names are
2062        the keys, the values are the names of the attributes' types)
2063        with the column names in the proper order if you iterate over it.
2064
2065        If flush is set, then the internal cache for attribute names will
2066        be flushed. This may be necessary after the database schema or
2067        the search path has been changed.
2068
2069        By default, only a limited number of simple types will be returned.
2070        You can get the registered types after calling use_regtypes(True).
2071        """
2072        attnames = self._attnames
2073        if flush:
2074            attnames.clear()
2075            self._do_debug('The attnames cache has been flushed')
2076        try:  # cache lookup
2077            names = attnames[table]
2078        except KeyError:  # cache miss, check the database
2079            q = "a.attnum > 0"
2080            if with_oid:
2081                q = "(%s OR a.attname = 'oid')" % q
2082            q = self._query_attnames % (_quote_if_unqualified('$1', table), q)
2083            names = self.db.query(q, (table,)).getresult()
2084            types = self.dbtypes
2085            names = ((name[0], types.add(*name[1:])) for name in names)
2086            names = AttrDict(names)
2087            attnames[table] = names  # cache it
2088        return names
2089
2090    def use_regtypes(self, regtypes=None):
2091        """Use registered type names instead of simplified type names."""
2092        if regtypes is None:
2093            return self.dbtypes._regtypes
2094        else:
2095            regtypes = bool(regtypes)
2096            if regtypes != self.dbtypes._regtypes:
2097                self.dbtypes._regtypes = regtypes
2098                self._attnames.clear()
2099                self.dbtypes.clear()
2100            return regtypes
2101
2102    def has_table_privilege(self, table, privilege='select', flush=False):
2103        """Check whether current user has specified table privilege.
2104
2105        If flush is set, then the internal cache for table privileges will
2106        be flushed. This may be necessary after privileges have been changed.
2107        """
2108        privileges = self._privileges
2109        if flush:
2110            privileges.clear()
2111            self._do_debug('The privileges cache has been flushed')
2112        privilege = privilege.lower()
2113        try:  # ask cache
2114            ret = privileges[table, privilege]
2115        except KeyError:  # cache miss, ask the database
2116            q = "SELECT has_table_privilege(%s, $2)" % (
2117                _quote_if_unqualified('$1', table),)
2118            q = self.db.query(q, (table, privilege))
2119            ret = q.getresult()[0][0] == self._make_bool(True)
2120            privileges[table, privilege] = ret  # cache it
2121        return ret
2122
2123    def get(self, table, row, keyname=None):
2124        """Get a row from a database table or view.
2125
2126        This method is the basic mechanism to get a single row.  It assumes
2127        that the keyname specifies a unique row.  It must be the name of a
2128        single column or a tuple of column names.  If the keyname is not
2129        specified, then the primary key for the table is used.
2130
2131        If row is a dictionary, then the value for the key is taken from it.
2132        Otherwise, the row must be a single value or a tuple of values
2133        corresponding to the passed keyname or primary key.  The fetched row
2134        from the table will be returned as a new dictionary or used to replace
2135        the existing values when row was passed as a dictionary.
2136
2137        The OID is also put into the dictionary if the table has one, but
2138        in order to allow the caller to work with multiple tables, it is
2139        munged as "oid(table)" using the actual name of the table.
2140        """
2141        if table.endswith('*'):  # hint for descendant tables can be ignored
2142            table = table[:-1].rstrip()
2143        attnames = self.get_attnames(table)
2144        qoid = _oid_key(table) if 'oid' in attnames else None
2145        if keyname and isinstance(keyname, basestring):
2146            keyname = (keyname,)
2147        if qoid and isinstance(row, dict) and qoid in row and 'oid' not in row:
2148            row['oid'] = row[qoid]
2149        if not keyname:
2150            try:  # if keyname is not specified, try using the primary key
2151                keyname = self.pkey(table, True)
2152            except KeyError:  # the table has no primary key
2153                # try using the oid instead
2154                if qoid and isinstance(row, dict) and 'oid' in row:
2155                    keyname = ('oid',)
2156                else:
2157                    raise _prg_error('Table %s has no primary key' % table)
2158            else:  # the table has a primary key
2159                # check whether all key columns have values
2160                if isinstance(row, dict) and not set(keyname).issubset(row):
2161                    # try using the oid instead
2162                    if qoid and 'oid' in row:
2163                        keyname = ('oid',)
2164                    else:
2165                        raise KeyError(
2166                            'Missing value in row for specified keyname')
2167        if not isinstance(row, dict):
2168            if not isinstance(row, (tuple, list)):
2169                row = [row]
2170            if len(keyname) != len(row):
2171                raise KeyError(
2172                    'Differing number of items in keyname and row')
2173            row = dict(zip(keyname, row))
2174        params = self.adapter.parameter_list()
2175        adapt = params.add
2176        col = self.escape_identifier
2177        what = 'oid, *' if qoid else '*'
2178        where = ' AND '.join('%s = %s' % (
2179            col(k), adapt(row[k], attnames[k])) for k in keyname)
2180        if 'oid' in row:
2181            if qoid:
2182                row[qoid] = row['oid']
2183            del row['oid']
2184        q = 'SELECT %s FROM %s WHERE %s LIMIT 1' % (
2185            what, self._escape_qualified_name(table), where)
2186        self._do_debug(q, params)
2187        q = self.db.query(q, params)
2188        res = q.dictresult()
2189        if not res:
2190            raise _db_error('No such record in %s\nwhere %s\nwith %s' % (
2191                table, where, self._list_params(params)))
2192        for n, value in res[0].items():
2193            if qoid and n == 'oid':
2194                n = qoid
2195            row[n] = value
2196        return row
2197
2198    def insert(self, table, row=None, **kw):
2199        """Insert a row into a database table.
2200
2201        This method inserts a row into a table.  The name of the table must
2202        be passed as the first parameter.  The other parameters are used for
2203        providing the data of the row that shall be inserted into the table.
2204        If a dictionary is supplied as the second parameter, it starts with
2205        that.  Otherwise it uses a blank dictionary. Either way the dictionary
2206        is updated from the keywords.
2207
2208        The dictionary is then reloaded with the values actually inserted in
2209        order to pick up values modified by rules, triggers, etc.
2210        """
2211        if table.endswith('*'):  # hint for descendant tables can be ignored
2212            table = table[:-1].rstrip()
2213        if row is None:
2214            row = {}
2215        row.update(kw)
2216        if 'oid' in row:
2217            del row['oid']  # do not insert oid
2218        attnames = self.get_attnames(table)
2219        qoid = _oid_key(table) if 'oid' in attnames else None
2220        params = self.adapter.parameter_list()
2221        adapt = params.add
2222        col = self.escape_identifier
2223        names, values = [], []
2224        for n in attnames:
2225            if n in row:
2226                names.append(col(n))
2227                values.append(adapt(row[n], attnames[n]))
2228        if not names:
2229            raise _prg_error('No column found that can be inserted')
2230        names, values = ', '.join(names), ', '.join(values)
2231        ret = 'oid, *' if qoid else '*'
2232        q = 'INSERT INTO %s (%s) VALUES (%s) RETURNING %s' % (
2233            self._escape_qualified_name(table), names, values, ret)
2234        self._do_debug(q, params)
2235        q = self.db.query(q, params)
2236        res = q.dictresult()
2237        if res:  # this should always be true
2238            for n, value in res[0].items():
2239                if qoid and n == 'oid':
2240                    n = qoid
2241                row[n] = value
2242        return row
2243
2244    def update(self, table, row=None, **kw):
2245        """Update an existing row in a database table.
2246
2247        Similar to insert, but updates an existing row.  The update is based
2248        on the primary key of the table or the OID value as munged by get()
2249        or passed as keyword.  The OID will take precedence if provided, so
2250        that it is possible to update the primary key itself.
2251
2252        The dictionary is then modified to reflect any changes caused by the
2253        update due to triggers, rules, default values, etc.
2254        """
2255        if table.endswith('*'):
2256            table = table[:-1].rstrip()  # need parent table name
2257        attnames = self.get_attnames(table)
2258        qoid = _oid_key(table) if 'oid' in attnames else None
2259        if row is None:
2260            row = {}
2261        elif 'oid' in row:
2262            del row['oid']  # only accept oid key from named args for safety
2263        row.update(kw)
2264        if qoid and qoid in row and 'oid' not in row:
2265            row['oid'] = row[qoid]
2266        if qoid and 'oid' in row:  # try using the oid
2267            keyname = ('oid',)
2268        else:  # try using the primary key
2269            try:
2270                keyname = self.pkey(table, True)
2271            except KeyError:  # the table has no primary key
2272                raise _prg_error('Table %s has no primary key' % table)
2273            # check whether all key columns have values
2274            if not set(keyname).issubset(row):
2275                raise KeyError('Missing value for primary key in row')
2276        params = self.adapter.parameter_list()
2277        adapt = params.add
2278        col = self.escape_identifier
2279        where = ' AND '.join('%s = %s' % (
2280            col(k), adapt(row[k], attnames[k])) for k in keyname)
2281        if 'oid' in row:
2282            if qoid:
2283                row[qoid] = row['oid']
2284            del row['oid']
2285        values = []
2286        keyname = set(keyname)
2287        for n in attnames:
2288            if n in row and n not in keyname:
2289                values.append('%s = %s' % (col(n), adapt(row[n], attnames[n])))
2290        if not values:
2291            return row
2292        values = ', '.join(values)
2293        ret = 'oid, *' if qoid else '*'
2294        q = 'UPDATE %s SET %s WHERE %s RETURNING %s' % (
2295            self._escape_qualified_name(table), values, where, ret)
2296        self._do_debug(q, params)
2297        q = self.db.query(q, params)
2298        res = q.dictresult()
2299        if res:  # may be empty when row does not exist
2300            for n, value in res[0].items():
2301                if qoid and n == 'oid':
2302                    n = qoid
2303                row[n] = value
2304        return row
2305
2306    def upsert(self, table, row=None, **kw):
2307        """Insert a row into a database table with conflict resolution
2308
2309        This method inserts a row into a table, but instead of raising a
2310        ProgrammingError exception in case a row with the same primary key
2311        already exists, an update will be executed instead.  This will be
2312        performed as a single atomic operation on the database, so race
2313        conditions can be avoided.
2314
2315        Like the insert method, the first parameter is the name of the
2316        table and the second parameter can be used to pass the values to
2317        be inserted as a dictionary.
2318
2319        Unlike the insert und update statement, keyword parameters are not
2320        used to modify the dictionary, but to specify which columns shall
2321        be updated in case of a conflict, and in which way:
2322
2323        A value of False or None means the column shall not be updated,
2324        a value of True means the column shall be updated with the value
2325        that has been proposed for insertion, i.e. has been passed as value
2326        in the dictionary.  Columns that are not specified by keywords but
2327        appear as keys in the dictionary are also updated like in the case
2328        keywords had been passed with the value True.
2329
2330        So if in the case of a conflict you want to update every column that
2331        has been passed in the dictionary row, you would call upsert(table, row).
2332        If you don't want to do anything in case of a conflict, i.e. leave
2333        the existing row as it is, call upsert(table, row, **dict.fromkeys(row)).
2334
2335        If you need more fine-grained control of what gets updated, you can
2336        also pass strings in the keyword parameters.  These strings will
2337        be used as SQL expressions for the update columns.  In these
2338        expressions you can refer to the value that already exists in
2339        the table by prefixing the column name with "included.", and to
2340        the value that has been proposed for insertion by prefixing the
2341        column name with the "excluded."
2342
2343        The dictionary is modified in any case to reflect the values in
2344        the database after the operation has completed.
2345
2346        Note: The method uses the PostgreSQL "upsert" feature which is
2347        only available since PostgreSQL 9.5.
2348        """
2349        if table.endswith('*'):  # hint for descendant tables can be ignored
2350            table = table[:-1].rstrip()
2351        if row is None:
2352            row = {}
2353        if 'oid' in row:
2354            del row['oid']  # do not insert oid
2355        if 'oid' in kw:
2356            del kw['oid']  # do not update oid
2357        attnames = self.get_attnames(table)
2358        qoid = _oid_key(table) if 'oid' in attnames else None
2359        params = self.adapter.parameter_list()
2360        adapt = params.add
2361        col = self.escape_identifier
2362        names, values, updates = [], [], []
2363        for n in attnames:
2364            if n in row:
2365                names.append(col(n))
2366                values.append(adapt(row[n], attnames[n]))
2367        names, values = ', '.join(names), ', '.join(values)
2368        try:
2369            keyname = self.pkey(table, True)
2370        except KeyError:
2371            raise _prg_error('Table %s has no primary key' % table)
2372        target = ', '.join(col(k) for k in keyname)
2373        update = []
2374        keyname = set(keyname)
2375        keyname.add('oid')
2376        for n in attnames:
2377            if n not in keyname:
2378                value = kw.get(n, True)
2379                if value:
2380                    if not isinstance(value, basestring):
2381                        value = 'excluded.%s' % col(n)
2382                    update.append('%s = %s' % (col(n), value))
2383        if not values:
2384            return row
2385        do = 'update set %s' % ', '.join(update) if update else 'nothing'
2386        ret = 'oid, *' if qoid else '*'
2387        q = ('INSERT INTO %s AS included (%s) VALUES (%s)'
2388            ' ON CONFLICT (%s) DO %s RETURNING %s') % (
2389                self._escape_qualified_name(table), names, values,
2390                target, do, ret)
2391        self._do_debug(q, params)
2392        try:
2393            q = self.db.query(q, params)
2394        except ProgrammingError:
2395            if self.server_version < 90500:
2396                raise _prg_error(
2397                    'Upsert operation is not supported by PostgreSQL version')
2398            raise  # re-raise original error
2399        res = q.dictresult()
2400        if res:  # may be empty with "do nothing"
2401            for n, value in res[0].items():
2402                if qoid and n == 'oid':
2403                    n = qoid
2404                row[n] = value
2405        else:
2406            self.get(table, row)
2407        return row
2408
2409    def clear(self, table, row=None):
2410        """Clear all the attributes to values determined by the types.
2411
2412        Numeric types are set to 0, Booleans are set to false, and everything
2413        else is set to the empty string.  If the row argument is present,
2414        it is used as the row dictionary and any entries matching attribute
2415        names are cleared with everything else left unchanged.
2416        """
2417        # At some point we will need a way to get defaults from a table.
2418        if row is None:
2419            row = {}  # empty if argument is not present
2420        attnames = self.get_attnames(table)
2421        for n, t in attnames.items():
2422            if n == 'oid':
2423                continue
2424            t = t.simple
2425            if t in DbTypes._num_types:
2426                row[n] = 0
2427            elif t == 'bool':
2428                row[n] = self._make_bool(False)
2429            else:
2430                row[n] = ''
2431        return row
2432
2433    def delete(self, table, row=None, **kw):
2434        """Delete an existing row in a database table.
2435
2436        This method deletes the row from a table.  It deletes based on the
2437        primary key of the table or the OID value as munged by get() or
2438        passed as keyword.  The OID will take precedence if provided.
2439
2440        The return value is the number of deleted rows (i.e. 0 if the row
2441        did not exist and 1 if the row was deleted).
2442
2443        Note that if the row cannot be deleted because e.g. it is still
2444        referenced by another table, this method raises a ProgrammingError.
2445        """
2446        if table.endswith('*'):  # hint for descendant tables can be ignored
2447            table = table[:-1].rstrip()
2448        attnames = self.get_attnames(table)
2449        qoid = _oid_key(table) if 'oid' in attnames else None
2450        if row is None:
2451            row = {}
2452        elif 'oid' in row:
2453            del row['oid']  # only accept oid key from named args for safety
2454        row.update(kw)
2455        if qoid and qoid in row and 'oid' not in row:
2456            row['oid'] = row[qoid]
2457        if qoid and 'oid' in row:  # try using the oid
2458            keyname = ('oid',)
2459        else:  # try using the primary key
2460            try:
2461                keyname = self.pkey(table, True)
2462            except KeyError:  # the table has no primary key
2463                raise _prg_error('Table %s has no primary key' % table)
2464            # check whether all key columns have values
2465            if not set(keyname).issubset(row):
2466                raise KeyError('Missing value for primary key in row')
2467        params = self.adapter.parameter_list()
2468        adapt = params.add
2469        col = self.escape_identifier
2470        where = ' AND '.join('%s = %s' % (
2471            col(k), adapt(row[k], attnames[k])) for k in keyname)
2472        if 'oid' in row:
2473            if qoid:
2474                row[qoid] = row['oid']
2475            del row['oid']
2476        q = 'DELETE FROM %s WHERE %s' % (
2477            self._escape_qualified_name(table), where)
2478        self._do_debug(q, params)
2479        res = self.db.query(q, params)
2480        return int(res)
2481
2482    def truncate(self, table, restart=False, cascade=False, only=False):
2483        """Empty a table or set of tables.
2484
2485        This method quickly removes all rows from the given table or set
2486        of tables.  It has the same effect as an unqualified DELETE on each
2487        table, but since it does not actually scan the tables it is faster.
2488        Furthermore, it reclaims disk space immediately, rather than requiring
2489        a subsequent VACUUM operation. This is most useful on large tables.
2490
2491        If restart is set to True, sequences owned by columns of the truncated
2492        table(s) are automatically restarted.  If cascade is set to True, it
2493        also truncates all tables that have foreign-key references to any of
2494        the named tables.  If the parameter only is not set to True, all the
2495        descendant tables (if any) will also be truncated. Optionally, a '*'
2496        can be specified after the table name to explicitly indicate that
2497        descendant tables are included.
2498        """
2499        if isinstance(table, basestring):
2500            only = {table: only}
2501            table = [table]
2502        elif isinstance(table, (list, tuple)):
2503            if isinstance(only, (list, tuple)):
2504                only = dict(zip(table, only))
2505            else:
2506                only = dict.fromkeys(table, only)
2507        elif isinstance(table, (set, frozenset)):
2508            only = dict.fromkeys(table, only)
2509        else:
2510            raise TypeError('The table must be a string, list or set')
2511        if not (restart is None or isinstance(restart, (bool, int))):
2512            raise TypeError('Invalid type for the restart option')
2513        if not (cascade is None or isinstance(cascade, (bool, int))):
2514            raise TypeError('Invalid type for the cascade option')
2515        tables = []
2516        for t in table:
2517            u = only.get(t)
2518            if not (u is None or isinstance(u, (bool, int))):
2519                raise TypeError('Invalid type for the only option')
2520            if t.endswith('*'):
2521                if u:
2522                    raise ValueError(
2523                        'Contradictory table name and only options')
2524                t = t[:-1].rstrip()
2525            t = self._escape_qualified_name(t)
2526            if u:
2527                t = 'ONLY %s' % t
2528            tables.append(t)
2529        q = ['TRUNCATE', ', '.join(tables)]
2530        if restart:
2531            q.append('RESTART IDENTITY')
2532        if cascade:
2533            q.append('CASCADE')
2534        q = ' '.join(q)
2535        self._do_debug(q)
2536        return self.db.query(q)
2537
2538    def get_as_list(self, table, what=None, where=None,
2539            order=None, limit=None, offset=None, scalar=False):
2540        """Get a table as a list.
2541
2542        This gets a convenient representation of the table as a list
2543        of named tuples in Python.  You only need to pass the name of
2544        the table (or any other SQL expression returning rows).  Note that
2545        by default this will return the full content of the table which
2546        can be huge and overflow your memory.  However, you can control
2547        the amount of data returned using the other optional parameters.
2548
2549        The parameter 'what' can restrict the query to only return a
2550        subset of the table columns.  It can be a string, list or a tuple.
2551        The parameter 'where' can restrict the query to only return a
2552        subset of the table rows.  It can be a string, list or a tuple
2553        of SQL expressions that all need to be fulfilled.  The parameter
2554        'order' specifies the ordering of the rows.  It can also be a
2555        other string, list or a tuple.  If no ordering is specified,
2556        the result will be ordered by the primary key(s) or all columns
2557        if no primary key exists.  You can set 'order' to False if you
2558        don't care about the ordering.  The parameters 'limit' and 'offset'
2559        can be integers specifying the maximum number of rows returned
2560        and a number of rows skipped over.
2561
2562        If you set the 'scalar' option to True, then instead of the
2563        named tuples you will get the first items of these tuples.
2564        This is useful if the result has only one column anyway.
2565        """
2566        if not table:
2567            raise TypeError('The table name is missing')
2568        if what:
2569            if isinstance(what, (list, tuple)):
2570                what = ', '.join(map(str, what))
2571            if order is None:
2572                order = what
2573        else:
2574            what = '*'
2575        q = ['SELECT', what, 'FROM', table]
2576        if where:
2577            if isinstance(where, (list, tuple)):
2578                where = ' AND '.join(map(str, where))
2579            q.extend(['WHERE', where])
2580        if order is None:
2581            try:
2582                order = self.pkey(table, True)
2583            except (KeyError, ProgrammingError):
2584                try:
2585                    order = list(self.get_attnames(table))
2586                except (KeyError, ProgrammingError):
2587                    pass
2588        if order:
2589            if isinstance(order, (list, tuple)):
2590                order = ', '.join(map(str, order))
2591            q.extend(['ORDER BY', order])
2592        if limit:
2593            q.append('LIMIT %d' % limit)
2594        if offset:
2595            q.append('OFFSET %d' % offset)
2596        q = ' '.join(q)
2597        self._do_debug(q)
2598        q = self.db.query(q)
2599        res = q.namedresult()
2600        if res and scalar:
2601            res = [row[0] for row in res]
2602        return res
2603
2604    def get_as_dict(self, table, keyname=None, what=None, where=None,
2605            order=None, limit=None, offset=None, scalar=False):
2606        """Get a table as a dictionary.
2607
2608        This method is similar to get_as_list(), but returns the table
2609        as a Python dict instead of a Python list, which can be even
2610        more convenient. The primary key column(s) of the table will
2611        be used as the keys of the dictionary, while the other column(s)
2612        will be the corresponding values.  The keys will be named tuples
2613        if the table has a composite primary key.  The rows will be also
2614        named tuples unless the 'scalar' option has been set to True.
2615        With the optional parameter 'keyname' you can specify an alternative
2616        set of columns to be used as the keys of the dictionary.  It must
2617        be set as a string, list or a tuple.
2618
2619        If the Python version supports it, the dictionary will be an
2620        OrderedDict using the order specified with the 'order' parameter
2621        or the key column(s) if not specified.  You can set 'order' to False
2622        if you don't care about the ordering.  In this case the returned
2623        dictionary will be an ordinary one.
2624        """
2625        if not table:
2626            raise TypeError('The table name is missing')
2627        if not keyname:
2628            try:
2629                keyname = self.pkey(table, True)
2630            except (KeyError, ProgrammingError):
2631                raise _prg_error('Table %s has no primary key' % table)
2632        if isinstance(keyname, basestring):
2633            keyname = [keyname]
2634        elif not isinstance(keyname, (list, tuple)):
2635            raise KeyError('The keyname must be a string, list or tuple')
2636        if what:
2637            if isinstance(what, (list, tuple)):
2638                what = ', '.join(map(str, what))
2639            if order is None:
2640                order = what
2641        else:
2642            what = '*'
2643        q = ['SELECT', what, 'FROM', table]
2644        if where:
2645            if isinstance(where, (list, tuple)):
2646                where = ' AND '.join(map(str, where))
2647            q.extend(['WHERE', where])
2648        if order is None:
2649            order = keyname
2650        if order:
2651            if isinstance(order, (list, tuple)):
2652                order = ', '.join(map(str, order))
2653            q.extend(['ORDER BY', order])
2654        if limit:
2655            q.append('LIMIT %d' % limit)
2656        if offset:
2657            q.append('OFFSET %d' % offset)
2658        q = ' '.join(q)
2659        self._do_debug(q)
2660        q = self.db.query(q)
2661        res = q.getresult()
2662        cls = OrderedDict if order else dict
2663        if not res:
2664            return cls()
2665        keyset = set(keyname)
2666        fields = q.listfields()
2667        if not keyset.issubset(fields):
2668            raise KeyError('Missing keyname in row')
2669        keyind, rowind = [], []
2670        for i, f in enumerate(fields):
2671            (keyind if f in keyset else rowind).append(i)
2672        keytuple = len(keyind) > 1
2673        getkey = itemgetter(*keyind)
2674        keys = map(getkey, res)
2675        if scalar:
2676            rowind = rowind[:1]
2677            rowtuple = False
2678        else:
2679            rowtuple = len(rowind) > 1
2680        if scalar or rowtuple:
2681            getrow = itemgetter(*rowind)
2682        else:
2683            rowind = rowind[0]
2684            getrow = lambda row: (row[rowind],)
2685            rowtuple = True
2686        rows = map(getrow, res)
2687        if keytuple or rowtuple:
2688            if keytuple:
2689                keys = _namediter(_MemoryQuery(keys, keyname))
2690            if rowtuple:
2691                fields = [f for f in fields if f not in keyset]
2692                rows = _namediter(_MemoryQuery(rows, fields))
2693        return cls(zip(keys, rows))
2694
2695    def notification_handler(self,
2696            event, callback, arg_dict=None, timeout=None, stop_event=None):
2697        """Get notification handler that will run the given callback."""
2698        return NotificationHandler(self,
2699            event, callback, arg_dict, timeout, stop_event)
2700
2701
2702# if run as script, print some information
2703
2704if __name__ == '__main__':
2705    print('PyGreSQL version' + version)
2706    print('')
2707    print(__doc__)
Note: See TracBrowser for help on using the repository browser.