source: trunk/pg.py

Last change on this file was 1018, checked in by cito, 4 weeks ago

Change all internal queries to consider CVE-2018-1058

  • Property svn:keywords set to Id
File size: 98.3 KB
Line 
1#!/usr/bin/python
2#
3# $Id: pg.py 1018 2019-09-27 16:16:00Z cito $
4#
5# PyGreSQL - a Python interface for the PostgreSQL database.
6#
7# This file contains the classic pg module.
8#
9# Copyright (c) 2019 by the PyGreSQL Development Team
10#
11# The notification handler is based on pgnotify which is
12# Copyright (c) 2001 Ng Pheng Siong. All rights reserved.
13#
14# Please see the LICENSE.TXT file for specific restrictions.
15
16"""PyGreSQL classic interface.
17
18This pg module implements some basic database management stuff.
19It includes the _pg module and builds on it, providing the higher
20level wrapper class named DB with additional functionality.
21This is known as the "classic" ("old style") PyGreSQL interface.
22For a DB-API 2 compliant interface use the newer pgdb module.
23"""
24
25from __future__ import print_function, division
26
27from _pg import *
28
29__version__ = version
30
31import select
32import warnings
33import weakref
34
35from datetime import date, time, datetime, timedelta, tzinfo
36from decimal import Decimal
37from math import isnan, isinf
38from collections import namedtuple
39from keyword import iskeyword
40from operator import itemgetter
41from functools import partial
42from re import compile as regex
43from json import loads as jsondecode, dumps as jsonencode
44from uuid import UUID
45
46try:  # noinspection PyUnresolvedReferences
47    long
48except NameError:  # Python >= 3.0
49    long = int
50
51try:  # noinspection PyUnresolvedReferences
52    basestring
53except NameError:  # Python >= 3.0
54    basestring = (str, bytes)
55
56try:
57    from functools import lru_cache
58except ImportError:  # Python < 3.2
59    from functools import update_wrapper
60    try:
61        from _thread import RLock
62    except ImportError:
63        class RLock:  # for builds without threads
64            def __enter__(self): pass
65
66            def __exit__(self, exctype, excinst, exctb): pass
67
68    def lru_cache(maxsize=128):
69        """Simplified functools.lru_cache decorator for one argument."""
70
71        def decorator(function):
72            sentinel = object()
73            cache = {}
74            get = cache.get
75            lock = RLock()
76            root = []
77            root_full = [root, False]
78            root[:] = [root, root, None, None]
79
80            if maxsize == 0:
81
82                def wrapper(arg):
83                    res = function(arg)
84                    return res
85
86            elif maxsize is None:
87
88                def wrapper(arg):
89                    res = get(arg, sentinel)
90                    if res is not sentinel:
91                        return res
92                    res = function(arg)
93                    cache[arg] = res
94                    return res
95
96            else:
97
98                def wrapper(arg):
99                    with lock:
100                        link = get(arg)
101                        if link is not None:
102                            root = root_full[0]
103                            prev, next, _arg, res = link
104                            prev[1] = next
105                            next[0] = prev
106                            last = root[0]
107                            last[1] = root[0] = link
108                            link[0] = last
109                            link[1] = root
110                            return res
111                    res = function(arg)
112                    with lock:
113                        root, full = root_full
114                        if arg in cache:
115                            pass
116                        elif full:
117                            oldroot = root
118                            oldroot[2] = arg
119                            oldroot[3] = res
120                            root = root_full[0] = oldroot[1]
121                            oldarg = root[2]
122                            oldres = root[3]  # keep reference
123                            root[2] = root[3] = None
124                            del cache[oldarg]
125                            cache[arg] = oldroot
126                        else:
127                            last = root[0]
128                            link = [last, root, arg, res]
129                            last[1] = root[0] = cache[arg] = link
130                            if len(cache) >= maxsize:
131                                root_full[1] = True
132                    return res
133
134            wrapper.__wrapped__ = function
135            return update_wrapper(wrapper, function)
136
137        return decorator
138
139
140# Auxiliary classes and functions that are independent from a DB connection:
141
142try:
143    from collections import OrderedDict
144except ImportError:  # Python 2.6 or 3.0
145    OrderedDict = dict
146
147
148    class AttrDict(dict):
149        """Simple read-only ordered dictionary for storing attribute names."""
150
151        def __init__(self, *args, **kw):
152            if len(args) > 1 or kw:
153                raise TypeError
154            items = args[0] if args else []
155            if isinstance(items, dict):
156                raise TypeError
157            items = list(items)
158            self._keys = [item[0] for item in items]
159            dict.__init__(self, items)
160            self._read_only = True
161            error = self._read_only_error
162            self.clear = self.update = error
163            self.pop = self.setdefault = self.popitem = error
164
165        def __setitem__(self, key, value):
166            if self._read_only:
167                self._read_only_error()
168            dict.__setitem__(self, key, value)
169
170        def __delitem__(self, key):
171            if self._read_only:
172                self._read_only_error()
173            dict.__delitem__(self, key)
174
175        def __iter__(self):
176            return iter(self._keys)
177
178        def keys(self):
179            return list(self._keys)
180
181        def values(self):
182            return [self[key] for key in self]
183
184        def items(self):
185            return [(key, self[key]) for key in self]
186
187        def iterkeys(self):
188            return self.__iter__()
189
190        def itervalues(self):
191            return iter(self.values())
192
193        def iteritems(self):
194            return iter(self.items())
195
196        @staticmethod
197        def _read_only_error(*args, **kw):
198            raise TypeError('This object is read-only')
199
200else:
201
202     class AttrDict(OrderedDict):
203        """Simple read-only ordered dictionary for storing attribute names."""
204
205        def __init__(self, *args, **kw):
206            self._read_only = False
207            OrderedDict.__init__(self, *args, **kw)
208            self._read_only = True
209            error = self._read_only_error
210            self.clear = self.update = error
211            self.pop = self.setdefault = self.popitem = error
212
213        def __setitem__(self, key, value):
214            if self._read_only:
215                self._read_only_error()
216            OrderedDict.__setitem__(self, key, value)
217
218        def __delitem__(self, key):
219            if self._read_only:
220                self._read_only_error()
221            OrderedDict.__delitem__(self, key)
222
223        @staticmethod
224        def _read_only_error(*args, **kw):
225            raise TypeError('This object is read-only')
226
227try:
228    from inspect import signature
229except ImportError:  # Python < 3.3
230    from inspect import getargspec
231
232    def get_args(func):
233        return getargspec(func).args
234else:
235
236    def get_args(func):
237        return list(signature(func).parameters)
238
239try:
240    from datetime import timezone
241except ImportError:  # Python < 3.2
242
243    class timezone(tzinfo):
244        """Simple timezone implementation."""
245
246        def __init__(self, offset, name=None):
247            self.offset = offset
248            if not name:
249                minutes = self.offset.days * 1440 + self.offset.seconds // 60
250                if minutes < 0:
251                    hours, minutes = divmod(-minutes, 60)
252                    hours = -hours
253                else:
254                    hours, minutes = divmod(minutes, 60)
255                name = 'UTC%+03d:%02d' % (hours, minutes)
256            self.name = name
257
258        def utcoffset(self, dt):
259            return self.offset
260
261        def tzname(self, dt):
262            return self.name
263
264        def dst(self, dt):
265            return None
266
267    timezone.utc = timezone(timedelta(0), 'UTC')
268
269    _has_timezone = False
270else:
271    _has_timezone = True
272
273# time zones used in Postgres timestamptz output
274_timezones = dict(CET='+0100', EET='+0200', EST='-0500',
275    GMT='+0000', HST='-1000', MET='+0100', MST='-0700',
276    UCT='+0000', UTC='+0000', WET='+0000')
277
278
279def _timezone_as_offset(tz):
280    if tz.startswith(('+', '-')):
281        if len(tz) < 5:
282            return tz + '00'
283        return tz.replace(':', '')
284    return _timezones.get(tz, '+0000')
285
286
287def _get_timezone(tz):
288    tz = _timezone_as_offset(tz)
289    minutes = 60 * int(tz[1:3]) + int(tz[3:5])
290    if tz[0] == '-':
291        minutes = -minutes
292    return timezone(timedelta(minutes=minutes), tz)
293
294
295def _oid_key(table):
296    """Build oid key from a table name."""
297    return 'oid(%s)' % table
298
299
300class _SimpleTypes(dict):
301    """Dictionary mapping pg_type names to simple type names."""
302
303    _types = {'bool': 'bool',
304        'bytea': 'bytea',
305        'date': 'date interval time timetz timestamp timestamptz'
306            ' abstime reltime',  # these are very old
307        'float': 'float4 float8',
308        'int': 'cid int2 int4 int8 oid xid',
309        'hstore': 'hstore', 'json': 'json jsonb', 'uuid': 'uuid',
310        'num': 'numeric', 'money': 'money',
311        'text': 'bpchar char name text varchar'}
312
313    def __init__(self):
314        for typ, keys in self._types.items():
315            for key in keys.split():
316                self[key] = typ
317                self['_%s' % key] = '%s[]' % typ
318
319    # this could be a static method in Python > 2.6
320    def __missing__(self, key):
321        return 'text'
322
323_simpletypes = _SimpleTypes()
324
325
326def _quote_if_unqualified(param, name):
327    """Quote parameter representing a qualified name.
328
329    Puts a quote_ident() call around the give parameter unless
330    the name contains a dot, in which case the name is ambiguous
331    (could be a qualified name or just a name with a dot in it)
332    and must be quoted manually by the caller.
333    """
334    if isinstance(name, basestring) and '.' not in name:
335        return 'quote_ident(%s)' % (param,)
336    return param
337
338
339class _ParameterList(list):
340    """Helper class for building typed parameter lists."""
341
342    def add(self, value, typ=None):
343        """Typecast value with known database type and build parameter list.
344
345        If this is a literal value, it will be returned as is.  Otherwise, a
346        placeholder will be returned and the parameter list will be augmented.
347        """
348        value = self.adapt(value, typ)
349        if isinstance(value, Literal):
350            return value
351        self.append(value)
352        return '$%d' % len(self)
353
354
355class Bytea(bytes):
356    """Wrapper class for marking Bytea values."""
357
358
359class Hstore(dict):
360    """Wrapper class for marking hstore values."""
361
362    _re_quote = regex('^[Nn][Uu][Ll][Ll]$|[ ,=>]')
363
364    @classmethod
365    def _quote(cls, s):
366        if s is None:
367            return 'NULL'
368        if not s:
369            return '""'
370        s = s.replace('"', '\\"')
371        if cls._re_quote.search(s):
372            s = '"%s"' % s
373        return s
374
375    def __str__(self):
376        q = self._quote
377        return ','.join('%s=>%s' % (q(k), q(v)) for k, v in self.items())
378
379
380class Json:
381    """Wrapper class for marking Json values."""
382
383    def __init__(self, obj):
384        self.obj = obj
385
386
387class Literal(str):
388    """Wrapper class for marking literal SQL values."""
389
390
391class Adapter:
392    """Class providing methods for adapting parameters to the database."""
393
394    _bool_true_values = frozenset('t true 1 y yes on'.split())
395
396    _date_literals = frozenset('current_date current_time'
397        ' current_timestamp localtime localtimestamp'.split())
398
399    _re_array_quote = regex(r'[{},"\\\s]|^[Nn][Uu][Ll][Ll]$')
400    _re_record_quote = regex(r'[(,"\\]')
401    _re_array_escape = _re_record_escape = regex(r'(["\\])')
402
403    def __init__(self, db):
404        self.db = weakref.proxy(db)
405
406    @classmethod
407    def _adapt_bool(cls, v):
408        """Adapt a boolean parameter."""
409        if isinstance(v, basestring):
410            if not v:
411                return None
412            v = v.lower() in cls._bool_true_values
413        return 't' if v else 'f'
414
415    @classmethod
416    def _adapt_date(cls, v):
417        """Adapt a date parameter."""
418        if not v:
419            return None
420        if isinstance(v, basestring) and v.lower() in cls._date_literals:
421            return Literal(v)
422        return v
423
424    @staticmethod
425    def _adapt_num(v):
426        """Adapt a numeric parameter."""
427        if not v and v != 0:
428            return None
429        return v
430
431    _adapt_int = _adapt_float = _adapt_money = _adapt_num
432
433    def _adapt_bytea(self, v):
434        """Adapt a bytea parameter."""
435        return self.db.escape_bytea(v)
436
437    def _adapt_json(self, v):
438        """Adapt a json parameter."""
439        if not v:
440            return None
441        if isinstance(v, basestring):
442            return v
443        return self.db.encode_json(v)
444
445    @classmethod
446    def _adapt_text_array(cls, v):
447        """Adapt a text type array parameter."""
448        if isinstance(v, list):
449            adapt = cls._adapt_text_array
450            return '{%s}' % ','.join(adapt(v) for v in v)
451        if v is None:
452            return 'null'
453        if not v:
454            return '""'
455        v = str(v)
456        if cls._re_array_quote.search(v):
457            v = '"%s"' % cls._re_array_escape.sub(r'\\\1', v)
458        return v
459
460    _adapt_date_array = _adapt_text_array
461
462    @classmethod
463    def _adapt_bool_array(cls, v):
464        """Adapt a boolean array parameter."""
465        if isinstance(v, list):
466            adapt = cls._adapt_bool_array
467            return '{%s}' % ','.join(adapt(v) for v in v)
468        if v is None:
469            return 'null'
470        if isinstance(v, basestring):
471            if not v:
472                return 'null'
473            v = v.lower() in cls._bool_true_values
474        return 't' if v else 'f'
475
476    @classmethod
477    def _adapt_num_array(cls, v):
478        """Adapt a numeric array parameter."""
479        if isinstance(v, list):
480            adapt = cls._adapt_num_array
481            return '{%s}' % ','.join(adapt(v) for v in v)
482        if not v and v != 0:
483            return 'null'
484        return str(v)
485
486    _adapt_int_array = _adapt_float_array = _adapt_money_array = \
487            _adapt_num_array
488
489    def _adapt_bytea_array(self, v):
490        """Adapt a bytea array parameter."""
491        if isinstance(v, list):
492            return b'{' + b','.join(
493                self._adapt_bytea_array(v) for v in v) + b'}'
494        if v is None:
495            return b'null'
496        return self.db.escape_bytea(v).replace(b'\\', b'\\\\')
497
498    def _adapt_json_array(self, v):
499        """Adapt a json array parameter."""
500        if isinstance(v, list):
501            adapt = self._adapt_json_array
502            return '{%s}' % ','.join(adapt(v) for v in v)
503        if not v:
504            return 'null'
505        if not isinstance(v, basestring):
506            v = self.db.encode_json(v)
507        if self._re_array_quote.search(v):
508            v = '"%s"' % self._re_array_escape.sub(r'\\\1', v)
509        return v
510
511    def _adapt_record(self, v, typ):
512        """Adapt a record parameter with given type."""
513        typ = self.get_attnames(typ).values()
514        if len(typ) != len(v):
515            raise TypeError('Record parameter %s has wrong size' % v)
516        adapt = self.adapt
517        value = []
518        for v, t in zip(v, typ):
519            v = adapt(v, t)
520            if v is None:
521                v = ''
522            elif not v:
523                v = '""'
524            else:
525                if isinstance(v, bytes):
526                    if str is not bytes:
527                        v = v.decode('ascii')
528                else:
529                    v = str(v)
530                if self._re_record_quote.search(v):
531                    v = '"%s"' % self._re_record_escape.sub(r'\\\1', v)
532            value.append(v)
533        return '(%s)' % ','.join(value)
534
535    def adapt(self, value, typ=None):
536        """Adapt a value with known database type."""
537        if value is not None and not isinstance(value, Literal):
538            if typ:
539                simple = self.get_simple_name(typ)
540            else:
541                typ = simple = self.guess_simple_type(value) or 'text'
542            pg_str = getattr(value, '__pg_str__', None)
543            if pg_str:
544                value = pg_str(typ)
545            if simple == 'text':
546                pass
547            elif simple == 'record':
548                if isinstance(value, tuple):
549                    value = self._adapt_record(value, typ)
550            elif simple.endswith('[]'):
551                if isinstance(value, list):
552                    adapt = getattr(self, '_adapt_%s_array' % simple[:-2])
553                    value = adapt(value)
554            else:
555                adapt = getattr(self, '_adapt_%s' % simple)
556                value = adapt(value)
557        return value
558
559    @staticmethod
560    def simple_type(name):
561        """Create a simple database type with given attribute names."""
562        typ = DbType(name)
563        typ.simple = name
564        return typ
565
566    @staticmethod
567    def get_simple_name(typ):
568        """Get the simple name of a database type."""
569        if isinstance(typ, DbType):
570            return typ.simple
571        return _simpletypes[typ]
572
573    @staticmethod
574    def get_attnames(typ):
575        """Get the attribute names of a composite database type."""
576        if isinstance(typ, DbType):
577            return typ.attnames
578        return {}
579
580    _frequent_simple_types = {
581        Bytea: 'bytea',
582        str: 'text',
583        bytes: 'text',
584        bool: 'bool',
585        int: 'int',
586        long: 'int',
587        float: 'float',
588        Decimal: 'num',
589        date: 'date',
590        time: 'date',
591        datetime: 'date',
592        timedelta: 'date'
593    }
594
595    @classmethod
596    def guess_simple_type(cls, value):
597        """Try to guess which database type the given value has."""
598        # optimize for most frequent types
599        simple_type = cls._frequent_simple_types.get(type(value))
600        if simple_type:
601            return simple_type
602        if isinstance(value, Bytea):
603            return 'bytea'
604        if isinstance(value, basestring):
605            return 'text'
606        if isinstance(value, bool):
607            return 'bool'
608        if isinstance(value, (int, long)):
609            return 'int'
610        if isinstance(value, float):
611            return 'float'
612        if isinstance(value, Decimal):
613            return 'num'
614        if isinstance(value, (date, time, datetime, timedelta)):
615            return 'date'
616        if isinstance(value, list):
617            return '%s[]' % (cls.guess_simple_base_type(value) or 'text',)
618        if isinstance(value, tuple):
619            simple_type = cls.simple_type
620            guess = cls.guess_simple_type
621
622            def get_attnames(self):
623                return AttrDict((str(n + 1), simple_type(guess(v)))
624                    for n, v in enumerate(value))
625
626            typ = simple_type('record')
627            typ._get_attnames = get_attnames
628            return typ
629
630    @classmethod
631    def guess_simple_base_type(cls, value):
632        """Try to guess the base type of a given array."""
633        for v in value:
634            if isinstance(v, list):
635                typ = cls.guess_simple_base_type(v)
636            else:
637                typ = cls.guess_simple_type(v)
638            if typ:
639                return typ
640
641    def adapt_inline(self, value, nested=False):
642        """Adapt a value that is put into the SQL and needs to be quoted."""
643        if value is None:
644            return 'NULL'
645        if isinstance(value, Literal):
646            return value
647        if isinstance(value, Bytea):
648            value = self.db.escape_bytea(value)
649            if bytes is not str:  # Python >= 3.0
650                value = value.decode('ascii')
651        elif isinstance(value, Json):
652            if value.encode:
653                return value.encode()
654            value = self.db.encode_json(value)
655        elif isinstance(value, (datetime, date, time, timedelta)):
656            value = str(value)
657        if isinstance(value, basestring):
658            value = self.db.escape_string(value)
659            return "'%s'" % value
660        if isinstance(value, bool):
661            return 'true' if value else 'false'
662        if isinstance(value, float):
663            if isinf(value):
664                return "'-Infinity'" if value < 0 else "'Infinity'"
665            if isnan(value):
666                return "'NaN'"
667            return value
668        if isinstance(value, (int, long, Decimal)):
669            return value
670        if isinstance(value, list):
671            q = self.adapt_inline
672            s = '[%s]' if nested else 'ARRAY[%s]'
673            return s % ','.join(str(q(v, nested=True)) for v in value)
674        if isinstance(value, tuple):
675            q = self.adapt_inline
676            return '(%s)' % ','.join(str(q(v)) for v in value)
677        pg_repr = getattr(value, '__pg_repr__', None)
678        if not pg_repr:
679            raise InterfaceError(
680                'Do not know how to adapt type %s' % type(value))
681        value = pg_repr()
682        if isinstance(value, (tuple, list)):
683            value = self.adapt_inline(value)
684        return value
685
686    def parameter_list(self):
687        """Return a parameter list for parameters with known database types.
688
689        The list has an add(value, typ) method that will build up the
690        list and return either the literal value or a placeholder.
691        """
692        params = _ParameterList()
693        params.adapt = self.adapt
694        return params
695
696    def format_query(self, command, values=None, types=None, inline=False):
697        """Format a database query using the given values and types."""
698        if not values:
699            return command, []
700        if inline and types:
701            raise ValueError('Typed parameters must be sent separately')
702        params = self.parameter_list()
703        if isinstance(values, (list, tuple)):
704            if inline:
705                adapt = self.adapt_inline
706                literals = [adapt(value) for value in values]
707            else:
708                add = params.add
709                if types:
710                    if (not isinstance(types, (list, tuple)) or
711                            len(types) != len(values)):
712                        raise TypeError('The values and types do not match')
713                    literals = [add(value, typ)
714                        for value, typ in zip(values, types)]
715                else:
716                    literals = [add(value) for value in values]
717            command %= tuple(literals)
718        elif isinstance(values, dict):
719            # we want to allow extra keys in the dictionary,
720            # so we first must find the values actually used in the command
721            used_values = {}
722            literals = dict.fromkeys(values, '')
723            for key in values:
724                del literals[key]
725                try:
726                    command % literals
727                except KeyError:
728                    used_values[key] = values[key]
729                literals[key] = ''
730            values = used_values
731            if inline:
732                adapt = self.adapt_inline
733                literals = dict((key, adapt(value))
734                    for key, value in values.items())
735            else:
736                add = params.add
737                if types:
738                    if not isinstance(types, dict):
739                        raise TypeError('The values and types do not match')
740                    literals = dict((key, add(values[key], types.get(key)))
741                        for key in sorted(values))
742                else:
743                    literals = dict((key, add(values[key]))
744                        for key in sorted(values))
745            command %= literals
746        else:
747            raise TypeError('The values must be passed as tuple, list or dict')
748        return command, params
749
750
751def cast_bool(value):
752    """Cast a boolean value."""
753    if not get_bool():
754        return value
755    return value[0] == 't'
756
757
758def cast_json(value):
759    """Cast a JSON value."""
760    cast = get_jsondecode()
761    if not cast:
762        return value
763    return cast(value)
764
765
766def cast_num(value):
767    """Cast a numeric value."""
768    return (get_decimal() or float)(value)
769
770
771def cast_money(value):
772    """Cast a money value."""
773    point = get_decimal_point()
774    if not point:
775        return value
776    if point != '.':
777        value = value.replace(point, '.')
778    value = value.replace('(', '-')
779    value = ''.join(c for c in value if c.isdigit() or c in '.-')
780    return (get_decimal() or float)(value)
781
782
783def cast_int2vector(value):
784    """Cast an int2vector value."""
785    return [int(v) for v in value.split()]
786
787
788def cast_date(value, connection):
789    """Cast a date value."""
790    # The output format depends on the server setting DateStyle.  The default
791    # setting ISO and the setting for German are actually unambiguous.  The
792    # order of days and months in the other two settings is however ambiguous,
793    # so at least here we need to consult the setting to properly parse values.
794    if value == '-infinity':
795        return date.min
796    if value == 'infinity':
797        return date.max
798    value = value.split()
799    if value[-1] == 'BC':
800        return date.min
801    value = value[0]
802    if len(value) > 10:
803        return date.max
804    fmt = connection.date_format()
805    return datetime.strptime(value, fmt).date()
806
807
808def cast_time(value):
809    """Cast a time value."""
810    fmt = '%H:%M:%S.%f' if len(value) > 8 else '%H:%M:%S'
811    return datetime.strptime(value, fmt).time()
812
813
814_re_timezone = regex('(.*)([+-].*)')
815
816
817def cast_timetz(value):
818    """Cast a timetz value."""
819    tz = _re_timezone.match(value)
820    if tz:
821        value, tz = tz.groups()
822    else:
823        tz = '+0000'
824    fmt = '%H:%M:%S.%f' if len(value) > 8 else '%H:%M:%S'
825    if _has_timezone:
826        value += _timezone_as_offset(tz)
827        fmt += '%z'
828        return datetime.strptime(value, fmt).timetz()
829    return datetime.strptime(value, fmt).timetz().replace(
830        tzinfo=_get_timezone(tz))
831
832
833def cast_timestamp(value, connection):
834    """Cast a timestamp value."""
835    if value == '-infinity':
836        return datetime.min
837    if value == 'infinity':
838        return datetime.max
839    value = value.split()
840    if value[-1] == 'BC':
841        return datetime.min
842    fmt = connection.date_format()
843    if fmt.endswith('-%Y') and len(value) > 2:
844        value = value[1:5]
845        if len(value[3]) > 4:
846            return datetime.max
847        fmt = ['%d %b' if fmt.startswith('%d') else '%b %d',
848            '%H:%M:%S.%f' if len(value[2]) > 8 else '%H:%M:%S', '%Y']
849    else:
850        if len(value[0]) > 10:
851            return datetime.max
852        fmt = [fmt, '%H:%M:%S.%f' if len(value[1]) > 8 else '%H:%M:%S']
853    return datetime.strptime(' '.join(value), ' '.join(fmt))
854
855
856def cast_timestamptz(value, connection):
857    """Cast a timestamptz value."""
858    if value == '-infinity':
859        return datetime.min
860    if value == 'infinity':
861        return datetime.max
862    value = value.split()
863    if value[-1] == 'BC':
864        return datetime.min
865    fmt = connection.date_format()
866    if fmt.endswith('-%Y') and len(value) > 2:
867        value = value[1:]
868        if len(value[3]) > 4:
869            return datetime.max
870        fmt = ['%d %b' if fmt.startswith('%d') else '%b %d',
871            '%H:%M:%S.%f' if len(value[2]) > 8 else '%H:%M:%S', '%Y']
872        value, tz = value[:-1], value[-1]
873    else:
874        if fmt.startswith('%Y-'):
875            tz = _re_timezone.match(value[1])
876            if tz:
877                value[1], tz = tz.groups()
878            else:
879                tz = '+0000'
880        else:
881            value, tz = value[:-1], value[-1]
882        if len(value[0]) > 10:
883            return datetime.max
884        fmt = [fmt, '%H:%M:%S.%f' if len(value[1]) > 8 else '%H:%M:%S']
885    if _has_timezone:
886        value.append(_timezone_as_offset(tz))
887        fmt.append('%z')
888        return datetime.strptime(' '.join(value), ' '.join(fmt))
889    return datetime.strptime(' '.join(value), ' '.join(fmt)).replace(
890        tzinfo=_get_timezone(tz))
891
892
893_re_interval_sql_standard = regex(
894    '(?:([+-])?([0-9]+)-([0-9]+) ?)?'
895    '(?:([+-]?[0-9]+)(?!:) ?)?'
896    '(?:([+-])?([0-9]+):([0-9]+):([0-9]+)(?:\\.([0-9]+))?)?')
897
898_re_interval_postgres = regex(
899    '(?:([+-]?[0-9]+) ?years? ?)?'
900    '(?:([+-]?[0-9]+) ?mons? ?)?'
901    '(?:([+-]?[0-9]+) ?days? ?)?'
902    '(?:([+-])?([0-9]+):([0-9]+):([0-9]+)(?:\\.([0-9]+))?)?')
903
904_re_interval_postgres_verbose = regex(
905    '@ ?(?:([+-]?[0-9]+) ?years? ?)?'
906    '(?:([+-]?[0-9]+) ?mons? ?)?'
907    '(?:([+-]?[0-9]+) ?days? ?)?'
908    '(?:([+-]?[0-9]+) ?hours? ?)?'
909    '(?:([+-]?[0-9]+) ?mins? ?)?'
910    '(?:([+-])?([0-9]+)(?:\\.([0-9]+))? ?secs?)? ?(ago)?')
911
912_re_interval_iso_8601 = regex(
913    'P(?:([+-]?[0-9]+)Y)?'
914    '(?:([+-]?[0-9]+)M)?'
915    '(?:([+-]?[0-9]+)D)?'
916    '(?:T(?:([+-]?[0-9]+)H)?'
917    '(?:([+-]?[0-9]+)M)?'
918    '(?:([+-])?([0-9]+)(?:\\.([0-9]+))?S)?)?')
919
920
921def cast_interval(value):
922    """Cast an interval value."""
923    # The output format depends on the server setting IntervalStyle, but it's
924    # not necessary to consult this setting to parse it.  It's faster to just
925    # check all possible formats, and there is no ambiguity here.
926    m = _re_interval_iso_8601.match(value)
927    if m:
928        m = [d or '0' for d in m.groups()]
929        secs_ago = m.pop(5) == '-'
930        m = [int(d) for d in m]
931        years, mons, days, hours, mins, secs, usecs = m
932        if secs_ago:
933            secs = -secs
934            usecs = -usecs
935    else:
936        m = _re_interval_postgres_verbose.match(value)
937        if m:
938            m, ago = [d or '0' for d in m.groups()[:8]], m.group(9)
939            secs_ago = m.pop(5) == '-'
940            m = [-int(d) for d in m] if ago else [int(d) for d in m]
941            years, mons, days, hours, mins, secs, usecs = m
942            if secs_ago:
943                secs = - secs
944                usecs = -usecs
945        else:
946            m = _re_interval_postgres.match(value)
947            if m and any(m.groups()):
948                m = [d or '0' for d in m.groups()]
949                hours_ago = m.pop(3) == '-'
950                m = [int(d) for d in m]
951                years, mons, days, hours, mins, secs, usecs = m
952                if hours_ago:
953                    hours = -hours
954                    mins = -mins
955                    secs = -secs
956                    usecs = -usecs
957            else:
958                m = _re_interval_sql_standard.match(value)
959                if m and any(m.groups()):
960                    m = [d or '0' for d in m.groups()]
961                    years_ago = m.pop(0) == '-'
962                    hours_ago = m.pop(3) == '-'
963                    m = [int(d) for d in m]
964                    years, mons, days, hours, mins, secs, usecs = m
965                    if years_ago:
966                        years = -years
967                        mons = -mons
968                    if hours_ago:
969                        hours = -hours
970                        mins = -mins
971                        secs = -secs
972                        usecs = -usecs
973                else:
974                    raise ValueError('Cannot parse interval: %s' % value)
975    days += 365 * years + 30 * mons
976    return timedelta(days=days, hours=hours, minutes=mins,
977        seconds=secs, microseconds=usecs)
978
979
980class Typecasts(dict):
981    """Dictionary mapping database types to typecast functions.
982
983    The cast functions get passed the string representation of a value in
984    the database which they need to convert to a Python object.  The
985    passed string will never be None since NULL values are already
986    handled before the cast function is called.
987
988    Note that the basic types are already handled by the C extension.
989    They only need to be handled here as record or array components.
990    """
991
992    # the default cast functions
993    # (str functions are ignored but have been added for faster access)
994    defaults = {'char': str, 'bpchar': str, 'name': str,
995        'text': str, 'varchar': str,
996        'bool': cast_bool, 'bytea': unescape_bytea,
997        'int2': int, 'int4': int, 'serial': int, 'int8': long, 'oid': int,
998        'hstore': cast_hstore, 'json': cast_json, 'jsonb': cast_json,
999        'float4': float, 'float8': float,
1000        'numeric': cast_num, 'money': cast_money,
1001        'date': cast_date, 'interval': cast_interval,
1002        'time': cast_time, 'timetz': cast_timetz,
1003        'timestamp': cast_timestamp, 'timestamptz': cast_timestamptz,
1004        'int2vector': cast_int2vector, 'uuid': UUID,
1005        'anyarray': cast_array, 'record': cast_record}
1006
1007    connection = None  # will be set in a connection specific instance
1008
1009    def __missing__(self, typ):
1010        """Create a cast function if it is not cached.
1011
1012        Note that this class never raises a KeyError,
1013        but returns None when no special cast function exists.
1014        """
1015        if not isinstance(typ, str):
1016            raise TypeError('Invalid type: %s' % typ)
1017        cast = self.defaults.get(typ)
1018        if cast:
1019            # store default for faster access
1020            cast = self._add_connection(cast)
1021            self[typ] = cast
1022        elif typ.startswith('_'):
1023            base_cast = self[typ[1:]]
1024            cast = self.create_array_cast(base_cast)
1025            if base_cast:
1026                self[typ] = cast
1027        else:
1028            attnames = self.get_attnames(typ)
1029            if attnames:
1030                casts = [self[v.pgtype] for v in attnames.values()]
1031                cast = self.create_record_cast(typ, attnames, casts)
1032                self[typ] = cast
1033        return cast
1034
1035    @staticmethod
1036    def _needs_connection(func):
1037        """Check if a typecast function needs a connection argument."""
1038        try:
1039            args = get_args(func)
1040        except (TypeError, ValueError):
1041            return False
1042        else:
1043            return 'connection' in args[1:]
1044
1045    def _add_connection(self, cast):
1046        """Add a connection argument to the typecast function if necessary."""
1047        if not self.connection or not self._needs_connection(cast):
1048            return cast
1049        return partial(cast, connection=self.connection)
1050
1051    def get(self, typ, default=None):
1052        """Get the typecast function for the given database type."""
1053        return self[typ] or default
1054
1055    def set(self, typ, cast):
1056        """Set a typecast function for the specified database type(s)."""
1057        if isinstance(typ, basestring):
1058            typ = [typ]
1059        if cast is None:
1060            for t in typ:
1061                self.pop(t, None)
1062                self.pop('_%s' % t, None)
1063        else:
1064            if not callable(cast):
1065                raise TypeError("Cast parameter must be callable")
1066            for t in typ:
1067                self[t] = self._add_connection(cast)
1068                self.pop('_%s' % t, None)
1069
1070    def reset(self, typ=None):
1071        """Reset the typecasts for the specified type(s) to their defaults.
1072
1073        When no type is specified, all typecasts will be reset.
1074        """
1075        if typ is None:
1076            self.clear()
1077        else:
1078            if isinstance(typ, basestring):
1079                typ = [typ]
1080            for t in typ:
1081                self.pop(t, None)
1082
1083    @classmethod
1084    def get_default(cls, typ):
1085        """Get the default typecast function for the given database type."""
1086        return cls.defaults.get(typ)
1087
1088    @classmethod
1089    def set_default(cls, typ, cast):
1090        """Set a default typecast function for the given database type(s)."""
1091        if isinstance(typ, basestring):
1092            typ = [typ]
1093        defaults = cls.defaults
1094        if cast is None:
1095            for t in typ:
1096                defaults.pop(t, None)
1097                defaults.pop('_%s' % t, None)
1098        else:
1099            if not callable(cast):
1100                raise TypeError("Cast parameter must be callable")
1101            for t in typ:
1102                defaults[t] = cast
1103                defaults.pop('_%s' % t, None)
1104
1105    def get_attnames(self, typ):
1106        """Return the fields for the given record type.
1107
1108        This method will be replaced with the get_attnames() method of DbTypes.
1109        """
1110        return {}
1111
1112    def dateformat(self):
1113        """Return the current date format.
1114
1115        This method will be replaced with the dateformat() method of DbTypes.
1116        """
1117        return '%Y-%m-%d'
1118
1119    def create_array_cast(self, basecast):
1120        """Create an array typecast for the given base cast."""
1121        cast_array = self['anyarray']
1122        def cast(v):
1123            return cast_array(v, basecast)
1124        return cast
1125
1126    def create_record_cast(self, name, fields, casts):
1127        """Create a named record typecast for the given fields and casts."""
1128        cast_record = self['record']
1129        record = namedtuple(name, fields)
1130        def cast(v):
1131            return record(*cast_record(v, casts))
1132        return cast
1133
1134
1135def get_typecast(typ):
1136    """Get the global typecast function for the given database type(s)."""
1137    return Typecasts.get_default(typ)
1138
1139
1140def set_typecast(typ, cast):
1141    """Set a global typecast function for the given database type(s).
1142
1143    Note that connections cache cast functions. To be sure a global change
1144    is picked up by a running connection, call db.db_types.reset_typecast().
1145    """
1146    Typecasts.set_default(typ, cast)
1147
1148
1149class DbType(str):
1150    """Class augmenting the simple type name with additional info.
1151
1152    The following additional information is provided:
1153
1154        oid: the PostgreSQL type OID
1155        pgtype: the internal PostgreSQL data type name
1156        regtype: the registered PostgreSQL data type name
1157        simple: the more coarse-grained PyGreSQL type name
1158        typtype: b = base type, c = composite type etc.
1159        category: A = Array, b = Boolean, C = Composite etc.
1160        delim: delimiter for array types
1161        relid: corresponding table for composite types
1162        attnames: attributes for composite types
1163    """
1164
1165    @property
1166    def attnames(self):
1167        """Get names and types of the fields of a composite type."""
1168        return self._get_attnames(self)
1169
1170
1171class DbTypes(dict):
1172    """Cache for PostgreSQL data types.
1173
1174    This cache maps type OIDs and names to DbType objects containing
1175    information on the associated database type.
1176    """
1177
1178    _num_types = frozenset('int float num money'
1179        ' int2 int4 int8 float4 float8 numeric money'.split())
1180
1181    def __init__(self, db):
1182        """Initialize type cache for connection."""
1183        super(DbTypes, self).__init__()
1184        self._db = weakref.proxy(db)
1185        self._regtypes = False
1186        self._typecasts = Typecasts()
1187        self._typecasts.get_attnames = self.get_attnames
1188        self._typecasts.connection = self._db
1189        if db.server_version < 80400:
1190            # older remote databases (not officially supported)
1191            self._query_pg_type = (
1192                "SELECT oid, typname, typname::text::regtype,"
1193                " typtype, null as typcategory, typdelim, typrelid"
1194                " FROM pg_catalog.pg_type"
1195                " WHERE oid OPERATOR(pg_catalog.=) %s::regtype")
1196        else:
1197            self._query_pg_type = (
1198                "SELECT oid, typname, typname::regtype,"
1199                " typtype, typcategory, typdelim, typrelid"
1200                " FROM pg_catalog.pg_type"
1201                " WHERE oid OPERATOR(pg_catalog.=) %s::regtype")
1202
1203    def add(self, oid, pgtype, regtype,
1204               typtype, category, delim, relid):
1205        """Create a PostgreSQL type name with additional info."""
1206        if oid in self:
1207            return self[oid]
1208        simple = 'record' if relid else _simpletypes[pgtype]
1209        typ = DbType(regtype if self._regtypes else simple)
1210        typ.oid = oid
1211        typ.simple = simple
1212        typ.pgtype = pgtype
1213        typ.regtype = regtype
1214        typ.typtype = typtype
1215        typ.category = category
1216        typ.delim = delim
1217        typ.relid = relid
1218        typ._get_attnames = self.get_attnames
1219        return typ
1220
1221    def __missing__(self, key):
1222        """Get the type info from the database if it is not cached."""
1223        try:
1224            q = self._query_pg_type % (_quote_if_unqualified('$1', key),)
1225            res = self._db.query(q, (key,)).getresult()
1226        except ProgrammingError:
1227            res = None
1228        if not res:
1229            raise KeyError('Type %s could not be found' % key)
1230        res = res[0]
1231        typ = self.add(*res)
1232        self[typ.oid] = self[typ.pgtype] = typ
1233        return typ
1234
1235    def get(self, key, default=None):
1236        """Get the type even if it is not cached."""
1237        try:
1238            return self[key]
1239        except KeyError:
1240            return default
1241
1242    def get_attnames(self, typ):
1243        """Get names and types of the fields of a composite type."""
1244        if not isinstance(typ, DbType):
1245            typ = self.get(typ)
1246            if not typ:
1247                return None
1248        if not typ.relid:
1249            return None
1250        return self._db.get_attnames(typ.relid, with_oid=False)
1251
1252    def get_typecast(self, typ):
1253        """Get the typecast function for the given database type."""
1254        return self._typecasts.get(typ)
1255
1256    def set_typecast(self, typ, cast):
1257        """Set a typecast function for the specified database type(s)."""
1258        self._typecasts.set(typ, cast)
1259
1260    def reset_typecast(self, typ=None):
1261        """Reset the typecast function for the specified database type(s)."""
1262        self._typecasts.reset(typ)
1263
1264    def typecast(self, value, typ):
1265        """Cast the given value according to the given database type."""
1266        if value is None:
1267            # for NULL values, no typecast is necessary
1268            return None
1269        if not isinstance(typ, DbType):
1270            typ = self.get(typ)
1271            if typ:
1272                typ = typ.pgtype
1273        cast = self.get_typecast(typ) if typ else None
1274        if not cast or cast is str:
1275            # no typecast is necessary
1276            return value
1277        return cast(value)
1278
1279
1280_re_fieldname = regex('^[A-Za-z][_a-zA-Z0-9]*$')
1281
1282# The result rows for database operations are returned as named tuples
1283# by default. Since creating namedtuple classes is a somewhat expensive
1284# operation, we cache up to 1024 of these classes by default.
1285
1286@lru_cache(maxsize=1024)
1287def _row_factory(names):
1288    """Get a namedtuple factory for row results with the given names."""
1289    try:
1290        try:
1291            return namedtuple('Row', names, rename=True)._make
1292        except TypeError:  # Python 2.6 and 3.0 do not support rename
1293            names = [v if _re_fieldname.match(v) and not iskeyword(v)
1294                        else 'column_%d' % (n,)
1295                     for n, v in enumerate(names)]
1296            return namedtuple('Row', names)._make
1297    except ValueError:  # there is still a problem with the field names
1298        names = ['column_%d' % (n,) for n in range(len(names))]
1299        return namedtuple('Row', names)._make
1300
1301
1302def set_row_factory_size(maxsize):
1303    """Change the size of the namedtuple factory cache.
1304
1305    If maxsize is set to None, the cache can grow without bound.
1306    """
1307    global _row_factory
1308    _row_factory = lru_cache(maxsize)(_row_factory.__wrapped__)
1309
1310
1311# Helper functions used by the query object
1312
1313def _dictiter(q):
1314    """Get query result as an iterator of dictionaries."""
1315    fields = q.listfields()
1316    for r in q:
1317        yield dict(zip(fields, r))
1318
1319
1320def _namediter(q):
1321    """Get query result as an iterator of named tuples."""
1322    row = _row_factory(q.listfields())
1323    for r in q:
1324        yield row(r)
1325
1326
1327def _namednext(q):
1328    """Get next row from query result as a named tuple."""
1329    return _row_factory(q.listfields())(next(q))
1330
1331
1332def _scalariter(q):
1333    """Get query result as an iterator of scalar values."""
1334    for r in q:
1335        yield r[0]
1336
1337
1338class _MemoryQuery:
1339    """Class that embodies a given query result."""
1340
1341    def __init__(self, result, fields):
1342        """Create query from given result rows and field names."""
1343        self.result = result
1344        self.fields = tuple(fields)
1345
1346    def listfields(self):
1347        """Return the stored field names of this query."""
1348        return self.fields
1349
1350    def getresult(self):
1351        """Return the stored result of this query."""
1352        return self.result
1353
1354    def __iter__(self):
1355        return iter(self.result)
1356
1357
1358def _db_error(msg, cls=DatabaseError):
1359    """Return DatabaseError with empty sqlstate attribute."""
1360    error = cls(msg)
1361    error.sqlstate = None
1362    return error
1363
1364
1365def _int_error(msg):
1366    """Return InternalError."""
1367    return _db_error(msg, InternalError)
1368
1369
1370def _prg_error(msg):
1371    """Return ProgrammingError."""
1372    return _db_error(msg, ProgrammingError)
1373
1374
1375# Initialize the C module
1376
1377set_decimal(Decimal)
1378set_jsondecode(jsondecode)
1379set_query_helpers(_dictiter, _namediter, _namednext, _scalariter)
1380
1381
1382# The notification handler
1383
1384class NotificationHandler(object):
1385    """A PostgreSQL client-side asynchronous notification handler."""
1386
1387    def __init__(self, db, event, callback=None,
1388            arg_dict=None, timeout=None, stop_event=None):
1389        """Initialize the notification handler.
1390
1391        You must pass a PyGreSQL database connection, the name of an
1392        event (notification channel) to listen for and a callback function.
1393
1394        You can also specify a dictionary arg_dict that will be passed as
1395        the single argument to the callback function, and a timeout value
1396        in seconds (a floating point number denotes fractions of seconds).
1397        If it is absent or None, the callers will never time out.  If the
1398        timeout is reached, the callback function will be called with a
1399        single argument that is None.  If you set the timeout to zero,
1400        the handler will poll notifications synchronously and return.
1401
1402        You can specify the name of the event that will be used to signal
1403        the handler to stop listening as stop_event. By default, it will
1404        be the event name prefixed with 'stop_'.
1405        """
1406        self.db = db
1407        self.event = event
1408        self.stop_event = stop_event or 'stop_%s' % event
1409        self.listening = False
1410        self.callback = callback
1411        if arg_dict is None:
1412            arg_dict = {}
1413        self.arg_dict = arg_dict
1414        self.timeout = timeout
1415
1416    def __del__(self):
1417        self.unlisten()
1418
1419    def close(self):
1420        """Stop listening and close the connection."""
1421        if self.db:
1422            self.unlisten()
1423            self.db.close()
1424            self.db = None
1425
1426    def listen(self):
1427        """Start listening for the event and the stop event."""
1428        if not self.listening:
1429            self.db.query('listen "%s"' % self.event)
1430            self.db.query('listen "%s"' % self.stop_event)
1431            self.listening = True
1432
1433    def unlisten(self):
1434        """Stop listening for the event and the stop event."""
1435        if self.listening:
1436            self.db.query('unlisten "%s"' % self.event)
1437            self.db.query('unlisten "%s"' % self.stop_event)
1438            self.listening = False
1439
1440    def notify(self, db=None, stop=False, payload=None):
1441        """Generate a notification.
1442
1443        Optionally, you can pass a payload with the notification.
1444
1445        If you set the stop flag, a stop notification will be sent that
1446        will cause the handler to stop listening.
1447
1448        Note: If the notification handler is running in another thread, you
1449        must pass a different database connection since PyGreSQL database
1450        connections are not thread-safe.
1451        """
1452        if self.listening:
1453            if not db:
1454                db = self.db
1455            q = 'notify "%s"' % (self.stop_event if stop else self.event)
1456            if payload:
1457                q += ", '%s'" % payload
1458            return db.query(q)
1459
1460    def __call__(self):
1461        """Invoke the notification handler.
1462
1463        The handler is a loop that listens for notifications on the event
1464        and stop event channels.  When either of these notifications are
1465        received, its associated 'pid', 'event' and 'extra' (the payload
1466        passed with the notification) are inserted into its arg_dict
1467        dictionary and the callback is invoked with this dictionary as
1468        a single argument.  When the handler receives a stop event, it
1469        stops listening to both events and return.
1470
1471        In the special case that the timeout of the handler has been set
1472        to zero, the handler will poll all events synchronously and return.
1473        If will keep listening until it receives a stop event.
1474
1475        Note: If you run this loop in another thread, don't use the same
1476        database connection for database operations in the main thread.
1477        """
1478        self.listen()
1479        poll = self.timeout == 0
1480        if not poll:
1481            rlist = [self.db.fileno()]
1482        while self.listening:
1483            if poll or select.select(rlist, [], [], self.timeout)[0]:
1484                while self.listening:
1485                    notice = self.db.getnotify()
1486                    if not notice:  # no more messages
1487                        break
1488                    event, pid, extra = notice
1489                    if event not in (self.event, self.stop_event):
1490                        self.unlisten()
1491                        raise _db_error(
1492                            'Listening for "%s" and "%s", but notified of "%s"'
1493                            % (self.event, self.stop_event, event))
1494                    if event == self.stop_event:
1495                        self.unlisten()
1496                    self.arg_dict.update(pid=pid, event=event, extra=extra)
1497                    self.callback(self.arg_dict)
1498                if poll:
1499                    break
1500            else:   # we timed out
1501                self.unlisten()
1502                self.callback(None)
1503
1504
1505def pgnotify(*args, **kw):
1506    """Same as NotificationHandler, under the traditional name."""
1507    warnings.warn("pgnotify is deprecated, use NotificationHandler instead",
1508        DeprecationWarning, stacklevel=2)
1509    return NotificationHandler(*args, **kw)
1510
1511
1512# The actual PostgreSQL database connection interface:
1513
1514class DB:
1515    """Wrapper class for the _pg connection type."""
1516
1517    db = None  # invalid fallback for underlying connection
1518
1519    def __init__(self, *args, **kw):
1520        """Create a new connection
1521
1522        You can pass either the connection parameters or an existing
1523        _pg or pgdb connection. This allows you to use the methods
1524        of the classic pg interface with a DB-API 2 pgdb connection.
1525        """
1526        if not args and len(kw) == 1:
1527            db = kw.get('db')
1528        elif not kw and len(args) == 1:
1529            db = args[0]
1530        else:
1531            db = None
1532        if db:
1533            if isinstance(db, DB):
1534                db = db.db
1535            else:
1536                try:
1537                    db = db._cnx
1538                except AttributeError:
1539                    pass
1540        if not db or not hasattr(db, 'db') or not hasattr(db, 'query'):
1541            db = connect(*args, **kw)
1542            self._db_args = args, kw
1543            self._closeable = True
1544        else:
1545            self._db_args = db
1546            self._closeable = False
1547        self.db = db
1548        self.dbname = db.db
1549        self._regtypes = False
1550        self._attnames = {}
1551        self._pkeys = {}
1552        self._privileges = {}
1553        self.adapter = Adapter(self)
1554        self.dbtypes = DbTypes(self)
1555        if db.server_version < 80400:
1556            # support older remote data bases (not officially supported)
1557            self._query_attnames = (
1558                "SELECT a.attname, t.oid, t.typname, t.typname::text::regtype,"
1559                " t.typtype, null as typcategory, t.typdelim, t.typrelid"
1560                " FROM pg_catalog.pg_attribute a"
1561                " JOIN pg_catalog.pg_type t"
1562                " ON t.oid OPERATOR(pg_catalog.=) a.atttypid"
1563                " WHERE a.attrelid OPERATOR(pg_catalog.=) %s::regclass AND %s"
1564                " AND NOT a.attisdropped ORDER BY a.attnum")
1565        else:
1566            self._query_attnames = (
1567                "SELECT a.attname, t.oid, t.typname, t.typname::regtype,"
1568                " t.typtype, t.typcategory, t.typdelim, t.typrelid"
1569                " FROM pg_catalog.pg_attribute a"
1570                " JOIN pg_catalog.pg_type t"
1571                " ON t.oid OPERATOR(pg_catalog.=) a.atttypid"
1572                " WHERE a.attrelid OPERATOR(pg_catalog.=) %s::regclass AND %s"
1573                " AND NOT a.attisdropped ORDER BY a.attnum")
1574        db.set_cast_hook(self.dbtypes.typecast)
1575        self.debug = None  # For debugging scripts, this can be set
1576            # * to a string format specification (e.g. in CGI set to "%s<BR>"),
1577            # * to a file object to write debug statements or
1578            # * to a callable object which takes a string argument
1579            # * to any other true value to just print debug statements
1580
1581    def __getattr__(self, name):
1582        # All undefined members are same as in underlying connection:
1583        if self.db:
1584            return getattr(self.db, name)
1585        else:
1586            raise _int_error('Connection is not valid')
1587
1588    def __dir__(self):
1589        # Custom dir function including the attributes of the connection:
1590        attrs = set(self.__class__.__dict__)
1591        attrs.update(self.__dict__)
1592        attrs.update(dir(self.db))
1593        return sorted(attrs)
1594
1595    # Context manager methods
1596
1597    def __enter__(self):
1598        """Enter the runtime context. This will start a transaction."""
1599        self.begin()
1600        return self
1601
1602    def __exit__(self, et, ev, tb):
1603        """Exit the runtime context. This will end the transaction."""
1604        if et is None and ev is None and tb is None:
1605            self.commit()
1606        else:
1607            self.rollback()
1608
1609    def __del__(self):
1610        try:
1611            db = self.db
1612        except AttributeError:
1613            db = None
1614        if db:
1615            try:
1616                db.set_cast_hook(None)
1617            except TypeError:
1618                pass  # probably already closed
1619            if self._closeable:
1620                try:
1621                    db.close()
1622                except InternalError:
1623                    pass  # probably already closed
1624
1625    # Auxiliary methods
1626
1627    def _do_debug(self, *args):
1628        """Print a debug message"""
1629        if self.debug:
1630            s = '\n'.join(str(arg) for arg in args)
1631            if isinstance(self.debug, basestring):
1632                print(self.debug % s)
1633            elif hasattr(self.debug, 'write'):
1634                self.debug.write(s + '\n')
1635            elif callable(self.debug):
1636                self.debug(s)
1637            else:
1638                print(s)
1639
1640    def _escape_qualified_name(self, s):
1641        """Escape a qualified name.
1642
1643        Escapes the name for use as an SQL identifier, unless the
1644        name contains a dot, in which case the name is ambiguous
1645        (could be a qualified name or just a name with a dot in it)
1646        and must be quoted manually by the caller.
1647        """
1648        if '.' not in s:
1649            s = self.escape_identifier(s)
1650        return s
1651
1652    @staticmethod
1653    def _make_bool(d):
1654        """Get boolean value corresponding to d."""
1655        return bool(d) if get_bool() else ('t' if d else 'f')
1656
1657    def _list_params(self, params):
1658        """Create a human readable parameter list."""
1659        return ', '.join('$%d=%r' % (n, v) for n, v in enumerate(params, 1))
1660
1661    # Public methods
1662
1663    # escape_string and escape_bytea exist as methods,
1664    # so we define unescape_bytea as a method as well
1665    unescape_bytea = staticmethod(unescape_bytea)
1666
1667    def decode_json(self, s):
1668        """Decode a JSON string coming from the database."""
1669        return (get_jsondecode() or jsondecode)(s)
1670
1671    def encode_json(self, d):
1672        """Encode a JSON string for use within SQL."""
1673        return jsonencode(d)
1674
1675    def close(self):
1676        """Close the database connection."""
1677        # Wraps shared library function so we can track state.
1678        db = self.db
1679        if db:
1680            try:
1681                db.set_cast_hook(None)
1682            except TypeError:
1683                pass  # probably already closed
1684            if self._closeable:
1685                db.close()
1686            self.db = None
1687        else:
1688            raise _int_error('Connection already closed')
1689
1690    def reset(self):
1691        """Reset connection with current parameters.
1692
1693        All derived queries and large objects derived from this connection
1694        will not be usable after this call.
1695
1696        """
1697        if self.db:
1698            self.db.reset()
1699        else:
1700            raise _int_error('Connection already closed')
1701
1702    def reopen(self):
1703        """Reopen connection to the database.
1704
1705        Used in case we need another connection to the same database.
1706        Note that we can still reopen a database that we have closed.
1707
1708        """
1709        # There is no such shared library function.
1710        if self._closeable:
1711            db = connect(*self._db_args[0], **self._db_args[1])
1712            if self.db:
1713                self.db.set_cast_hook(None)
1714                self.db.close()
1715            db.set_cast_hook(self.dbtypes.typecast)
1716            self.db = db
1717        else:
1718            self.db = self._db_args
1719
1720    def begin(self, mode=None):
1721        """Begin a transaction."""
1722        qstr = 'BEGIN'
1723        if mode:
1724            qstr += ' ' + mode
1725        return self.query(qstr)
1726
1727    start = begin
1728
1729    def commit(self):
1730        """Commit the current transaction."""
1731        return self.query('COMMIT')
1732
1733    end = commit
1734
1735    def rollback(self, name=None):
1736        """Roll back the current transaction."""
1737        qstr = 'ROLLBACK'
1738        if name:
1739            qstr += ' TO ' + name
1740        return self.query(qstr)
1741
1742    abort = rollback
1743
1744    def savepoint(self, name):
1745        """Define a new savepoint within the current transaction."""
1746        return self.query('SAVEPOINT ' + name)
1747
1748    def release(self, name):
1749        """Destroy a previously defined savepoint."""
1750        return self.query('RELEASE ' + name)
1751
1752    def get_parameter(self, parameter):
1753        """Get the value of a run-time parameter.
1754
1755        If the parameter is a string, the return value will also be a string
1756        that is the current setting of the run-time parameter with that name.
1757
1758        You can get several parameters at once by passing a list, set or dict.
1759        When passing a list of parameter names, the return value will be a
1760        corresponding list of parameter settings.  When passing a set of
1761        parameter names, a new dict will be returned, mapping these parameter
1762        names to their settings.  Finally, if you pass a dict as parameter,
1763        its values will be set to the current parameter settings corresponding
1764        to its keys.
1765
1766        By passing the special name 'all' as the parameter, you can get a dict
1767        of all existing configuration parameters.
1768        """
1769        if isinstance(parameter, basestring):
1770            parameter = [parameter]
1771            values = None
1772        elif isinstance(parameter, (list, tuple)):
1773            values = []
1774        elif isinstance(parameter, (set, frozenset)):
1775            values = {}
1776        elif isinstance(parameter, dict):
1777            values = parameter
1778        else:
1779            raise TypeError(
1780                'The parameter must be a string, list, set or dict')
1781        if not parameter:
1782            raise TypeError('No parameter has been specified')
1783        params = {} if isinstance(values, dict) else []
1784        for key in parameter:
1785            param = key.strip().lower() if isinstance(
1786                key, basestring) else None
1787            if not param:
1788                raise TypeError('Invalid parameter')
1789            if param == 'all':
1790                q = 'SHOW ALL'
1791                values = self.db.query(q).getresult()
1792                values = dict(value[:2] for value in values)
1793                break
1794            if isinstance(values, dict):
1795                params[param] = key
1796            else:
1797                params.append(param)
1798        else:
1799            for param in params:
1800                q = 'SHOW %s' % (param,)
1801                value = self.db.query(q).getresult()[0][0]
1802                if values is None:
1803                    values = value
1804                elif isinstance(values, list):
1805                    values.append(value)
1806                else:
1807                    values[params[param]] = value
1808        return values
1809
1810    def set_parameter(self, parameter, value=None, local=False):
1811        """Set the value of a run-time parameter.
1812
1813        If the parameter and the value are strings, the run-time parameter
1814        will be set to that value.  If no value or None is passed as a value,
1815        then the run-time parameter will be restored to its default value.
1816
1817        You can set several parameters at once by passing a list of parameter
1818        names, together with a single value that all parameters should be
1819        set to or with a corresponding list of values.  You can also pass
1820        the parameters as a set if you only provide a single value.
1821        Finally, you can pass a dict with parameter names as keys.  In this
1822        case, you should not pass a value, since the values for the parameters
1823        will be taken from the dict.
1824
1825        By passing the special name 'all' as the parameter, you can reset
1826        all existing settable run-time parameters to their default values.
1827
1828        If you set local to True, then the command takes effect for only the
1829        current transaction.  After commit() or rollback(), the session-level
1830        setting takes effect again.  Setting local to True will appear to
1831        have no effect if it is executed outside a transaction, since the
1832        transaction will end immediately.
1833        """
1834        if isinstance(parameter, basestring):
1835            parameter = {parameter: value}
1836        elif isinstance(parameter, (list, tuple)):
1837            if isinstance(value, (list, tuple)):
1838                parameter = dict(zip(parameter, value))
1839            else:
1840                parameter = dict.fromkeys(parameter, value)
1841        elif isinstance(parameter, (set, frozenset)):
1842            if isinstance(value, (list, tuple, set, frozenset)):
1843                value = set(value)
1844                if len(value) == 1:
1845                    value = value.pop()
1846            if not(value is None or isinstance(value, basestring)):
1847                raise ValueError('A single value must be specified'
1848                    ' when parameter is a set')
1849            parameter = dict.fromkeys(parameter, value)
1850        elif isinstance(parameter, dict):
1851            if value is not None:
1852                raise ValueError('A value must not be specified'
1853                    ' when parameter is a dictionary')
1854        else:
1855            raise TypeError(
1856                'The parameter must be a string, list, set or dict')
1857        if not parameter:
1858            raise TypeError('No parameter has been specified')
1859        params = {}
1860        for key, value in parameter.items():
1861            param = key.strip().lower() if isinstance(
1862                key, basestring) else None
1863            if not param:
1864                raise TypeError('Invalid parameter')
1865            if param == 'all':
1866                if value is not None:
1867                    raise ValueError('A value must ot be specified'
1868                        " when parameter is 'all'")
1869                params = {'all': None}
1870                break
1871            params[param] = value
1872        local = ' LOCAL' if local else ''
1873        for param, value in params.items():
1874            if value is None:
1875                q = 'RESET%s %s' % (local, param)
1876            else:
1877                q = 'SET%s %s TO %s' % (local, param, value)
1878            self._do_debug(q)
1879            self.db.query(q)
1880
1881    def query(self, command, *args):
1882        """Execute a SQL command string.
1883
1884        This method simply sends a SQL query to the database.  If the query is
1885        an insert statement that inserted exactly one row into a table that
1886        has OIDs, the return value is the OID of the newly inserted row.
1887        If the query is an update or delete statement, or an insert statement
1888        that did not insert exactly one row in a table with OIDs, then the
1889        number of rows affected is returned as a string.  If it is a statement
1890        that returns rows as a result (usually a select statement, but maybe
1891        also an "insert/update ... returning" statement), this method returns
1892        a Query object that can be accessed via getresult() or dictresult()
1893        or simply printed.  Otherwise, it returns `None`.
1894
1895        The query can contain numbered parameters of the form $1 in place
1896        of any data constant.  Arguments given after the query string will
1897        be substituted for the corresponding numbered parameter.  Parameter
1898        values can also be given as a single list or tuple argument.
1899        """
1900        # Wraps shared library function for debugging.
1901        if not self.db:
1902            raise _int_error('Connection is not valid')
1903        if args:
1904            self._do_debug(command, args)
1905            return self.db.query(command, args)
1906        self._do_debug(command)
1907        return self.db.query(command)
1908
1909    def query_formatted(self, command,
1910            parameters=None, types=None, inline=False):
1911        """Execute a formatted SQL command string.
1912
1913        Similar to query, but using Python format placeholders of the form
1914        %s or %(names)s instead of PostgreSQL placeholders of the form $1.
1915        The parameters must be passed as a tuple, list or dict.  You can
1916        also pass a corresponding tuple, list or dict of database types in
1917        order to format the parameters properly in case there is ambiguity.
1918
1919        If you set inline to True, the parameters will be sent to the database
1920        embedded in the SQL command, otherwise they will be sent separately.
1921        """
1922        return self.query(*self.adapter.format_query(
1923            command, parameters, types, inline))
1924
1925    def query_prepared(self, name, *args):
1926        """Execute a prepared SQL statement.
1927
1928        This works like the query() method, except that instead of passing
1929        the SQL command, you pass the name of a prepared statement.  If you
1930        pass an empty name, the unnamed statement will be executed.
1931        """
1932        if not self.db:
1933            raise _int_error('Connection is not valid')
1934        if name is None:
1935            name = ''
1936        if args:
1937            self._do_debug('EXECUTE', name, args)
1938            return self.db.query_prepared(name, args)
1939        self._do_debug('EXECUTE', name)
1940        return self.db.query_prepared(name)
1941
1942    def prepare(self, name, command):
1943        """Create a prepared SQL statement.
1944
1945        This creates a prepared statement for the given command with the
1946        given name for later execution with the query_prepared() method.
1947
1948        The name can be empty to create an unnamed statement, in which case
1949        any pre-existing unnamed statement is automatically replaced;
1950        otherwise it is an error if the statement name is already
1951        defined in the current database session. We recommend always using
1952        named queries, since unnamed queries have a limited lifetime and
1953        can be automatically replaced or destroyed by various operations.
1954        """
1955        if not self.db:
1956            raise _int_error('Connection is not valid')
1957        if name is None:
1958            name = ''
1959        self._do_debug('prepare', name, command)
1960        return self.db.prepare(name, command)
1961
1962    def describe_prepared(self, name=None):
1963        """Describe a prepared SQL statement.
1964
1965        This method returns a Query object describing the result columns of
1966        the prepared statement with the given name. If you omit the name,
1967        the unnamed statement will be described if you created one before.
1968        """
1969        if name is None:
1970            name = ''
1971        return self.db.describe_prepared(name)
1972
1973    def delete_prepared(self, name=None):
1974        """Delete a prepared SQL statement
1975
1976        This deallocates a previously prepared SQL statement with the given
1977        name, or deallocates all prepared statements if you do not specify a
1978        name. Note that prepared statements are also deallocated automatically
1979        when the current session ends.
1980        """
1981        q = "DEALLOCATE %s" % (name or 'ALL',)
1982        self._do_debug(q)
1983        return self.db.query(q)
1984
1985    def pkey(self, table, composite=False, flush=False):
1986        """Get or set the primary key of a table.
1987
1988        Single primary keys are returned as strings unless you
1989        set the composite flag.  Composite primary keys are always
1990        represented as tuples.  Note that this raises a KeyError
1991        if the table does not have a primary key.
1992
1993        If flush is set then the internal cache for primary keys will
1994        be flushed.  This may be necessary after the database schema or
1995        the search path has been changed.
1996        """
1997        pkeys = self._pkeys
1998        if flush:
1999            pkeys.clear()
2000            self._do_debug('The pkey cache has been flushed')
2001        try:  # cache lookup
2002            pkey = pkeys[table]
2003        except KeyError:  # cache miss, check the database
2004            q = ("SELECT a.attname, a.attnum, i.indkey"
2005                " FROM pg_catalog.pg_index i"
2006                " JOIN pg_catalog.pg_attribute a"
2007                " ON a.attrelid OPERATOR(pg_catalog.=) i.indrelid"
2008                " AND a.attnum OPERATOR(pg_catalog.=) ANY(i.indkey)"
2009                " AND NOT a.attisdropped"
2010                " WHERE i.indrelid OPERATOR(pg_catalog.=) %s::regclass"
2011                " AND i.indisprimary ORDER BY a.attnum") % (
2012                    _quote_if_unqualified('$1', table),)
2013            pkey = self.db.query(q, (table,)).getresult()
2014            if not pkey:
2015                raise KeyError('Table %s has no primary key' % table)
2016            # we want to use the order defined in the primary key index here,
2017            # not the order as defined by the columns in the table
2018            if len(pkey) > 1:
2019                indkey = pkey[0][2]
2020                pkey = sorted(pkey, key=lambda row: indkey.index(row[1]))
2021                pkey = tuple(row[0] for row in pkey)
2022            else:
2023                pkey = pkey[0][0]
2024            pkeys[table] = pkey  # cache it
2025        if composite and not isinstance(pkey, tuple):
2026            pkey = (pkey,)
2027        return pkey
2028
2029    def get_databases(self):
2030        """Get list of databases in the system."""
2031        return [s[0] for s in
2032            self.db.query(
2033                'SELECT datname FROM pg_catalog.pg_database').getresult()]
2034
2035    def get_relations(self, kinds=None, system=False):
2036        """Get list of relations in connected database of specified kinds.
2037
2038        If kinds is None or empty, all kinds of relations are returned.
2039        Otherwise kinds can be a string or sequence of type letters
2040        specifying which kind of relations you want to list.
2041
2042        Set the system flag if you want to get the system relations as well.
2043        """
2044        where = []
2045        if kinds:
2046            where.append("r.relkind IN (%s)" %
2047                ','.join("'%s'" % k for k in kinds))
2048        if not system:
2049            where.append("s.nspname NOT SIMILAR"
2050                " TO 'pg/_%|information/_schema' ESCAPE '/'")
2051        where = " WHERE %s" % ' AND '.join(where) if where else ''
2052        q = ("SELECT pg_catalog.quote_ident(s.nspname) OPERATOR(pg_catalog.||)"
2053            " '.' OPERATOR(pg_catalog.||) pg_catalog.quote_ident(r.relname)"
2054            " FROM pg_catalog.pg_class r"
2055            " JOIN pg_catalog.pg_namespace s"
2056            " ON s.oid OPERATOR(pg_catalog.=) r.relnamespace%s"
2057            " ORDER BY s.nspname, r.relname") % where
2058        return [r[0] for r in self.db.query(q).getresult()]
2059
2060    def get_tables(self, system=False):
2061        """Return list of tables in connected database.
2062
2063        Set the system flag if you want to get the system tables as well.
2064        """
2065        return self.get_relations('r', system)
2066
2067    def get_attnames(self, table, with_oid=True, flush=False):
2068        """Given the name of a table, dig out the set of attribute names.
2069
2070        Returns a read-only dictionary of attribute names (the names are
2071        the keys, the values are the names of the attributes' types)
2072        with the column names in the proper order if you iterate over it.
2073
2074        If flush is set, then the internal cache for attribute names will
2075        be flushed. This may be necessary after the database schema or
2076        the search path has been changed.
2077
2078        By default, only a limited number of simple types will be returned.
2079        You can get the registered types after calling use_regtypes(True).
2080        """
2081        attnames = self._attnames
2082        if flush:
2083            attnames.clear()
2084            self._do_debug('The attnames cache has been flushed')
2085        try:  # cache lookup
2086            names = attnames[table]
2087        except KeyError:  # cache miss, check the database
2088            q = "a.attnum OPERATOR(pg_catalog.>) 0"
2089            if with_oid:
2090                q = "(%s OR a.attname OPERATOR(pg_catalog.=) 'oid')" % q
2091            q = self._query_attnames % (_quote_if_unqualified('$1', table), q)
2092            names = self.db.query(q, (table,)).getresult()
2093            types = self.dbtypes
2094            names = ((name[0], types.add(*name[1:])) for name in names)
2095            names = AttrDict(names)
2096            attnames[table] = names  # cache it
2097        return names
2098
2099    def use_regtypes(self, regtypes=None):
2100        """Use registered type names instead of simplified type names."""
2101        if regtypes is None:
2102            return self.dbtypes._regtypes
2103        else:
2104            regtypes = bool(regtypes)
2105            if regtypes != self.dbtypes._regtypes:
2106                self.dbtypes._regtypes = regtypes
2107                self._attnames.clear()
2108                self.dbtypes.clear()
2109            return regtypes
2110
2111    def has_table_privilege(self, table, privilege='select', flush=False):
2112        """Check whether current user has specified table privilege.
2113
2114        If flush is set, then the internal cache for table privileges will
2115        be flushed. This may be necessary after privileges have been changed.
2116        """
2117        privileges = self._privileges
2118        if flush:
2119            privileges.clear()
2120            self._do_debug('The privileges cache has been flushed')
2121        privilege = privilege.lower()
2122        try:  # ask cache
2123            ret = privileges[table, privilege]
2124        except KeyError:  # cache miss, ask the database
2125            q = "SELECT pg_catalog.has_table_privilege(%s, $2)" % (
2126                _quote_if_unqualified('$1', table),)
2127            q = self.db.query(q, (table, privilege))
2128            ret = q.getresult()[0][0] == self._make_bool(True)
2129            privileges[table, privilege] = ret  # cache it
2130        return ret
2131
2132    def get(self, table, row, keyname=None):
2133        """Get a row from a database table or view.
2134
2135        This method is the basic mechanism to get a single row.  It assumes
2136        that the keyname specifies a unique row.  It must be the name of a
2137        single column or a tuple of column names.  If the keyname is not
2138        specified, then the primary key for the table is used.
2139
2140        If row is a dictionary, then the value for the key is taken from it.
2141        Otherwise, the row must be a single value or a tuple of values
2142        corresponding to the passed keyname or primary key.  The fetched row
2143        from the table will be returned as a new dictionary or used to replace
2144        the existing values when row was passed as a dictionary.
2145
2146        The OID is also put into the dictionary if the table has one, but
2147        in order to allow the caller to work with multiple tables, it is
2148        munged as "oid(table)" using the actual name of the table.
2149        """
2150        if table.endswith('*'):  # hint for descendant tables can be ignored
2151            table = table[:-1].rstrip()
2152        attnames = self.get_attnames(table)
2153        qoid = _oid_key(table) if 'oid' in attnames else None
2154        if keyname and isinstance(keyname, basestring):
2155            keyname = (keyname,)
2156        if qoid and isinstance(row, dict) and qoid in row and 'oid' not in row:
2157            row['oid'] = row[qoid]
2158        if not keyname:
2159            try:  # if keyname is not specified, try using the primary key
2160                keyname = self.pkey(table, True)
2161            except KeyError:  # the table has no primary key
2162                # try using the oid instead
2163                if qoid and isinstance(row, dict) and 'oid' in row:
2164                    keyname = ('oid',)
2165                else:
2166                    raise _prg_error('Table %s has no primary key' % table)
2167            else:  # the table has a primary key
2168                # check whether all key columns have values
2169                if isinstance(row, dict) and not set(keyname).issubset(row):
2170                    # try using the oid instead
2171                    if qoid and 'oid' in row:
2172                        keyname = ('oid',)
2173                    else:
2174                        raise KeyError(
2175                            'Missing value in row for specified keyname')
2176        if not isinstance(row, dict):
2177            if not isinstance(row, (tuple, list)):
2178                row = [row]
2179            if len(keyname) != len(row):
2180                raise KeyError(
2181                    'Differing number of items in keyname and row')
2182            row = dict(zip(keyname, row))
2183        params = self.adapter.parameter_list()
2184        adapt = params.add
2185        col = self.escape_identifier
2186        what = 'oid, *' if qoid else '*'
2187        where = ' AND '.join('%s OPERATOR(pg_catalog.=) %s' % (
2188            col(k), adapt(row[k], attnames[k])) for k in keyname)
2189        if 'oid' in row:
2190            if qoid:
2191                row[qoid] = row['oid']
2192            del row['oid']
2193        q = 'SELECT %s FROM %s WHERE %s LIMIT 1' % (
2194            what, self._escape_qualified_name(table), where)
2195        self._do_debug(q, params)
2196        q = self.db.query(q, params)
2197        res = q.dictresult()
2198        if not res:
2199            # make where clause in error message better readable
2200            where = where.replace('OPERATOR(pg_catalog.=)', '=')
2201            raise _db_error('No such record in %s\nwhere %s\nwith %s' % (
2202                table, where, self._list_params(params)))
2203        for n, value in res[0].items():
2204            if qoid and n == 'oid':
2205                n = qoid
2206            row[n] = value
2207        return row
2208
2209    def insert(self, table, row=None, **kw):
2210        """Insert a row into a database table.
2211
2212        This method inserts a row into a table.  The name of the table must
2213        be passed as the first parameter.  The other parameters are used for
2214        providing the data of the row that shall be inserted into the table.
2215        If a dictionary is supplied as the second parameter, it starts with
2216        that.  Otherwise it uses a blank dictionary. Either way the dictionary
2217        is updated from the keywords.
2218
2219        The dictionary is then reloaded with the values actually inserted in
2220        order to pick up values modified by rules, triggers, etc.
2221        """
2222        if table.endswith('*'):  # hint for descendant tables can be ignored
2223            table = table[:-1].rstrip()
2224        if row is None:
2225            row = {}
2226        row.update(kw)
2227        if 'oid' in row:
2228            del row['oid']  # do not insert oid
2229        attnames = self.get_attnames(table)
2230        qoid = _oid_key(table) if 'oid' in attnames else None
2231        params = self.adapter.parameter_list()
2232        adapt = params.add
2233        col = self.escape_identifier
2234        names, values = [], []
2235        for n in attnames:
2236            if n in row:
2237                names.append(col(n))
2238                values.append(adapt(row[n], attnames[n]))
2239        if not names:
2240            raise _prg_error('No column found that can be inserted')
2241        names, values = ', '.join(names), ', '.join(values)
2242        ret = 'oid, *' if qoid else '*'
2243        q = 'INSERT INTO %s (%s) VALUES (%s) RETURNING %s' % (
2244            self._escape_qualified_name(table), names, values, ret)
2245        self._do_debug(q, params)
2246        q = self.db.query(q, params)
2247        res = q.dictresult()
2248        if res:  # this should always be true
2249            for n, value in res[0].items():
2250                if qoid and n == 'oid':
2251                    n = qoid
2252                row[n] = value
2253        return row
2254
2255    def update(self, table, row=None, **kw):
2256        """Update an existing row in a database table.
2257
2258        Similar to insert, but updates an existing row.  The update is based
2259        on the primary key of the table or the OID value as munged by get()
2260        or passed as keyword.  The OID will take precedence if provided, so
2261        that it is possible to update the primary key itself.
2262
2263        The dictionary is then modified to reflect any changes caused by the
2264        update due to triggers, rules, default values, etc.
2265        """
2266        if table.endswith('*'):
2267            table = table[:-1].rstrip()  # need parent table name
2268        attnames = self.get_attnames(table)
2269        qoid = _oid_key(table) if 'oid' in attnames else None
2270        if row is None:
2271            row = {}
2272        elif 'oid' in row:
2273            del row['oid']  # only accept oid key from named args for safety
2274        row.update(kw)
2275        if qoid and qoid in row and 'oid' not in row:
2276            row['oid'] = row[qoid]
2277        if qoid and 'oid' in row:  # try using the oid
2278            keyname = ('oid',)
2279        else:  # try using the primary key
2280            try:
2281                keyname = self.pkey(table, True)
2282            except KeyError:  # the table has no primary key
2283                raise _prg_error('Table %s has no primary key' % table)
2284            # check whether all key columns have values
2285            if not set(keyname).issubset(row):
2286                raise KeyError('Missing value for primary key in row')
2287        params = self.adapter.parameter_list()
2288        adapt = params.add
2289        col = self.escape_identifier
2290        where = ' AND '.join('%s OPERATOR(pg_catalog.=) %s' % (
2291            col(k), adapt(row[k], attnames[k])) for k in keyname)
2292        if 'oid' in row:
2293            if qoid:
2294                row[qoid] = row['oid']
2295            del row['oid']
2296        values = []
2297        keyname = set(keyname)
2298        for n in attnames:
2299            if n in row and n not in keyname:
2300                values.append('%s = %s' % (col(n), adapt(row[n], attnames[n])))
2301        if not values:
2302            return row
2303        values = ', '.join(values)
2304        ret = 'oid, *' if qoid else '*'
2305        q = 'UPDATE %s SET %s WHERE %s RETURNING %s' % (
2306            self._escape_qualified_name(table), values, where, ret)
2307        self._do_debug(q, params)
2308        q = self.db.query(q, params)
2309        res = q.dictresult()
2310        if res:  # may be empty when row does not exist
2311            for n, value in res[0].items():
2312                if qoid and n == 'oid':
2313                    n = qoid
2314                row[n] = value
2315        return row
2316
2317    def upsert(self, table, row=None, **kw):
2318        """Insert a row into a database table with conflict resolution
2319
2320        This method inserts a row into a table, but instead of raising a
2321        ProgrammingError exception in case a row with the same primary key
2322        already exists, an update will be executed instead.  This will be
2323        performed as a single atomic operation on the database, so race
2324        conditions can be avoided.
2325
2326        Like the insert method, the first parameter is the name of the
2327        table and the second parameter can be used to pass the values to
2328        be inserted as a dictionary.
2329
2330        Unlike the insert und update statement, keyword parameters are not
2331        used to modify the dictionary, but to specify which columns shall
2332        be updated in case of a conflict, and in which way:
2333
2334        A value of False or None means the column shall not be updated,
2335        a value of True means the column shall be updated with the value
2336        that has been proposed for insertion, i.e. has been passed as value
2337        in the dictionary.  Columns that are not specified by keywords but
2338        appear as keys in the dictionary are also updated like in the case
2339        keywords had been passed with the value True.
2340
2341        So if in the case of a conflict you want to update every column that
2342        has been passed in the dictionary row, you would call upsert(table, row).
2343        If you don't want to do anything in case of a conflict, i.e. leave
2344        the existing row as it is, call upsert(table, row, **dict.fromkeys(row)).
2345
2346        If you need more fine-grained control of what gets updated, you can
2347        also pass strings in the keyword parameters.  These strings will
2348        be used as SQL expressions for the update columns.  In these
2349        expressions you can refer to the value that already exists in
2350        the table by prefixing the column name with "included.", and to
2351        the value that has been proposed for insertion by prefixing the
2352        column name with the "excluded."
2353
2354        The dictionary is modified in any case to reflect the values in
2355        the database after the operation has completed.
2356
2357        Note: The method uses the PostgreSQL "upsert" feature which is
2358        only available since PostgreSQL 9.5.
2359        """
2360        if table.endswith('*'):  # hint for descendant tables can be ignored
2361            table = table[:-1].rstrip()
2362        if row is None:
2363            row = {}
2364        if 'oid' in row:
2365            del row['oid']  # do not insert oid
2366        if 'oid' in kw:
2367            del kw['oid']  # do not update oid
2368        attnames = self.get_attnames(table)
2369        qoid = _oid_key(table) if 'oid' in attnames else None
2370        params = self.adapter.parameter_list()
2371        adapt = params.add
2372        col = self.escape_identifier
2373        names, values, updates = [], [], []
2374        for n in attnames:
2375            if n in row:
2376                names.append(col(n))
2377                values.append(adapt(row[n], attnames[n]))
2378        names, values = ', '.join(names), ', '.join(values)
2379        try:
2380            keyname = self.pkey(table, True)
2381        except KeyError:
2382            raise _prg_error('Table %s has no primary key' % table)
2383        target = ', '.join(col(k) for k in keyname)
2384        update = []
2385        keyname = set(keyname)
2386        keyname.add('oid')
2387        for n in attnames:
2388            if n not in keyname:
2389                value = kw.get(n, True)
2390                if value:
2391                    if not isinstance(value, basestring):
2392                        value = 'excluded.%s' % col(n)
2393                    update.append('%s = %s' % (col(n), value))
2394        if not values:
2395            return row
2396        do = 'update set %s' % ', '.join(update) if update else 'nothing'
2397        ret = 'oid, *' if qoid else '*'
2398        q = ('INSERT INTO %s AS included (%s) VALUES (%s)'
2399            ' ON CONFLICT (%s) DO %s RETURNING %s') % (
2400                self._escape_qualified_name(table), names, values,
2401                target, do, ret)
2402        self._do_debug(q, params)
2403        try:
2404            q = self.db.query(q, params)
2405        except ProgrammingError:
2406            if self.server_version < 90500:
2407                raise _prg_error(
2408                    'Upsert operation is not supported by PostgreSQL version')
2409            raise  # re-raise original error
2410        res = q.dictresult()
2411        if res:  # may be empty with "do nothing"
2412            for n, value in res[0].items():
2413                if qoid and n == 'oid':
2414                    n = qoid
2415                row[n] = value
2416        else:
2417            self.get(table, row)
2418        return row
2419
2420    def clear(self, table, row=None):
2421        """Clear all the attributes to values determined by the types.
2422
2423        Numeric types are set to 0, Booleans are set to false, and everything
2424        else is set to the empty string.  If the row argument is present,
2425        it is used as the row dictionary and any entries matching attribute
2426        names are cleared with everything else left unchanged.
2427        """
2428        # At some point we will need a way to get defaults from a table.
2429        if row is None:
2430            row = {}  # empty if argument is not present
2431        attnames = self.get_attnames(table)
2432        for n, t in attnames.items():
2433            if n == 'oid':
2434                continue
2435            t = t.simple
2436            if t in DbTypes._num_types:
2437                row[n] = 0
2438            elif t == 'bool':
2439                row[n] = self._make_bool(False)
2440            else:
2441                row[n] = ''
2442        return row
2443
2444    def delete(self, table, row=None, **kw):
2445        """Delete an existing row in a database table.
2446
2447        This method deletes the row from a table.  It deletes based on the
2448        primary key of the table or the OID value as munged by get() or
2449        passed as keyword.  The OID will take precedence if provided.
2450
2451        The return value is the number of deleted rows (i.e. 0 if the row
2452        did not exist and 1 if the row was deleted).
2453
2454        Note that if the row cannot be deleted because e.g. it is still
2455        referenced by another table, this method raises a ProgrammingError.
2456        """
2457        if table.endswith('*'):  # hint for descendant tables can be ignored
2458            table = table[:-1].rstrip()
2459        attnames = self.get_attnames(table)
2460        qoid = _oid_key(table) if 'oid' in attnames else None
2461        if row is None:
2462            row = {}
2463        elif 'oid' in row:
2464            del row['oid']  # only accept oid key from named args for safety
2465        row.update(kw)
2466        if qoid and qoid in row and 'oid' not in row:
2467            row['oid'] = row[qoid]
2468        if qoid and 'oid' in row:  # try using the oid
2469            keyname = ('oid',)
2470        else:  # try using the primary key
2471            try:
2472                keyname = self.pkey(table, True)
2473            except KeyError:  # the table has no primary key
2474                raise _prg_error('Table %s has no primary key' % table)
2475            # check whether all key columns have values
2476            if not set(keyname).issubset(row):
2477                raise KeyError('Missing value for primary key in row')
2478        params = self.adapter.parameter_list()
2479        adapt = params.add
2480        col = self.escape_identifier
2481        where = ' AND '.join('%s OPERATOR(pg_catalog.=) %s' % (
2482            col(k), adapt(row[k], attnames[k])) for k in keyname)
2483        if 'oid' in row:
2484            if qoid:
2485                row[qoid] = row['oid']
2486            del row['oid']
2487        q = 'DELETE FROM %s WHERE %s' % (
2488            self._escape_qualified_name(table), where)
2489        self._do_debug(q, params)
2490        res = self.db.query(q, params)
2491        return int(res)
2492
2493    def truncate(self, table, restart=False, cascade=False, only=False):
2494        """Empty a table or set of tables.
2495
2496        This method quickly removes all rows from the given table or set
2497        of tables.  It has the same effect as an unqualified DELETE on each
2498        table, but since it does not actually scan the tables it is faster.
2499        Furthermore, it reclaims disk space immediately, rather than requiring
2500        a subsequent VACUUM operation. This is most useful on large tables.
2501
2502        If restart is set to True, sequences owned by columns of the truncated
2503        table(s) are automatically restarted.  If cascade is set to True, it
2504        also truncates all tables that have foreign-key references to any of
2505        the named tables.  If the parameter only is not set to True, all the
2506        descendant tables (if any) will also be truncated. Optionally, a '*'
2507        can be specified after the table name to explicitly indicate that
2508        descendant tables are included.
2509        """
2510        if isinstance(table, basestring):
2511            only = {table: only}
2512            table = [table]
2513        elif isinstance(table, (list, tuple)):
2514            if isinstance(only, (list, tuple)):
2515                only = dict(zip(table, only))
2516            else:
2517                only = dict.fromkeys(table, only)
2518        elif isinstance(table, (set, frozenset)):
2519            only = dict.fromkeys(table, only)
2520        else:
2521            raise TypeError('The table must be a string, list or set')
2522        if not (restart is None or isinstance(restart, (bool, int))):
2523            raise TypeError('Invalid type for the restart option')
2524        if not (cascade is None or isinstance(cascade, (bool, int))):
2525            raise TypeError('Invalid type for the cascade option')
2526        tables = []
2527        for t in table:
2528            u = only.get(t)
2529            if not (u is None or isinstance(u, (bool, int))):
2530                raise TypeError('Invalid type for the only option')
2531            if t.endswith('*'):
2532                if u:
2533                    raise ValueError(
2534                        'Contradictory table name and only options')
2535                t = t[:-1].rstrip()
2536            t = self._escape_qualified_name(t)
2537            if u:
2538                t = 'ONLY %s' % t
2539            tables.append(t)
2540        q = ['TRUNCATE', ', '.join(tables)]
2541        if restart:
2542            q.append('RESTART IDENTITY')
2543        if cascade:
2544            q.append('CASCADE')
2545        q = ' '.join(q)
2546        self._do_debug(q)
2547        return self.db.query(q)
2548
2549    def get_as_list(self, table, what=None, where=None,
2550            order=None, limit=None, offset=None, scalar=False):
2551        """Get a table as a list.
2552
2553        This gets a convenient representation of the table as a list
2554        of named tuples in Python.  You only need to pass the name of
2555        the table (or any other SQL expression returning rows).  Note that
2556        by default this will return the full content of the table which
2557        can be huge and overflow your memory.  However, you can control
2558        the amount of data returned using the other optional parameters.
2559
2560        The parameter 'what' can restrict the query to only return a
2561        subset of the table columns.  It can be a string, list or a tuple.
2562        The parameter 'where' can restrict the query to only return a
2563        subset of the table rows.  It can be a string, list or a tuple
2564        of SQL expressions that all need to be fulfilled.  The parameter
2565        'order' specifies the ordering of the rows.  It can also be a
2566        other string, list or a tuple.  If no ordering is specified,
2567        the result will be ordered by the primary key(s) or all columns
2568        if no primary key exists.  You can set 'order' to False if you
2569        don't care about the ordering.  The parameters 'limit' and 'offset'
2570        can be integers specifying the maximum number of rows returned
2571        and a number of rows skipped over.
2572
2573        If you set the 'scalar' option to True, then instead of the
2574        named tuples you will get the first items of these tuples.
2575        This is useful if the result has only one column anyway.
2576        """
2577        if not table:
2578            raise TypeError('The table name is missing')
2579        if what:
2580            if isinstance(what, (list, tuple)):
2581                what = ', '.join(map(str, what))
2582            if order is None:
2583                order = what
2584        else:
2585            what = '*'
2586        q = ['SELECT', what, 'FROM', table]
2587        if where:
2588            if isinstance(where, (list, tuple)):
2589                where = ' AND '.join(map(str, where))
2590            q.extend(['WHERE', where])
2591        if order is None:
2592            try:
2593                order = self.pkey(table, True)
2594            except (KeyError, ProgrammingError):
2595                try:
2596                    order = list(self.get_attnames(table))
2597                except (KeyError, ProgrammingError):
2598                    pass
2599        if order:
2600            if isinstance(order, (list, tuple)):
2601                order = ', '.join(map(str, order))
2602            q.extend(['ORDER BY', order])
2603        if limit:
2604            q.append('LIMIT %d' % limit)
2605        if offset:
2606            q.append('OFFSET %d' % offset)
2607        q = ' '.join(q)
2608        self._do_debug(q)
2609        q = self.db.query(q)
2610        res = q.namedresult()
2611        if res and scalar:
2612            res = [row[0] for row in res]
2613        return res
2614
2615    def get_as_dict(self, table, keyname=None, what=None, where=None,
2616            order=None, limit=None, offset=None, scalar=False):
2617        """Get a table as a dictionary.
2618
2619        This method is similar to get_as_list(), but returns the table
2620        as a Python dict instead of a Python list, which can be even
2621        more convenient. The primary key column(s) of the table will
2622        be used as the keys of the dictionary, while the other column(s)
2623        will be the corresponding values.  The keys will be named tuples
2624        if the table has a composite primary key.  The rows will be also
2625        named tuples unless the 'scalar' option has been set to True.
2626        With the optional parameter 'keyname' you can specify an alternative
2627        set of columns to be used as the keys of the dictionary.  It must
2628        be set as a string, list or a tuple.
2629
2630        If the Python version supports it, the dictionary will be an
2631        OrderedDict using the order specified with the 'order' parameter
2632        or the key column(s) if not specified.  You can set 'order' to False
2633        if you don't care about the ordering.  In this case the returned
2634        dictionary will be an ordinary one.
2635        """
2636        if not table:
2637            raise TypeError('The table name is missing')
2638        if not keyname:
2639            try:
2640                keyname = self.pkey(table, True)
2641            except (KeyError, ProgrammingError):
2642                raise _prg_error('Table %s has no primary key' % table)
2643        if isinstance(keyname, basestring):
2644            keyname = [keyname]
2645        elif not isinstance(keyname, (list, tuple)):
2646            raise KeyError('The keyname must be a string, list or tuple')
2647        if what:
2648            if isinstance(what, (list, tuple)):
2649                what = ', '.join(map(str, what))
2650            if order is None:
2651                order = what
2652        else:
2653            what = '*'
2654        q = ['SELECT', what, 'FROM', table]
2655        if where:
2656            if isinstance(where, (list, tuple)):
2657                where = ' AND '.join(map(str, where))
2658            q.extend(['WHERE', where])
2659        if order is None:
2660            order = keyname
2661        if order:
2662            if isinstance(order, (list, tuple)):
2663                order = ', '.join(map(str, order))
2664            q.extend(['ORDER BY', order])
2665        if limit:
2666            q.append('LIMIT %d' % limit)
2667        if offset:
2668            q.append('OFFSET %d' % offset)
2669        q = ' '.join(q)
2670        self._do_debug(q)
2671        q = self.db.query(q)
2672        res = q.getresult()
2673        cls = OrderedDict if order else dict
2674        if not res:
2675            return cls()
2676        keyset = set(keyname)
2677        fields = q.listfields()
2678        if not keyset.issubset(fields):
2679            raise KeyError('Missing keyname in row')
2680        keyind, rowind = [], []
2681        for i, f in enumerate(fields):
2682            (keyind if f in keyset else rowind).append(i)
2683        keytuple = len(keyind) > 1
2684        getkey = itemgetter(*keyind)
2685        keys = map(getkey, res)
2686        if scalar:
2687            rowind = rowind[:1]
2688            rowtuple = False
2689        else:
2690            rowtuple = len(rowind) > 1
2691        if scalar or rowtuple:
2692            getrow = itemgetter(*rowind)
2693        else:
2694            rowind = rowind[0]
2695            getrow = lambda row: (row[rowind],)
2696            rowtuple = True
2697        rows = map(getrow, res)
2698        if keytuple or rowtuple:
2699            if keytuple:
2700                keys = _namediter(_MemoryQuery(keys, keyname))
2701            if rowtuple:
2702                fields = [f for f in fields if f not in keyset]
2703                rows = _namediter(_MemoryQuery(rows, fields))
2704        return cls(zip(keys, rows))
2705
2706    def notification_handler(self,
2707            event, callback, arg_dict=None, timeout=None, stop_event=None):
2708        """Get notification handler that will run the given callback."""
2709        return NotificationHandler(self,
2710            event, callback, arg_dict, timeout, stop_event)
2711
2712
2713# if run as script, print some information
2714
2715if __name__ == '__main__':
2716    print('PyGreSQL version' + version)
2717    print('')
2718    print(__doc__)
Note: See TracBrowser for help on using the repository browser.