source: trunk/pg.py @ 814

Last change on this file since 814 was 814, checked in by cito, 3 years ago

Add typecasting of dates, times, timestamps, intervals

So far, PyGreSQL has returned these types only as strings (in various
formats depending on the DateStyle? setting) and left it to the user
to parse and interpret the strings. These types are now properly cast
into the corresponding detetime types of Python, and this works with
any setting of DatesStyle?, even if you change DateStyle? in the middle
of a database session.

To implement this, a fast method for getting the datestyle (cached and
without roundtrip to the database) has been added. Also, the typecast
mechanism has been extended so that typecast functions can optionally
also take the connection as argument.

The date and time typecast functions have been implemented in Python
using the new typecast registry and added to both pg and pgdb. Some
duplication of code in the two modules was unavoidable, since we don't
want the modules to be dependent of each other or install additional
helper modules. One day we might want to change this, put everything
in one package and factor out some of the functionality.

  • Property svn:keywords set to Id
File size: 86.2 KB
Line 
1#! /usr/bin/python
2#
3# pg.py
4#
5# $Id: pg.py 814 2016-02-03 20:23:20Z cito $
6#
7
8"""PyGreSQL classic interface.
9
10This pg module implements some basic database management stuff.
11It includes the _pg module and builds on it, providing the higher
12level wrapper class named DB with additional functionality.
13This is known as the "classic" ("old style") PyGreSQL interface.
14For a DB-API 2 compliant interface use the newer pgdb module.
15"""
16
17# Copyright (c) 1997-2016 by D'Arcy J.M. Cain.
18#
19# Contributions made by Ch. Zwerschke and others.
20#
21# The notification handler is based on pgnotify which is
22# Copyright (c) 2001 Ng Pheng Siong. All rights reserved.
23#
24# Permission to use, copy, modify, and distribute this software and its
25# documentation for any purpose and without fee is hereby granted,
26# provided that the above copyright notice appear in all copies and that
27# both that copyright notice and this permission notice appear in
28# supporting documentation.
29
30from __future__ import print_function
31
32from _pg import *
33
34import select
35import warnings
36
37from datetime import date, time, datetime, timedelta
38from decimal import Decimal
39from math import isnan, isinf
40from collections import namedtuple
41from operator import itemgetter
42from re import compile as regex
43from json import loads as jsondecode, dumps as jsonencode
44
45try:
46    long
47except NameError:  # Python >= 3.0
48    long = int
49
50try:
51    basestring
52except NameError:  # Python >= 3.0
53    basestring = (str, bytes)
54
55
56# Auxiliary classes and functions that are independent from a DB connection:
57
58try:
59    from collections import OrderedDict
60except ImportError:  # Python 2.6 or 3.0
61    OrderedDict = dict
62
63
64    class AttrDict(dict):
65        """Simple read-only ordered dictionary for storing attribute names."""
66
67        def __init__(self, *args, **kw):
68            if len(args) > 1 or kw:
69                raise TypeError
70            items = args[0] if args else []
71            if isinstance(items, dict):
72                raise TypeError
73            items = list(items)
74            self._keys = [item[0] for item in items]
75            dict.__init__(self, items)
76            self._read_only = True
77            error = self._read_only_error
78            self.clear = self.update = error
79            self.pop = self.setdefault = self.popitem = error
80
81        def __setitem__(self, key, value):
82            if self._read_only:
83                self._read_only_error()
84            dict.__setitem__(self, key, value)
85
86        def __delitem__(self, key):
87            if self._read_only:
88                self._read_only_error()
89            dict.__delitem__(self, key)
90
91        def __iter__(self):
92            return iter(self._keys)
93
94        def keys(self):
95            return list(self._keys)
96
97        def values(self):
98            return [self[key] for key in self]
99
100        def items(self):
101            return [(key, self[key]) for key in self]
102
103        def iterkeys(self):
104            return self.__iter__()
105
106        def itervalues(self):
107            return iter(self.values())
108
109        def iteritems(self):
110            return iter(self.items())
111
112        @staticmethod
113        def _read_only_error(*args, **kw):
114            raise TypeError('This object is read-only')
115
116else:
117
118     class AttrDict(OrderedDict):
119        """Simple read-only ordered dictionary for storing attribute names."""
120
121        def __init__(self, *args, **kw):
122            self._read_only = False
123            OrderedDict.__init__(self, *args, **kw)
124            self._read_only = True
125            error = self._read_only_error
126            self.clear = self.update = error
127            self.pop = self.setdefault = self.popitem = error
128
129        def __setitem__(self, key, value):
130            if self._read_only:
131                self._read_only_error()
132            OrderedDict.__setitem__(self, key, value)
133
134        def __delitem__(self, key):
135            if self._read_only:
136                self._read_only_error()
137            OrderedDict.__delitem__(self, key)
138
139        @staticmethod
140        def _read_only_error(*args, **kw):
141            raise TypeError('This object is read-only')
142
143try:
144    from inspect import signature
145except ImportError:  # Python < 3.3
146    from inspect import getargspec
147
148    get_args = lambda func: getargspec(func).args
149else:
150    get_args = lambda func: list(signature(func).parameters)
151
152try:
153    if datetime.strptime('+0100', '%z') is None:
154        raise ValueError
155except ValueError:  # Python < 3.2
156    timezones = None
157else:
158    # time zones used in Postgres timestamptz output
159    timezones = dict(CET='+0100', EET='+0200', EST='-0500',
160        GMT='+0000', HST='-1000', MET='+0100', MST='-0700',
161        UCT='+0000', UTC='+0000', WET='+0000')
162
163
164def _oid_key(table):
165    """Build oid key from a table name."""
166    return 'oid(%s)' % table
167
168
169class _SimpleTypes(dict):
170    """Dictionary mapping pg_type names to simple type names."""
171
172    _types = {'bool': 'bool',
173        'bytea': 'bytea',
174        'date': 'date interval time timetz timestamp timestamptz'
175            ' abstime reltime',  # these are very old
176        'float': 'float4 float8',
177        'int': 'cid int2 int4 int8 oid xid',
178        'json': 'json jsonb',
179        'num': 'numeric',
180        'money': 'money',
181        'text': 'bpchar char name text varchar'}
182
183    def __init__(self):
184        for typ, keys in self._types.items():
185            for key in keys.split():
186                self[key] = typ
187                self['_%s' % key] = '%s[]' % typ
188
189    @staticmethod
190    def __missing__(key):
191        return 'text'
192
193_simpletypes = _SimpleTypes()
194
195
196def _quote_if_unqualified(param, name):
197    """Quote parameter representing a qualified name.
198
199    Puts a quote_ident() call around the give parameter unless
200    the name contains a dot, in which case the name is ambiguous
201    (could be a qualified name or just a name with a dot in it)
202    and must be quoted manually by the caller.
203    """
204    if isinstance(name, basestring) and '.' not in name:
205        return 'quote_ident(%s)' % (param,)
206    return param
207
208
209class _ParameterList(list):
210    """Helper class for building typed parameter lists."""
211
212    def add(self, value, typ=None):
213        """Typecast value with known database type and build parameter list.
214
215        If this is a literal value, it will be returned as is.  Otherwise, a
216        placeholder will be returned and the parameter list will be augmented.
217        """
218        value = self.adapt(value, typ)
219        if isinstance(value, Literal):
220            return value
221        self.append(value)
222        return '$%d' % len(self)
223
224
225class Literal(str):
226    """Wrapper class for marking literal SQL values."""
227
228
229class Json:
230    """Wrapper class for marking Json values."""
231
232    def __init__(self, obj):
233        self.obj = obj
234
235
236class Bytea(bytes):
237    """Wrapper class for marking Bytea values."""
238
239
240class Adapter:
241    """Class providing methods for adapting parameters to the database."""
242
243    _bool_true_values = frozenset('t true 1 y yes on'.split())
244
245    _date_literals = frozenset('current_date current_time'
246        ' current_timestamp localtime localtimestamp'.split())
247
248    _re_array_quote = regex(r'[{},"\\\s]|^[Nn][Uu][Ll][Ll]$')
249    _re_record_quote = regex(r'[(,"\\]')
250    _re_array_escape = _re_record_escape = regex(r'(["\\])')
251
252    def __init__(self, db):
253        self.db = db
254        self.encode_json = db.encode_json
255        db = db.db
256        self.escape_bytea = db.escape_bytea
257        self.escape_string = db.escape_string
258
259    @classmethod
260    def _adapt_bool(cls, v):
261        """Adapt a boolean parameter."""
262        if isinstance(v, basestring):
263            if not v:
264                return None
265            v = v.lower() in cls._bool_true_values
266        return 't' if v else 'f'
267
268    @classmethod
269    def _adapt_date(cls, v):
270        """Adapt a date parameter."""
271        if not v:
272            return None
273        if isinstance(v, basestring) and v.lower() in cls._date_literals:
274            return Literal(v)
275        return v
276
277    @staticmethod
278    def _adapt_num(v):
279        """Adapt a numeric parameter."""
280        if not v and v != 0:
281            return None
282        return v
283
284    _adapt_int = _adapt_float = _adapt_money = _adapt_num
285
286    def _adapt_bytea(self, v):
287        """Adapt a bytea parameter."""
288        return self.escape_bytea(v)
289
290    def _adapt_json(self, v):
291        """Adapt a json parameter."""
292        if not v:
293            return None
294        if isinstance(v, basestring):
295            return v
296        return self.encode_json(v)
297
298    @classmethod
299    def _adapt_text_array(cls, v):
300        """Adapt a text type array parameter."""
301        if isinstance(v, list):
302            adapt = cls._adapt_text_array
303            return '{%s}' % ','.join(adapt(v) for v in v)
304        if v is None:
305            return 'null'
306        if not v:
307            return '""'
308        v = str(v)
309        if cls._re_array_quote.search(v):
310            v = '"%s"' % cls._re_array_escape.sub(r'\\\1', v)
311        return v
312
313    _adapt_date_array = _adapt_text_array
314
315    @classmethod
316    def _adapt_bool_array(cls, v):
317        """Adapt a boolean array parameter."""
318        if isinstance(v, list):
319            adapt = cls._adapt_bool_array
320            return '{%s}' % ','.join(adapt(v) for v in v)
321        if v is None:
322            return 'null'
323        if isinstance(v, basestring):
324            if not v:
325                return 'null'
326            v = v.lower() in cls._bool_true_values
327        return 't' if v else 'f'
328
329    @classmethod
330    def _adapt_num_array(cls, v):
331        """Adapt a numeric array parameter."""
332        if isinstance(v, list):
333            adapt = cls._adapt_num_array
334            return '{%s}' % ','.join(adapt(v) for v in v)
335        if not v and v != 0:
336            return 'null'
337        return str(v)
338
339    _adapt_int_array = _adapt_float_array = _adapt_money_array = \
340            _adapt_num_array
341
342    def _adapt_bytea_array(self, v):
343        """Adapt a bytea array parameter."""
344        if isinstance(v, list):
345            return b'{' + b','.join(
346                self._adapt_bytea_array(v) for v in v) + b'}'
347        if v is None:
348            return b'null'
349        return self.escape_bytea(v).replace(b'\\', b'\\\\')
350
351    def _adapt_json_array(self, v):
352        """Adapt a json array parameter."""
353        if isinstance(v, list):
354            adapt = self._adapt_json_array
355            return '{%s}' % ','.join(adapt(v) for v in v)
356        if not v:
357            return 'null'
358        if not isinstance(v, basestring):
359            v = self.encode_json(v)
360        if self._re_array_quote.search(v):
361            v = '"%s"' % self._re_array_escape.sub(r'\\\1', v)
362        return v
363
364    def _adapt_record(self, v, typ):
365        """Adapt a record parameter with given type."""
366        typ = self.get_attnames(typ).values()
367        if len(typ) != len(v):
368            raise TypeError('Record parameter %s has wrong size' % v)
369        adapt = self.adapt
370        value = []
371        for v, t in zip(v, typ):
372            v = adapt(v, t)
373            if v is None:
374                v = ''
375            elif not v:
376                v = '""'
377            else:
378                if isinstance(v, bytes):
379                    if str is not bytes:
380                        v = v.decode('ascii')
381                else:
382                    v = str(v)
383                if self._re_record_quote.search(v):
384                    v = '"%s"' % self._re_record_escape.sub(r'\\\1', v)
385            value.append(v)
386        return '(%s)' % ','.join(value)
387
388    def adapt(self, value, typ=None):
389        """Adapt a value with known database type."""
390        if value is not None and not isinstance(value, Literal):
391            if typ:
392                simple = self.get_simple_name(typ)
393            else:
394                typ = simple = self.guess_simple_type(value) or 'text'
395            try:
396                value = value.__pg_str__(typ)
397            except AttributeError:
398                pass
399            if simple == 'text':
400                pass
401            elif simple == 'record':
402                if isinstance(value, tuple):
403                    value = self._adapt_record(value, typ)
404            elif simple.endswith('[]'):
405                if isinstance(value, list):
406                    adapt = getattr(self, '_adapt_%s_array' % simple[:-2])
407                    value = adapt(value)
408            else:
409                adapt = getattr(self, '_adapt_%s' % simple)
410                value = adapt(value)
411        return value
412
413    @staticmethod
414    def simple_type(name):
415        """Create a simple database type with given attribute names."""
416        typ = DbType(name)
417        typ.simple = name
418        return typ
419
420    @staticmethod
421    def get_simple_name(typ):
422        """Get the simple name of a database type."""
423        if isinstance(typ, DbType):
424            return typ.simple
425        return _simpletypes[typ]
426
427    @staticmethod
428    def get_attnames(typ):
429        """Get the attribute names of a composite database type."""
430        if isinstance(typ, DbType):
431            return typ.attnames
432        return {}
433
434    @classmethod
435    def guess_simple_type(cls, value):
436        """Try to guess which database type the given value has."""
437        if isinstance(value, Bytea):
438            return 'bytea'
439        if isinstance(value, basestring):
440            return 'text'
441        if isinstance(value, bool):
442            return 'bool'
443        if isinstance(value, (int, long)):
444            return 'int'
445        if isinstance(value, float):
446            return 'float'
447        if isinstance(value, Decimal):
448            return 'num'
449        if isinstance(value, (date, time, datetime, timedelta)):
450            return 'date'
451        if isinstance(value, list):
452            return '%s[]' % cls.guess_simple_base_type(value)
453        if isinstance(value, tuple):
454            simple_type = cls.simple_type
455            typ = simple_type('record')
456            guess = cls.guess_simple_type
457            typ._get_attnames = lambda _self: AttrDict(
458                (str(n + 1), simple_type(guess(v)))
459                for n, v in enumerate(value))
460            return typ
461
462    @classmethod
463    def guess_simple_base_type(cls, value):
464        """Try to guess the base type of a given array."""
465        for v in value:
466            if isinstance(v, list):
467                typ = cls.guess_simple_base_type(v)
468            else:
469                typ = cls.guess_simple_type(v)
470            if typ:
471                return typ
472
473    def adapt_inline(self, value, nested=False):
474        """Adapt a value that is put into the SQL and needs to be quoted."""
475        if value is None:
476            return 'NULL'
477        if isinstance(value, Literal):
478            return value
479        if isinstance(value, Bytea):
480            value = self.escape_bytea(value)
481            if bytes is not str:  # Python >= 3.0
482                value = value.decode('ascii')
483        elif isinstance(value, Json):
484            if value.encode:
485                return value.encode()
486            value = self.encode_json(value)
487        elif isinstance(value, (datetime, date, time, timedelta)):
488            value = str(value)
489        if isinstance(value, basestring):
490            value = self.escape_string(value)
491            return "'%s'" % value
492        if isinstance(value, bool):
493            return 'true' if value else 'false'
494        if isinstance(value, float):
495            if isinf(value):
496                return "'-Infinity'" if value < 0 else "'Infinity'"
497            if isnan(value):
498                return "'NaN'"
499            return value
500        if isinstance(value, (int, long, Decimal)):
501            return value
502        if isinstance(value, list):
503            q = self.adapt_inline
504            s = '[%s]' if nested else 'ARRAY[%s]'
505            return s % ','.join(str(q(v, nested=True)) for v in value)
506        if isinstance(value, tuple):
507            q = self.adapt_inline
508            return '(%s)' % ','.join(str(q(v)) for v in value)
509        try:
510            value = value.__pg_repr__()
511        except AttributeError:
512            raise InterfaceError(
513                'Do not know how to adapt type %s' % type(value))
514        if isinstance(value, (tuple, list)):
515            value = self.adapt_inline(value)
516        return value
517
518    def parameter_list(self):
519        """Return a parameter list for parameters with known database types.
520
521        The list has an add(value, typ) method that will build up the
522        list and return either the literal value or a placeholder.
523        """
524        params = _ParameterList()
525        params.adapt = self.adapt
526        return params
527
528    def format_query(self, command, values, types=None, inline=False):
529        """Format a database query using the given values and types."""
530        if inline and types:
531            raise ValueError('Typed parameters must be sent separately')
532        params = self.parameter_list()
533        if isinstance(values, (list, tuple)):
534            if inline:
535                adapt = self.adapt_inline
536                literals = [adapt(value) for value in values]
537            else:
538                add = params.add
539                literals = []
540                append = literals.append
541                if types:
542                    if (not isinstance(types, (list, tuple)) or
543                            len(types) != len(values)):
544                        raise TypeError('The values and types do not match')
545                    for value, typ in zip(values, types):
546                        append(add(value, typ))
547                else:
548                    for value in values:
549                        append(add(value))
550            command = command % tuple(literals)
551        elif isinstance(values, dict):
552            if inline:
553                adapt = self.adapt_inline
554                literals = dict((key, adapt(value))
555                    for key, value in values.items())
556            else:
557                add = params.add
558                literals = {}
559                if types:
560                    if (not isinstance(types, dict) or
561                            len(types) < len(values)):
562                        raise TypeError('The values and types do not match')
563                    for key in sorted(values):
564                        literals[key] = add(values[key], types[key])
565                else:
566                    for key in sorted(values):
567                        literals[key] = add(values[key])
568            command = command % literals
569        else:
570            raise TypeError('The values must be passed as tuple, list or dict')
571        return command, params
572
573
574def cast_bool(value):
575    """Cast a boolean value."""
576    if not get_bool():
577        return value
578    return value[0] == 't'
579
580
581def cast_json(value):
582    """Cast a JSON value."""
583    cast = get_jsondecode()
584    if not cast:
585        return value
586    return cast(value)
587
588
589def cast_num(value):
590    """Cast a numeric value."""
591    return (get_decimal() or float)(value)
592
593
594def cast_money(value):
595    """Cast a money value."""
596    point = get_decimal_point()
597    if not point:
598        return value
599    if point != '.':
600        value = value.replace(point, '.')
601    value = value.replace('(', '-')
602    value = ''.join(c for c in value if c.isdigit() or c in '.-')
603    return (get_decimal() or float)(value)
604
605
606def cast_int2vector(value):
607    """Cast an int2vector value."""
608    return [int(v) for v in value.split()]
609
610
611def cast_date(value, connection):
612    """Cast a date value."""
613    # The output format depends on the server setting DateStyle.  The default
614    # setting ISO and the setting for German are actually unambiguous.  The
615    # order of days and months in the other two settings is however ambiguous,
616    # so at least here we need to consult the setting to properly parse values.
617    if value == '-infinity':
618        return date.min
619    if value == 'infinity':
620        return date.max
621    value = value.split()
622    if value[-1] == 'BC':
623        return date.min
624    value = value[0]
625    if len(value) > 10:
626        return date.max
627    fmt = connection.date_format()
628    return datetime.strptime(value, fmt).date()
629
630
631def cast_time(value):
632    """Cast a time value."""
633    fmt = '%H:%M:%S.%f' if len(value) > 8 else '%H:%M:%S'
634    return datetime.strptime(value, fmt).time()
635
636
637_re_timezone = regex('(.*)([+-].*)')
638
639
640def cast_timetz(value):
641    """Cast a timetz value."""
642    tz = _re_timezone.match(value)
643    if tz:
644        value, tz = tz.groups()
645    else:
646        tz = '+0000'
647    fmt = '%H:%M:%S.%f' if len(value) > 8 else '%H:%M:%S'
648    if timezones:
649        if tz.startswith(('+', '-')):
650            if len(tz) < 5:
651                tz += '00'
652            else:
653                tz = tz.replace(':', '')
654        elif tz in timezones:
655            tz = timezones[tz]
656        else:
657            tz = '+0000'
658        value += tz
659        fmt += '%z'
660    return datetime.strptime(value, fmt).timetz()
661
662
663def cast_timestamp(value, connection):
664    """Cast a timestamp value."""
665    if value == '-infinity':
666        return datetime.min
667    if value == 'infinity':
668        return datetime.max
669    value = value.split()
670    if value[-1] == 'BC':
671        return datetime.min
672    fmt = connection.date_format()
673    if fmt.endswith('-%Y') and len(value) > 2:
674        value = value[1:5]
675        if len(value[3]) > 4:
676            return datetime.max
677        fmt = ['%d %b' if fmt.startswith('%d') else '%b %d',
678            '%H:%M:%S.%f' if len(value[2]) > 8 else '%H:%M:%S', '%Y']
679    else:
680        if len(value[0]) > 10:
681            return datetime.max
682        fmt = [fmt, '%H:%M:%S.%f' if len(value[1]) > 8 else '%H:%M:%S']
683    return datetime.strptime(' '.join(value), ' '.join(fmt))
684
685
686def cast_timestamptz(value, connection):
687    """Cast a timestamptz value."""
688    if value == '-infinity':
689        return datetime.min
690    if value == 'infinity':
691        return datetime.max
692    value = value.split()
693    if value[-1] == 'BC':
694        return datetime.min
695    fmt = connection.date_format()
696    if fmt.endswith('-%Y') and len(value) > 2:
697        value = value[1:]
698        if len(value[3]) > 4:
699            return datetime.max
700        fmt = ['%d %b' if fmt.startswith('%d') else '%b %d',
701            '%H:%M:%S.%f' if len(value[2]) > 8 else '%H:%M:%S', '%Y']
702        value, tz = value[:-1], value[-1]
703    else:
704        if fmt.startswith('%Y-'):
705            tz = _re_timezone.match(value[1])
706            if tz:
707                value[1], tz = tz.groups()
708            else:
709                tz = '+0000'
710        else:
711            value, tz = value[:-1], value[-1]
712        if len(value[0]) > 10:
713            return datetime.max
714        fmt = [fmt, '%H:%M:%S.%f' if len(value[1]) > 8 else '%H:%M:%S']
715    if timezones:
716        if tz.startswith(('+', '-')):
717            if len(tz) < 5:
718                tz += '00'
719            else:
720                tz = tz.replace(':', '')
721        elif tz in timezones:
722            tz = timezones[tz]
723        else:
724            tz = '+0000'
725        value.append(tz)
726        fmt.append('%z')
727    return datetime.strptime(' '.join(value), ' '.join(fmt))
728
729_re_interval_sql_standard = regex(
730    '(?:([+-])?([0-9]+)-([0-9]+) ?)?'
731    '(?:([+-]?[0-9]+)(?!:) ?)?'
732    '(?:([+-])?([0-9]+):([0-9]+):([0-9]+)(?:\\.([0-9]+))?)?')
733
734_re_interval_postgres = regex(
735    '(?:([+-]?[0-9]+) ?years? ?)?'
736    '(?:([+-]?[0-9]+) ?mons? ?)?'
737    '(?:([+-]?[0-9]+) ?days? ?)?'
738    '(?:([+-])?([0-9]+):([0-9]+):([0-9]+)(?:\\.([0-9]+))?)?')
739
740_re_interval_postgres_verbose = regex(
741    '@ ?(?:([+-]?[0-9]+) ?years? ?)?'
742    '(?:([+-]?[0-9]+) ?mons? ?)?'
743    '(?:([+-]?[0-9]+) ?days? ?)?'
744    '(?:([+-]?[0-9]+) ?hours? ?)?'
745    '(?:([+-]?[0-9]+) ?mins? ?)?'
746    '(?:([+-])?([0-9]+)(?:\\.([0-9]+))? ?secs?)? ?(ago)?')
747
748_re_interval_iso_8601 = regex(
749    'P(?:([+-]?[0-9]+)Y)?'
750    '(?:([+-]?[0-9]+)M)?'
751    '(?:([+-]?[0-9]+)D)?'
752    '(?:T(?:([+-]?[0-9]+)H)?'
753    '(?:([+-]?[0-9]+)M)?'
754    '(?:([+-])?([0-9]+)(?:\\.([0-9]+))?S)?)?')
755
756
757def cast_interval(value):
758    """Cast an interval value."""
759    # The output format depends on the server setting IntervalStyle, but it's
760    # not necessary to consult this setting to parse it.  It's faster to just
761    # check all possible formats, and there is no ambiguity here.
762    m = _re_interval_iso_8601.match(value)
763    if m:
764        m = [d or '0' for d in m.groups()]
765        secs_ago = m.pop(5) == '-'
766        m = [int(d) for d in m]
767        years, mons, days, hours, mins, secs, usecs = m
768        if secs_ago:
769            secs = -secs
770            usecs = -usecs
771    else:
772        m = _re_interval_postgres_verbose.match(value)
773        if m:
774            m, ago = [d or '0' for d in m.groups()[:8]], m.group(9)
775            secs_ago = m.pop(5) == '-'
776            m = [-int(d) for d in m] if ago else [int(d) for d in m]
777            years, mons, days, hours, mins, secs, usecs = m
778            if secs_ago:
779                secs = - secs
780                usecs = -usecs
781        else:
782            m = _re_interval_postgres.match(value)
783            if m and any(m.groups()):
784                m = [d or '0' for d in m.groups()]
785                hours_ago = m.pop(3) == '-'
786                m = [int(d) for d in m]
787                years, mons, days, hours, mins, secs, usecs = m
788                if hours_ago:
789                    hours = -hours
790                    mins = -mins
791                    secs = -secs
792                    usecs = -usecs
793            else:
794                m = _re_interval_sql_standard.match(value)
795                if m and any(m.groups()):
796                    m = [d or '0' for d in m.groups()]
797                    years_ago = m.pop(0) == '-'
798                    hours_ago = m.pop(3) == '-'
799                    m = [int(d) for d in m]
800                    years, mons, days, hours, mins, secs, usecs = m
801                    if years_ago:
802                        years = -years
803                        mons = -mons
804                    if hours_ago:
805                        hours = -hours
806                        mins = -mins
807                        secs = -secs
808                        usecs = -usecs
809                else:
810                    raise ValueError('Cannot parse interval: %s' % value)
811    days += 365 * years + 30 * mons
812    return timedelta(days=days, hours=hours, minutes=mins,
813        seconds=secs, microseconds=usecs)
814
815
816class Typecasts(dict):
817    """Dictionary mapping database types to typecast functions.
818
819    The cast functions get passed the string representation of a value in
820    the database which they need to convert to a Python object.  The
821    passed string will never be None since NULL values are already be
822    handled before the cast function is called.
823
824    Note that the basic types are already handled by the C extension.
825    They only need to be handled here as record or array components.
826    """
827
828    # the default cast functions
829    # (str functions are ignored but have been added for faster access)
830    defaults = {'char': str, 'bpchar': str, 'name': str,
831        'text': str, 'varchar': str,
832        'bool': cast_bool, 'bytea': unescape_bytea,
833        'int2': int, 'int4': int, 'serial': int,
834        'int8': long, 'json': cast_json, 'jsonb': cast_json,
835        'oid': long, 'oid8': long,
836        'float4': float, 'float8': float,
837        'numeric': cast_num, 'money': cast_money,
838        'date': cast_date, 'interval': cast_interval,
839        'time': cast_time, 'timetz': cast_timetz,
840        'timestamp': cast_timestamp, 'timestamptz': cast_timestamptz,
841        'int2vector': cast_int2vector,
842        'anyarray': cast_array, 'record': cast_record}
843
844    connection = None  # will be set in a connection specific instance
845
846    def __missing__(self, typ):
847        """Create a cast function if it is not cached.
848       
849        Note that this class never raises a KeyError,
850        but returns None when no special cast function exists.
851        """
852        if not isinstance(typ, str):
853            raise TypeError('Invalid type: %s' % typ)
854        cast = self.defaults.get(typ)
855        if cast:
856            # store default for faster access
857            cast = self._add_connection(cast)
858            self[typ] = cast
859        elif typ.startswith('_'):
860            base_cast = self[typ[1:]]
861            cast = self.create_array_cast(base_cast)
862            if base_cast:
863                self[typ] = cast
864        else:
865            attnames = self.get_attnames(typ)
866            if attnames:
867                casts = [self[v.pgtype] for v in attnames.values()]
868                cast = self.create_record_cast(typ, attnames, casts)
869                self[typ] = cast
870        return cast
871
872    @staticmethod
873    def _needs_connection(func):
874        """Check if a typecast function needs a connection argument."""
875        try:
876            args = get_args(func)
877        except (TypeError, ValueError):
878            return False
879        else:
880            return 'connection' in args[1:]
881
882    def _add_connection(self, cast):
883        """Add a connection argument to the typecast function if necessary."""
884        if not self.connection or not self._needs_connection(cast):
885            return cast
886        connection = self.connection
887        return lambda value: cast(value, connection=connection)
888
889    def get(self, typ, default=None):
890        """Get the typecast function for the given database type."""
891        return self[typ] or default
892
893    def set(self, typ, cast):
894        """Set a typecast function for the specified database type(s)."""
895        if isinstance(typ, basestring):
896            typ = [typ]
897        if cast is None:
898            for t in typ:
899                self.pop(t, None)
900                self.pop('_%s' % t, None)
901        else:
902            if not callable(cast):
903                raise TypeError("Cast parameter must be callable")
904            for t in typ:
905                self[t] = self._add_connection(cast)
906                self.pop('_%s' % t, None)
907
908    def reset(self, typ=None):
909        """Reset the typecasts for the specified type(s) to their defaults.
910
911        When no type is specified, all typecasts will be reset.
912        """
913        if typ is None:
914            self.clear()
915        else:
916            if isinstance(typ, basestring):
917                typ = [typ]
918            for t in typ:
919                self.pop(t, None)
920
921    @classmethod
922    def get_default(cls, typ):
923        """Get the default typecast function for the given database type."""
924        return cls.defaults.get(typ)
925
926    @classmethod
927    def set_default(cls, typ, cast):
928        """Set a default typecast function for the given database type(s)."""
929        if isinstance(typ, basestring):
930            typ = [typ]
931        defaults = cls.defaults
932        if cast is None:
933            for t in typ:
934                defaults.pop(t, None)
935                defaults.pop('_%s' % t, None)
936        else:
937            if not callable(cast):
938                raise TypeError("Cast parameter must be callable")
939            for t in typ:
940                defaults[t] = cast
941                defaults.pop('_%s' % t, None)
942
943    def get_attnames(self, typ):
944        """Return the fields for the given record type.
945
946        This method will be replaced with the get_attnames() method of DbTypes.
947        """
948        return {}
949
950    def dateformat(self):
951        """Return the current date format.
952
953        This method will be replaced with the dateformat() method of DbTypes.
954        """
955        return '%Y-%m-%d'
956
957    def create_array_cast(self, cast):
958        """Create an array typecast for the given base cast."""
959        return lambda v: cast_array(v, cast)
960
961    def create_record_cast(self, name, fields, casts):
962        """Create a named record typecast for the given fields and casts."""
963        record = namedtuple(name, fields)
964        return lambda v: record(*cast_record(v, casts))
965
966
967def get_typecast(typ):
968    """Get the global typecast function for the given database type(s)."""
969    return Typecasts.get_default(typ)
970
971
972def set_typecast(typ, cast):
973    """Set a global typecast function for the given database type(s).
974
975    Note that connections cache cast functions. To be sure a global change
976    is picked up by a running connection, call db.db_types.reset_typecast().
977    """
978    Typecasts.set_default(typ, cast)
979
980
981class DbType(str):
982    """Class augmenting the simple type name with additional info.
983
984    The following additional information is provided:
985
986        oid: the PostgreSQL type OID
987        pgtype: the PostgreSQL type name
988        regtype: the regular type name
989        simple: the simple PyGreSQL type name
990        typtype: b = base type, c = composite type etc.
991        category: A = Array, b =Boolean, C = Composite etc.
992        delim: delimiter for array types
993        relid: corresponding table for composite types
994        attnames: attributes for composite types
995    """
996
997    @property
998    def attnames(self):
999        """Get names and types of the fields of a composite type."""
1000        return self._get_attnames(self)
1001
1002
1003class DbTypes(dict):
1004    """Cache for PostgreSQL data types.
1005
1006    This cache maps type OIDs and names to DbType objects containing
1007    information on the associated database type.
1008    """
1009
1010    _num_types = frozenset('int float num money'
1011        ' int2 int4 int8 float4 float8 numeric money'.split())
1012
1013    def __init__(self, db):
1014        """Initialize type cache for connection."""
1015        super(DbTypes, self).__init__()
1016        self._regtypes = False
1017        self._get_attnames = db.get_attnames
1018        self._typecasts = Typecasts()
1019        self._typecasts.get_attnames = self.get_attnames
1020        self._typecasts.connection = db
1021        db = db.db
1022        self.query = db.query
1023        self.escape_string = db.escape_string
1024
1025    def add(self, oid, pgtype, regtype,
1026               typtype, category, delim, relid):
1027        """Create a PostgreSQL type name with additional info."""
1028        if oid in self:
1029            return self[oid]
1030        simple = 'record' if relid else _simpletypes[pgtype]
1031        typ = DbType(regtype if self._regtypes else simple)
1032        typ.oid = oid
1033        typ.simple = simple
1034        typ.pgtype = pgtype
1035        typ.regtype = regtype
1036        typ.typtype = typtype
1037        typ.category = category
1038        typ.delim = delim
1039        typ.relid = relid
1040        typ._get_attnames = self.get_attnames
1041        return typ
1042
1043    def __missing__(self, key):
1044        """Get the type info from the database if it is not cached."""
1045        try:
1046            res = self.query("SELECT oid, typname, typname::regtype,"
1047                " typtype, typcategory, typdelim, typrelid"
1048                " FROM pg_type WHERE oid=%s::regtype" %
1049                (_quote_if_unqualified('$1', key),), (key,)).getresult()
1050        except ProgrammingError:
1051            res = None
1052        if not res:
1053            raise KeyError('Type %s could not be found' % key)
1054        res = res[0]
1055        typ = self.add(*res)
1056        self[typ.oid] = self[typ.pgtype] = typ
1057        return typ
1058
1059    def get(self, key, default=None):
1060        """Get the type even if it is not cached."""
1061        try:
1062            return self[key]
1063        except KeyError:
1064            return default
1065
1066    def get_attnames(self, typ):
1067        """Get names and types of the fields of a composite type."""
1068        if not isinstance(typ, DbType):
1069            typ = self.get(typ)
1070            if not typ:
1071                return None
1072        if not typ.relid:
1073            return None
1074        return self._get_attnames(typ.relid, with_oid=False)
1075
1076    def get_typecast(self, typ):
1077        """Get the typecast function for the given database type."""
1078        return self._typecasts.get(typ)
1079
1080    def set_typecast(self, typ, cast):
1081        """Set a typecast function for the specified database type(s)."""
1082        self._typecasts.set(typ, cast)
1083
1084    def reset_typecast(self, typ=None):
1085        """Reset the typecast function for the specified database type(s)."""
1086        self._typecasts.reset(typ)
1087
1088    def typecast(self, value, typ):
1089        """Cast the given value according to the given database type."""
1090        if value is None:
1091            # for NULL values, no typecast is necessary
1092            return None
1093        if not isinstance(typ, DbType):
1094            typ = self.get(typ)
1095            if typ:
1096                typ = typ.pgtype
1097        cast = self.get_typecast(typ) if typ else None
1098        if not cast or cast is str:
1099            # no typecast is necessary
1100            return value
1101        return cast(value)
1102
1103
1104def _namedresult(q):
1105    """Get query result as named tuples."""
1106    row = namedtuple('Row', q.listfields())
1107    return [row(*r) for r in q.getresult()]
1108
1109
1110class _MemoryQuery:
1111    """Class that embodies a given query result."""
1112
1113    def __init__(self, result, fields):
1114        """Create query from given result rows and field names."""
1115        self.result = result
1116        self.fields = fields
1117
1118    def listfields(self):
1119        """Return the stored field names of this query."""
1120        return self.fields
1121
1122    def getresult(self):
1123        """Return the stored result of this query."""
1124        return self.result
1125
1126
1127def _db_error(msg, cls=DatabaseError):
1128    """Return DatabaseError with empty sqlstate attribute."""
1129    error = cls(msg)
1130    error.sqlstate = None
1131    return error
1132
1133
1134def _int_error(msg):
1135    """Return InternalError."""
1136    return _db_error(msg, InternalError)
1137
1138
1139def _prg_error(msg):
1140    """Return ProgrammingError."""
1141    return _db_error(msg, ProgrammingError)
1142
1143
1144# Initialize the C module
1145
1146set_namedresult(_namedresult)
1147set_decimal(Decimal)
1148set_jsondecode(jsondecode)
1149
1150
1151# The notification handler
1152
1153class NotificationHandler(object):
1154    """A PostgreSQL client-side asynchronous notification handler."""
1155
1156    def __init__(self, db, event, callback=None,
1157            arg_dict=None, timeout=None, stop_event=None):
1158        """Initialize the notification handler.
1159
1160        You must pass a PyGreSQL database connection, the name of an
1161        event (notification channel) to listen for and a callback function.
1162
1163        You can also specify a dictionary arg_dict that will be passed as
1164        the single argument to the callback function, and a timeout value
1165        in seconds (a floating point number denotes fractions of seconds).
1166        If it is absent or None, the callers will never time out.  If the
1167        timeout is reached, the callback function will be called with a
1168        single argument that is None.  If you set the timeout to zero,
1169        the handler will poll notifications synchronously and return.
1170
1171        You can specify the name of the event that will be used to signal
1172        the handler to stop listening as stop_event. By default, it will
1173        be the event name prefixed with 'stop_'.
1174        """
1175        self.db = db
1176        self.event = event
1177        self.stop_event = stop_event or 'stop_%s' % event
1178        self.listening = False
1179        self.callback = callback
1180        if arg_dict is None:
1181            arg_dict = {}
1182        self.arg_dict = arg_dict
1183        self.timeout = timeout
1184
1185    def __del__(self):
1186        self.unlisten()
1187
1188    def close(self):
1189        """Stop listening and close the connection."""
1190        if self.db:
1191            self.unlisten()
1192            self.db.close()
1193            self.db = None
1194
1195    def listen(self):
1196        """Start listening for the event and the stop event."""
1197        if not self.listening:
1198            self.db.query('listen "%s"' % self.event)
1199            self.db.query('listen "%s"' % self.stop_event)
1200            self.listening = True
1201
1202    def unlisten(self):
1203        """Stop listening for the event and the stop event."""
1204        if self.listening:
1205            self.db.query('unlisten "%s"' % self.event)
1206            self.db.query('unlisten "%s"' % self.stop_event)
1207            self.listening = False
1208
1209    def notify(self, db=None, stop=False, payload=None):
1210        """Generate a notification.
1211
1212        Optionally, you can pass a payload with the notification.
1213
1214        If you set the stop flag, a stop notification will be sent that
1215        will cause the handler to stop listening.
1216
1217        Note: If the notification handler is running in another thread, you
1218        must pass a different database connection since PyGreSQL database
1219        connections are not thread-safe.
1220        """
1221        if self.listening:
1222            if not db:
1223                db = self.db
1224            q = 'notify "%s"' % (self.stop_event if stop else self.event)
1225            if payload:
1226                q += ", '%s'" % payload
1227            return db.query(q)
1228
1229    def __call__(self):
1230        """Invoke the notification handler.
1231
1232        The handler is a loop that listens for notifications on the event
1233        and stop event channels.  When either of these notifications are
1234        received, its associated 'pid', 'event' and 'extra' (the payload
1235        passed with the notification) are inserted into its arg_dict
1236        dictionary and the callback is invoked with this dictionary as
1237        a single argument.  When the handler receives a stop event, it
1238        stops listening to both events and return.
1239
1240        In the special case that the timeout of the handler has been set
1241        to zero, the handler will poll all events synchronously and return.
1242        If will keep listening until it receives a stop event.
1243
1244        Note: If you run this loop in another thread, don't use the same
1245        database connection for database operations in the main thread.
1246        """
1247        self.listen()
1248        poll = self.timeout == 0
1249        if not poll:
1250            rlist = [self.db.fileno()]
1251        while self.listening:
1252            if poll or select.select(rlist, [], [], self.timeout)[0]:
1253                while self.listening:
1254                    notice = self.db.getnotify()
1255                    if not notice:  # no more messages
1256                        break
1257                    event, pid, extra = notice
1258                    if event not in (self.event, self.stop_event):
1259                        self.unlisten()
1260                        raise _db_error(
1261                            'Listening for "%s" and "%s", but notified of "%s"'
1262                            % (self.event, self.stop_event, event))
1263                    if event == self.stop_event:
1264                        self.unlisten()
1265                    self.arg_dict.update(pid=pid, event=event, extra=extra)
1266                    self.callback(self.arg_dict)
1267                if poll:
1268                    break
1269            else:   # we timed out
1270                self.unlisten()
1271                self.callback(None)
1272
1273
1274def pgnotify(*args, **kw):
1275    """Same as NotificationHandler, under the traditional name."""
1276    warnings.warn("pgnotify is deprecated, use NotificationHandler instead",
1277        DeprecationWarning, stacklevel=2)
1278    return NotificationHandler(*args, **kw)
1279
1280
1281# The actual PostGreSQL database connection interface:
1282
1283class DB:
1284    """Wrapper class for the _pg connection type."""
1285
1286    def __init__(self, *args, **kw):
1287        """Create a new connection
1288
1289        You can pass either the connection parameters or an existing
1290        _pg or pgdb connection. This allows you to use the methods
1291        of the classic pg interface with a DB-API 2 pgdb connection.
1292        """
1293        if not args and len(kw) == 1:
1294            db = kw.get('db')
1295        elif not kw and len(args) == 1:
1296            db = args[0]
1297        else:
1298            db = None
1299        if db:
1300            if isinstance(db, DB):
1301                db = db.db
1302            else:
1303                try:
1304                    db = db._cnx
1305                except AttributeError:
1306                    pass
1307        if not db or not hasattr(db, 'db') or not hasattr(db, 'query'):
1308            db = connect(*args, **kw)
1309            self._closeable = True
1310        else:
1311            self._closeable = False
1312        self.db = db
1313        self.dbname = db.db
1314        self._regtypes = False
1315        self._attnames = {}
1316        self._pkeys = {}
1317        self._privileges = {}
1318        self._args = args, kw
1319        self.adapter = Adapter(self)
1320        self.dbtypes = DbTypes(self)
1321        db.set_cast_hook(self.dbtypes.typecast)
1322        self.debug = None  # For debugging scripts, this can be set
1323            # * to a string format specification (e.g. in CGI set to "%s<BR>"),
1324            # * to a file object to write debug statements or
1325            # * to a callable object which takes a string argument
1326            # * to any other true value to just print debug statements
1327
1328    def __getattr__(self, name):
1329        # All undefined members are same as in underlying connection:
1330        if self.db:
1331            return getattr(self.db, name)
1332        else:
1333            raise _int_error('Connection is not valid')
1334
1335    def __dir__(self):
1336        # Custom dir function including the attributes of the connection:
1337        attrs = set(self.__class__.__dict__)
1338        attrs.update(self.__dict__)
1339        attrs.update(dir(self.db))
1340        return sorted(attrs)
1341
1342    # Context manager methods
1343
1344    def __enter__(self):
1345        """Enter the runtime context. This will start a transactio."""
1346        self.begin()
1347        return self
1348
1349    def __exit__(self, et, ev, tb):
1350        """Exit the runtime context. This will end the transaction."""
1351        if et is None and ev is None and tb is None:
1352            self.commit()
1353        else:
1354            self.rollback()
1355
1356    # Auxiliary methods
1357
1358    def _do_debug(self, *args):
1359        """Print a debug message"""
1360        if self.debug:
1361            s = '\n'.join(str(arg) for arg in args)
1362            if isinstance(self.debug, basestring):
1363                print(self.debug % s)
1364            elif hasattr(self.debug, 'write'):
1365                self.debug.write(s + '\n')
1366            elif callable(self.debug):
1367                self.debug(s)
1368            else:
1369                print(s)
1370
1371    def _escape_qualified_name(self, s):
1372        """Escape a qualified name.
1373
1374        Escapes the name for use as an SQL identifier, unless the
1375        name contains a dot, in which case the name is ambiguous
1376        (could be a qualified name or just a name with a dot in it)
1377        and must be quoted manually by the caller.
1378        """
1379        if '.' not in s:
1380            s = self.escape_identifier(s)
1381        return s
1382
1383    @staticmethod
1384    def _make_bool(d):
1385        """Get boolean value corresponding to d."""
1386        return bool(d) if get_bool() else ('t' if d else 'f')
1387
1388    def _list_params(self, params):
1389        """Create a human readable parameter list."""
1390        return ', '.join('$%d=%r' % (n, v) for n, v in enumerate(params, 1))
1391
1392    # Public methods
1393
1394    # escape_string and escape_bytea exist as methods,
1395    # so we define unescape_bytea as a method as well
1396    unescape_bytea = staticmethod(unescape_bytea)
1397
1398    def decode_json(self, s):
1399        """Decode a JSON string coming from the database."""
1400        return (get_jsondecode() or jsondecode)(s)
1401
1402    def encode_json(self, d):
1403        """Encode a JSON string for use within SQL."""
1404        return jsonencode(d)
1405
1406    def close(self):
1407        """Close the database connection."""
1408        # Wraps shared library function so we can track state.
1409        if self._closeable:
1410            if self.db:
1411                self.db.close()
1412                self.db = None
1413            else:
1414                raise _int_error('Connection already closed')
1415
1416    def reset(self):
1417        """Reset connection with current parameters.
1418
1419        All derived queries and large objects derived from this connection
1420        will not be usable after this call.
1421
1422        """
1423        if self.db:
1424            self.db.reset()
1425        else:
1426            raise _int_error('Connection already closed')
1427
1428    def reopen(self):
1429        """Reopen connection to the database.
1430
1431        Used in case we need another connection to the same database.
1432        Note that we can still reopen a database that we have closed.
1433
1434        """
1435        # There is no such shared library function.
1436        if self._closeable:
1437            db = connect(*self._args[0], **self._args[1])
1438            if self.db:
1439                self.db.close()
1440            self.db = db
1441
1442    def begin(self, mode=None):
1443        """Begin a transaction."""
1444        qstr = 'BEGIN'
1445        if mode:
1446            qstr += ' ' + mode
1447        return self.query(qstr)
1448
1449    start = begin
1450
1451    def commit(self):
1452        """Commit the current transaction."""
1453        return self.query('COMMIT')
1454
1455    end = commit
1456
1457    def rollback(self, name=None):
1458        """Roll back the current transaction."""
1459        qstr = 'ROLLBACK'
1460        if name:
1461            qstr += ' TO ' + name
1462        return self.query(qstr)
1463
1464    abort = rollback
1465
1466    def savepoint(self, name):
1467        """Define a new savepoint within the current transaction."""
1468        return self.query('SAVEPOINT ' + name)
1469
1470    def release(self, name):
1471        """Destroy a previously defined savepoint."""
1472        return self.query('RELEASE ' + name)
1473
1474    def get_parameter(self, parameter):
1475        """Get the value of a run-time parameter.
1476
1477        If the parameter is a string, the return value will also be a string
1478        that is the current setting of the run-time parameter with that name.
1479
1480        You can get several parameters at once by passing a list, set or dict.
1481        When passing a list of parameter names, the return value will be a
1482        corresponding list of parameter settings.  When passing a set of
1483        parameter names, a new dict will be returned, mapping these parameter
1484        names to their settings.  Finally, if you pass a dict as parameter,
1485        its values will be set to the current parameter settings corresponding
1486        to its keys.
1487
1488        By passing the special name 'all' as the parameter, you can get a dict
1489        of all existing configuration parameters.
1490        """
1491        if isinstance(parameter, basestring):
1492            parameter = [parameter]
1493            values = None
1494        elif isinstance(parameter, (list, tuple)):
1495            values = []
1496        elif isinstance(parameter, (set, frozenset)):
1497            values = {}
1498        elif isinstance(parameter, dict):
1499            values = parameter
1500        else:
1501            raise TypeError(
1502                'The parameter must be a string, list, set or dict')
1503        if not parameter:
1504            raise TypeError('No parameter has been specified')
1505        params = {} if isinstance(values, dict) else []
1506        for key in parameter:
1507            param = key.strip().lower() if isinstance(
1508                key, basestring) else None
1509            if not param:
1510                raise TypeError('Invalid parameter')
1511            if param == 'all':
1512                q = 'SHOW ALL'
1513                values = self.db.query(q).getresult()
1514                values = dict(value[:2] for value in values)
1515                break
1516            if isinstance(values, dict):
1517                params[param] = key
1518            else:
1519                params.append(param)
1520        else:
1521            for param in params:
1522                q = 'SHOW %s' % (param,)
1523                value = self.db.query(q).getresult()[0][0]
1524                if values is None:
1525                    values = value
1526                elif isinstance(values, list):
1527                    values.append(value)
1528                else:
1529                    values[params[param]] = value
1530        return values
1531
1532    def set_parameter(self, parameter, value=None, local=False):
1533        """Set the value of a run-time parameter.
1534
1535        If the parameter and the value are strings, the run-time parameter
1536        will be set to that value.  If no value or None is passed as a value,
1537        then the run-time parameter will be restored to its default value.
1538
1539        You can set several parameters at once by passing a list of parameter
1540        names, together with a single value that all parameters should be
1541        set to or with a corresponding list of values.  You can also pass
1542        the parameters as a set if you only provide a single value.
1543        Finally, you can pass a dict with parameter names as keys.  In this
1544        case, you should not pass a value, since the values for the parameters
1545        will be taken from the dict.
1546
1547        By passing the special name 'all' as the parameter, you can reset
1548        all existing settable run-time parameters to their default values.
1549
1550        If you set local to True, then the command takes effect for only the
1551        current transaction.  After commit() or rollback(), the session-level
1552        setting takes effect again.  Setting local to True will appear to
1553        have no effect if it is executed outside a transaction, since the
1554        transaction will end immediately.
1555        """
1556        if isinstance(parameter, basestring):
1557            parameter = {parameter: value}
1558        elif isinstance(parameter, (list, tuple)):
1559            if isinstance(value, (list, tuple)):
1560                parameter = dict(zip(parameter, value))
1561            else:
1562                parameter = dict.fromkeys(parameter, value)
1563        elif isinstance(parameter, (set, frozenset)):
1564            if isinstance(value, (list, tuple, set, frozenset)):
1565                value = set(value)
1566                if len(value) == 1:
1567                    value = value.pop()
1568            if not(value is None or isinstance(value, basestring)):
1569                raise ValueError('A single value must be specified'
1570                    ' when parameter is a set')
1571            parameter = dict.fromkeys(parameter, value)
1572        elif isinstance(parameter, dict):
1573            if value is not None:
1574                raise ValueError('A value must not be specified'
1575                    ' when parameter is a dictionary')
1576        else:
1577            raise TypeError(
1578                'The parameter must be a string, list, set or dict')
1579        if not parameter:
1580            raise TypeError('No parameter has been specified')
1581        params = {}
1582        for key, value in parameter.items():
1583            param = key.strip().lower() if isinstance(
1584                key, basestring) else None
1585            if not param:
1586                raise TypeError('Invalid parameter')
1587            if param == 'all':
1588                if value is not None:
1589                    raise ValueError('A value must ot be specified'
1590                        " when parameter is 'all'")
1591                params = {'all': None}
1592                break
1593            params[param] = value
1594        local = ' LOCAL' if local else ''
1595        for param, value in params.items():
1596            if value is None:
1597                q = 'RESET%s %s' % (local, param)
1598            else:
1599                q = 'SET%s %s TO %s' % (local, param, value)
1600            self._do_debug(q)
1601            self.db.query(q)
1602
1603    def query(self, command, *args):
1604        """Execute a SQL command string.
1605
1606        This method simply sends a SQL query to the database.  If the query is
1607        an insert statement that inserted exactly one row into a table that
1608        has OIDs, the return value is the OID of the newly inserted row.
1609        If the query is an update or delete statement, or an insert statement
1610        that did not insert exactly one row in a table with OIDs, then the
1611        number of rows affected is returned as a string.  If it is a statement
1612        that returns rows as a result (usually a select statement, but maybe
1613        also an "insert/update ... returning" statement), this method returns
1614        a Query object that can be accessed via getresult() or dictresult()
1615        or simply printed.  Otherwise, it returns `None`.
1616
1617        The query can contain numbered parameters of the form $1 in place
1618        of any data constant.  Arguments given after the query string will
1619        be substituted for the corresponding numbered parameter.  Parameter
1620        values can also be given as a single list or tuple argument.
1621        """
1622        # Wraps shared library function for debugging.
1623        if not self.db:
1624            raise _int_error('Connection is not valid')
1625        if args:
1626            self._do_debug(command, args)
1627            return self.db.query(command, args)
1628        self._do_debug(command)
1629        return self.db.query(command)
1630
1631    def query_formatted(self, command, parameters, types=None, inline=False):
1632        """Execute a formatted SQL command string.
1633
1634        Similar to query, but using Python format placeholders of the form
1635        %s or %(names)s instead of PostgreSQL placeholders of the form $1.
1636        The parameters must be passed as a tuple, list or dict.  You can
1637        also pass a corresponding tuple, list or dict of database types in
1638        order to format the parameters properly in case there is ambiguity.
1639
1640        If you set inline to True, the parameters will be sent to the database
1641        embedded in the SQL command, otherwise they will be sent separately.
1642        """
1643        return self.query(*self.adapter.format_query(
1644            command, parameters, types, inline))
1645
1646    def pkey(self, table, composite=False, flush=False):
1647        """Get or set the primary key of a table.
1648
1649        Single primary keys are returned as strings unless you
1650        set the composite flag.  Composite primary keys are always
1651        represented as tuples.  Note that this raises a KeyError
1652        if the table does not have a primary key.
1653
1654        If flush is set then the internal cache for primary keys will
1655        be flushed.  This may be necessary after the database schema or
1656        the search path has been changed.
1657        """
1658        pkeys = self._pkeys
1659        if flush:
1660            pkeys.clear()
1661            self._do_debug('The pkey cache has been flushed')
1662        try:  # cache lookup
1663            pkey = pkeys[table]
1664        except KeyError:  # cache miss, check the database
1665            q = ("SELECT a.attname, a.attnum, i.indkey FROM pg_index i"
1666                " JOIN pg_attribute a ON a.attrelid = i.indrelid"
1667                " AND a.attnum = ANY(i.indkey)"
1668                " AND NOT a.attisdropped"
1669                " WHERE i.indrelid=%s::regclass"
1670                " AND i.indisprimary ORDER BY a.attnum") % (
1671                    _quote_if_unqualified('$1', table),)
1672            pkey = self.db.query(q, (table,)).getresult()
1673            if not pkey:
1674                raise KeyError('Table %s has no primary key' % table)
1675            # we want to use the order defined in the primary key index here,
1676            # not the order as defined by the columns in the table
1677            if len(pkey) > 1:
1678                indkey = pkey[0][2]
1679                pkey = sorted(pkey, key=lambda row: indkey.index(row[1]))
1680                pkey = tuple(row[0] for row in pkey)
1681            else:
1682                pkey = pkey[0][0]
1683            pkeys[table] = pkey  # cache it
1684        if composite and not isinstance(pkey, tuple):
1685            pkey = (pkey,)
1686        return pkey
1687
1688    def get_databases(self):
1689        """Get list of databases in the system."""
1690        return [s[0] for s in
1691            self.db.query('SELECT datname FROM pg_database').getresult()]
1692
1693    def get_relations(self, kinds=None):
1694        """Get list of relations in connected database of specified kinds.
1695
1696        If kinds is None or empty, all kinds of relations are returned.
1697        Otherwise kinds can be a string or sequence of type letters
1698        specifying which kind of relations you want to list.
1699        """
1700        where = " AND r.relkind IN (%s)" % ','.join(
1701            ["'%s'" % k for k in kinds]) if kinds else ''
1702        q = ("SELECT quote_ident(s.nspname)||'.'||quote_ident(r.relname)"
1703            " FROM pg_class r"
1704            " JOIN pg_namespace s ON s.oid = r.relnamespace"
1705            " WHERE s.nspname NOT SIMILAR"
1706            " TO 'pg/_%%|information/_schema' ESCAPE '/' %s"
1707            " ORDER BY s.nspname, r.relname") % where
1708        return [r[0] for r in self.db.query(q).getresult()]
1709
1710    def get_tables(self):
1711        """Return list of tables in connected database."""
1712        return self.get_relations('r')
1713
1714    def get_attnames(self, table, with_oid=True, flush=False):
1715        """Given the name of a table, dig out the set of attribute names.
1716
1717        Returns a read-only dictionary of attribute names (the names are
1718        the keys, the values are the names of the attributes' types)
1719        with the column names in the proper order if you iterate over it.
1720
1721        If flush is set, then the internal cache for attribute names will
1722        be flushed. This may be necessary after the database schema or
1723        the search path has been changed.
1724
1725        By default, only a limited number of simple types will be returned.
1726        You can get the regular types after calling use_regtypes(True).
1727        """
1728        attnames = self._attnames
1729        if flush:
1730            attnames.clear()
1731            self._do_debug('The attnames cache has been flushed')
1732        try:  # cache lookup
1733            names = attnames[table]
1734        except KeyError:  # cache miss, check the database
1735            q = "a.attnum > 0"
1736            if with_oid:
1737                q = "(%s OR a.attname = 'oid')" % q
1738            q = ("SELECT a.attname, t.oid, t.typname, t.typname::regtype,"
1739                " t.typtype, t.typcategory, t.typdelim, t.typrelid"
1740                " FROM pg_attribute a"
1741                " JOIN pg_type t ON t.oid = a.atttypid"
1742                " WHERE a.attrelid = %s::regclass AND %s"
1743                " AND NOT a.attisdropped ORDER BY a.attnum") % (
1744                    _quote_if_unqualified('$1', table), q)
1745            names = self.db.query(q, (table,)).getresult()
1746            types = self.dbtypes
1747            names = ((name[0], types.add(*name[1:])) for name in names)
1748            names = AttrDict(names)
1749            attnames[table] = names  # cache it
1750        return names
1751
1752    def use_regtypes(self, regtypes=None):
1753        """Use regular type names instead of simplified type names."""
1754        if regtypes is None:
1755            return self.dbtypes._regtypes
1756        else:
1757            regtypes = bool(regtypes)
1758            if regtypes != self.dbtypes._regtypes:
1759                self.dbtypes._regtypes = regtypes
1760                self._attnames.clear()
1761                self.dbtypes.clear()
1762            return regtypes
1763
1764    def has_table_privilege(self, table, privilege='select'):
1765        """Check whether current user has specified table privilege."""
1766        privilege = privilege.lower()
1767        try:  # ask cache
1768            return self._privileges[(table, privilege)]
1769        except KeyError:  # cache miss, ask the database
1770            q = "SELECT has_table_privilege(%s, $2)" % (
1771                _quote_if_unqualified('$1', table),)
1772            q = self.db.query(q, (table, privilege))
1773            ret = q.getresult()[0][0] == self._make_bool(True)
1774            self._privileges[(table, privilege)] = ret  # cache it
1775            return ret
1776
1777    def get(self, table, row, keyname=None):
1778        """Get a row from a database table or view.
1779
1780        This method is the basic mechanism to get a single row.  It assumes
1781        that the keyname specifies a unique row.  It must be the name of a
1782        single column or a tuple of column names.  If the keyname is not
1783        specified, then the primary key for the table is used.
1784
1785        If row is a dictionary, then the value for the key is taken from it.
1786        Otherwise, the row must be a single value or a tuple of values
1787        corresponding to the passed keyname or primary key.  The fetched row
1788        from the table will be returned as a new dictionary or used to replace
1789        the existing values when row was passed as aa dictionary.
1790
1791        The OID is also put into the dictionary if the table has one, but
1792        in order to allow the caller to work with multiple tables, it is
1793        munged as "oid(table)" using the actual name of the table.
1794        """
1795        if table.endswith('*'):  # hint for descendant tables can be ignored
1796            table = table[:-1].rstrip()
1797        attnames = self.get_attnames(table)
1798        qoid = _oid_key(table) if 'oid' in attnames else None
1799        if keyname and isinstance(keyname, basestring):
1800            keyname = (keyname,)
1801        if qoid and isinstance(row, dict) and qoid in row and 'oid' not in row:
1802            row['oid'] = row[qoid]
1803        if not keyname:
1804            try:  # if keyname is not specified, try using the primary key
1805                keyname = self.pkey(table, True)
1806            except KeyError:  # the table has no primary key
1807                # try using the oid instead
1808                if qoid and isinstance(row, dict) and 'oid' in row:
1809                    keyname = ('oid',)
1810                else:
1811                    raise _prg_error('Table %s has no primary key' % table)
1812            else:  # the table has a primary key
1813                # check whether all key columns have values
1814                if isinstance(row, dict) and not set(keyname).issubset(row):
1815                    # try using the oid instead
1816                    if qoid and 'oid' in row:
1817                        keyname = ('oid',)
1818                    else:
1819                        raise KeyError(
1820                            'Missing value in row for specified keyname')
1821        if not isinstance(row, dict):
1822            if not isinstance(row, (tuple, list)):
1823                row = [row]
1824            if len(keyname) != len(row):
1825                raise KeyError(
1826                    'Differing number of items in keyname and row')
1827            row = dict(zip(keyname, row))
1828        params = self.adapter.parameter_list()
1829        adapt = params.add
1830        col = self.escape_identifier
1831        what = 'oid, *' if qoid else '*'
1832        where = ' AND '.join('%s = %s' % (
1833            col(k), adapt(row[k], attnames[k])) for k in keyname)
1834        if 'oid' in row:
1835            if qoid:
1836                row[qoid] = row['oid']
1837            del row['oid']
1838        q = 'SELECT %s FROM %s WHERE %s LIMIT 1' % (
1839            what, self._escape_qualified_name(table), where)
1840        self._do_debug(q, params)
1841        q = self.db.query(q, params)
1842        res = q.dictresult()
1843        if not res:
1844            raise _db_error('No such record in %s\nwhere %s\nwith %s' % (
1845                table, where, self._list_params(params)))
1846        for n, value in res[0].items():
1847            if qoid and n == 'oid':
1848                n = qoid
1849            row[n] = value
1850        return row
1851
1852    def insert(self, table, row=None, **kw):
1853        """Insert a row into a database table.
1854
1855        This method inserts a row into a table.  The name of the table must
1856        be passed as the first parameter.  The other parameters are used for
1857        providing the data of the row that shall be inserted into the table.
1858        If a dictionary is supplied as the second parameter, it starts with
1859        that.  Otherwise it uses a blank dictionary. Either way the dictionary
1860        is updated from the keywords.
1861
1862        The dictionary is then reloaded with the values actually inserted in
1863        order to pick up values modified by rules, triggers, etc.
1864        """
1865        if table.endswith('*'):  # hint for descendant tables can be ignored
1866            table = table[:-1].rstrip()
1867        if row is None:
1868            row = {}
1869        row.update(kw)
1870        if 'oid' in row:
1871            del row['oid']  # do not insert oid
1872        attnames = self.get_attnames(table)
1873        qoid = _oid_key(table) if 'oid' in attnames else None
1874        params = self.adapter.parameter_list()
1875        adapt = params.add
1876        col = self.escape_identifier
1877        names, values = [], []
1878        for n in attnames:
1879            if n in row:
1880                names.append(col(n))
1881                values.append(adapt(row[n], attnames[n]))
1882        if not names:
1883            raise _prg_error('No column found that can be inserted')
1884        names, values = ', '.join(names), ', '.join(values)
1885        ret = 'oid, *' if qoid else '*'
1886        q = 'INSERT INTO %s (%s) VALUES (%s) RETURNING %s' % (
1887            self._escape_qualified_name(table), names, values, ret)
1888        self._do_debug(q, params)
1889        q = self.db.query(q, params)
1890        res = q.dictresult()
1891        if res:  # this should always be true
1892            for n, value in res[0].items():
1893                if qoid and n == 'oid':
1894                    n = qoid
1895                row[n] = value
1896        return row
1897
1898    def update(self, table, row=None, **kw):
1899        """Update an existing row in a database table.
1900
1901        Similar to insert but updates an existing row.  The update is based
1902        on the primary key of the table or the OID value as munged by get
1903        or passed as keyword.
1904
1905        The dictionary is then modified to reflect any changes caused by the
1906        update due to triggers, rules, default values, etc.
1907        """
1908        if table.endswith('*'):
1909            table = table[:-1].rstrip()  # need parent table name
1910        attnames = self.get_attnames(table)
1911        qoid = _oid_key(table) if 'oid' in attnames else None
1912        if row is None:
1913            row = {}
1914        elif 'oid' in row:
1915            del row['oid']  # only accept oid key from named args for safety
1916        row.update(kw)
1917        if qoid and qoid in row and 'oid' not in row:
1918            row['oid'] = row[qoid]
1919        try:  # try using the primary key
1920            keyname = self.pkey(table, True)
1921        except KeyError:  # the table has no primary key
1922            # try using the oid instead
1923            if qoid and 'oid' in row:
1924                keyname = ('oid',)
1925            else:
1926                raise _prg_error('Table %s has no primary key' % table)
1927        else:  # the table has a primary key
1928            # check whether all key columns have values
1929            if not set(keyname).issubset(row):
1930                # try using the oid instead
1931                if qoid and 'oid' in row:
1932                    keyname = ('oid',)
1933                else:
1934                    raise KeyError('Missing primary key in row')
1935        params = self.adapter.parameter_list()
1936        adapt = params.add
1937        col = self.escape_identifier
1938        where = ' AND '.join('%s = %s' % (
1939            col(k), adapt(row[k], attnames[k])) for k in keyname)
1940        if 'oid' in row:
1941            if qoid:
1942                row[qoid] = row['oid']
1943            del row['oid']
1944        values = []
1945        keyname = set(keyname)
1946        for n in attnames:
1947            if n in row and n not in keyname:
1948                values.append('%s = %s' % (col(n), adapt(row[n], attnames[n])))
1949        if not values:
1950            return row
1951        values = ', '.join(values)
1952        ret = 'oid, *' if qoid else '*'
1953        q = 'UPDATE %s SET %s WHERE %s RETURNING %s' % (
1954            self._escape_qualified_name(table), values, where, ret)
1955        self._do_debug(q, params)
1956        q = self.db.query(q, params)
1957        res = q.dictresult()
1958        if res:  # may be empty when row does not exist
1959            for n, value in res[0].items():
1960                if qoid and n == 'oid':
1961                    n = qoid
1962                row[n] = value
1963        return row
1964
1965    def upsert(self, table, row=None, **kw):
1966        """Insert a row into a database table with conflict resolution
1967
1968        This method inserts a row into a table, but instead of raising a
1969        ProgrammingError exception in case a row with the same primary key
1970        already exists, an update will be executed instead.  This will be
1971        performed as a single atomic operation on the database, so race
1972        conditions can be avoided.
1973
1974        Like the insert method, the first parameter is the name of the
1975        table and the second parameter can be used to pass the values to
1976        be inserted as a dictionary.
1977
1978        Unlike the insert und update statement, keyword parameters are not
1979        used to modify the dictionary, but to specify which columns shall
1980        be updated in case of a conflict, and in which way:
1981
1982        A value of False or None means the column shall not be updated,
1983        a value of True means the column shall be updated with the value
1984        that has been proposed for insertion, i.e. has been passed as value
1985        in the dictionary.  Columns that are not specified by keywords but
1986        appear as keys in the dictionary are also updated like in the case
1987        keywords had been passed with the value True.
1988
1989        So if in the case of a conflict you want to update every column that
1990        has been passed in the dictionary row , you would call upsert(table, row).
1991        If you don't want to do anything in case of a conflict, i.e. leave
1992        the existing row as it is, call upsert(table, row, **dict.fromkeys(row)).
1993
1994        If you need more fine-grained control of what gets updated, you can
1995        also pass strings in the keyword parameters.  These strings will
1996        be used as SQL expressions for the update columns.  In these
1997        expressions you can refer to the value that already exists in
1998        the table by prefixing the column name with "included.", and to
1999        the value that has been proposed for insertion by prefixing the
2000        column name with the "excluded."
2001
2002        The dictionary is modified in any case to reflect the values in
2003        the database after the operation has completed.
2004
2005        Note: The method uses the PostgreSQL "upsert" feature which is
2006        only available since PostgreSQL 9.5.
2007        """
2008        if table.endswith('*'):  # hint for descendant tables can be ignored
2009            table = table[:-1].rstrip()
2010        if row is None:
2011            row = {}
2012        if 'oid' in row:
2013            del row['oid']  # do not insert oid
2014        if 'oid' in kw:
2015            del kw['oid']  # do not update oid
2016        attnames = self.get_attnames(table)
2017        qoid = _oid_key(table) if 'oid' in attnames else None
2018        params = self.adapter.parameter_list()
2019        adapt = params.add
2020        col = self.escape_identifier
2021        names, values, updates = [], [], []
2022        for n in attnames:
2023            if n in row:
2024                names.append(col(n))
2025                values.append(adapt(row[n], attnames[n]))
2026        names, values = ', '.join(names), ', '.join(values)
2027        try:
2028            keyname = self.pkey(table, True)
2029        except KeyError:
2030            raise _prg_error('Table %s has no primary key' % table)
2031        target = ', '.join(col(k) for k in keyname)
2032        update = []
2033        keyname = set(keyname)
2034        keyname.add('oid')
2035        for n in attnames:
2036            if n not in keyname:
2037                value = kw.get(n, True)
2038                if value:
2039                    if not isinstance(value, basestring):
2040                        value = 'excluded.%s' % col(n)
2041                    update.append('%s = %s' % (col(n), value))
2042        if not values:
2043            return row
2044        do = 'update set %s' % ', '.join(update) if update else 'nothing'
2045        ret = 'oid, *' if qoid else '*'
2046        q = ('INSERT INTO %s AS included (%s) VALUES (%s)'
2047            ' ON CONFLICT (%s) DO %s RETURNING %s') % (
2048                self._escape_qualified_name(table), names, values,
2049                target, do, ret)
2050        self._do_debug(q, params)
2051        try:
2052            q = self.db.query(q, params)
2053        except ProgrammingError:
2054            if self.server_version < 90500:
2055                raise _prg_error(
2056                    'Upsert operation is not supported by PostgreSQL version')
2057            raise  # re-raise original error
2058        res = q.dictresult()
2059        if res:  # may be empty with "do nothing"
2060            for n, value in res[0].items():
2061                if qoid and n == 'oid':
2062                    n = qoid
2063                row[n] = value
2064        else:
2065            self.get(table, row)
2066        return row
2067
2068    def clear(self, table, row=None):
2069        """Clear all the attributes to values determined by the types.
2070
2071        Numeric types are set to 0, Booleans are set to false, and everything
2072        else is set to the empty string.  If the row argument is present,
2073        it is used as the row dictionary and any entries matching attribute
2074        names are cleared with everything else left unchanged.
2075        """
2076        # At some point we will need a way to get defaults from a table.
2077        if row is None:
2078            row = {}  # empty if argument is not present
2079        attnames = self.get_attnames(table)
2080        for n, t in attnames.items():
2081            if n == 'oid':
2082                continue
2083            t = t.simple
2084            if t in DbTypes._num_types:
2085                row[n] = 0
2086            elif t == 'bool':
2087                row[n] = self._make_bool(False)
2088            else:
2089                row[n] = ''
2090        return row
2091
2092    def delete(self, table, row=None, **kw):
2093        """Delete an existing row in a database table.
2094
2095        This method deletes the row from a table.  It deletes based on the
2096        primary key of the table or the OID value as munged by get() or
2097        passed as keyword.
2098
2099        The return value is the number of deleted rows (i.e. 0 if the row
2100        did not exist and 1 if the row was deleted).
2101
2102        Note that if the row cannot be deleted because e.g. it is still
2103        referenced by another table, this method raises a ProgrammingError.
2104        """
2105        if table.endswith('*'):  # hint for descendant tables can be ignored
2106            table = table[:-1].rstrip()
2107        attnames = self.get_attnames(table)
2108        qoid = _oid_key(table) if 'oid' in attnames else None
2109        if row is None:
2110            row = {}
2111        elif 'oid' in row:
2112            del row['oid']  # only accept oid key from named args for safety
2113        row.update(kw)
2114        if qoid and qoid in row and 'oid' not in row:
2115            row['oid'] = row[qoid]
2116        try:  # try using the primary key
2117            keyname = self.pkey(table, True)
2118        except KeyError:  # the table has no primary key
2119            # try using the oid instead
2120            if qoid and 'oid' in row:
2121                keyname = ('oid',)
2122            else:
2123                raise _prg_error('Table %s has no primary key' % table)
2124        else:  # the table has a primary key
2125            # check whether all key columns have values
2126            if not set(keyname).issubset(row):
2127                # try using the oid instead
2128                if qoid and 'oid' in row:
2129                    keyname = ('oid',)
2130                else:
2131                    raise KeyError('Missing primary key in row')
2132        params = self.adapter.parameter_list()
2133        adapt = params.add
2134        col = self.escape_identifier
2135        where = ' AND '.join('%s = %s' % (
2136            col(k), adapt(row[k], attnames[k])) for k in keyname)
2137        if 'oid' in row:
2138            if qoid:
2139                row[qoid] = row['oid']
2140            del row['oid']
2141        q = 'DELETE FROM %s WHERE %s' % (
2142            self._escape_qualified_name(table), where)
2143        self._do_debug(q, params)
2144        res = self.db.query(q, params)
2145        return int(res)
2146
2147    def truncate(self, table, restart=False, cascade=False, only=False):
2148        """Empty a table or set of tables.
2149
2150        This method quickly removes all rows from the given table or set
2151        of tables.  It has the same effect as an unqualified DELETE on each
2152        table, but since it does not actually scan the tables it is faster.
2153        Furthermore, it reclaims disk space immediately, rather than requiring
2154        a subsequent VACUUM operation. This is most useful on large tables.
2155
2156        If restart is set to True, sequences owned by columns of the truncated
2157        table(s) are automatically restarted.  If cascade is set to True, it
2158        also truncates all tables that have foreign-key references to any of
2159        the named tables.  If the parameter only is not set to True, all the
2160        descendant tables (if any) will also be truncated. Optionally, a '*'
2161        can be specified after the table name to explicitly indicate that
2162        descendant tables are included.
2163        """
2164        if isinstance(table, basestring):
2165            only = {table: only}
2166            table = [table]
2167        elif isinstance(table, (list, tuple)):
2168            if isinstance(only, (list, tuple)):
2169                only = dict(zip(table, only))
2170            else:
2171                only = dict.fromkeys(table, only)
2172        elif isinstance(table, (set, frozenset)):
2173            only = dict.fromkeys(table, only)
2174        else:
2175            raise TypeError('The table must be a string, list or set')
2176        if not (restart is None or isinstance(restart, (bool, int))):
2177            raise TypeError('Invalid type for the restart option')
2178        if not (cascade is None or isinstance(cascade, (bool, int))):
2179            raise TypeError('Invalid type for the cascade option')
2180        tables = []
2181        for t in table:
2182            u = only.get(t)
2183            if not (u is None or isinstance(u, (bool, int))):
2184                raise TypeError('Invalid type for the only option')
2185            if t.endswith('*'):
2186                if u:
2187                    raise ValueError(
2188                        'Contradictory table name and only options')
2189                t = t[:-1].rstrip()
2190            t = self._escape_qualified_name(t)
2191            if u:
2192                t = 'ONLY %s' % t
2193            tables.append(t)
2194        q = ['TRUNCATE', ', '.join(tables)]
2195        if restart:
2196            q.append('RESTART IDENTITY')
2197        if cascade:
2198            q.append('CASCADE')
2199        q = ' '.join(q)
2200        self._do_debug(q)
2201        return self.db.query(q)
2202
2203    def get_as_list(self, table, what=None, where=None,
2204            order=None, limit=None, offset=None, scalar=False):
2205        """Get a table as a list.
2206
2207        This gets a convenient representation of the table as a list
2208        of named tuples in Python.  You only need to pass the name of
2209        the table (or any other SQL expression returning rows).  Note that
2210        by default this will return the full content of the table which
2211        can be huge and overflow your memory.  However, you can control
2212        the amount of data returned using the other optional parameters.
2213
2214        The parameter 'what' can restrict the query to only return a
2215        subset of the table columns.  It can be a string, list or a tuple.
2216        The parameter 'where' can restrict the query to only return a
2217        subset of the table rows.  It can be a string, list or a tuple
2218        of SQL expressions that all need to be fulfilled.  The parameter
2219        'order' specifies the ordering of the rows.  It can also be a
2220        other string, list or a tuple.  If no ordering is specified,
2221        the result will be ordered by the primary key(s) or all columns
2222        if no primary key exists.  You can set 'order' to False if you
2223        don't care about the ordering.  The parameters 'limit' and 'offset'
2224        can be integers specifying the maximum number of rows returned
2225        and a number of rows skipped over.
2226
2227        If you set the 'scalar' option to True, then instead of the
2228        named tuples you will get the first items of these tuples.
2229        This is useful if the result has only one column anyway.
2230        """
2231        if not table:
2232            raise TypeError('The table name is missing')
2233        if what:
2234            if isinstance(what, (list, tuple)):
2235                what = ', '.join(map(str, what))
2236            if order is None:
2237                order = what
2238        else:
2239            what = '*'
2240        q = ['SELECT', what, 'FROM', table]
2241        if where:
2242            if isinstance(where, (list, tuple)):
2243                where = ' AND '.join(map(str, where))
2244            q.extend(['WHERE', where])
2245        if order is None:
2246            try:
2247                order = self.pkey(table, True)
2248            except (KeyError, ProgrammingError):
2249                try:
2250                    order = list(self.get_attnames(table))
2251                except (KeyError, ProgrammingError):
2252                    pass
2253        if order:
2254            if isinstance(order, (list, tuple)):
2255                order = ', '.join(map(str, order))
2256            q.extend(['ORDER BY', order])
2257        if limit:
2258            q.append('LIMIT %d' % limit)
2259        if offset:
2260            q.append('OFFSET %d' % offset)
2261        q = ' '.join(q)
2262        self._do_debug(q)
2263        q = self.db.query(q)
2264        res = q.namedresult()
2265        if res and scalar:
2266            res = [row[0] for row in res]
2267        return res
2268
2269    def get_as_dict(self, table, keyname=None, what=None, where=None,
2270            order=None, limit=None, offset=None, scalar=False):
2271        """Get a table as a dictionary.
2272
2273        This method is similar to get_as_list(), but returns the table
2274        as a Python dict instead of a Python list, which can be even
2275        more convenient. The primary key column(s) of the table will
2276        be used as the keys of the dictionary, while the other column(s)
2277        will be the corresponding values.  The keys will be named tuples
2278        if the table has a composite primary key.  The rows will be also
2279        named tuples unless the 'scalar' option has been set to True.
2280        With the optional parameter 'keyname' you can specify an alternative
2281        set of columns to be used as the keys of the dictionary.  It must
2282        be set as a string, list or a tuple.
2283
2284        If the Python version supports it, the dictionary will be an
2285        OrderedDict using the order specified with the 'order' parameter
2286        or the key column(s) if not specified.  You can set 'order' to False
2287        if you don't care about the ordering.  In this case the returned
2288        dictionary will be an ordinary one.
2289        """
2290        if not table:
2291            raise TypeError('The table name is missing')
2292        if not keyname:
2293            try:
2294                keyname = self.pkey(table, True)
2295            except (KeyError, ProgrammingError):
2296                raise _prg_error('Table %s has no primary key' % table)
2297        if isinstance(keyname, basestring):
2298            keyname = [keyname]
2299        elif not isinstance(keyname, (list, tuple)):
2300            raise KeyError('The keyname must be a string, list or tuple')
2301        if what:
2302            if isinstance(what, (list, tuple)):
2303                what = ', '.join(map(str, what))
2304            if order is None:
2305                order = what
2306        else:
2307            what = '*'
2308        q = ['SELECT', what, 'FROM', table]
2309        if where:
2310            if isinstance(where, (list, tuple)):
2311                where = ' AND '.join(map(str, where))
2312            q.extend(['WHERE', where])
2313        if order is None:
2314            order = keyname
2315        if order:
2316            if isinstance(order, (list, tuple)):
2317                order = ', '.join(map(str, order))
2318            q.extend(['ORDER BY', order])
2319        if limit:
2320            q.append('LIMIT %d' % limit)
2321        if offset:
2322            q.append('OFFSET %d' % offset)
2323        q = ' '.join(q)
2324        self._do_debug(q)
2325        q = self.db.query(q)
2326        res = q.getresult()
2327        cls = OrderedDict if order else dict
2328        if not res:
2329            return cls()
2330        keyset = set(keyname)
2331        fields = q.listfields()
2332        if not keyset.issubset(fields):
2333            raise KeyError('Missing keyname in row')
2334        keyind, rowind = [], []
2335        for i, f in enumerate(fields):
2336            (keyind if f in keyset else rowind).append(i)
2337        keytuple = len(keyind) > 1
2338        getkey = itemgetter(*keyind)
2339        keys = map(getkey, res)
2340        if scalar:
2341            rowind = rowind[:1]
2342            rowtuple = False
2343        else:
2344            rowtuple = len(rowind) > 1
2345        if scalar or rowtuple:
2346            getrow = itemgetter(*rowind)
2347        else:
2348            rowind = rowind[0]
2349            getrow = lambda row: (row[rowind],)
2350            rowtuple = True
2351        rows = map(getrow, res)
2352        if keytuple or rowtuple:
2353            namedresult = get_namedresult()
2354            if namedresult:
2355                if keytuple:
2356                    keys = namedresult(_MemoryQuery(keys, keyname))
2357                if rowtuple:
2358                    fields = [f for f in fields if f not in keyset]
2359                    rows = namedresult(_MemoryQuery(rows, fields))
2360        return cls(zip(keys, rows))
2361
2362    def notification_handler(self,
2363            event, callback, arg_dict=None, timeout=None, stop_event=None):
2364        """Get notification handler that will run the given callback."""
2365        return NotificationHandler(self,
2366            event, callback, arg_dict, timeout, stop_event)
2367
2368
2369# if run as script, print some information
2370
2371if __name__ == '__main__':
2372    print('PyGreSQL version' + version)
2373    print('')
2374    print(__doc__)
Note: See TracBrowser for help on using the repository browser.