source: trunk/pg.py @ 876

Last change on this file since 876 was 876, checked in by cito, 3 years ago

Fixed issues when remote server version < 9.0

Though this is not officially supported and tested,
sometimes you just have to access those legacy databases.

  • Property svn:keywords set to Id
File size: 89.1 KB
Line 
1#! /usr/bin/python
2#
3# pg.py
4#
5# $Id: pg.py 876 2016-07-14 12:21:19Z cito $
6#
7
8"""PyGreSQL classic interface.
9
10This pg module implements some basic database management stuff.
11It includes the _pg module and builds on it, providing the higher
12level wrapper class named DB with additional functionality.
13This is known as the "classic" ("old style") PyGreSQL interface.
14For a DB-API 2 compliant interface use the newer pgdb module.
15"""
16
17# Copyright (c) 1997-2016 by D'Arcy J.M. Cain.
18#
19# Contributions made by Ch. Zwerschke and others.
20#
21# The notification handler is based on pgnotify which is
22# Copyright (c) 2001 Ng Pheng Siong. All rights reserved.
23#
24# Permission to use, copy, modify, and distribute this software and its
25# documentation for any purpose and without fee is hereby granted,
26# provided that the above copyright notice appear in all copies and that
27# both that copyright notice and this permission notice appear in
28# supporting documentation.
29
30from __future__ import print_function, division
31
32from _pg import *
33
34__version__ = version
35
36import select
37import warnings
38
39from datetime import date, time, datetime, timedelta, tzinfo
40from decimal import Decimal
41from math import isnan, isinf
42from collections import namedtuple
43from operator import itemgetter
44from functools import partial
45from re import compile as regex
46from json import loads as jsondecode, dumps as jsonencode
47from uuid import UUID
48
49try:
50    long
51except NameError:  # Python >= 3.0
52    long = int
53
54try:
55    basestring
56except NameError:  # Python >= 3.0
57    basestring = (str, bytes)
58
59
60# Auxiliary classes and functions that are independent from a DB connection:
61
62try:
63    from collections import OrderedDict
64except ImportError:  # Python 2.6 or 3.0
65    OrderedDict = dict
66
67
68    class AttrDict(dict):
69        """Simple read-only ordered dictionary for storing attribute names."""
70
71        def __init__(self, *args, **kw):
72            if len(args) > 1 or kw:
73                raise TypeError
74            items = args[0] if args else []
75            if isinstance(items, dict):
76                raise TypeError
77            items = list(items)
78            self._keys = [item[0] for item in items]
79            dict.__init__(self, items)
80            self._read_only = True
81            error = self._read_only_error
82            self.clear = self.update = error
83            self.pop = self.setdefault = self.popitem = error
84
85        def __setitem__(self, key, value):
86            if self._read_only:
87                self._read_only_error()
88            dict.__setitem__(self, key, value)
89
90        def __delitem__(self, key):
91            if self._read_only:
92                self._read_only_error()
93            dict.__delitem__(self, key)
94
95        def __iter__(self):
96            return iter(self._keys)
97
98        def keys(self):
99            return list(self._keys)
100
101        def values(self):
102            return [self[key] for key in self]
103
104        def items(self):
105            return [(key, self[key]) for key in self]
106
107        def iterkeys(self):
108            return self.__iter__()
109
110        def itervalues(self):
111            return iter(self.values())
112
113        def iteritems(self):
114            return iter(self.items())
115
116        @staticmethod
117        def _read_only_error(*args, **kw):
118            raise TypeError('This object is read-only')
119
120else:
121
122     class AttrDict(OrderedDict):
123        """Simple read-only ordered dictionary for storing attribute names."""
124
125        def __init__(self, *args, **kw):
126            self._read_only = False
127            OrderedDict.__init__(self, *args, **kw)
128            self._read_only = True
129            error = self._read_only_error
130            self.clear = self.update = error
131            self.pop = self.setdefault = self.popitem = error
132
133        def __setitem__(self, key, value):
134            if self._read_only:
135                self._read_only_error()
136            OrderedDict.__setitem__(self, key, value)
137
138        def __delitem__(self, key):
139            if self._read_only:
140                self._read_only_error()
141            OrderedDict.__delitem__(self, key)
142
143        @staticmethod
144        def _read_only_error(*args, **kw):
145            raise TypeError('This object is read-only')
146
147try:
148    from inspect import signature
149except ImportError:  # Python < 3.3
150    from inspect import getargspec
151
152    def get_args(func):
153        return getargspec(func).args
154else:
155
156    def get_args(func):
157        return list(signature(func).parameters)
158
159try:
160    from datetime import timezone
161except ImportError:  # Python < 3.2
162
163    class timezone(tzinfo):
164        """Simple timezone implementation."""
165
166        def __init__(self, offset, name=None):
167            self.offset = offset
168            if not name:
169                minutes = self.offset.days * 1440 + self.offset.seconds // 60
170                if minutes < 0:
171                    hours, minutes = divmod(-minutes, 60)
172                    hours = -hours
173                else:
174                    hours, minutes = divmod(minutes, 60)
175                name = 'UTC%+03d:%02d' % (hours, minutes)
176            self.name = name
177
178        def utcoffset(self, dt):
179            return self.offset
180
181        def tzname(self, dt):
182            return self.name
183
184        def dst(self, dt):
185            return None
186
187    timezone.utc = timezone(timedelta(0), 'UTC')
188
189    _has_timezone = False
190else:
191    _has_timezone = True
192
193# time zones used in Postgres timestamptz output
194_timezones = dict(CET='+0100', EET='+0200', EST='-0500',
195    GMT='+0000', HST='-1000', MET='+0100', MST='-0700',
196    UCT='+0000', UTC='+0000', WET='+0000')
197
198
199def _timezone_as_offset(tz):
200    if tz.startswith(('+', '-')):
201        if len(tz) < 5:
202            return tz + '00'
203        return tz.replace(':', '')
204    return _timezones.get(tz, '+0000')
205
206
207def _get_timezone(tz):
208    tz = _timezone_as_offset(tz)
209    minutes = 60 * int(tz[1:3]) + int(tz[3:5])
210    if tz[0] == '-':
211        minutes = -minutes
212    return timezone(timedelta(minutes=minutes), tz)
213
214
215def _oid_key(table):
216    """Build oid key from a table name."""
217    return 'oid(%s)' % table
218
219
220class _SimpleTypes(dict):
221    """Dictionary mapping pg_type names to simple type names."""
222
223    _types = {'bool': 'bool',
224        'bytea': 'bytea',
225        'date': 'date interval time timetz timestamp timestamptz'
226            ' abstime reltime',  # these are very old
227        'float': 'float4 float8',
228        'int': 'cid int2 int4 int8 oid xid',
229        'hstore': 'hstore', 'json': 'json jsonb', 'uuid': 'uuid',
230        'num': 'numeric', 'money': 'money',
231        'text': 'bpchar char name text varchar'}
232
233    def __init__(self):
234        for typ, keys in self._types.items():
235            for key in keys.split():
236                self[key] = typ
237                self['_%s' % key] = '%s[]' % typ
238
239    # this could be a static method in Python > 2.6
240    def __missing__(self, key):
241        return 'text'
242
243_simpletypes = _SimpleTypes()
244
245
246def _quote_if_unqualified(param, name):
247    """Quote parameter representing a qualified name.
248
249    Puts a quote_ident() call around the give parameter unless
250    the name contains a dot, in which case the name is ambiguous
251    (could be a qualified name or just a name with a dot in it)
252    and must be quoted manually by the caller.
253    """
254    if isinstance(name, basestring) and '.' not in name:
255        return 'quote_ident(%s)' % (param,)
256    return param
257
258
259class _ParameterList(list):
260    """Helper class for building typed parameter lists."""
261
262    def add(self, value, typ=None):
263        """Typecast value with known database type and build parameter list.
264
265        If this is a literal value, it will be returned as is.  Otherwise, a
266        placeholder will be returned and the parameter list will be augmented.
267        """
268        value = self.adapt(value, typ)
269        if isinstance(value, Literal):
270            return value
271        self.append(value)
272        return '$%d' % len(self)
273
274
275class Bytea(bytes):
276    """Wrapper class for marking Bytea values."""
277
278
279class Hstore(dict):
280    """Wrapper class for marking hstore values."""
281
282    _re_quote = regex('^[Nn][Uu][Ll][Ll]$|[ ,=>]')
283
284    @classmethod
285    def _quote(cls, s):
286        if s is None:
287            return 'NULL'
288        if not s:
289            return '""'
290        s = s.replace('"', '\\"')
291        if cls._re_quote.search(s):
292            s = '"%s"' % s
293        return s
294
295    def __str__(self):
296        q = self._quote
297        return ','.join('%s=>%s' % (q(k), q(v)) for k, v in self.items())
298
299
300class Json:
301    """Wrapper class for marking Json values."""
302
303    def __init__(self, obj):
304        self.obj = obj
305
306
307class Literal(str):
308    """Wrapper class for marking literal SQL values."""
309
310
311class Adapter:
312    """Class providing methods for adapting parameters to the database."""
313
314    _bool_true_values = frozenset('t true 1 y yes on'.split())
315
316    _date_literals = frozenset('current_date current_time'
317        ' current_timestamp localtime localtimestamp'.split())
318
319    _re_array_quote = regex(r'[{},"\\\s]|^[Nn][Uu][Ll][Ll]$')
320    _re_record_quote = regex(r'[(,"\\]')
321    _re_array_escape = _re_record_escape = regex(r'(["\\])')
322
323    def __init__(self, db):
324        self.db = db
325        self.encode_json = db.encode_json
326        db = db.db
327        self.escape_bytea = db.escape_bytea
328        self.escape_string = db.escape_string
329
330    @classmethod
331    def _adapt_bool(cls, v):
332        """Adapt a boolean parameter."""
333        if isinstance(v, basestring):
334            if not v:
335                return None
336            v = v.lower() in cls._bool_true_values
337        return 't' if v else 'f'
338
339    @classmethod
340    def _adapt_date(cls, v):
341        """Adapt a date parameter."""
342        if not v:
343            return None
344        if isinstance(v, basestring) and v.lower() in cls._date_literals:
345            return Literal(v)
346        return v
347
348    @staticmethod
349    def _adapt_num(v):
350        """Adapt a numeric parameter."""
351        if not v and v != 0:
352            return None
353        return v
354
355    _adapt_int = _adapt_float = _adapt_money = _adapt_num
356
357    def _adapt_bytea(self, v):
358        """Adapt a bytea parameter."""
359        return self.escape_bytea(v)
360
361    def _adapt_json(self, v):
362        """Adapt a json parameter."""
363        if not v:
364            return None
365        if isinstance(v, basestring):
366            return v
367        return self.encode_json(v)
368
369    @classmethod
370    def _adapt_text_array(cls, v):
371        """Adapt a text type array parameter."""
372        if isinstance(v, list):
373            adapt = cls._adapt_text_array
374            return '{%s}' % ','.join(adapt(v) for v in v)
375        if v is None:
376            return 'null'
377        if not v:
378            return '""'
379        v = str(v)
380        if cls._re_array_quote.search(v):
381            v = '"%s"' % cls._re_array_escape.sub(r'\\\1', v)
382        return v
383
384    _adapt_date_array = _adapt_text_array
385
386    @classmethod
387    def _adapt_bool_array(cls, v):
388        """Adapt a boolean array parameter."""
389        if isinstance(v, list):
390            adapt = cls._adapt_bool_array
391            return '{%s}' % ','.join(adapt(v) for v in v)
392        if v is None:
393            return 'null'
394        if isinstance(v, basestring):
395            if not v:
396                return 'null'
397            v = v.lower() in cls._bool_true_values
398        return 't' if v else 'f'
399
400    @classmethod
401    def _adapt_num_array(cls, v):
402        """Adapt a numeric array parameter."""
403        if isinstance(v, list):
404            adapt = cls._adapt_num_array
405            return '{%s}' % ','.join(adapt(v) for v in v)
406        if not v and v != 0:
407            return 'null'
408        return str(v)
409
410    _adapt_int_array = _adapt_float_array = _adapt_money_array = \
411            _adapt_num_array
412
413    def _adapt_bytea_array(self, v):
414        """Adapt a bytea array parameter."""
415        if isinstance(v, list):
416            return b'{' + b','.join(
417                self._adapt_bytea_array(v) for v in v) + b'}'
418        if v is None:
419            return b'null'
420        return self.escape_bytea(v).replace(b'\\', b'\\\\')
421
422    def _adapt_json_array(self, v):
423        """Adapt a json array parameter."""
424        if isinstance(v, list):
425            adapt = self._adapt_json_array
426            return '{%s}' % ','.join(adapt(v) for v in v)
427        if not v:
428            return 'null'
429        if not isinstance(v, basestring):
430            v = self.encode_json(v)
431        if self._re_array_quote.search(v):
432            v = '"%s"' % self._re_array_escape.sub(r'\\\1', v)
433        return v
434
435    def _adapt_record(self, v, typ):
436        """Adapt a record parameter with given type."""
437        typ = self.get_attnames(typ).values()
438        if len(typ) != len(v):
439            raise TypeError('Record parameter %s has wrong size' % v)
440        adapt = self.adapt
441        value = []
442        for v, t in zip(v, typ):
443            v = adapt(v, t)
444            if v is None:
445                v = ''
446            elif not v:
447                v = '""'
448            else:
449                if isinstance(v, bytes):
450                    if str is not bytes:
451                        v = v.decode('ascii')
452                else:
453                    v = str(v)
454                if self._re_record_quote.search(v):
455                    v = '"%s"' % self._re_record_escape.sub(r'\\\1', v)
456            value.append(v)
457        return '(%s)' % ','.join(value)
458
459    def adapt(self, value, typ=None):
460        """Adapt a value with known database type."""
461        if value is not None and not isinstance(value, Literal):
462            if typ:
463                simple = self.get_simple_name(typ)
464            else:
465                typ = simple = self.guess_simple_type(value) or 'text'
466            try:
467                value = value.__pg_str__(typ)
468            except AttributeError:
469                pass
470            if simple == 'text':
471                pass
472            elif simple == 'record':
473                if isinstance(value, tuple):
474                    value = self._adapt_record(value, typ)
475            elif simple.endswith('[]'):
476                if isinstance(value, list):
477                    adapt = getattr(self, '_adapt_%s_array' % simple[:-2])
478                    value = adapt(value)
479            else:
480                adapt = getattr(self, '_adapt_%s' % simple)
481                value = adapt(value)
482        return value
483
484    @staticmethod
485    def simple_type(name):
486        """Create a simple database type with given attribute names."""
487        typ = DbType(name)
488        typ.simple = name
489        return typ
490
491    @staticmethod
492    def get_simple_name(typ):
493        """Get the simple name of a database type."""
494        if isinstance(typ, DbType):
495            return typ.simple
496        return _simpletypes[typ]
497
498    @staticmethod
499    def get_attnames(typ):
500        """Get the attribute names of a composite database type."""
501        if isinstance(typ, DbType):
502            return typ.attnames
503        return {}
504
505    @classmethod
506    def guess_simple_type(cls, value):
507        """Try to guess which database type the given value has."""
508        if isinstance(value, Bytea):
509            return 'bytea'
510        if isinstance(value, basestring):
511            return 'text'
512        if isinstance(value, bool):
513            return 'bool'
514        if isinstance(value, (int, long)):
515            return 'int'
516        if isinstance(value, float):
517            return 'float'
518        if isinstance(value, Decimal):
519            return 'num'
520        if isinstance(value, (date, time, datetime, timedelta)):
521            return 'date'
522        if isinstance(value, list):
523            return '%s[]' % cls.guess_simple_base_type(value)
524        if isinstance(value, tuple):
525            simple_type = cls.simple_type
526            typ = simple_type('record')
527            guess = cls.guess_simple_type
528            def get_attnames(self):
529                return AttrDict((str(n + 1), simple_type(guess(v)))
530                    for n, v in enumerate(value))
531            typ._get_attnames = get_attnames
532            return typ
533
534    @classmethod
535    def guess_simple_base_type(cls, value):
536        """Try to guess the base type of a given array."""
537        for v in value:
538            if isinstance(v, list):
539                typ = cls.guess_simple_base_type(v)
540            else:
541                typ = cls.guess_simple_type(v)
542            if typ:
543                return typ
544
545    def adapt_inline(self, value, nested=False):
546        """Adapt a value that is put into the SQL and needs to be quoted."""
547        if value is None:
548            return 'NULL'
549        if isinstance(value, Literal):
550            return value
551        if isinstance(value, Bytea):
552            value = self.escape_bytea(value)
553            if bytes is not str:  # Python >= 3.0
554                value = value.decode('ascii')
555        elif isinstance(value, Json):
556            if value.encode:
557                return value.encode()
558            value = self.encode_json(value)
559        elif isinstance(value, (datetime, date, time, timedelta)):
560            value = str(value)
561        if isinstance(value, basestring):
562            value = self.escape_string(value)
563            return "'%s'" % value
564        if isinstance(value, bool):
565            return 'true' if value else 'false'
566        if isinstance(value, float):
567            if isinf(value):
568                return "'-Infinity'" if value < 0 else "'Infinity'"
569            if isnan(value):
570                return "'NaN'"
571            return value
572        if isinstance(value, (int, long, Decimal)):
573            return value
574        if isinstance(value, list):
575            q = self.adapt_inline
576            s = '[%s]' if nested else 'ARRAY[%s]'
577            return s % ','.join(str(q(v, nested=True)) for v in value)
578        if isinstance(value, tuple):
579            q = self.adapt_inline
580            return '(%s)' % ','.join(str(q(v)) for v in value)
581        try:
582            value = value.__pg_repr__()
583        except AttributeError:
584            raise InterfaceError(
585                'Do not know how to adapt type %s' % type(value))
586        if isinstance(value, (tuple, list)):
587            value = self.adapt_inline(value)
588        return value
589
590    def parameter_list(self):
591        """Return a parameter list for parameters with known database types.
592
593        The list has an add(value, typ) method that will build up the
594        list and return either the literal value or a placeholder.
595        """
596        params = _ParameterList()
597        params.adapt = self.adapt
598        return params
599
600    def format_query(self, command, values, types=None, inline=False):
601        """Format a database query using the given values and types."""
602        if inline and types:
603            raise ValueError('Typed parameters must be sent separately')
604        params = self.parameter_list()
605        if isinstance(values, (list, tuple)):
606            if inline:
607                adapt = self.adapt_inline
608                literals = [adapt(value) for value in values]
609            else:
610                add = params.add
611                literals = []
612                append = literals.append
613                if types:
614                    if (not isinstance(types, (list, tuple)) or
615                            len(types) != len(values)):
616                        raise TypeError('The values and types do not match')
617                    for value, typ in zip(values, types):
618                        append(add(value, typ))
619                else:
620                    for value in values:
621                        append(add(value))
622            command = command % tuple(literals)
623        elif isinstance(values, dict):
624            if inline:
625                adapt = self.adapt_inline
626                literals = dict((key, adapt(value))
627                    for key, value in values.items())
628            else:
629                add = params.add
630                literals = {}
631                if types:
632                    if (not isinstance(types, dict) or
633                            len(types) < len(values)):
634                        raise TypeError('The values and types do not match')
635                    for key in sorted(values):
636                        literals[key] = add(values[key], types[key])
637                else:
638                    for key in sorted(values):
639                        literals[key] = add(values[key])
640            command = command % literals
641        else:
642            raise TypeError('The values must be passed as tuple, list or dict')
643        return command, params
644
645
646def cast_bool(value):
647    """Cast a boolean value."""
648    if not get_bool():
649        return value
650    return value[0] == 't'
651
652
653def cast_json(value):
654    """Cast a JSON value."""
655    cast = get_jsondecode()
656    if not cast:
657        return value
658    return cast(value)
659
660
661def cast_num(value):
662    """Cast a numeric value."""
663    return (get_decimal() or float)(value)
664
665
666def cast_money(value):
667    """Cast a money value."""
668    point = get_decimal_point()
669    if not point:
670        return value
671    if point != '.':
672        value = value.replace(point, '.')
673    value = value.replace('(', '-')
674    value = ''.join(c for c in value if c.isdigit() or c in '.-')
675    return (get_decimal() or float)(value)
676
677
678def cast_int2vector(value):
679    """Cast an int2vector value."""
680    return [int(v) for v in value.split()]
681
682
683def cast_date(value, connection):
684    """Cast a date value."""
685    # The output format depends on the server setting DateStyle.  The default
686    # setting ISO and the setting for German are actually unambiguous.  The
687    # order of days and months in the other two settings is however ambiguous,
688    # so at least here we need to consult the setting to properly parse values.
689    if value == '-infinity':
690        return date.min
691    if value == 'infinity':
692        return date.max
693    value = value.split()
694    if value[-1] == 'BC':
695        return date.min
696    value = value[0]
697    if len(value) > 10:
698        return date.max
699    fmt = connection.date_format()
700    return datetime.strptime(value, fmt).date()
701
702
703def cast_time(value):
704    """Cast a time value."""
705    fmt = '%H:%M:%S.%f' if len(value) > 8 else '%H:%M:%S'
706    return datetime.strptime(value, fmt).time()
707
708
709_re_timezone = regex('(.*)([+-].*)')
710
711
712def cast_timetz(value):
713    """Cast a timetz value."""
714    tz = _re_timezone.match(value)
715    if tz:
716        value, tz = tz.groups()
717    else:
718        tz = '+0000'
719    fmt = '%H:%M:%S.%f' if len(value) > 8 else '%H:%M:%S'
720    if _has_timezone:
721        value += _timezone_as_offset(tz)
722        fmt += '%z'
723        return datetime.strptime(value, fmt).timetz()
724    return datetime.strptime(value, fmt).timetz().replace(
725        tzinfo=_get_timezone(tz))
726
727
728def cast_timestamp(value, connection):
729    """Cast a timestamp value."""
730    if value == '-infinity':
731        return datetime.min
732    if value == 'infinity':
733        return datetime.max
734    value = value.split()
735    if value[-1] == 'BC':
736        return datetime.min
737    fmt = connection.date_format()
738    if fmt.endswith('-%Y') and len(value) > 2:
739        value = value[1:5]
740        if len(value[3]) > 4:
741            return datetime.max
742        fmt = ['%d %b' if fmt.startswith('%d') else '%b %d',
743            '%H:%M:%S.%f' if len(value[2]) > 8 else '%H:%M:%S', '%Y']
744    else:
745        if len(value[0]) > 10:
746            return datetime.max
747        fmt = [fmt, '%H:%M:%S.%f' if len(value[1]) > 8 else '%H:%M:%S']
748    return datetime.strptime(' '.join(value), ' '.join(fmt))
749
750
751def cast_timestamptz(value, connection):
752    """Cast a timestamptz value."""
753    if value == '-infinity':
754        return datetime.min
755    if value == 'infinity':
756        return datetime.max
757    value = value.split()
758    if value[-1] == 'BC':
759        return datetime.min
760    fmt = connection.date_format()
761    if fmt.endswith('-%Y') and len(value) > 2:
762        value = value[1:]
763        if len(value[3]) > 4:
764            return datetime.max
765        fmt = ['%d %b' if fmt.startswith('%d') else '%b %d',
766            '%H:%M:%S.%f' if len(value[2]) > 8 else '%H:%M:%S', '%Y']
767        value, tz = value[:-1], value[-1]
768    else:
769        if fmt.startswith('%Y-'):
770            tz = _re_timezone.match(value[1])
771            if tz:
772                value[1], tz = tz.groups()
773            else:
774                tz = '+0000'
775        else:
776            value, tz = value[:-1], value[-1]
777        if len(value[0]) > 10:
778            return datetime.max
779        fmt = [fmt, '%H:%M:%S.%f' if len(value[1]) > 8 else '%H:%M:%S']
780    if _has_timezone:
781        value.append(_timezone_as_offset(tz))
782        fmt.append('%z')
783        return datetime.strptime(' '.join(value), ' '.join(fmt))
784    return datetime.strptime(' '.join(value), ' '.join(fmt)).replace(
785        tzinfo=_get_timezone(tz))
786
787
788_re_interval_sql_standard = regex(
789    '(?:([+-])?([0-9]+)-([0-9]+) ?)?'
790    '(?:([+-]?[0-9]+)(?!:) ?)?'
791    '(?:([+-])?([0-9]+):([0-9]+):([0-9]+)(?:\\.([0-9]+))?)?')
792
793_re_interval_postgres = regex(
794    '(?:([+-]?[0-9]+) ?years? ?)?'
795    '(?:([+-]?[0-9]+) ?mons? ?)?'
796    '(?:([+-]?[0-9]+) ?days? ?)?'
797    '(?:([+-])?([0-9]+):([0-9]+):([0-9]+)(?:\\.([0-9]+))?)?')
798
799_re_interval_postgres_verbose = regex(
800    '@ ?(?:([+-]?[0-9]+) ?years? ?)?'
801    '(?:([+-]?[0-9]+) ?mons? ?)?'
802    '(?:([+-]?[0-9]+) ?days? ?)?'
803    '(?:([+-]?[0-9]+) ?hours? ?)?'
804    '(?:([+-]?[0-9]+) ?mins? ?)?'
805    '(?:([+-])?([0-9]+)(?:\\.([0-9]+))? ?secs?)? ?(ago)?')
806
807_re_interval_iso_8601 = regex(
808    'P(?:([+-]?[0-9]+)Y)?'
809    '(?:([+-]?[0-9]+)M)?'
810    '(?:([+-]?[0-9]+)D)?'
811    '(?:T(?:([+-]?[0-9]+)H)?'
812    '(?:([+-]?[0-9]+)M)?'
813    '(?:([+-])?([0-9]+)(?:\\.([0-9]+))?S)?)?')
814
815
816def cast_interval(value):
817    """Cast an interval value."""
818    # The output format depends on the server setting IntervalStyle, but it's
819    # not necessary to consult this setting to parse it.  It's faster to just
820    # check all possible formats, and there is no ambiguity here.
821    m = _re_interval_iso_8601.match(value)
822    if m:
823        m = [d or '0' for d in m.groups()]
824        secs_ago = m.pop(5) == '-'
825        m = [int(d) for d in m]
826        years, mons, days, hours, mins, secs, usecs = m
827        if secs_ago:
828            secs = -secs
829            usecs = -usecs
830    else:
831        m = _re_interval_postgres_verbose.match(value)
832        if m:
833            m, ago = [d or '0' for d in m.groups()[:8]], m.group(9)
834            secs_ago = m.pop(5) == '-'
835            m = [-int(d) for d in m] if ago else [int(d) for d in m]
836            years, mons, days, hours, mins, secs, usecs = m
837            if secs_ago:
838                secs = - secs
839                usecs = -usecs
840        else:
841            m = _re_interval_postgres.match(value)
842            if m and any(m.groups()):
843                m = [d or '0' for d in m.groups()]
844                hours_ago = m.pop(3) == '-'
845                m = [int(d) for d in m]
846                years, mons, days, hours, mins, secs, usecs = m
847                if hours_ago:
848                    hours = -hours
849                    mins = -mins
850                    secs = -secs
851                    usecs = -usecs
852            else:
853                m = _re_interval_sql_standard.match(value)
854                if m and any(m.groups()):
855                    m = [d or '0' for d in m.groups()]
856                    years_ago = m.pop(0) == '-'
857                    hours_ago = m.pop(3) == '-'
858                    m = [int(d) for d in m]
859                    years, mons, days, hours, mins, secs, usecs = m
860                    if years_ago:
861                        years = -years
862                        mons = -mons
863                    if hours_ago:
864                        hours = -hours
865                        mins = -mins
866                        secs = -secs
867                        usecs = -usecs
868                else:
869                    raise ValueError('Cannot parse interval: %s' % value)
870    days += 365 * years + 30 * mons
871    return timedelta(days=days, hours=hours, minutes=mins,
872        seconds=secs, microseconds=usecs)
873
874
875class Typecasts(dict):
876    """Dictionary mapping database types to typecast functions.
877
878    The cast functions get passed the string representation of a value in
879    the database which they need to convert to a Python object.  The
880    passed string will never be None since NULL values are already be
881    handled before the cast function is called.
882
883    Note that the basic types are already handled by the C extension.
884    They only need to be handled here as record or array components.
885    """
886
887    # the default cast functions
888    # (str functions are ignored but have been added for faster access)
889    defaults = {'char': str, 'bpchar': str, 'name': str,
890        'text': str, 'varchar': str,
891        'bool': cast_bool, 'bytea': unescape_bytea,
892        'int2': int, 'int4': int, 'serial': int, 'int8': long, 'oid': int,
893        'hstore': cast_hstore, 'json': cast_json, 'jsonb': cast_json,
894        'float4': float, 'float8': float,
895        'numeric': cast_num, 'money': cast_money,
896        'date': cast_date, 'interval': cast_interval,
897        'time': cast_time, 'timetz': cast_timetz,
898        'timestamp': cast_timestamp, 'timestamptz': cast_timestamptz,
899        'int2vector': cast_int2vector, 'uuid': UUID,
900        'anyarray': cast_array, 'record': cast_record}
901
902    connection = None  # will be set in a connection specific instance
903
904    def __missing__(self, typ):
905        """Create a cast function if it is not cached.
906
907        Note that this class never raises a KeyError,
908        but returns None when no special cast function exists.
909        """
910        if not isinstance(typ, str):
911            raise TypeError('Invalid type: %s' % typ)
912        cast = self.defaults.get(typ)
913        if cast:
914            # store default for faster access
915            cast = self._add_connection(cast)
916            self[typ] = cast
917        elif typ.startswith('_'):
918            base_cast = self[typ[1:]]
919            cast = self.create_array_cast(base_cast)
920            if base_cast:
921                self[typ] = cast
922        else:
923            attnames = self.get_attnames(typ)
924            if attnames:
925                casts = [self[v.pgtype] for v in attnames.values()]
926                cast = self.create_record_cast(typ, attnames, casts)
927                self[typ] = cast
928        return cast
929
930    @staticmethod
931    def _needs_connection(func):
932        """Check if a typecast function needs a connection argument."""
933        try:
934            args = get_args(func)
935        except (TypeError, ValueError):
936            return False
937        else:
938            return 'connection' in args[1:]
939
940    def _add_connection(self, cast):
941        """Add a connection argument to the typecast function if necessary."""
942        if not self.connection or not self._needs_connection(cast):
943            return cast
944        return partial(cast, connection=self.connection)
945
946    def get(self, typ, default=None):
947        """Get the typecast function for the given database type."""
948        return self[typ] or default
949
950    def set(self, typ, cast):
951        """Set a typecast function for the specified database type(s)."""
952        if isinstance(typ, basestring):
953            typ = [typ]
954        if cast is None:
955            for t in typ:
956                self.pop(t, None)
957                self.pop('_%s' % t, None)
958        else:
959            if not callable(cast):
960                raise TypeError("Cast parameter must be callable")
961            for t in typ:
962                self[t] = self._add_connection(cast)
963                self.pop('_%s' % t, None)
964
965    def reset(self, typ=None):
966        """Reset the typecasts for the specified type(s) to their defaults.
967
968        When no type is specified, all typecasts will be reset.
969        """
970        if typ is None:
971            self.clear()
972        else:
973            if isinstance(typ, basestring):
974                typ = [typ]
975            for t in typ:
976                self.pop(t, None)
977
978    @classmethod
979    def get_default(cls, typ):
980        """Get the default typecast function for the given database type."""
981        return cls.defaults.get(typ)
982
983    @classmethod
984    def set_default(cls, typ, cast):
985        """Set a default typecast function for the given database type(s)."""
986        if isinstance(typ, basestring):
987            typ = [typ]
988        defaults = cls.defaults
989        if cast is None:
990            for t in typ:
991                defaults.pop(t, None)
992                defaults.pop('_%s' % t, None)
993        else:
994            if not callable(cast):
995                raise TypeError("Cast parameter must be callable")
996            for t in typ:
997                defaults[t] = cast
998                defaults.pop('_%s' % t, None)
999
1000    def get_attnames(self, typ):
1001        """Return the fields for the given record type.
1002
1003        This method will be replaced with the get_attnames() method of DbTypes.
1004        """
1005        return {}
1006
1007    def dateformat(self):
1008        """Return the current date format.
1009
1010        This method will be replaced with the dateformat() method of DbTypes.
1011        """
1012        return '%Y-%m-%d'
1013
1014    def create_array_cast(self, basecast):
1015        """Create an array typecast for the given base cast."""
1016        def cast(v):
1017            return cast_array(v, basecast)
1018        return cast
1019
1020    def create_record_cast(self, name, fields, casts):
1021        """Create a named record typecast for the given fields and casts."""
1022        record = namedtuple(name, fields)
1023        def cast(v):
1024            return record(*cast_record(v, casts))
1025        return cast
1026
1027
1028def get_typecast(typ):
1029    """Get the global typecast function for the given database type(s)."""
1030    return Typecasts.get_default(typ)
1031
1032
1033def set_typecast(typ, cast):
1034    """Set a global typecast function for the given database type(s).
1035
1036    Note that connections cache cast functions. To be sure a global change
1037    is picked up by a running connection, call db.db_types.reset_typecast().
1038    """
1039    Typecasts.set_default(typ, cast)
1040
1041
1042class DbType(str):
1043    """Class augmenting the simple type name with additional info.
1044
1045    The following additional information is provided:
1046
1047        oid: the PostgreSQL type OID
1048        pgtype: the PostgreSQL type name
1049        regtype: the regular type name
1050        simple: the simple PyGreSQL type name
1051        typtype: b = base type, c = composite type etc.
1052        category: A = Array, b = Boolean, C = Composite etc.
1053        delim: delimiter for array types
1054        relid: corresponding table for composite types
1055        attnames: attributes for composite types
1056    """
1057
1058    @property
1059    def attnames(self):
1060        """Get names and types of the fields of a composite type."""
1061        return self._get_attnames(self)
1062
1063
1064class DbTypes(dict):
1065    """Cache for PostgreSQL data types.
1066
1067    This cache maps type OIDs and names to DbType objects containing
1068    information on the associated database type.
1069    """
1070
1071    _num_types = frozenset('int float num money'
1072        ' int2 int4 int8 float4 float8 numeric money'.split())
1073
1074    def __init__(self, db):
1075        """Initialize type cache for connection."""
1076        super(DbTypes, self).__init__()
1077        self._regtypes = False
1078        self._get_attnames = db.get_attnames
1079        self._typecasts = Typecasts()
1080        self._typecasts.get_attnames = self.get_attnames
1081        self._typecasts.connection = db
1082        db = db.db
1083        self.query = db.query
1084        self.escape_string = db.escape_string
1085        if db.server_version < 80400:
1086            # older remote databases (not officially supported)
1087            self._query_pg_type = (
1088                "SELECT oid, typname, typname::text::regtype,"
1089                " typtype, null as typcategory, typdelim, typrelid"
1090                " FROM pg_type WHERE oid=%s::regtype")
1091        else:
1092            self._query_pg_type = (
1093                "SELECT oid, typname, typname::regtype,"
1094                " typtype, typcategory, typdelim, typrelid"
1095                " FROM pg_type WHERE oid=%s::regtype")
1096
1097    def add(self, oid, pgtype, regtype,
1098               typtype, category, delim, relid):
1099        """Create a PostgreSQL type name with additional info."""
1100        if oid in self:
1101            return self[oid]
1102        simple = 'record' if relid else _simpletypes[pgtype]
1103        typ = DbType(regtype if self._regtypes else simple)
1104        typ.oid = oid
1105        typ.simple = simple
1106        typ.pgtype = pgtype
1107        typ.regtype = regtype
1108        typ.typtype = typtype
1109        typ.category = category
1110        typ.delim = delim
1111        typ.relid = relid
1112        typ._get_attnames = self.get_attnames
1113        return typ
1114
1115    def __missing__(self, key):
1116        """Get the type info from the database if it is not cached."""
1117        try:
1118            q = self._query_pg_type % (_quote_if_unqualified('$1', key),)
1119            res = self.query(q, (key,)).getresult()
1120        except ProgrammingError:
1121            res = None
1122        if not res:
1123            raise KeyError('Type %s could not be found' % key)
1124        res = res[0]
1125        typ = self.add(*res)
1126        self[typ.oid] = self[typ.pgtype] = typ
1127        return typ
1128
1129    def get(self, key, default=None):
1130        """Get the type even if it is not cached."""
1131        try:
1132            return self[key]
1133        except KeyError:
1134            return default
1135
1136    def get_attnames(self, typ):
1137        """Get names and types of the fields of a composite type."""
1138        if not isinstance(typ, DbType):
1139            typ = self.get(typ)
1140            if not typ:
1141                return None
1142        if not typ.relid:
1143            return None
1144        return self._get_attnames(typ.relid, with_oid=False)
1145
1146    def get_typecast(self, typ):
1147        """Get the typecast function for the given database type."""
1148        return self._typecasts.get(typ)
1149
1150    def set_typecast(self, typ, cast):
1151        """Set a typecast function for the specified database type(s)."""
1152        self._typecasts.set(typ, cast)
1153
1154    def reset_typecast(self, typ=None):
1155        """Reset the typecast function for the specified database type(s)."""
1156        self._typecasts.reset(typ)
1157
1158    def typecast(self, value, typ):
1159        """Cast the given value according to the given database type."""
1160        if value is None:
1161            # for NULL values, no typecast is necessary
1162            return None
1163        if not isinstance(typ, DbType):
1164            typ = self.get(typ)
1165            if typ:
1166                typ = typ.pgtype
1167        cast = self.get_typecast(typ) if typ else None
1168        if not cast or cast is str:
1169            # no typecast is necessary
1170            return value
1171        return cast(value)
1172
1173
1174def _namedresult(q):
1175    """Get query result as named tuples."""
1176    row = namedtuple('Row', q.listfields())
1177    return [row(*r) for r in q.getresult()]
1178
1179
1180class _MemoryQuery:
1181    """Class that embodies a given query result."""
1182
1183    def __init__(self, result, fields):
1184        """Create query from given result rows and field names."""
1185        self.result = result
1186        self.fields = fields
1187
1188    def listfields(self):
1189        """Return the stored field names of this query."""
1190        return self.fields
1191
1192    def getresult(self):
1193        """Return the stored result of this query."""
1194        return self.result
1195
1196
1197def _db_error(msg, cls=DatabaseError):
1198    """Return DatabaseError with empty sqlstate attribute."""
1199    error = cls(msg)
1200    error.sqlstate = None
1201    return error
1202
1203
1204def _int_error(msg):
1205    """Return InternalError."""
1206    return _db_error(msg, InternalError)
1207
1208
1209def _prg_error(msg):
1210    """Return ProgrammingError."""
1211    return _db_error(msg, ProgrammingError)
1212
1213
1214# Initialize the C module
1215
1216set_namedresult(_namedresult)
1217set_decimal(Decimal)
1218set_jsondecode(jsondecode)
1219
1220
1221# The notification handler
1222
1223class NotificationHandler(object):
1224    """A PostgreSQL client-side asynchronous notification handler."""
1225
1226    def __init__(self, db, event, callback=None,
1227            arg_dict=None, timeout=None, stop_event=None):
1228        """Initialize the notification handler.
1229
1230        You must pass a PyGreSQL database connection, the name of an
1231        event (notification channel) to listen for and a callback function.
1232
1233        You can also specify a dictionary arg_dict that will be passed as
1234        the single argument to the callback function, and a timeout value
1235        in seconds (a floating point number denotes fractions of seconds).
1236        If it is absent or None, the callers will never time out.  If the
1237        timeout is reached, the callback function will be called with a
1238        single argument that is None.  If you set the timeout to zero,
1239        the handler will poll notifications synchronously and return.
1240
1241        You can specify the name of the event that will be used to signal
1242        the handler to stop listening as stop_event. By default, it will
1243        be the event name prefixed with 'stop_'.
1244        """
1245        self.db = db
1246        self.event = event
1247        self.stop_event = stop_event or 'stop_%s' % event
1248        self.listening = False
1249        self.callback = callback
1250        if arg_dict is None:
1251            arg_dict = {}
1252        self.arg_dict = arg_dict
1253        self.timeout = timeout
1254
1255    def __del__(self):
1256        self.unlisten()
1257
1258    def close(self):
1259        """Stop listening and close the connection."""
1260        if self.db:
1261            self.unlisten()
1262            self.db.close()
1263            self.db = None
1264
1265    def listen(self):
1266        """Start listening for the event and the stop event."""
1267        if not self.listening:
1268            self.db.query('listen "%s"' % self.event)
1269            self.db.query('listen "%s"' % self.stop_event)
1270            self.listening = True
1271
1272    def unlisten(self):
1273        """Stop listening for the event and the stop event."""
1274        if self.listening:
1275            self.db.query('unlisten "%s"' % self.event)
1276            self.db.query('unlisten "%s"' % self.stop_event)
1277            self.listening = False
1278
1279    def notify(self, db=None, stop=False, payload=None):
1280        """Generate a notification.
1281
1282        Optionally, you can pass a payload with the notification.
1283
1284        If you set the stop flag, a stop notification will be sent that
1285        will cause the handler to stop listening.
1286
1287        Note: If the notification handler is running in another thread, you
1288        must pass a different database connection since PyGreSQL database
1289        connections are not thread-safe.
1290        """
1291        if self.listening:
1292            if not db:
1293                db = self.db
1294            q = 'notify "%s"' % (self.stop_event if stop else self.event)
1295            if payload:
1296                q += ", '%s'" % payload
1297            return db.query(q)
1298
1299    def __call__(self):
1300        """Invoke the notification handler.
1301
1302        The handler is a loop that listens for notifications on the event
1303        and stop event channels.  When either of these notifications are
1304        received, its associated 'pid', 'event' and 'extra' (the payload
1305        passed with the notification) are inserted into its arg_dict
1306        dictionary and the callback is invoked with this dictionary as
1307        a single argument.  When the handler receives a stop event, it
1308        stops listening to both events and return.
1309
1310        In the special case that the timeout of the handler has been set
1311        to zero, the handler will poll all events synchronously and return.
1312        If will keep listening until it receives a stop event.
1313
1314        Note: If you run this loop in another thread, don't use the same
1315        database connection for database operations in the main thread.
1316        """
1317        self.listen()
1318        poll = self.timeout == 0
1319        if not poll:
1320            rlist = [self.db.fileno()]
1321        while self.listening:
1322            if poll or select.select(rlist, [], [], self.timeout)[0]:
1323                while self.listening:
1324                    notice = self.db.getnotify()
1325                    if not notice:  # no more messages
1326                        break
1327                    event, pid, extra = notice
1328                    if event not in (self.event, self.stop_event):
1329                        self.unlisten()
1330                        raise _db_error(
1331                            'Listening for "%s" and "%s", but notified of "%s"'
1332                            % (self.event, self.stop_event, event))
1333                    if event == self.stop_event:
1334                        self.unlisten()
1335                    self.arg_dict.update(pid=pid, event=event, extra=extra)
1336                    self.callback(self.arg_dict)
1337                if poll:
1338                    break
1339            else:   # we timed out
1340                self.unlisten()
1341                self.callback(None)
1342
1343
1344def pgnotify(*args, **kw):
1345    """Same as NotificationHandler, under the traditional name."""
1346    warnings.warn("pgnotify is deprecated, use NotificationHandler instead",
1347        DeprecationWarning, stacklevel=2)
1348    return NotificationHandler(*args, **kw)
1349
1350
1351# The actual PostGreSQL database connection interface:
1352
1353class DB:
1354    """Wrapper class for the _pg connection type."""
1355
1356    def __init__(self, *args, **kw):
1357        """Create a new connection
1358
1359        You can pass either the connection parameters or an existing
1360        _pg or pgdb connection. This allows you to use the methods
1361        of the classic pg interface with a DB-API 2 pgdb connection.
1362        """
1363        if not args and len(kw) == 1:
1364            db = kw.get('db')
1365        elif not kw and len(args) == 1:
1366            db = args[0]
1367        else:
1368            db = None
1369        if db:
1370            if isinstance(db, DB):
1371                db = db.db
1372            else:
1373                try:
1374                    db = db._cnx
1375                except AttributeError:
1376                    pass
1377        if not db or not hasattr(db, 'db') or not hasattr(db, 'query'):
1378            db = connect(*args, **kw)
1379            self._closeable = True
1380        else:
1381            self._closeable = False
1382        self.db = db
1383        self.dbname = db.db
1384        self._regtypes = False
1385        self._attnames = {}
1386        self._pkeys = {}
1387        self._privileges = {}
1388        self._args = args, kw
1389        self.adapter = Adapter(self)
1390        self.dbtypes = DbTypes(self)
1391        if db.server_version < 80400:
1392            # support older remote data bases
1393            self._query_attnames = (
1394                "SELECT a.attname, t.oid, t.typname, t.typname::text::regtype,"
1395                " t.typtype, null as typcategory, t.typdelim, t.typrelid"
1396                " FROM pg_attribute a"
1397                " JOIN pg_type t ON t.oid = a.atttypid"
1398                " WHERE a.attrelid = %s::regclass AND %s"
1399                " AND NOT a.attisdropped ORDER BY a.attnum")
1400        else:
1401            self._query_attnames = (
1402                "SELECT a.attname, t.oid, t.typname, t.typname::regtype,"
1403                " t.typtype, t.typcategory, t.typdelim, t.typrelid"
1404                " FROM pg_attribute a"
1405                " JOIN pg_type t ON t.oid = a.atttypid"
1406                " WHERE a.attrelid = %s::regclass AND %s"
1407                " AND NOT a.attisdropped ORDER BY a.attnum")
1408        db.set_cast_hook(self.dbtypes.typecast)
1409        self.debug = None  # For debugging scripts, this can be set
1410            # * to a string format specification (e.g. in CGI set to "%s<BR>"),
1411            # * to a file object to write debug statements or
1412            # * to a callable object which takes a string argument
1413            # * to any other true value to just print debug statements
1414
1415    def __getattr__(self, name):
1416        # All undefined members are same as in underlying connection:
1417        if self.db:
1418            return getattr(self.db, name)
1419        else:
1420            raise _int_error('Connection is not valid')
1421
1422    def __dir__(self):
1423        # Custom dir function including the attributes of the connection:
1424        attrs = set(self.__class__.__dict__)
1425        attrs.update(self.__dict__)
1426        attrs.update(dir(self.db))
1427        return sorted(attrs)
1428
1429    # Context manager methods
1430
1431    def __enter__(self):
1432        """Enter the runtime context. This will start a transactio."""
1433        self.begin()
1434        return self
1435
1436    def __exit__(self, et, ev, tb):
1437        """Exit the runtime context. This will end the transaction."""
1438        if et is None and ev is None and tb is None:
1439            self.commit()
1440        else:
1441            self.rollback()
1442
1443    # Auxiliary methods
1444
1445    def _do_debug(self, *args):
1446        """Print a debug message"""
1447        if self.debug:
1448            s = '\n'.join(str(arg) for arg in args)
1449            if isinstance(self.debug, basestring):
1450                print(self.debug % s)
1451            elif hasattr(self.debug, 'write'):
1452                self.debug.write(s + '\n')
1453            elif callable(self.debug):
1454                self.debug(s)
1455            else:
1456                print(s)
1457
1458    def _escape_qualified_name(self, s):
1459        """Escape a qualified name.
1460
1461        Escapes the name for use as an SQL identifier, unless the
1462        name contains a dot, in which case the name is ambiguous
1463        (could be a qualified name or just a name with a dot in it)
1464        and must be quoted manually by the caller.
1465        """
1466        if '.' not in s:
1467            s = self.escape_identifier(s)
1468        return s
1469
1470    @staticmethod
1471    def _make_bool(d):
1472        """Get boolean value corresponding to d."""
1473        return bool(d) if get_bool() else ('t' if d else 'f')
1474
1475    def _list_params(self, params):
1476        """Create a human readable parameter list."""
1477        return ', '.join('$%d=%r' % (n, v) for n, v in enumerate(params, 1))
1478
1479    # Public methods
1480
1481    # escape_string and escape_bytea exist as methods,
1482    # so we define unescape_bytea as a method as well
1483    unescape_bytea = staticmethod(unescape_bytea)
1484
1485    def decode_json(self, s):
1486        """Decode a JSON string coming from the database."""
1487        return (get_jsondecode() or jsondecode)(s)
1488
1489    def encode_json(self, d):
1490        """Encode a JSON string for use within SQL."""
1491        return jsonencode(d)
1492
1493    def close(self):
1494        """Close the database connection."""
1495        # Wraps shared library function so we can track state.
1496        if self._closeable:
1497            if self.db:
1498                self.db.close()
1499                self.db = None
1500            else:
1501                raise _int_error('Connection already closed')
1502
1503    def reset(self):
1504        """Reset connection with current parameters.
1505
1506        All derived queries and large objects derived from this connection
1507        will not be usable after this call.
1508
1509        """
1510        if self.db:
1511            self.db.reset()
1512        else:
1513            raise _int_error('Connection already closed')
1514
1515    def reopen(self):
1516        """Reopen connection to the database.
1517
1518        Used in case we need another connection to the same database.
1519        Note that we can still reopen a database that we have closed.
1520
1521        """
1522        # There is no such shared library function.
1523        if self._closeable:
1524            db = connect(*self._args[0], **self._args[1])
1525            if self.db:
1526                self.db.close()
1527            self.db = db
1528
1529    def begin(self, mode=None):
1530        """Begin a transaction."""
1531        qstr = 'BEGIN'
1532        if mode:
1533            qstr += ' ' + mode
1534        return self.query(qstr)
1535
1536    start = begin
1537
1538    def commit(self):
1539        """Commit the current transaction."""
1540        return self.query('COMMIT')
1541
1542    end = commit
1543
1544    def rollback(self, name=None):
1545        """Roll back the current transaction."""
1546        qstr = 'ROLLBACK'
1547        if name:
1548            qstr += ' TO ' + name
1549        return self.query(qstr)
1550
1551    abort = rollback
1552
1553    def savepoint(self, name):
1554        """Define a new savepoint within the current transaction."""
1555        return self.query('SAVEPOINT ' + name)
1556
1557    def release(self, name):
1558        """Destroy a previously defined savepoint."""
1559        return self.query('RELEASE ' + name)
1560
1561    def get_parameter(self, parameter):
1562        """Get the value of a run-time parameter.
1563
1564        If the parameter is a string, the return value will also be a string
1565        that is the current setting of the run-time parameter with that name.
1566
1567        You can get several parameters at once by passing a list, set or dict.
1568        When passing a list of parameter names, the return value will be a
1569        corresponding list of parameter settings.  When passing a set of
1570        parameter names, a new dict will be returned, mapping these parameter
1571        names to their settings.  Finally, if you pass a dict as parameter,
1572        its values will be set to the current parameter settings corresponding
1573        to its keys.
1574
1575        By passing the special name 'all' as the parameter, you can get a dict
1576        of all existing configuration parameters.
1577        """
1578        if isinstance(parameter, basestring):
1579            parameter = [parameter]
1580            values = None
1581        elif isinstance(parameter, (list, tuple)):
1582            values = []
1583        elif isinstance(parameter, (set, frozenset)):
1584            values = {}
1585        elif isinstance(parameter, dict):
1586            values = parameter
1587        else:
1588            raise TypeError(
1589                'The parameter must be a string, list, set or dict')
1590        if not parameter:
1591            raise TypeError('No parameter has been specified')
1592        params = {} if isinstance(values, dict) else []
1593        for key in parameter:
1594            param = key.strip().lower() if isinstance(
1595                key, basestring) else None
1596            if not param:
1597                raise TypeError('Invalid parameter')
1598            if param == 'all':
1599                q = 'SHOW ALL'
1600                values = self.db.query(q).getresult()
1601                values = dict(value[:2] for value in values)
1602                break
1603            if isinstance(values, dict):
1604                params[param] = key
1605            else:
1606                params.append(param)
1607        else:
1608            for param in params:
1609                q = 'SHOW %s' % (param,)
1610                value = self.db.query(q).getresult()[0][0]
1611                if values is None:
1612                    values = value
1613                elif isinstance(values, list):
1614                    values.append(value)
1615                else:
1616                    values[params[param]] = value
1617        return values
1618
1619    def set_parameter(self, parameter, value=None, local=False):
1620        """Set the value of a run-time parameter.
1621
1622        If the parameter and the value are strings, the run-time parameter
1623        will be set to that value.  If no value or None is passed as a value,
1624        then the run-time parameter will be restored to its default value.
1625
1626        You can set several parameters at once by passing a list of parameter
1627        names, together with a single value that all parameters should be
1628        set to or with a corresponding list of values.  You can also pass
1629        the parameters as a set if you only provide a single value.
1630        Finally, you can pass a dict with parameter names as keys.  In this
1631        case, you should not pass a value, since the values for the parameters
1632        will be taken from the dict.
1633
1634        By passing the special name 'all' as the parameter, you can reset
1635        all existing settable run-time parameters to their default values.
1636
1637        If you set local to True, then the command takes effect for only the
1638        current transaction.  After commit() or rollback(), the session-level
1639        setting takes effect again.  Setting local to True will appear to
1640        have no effect if it is executed outside a transaction, since the
1641        transaction will end immediately.
1642        """
1643        if isinstance(parameter, basestring):
1644            parameter = {parameter: value}
1645        elif isinstance(parameter, (list, tuple)):
1646            if isinstance(value, (list, tuple)):
1647                parameter = dict(zip(parameter, value))
1648            else:
1649                parameter = dict.fromkeys(parameter, value)
1650        elif isinstance(parameter, (set, frozenset)):
1651            if isinstance(value, (list, tuple, set, frozenset)):
1652                value = set(value)
1653                if len(value) == 1:
1654                    value = value.pop()
1655            if not(value is None or isinstance(value, basestring)):
1656                raise ValueError('A single value must be specified'
1657                    ' when parameter is a set')
1658            parameter = dict.fromkeys(parameter, value)
1659        elif isinstance(parameter, dict):
1660            if value is not None:
1661                raise ValueError('A value must not be specified'
1662                    ' when parameter is a dictionary')
1663        else:
1664            raise TypeError(
1665                'The parameter must be a string, list, set or dict')
1666        if not parameter:
1667            raise TypeError('No parameter has been specified')
1668        params = {}
1669        for key, value in parameter.items():
1670            param = key.strip().lower() if isinstance(
1671                key, basestring) else None
1672            if not param:
1673                raise TypeError('Invalid parameter')
1674            if param == 'all':
1675                if value is not None:
1676                    raise ValueError('A value must ot be specified'
1677                        " when parameter is 'all'")
1678                params = {'all': None}
1679                break
1680            params[param] = value
1681        local = ' LOCAL' if local else ''
1682        for param, value in params.items():
1683            if value is None:
1684                q = 'RESET%s %s' % (local, param)
1685            else:
1686                q = 'SET%s %s TO %s' % (local, param, value)
1687            self._do_debug(q)
1688            self.db.query(q)
1689
1690    def query(self, command, *args):
1691        """Execute a SQL command string.
1692
1693        This method simply sends a SQL query to the database.  If the query is
1694        an insert statement that inserted exactly one row into a table that
1695        has OIDs, the return value is the OID of the newly inserted row.
1696        If the query is an update or delete statement, or an insert statement
1697        that did not insert exactly one row in a table with OIDs, then the
1698        number of rows affected is returned as a string.  If it is a statement
1699        that returns rows as a result (usually a select statement, but maybe
1700        also an "insert/update ... returning" statement), this method returns
1701        a Query object that can be accessed via getresult() or dictresult()
1702        or simply printed.  Otherwise, it returns `None`.
1703
1704        The query can contain numbered parameters of the form $1 in place
1705        of any data constant.  Arguments given after the query string will
1706        be substituted for the corresponding numbered parameter.  Parameter
1707        values can also be given as a single list or tuple argument.
1708        """
1709        # Wraps shared library function for debugging.
1710        if not self.db:
1711            raise _int_error('Connection is not valid')
1712        if args:
1713            self._do_debug(command, args)
1714            return self.db.query(command, args)
1715        self._do_debug(command)
1716        return self.db.query(command)
1717
1718    def query_formatted(self, command, parameters, types=None, inline=False):
1719        """Execute a formatted SQL command string.
1720
1721        Similar to query, but using Python format placeholders of the form
1722        %s or %(names)s instead of PostgreSQL placeholders of the form $1.
1723        The parameters must be passed as a tuple, list or dict.  You can
1724        also pass a corresponding tuple, list or dict of database types in
1725        order to format the parameters properly in case there is ambiguity.
1726
1727        If you set inline to True, the parameters will be sent to the database
1728        embedded in the SQL command, otherwise they will be sent separately.
1729        """
1730        return self.query(*self.adapter.format_query(
1731            command, parameters, types, inline))
1732
1733    def pkey(self, table, composite=False, flush=False):
1734        """Get or set the primary key of a table.
1735
1736        Single primary keys are returned as strings unless you
1737        set the composite flag.  Composite primary keys are always
1738        represented as tuples.  Note that this raises a KeyError
1739        if the table does not have a primary key.
1740
1741        If flush is set then the internal cache for primary keys will
1742        be flushed.  This may be necessary after the database schema or
1743        the search path has been changed.
1744        """
1745        pkeys = self._pkeys
1746        if flush:
1747            pkeys.clear()
1748            self._do_debug('The pkey cache has been flushed')
1749        try:  # cache lookup
1750            pkey = pkeys[table]
1751        except KeyError:  # cache miss, check the database
1752            q = ("SELECT a.attname, a.attnum, i.indkey FROM pg_index i"
1753                " JOIN pg_attribute a ON a.attrelid = i.indrelid"
1754                " AND a.attnum = ANY(i.indkey)"
1755                " AND NOT a.attisdropped"
1756                " WHERE i.indrelid=%s::regclass"
1757                " AND i.indisprimary ORDER BY a.attnum") % (
1758                    _quote_if_unqualified('$1', table),)
1759            pkey = self.db.query(q, (table,)).getresult()
1760            if not pkey:
1761                raise KeyError('Table %s has no primary key' % table)
1762            # we want to use the order defined in the primary key index here,
1763            # not the order as defined by the columns in the table
1764            if len(pkey) > 1:
1765                indkey = pkey[0][2]
1766                pkey = sorted(pkey, key=lambda row: indkey.index(row[1]))
1767                pkey = tuple(row[0] for row in pkey)
1768            else:
1769                pkey = pkey[0][0]
1770            pkeys[table] = pkey  # cache it
1771        if composite and not isinstance(pkey, tuple):
1772            pkey = (pkey,)
1773        return pkey
1774
1775    def get_databases(self):
1776        """Get list of databases in the system."""
1777        return [s[0] for s in
1778            self.db.query('SELECT datname FROM pg_database').getresult()]
1779
1780    def get_relations(self, kinds=None, system=False):
1781        """Get list of relations in connected database of specified kinds.
1782
1783        If kinds is None or empty, all kinds of relations are returned.
1784        Otherwise kinds can be a string or sequence of type letters
1785        specifying which kind of relations you want to list.
1786
1787        Set the system flag if you want to get the system relations as well.
1788        """
1789        where = []
1790        if kinds:
1791            where.append("r.relkind IN (%s)" %
1792                ','.join("'%s'" % k for k in kinds))
1793        if not system:
1794            where.append("s.nspname NOT SIMILAR"
1795                " TO 'pg/_%|information/_schema' ESCAPE '/'")
1796        where = " WHERE %s" % ' AND '.join(where) if where else ''
1797        q = ("SELECT quote_ident(s.nspname)||'.'||quote_ident(r.relname)"
1798            " FROM pg_class r"
1799            " JOIN pg_namespace s ON s.oid = r.relnamespace%s"
1800            " ORDER BY s.nspname, r.relname") % where
1801        return [r[0] for r in self.db.query(q).getresult()]
1802
1803    def get_tables(self, system=False):
1804        """Return list of tables in connected database.
1805
1806        Set the system flag if you want to get the system tables as well.
1807        """
1808        return self.get_relations('r', system)
1809
1810    def get_attnames(self, table, with_oid=True, flush=False):
1811        """Given the name of a table, dig out the set of attribute names.
1812
1813        Returns a read-only dictionary of attribute names (the names are
1814        the keys, the values are the names of the attributes' types)
1815        with the column names in the proper order if you iterate over it.
1816
1817        If flush is set, then the internal cache for attribute names will
1818        be flushed. This may be necessary after the database schema or
1819        the search path has been changed.
1820
1821        By default, only a limited number of simple types will be returned.
1822        You can get the regular types after calling use_regtypes(True).
1823        """
1824        attnames = self._attnames
1825        if flush:
1826            attnames.clear()
1827            self._do_debug('The attnames cache has been flushed')
1828        try:  # cache lookup
1829            names = attnames[table]
1830        except KeyError:  # cache miss, check the database
1831            q = "a.attnum > 0"
1832            if with_oid:
1833                q = "(%s OR a.attname = 'oid')" % q
1834            q = self._query_attnames % (_quote_if_unqualified('$1', table), q)
1835            names = self.db.query(q, (table,)).getresult()
1836            types = self.dbtypes
1837            names = ((name[0], types.add(*name[1:])) for name in names)
1838            names = AttrDict(names)
1839            attnames[table] = names  # cache it
1840        return names
1841
1842    def use_regtypes(self, regtypes=None):
1843        """Use regular type names instead of simplified type names."""
1844        if regtypes is None:
1845            return self.dbtypes._regtypes
1846        else:
1847            regtypes = bool(regtypes)
1848            if regtypes != self.dbtypes._regtypes:
1849                self.dbtypes._regtypes = regtypes
1850                self._attnames.clear()
1851                self.dbtypes.clear()
1852            return regtypes
1853
1854    def has_table_privilege(self, table, privilege='select', flush=False):
1855        """Check whether current user has specified table privilege.
1856
1857        If flush is set, then the internal cache for table privileges will
1858        be flushed. This may be necessary after privileges have been changed.
1859        """
1860        privileges = self._privileges
1861        if flush:
1862            privileges.clear()
1863            self._do_debug('The privileges cache has been flushed')
1864        privilege = privilege.lower()
1865        try:  # ask cache
1866            ret = privileges[table, privilege]
1867        except KeyError:  # cache miss, ask the database
1868            q = "SELECT has_table_privilege(%s, $2)" % (
1869                _quote_if_unqualified('$1', table),)
1870            q = self.db.query(q, (table, privilege))
1871            ret = q.getresult()[0][0] == self._make_bool(True)
1872            privileges[table, privilege] = ret  # cache it
1873        return ret
1874
1875    def get(self, table, row, keyname=None):
1876        """Get a row from a database table or view.
1877
1878        This method is the basic mechanism to get a single row.  It assumes
1879        that the keyname specifies a unique row.  It must be the name of a
1880        single column or a tuple of column names.  If the keyname is not
1881        specified, then the primary key for the table is used.
1882
1883        If row is a dictionary, then the value for the key is taken from it.
1884        Otherwise, the row must be a single value or a tuple of values
1885        corresponding to the passed keyname or primary key.  The fetched row
1886        from the table will be returned as a new dictionary or used to replace
1887        the existing values when row was passed as a dictionary.
1888
1889        The OID is also put into the dictionary if the table has one, but
1890        in order to allow the caller to work with multiple tables, it is
1891        munged as "oid(table)" using the actual name of the table.
1892        """
1893        if table.endswith('*'):  # hint for descendant tables can be ignored
1894            table = table[:-1].rstrip()
1895        attnames = self.get_attnames(table)
1896        qoid = _oid_key(table) if 'oid' in attnames else None
1897        if keyname and isinstance(keyname, basestring):
1898            keyname = (keyname,)
1899        if qoid and isinstance(row, dict) and qoid in row and 'oid' not in row:
1900            row['oid'] = row[qoid]
1901        if not keyname:
1902            try:  # if keyname is not specified, try using the primary key
1903                keyname = self.pkey(table, True)
1904            except KeyError:  # the table has no primary key
1905                # try using the oid instead
1906                if qoid and isinstance(row, dict) and 'oid' in row:
1907                    keyname = ('oid',)
1908                else:
1909                    raise _prg_error('Table %s has no primary key' % table)
1910            else:  # the table has a primary key
1911                # check whether all key columns have values
1912                if isinstance(row, dict) and not set(keyname).issubset(row):
1913                    # try using the oid instead
1914                    if qoid and 'oid' in row:
1915                        keyname = ('oid',)
1916                    else:
1917                        raise KeyError(
1918                            'Missing value in row for specified keyname')
1919        if not isinstance(row, dict):
1920            if not isinstance(row, (tuple, list)):
1921                row = [row]
1922            if len(keyname) != len(row):
1923                raise KeyError(
1924                    'Differing number of items in keyname and row')
1925            row = dict(zip(keyname, row))
1926        params = self.adapter.parameter_list()
1927        adapt = params.add
1928        col = self.escape_identifier
1929        what = 'oid, *' if qoid else '*'
1930        where = ' AND '.join('%s = %s' % (
1931            col(k), adapt(row[k], attnames[k])) for k in keyname)
1932        if 'oid' in row:
1933            if qoid:
1934                row[qoid] = row['oid']
1935            del row['oid']
1936        q = 'SELECT %s FROM %s WHERE %s LIMIT 1' % (
1937            what, self._escape_qualified_name(table), where)
1938        self._do_debug(q, params)
1939        q = self.db.query(q, params)
1940        res = q.dictresult()
1941        if not res:
1942            raise _db_error('No such record in %s\nwhere %s\nwith %s' % (
1943                table, where, self._list_params(params)))
1944        for n, value in res[0].items():
1945            if qoid and n == 'oid':
1946                n = qoid
1947            row[n] = value
1948        return row
1949
1950    def insert(self, table, row=None, **kw):
1951        """Insert a row into a database table.
1952
1953        This method inserts a row into a table.  The name of the table must
1954        be passed as the first parameter.  The other parameters are used for
1955        providing the data of the row that shall be inserted into the table.
1956        If a dictionary is supplied as the second parameter, it starts with
1957        that.  Otherwise it uses a blank dictionary. Either way the dictionary
1958        is updated from the keywords.
1959
1960        The dictionary is then reloaded with the values actually inserted in
1961        order to pick up values modified by rules, triggers, etc.
1962        """
1963        if table.endswith('*'):  # hint for descendant tables can be ignored
1964            table = table[:-1].rstrip()
1965        if row is None:
1966            row = {}
1967        row.update(kw)
1968        if 'oid' in row:
1969            del row['oid']  # do not insert oid
1970        attnames = self.get_attnames(table)
1971        qoid = _oid_key(table) if 'oid' in attnames else None
1972        params = self.adapter.parameter_list()
1973        adapt = params.add
1974        col = self.escape_identifier
1975        names, values = [], []
1976        for n in attnames:
1977            if n in row:
1978                names.append(col(n))
1979                values.append(adapt(row[n], attnames[n]))
1980        if not names:
1981            raise _prg_error('No column found that can be inserted')
1982        names, values = ', '.join(names), ', '.join(values)
1983        ret = 'oid, *' if qoid else '*'
1984        q = 'INSERT INTO %s (%s) VALUES (%s) RETURNING %s' % (
1985            self._escape_qualified_name(table), names, values, ret)
1986        self._do_debug(q, params)
1987        q = self.db.query(q, params)
1988        res = q.dictresult()
1989        if res:  # this should always be true
1990            for n, value in res[0].items():
1991                if qoid and n == 'oid':
1992                    n = qoid
1993                row[n] = value
1994        return row
1995
1996    def update(self, table, row=None, **kw):
1997        """Update an existing row in a database table.
1998
1999        Similar to insert, but updates an existing row.  The update is based
2000        on the primary key of the table or the OID value as munged by get()
2001        or passed as keyword.  The OID will take precedence if provided, so
2002        that it is possible to update the primary key itself.
2003
2004        The dictionary is then modified to reflect any changes caused by the
2005        update due to triggers, rules, default values, etc.
2006        """
2007        if table.endswith('*'):
2008            table = table[:-1].rstrip()  # need parent table name
2009        attnames = self.get_attnames(table)
2010        qoid = _oid_key(table) if 'oid' in attnames else None
2011        if row is None:
2012            row = {}
2013        elif 'oid' in row:
2014            del row['oid']  # only accept oid key from named args for safety
2015        row.update(kw)
2016        if qoid and qoid in row and 'oid' not in row:
2017            row['oid'] = row[qoid]
2018        if qoid and 'oid' in row:  # try using the oid
2019            keyname = ('oid',)
2020        else:  # try using the primary key
2021            try:
2022                keyname = self.pkey(table, True)
2023            except KeyError:  # the table has no primary key
2024                raise _prg_error('Table %s has no primary key' % table)
2025            # check whether all key columns have values
2026            if not set(keyname).issubset(row):
2027                raise KeyError('Missing value for primary key in row')
2028        params = self.adapter.parameter_list()
2029        adapt = params.add
2030        col = self.escape_identifier
2031        where = ' AND '.join('%s = %s' % (
2032            col(k), adapt(row[k], attnames[k])) for k in keyname)
2033        if 'oid' in row:
2034            if qoid:
2035                row[qoid] = row['oid']
2036            del row['oid']
2037        values = []
2038        keyname = set(keyname)
2039        for n in attnames:
2040            if n in row and n not in keyname:
2041                values.append('%s = %s' % (col(n), adapt(row[n], attnames[n])))
2042        if not values:
2043            return row
2044        values = ', '.join(values)
2045        ret = 'oid, *' if qoid else '*'
2046        q = 'UPDATE %s SET %s WHERE %s RETURNING %s' % (
2047            self._escape_qualified_name(table), values, where, ret)
2048        self._do_debug(q, params)
2049        q = self.db.query(q, params)
2050        res = q.dictresult()
2051        if res:  # may be empty when row does not exist
2052            for n, value in res[0].items():
2053                if qoid and n == 'oid':
2054                    n = qoid
2055                row[n] = value
2056        return row
2057
2058    def upsert(self, table, row=None, **kw):
2059        """Insert a row into a database table with conflict resolution
2060
2061        This method inserts a row into a table, but instead of raising a
2062        ProgrammingError exception in case a row with the same primary key
2063        already exists, an update will be executed instead.  This will be
2064        performed as a single atomic operation on the database, so race
2065        conditions can be avoided.
2066
2067        Like the insert method, the first parameter is the name of the
2068        table and the second parameter can be used to pass the values to
2069        be inserted as a dictionary.
2070
2071        Unlike the insert und update statement, keyword parameters are not
2072        used to modify the dictionary, but to specify which columns shall
2073        be updated in case of a conflict, and in which way:
2074
2075        A value of False or None means the column shall not be updated,
2076        a value of True means the column shall be updated with the value
2077        that has been proposed for insertion, i.e. has been passed as value
2078        in the dictionary.  Columns that are not specified by keywords but
2079        appear as keys in the dictionary are also updated like in the case
2080        keywords had been passed with the value True.
2081
2082        So if in the case of a conflict you want to update every column that
2083        has been passed in the dictionary row, you would call upsert(table, row).
2084        If you don't want to do anything in case of a conflict, i.e. leave
2085        the existing row as it is, call upsert(table, row, **dict.fromkeys(row)).
2086
2087        If you need more fine-grained control of what gets updated, you can
2088        also pass strings in the keyword parameters.  These strings will
2089        be used as SQL expressions for the update columns.  In these
2090        expressions you can refer to the value that already exists in
2091        the table by prefixing the column name with "included.", and to
2092        the value that has been proposed for insertion by prefixing the
2093        column name with the "excluded."
2094
2095        The dictionary is modified in any case to reflect the values in
2096        the database after the operation has completed.
2097
2098        Note: The method uses the PostgreSQL "upsert" feature which is
2099        only available since PostgreSQL 9.5.
2100        """
2101        if table.endswith('*'):  # hint for descendant tables can be ignored
2102            table = table[:-1].rstrip()
2103        if row is None:
2104            row = {}
2105        if 'oid' in row:
2106            del row['oid']  # do not insert oid
2107        if 'oid' in kw:
2108            del kw['oid']  # do not update oid
2109        attnames = self.get_attnames(table)
2110        qoid = _oid_key(table) if 'oid' in attnames else None
2111        params = self.adapter.parameter_list()
2112        adapt = params.add
2113        col = self.escape_identifier
2114        names, values, updates = [], [], []
2115        for n in attnames:
2116            if n in row:
2117                names.append(col(n))
2118                values.append(adapt(row[n], attnames[n]))
2119        names, values = ', '.join(names), ', '.join(values)
2120        try:
2121            keyname = self.pkey(table, True)
2122        except KeyError:
2123            raise _prg_error('Table %s has no primary key' % table)
2124        target = ', '.join(col(k) for k in keyname)
2125        update = []
2126        keyname = set(keyname)
2127        keyname.add('oid')
2128        for n in attnames:
2129            if n not in keyname:
2130                value = kw.get(n, True)
2131                if value:
2132                    if not isinstance(value, basestring):
2133                        value = 'excluded.%s' % col(n)
2134                    update.append('%s = %s' % (col(n), value))
2135        if not values:
2136            return row
2137        do = 'update set %s' % ', '.join(update) if update else 'nothing'
2138        ret = 'oid, *' if qoid else '*'
2139        q = ('INSERT INTO %s AS included (%s) VALUES (%s)'
2140            ' ON CONFLICT (%s) DO %s RETURNING %s') % (
2141                self._escape_qualified_name(table), names, values,
2142                target, do, ret)
2143        self._do_debug(q, params)
2144        try:
2145            q = self.db.query(q, params)
2146        except ProgrammingError:
2147            if self.server_version < 90500:
2148                raise _prg_error(
2149                    'Upsert operation is not supported by PostgreSQL version')
2150            raise  # re-raise original error
2151        res = q.dictresult()
2152        if res:  # may be empty with "do nothing"
2153            for n, value in res[0].items():
2154                if qoid and n == 'oid':
2155                    n = qoid
2156                row[n] = value
2157        else:
2158            self.get(table, row)
2159        return row
2160
2161    def clear(self, table, row=None):
2162        """Clear all the attributes to values determined by the types.
2163
2164        Numeric types are set to 0, Booleans are set to false, and everything
2165        else is set to the empty string.  If the row argument is present,
2166        it is used as the row dictionary and any entries matching attribute
2167        names are cleared with everything else left unchanged.
2168        """
2169        # At some point we will need a way to get defaults from a table.
2170        if row is None:
2171            row = {}  # empty if argument is not present
2172        attnames = self.get_attnames(table)
2173        for n, t in attnames.items():
2174            if n == 'oid':
2175                continue
2176            t = t.simple
2177            if t in DbTypes._num_types:
2178                row[n] = 0
2179            elif t == 'bool':
2180                row[n] = self._make_bool(False)
2181            else:
2182                row[n] = ''
2183        return row
2184
2185    def delete(self, table, row=None, **kw):
2186        """Delete an existing row in a database table.
2187
2188        This method deletes the row from a table.  It deletes based on the
2189        primary key of the table or the OID value as munged by get() or
2190        passed as keyword.  The OID will take precedence if provided.
2191
2192        The return value is the number of deleted rows (i.e. 0 if the row
2193        did not exist and 1 if the row was deleted).
2194
2195        Note that if the row cannot be deleted because e.g. it is still
2196        referenced by another table, this method raises a ProgrammingError.
2197        """
2198        if table.endswith('*'):  # hint for descendant tables can be ignored
2199            table = table[:-1].rstrip()
2200        attnames = self.get_attnames(table)
2201        qoid = _oid_key(table) if 'oid' in attnames else None
2202        if row is None:
2203            row = {}
2204        elif 'oid' in row:
2205            del row['oid']  # only accept oid key from named args for safety
2206        row.update(kw)
2207        if qoid and qoid in row and 'oid' not in row:
2208            row['oid'] = row[qoid]
2209        if qoid and 'oid' in row:  # try using the oid
2210            keyname = ('oid',)
2211        else:  # try using the primary key
2212            try:
2213                keyname = self.pkey(table, True)
2214            except KeyError:  # the table has no primary key
2215                raise _prg_error('Table %s has no primary key' % table)
2216            # check whether all key columns have values
2217            if not set(keyname).issubset(row):
2218                raise KeyError('Missing value for primary key in row')
2219        params = self.adapter.parameter_list()
2220        adapt = params.add
2221        col = self.escape_identifier
2222        where = ' AND '.join('%s = %s' % (
2223            col(k), adapt(row[k], attnames[k])) for k in keyname)
2224        if 'oid' in row:
2225            if qoid:
2226                row[qoid] = row['oid']
2227            del row['oid']
2228        q = 'DELETE FROM %s WHERE %s' % (
2229            self._escape_qualified_name(table), where)
2230        self._do_debug(q, params)
2231        res = self.db.query(q, params)
2232        return int(res)
2233
2234    def truncate(self, table, restart=False, cascade=False, only=False):
2235        """Empty a table or set of tables.
2236
2237        This method quickly removes all rows from the given table or set
2238        of tables.  It has the same effect as an unqualified DELETE on each
2239        table, but since it does not actually scan the tables it is faster.
2240        Furthermore, it reclaims disk space immediately, rather than requiring
2241        a subsequent VACUUM operation. This is most useful on large tables.
2242
2243        If restart is set to True, sequences owned by columns of the truncated
2244        table(s) are automatically restarted.  If cascade is set to True, it
2245        also truncates all tables that have foreign-key references to any of
2246        the named tables.  If the parameter only is not set to True, all the
2247        descendant tables (if any) will also be truncated. Optionally, a '*'
2248        can be specified after the table name to explicitly indicate that
2249        descendant tables are included.
2250        """
2251        if isinstance(table, basestring):
2252            only = {table: only}
2253            table = [table]
2254        elif isinstance(table, (list, tuple)):
2255            if isinstance(only, (list, tuple)):
2256                only = dict(zip(table, only))
2257            else:
2258                only = dict.fromkeys(table, only)
2259        elif isinstance(table, (set, frozenset)):
2260            only = dict.fromkeys(table, only)
2261        else:
2262            raise TypeError('The table must be a string, list or set')
2263        if not (restart is None or isinstance(restart, (bool, int))):
2264            raise TypeError('Invalid type for the restart option')
2265        if not (cascade is None or isinstance(cascade, (bool, int))):
2266            raise TypeError('Invalid type for the cascade option')
2267        tables = []
2268        for t in table:
2269            u = only.get(t)
2270            if not (u is None or isinstance(u, (bool, int))):
2271                raise TypeError('Invalid type for the only option')
2272            if t.endswith('*'):
2273                if u:
2274                    raise ValueError(
2275                        'Contradictory table name and only options')
2276                t = t[:-1].rstrip()
2277            t = self._escape_qualified_name(t)
2278            if u:
2279                t = 'ONLY %s' % t
2280            tables.append(t)
2281        q = ['TRUNCATE', ', '.join(tables)]
2282        if restart:
2283            q.append('RESTART IDENTITY')
2284        if cascade:
2285            q.append('CASCADE')
2286        q = ' '.join(q)
2287        self._do_debug(q)
2288        return self.db.query(q)
2289
2290    def get_as_list(self, table, what=None, where=None,
2291            order=None, limit=None, offset=None, scalar=False):
2292        """Get a table as a list.
2293
2294        This gets a convenient representation of the table as a list
2295        of named tuples in Python.  You only need to pass the name of
2296        the table (or any other SQL expression returning rows).  Note that
2297        by default this will return the full content of the table which
2298        can be huge and overflow your memory.  However, you can control
2299        the amount of data returned using the other optional parameters.
2300
2301        The parameter 'what' can restrict the query to only return a
2302        subset of the table columns.  It can be a string, list or a tuple.
2303        The parameter 'where' can restrict the query to only return a
2304        subset of the table rows.  It can be a string, list or a tuple
2305        of SQL expressions that all need to be fulfilled.  The parameter
2306        'order' specifies the ordering of the rows.  It can also be a
2307        other string, list or a tuple.  If no ordering is specified,
2308        the result will be ordered by the primary key(s) or all columns
2309        if no primary key exists.  You can set 'order' to False if you
2310        don't care about the ordering.  The parameters 'limit' and 'offset'
2311        can be integers specifying the maximum number of rows returned
2312        and a number of rows skipped over.
2313
2314        If you set the 'scalar' option to True, then instead of the
2315        named tuples you will get the first items of these tuples.
2316        This is useful if the result has only one column anyway.
2317        """
2318        if not table:
2319            raise TypeError('The table name is missing')
2320        if what:
2321            if isinstance(what, (list, tuple)):
2322                what = ', '.join(map(str, what))
2323            if order is None:
2324                order = what
2325        else:
2326            what = '*'
2327        q = ['SELECT', what, 'FROM', table]
2328        if where:
2329            if isinstance(where, (list, tuple)):
2330                where = ' AND '.join(map(str, where))
2331            q.extend(['WHERE', where])
2332        if order is None:
2333            try:
2334                order = self.pkey(table, True)
2335            except (KeyError, ProgrammingError):
2336                try:
2337                    order = list(self.get_attnames(table))
2338                except (KeyError, ProgrammingError):
2339                    pass
2340        if order:
2341            if isinstance(order, (list, tuple)):
2342                order = ', '.join(map(str, order))
2343            q.extend(['ORDER BY', order])
2344        if limit:
2345            q.append('LIMIT %d' % limit)
2346        if offset:
2347            q.append('OFFSET %d' % offset)
2348        q = ' '.join(q)
2349        self._do_debug(q)
2350        q = self.db.query(q)
2351        res = q.namedresult()
2352        if res and scalar:
2353            res = [row[0] for row in res]
2354        return res
2355
2356    def get_as_dict(self, table, keyname=None, what=None, where=None,
2357            order=None, limit=None, offset=None, scalar=False):
2358        """Get a table as a dictionary.
2359
2360        This method is similar to get_as_list(), but returns the table
2361        as a Python dict instead of a Python list, which can be even
2362        more convenient. The primary key column(s) of the table will
2363        be used as the keys of the dictionary, while the other column(s)
2364        will be the corresponding values.  The keys will be named tuples
2365        if the table has a composite primary key.  The rows will be also
2366        named tuples unless the 'scalar' option has been set to True.
2367        With the optional parameter 'keyname' you can specify an alternative
2368        set of columns to be used as the keys of the dictionary.  It must
2369        be set as a string, list or a tuple.
2370
2371        If the Python version supports it, the dictionary will be an
2372        OrderedDict using the order specified with the 'order' parameter
2373        or the key column(s) if not specified.  You can set 'order' to False
2374        if you don't care about the ordering.  In this case the returned
2375        dictionary will be an ordinary one.
2376        """
2377        if not table:
2378            raise TypeError('The table name is missing')
2379        if not keyname:
2380            try:
2381                keyname = self.pkey(table, True)
2382            except (KeyError, ProgrammingError):
2383                raise _prg_error('Table %s has no primary key' % table)
2384        if isinstance(keyname, basestring):
2385            keyname = [keyname]
2386        elif not isinstance(keyname, (list, tuple)):
2387            raise KeyError('The keyname must be a string, list or tuple')
2388        if what:
2389            if isinstance(what, (list, tuple)):
2390                what = ', '.join(map(str, what))
2391            if order is None:
2392                order = what
2393        else:
2394            what = '*'
2395        q = ['SELECT', what, 'FROM', table]
2396        if where:
2397            if isinstance(where, (list, tuple)):
2398                where = ' AND '.join(map(str, where))
2399            q.extend(['WHERE', where])
2400        if order is None:
2401            order = keyname
2402        if order:
2403            if isinstance(order, (list, tuple)):
2404                order = ', '.join(map(str, order))
2405            q.extend(['ORDER BY', order])
2406        if limit:
2407            q.append('LIMIT %d' % limit)
2408        if offset:
2409            q.append('OFFSET %d' % offset)
2410        q = ' '.join(q)
2411        self._do_debug(q)
2412        q = self.db.query(q)
2413        res = q.getresult()
2414        cls = OrderedDict if order else dict
2415        if not res:
2416            return cls()
2417        keyset = set(keyname)
2418        fields = q.listfields()
2419        if not keyset.issubset(fields):
2420            raise KeyError('Missing keyname in row')
2421        keyind, rowind = [], []
2422        for i, f in enumerate(fields):
2423            (keyind if f in keyset else rowind).append(i)
2424        keytuple = len(keyind) > 1
2425        getkey = itemgetter(*keyind)
2426        keys = map(getkey, res)
2427        if scalar:
2428            rowind = rowind[:1]
2429            rowtuple = False
2430        else:
2431            rowtuple = len(rowind) > 1
2432        if scalar or rowtuple:
2433            getrow = itemgetter(*rowind)
2434        else:
2435            rowind = rowind[0]
2436            getrow = lambda row: (row[rowind],)
2437            rowtuple = True
2438        rows = map(getrow, res)
2439        if keytuple or rowtuple:
2440            namedresult = get_namedresult()
2441            if namedresult:
2442                if keytuple:
2443                    keys = namedresult(_MemoryQuery(keys, keyname))
2444                if rowtuple:
2445                    fields = [f for f in fields if f not in keyset]
2446                    rows = namedresult(_MemoryQuery(rows, fields))
2447        return cls(zip(keys, rows))
2448
2449    def notification_handler(self,
2450            event, callback, arg_dict=None, timeout=None, stop_event=None):
2451        """Get notification handler that will run the given callback."""
2452        return NotificationHandler(self,
2453            event, callback, arg_dict, timeout, stop_event)
2454
2455
2456# if run as script, print some information
2457
2458if __name__ == '__main__':
2459    print('PyGreSQL version' + version)
2460    print('')
2461    print(__doc__)
Note: See TracBrowser for help on using the repository browser.