source: trunk/pg.py @ 857

Last change on this file since 857 was 857, checked in by cito, 3 years ago

Add system parameter to get_relations()

Also fix a regression in the 4.x branch when using temporary tables,
related to filtering system tables (as discussed on the mailing list).

  • Property svn:keywords set to Id
File size: 88.4 KB
Line 
1#! /usr/bin/python
2#
3# pg.py
4#
5# $Id: pg.py 857 2016-03-18 12:22:21Z cito $
6#
7
8"""PyGreSQL classic interface.
9
10This pg module implements some basic database management stuff.
11It includes the _pg module and builds on it, providing the higher
12level wrapper class named DB with additional functionality.
13This is known as the "classic" ("old style") PyGreSQL interface.
14For a DB-API 2 compliant interface use the newer pgdb module.
15"""
16
17# Copyright (c) 1997-2016 by D'Arcy J.M. Cain.
18#
19# Contributions made by Ch. Zwerschke and others.
20#
21# The notification handler is based on pgnotify which is
22# Copyright (c) 2001 Ng Pheng Siong. All rights reserved.
23#
24# Permission to use, copy, modify, and distribute this software and its
25# documentation for any purpose and without fee is hereby granted,
26# provided that the above copyright notice appear in all copies and that
27# both that copyright notice and this permission notice appear in
28# supporting documentation.
29
30from __future__ import print_function, division
31
32from _pg import *
33
34__version__ = version
35
36import select
37import warnings
38
39from datetime import date, time, datetime, timedelta, tzinfo
40from decimal import Decimal
41from math import isnan, isinf
42from collections import namedtuple
43from operator import itemgetter
44from functools import partial
45from re import compile as regex
46from json import loads as jsondecode, dumps as jsonencode
47from uuid import UUID
48
49try:
50    long
51except NameError:  # Python >= 3.0
52    long = int
53
54try:
55    basestring
56except NameError:  # Python >= 3.0
57    basestring = (str, bytes)
58
59
60# Auxiliary classes and functions that are independent from a DB connection:
61
62try:
63    from collections import OrderedDict
64except ImportError:  # Python 2.6 or 3.0
65    OrderedDict = dict
66
67
68    class AttrDict(dict):
69        """Simple read-only ordered dictionary for storing attribute names."""
70
71        def __init__(self, *args, **kw):
72            if len(args) > 1 or kw:
73                raise TypeError
74            items = args[0] if args else []
75            if isinstance(items, dict):
76                raise TypeError
77            items = list(items)
78            self._keys = [item[0] for item in items]
79            dict.__init__(self, items)
80            self._read_only = True
81            error = self._read_only_error
82            self.clear = self.update = error
83            self.pop = self.setdefault = self.popitem = error
84
85        def __setitem__(self, key, value):
86            if self._read_only:
87                self._read_only_error()
88            dict.__setitem__(self, key, value)
89
90        def __delitem__(self, key):
91            if self._read_only:
92                self._read_only_error()
93            dict.__delitem__(self, key)
94
95        def __iter__(self):
96            return iter(self._keys)
97
98        def keys(self):
99            return list(self._keys)
100
101        def values(self):
102            return [self[key] for key in self]
103
104        def items(self):
105            return [(key, self[key]) for key in self]
106
107        def iterkeys(self):
108            return self.__iter__()
109
110        def itervalues(self):
111            return iter(self.values())
112
113        def iteritems(self):
114            return iter(self.items())
115
116        @staticmethod
117        def _read_only_error(*args, **kw):
118            raise TypeError('This object is read-only')
119
120else:
121
122     class AttrDict(OrderedDict):
123        """Simple read-only ordered dictionary for storing attribute names."""
124
125        def __init__(self, *args, **kw):
126            self._read_only = False
127            OrderedDict.__init__(self, *args, **kw)
128            self._read_only = True
129            error = self._read_only_error
130            self.clear = self.update = error
131            self.pop = self.setdefault = self.popitem = error
132
133        def __setitem__(self, key, value):
134            if self._read_only:
135                self._read_only_error()
136            OrderedDict.__setitem__(self, key, value)
137
138        def __delitem__(self, key):
139            if self._read_only:
140                self._read_only_error()
141            OrderedDict.__delitem__(self, key)
142
143        @staticmethod
144        def _read_only_error(*args, **kw):
145            raise TypeError('This object is read-only')
146
147try:
148    from inspect import signature
149except ImportError:  # Python < 3.3
150    from inspect import getargspec
151
152    def get_args(func):
153        return getargspec(func).args
154else:
155
156    def get_args(func):
157        return list(signature(func).parameters)
158
159try:
160    from datetime import timezone
161except ImportError:  # Python < 3.2
162
163    class timezone(tzinfo):
164        """Simple timezone implementation."""
165
166        def __init__(self, offset, name=None):
167            self.offset = offset
168            if not name:
169                minutes = self.offset.days * 1440 + self.offset.seconds // 60
170                if minutes < 0:
171                    hours, minutes = divmod(-minutes, 60)
172                    hours = -hours
173                else:
174                    hours, minutes = divmod(minutes, 60)
175                name = 'UTC%+03d:%02d' % (hours, minutes)
176            self.name = name
177
178        def utcoffset(self, dt):
179            return self.offset
180
181        def tzname(self, dt):
182            return self.name
183
184        def dst(self, dt):
185            return None
186
187    timezone.utc = timezone(timedelta(0), 'UTC')
188
189    _has_timezone = False
190else:
191    _has_timezone = True
192
193# time zones used in Postgres timestamptz output
194_timezones = dict(CET='+0100', EET='+0200', EST='-0500',
195    GMT='+0000', HST='-1000', MET='+0100', MST='-0700',
196    UCT='+0000', UTC='+0000', WET='+0000')
197
198
199def _timezone_as_offset(tz):
200    if tz.startswith(('+', '-')):
201        if len(tz) < 5:
202            return tz + '00'
203        return tz.replace(':', '')
204    return _timezones.get(tz, '+0000')
205
206
207def _get_timezone(tz):
208    tz = _timezone_as_offset(tz)
209    minutes = 60 * int(tz[1:3]) + int(tz[3:5])
210    if tz[0] == '-':
211        minutes = -minutes
212    return timezone(timedelta(minutes=minutes), tz)
213
214
215def _oid_key(table):
216    """Build oid key from a table name."""
217    return 'oid(%s)' % table
218
219
220class _SimpleTypes(dict):
221    """Dictionary mapping pg_type names to simple type names."""
222
223    _types = {'bool': 'bool',
224        'bytea': 'bytea',
225        'date': 'date interval time timetz timestamp timestamptz'
226            ' abstime reltime',  # these are very old
227        'float': 'float4 float8',
228        'int': 'cid int2 int4 int8 oid xid',
229        'hstore': 'hstore', 'json': 'json jsonb', 'uuid': 'uuid',
230        'num': 'numeric', 'money': 'money',
231        'text': 'bpchar char name text varchar'}
232
233    def __init__(self):
234        for typ, keys in self._types.items():
235            for key in keys.split():
236                self[key] = typ
237                self['_%s' % key] = '%s[]' % typ
238
239    # this could be a static method in Python > 2.6
240    def __missing__(self, key):
241        return 'text'
242
243_simpletypes = _SimpleTypes()
244
245
246def _quote_if_unqualified(param, name):
247    """Quote parameter representing a qualified name.
248
249    Puts a quote_ident() call around the give parameter unless
250    the name contains a dot, in which case the name is ambiguous
251    (could be a qualified name or just a name with a dot in it)
252    and must be quoted manually by the caller.
253    """
254    if isinstance(name, basestring) and '.' not in name:
255        return 'quote_ident(%s)' % (param,)
256    return param
257
258
259class _ParameterList(list):
260    """Helper class for building typed parameter lists."""
261
262    def add(self, value, typ=None):
263        """Typecast value with known database type and build parameter list.
264
265        If this is a literal value, it will be returned as is.  Otherwise, a
266        placeholder will be returned and the parameter list will be augmented.
267        """
268        value = self.adapt(value, typ)
269        if isinstance(value, Literal):
270            return value
271        self.append(value)
272        return '$%d' % len(self)
273
274
275class Bytea(bytes):
276    """Wrapper class for marking Bytea values."""
277
278
279class Hstore(dict):
280    """Wrapper class for marking hstore values."""
281
282    _re_quote = regex('^[Nn][Uu][Ll][Ll]$|[ ,=>]')
283
284    @classmethod
285    def _quote(cls, s):
286        if s is None:
287            return 'NULL'
288        if not s:
289            return '""'
290        s = s.replace('"', '\\"')
291        if cls._re_quote.search(s):
292            s = '"%s"' % s
293        return s
294
295    def __str__(self):
296        q = self._quote
297        return ','.join('%s=>%s' % (q(k), q(v)) for k, v in self.items())
298
299
300class Json:
301    """Wrapper class for marking Json values."""
302
303    def __init__(self, obj):
304        self.obj = obj
305
306
307class Literal(str):
308    """Wrapper class for marking literal SQL values."""
309
310
311class Adapter:
312    """Class providing methods for adapting parameters to the database."""
313
314    _bool_true_values = frozenset('t true 1 y yes on'.split())
315
316    _date_literals = frozenset('current_date current_time'
317        ' current_timestamp localtime localtimestamp'.split())
318
319    _re_array_quote = regex(r'[{},"\\\s]|^[Nn][Uu][Ll][Ll]$')
320    _re_record_quote = regex(r'[(,"\\]')
321    _re_array_escape = _re_record_escape = regex(r'(["\\])')
322
323    def __init__(self, db):
324        self.db = db
325        self.encode_json = db.encode_json
326        db = db.db
327        self.escape_bytea = db.escape_bytea
328        self.escape_string = db.escape_string
329
330    @classmethod
331    def _adapt_bool(cls, v):
332        """Adapt a boolean parameter."""
333        if isinstance(v, basestring):
334            if not v:
335                return None
336            v = v.lower() in cls._bool_true_values
337        return 't' if v else 'f'
338
339    @classmethod
340    def _adapt_date(cls, v):
341        """Adapt a date parameter."""
342        if not v:
343            return None
344        if isinstance(v, basestring) and v.lower() in cls._date_literals:
345            return Literal(v)
346        return v
347
348    @staticmethod
349    def _adapt_num(v):
350        """Adapt a numeric parameter."""
351        if not v and v != 0:
352            return None
353        return v
354
355    _adapt_int = _adapt_float = _adapt_money = _adapt_num
356
357    def _adapt_bytea(self, v):
358        """Adapt a bytea parameter."""
359        return self.escape_bytea(v)
360
361    def _adapt_json(self, v):
362        """Adapt a json parameter."""
363        if not v:
364            return None
365        if isinstance(v, basestring):
366            return v
367        return self.encode_json(v)
368
369    @classmethod
370    def _adapt_text_array(cls, v):
371        """Adapt a text type array parameter."""
372        if isinstance(v, list):
373            adapt = cls._adapt_text_array
374            return '{%s}' % ','.join(adapt(v) for v in v)
375        if v is None:
376            return 'null'
377        if not v:
378            return '""'
379        v = str(v)
380        if cls._re_array_quote.search(v):
381            v = '"%s"' % cls._re_array_escape.sub(r'\\\1', v)
382        return v
383
384    _adapt_date_array = _adapt_text_array
385
386    @classmethod
387    def _adapt_bool_array(cls, v):
388        """Adapt a boolean array parameter."""
389        if isinstance(v, list):
390            adapt = cls._adapt_bool_array
391            return '{%s}' % ','.join(adapt(v) for v in v)
392        if v is None:
393            return 'null'
394        if isinstance(v, basestring):
395            if not v:
396                return 'null'
397            v = v.lower() in cls._bool_true_values
398        return 't' if v else 'f'
399
400    @classmethod
401    def _adapt_num_array(cls, v):
402        """Adapt a numeric array parameter."""
403        if isinstance(v, list):
404            adapt = cls._adapt_num_array
405            return '{%s}' % ','.join(adapt(v) for v in v)
406        if not v and v != 0:
407            return 'null'
408        return str(v)
409
410    _adapt_int_array = _adapt_float_array = _adapt_money_array = \
411            _adapt_num_array
412
413    def _adapt_bytea_array(self, v):
414        """Adapt a bytea array parameter."""
415        if isinstance(v, list):
416            return b'{' + b','.join(
417                self._adapt_bytea_array(v) for v in v) + b'}'
418        if v is None:
419            return b'null'
420        return self.escape_bytea(v).replace(b'\\', b'\\\\')
421
422    def _adapt_json_array(self, v):
423        """Adapt a json array parameter."""
424        if isinstance(v, list):
425            adapt = self._adapt_json_array
426            return '{%s}' % ','.join(adapt(v) for v in v)
427        if not v:
428            return 'null'
429        if not isinstance(v, basestring):
430            v = self.encode_json(v)
431        if self._re_array_quote.search(v):
432            v = '"%s"' % self._re_array_escape.sub(r'\\\1', v)
433        return v
434
435    def _adapt_record(self, v, typ):
436        """Adapt a record parameter with given type."""
437        typ = self.get_attnames(typ).values()
438        if len(typ) != len(v):
439            raise TypeError('Record parameter %s has wrong size' % v)
440        adapt = self.adapt
441        value = []
442        for v, t in zip(v, typ):
443            v = adapt(v, t)
444            if v is None:
445                v = ''
446            elif not v:
447                v = '""'
448            else:
449                if isinstance(v, bytes):
450                    if str is not bytes:
451                        v = v.decode('ascii')
452                else:
453                    v = str(v)
454                if self._re_record_quote.search(v):
455                    v = '"%s"' % self._re_record_escape.sub(r'\\\1', v)
456            value.append(v)
457        return '(%s)' % ','.join(value)
458
459    def adapt(self, value, typ=None):
460        """Adapt a value with known database type."""
461        if value is not None and not isinstance(value, Literal):
462            if typ:
463                simple = self.get_simple_name(typ)
464            else:
465                typ = simple = self.guess_simple_type(value) or 'text'
466            try:
467                value = value.__pg_str__(typ)
468            except AttributeError:
469                pass
470            if simple == 'text':
471                pass
472            elif simple == 'record':
473                if isinstance(value, tuple):
474                    value = self._adapt_record(value, typ)
475            elif simple.endswith('[]'):
476                if isinstance(value, list):
477                    adapt = getattr(self, '_adapt_%s_array' % simple[:-2])
478                    value = adapt(value)
479            else:
480                adapt = getattr(self, '_adapt_%s' % simple)
481                value = adapt(value)
482        return value
483
484    @staticmethod
485    def simple_type(name):
486        """Create a simple database type with given attribute names."""
487        typ = DbType(name)
488        typ.simple = name
489        return typ
490
491    @staticmethod
492    def get_simple_name(typ):
493        """Get the simple name of a database type."""
494        if isinstance(typ, DbType):
495            return typ.simple
496        return _simpletypes[typ]
497
498    @staticmethod
499    def get_attnames(typ):
500        """Get the attribute names of a composite database type."""
501        if isinstance(typ, DbType):
502            return typ.attnames
503        return {}
504
505    @classmethod
506    def guess_simple_type(cls, value):
507        """Try to guess which database type the given value has."""
508        if isinstance(value, Bytea):
509            return 'bytea'
510        if isinstance(value, basestring):
511            return 'text'
512        if isinstance(value, bool):
513            return 'bool'
514        if isinstance(value, (int, long)):
515            return 'int'
516        if isinstance(value, float):
517            return 'float'
518        if isinstance(value, Decimal):
519            return 'num'
520        if isinstance(value, (date, time, datetime, timedelta)):
521            return 'date'
522        if isinstance(value, list):
523            return '%s[]' % cls.guess_simple_base_type(value)
524        if isinstance(value, tuple):
525            simple_type = cls.simple_type
526            typ = simple_type('record')
527            guess = cls.guess_simple_type
528            def get_attnames(self):
529                return AttrDict((str(n + 1), simple_type(guess(v)))
530                    for n, v in enumerate(value))
531            typ._get_attnames = get_attnames
532            return typ
533
534    @classmethod
535    def guess_simple_base_type(cls, value):
536        """Try to guess the base type of a given array."""
537        for v in value:
538            if isinstance(v, list):
539                typ = cls.guess_simple_base_type(v)
540            else:
541                typ = cls.guess_simple_type(v)
542            if typ:
543                return typ
544
545    def adapt_inline(self, value, nested=False):
546        """Adapt a value that is put into the SQL and needs to be quoted."""
547        if value is None:
548            return 'NULL'
549        if isinstance(value, Literal):
550            return value
551        if isinstance(value, Bytea):
552            value = self.escape_bytea(value)
553            if bytes is not str:  # Python >= 3.0
554                value = value.decode('ascii')
555        elif isinstance(value, Json):
556            if value.encode:
557                return value.encode()
558            value = self.encode_json(value)
559        elif isinstance(value, (datetime, date, time, timedelta)):
560            value = str(value)
561        if isinstance(value, basestring):
562            value = self.escape_string(value)
563            return "'%s'" % value
564        if isinstance(value, bool):
565            return 'true' if value else 'false'
566        if isinstance(value, float):
567            if isinf(value):
568                return "'-Infinity'" if value < 0 else "'Infinity'"
569            if isnan(value):
570                return "'NaN'"
571            return value
572        if isinstance(value, (int, long, Decimal)):
573            return value
574        if isinstance(value, list):
575            q = self.adapt_inline
576            s = '[%s]' if nested else 'ARRAY[%s]'
577            return s % ','.join(str(q(v, nested=True)) for v in value)
578        if isinstance(value, tuple):
579            q = self.adapt_inline
580            return '(%s)' % ','.join(str(q(v)) for v in value)
581        try:
582            value = value.__pg_repr__()
583        except AttributeError:
584            raise InterfaceError(
585                'Do not know how to adapt type %s' % type(value))
586        if isinstance(value, (tuple, list)):
587            value = self.adapt_inline(value)
588        return value
589
590    def parameter_list(self):
591        """Return a parameter list for parameters with known database types.
592
593        The list has an add(value, typ) method that will build up the
594        list and return either the literal value or a placeholder.
595        """
596        params = _ParameterList()
597        params.adapt = self.adapt
598        return params
599
600    def format_query(self, command, values, types=None, inline=False):
601        """Format a database query using the given values and types."""
602        if inline and types:
603            raise ValueError('Typed parameters must be sent separately')
604        params = self.parameter_list()
605        if isinstance(values, (list, tuple)):
606            if inline:
607                adapt = self.adapt_inline
608                literals = [adapt(value) for value in values]
609            else:
610                add = params.add
611                literals = []
612                append = literals.append
613                if types:
614                    if (not isinstance(types, (list, tuple)) or
615                            len(types) != len(values)):
616                        raise TypeError('The values and types do not match')
617                    for value, typ in zip(values, types):
618                        append(add(value, typ))
619                else:
620                    for value in values:
621                        append(add(value))
622            command = command % tuple(literals)
623        elif isinstance(values, dict):
624            if inline:
625                adapt = self.adapt_inline
626                literals = dict((key, adapt(value))
627                    for key, value in values.items())
628            else:
629                add = params.add
630                literals = {}
631                if types:
632                    if (not isinstance(types, dict) or
633                            len(types) < len(values)):
634                        raise TypeError('The values and types do not match')
635                    for key in sorted(values):
636                        literals[key] = add(values[key], types[key])
637                else:
638                    for key in sorted(values):
639                        literals[key] = add(values[key])
640            command = command % literals
641        else:
642            raise TypeError('The values must be passed as tuple, list or dict')
643        return command, params
644
645
646def cast_bool(value):
647    """Cast a boolean value."""
648    if not get_bool():
649        return value
650    return value[0] == 't'
651
652
653def cast_json(value):
654    """Cast a JSON value."""
655    cast = get_jsondecode()
656    if not cast:
657        return value
658    return cast(value)
659
660
661def cast_num(value):
662    """Cast a numeric value."""
663    return (get_decimal() or float)(value)
664
665
666def cast_money(value):
667    """Cast a money value."""
668    point = get_decimal_point()
669    if not point:
670        return value
671    if point != '.':
672        value = value.replace(point, '.')
673    value = value.replace('(', '-')
674    value = ''.join(c for c in value if c.isdigit() or c in '.-')
675    return (get_decimal() or float)(value)
676
677
678def cast_int2vector(value):
679    """Cast an int2vector value."""
680    return [int(v) for v in value.split()]
681
682
683def cast_date(value, connection):
684    """Cast a date value."""
685    # The output format depends on the server setting DateStyle.  The default
686    # setting ISO and the setting for German are actually unambiguous.  The
687    # order of days and months in the other two settings is however ambiguous,
688    # so at least here we need to consult the setting to properly parse values.
689    if value == '-infinity':
690        return date.min
691    if value == 'infinity':
692        return date.max
693    value = value.split()
694    if value[-1] == 'BC':
695        return date.min
696    value = value[0]
697    if len(value) > 10:
698        return date.max
699    fmt = connection.date_format()
700    return datetime.strptime(value, fmt).date()
701
702
703def cast_time(value):
704    """Cast a time value."""
705    fmt = '%H:%M:%S.%f' if len(value) > 8 else '%H:%M:%S'
706    return datetime.strptime(value, fmt).time()
707
708
709_re_timezone = regex('(.*)([+-].*)')
710
711
712def cast_timetz(value):
713    """Cast a timetz value."""
714    tz = _re_timezone.match(value)
715    if tz:
716        value, tz = tz.groups()
717    else:
718        tz = '+0000'
719    fmt = '%H:%M:%S.%f' if len(value) > 8 else '%H:%M:%S'
720    if _has_timezone:
721        value += _timezone_as_offset(tz)
722        fmt += '%z'
723        return datetime.strptime(value, fmt).timetz()
724    return datetime.strptime(value, fmt).timetz().replace(
725        tzinfo=_get_timezone(tz))
726
727
728def cast_timestamp(value, connection):
729    """Cast a timestamp value."""
730    if value == '-infinity':
731        return datetime.min
732    if value == 'infinity':
733        return datetime.max
734    value = value.split()
735    if value[-1] == 'BC':
736        return datetime.min
737    fmt = connection.date_format()
738    if fmt.endswith('-%Y') and len(value) > 2:
739        value = value[1:5]
740        if len(value[3]) > 4:
741            return datetime.max
742        fmt = ['%d %b' if fmt.startswith('%d') else '%b %d',
743            '%H:%M:%S.%f' if len(value[2]) > 8 else '%H:%M:%S', '%Y']
744    else:
745        if len(value[0]) > 10:
746            return datetime.max
747        fmt = [fmt, '%H:%M:%S.%f' if len(value[1]) > 8 else '%H:%M:%S']
748    return datetime.strptime(' '.join(value), ' '.join(fmt))
749
750
751def cast_timestamptz(value, connection):
752    """Cast a timestamptz value."""
753    if value == '-infinity':
754        return datetime.min
755    if value == 'infinity':
756        return datetime.max
757    value = value.split()
758    if value[-1] == 'BC':
759        return datetime.min
760    fmt = connection.date_format()
761    if fmt.endswith('-%Y') and len(value) > 2:
762        value = value[1:]
763        if len(value[3]) > 4:
764            return datetime.max
765        fmt = ['%d %b' if fmt.startswith('%d') else '%b %d',
766            '%H:%M:%S.%f' if len(value[2]) > 8 else '%H:%M:%S', '%Y']
767        value, tz = value[:-1], value[-1]
768    else:
769        if fmt.startswith('%Y-'):
770            tz = _re_timezone.match(value[1])
771            if tz:
772                value[1], tz = tz.groups()
773            else:
774                tz = '+0000'
775        else:
776            value, tz = value[:-1], value[-1]
777        if len(value[0]) > 10:
778            return datetime.max
779        fmt = [fmt, '%H:%M:%S.%f' if len(value[1]) > 8 else '%H:%M:%S']
780    if _has_timezone:
781        value.append(_timezone_as_offset(tz))
782        fmt.append('%z')
783        return datetime.strptime(' '.join(value), ' '.join(fmt))
784    return datetime.strptime(' '.join(value), ' '.join(fmt)).replace(
785        tzinfo=_get_timezone(tz))
786
787
788_re_interval_sql_standard = regex(
789    '(?:([+-])?([0-9]+)-([0-9]+) ?)?'
790    '(?:([+-]?[0-9]+)(?!:) ?)?'
791    '(?:([+-])?([0-9]+):([0-9]+):([0-9]+)(?:\\.([0-9]+))?)?')
792
793_re_interval_postgres = regex(
794    '(?:([+-]?[0-9]+) ?years? ?)?'
795    '(?:([+-]?[0-9]+) ?mons? ?)?'
796    '(?:([+-]?[0-9]+) ?days? ?)?'
797    '(?:([+-])?([0-9]+):([0-9]+):([0-9]+)(?:\\.([0-9]+))?)?')
798
799_re_interval_postgres_verbose = regex(
800    '@ ?(?:([+-]?[0-9]+) ?years? ?)?'
801    '(?:([+-]?[0-9]+) ?mons? ?)?'
802    '(?:([+-]?[0-9]+) ?days? ?)?'
803    '(?:([+-]?[0-9]+) ?hours? ?)?'
804    '(?:([+-]?[0-9]+) ?mins? ?)?'
805    '(?:([+-])?([0-9]+)(?:\\.([0-9]+))? ?secs?)? ?(ago)?')
806
807_re_interval_iso_8601 = regex(
808    'P(?:([+-]?[0-9]+)Y)?'
809    '(?:([+-]?[0-9]+)M)?'
810    '(?:([+-]?[0-9]+)D)?'
811    '(?:T(?:([+-]?[0-9]+)H)?'
812    '(?:([+-]?[0-9]+)M)?'
813    '(?:([+-])?([0-9]+)(?:\\.([0-9]+))?S)?)?')
814
815
816def cast_interval(value):
817    """Cast an interval value."""
818    # The output format depends on the server setting IntervalStyle, but it's
819    # not necessary to consult this setting to parse it.  It's faster to just
820    # check all possible formats, and there is no ambiguity here.
821    m = _re_interval_iso_8601.match(value)
822    if m:
823        m = [d or '0' for d in m.groups()]
824        secs_ago = m.pop(5) == '-'
825        m = [int(d) for d in m]
826        years, mons, days, hours, mins, secs, usecs = m
827        if secs_ago:
828            secs = -secs
829            usecs = -usecs
830    else:
831        m = _re_interval_postgres_verbose.match(value)
832        if m:
833            m, ago = [d or '0' for d in m.groups()[:8]], m.group(9)
834            secs_ago = m.pop(5) == '-'
835            m = [-int(d) for d in m] if ago else [int(d) for d in m]
836            years, mons, days, hours, mins, secs, usecs = m
837            if secs_ago:
838                secs = - secs
839                usecs = -usecs
840        else:
841            m = _re_interval_postgres.match(value)
842            if m and any(m.groups()):
843                m = [d or '0' for d in m.groups()]
844                hours_ago = m.pop(3) == '-'
845                m = [int(d) for d in m]
846                years, mons, days, hours, mins, secs, usecs = m
847                if hours_ago:
848                    hours = -hours
849                    mins = -mins
850                    secs = -secs
851                    usecs = -usecs
852            else:
853                m = _re_interval_sql_standard.match(value)
854                if m and any(m.groups()):
855                    m = [d or '0' for d in m.groups()]
856                    years_ago = m.pop(0) == '-'
857                    hours_ago = m.pop(3) == '-'
858                    m = [int(d) for d in m]
859                    years, mons, days, hours, mins, secs, usecs = m
860                    if years_ago:
861                        years = -years
862                        mons = -mons
863                    if hours_ago:
864                        hours = -hours
865                        mins = -mins
866                        secs = -secs
867                        usecs = -usecs
868                else:
869                    raise ValueError('Cannot parse interval: %s' % value)
870    days += 365 * years + 30 * mons
871    return timedelta(days=days, hours=hours, minutes=mins,
872        seconds=secs, microseconds=usecs)
873
874
875class Typecasts(dict):
876    """Dictionary mapping database types to typecast functions.
877
878    The cast functions get passed the string representation of a value in
879    the database which they need to convert to a Python object.  The
880    passed string will never be None since NULL values are already be
881    handled before the cast function is called.
882
883    Note that the basic types are already handled by the C extension.
884    They only need to be handled here as record or array components.
885    """
886
887    # the default cast functions
888    # (str functions are ignored but have been added for faster access)
889    defaults = {'char': str, 'bpchar': str, 'name': str,
890        'text': str, 'varchar': str,
891        'bool': cast_bool, 'bytea': unescape_bytea,
892        'int2': int, 'int4': int, 'serial': int, 'int8': long, 'oid': int,
893        'hstore': cast_hstore, 'json': cast_json, 'jsonb': cast_json,
894        'float4': float, 'float8': float,
895        'numeric': cast_num, 'money': cast_money,
896        'date': cast_date, 'interval': cast_interval,
897        'time': cast_time, 'timetz': cast_timetz,
898        'timestamp': cast_timestamp, 'timestamptz': cast_timestamptz,
899        'int2vector': cast_int2vector, 'uuid': UUID,
900        'anyarray': cast_array, 'record': cast_record}
901
902    connection = None  # will be set in a connection specific instance
903
904    def __missing__(self, typ):
905        """Create a cast function if it is not cached.
906       
907        Note that this class never raises a KeyError,
908        but returns None when no special cast function exists.
909        """
910        if not isinstance(typ, str):
911            raise TypeError('Invalid type: %s' % typ)
912        cast = self.defaults.get(typ)
913        if cast:
914            # store default for faster access
915            cast = self._add_connection(cast)
916            self[typ] = cast
917        elif typ.startswith('_'):
918            base_cast = self[typ[1:]]
919            cast = self.create_array_cast(base_cast)
920            if base_cast:
921                self[typ] = cast
922        else:
923            attnames = self.get_attnames(typ)
924            if attnames:
925                casts = [self[v.pgtype] for v in attnames.values()]
926                cast = self.create_record_cast(typ, attnames, casts)
927                self[typ] = cast
928        return cast
929
930    @staticmethod
931    def _needs_connection(func):
932        """Check if a typecast function needs a connection argument."""
933        try:
934            args = get_args(func)
935        except (TypeError, ValueError):
936            return False
937        else:
938            return 'connection' in args[1:]
939
940    def _add_connection(self, cast):
941        """Add a connection argument to the typecast function if necessary."""
942        if not self.connection or not self._needs_connection(cast):
943            return cast
944        return partial(cast, connection=self.connection)
945
946    def get(self, typ, default=None):
947        """Get the typecast function for the given database type."""
948        return self[typ] or default
949
950    def set(self, typ, cast):
951        """Set a typecast function for the specified database type(s)."""
952        if isinstance(typ, basestring):
953            typ = [typ]
954        if cast is None:
955            for t in typ:
956                self.pop(t, None)
957                self.pop('_%s' % t, None)
958        else:
959            if not callable(cast):
960                raise TypeError("Cast parameter must be callable")
961            for t in typ:
962                self[t] = self._add_connection(cast)
963                self.pop('_%s' % t, None)
964
965    def reset(self, typ=None):
966        """Reset the typecasts for the specified type(s) to their defaults.
967
968        When no type is specified, all typecasts will be reset.
969        """
970        if typ is None:
971            self.clear()
972        else:
973            if isinstance(typ, basestring):
974                typ = [typ]
975            for t in typ:
976                self.pop(t, None)
977
978    @classmethod
979    def get_default(cls, typ):
980        """Get the default typecast function for the given database type."""
981        return cls.defaults.get(typ)
982
983    @classmethod
984    def set_default(cls, typ, cast):
985        """Set a default typecast function for the given database type(s)."""
986        if isinstance(typ, basestring):
987            typ = [typ]
988        defaults = cls.defaults
989        if cast is None:
990            for t in typ:
991                defaults.pop(t, None)
992                defaults.pop('_%s' % t, None)
993        else:
994            if not callable(cast):
995                raise TypeError("Cast parameter must be callable")
996            for t in typ:
997                defaults[t] = cast
998                defaults.pop('_%s' % t, None)
999
1000    def get_attnames(self, typ):
1001        """Return the fields for the given record type.
1002
1003        This method will be replaced with the get_attnames() method of DbTypes.
1004        """
1005        return {}
1006
1007    def dateformat(self):
1008        """Return the current date format.
1009
1010        This method will be replaced with the dateformat() method of DbTypes.
1011        """
1012        return '%Y-%m-%d'
1013
1014    def create_array_cast(self, basecast):
1015        """Create an array typecast for the given base cast."""
1016        def cast(v):
1017            return cast_array(v, basecast)
1018        return cast
1019
1020    def create_record_cast(self, name, fields, casts):
1021        """Create a named record typecast for the given fields and casts."""
1022        record = namedtuple(name, fields)
1023        def cast(v):
1024            return record(*cast_record(v, casts))
1025        return cast
1026
1027
1028def get_typecast(typ):
1029    """Get the global typecast function for the given database type(s)."""
1030    return Typecasts.get_default(typ)
1031
1032
1033def set_typecast(typ, cast):
1034    """Set a global typecast function for the given database type(s).
1035
1036    Note that connections cache cast functions. To be sure a global change
1037    is picked up by a running connection, call db.db_types.reset_typecast().
1038    """
1039    Typecasts.set_default(typ, cast)
1040
1041
1042class DbType(str):
1043    """Class augmenting the simple type name with additional info.
1044
1045    The following additional information is provided:
1046
1047        oid: the PostgreSQL type OID
1048        pgtype: the PostgreSQL type name
1049        regtype: the regular type name
1050        simple: the simple PyGreSQL type name
1051        typtype: b = base type, c = composite type etc.
1052        category: A = Array, b =Boolean, C = Composite etc.
1053        delim: delimiter for array types
1054        relid: corresponding table for composite types
1055        attnames: attributes for composite types
1056    """
1057
1058    @property
1059    def attnames(self):
1060        """Get names and types of the fields of a composite type."""
1061        return self._get_attnames(self)
1062
1063
1064class DbTypes(dict):
1065    """Cache for PostgreSQL data types.
1066
1067    This cache maps type OIDs and names to DbType objects containing
1068    information on the associated database type.
1069    """
1070
1071    _num_types = frozenset('int float num money'
1072        ' int2 int4 int8 float4 float8 numeric money'.split())
1073
1074    def __init__(self, db):
1075        """Initialize type cache for connection."""
1076        super(DbTypes, self).__init__()
1077        self._regtypes = False
1078        self._get_attnames = db.get_attnames
1079        self._typecasts = Typecasts()
1080        self._typecasts.get_attnames = self.get_attnames
1081        self._typecasts.connection = db
1082        db = db.db
1083        self.query = db.query
1084        self.escape_string = db.escape_string
1085
1086    def add(self, oid, pgtype, regtype,
1087               typtype, category, delim, relid):
1088        """Create a PostgreSQL type name with additional info."""
1089        if oid in self:
1090            return self[oid]
1091        simple = 'record' if relid else _simpletypes[pgtype]
1092        typ = DbType(regtype if self._regtypes else simple)
1093        typ.oid = oid
1094        typ.simple = simple
1095        typ.pgtype = pgtype
1096        typ.regtype = regtype
1097        typ.typtype = typtype
1098        typ.category = category
1099        typ.delim = delim
1100        typ.relid = relid
1101        typ._get_attnames = self.get_attnames
1102        return typ
1103
1104    def __missing__(self, key):
1105        """Get the type info from the database if it is not cached."""
1106        try:
1107            res = self.query("SELECT oid, typname, typname::regtype,"
1108                " typtype, typcategory, typdelim, typrelid"
1109                " FROM pg_type WHERE oid=%s::regtype" %
1110                (_quote_if_unqualified('$1', key),), (key,)).getresult()
1111        except ProgrammingError:
1112            res = None
1113        if not res:
1114            raise KeyError('Type %s could not be found' % key)
1115        res = res[0]
1116        typ = self.add(*res)
1117        self[typ.oid] = self[typ.pgtype] = typ
1118        return typ
1119
1120    def get(self, key, default=None):
1121        """Get the type even if it is not cached."""
1122        try:
1123            return self[key]
1124        except KeyError:
1125            return default
1126
1127    def get_attnames(self, typ):
1128        """Get names and types of the fields of a composite type."""
1129        if not isinstance(typ, DbType):
1130            typ = self.get(typ)
1131            if not typ:
1132                return None
1133        if not typ.relid:
1134            return None
1135        return self._get_attnames(typ.relid, with_oid=False)
1136
1137    def get_typecast(self, typ):
1138        """Get the typecast function for the given database type."""
1139        return self._typecasts.get(typ)
1140
1141    def set_typecast(self, typ, cast):
1142        """Set a typecast function for the specified database type(s)."""
1143        self._typecasts.set(typ, cast)
1144
1145    def reset_typecast(self, typ=None):
1146        """Reset the typecast function for the specified database type(s)."""
1147        self._typecasts.reset(typ)
1148
1149    def typecast(self, value, typ):
1150        """Cast the given value according to the given database type."""
1151        if value is None:
1152            # for NULL values, no typecast is necessary
1153            return None
1154        if not isinstance(typ, DbType):
1155            typ = self.get(typ)
1156            if typ:
1157                typ = typ.pgtype
1158        cast = self.get_typecast(typ) if typ else None
1159        if not cast or cast is str:
1160            # no typecast is necessary
1161            return value
1162        return cast(value)
1163
1164
1165def _namedresult(q):
1166    """Get query result as named tuples."""
1167    row = namedtuple('Row', q.listfields())
1168    return [row(*r) for r in q.getresult()]
1169
1170
1171class _MemoryQuery:
1172    """Class that embodies a given query result."""
1173
1174    def __init__(self, result, fields):
1175        """Create query from given result rows and field names."""
1176        self.result = result
1177        self.fields = fields
1178
1179    def listfields(self):
1180        """Return the stored field names of this query."""
1181        return self.fields
1182
1183    def getresult(self):
1184        """Return the stored result of this query."""
1185        return self.result
1186
1187
1188def _db_error(msg, cls=DatabaseError):
1189    """Return DatabaseError with empty sqlstate attribute."""
1190    error = cls(msg)
1191    error.sqlstate = None
1192    return error
1193
1194
1195def _int_error(msg):
1196    """Return InternalError."""
1197    return _db_error(msg, InternalError)
1198
1199
1200def _prg_error(msg):
1201    """Return ProgrammingError."""
1202    return _db_error(msg, ProgrammingError)
1203
1204
1205# Initialize the C module
1206
1207set_namedresult(_namedresult)
1208set_decimal(Decimal)
1209set_jsondecode(jsondecode)
1210
1211
1212# The notification handler
1213
1214class NotificationHandler(object):
1215    """A PostgreSQL client-side asynchronous notification handler."""
1216
1217    def __init__(self, db, event, callback=None,
1218            arg_dict=None, timeout=None, stop_event=None):
1219        """Initialize the notification handler.
1220
1221        You must pass a PyGreSQL database connection, the name of an
1222        event (notification channel) to listen for and a callback function.
1223
1224        You can also specify a dictionary arg_dict that will be passed as
1225        the single argument to the callback function, and a timeout value
1226        in seconds (a floating point number denotes fractions of seconds).
1227        If it is absent or None, the callers will never time out.  If the
1228        timeout is reached, the callback function will be called with a
1229        single argument that is None.  If you set the timeout to zero,
1230        the handler will poll notifications synchronously and return.
1231
1232        You can specify the name of the event that will be used to signal
1233        the handler to stop listening as stop_event. By default, it will
1234        be the event name prefixed with 'stop_'.
1235        """
1236        self.db = db
1237        self.event = event
1238        self.stop_event = stop_event or 'stop_%s' % event
1239        self.listening = False
1240        self.callback = callback
1241        if arg_dict is None:
1242            arg_dict = {}
1243        self.arg_dict = arg_dict
1244        self.timeout = timeout
1245
1246    def __del__(self):
1247        self.unlisten()
1248
1249    def close(self):
1250        """Stop listening and close the connection."""
1251        if self.db:
1252            self.unlisten()
1253            self.db.close()
1254            self.db = None
1255
1256    def listen(self):
1257        """Start listening for the event and the stop event."""
1258        if not self.listening:
1259            self.db.query('listen "%s"' % self.event)
1260            self.db.query('listen "%s"' % self.stop_event)
1261            self.listening = True
1262
1263    def unlisten(self):
1264        """Stop listening for the event and the stop event."""
1265        if self.listening:
1266            self.db.query('unlisten "%s"' % self.event)
1267            self.db.query('unlisten "%s"' % self.stop_event)
1268            self.listening = False
1269
1270    def notify(self, db=None, stop=False, payload=None):
1271        """Generate a notification.
1272
1273        Optionally, you can pass a payload with the notification.
1274
1275        If you set the stop flag, a stop notification will be sent that
1276        will cause the handler to stop listening.
1277
1278        Note: If the notification handler is running in another thread, you
1279        must pass a different database connection since PyGreSQL database
1280        connections are not thread-safe.
1281        """
1282        if self.listening:
1283            if not db:
1284                db = self.db
1285            q = 'notify "%s"' % (self.stop_event if stop else self.event)
1286            if payload:
1287                q += ", '%s'" % payload
1288            return db.query(q)
1289
1290    def __call__(self):
1291        """Invoke the notification handler.
1292
1293        The handler is a loop that listens for notifications on the event
1294        and stop event channels.  When either of these notifications are
1295        received, its associated 'pid', 'event' and 'extra' (the payload
1296        passed with the notification) are inserted into its arg_dict
1297        dictionary and the callback is invoked with this dictionary as
1298        a single argument.  When the handler receives a stop event, it
1299        stops listening to both events and return.
1300
1301        In the special case that the timeout of the handler has been set
1302        to zero, the handler will poll all events synchronously and return.
1303        If will keep listening until it receives a stop event.
1304
1305        Note: If you run this loop in another thread, don't use the same
1306        database connection for database operations in the main thread.
1307        """
1308        self.listen()
1309        poll = self.timeout == 0
1310        if not poll:
1311            rlist = [self.db.fileno()]
1312        while self.listening:
1313            if poll or select.select(rlist, [], [], self.timeout)[0]:
1314                while self.listening:
1315                    notice = self.db.getnotify()
1316                    if not notice:  # no more messages
1317                        break
1318                    event, pid, extra = notice
1319                    if event not in (self.event, self.stop_event):
1320                        self.unlisten()
1321                        raise _db_error(
1322                            'Listening for "%s" and "%s", but notified of "%s"'
1323                            % (self.event, self.stop_event, event))
1324                    if event == self.stop_event:
1325                        self.unlisten()
1326                    self.arg_dict.update(pid=pid, event=event, extra=extra)
1327                    self.callback(self.arg_dict)
1328                if poll:
1329                    break
1330            else:   # we timed out
1331                self.unlisten()
1332                self.callback(None)
1333
1334
1335def pgnotify(*args, **kw):
1336    """Same as NotificationHandler, under the traditional name."""
1337    warnings.warn("pgnotify is deprecated, use NotificationHandler instead",
1338        DeprecationWarning, stacklevel=2)
1339    return NotificationHandler(*args, **kw)
1340
1341
1342# The actual PostGreSQL database connection interface:
1343
1344class DB:
1345    """Wrapper class for the _pg connection type."""
1346
1347    def __init__(self, *args, **kw):
1348        """Create a new connection
1349
1350        You can pass either the connection parameters or an existing
1351        _pg or pgdb connection. This allows you to use the methods
1352        of the classic pg interface with a DB-API 2 pgdb connection.
1353        """
1354        if not args and len(kw) == 1:
1355            db = kw.get('db')
1356        elif not kw and len(args) == 1:
1357            db = args[0]
1358        else:
1359            db = None
1360        if db:
1361            if isinstance(db, DB):
1362                db = db.db
1363            else:
1364                try:
1365                    db = db._cnx
1366                except AttributeError:
1367                    pass
1368        if not db or not hasattr(db, 'db') or not hasattr(db, 'query'):
1369            db = connect(*args, **kw)
1370            self._closeable = True
1371        else:
1372            self._closeable = False
1373        self.db = db
1374        self.dbname = db.db
1375        self._regtypes = False
1376        self._attnames = {}
1377        self._pkeys = {}
1378        self._privileges = {}
1379        self._args = args, kw
1380        self.adapter = Adapter(self)
1381        self.dbtypes = DbTypes(self)
1382        db.set_cast_hook(self.dbtypes.typecast)
1383        self.debug = None  # For debugging scripts, this can be set
1384            # * to a string format specification (e.g. in CGI set to "%s<BR>"),
1385            # * to a file object to write debug statements or
1386            # * to a callable object which takes a string argument
1387            # * to any other true value to just print debug statements
1388
1389    def __getattr__(self, name):
1390        # All undefined members are same as in underlying connection:
1391        if self.db:
1392            return getattr(self.db, name)
1393        else:
1394            raise _int_error('Connection is not valid')
1395
1396    def __dir__(self):
1397        # Custom dir function including the attributes of the connection:
1398        attrs = set(self.__class__.__dict__)
1399        attrs.update(self.__dict__)
1400        attrs.update(dir(self.db))
1401        return sorted(attrs)
1402
1403    # Context manager methods
1404
1405    def __enter__(self):
1406        """Enter the runtime context. This will start a transactio."""
1407        self.begin()
1408        return self
1409
1410    def __exit__(self, et, ev, tb):
1411        """Exit the runtime context. This will end the transaction."""
1412        if et is None and ev is None and tb is None:
1413            self.commit()
1414        else:
1415            self.rollback()
1416
1417    # Auxiliary methods
1418
1419    def _do_debug(self, *args):
1420        """Print a debug message"""
1421        if self.debug:
1422            s = '\n'.join(str(arg) for arg in args)
1423            if isinstance(self.debug, basestring):
1424                print(self.debug % s)
1425            elif hasattr(self.debug, 'write'):
1426                self.debug.write(s + '\n')
1427            elif callable(self.debug):
1428                self.debug(s)
1429            else:
1430                print(s)
1431
1432    def _escape_qualified_name(self, s):
1433        """Escape a qualified name.
1434
1435        Escapes the name for use as an SQL identifier, unless the
1436        name contains a dot, in which case the name is ambiguous
1437        (could be a qualified name or just a name with a dot in it)
1438        and must be quoted manually by the caller.
1439        """
1440        if '.' not in s:
1441            s = self.escape_identifier(s)
1442        return s
1443
1444    @staticmethod
1445    def _make_bool(d):
1446        """Get boolean value corresponding to d."""
1447        return bool(d) if get_bool() else ('t' if d else 'f')
1448
1449    def _list_params(self, params):
1450        """Create a human readable parameter list."""
1451        return ', '.join('$%d=%r' % (n, v) for n, v in enumerate(params, 1))
1452
1453    # Public methods
1454
1455    # escape_string and escape_bytea exist as methods,
1456    # so we define unescape_bytea as a method as well
1457    unescape_bytea = staticmethod(unescape_bytea)
1458
1459    def decode_json(self, s):
1460        """Decode a JSON string coming from the database."""
1461        return (get_jsondecode() or jsondecode)(s)
1462
1463    def encode_json(self, d):
1464        """Encode a JSON string for use within SQL."""
1465        return jsonencode(d)
1466
1467    def close(self):
1468        """Close the database connection."""
1469        # Wraps shared library function so we can track state.
1470        if self._closeable:
1471            if self.db:
1472                self.db.close()
1473                self.db = None
1474            else:
1475                raise _int_error('Connection already closed')
1476
1477    def reset(self):
1478        """Reset connection with current parameters.
1479
1480        All derived queries and large objects derived from this connection
1481        will not be usable after this call.
1482
1483        """
1484        if self.db:
1485            self.db.reset()
1486        else:
1487            raise _int_error('Connection already closed')
1488
1489    def reopen(self):
1490        """Reopen connection to the database.
1491
1492        Used in case we need another connection to the same database.
1493        Note that we can still reopen a database that we have closed.
1494
1495        """
1496        # There is no such shared library function.
1497        if self._closeable:
1498            db = connect(*self._args[0], **self._args[1])
1499            if self.db:
1500                self.db.close()
1501            self.db = db
1502
1503    def begin(self, mode=None):
1504        """Begin a transaction."""
1505        qstr = 'BEGIN'
1506        if mode:
1507            qstr += ' ' + mode
1508        return self.query(qstr)
1509
1510    start = begin
1511
1512    def commit(self):
1513        """Commit the current transaction."""
1514        return self.query('COMMIT')
1515
1516    end = commit
1517
1518    def rollback(self, name=None):
1519        """Roll back the current transaction."""
1520        qstr = 'ROLLBACK'
1521        if name:
1522            qstr += ' TO ' + name
1523        return self.query(qstr)
1524
1525    abort = rollback
1526
1527    def savepoint(self, name):
1528        """Define a new savepoint within the current transaction."""
1529        return self.query('SAVEPOINT ' + name)
1530
1531    def release(self, name):
1532        """Destroy a previously defined savepoint."""
1533        return self.query('RELEASE ' + name)
1534
1535    def get_parameter(self, parameter):
1536        """Get the value of a run-time parameter.
1537
1538        If the parameter is a string, the return value will also be a string
1539        that is the current setting of the run-time parameter with that name.
1540
1541        You can get several parameters at once by passing a list, set or dict.
1542        When passing a list of parameter names, the return value will be a
1543        corresponding list of parameter settings.  When passing a set of
1544        parameter names, a new dict will be returned, mapping these parameter
1545        names to their settings.  Finally, if you pass a dict as parameter,
1546        its values will be set to the current parameter settings corresponding
1547        to its keys.
1548
1549        By passing the special name 'all' as the parameter, you can get a dict
1550        of all existing configuration parameters.
1551        """
1552        if isinstance(parameter, basestring):
1553            parameter = [parameter]
1554            values = None
1555        elif isinstance(parameter, (list, tuple)):
1556            values = []
1557        elif isinstance(parameter, (set, frozenset)):
1558            values = {}
1559        elif isinstance(parameter, dict):
1560            values = parameter
1561        else:
1562            raise TypeError(
1563                'The parameter must be a string, list, set or dict')
1564        if not parameter:
1565            raise TypeError('No parameter has been specified')
1566        params = {} if isinstance(values, dict) else []
1567        for key in parameter:
1568            param = key.strip().lower() if isinstance(
1569                key, basestring) else None
1570            if not param:
1571                raise TypeError('Invalid parameter')
1572            if param == 'all':
1573                q = 'SHOW ALL'
1574                values = self.db.query(q).getresult()
1575                values = dict(value[:2] for value in values)
1576                break
1577            if isinstance(values, dict):
1578                params[param] = key
1579            else:
1580                params.append(param)
1581        else:
1582            for param in params:
1583                q = 'SHOW %s' % (param,)
1584                value = self.db.query(q).getresult()[0][0]
1585                if values is None:
1586                    values = value
1587                elif isinstance(values, list):
1588                    values.append(value)
1589                else:
1590                    values[params[param]] = value
1591        return values
1592
1593    def set_parameter(self, parameter, value=None, local=False):
1594        """Set the value of a run-time parameter.
1595
1596        If the parameter and the value are strings, the run-time parameter
1597        will be set to that value.  If no value or None is passed as a value,
1598        then the run-time parameter will be restored to its default value.
1599
1600        You can set several parameters at once by passing a list of parameter
1601        names, together with a single value that all parameters should be
1602        set to or with a corresponding list of values.  You can also pass
1603        the parameters as a set if you only provide a single value.
1604        Finally, you can pass a dict with parameter names as keys.  In this
1605        case, you should not pass a value, since the values for the parameters
1606        will be taken from the dict.
1607
1608        By passing the special name 'all' as the parameter, you can reset
1609        all existing settable run-time parameters to their default values.
1610
1611        If you set local to True, then the command takes effect for only the
1612        current transaction.  After commit() or rollback(), the session-level
1613        setting takes effect again.  Setting local to True will appear to
1614        have no effect if it is executed outside a transaction, since the
1615        transaction will end immediately.
1616        """
1617        if isinstance(parameter, basestring):
1618            parameter = {parameter: value}
1619        elif isinstance(parameter, (list, tuple)):
1620            if isinstance(value, (list, tuple)):
1621                parameter = dict(zip(parameter, value))
1622            else:
1623                parameter = dict.fromkeys(parameter, value)
1624        elif isinstance(parameter, (set, frozenset)):
1625            if isinstance(value, (list, tuple, set, frozenset)):
1626                value = set(value)
1627                if len(value) == 1:
1628                    value = value.pop()
1629            if not(value is None or isinstance(value, basestring)):
1630                raise ValueError('A single value must be specified'
1631                    ' when parameter is a set')
1632            parameter = dict.fromkeys(parameter, value)
1633        elif isinstance(parameter, dict):
1634            if value is not None:
1635                raise ValueError('A value must not be specified'
1636                    ' when parameter is a dictionary')
1637        else:
1638            raise TypeError(
1639                'The parameter must be a string, list, set or dict')
1640        if not parameter:
1641            raise TypeError('No parameter has been specified')
1642        params = {}
1643        for key, value in parameter.items():
1644            param = key.strip().lower() if isinstance(
1645                key, basestring) else None
1646            if not param:
1647                raise TypeError('Invalid parameter')
1648            if param == 'all':
1649                if value is not None:
1650                    raise ValueError('A value must ot be specified'
1651                        " when parameter is 'all'")
1652                params = {'all': None}
1653                break
1654            params[param] = value
1655        local = ' LOCAL' if local else ''
1656        for param, value in params.items():
1657            if value is None:
1658                q = 'RESET%s %s' % (local, param)
1659            else:
1660                q = 'SET%s %s TO %s' % (local, param, value)
1661            self._do_debug(q)
1662            self.db.query(q)
1663
1664    def query(self, command, *args):
1665        """Execute a SQL command string.
1666
1667        This method simply sends a SQL query to the database.  If the query is
1668        an insert statement that inserted exactly one row into a table that
1669        has OIDs, the return value is the OID of the newly inserted row.
1670        If the query is an update or delete statement, or an insert statement
1671        that did not insert exactly one row in a table with OIDs, then the
1672        number of rows affected is returned as a string.  If it is a statement
1673        that returns rows as a result (usually a select statement, but maybe
1674        also an "insert/update ... returning" statement), this method returns
1675        a Query object that can be accessed via getresult() or dictresult()
1676        or simply printed.  Otherwise, it returns `None`.
1677
1678        The query can contain numbered parameters of the form $1 in place
1679        of any data constant.  Arguments given after the query string will
1680        be substituted for the corresponding numbered parameter.  Parameter
1681        values can also be given as a single list or tuple argument.
1682        """
1683        # Wraps shared library function for debugging.
1684        if not self.db:
1685            raise _int_error('Connection is not valid')
1686        if args:
1687            self._do_debug(command, args)
1688            return self.db.query(command, args)
1689        self._do_debug(command)
1690        return self.db.query(command)
1691
1692    def query_formatted(self, command, parameters, types=None, inline=False):
1693        """Execute a formatted SQL command string.
1694
1695        Similar to query, but using Python format placeholders of the form
1696        %s or %(names)s instead of PostgreSQL placeholders of the form $1.
1697        The parameters must be passed as a tuple, list or dict.  You can
1698        also pass a corresponding tuple, list or dict of database types in
1699        order to format the parameters properly in case there is ambiguity.
1700
1701        If you set inline to True, the parameters will be sent to the database
1702        embedded in the SQL command, otherwise they will be sent separately.
1703        """
1704        return self.query(*self.adapter.format_query(
1705            command, parameters, types, inline))
1706
1707    def pkey(self, table, composite=False, flush=False):
1708        """Get or set the primary key of a table.
1709
1710        Single primary keys are returned as strings unless you
1711        set the composite flag.  Composite primary keys are always
1712        represented as tuples.  Note that this raises a KeyError
1713        if the table does not have a primary key.
1714
1715        If flush is set then the internal cache for primary keys will
1716        be flushed.  This may be necessary after the database schema or
1717        the search path has been changed.
1718        """
1719        pkeys = self._pkeys
1720        if flush:
1721            pkeys.clear()
1722            self._do_debug('The pkey cache has been flushed')
1723        try:  # cache lookup
1724            pkey = pkeys[table]
1725        except KeyError:  # cache miss, check the database
1726            q = ("SELECT a.attname, a.attnum, i.indkey FROM pg_index i"
1727                " JOIN pg_attribute a ON a.attrelid = i.indrelid"
1728                " AND a.attnum = ANY(i.indkey)"
1729                " AND NOT a.attisdropped"
1730                " WHERE i.indrelid=%s::regclass"
1731                " AND i.indisprimary ORDER BY a.attnum") % (
1732                    _quote_if_unqualified('$1', table),)
1733            pkey = self.db.query(q, (table,)).getresult()
1734            if not pkey:
1735                raise KeyError('Table %s has no primary key' % table)
1736            # we want to use the order defined in the primary key index here,
1737            # not the order as defined by the columns in the table
1738            if len(pkey) > 1:
1739                indkey = pkey[0][2]
1740                pkey = sorted(pkey, key=lambda row: indkey.index(row[1]))
1741                pkey = tuple(row[0] for row in pkey)
1742            else:
1743                pkey = pkey[0][0]
1744            pkeys[table] = pkey  # cache it
1745        if composite and not isinstance(pkey, tuple):
1746            pkey = (pkey,)
1747        return pkey
1748
1749    def get_databases(self):
1750        """Get list of databases in the system."""
1751        return [s[0] for s in
1752            self.db.query('SELECT datname FROM pg_database').getresult()]
1753
1754    def get_relations(self, kinds=None, system=False):
1755        """Get list of relations in connected database of specified kinds.
1756
1757        If kinds is None or empty, all kinds of relations are returned.
1758        Otherwise kinds can be a string or sequence of type letters
1759        specifying which kind of relations you want to list.
1760
1761        Set the system flag if you want to get the system relations as well.
1762        """
1763        where = []
1764        if kinds:
1765            where.append("r.relkind IN (%s)" %
1766                ','.join("'%s'" % k for k in kinds))
1767        if not system:
1768            where.append("s.nspname NOT SIMILAR"
1769                " TO 'pg/_%|information/_schema' ESCAPE '/'")
1770        where = " WHERE %s" % ' AND '.join(where) if where else ''
1771        q = ("SELECT quote_ident(s.nspname)||'.'||quote_ident(r.relname)"
1772            " FROM pg_class r"
1773            " JOIN pg_namespace s ON s.oid = r.relnamespace%s"
1774            " ORDER BY s.nspname, r.relname") % where
1775        return [r[0] for r in self.db.query(q).getresult()]
1776
1777    def get_tables(self, system=False):
1778        """Return list of tables in connected database.
1779
1780        Set the system flag if you want to get the system tables as well.
1781        """
1782        return self.get_relations('r', system)
1783
1784    def get_attnames(self, table, with_oid=True, flush=False):
1785        """Given the name of a table, dig out the set of attribute names.
1786
1787        Returns a read-only dictionary of attribute names (the names are
1788        the keys, the values are the names of the attributes' types)
1789        with the column names in the proper order if you iterate over it.
1790
1791        If flush is set, then the internal cache for attribute names will
1792        be flushed. This may be necessary after the database schema or
1793        the search path has been changed.
1794
1795        By default, only a limited number of simple types will be returned.
1796        You can get the regular types after calling use_regtypes(True).
1797        """
1798        attnames = self._attnames
1799        if flush:
1800            attnames.clear()
1801            self._do_debug('The attnames cache has been flushed')
1802        try:  # cache lookup
1803            names = attnames[table]
1804        except KeyError:  # cache miss, check the database
1805            q = "a.attnum > 0"
1806            if with_oid:
1807                q = "(%s OR a.attname = 'oid')" % q
1808            q = ("SELECT a.attname, t.oid, t.typname, t.typname::regtype,"
1809                " t.typtype, t.typcategory, t.typdelim, t.typrelid"
1810                " FROM pg_attribute a"
1811                " JOIN pg_type t ON t.oid = a.atttypid"
1812                " WHERE a.attrelid = %s::regclass AND %s"
1813                " AND NOT a.attisdropped ORDER BY a.attnum") % (
1814                    _quote_if_unqualified('$1', table), q)
1815            names = self.db.query(q, (table,)).getresult()
1816            types = self.dbtypes
1817            names = ((name[0], types.add(*name[1:])) for name in names)
1818            names = AttrDict(names)
1819            attnames[table] = names  # cache it
1820        return names
1821
1822    def use_regtypes(self, regtypes=None):
1823        """Use regular type names instead of simplified type names."""
1824        if regtypes is None:
1825            return self.dbtypes._regtypes
1826        else:
1827            regtypes = bool(regtypes)
1828            if regtypes != self.dbtypes._regtypes:
1829                self.dbtypes._regtypes = regtypes
1830                self._attnames.clear()
1831                self.dbtypes.clear()
1832            return regtypes
1833
1834    def has_table_privilege(self, table, privilege='select', flush=False):
1835        """Check whether current user has specified table privilege.
1836
1837        If flush is set, then the internal cache for table privileges will
1838        be flushed. This may be necessary after privileges have been changed.
1839        """
1840        privileges = self._privileges
1841        if flush:
1842            privileges.clear()
1843            self._do_debug('The privileges cache has been flushed')
1844        privilege = privilege.lower()
1845        try:  # ask cache
1846            ret = privileges[table, privilege]
1847        except KeyError:  # cache miss, ask the database
1848            q = "SELECT has_table_privilege(%s, $2)" % (
1849                _quote_if_unqualified('$1', table),)
1850            q = self.db.query(q, (table, privilege))
1851            ret = q.getresult()[0][0] == self._make_bool(True)
1852            privileges[table, privilege] = ret  # cache it
1853        return ret
1854
1855    def get(self, table, row, keyname=None):
1856        """Get a row from a database table or view.
1857
1858        This method is the basic mechanism to get a single row.  It assumes
1859        that the keyname specifies a unique row.  It must be the name of a
1860        single column or a tuple of column names.  If the keyname is not
1861        specified, then the primary key for the table is used.
1862
1863        If row is a dictionary, then the value for the key is taken from it.
1864        Otherwise, the row must be a single value or a tuple of values
1865        corresponding to the passed keyname or primary key.  The fetched row
1866        from the table will be returned as a new dictionary or used to replace
1867        the existing values when row was passed as aa dictionary.
1868
1869        The OID is also put into the dictionary if the table has one, but
1870        in order to allow the caller to work with multiple tables, it is
1871        munged as "oid(table)" using the actual name of the table.
1872        """
1873        if table.endswith('*'):  # hint for descendant tables can be ignored
1874            table = table[:-1].rstrip()
1875        attnames = self.get_attnames(table)
1876        qoid = _oid_key(table) if 'oid' in attnames else None
1877        if keyname and isinstance(keyname, basestring):
1878            keyname = (keyname,)
1879        if qoid and isinstance(row, dict) and qoid in row and 'oid' not in row:
1880            row['oid'] = row[qoid]
1881        if not keyname:
1882            try:  # if keyname is not specified, try using the primary key
1883                keyname = self.pkey(table, True)
1884            except KeyError:  # the table has no primary key
1885                # try using the oid instead
1886                if qoid and isinstance(row, dict) and 'oid' in row:
1887                    keyname = ('oid',)
1888                else:
1889                    raise _prg_error('Table %s has no primary key' % table)
1890            else:  # the table has a primary key
1891                # check whether all key columns have values
1892                if isinstance(row, dict) and not set(keyname).issubset(row):
1893                    # try using the oid instead
1894                    if qoid and 'oid' in row:
1895                        keyname = ('oid',)
1896                    else:
1897                        raise KeyError(
1898                            'Missing value in row for specified keyname')
1899        if not isinstance(row, dict):
1900            if not isinstance(row, (tuple, list)):
1901                row = [row]
1902            if len(keyname) != len(row):
1903                raise KeyError(
1904                    'Differing number of items in keyname and row')
1905            row = dict(zip(keyname, row))
1906        params = self.adapter.parameter_list()
1907        adapt = params.add
1908        col = self.escape_identifier
1909        what = 'oid, *' if qoid else '*'
1910        where = ' AND '.join('%s = %s' % (
1911            col(k), adapt(row[k], attnames[k])) for k in keyname)
1912        if 'oid' in row:
1913            if qoid:
1914                row[qoid] = row['oid']
1915            del row['oid']
1916        q = 'SELECT %s FROM %s WHERE %s LIMIT 1' % (
1917            what, self._escape_qualified_name(table), where)
1918        self._do_debug(q, params)
1919        q = self.db.query(q, params)
1920        res = q.dictresult()
1921        if not res:
1922            raise _db_error('No such record in %s\nwhere %s\nwith %s' % (
1923                table, where, self._list_params(params)))
1924        for n, value in res[0].items():
1925            if qoid and n == 'oid':
1926                n = qoid
1927            row[n] = value
1928        return row
1929
1930    def insert(self, table, row=None, **kw):
1931        """Insert a row into a database table.
1932
1933        This method inserts a row into a table.  The name of the table must
1934        be passed as the first parameter.  The other parameters are used for
1935        providing the data of the row that shall be inserted into the table.
1936        If a dictionary is supplied as the second parameter, it starts with
1937        that.  Otherwise it uses a blank dictionary. Either way the dictionary
1938        is updated from the keywords.
1939
1940        The dictionary is then reloaded with the values actually inserted in
1941        order to pick up values modified by rules, triggers, etc.
1942        """
1943        if table.endswith('*'):  # hint for descendant tables can be ignored
1944            table = table[:-1].rstrip()
1945        if row is None:
1946            row = {}
1947        row.update(kw)
1948        if 'oid' in row:
1949            del row['oid']  # do not insert oid
1950        attnames = self.get_attnames(table)
1951        qoid = _oid_key(table) if 'oid' in attnames else None
1952        params = self.adapter.parameter_list()
1953        adapt = params.add
1954        col = self.escape_identifier
1955        names, values = [], []
1956        for n in attnames:
1957            if n in row:
1958                names.append(col(n))
1959                values.append(adapt(row[n], attnames[n]))
1960        if not names:
1961            raise _prg_error('No column found that can be inserted')
1962        names, values = ', '.join(names), ', '.join(values)
1963        ret = 'oid, *' if qoid else '*'
1964        q = 'INSERT INTO %s (%s) VALUES (%s) RETURNING %s' % (
1965            self._escape_qualified_name(table), names, values, ret)
1966        self._do_debug(q, params)
1967        q = self.db.query(q, params)
1968        res = q.dictresult()
1969        if res:  # this should always be true
1970            for n, value in res[0].items():
1971                if qoid and n == 'oid':
1972                    n = qoid
1973                row[n] = value
1974        return row
1975
1976    def update(self, table, row=None, **kw):
1977        """Update an existing row in a database table.
1978
1979        Similar to insert but updates an existing row.  The update is based
1980        on the primary key of the table or the OID value as munged by get
1981        or passed as keyword.
1982
1983        The dictionary is then modified to reflect any changes caused by the
1984        update due to triggers, rules, default values, etc.
1985        """
1986        if table.endswith('*'):
1987            table = table[:-1].rstrip()  # need parent table name
1988        attnames = self.get_attnames(table)
1989        qoid = _oid_key(table) if 'oid' in attnames else None
1990        if row is None:
1991            row = {}
1992        elif 'oid' in row:
1993            del row['oid']  # only accept oid key from named args for safety
1994        row.update(kw)
1995        if qoid and qoid in row and 'oid' not in row:
1996            row['oid'] = row[qoid]
1997        try:  # try using the primary key
1998            keyname = self.pkey(table, True)
1999        except KeyError:  # the table has no primary key
2000            # try using the oid instead
2001            if qoid and 'oid' in row:
2002                keyname = ('oid',)
2003            else:
2004                raise _prg_error('Table %s has no primary key' % table)
2005        else:  # the table has a primary key
2006            # check whether all key columns have values
2007            if not set(keyname).issubset(row):
2008                # try using the oid instead
2009                if qoid and 'oid' in row:
2010                    keyname = ('oid',)
2011                else:
2012                    raise KeyError('Missing primary key in row')
2013        params = self.adapter.parameter_list()
2014        adapt = params.add
2015        col = self.escape_identifier
2016        where = ' AND '.join('%s = %s' % (
2017            col(k), adapt(row[k], attnames[k])) for k in keyname)
2018        if 'oid' in row:
2019            if qoid:
2020                row[qoid] = row['oid']
2021            del row['oid']
2022        values = []
2023        keyname = set(keyname)
2024        for n in attnames:
2025            if n in row and n not in keyname:
2026                values.append('%s = %s' % (col(n), adapt(row[n], attnames[n])))
2027        if not values:
2028            return row
2029        values = ', '.join(values)
2030        ret = 'oid, *' if qoid else '*'
2031        q = 'UPDATE %s SET %s WHERE %s RETURNING %s' % (
2032            self._escape_qualified_name(table), values, where, ret)
2033        self._do_debug(q, params)
2034        q = self.db.query(q, params)
2035        res = q.dictresult()
2036        if res:  # may be empty when row does not exist
2037            for n, value in res[0].items():
2038                if qoid and n == 'oid':
2039                    n = qoid
2040                row[n] = value
2041        return row
2042
2043    def upsert(self, table, row=None, **kw):
2044        """Insert a row into a database table with conflict resolution
2045
2046        This method inserts a row into a table, but instead of raising a
2047        ProgrammingError exception in case a row with the same primary key
2048        already exists, an update will be executed instead.  This will be
2049        performed as a single atomic operation on the database, so race
2050        conditions can be avoided.
2051
2052        Like the insert method, the first parameter is the name of the
2053        table and the second parameter can be used to pass the values to
2054        be inserted as a dictionary.
2055
2056        Unlike the insert und update statement, keyword parameters are not
2057        used to modify the dictionary, but to specify which columns shall
2058        be updated in case of a conflict, and in which way:
2059
2060        A value of False or None means the column shall not be updated,
2061        a value of True means the column shall be updated with the value
2062        that has been proposed for insertion, i.e. has been passed as value
2063        in the dictionary.  Columns that are not specified by keywords but
2064        appear as keys in the dictionary are also updated like in the case
2065        keywords had been passed with the value True.
2066
2067        So if in the case of a conflict you want to update every column that
2068        has been passed in the dictionary row , you would call upsert(table, row).
2069        If you don't want to do anything in case of a conflict, i.e. leave
2070        the existing row as it is, call upsert(table, row, **dict.fromkeys(row)).
2071
2072        If you need more fine-grained control of what gets updated, you can
2073        also pass strings in the keyword parameters.  These strings will
2074        be used as SQL expressions for the update columns.  In these
2075        expressions you can refer to the value that already exists in
2076        the table by prefixing the column name with "included.", and to
2077        the value that has been proposed for insertion by prefixing the
2078        column name with the "excluded."
2079
2080        The dictionary is modified in any case to reflect the values in
2081        the database after the operation has completed.
2082
2083        Note: The method uses the PostgreSQL "upsert" feature which is
2084        only available since PostgreSQL 9.5.
2085        """
2086        if table.endswith('*'):  # hint for descendant tables can be ignored
2087            table = table[:-1].rstrip()
2088        if row is None:
2089            row = {}
2090        if 'oid' in row:
2091            del row['oid']  # do not insert oid
2092        if 'oid' in kw:
2093            del kw['oid']  # do not update oid
2094        attnames = self.get_attnames(table)
2095        qoid = _oid_key(table) if 'oid' in attnames else None
2096        params = self.adapter.parameter_list()
2097        adapt = params.add
2098        col = self.escape_identifier
2099        names, values, updates = [], [], []
2100        for n in attnames:
2101            if n in row:
2102                names.append(col(n))
2103                values.append(adapt(row[n], attnames[n]))
2104        names, values = ', '.join(names), ', '.join(values)
2105        try:
2106            keyname = self.pkey(table, True)
2107        except KeyError:
2108            raise _prg_error('Table %s has no primary key' % table)
2109        target = ', '.join(col(k) for k in keyname)
2110        update = []
2111        keyname = set(keyname)
2112        keyname.add('oid')
2113        for n in attnames:
2114            if n not in keyname:
2115                value = kw.get(n, True)
2116                if value:
2117                    if not isinstance(value, basestring):
2118                        value = 'excluded.%s' % col(n)
2119                    update.append('%s = %s' % (col(n), value))
2120        if not values:
2121            return row
2122        do = 'update set %s' % ', '.join(update) if update else 'nothing'
2123        ret = 'oid, *' if qoid else '*'
2124        q = ('INSERT INTO %s AS included (%s) VALUES (%s)'
2125            ' ON CONFLICT (%s) DO %s RETURNING %s') % (
2126                self._escape_qualified_name(table), names, values,
2127                target, do, ret)
2128        self._do_debug(q, params)
2129        try:
2130            q = self.db.query(q, params)
2131        except ProgrammingError:
2132            if self.server_version < 90500:
2133                raise _prg_error(
2134                    'Upsert operation is not supported by PostgreSQL version')
2135            raise  # re-raise original error
2136        res = q.dictresult()
2137        if res:  # may be empty with "do nothing"
2138            for n, value in res[0].items():
2139                if qoid and n == 'oid':
2140                    n = qoid
2141                row[n] = value
2142        else:
2143            self.get(table, row)
2144        return row
2145
2146    def clear(self, table, row=None):
2147        """Clear all the attributes to values determined by the types.
2148
2149        Numeric types are set to 0, Booleans are set to false, and everything
2150        else is set to the empty string.  If the row argument is present,
2151        it is used as the row dictionary and any entries matching attribute
2152        names are cleared with everything else left unchanged.
2153        """
2154        # At some point we will need a way to get defaults from a table.
2155        if row is None:
2156            row = {}  # empty if argument is not present
2157        attnames = self.get_attnames(table)
2158        for n, t in attnames.items():
2159            if n == 'oid':
2160                continue
2161            t = t.simple
2162            if t in DbTypes._num_types:
2163                row[n] = 0
2164            elif t == 'bool':
2165                row[n] = self._make_bool(False)
2166            else:
2167                row[n] = ''
2168        return row
2169
2170    def delete(self, table, row=None, **kw):
2171        """Delete an existing row in a database table.
2172
2173        This method deletes the row from a table.  It deletes based on the
2174        primary key of the table or the OID value as munged by get() or
2175        passed as keyword.
2176
2177        The return value is the number of deleted rows (i.e. 0 if the row
2178        did not exist and 1 if the row was deleted).
2179
2180        Note that if the row cannot be deleted because e.g. it is still
2181        referenced by another table, this method raises a ProgrammingError.
2182        """
2183        if table.endswith('*'):  # hint for descendant tables can be ignored
2184            table = table[:-1].rstrip()
2185        attnames = self.get_attnames(table)
2186        qoid = _oid_key(table) if 'oid' in attnames else None
2187        if row is None:
2188            row = {}
2189        elif 'oid' in row:
2190            del row['oid']  # only accept oid key from named args for safety
2191        row.update(kw)
2192        if qoid and qoid in row and 'oid' not in row:
2193            row['oid'] = row[qoid]
2194        try:  # try using the primary key
2195            keyname = self.pkey(table, True)
2196        except KeyError:  # the table has no primary key
2197            # try using the oid instead
2198            if qoid and 'oid' in row:
2199                keyname = ('oid',)
2200            else:
2201                raise _prg_error('Table %s has no primary key' % table)
2202        else:  # the table has a primary key
2203            # check whether all key columns have values
2204            if not set(keyname).issubset(row):
2205                # try using the oid instead
2206                if qoid and 'oid' in row:
2207                    keyname = ('oid',)
2208                else:
2209                    raise KeyError('Missing primary key in row')
2210        params = self.adapter.parameter_list()
2211        adapt = params.add
2212        col = self.escape_identifier
2213        where = ' AND '.join('%s = %s' % (
2214            col(k), adapt(row[k], attnames[k])) for k in keyname)
2215        if 'oid' in row:
2216            if qoid:
2217                row[qoid] = row['oid']
2218            del row['oid']
2219        q = 'DELETE FROM %s WHERE %s' % (
2220            self._escape_qualified_name(table), where)
2221        self._do_debug(q, params)
2222        res = self.db.query(q, params)
2223        return int(res)
2224
2225    def truncate(self, table, restart=False, cascade=False, only=False):
2226        """Empty a table or set of tables.
2227
2228        This method quickly removes all rows from the given table or set
2229        of tables.  It has the same effect as an unqualified DELETE on each
2230        table, but since it does not actually scan the tables it is faster.
2231        Furthermore, it reclaims disk space immediately, rather than requiring
2232        a subsequent VACUUM operation. This is most useful on large tables.
2233
2234        If restart is set to True, sequences owned by columns of the truncated
2235        table(s) are automatically restarted.  If cascade is set to True, it
2236        also truncates all tables that have foreign-key references to any of
2237        the named tables.  If the parameter only is not set to True, all the
2238        descendant tables (if any) will also be truncated. Optionally, a '*'
2239        can be specified after the table name to explicitly indicate that
2240        descendant tables are included.
2241        """
2242        if isinstance(table, basestring):
2243            only = {table: only}
2244            table = [table]
2245        elif isinstance(table, (list, tuple)):
2246            if isinstance(only, (list, tuple)):
2247                only = dict(zip(table, only))
2248            else:
2249                only = dict.fromkeys(table, only)
2250        elif isinstance(table, (set, frozenset)):
2251            only = dict.fromkeys(table, only)
2252        else:
2253            raise TypeError('The table must be a string, list or set')
2254        if not (restart is None or isinstance(restart, (bool, int))):
2255            raise TypeError('Invalid type for the restart option')
2256        if not (cascade is None or isinstance(cascade, (bool, int))):
2257            raise TypeError('Invalid type for the cascade option')
2258        tables = []
2259        for t in table:
2260            u = only.get(t)
2261            if not (u is None or isinstance(u, (bool, int))):
2262                raise TypeError('Invalid type for the only option')
2263            if t.endswith('*'):
2264                if u:
2265                    raise ValueError(
2266                        'Contradictory table name and only options')
2267                t = t[:-1].rstrip()
2268            t = self._escape_qualified_name(t)
2269            if u:
2270                t = 'ONLY %s' % t
2271            tables.append(t)
2272        q = ['TRUNCATE', ', '.join(tables)]
2273        if restart:
2274            q.append('RESTART IDENTITY')
2275        if cascade:
2276            q.append('CASCADE')
2277        q = ' '.join(q)
2278        self._do_debug(q)
2279        return self.db.query(q)
2280
2281    def get_as_list(self, table, what=None, where=None,
2282            order=None, limit=None, offset=None, scalar=False):
2283        """Get a table as a list.
2284
2285        This gets a convenient representation of the table as a list
2286        of named tuples in Python.  You only need to pass the name of
2287        the table (or any other SQL expression returning rows).  Note that
2288        by default this will return the full content of the table which
2289        can be huge and overflow your memory.  However, you can control
2290        the amount of data returned using the other optional parameters.
2291
2292        The parameter 'what' can restrict the query to only return a
2293        subset of the table columns.  It can be a string, list or a tuple.
2294        The parameter 'where' can restrict the query to only return a
2295        subset of the table rows.  It can be a string, list or a tuple
2296        of SQL expressions that all need to be fulfilled.  The parameter
2297        'order' specifies the ordering of the rows.  It can also be a
2298        other string, list or a tuple.  If no ordering is specified,
2299        the result will be ordered by the primary key(s) or all columns
2300        if no primary key exists.  You can set 'order' to False if you
2301        don't care about the ordering.  The parameters 'limit' and 'offset'
2302        can be integers specifying the maximum number of rows returned
2303        and a number of rows skipped over.
2304
2305        If you set the 'scalar' option to True, then instead of the
2306        named tuples you will get the first items of these tuples.
2307        This is useful if the result has only one column anyway.
2308        """
2309        if not table:
2310            raise TypeError('The table name is missing')
2311        if what:
2312            if isinstance(what, (list, tuple)):
2313                what = ', '.join(map(str, what))
2314            if order is None:
2315                order = what
2316        else:
2317            what = '*'
2318        q = ['SELECT', what, 'FROM', table]
2319        if where:
2320            if isinstance(where, (list, tuple)):
2321                where = ' AND '.join(map(str, where))
2322            q.extend(['WHERE', where])
2323        if order is None:
2324            try:
2325                order = self.pkey(table, True)
2326            except (KeyError, ProgrammingError):
2327                try:
2328                    order = list(self.get_attnames(table))
2329                except (KeyError, ProgrammingError):
2330                    pass
2331        if order:
2332            if isinstance(order, (list, tuple)):
2333                order = ', '.join(map(str, order))
2334            q.extend(['ORDER BY', order])
2335        if limit:
2336            q.append('LIMIT %d' % limit)
2337        if offset:
2338            q.append('OFFSET %d' % offset)
2339        q = ' '.join(q)
2340        self._do_debug(q)
2341        q = self.db.query(q)
2342        res = q.namedresult()
2343        if res and scalar:
2344            res = [row[0] for row in res]
2345        return res
2346
2347    def get_as_dict(self, table, keyname=None, what=None, where=None,
2348            order=None, limit=None, offset=None, scalar=False):
2349        """Get a table as a dictionary.
2350
2351        This method is similar to get_as_list(), but returns the table
2352        as a Python dict instead of a Python list, which can be even
2353        more convenient. The primary key column(s) of the table will
2354        be used as the keys of the dictionary, while the other column(s)
2355        will be the corresponding values.  The keys will be named tuples
2356        if the table has a composite primary key.  The rows will be also
2357        named tuples unless the 'scalar' option has been set to True.
2358        With the optional parameter 'keyname' you can specify an alternative
2359        set of columns to be used as the keys of the dictionary.  It must
2360        be set as a string, list or a tuple.
2361
2362        If the Python version supports it, the dictionary will be an
2363        OrderedDict using the order specified with the 'order' parameter
2364        or the key column(s) if not specified.  You can set 'order' to False
2365        if you don't care about the ordering.  In this case the returned
2366        dictionary will be an ordinary one.
2367        """
2368        if not table:
2369            raise TypeError('The table name is missing')
2370        if not keyname:
2371            try:
2372                keyname = self.pkey(table, True)
2373            except (KeyError, ProgrammingError):
2374                raise _prg_error('Table %s has no primary key' % table)
2375        if isinstance(keyname, basestring):
2376            keyname = [keyname]
2377        elif not isinstance(keyname, (list, tuple)):
2378            raise KeyError('The keyname must be a string, list or tuple')
2379        if what:
2380            if isinstance(what, (list, tuple)):
2381                what = ', '.join(map(str, what))
2382            if order is None:
2383                order = what
2384        else:
2385            what = '*'
2386        q = ['SELECT', what, 'FROM', table]
2387        if where:
2388            if isinstance(where, (list, tuple)):
2389                where = ' AND '.join(map(str, where))
2390            q.extend(['WHERE', where])
2391        if order is None:
2392            order = keyname
2393        if order:
2394            if isinstance(order, (list, tuple)):
2395                order = ', '.join(map(str, order))
2396            q.extend(['ORDER BY', order])
2397        if limit:
2398            q.append('LIMIT %d' % limit)
2399        if offset:
2400            q.append('OFFSET %d' % offset)
2401        q = ' '.join(q)
2402        self._do_debug(q)
2403        q = self.db.query(q)
2404        res = q.getresult()
2405        cls = OrderedDict if order else dict
2406        if not res:
2407            return cls()
2408        keyset = set(keyname)
2409        fields = q.listfields()
2410        if not keyset.issubset(fields):
2411            raise KeyError('Missing keyname in row')
2412        keyind, rowind = [], []
2413        for i, f in enumerate(fields):
2414            (keyind if f in keyset else rowind).append(i)
2415        keytuple = len(keyind) > 1
2416        getkey = itemgetter(*keyind)
2417        keys = map(getkey, res)
2418        if scalar:
2419            rowind = rowind[:1]
2420            rowtuple = False
2421        else:
2422            rowtuple = len(rowind) > 1
2423        if scalar or rowtuple:
2424            getrow = itemgetter(*rowind)
2425        else:
2426            rowind = rowind[0]
2427            getrow = lambda row: (row[rowind],)
2428            rowtuple = True
2429        rows = map(getrow, res)
2430        if keytuple or rowtuple:
2431            namedresult = get_namedresult()
2432            if namedresult:
2433                if keytuple:
2434                    keys = namedresult(_MemoryQuery(keys, keyname))
2435                if rowtuple:
2436                    fields = [f for f in fields if f not in keyset]
2437                    rows = namedresult(_MemoryQuery(rows, fields))
2438        return cls(zip(keys, rows))
2439
2440    def notification_handler(self,
2441            event, callback, arg_dict=None, timeout=None, stop_event=None):
2442        """Get notification handler that will run the given callback."""
2443        return NotificationHandler(self,
2444            event, callback, arg_dict, timeout, stop_event)
2445
2446
2447# if run as script, print some information
2448
2449if __name__ == '__main__':
2450    print('PyGreSQL version' + version)
2451    print('')
2452    print(__doc__)
Note: See TracBrowser for help on using the repository browser.