source: trunk/pg.py @ 818

Last change on this file since 818 was 818, checked in by cito, 3 years ago

Add list of supported data types.

  • Property svn:keywords set to Id
File size: 86.8 KB
Line 
1#! /usr/bin/python
2#
3# pg.py
4#
5# $Id: pg.py 818 2016-02-04 20:56:19Z cito $
6#
7
8"""PyGreSQL classic interface.
9
10This pg module implements some basic database management stuff.
11It includes the _pg module and builds on it, providing the higher
12level wrapper class named DB with additional functionality.
13This is known as the "classic" ("old style") PyGreSQL interface.
14For a DB-API 2 compliant interface use the newer pgdb module.
15"""
16
17# Copyright (c) 1997-2016 by D'Arcy J.M. Cain.
18#
19# Contributions made by Ch. Zwerschke and others.
20#
21# The notification handler is based on pgnotify which is
22# Copyright (c) 2001 Ng Pheng Siong. All rights reserved.
23#
24# Permission to use, copy, modify, and distribute this software and its
25# documentation for any purpose and without fee is hereby granted,
26# provided that the above copyright notice appear in all copies and that
27# both that copyright notice and this permission notice appear in
28# supporting documentation.
29
30from __future__ import print_function
31
32from _pg import *
33
34import select
35import warnings
36
37from datetime import date, time, datetime, timedelta
38from decimal import Decimal
39from math import isnan, isinf
40from collections import namedtuple
41from operator import itemgetter
42from functools import partial
43from re import compile as regex
44from json import loads as jsondecode, dumps as jsonencode
45
46try:
47    long
48except NameError:  # Python >= 3.0
49    long = int
50
51try:
52    basestring
53except NameError:  # Python >= 3.0
54    basestring = (str, bytes)
55
56
57# Auxiliary classes and functions that are independent from a DB connection:
58
59try:
60    from collections import OrderedDict
61except ImportError:  # Python 2.6 or 3.0
62    OrderedDict = dict
63
64
65    class AttrDict(dict):
66        """Simple read-only ordered dictionary for storing attribute names."""
67
68        def __init__(self, *args, **kw):
69            if len(args) > 1 or kw:
70                raise TypeError
71            items = args[0] if args else []
72            if isinstance(items, dict):
73                raise TypeError
74            items = list(items)
75            self._keys = [item[0] for item in items]
76            dict.__init__(self, items)
77            self._read_only = True
78            error = self._read_only_error
79            self.clear = self.update = error
80            self.pop = self.setdefault = self.popitem = error
81
82        def __setitem__(self, key, value):
83            if self._read_only:
84                self._read_only_error()
85            dict.__setitem__(self, key, value)
86
87        def __delitem__(self, key):
88            if self._read_only:
89                self._read_only_error()
90            dict.__delitem__(self, key)
91
92        def __iter__(self):
93            return iter(self._keys)
94
95        def keys(self):
96            return list(self._keys)
97
98        def values(self):
99            return [self[key] for key in self]
100
101        def items(self):
102            return [(key, self[key]) for key in self]
103
104        def iterkeys(self):
105            return self.__iter__()
106
107        def itervalues(self):
108            return iter(self.values())
109
110        def iteritems(self):
111            return iter(self.items())
112
113        @staticmethod
114        def _read_only_error(*args, **kw):
115            raise TypeError('This object is read-only')
116
117else:
118
119     class AttrDict(OrderedDict):
120        """Simple read-only ordered dictionary for storing attribute names."""
121
122        def __init__(self, *args, **kw):
123            self._read_only = False
124            OrderedDict.__init__(self, *args, **kw)
125            self._read_only = True
126            error = self._read_only_error
127            self.clear = self.update = error
128            self.pop = self.setdefault = self.popitem = error
129
130        def __setitem__(self, key, value):
131            if self._read_only:
132                self._read_only_error()
133            OrderedDict.__setitem__(self, key, value)
134
135        def __delitem__(self, key):
136            if self._read_only:
137                self._read_only_error()
138            OrderedDict.__delitem__(self, key)
139
140        @staticmethod
141        def _read_only_error(*args, **kw):
142            raise TypeError('This object is read-only')
143
144try:
145    from inspect import signature
146except ImportError:  # Python < 3.3
147    from inspect import getargspec
148
149    def get_args(func):
150        return getargspec(func).args
151else:
152
153    def get_args(func):
154        return list(signature(func).parameters)
155
156try:
157    if datetime.strptime('+0100', '%z') is None:
158        raise ValueError
159except ValueError:  # Python < 3.2
160    timezones = None
161else:
162    # time zones used in Postgres timestamptz output
163    timezones = dict(CET='+0100', EET='+0200', EST='-0500',
164        GMT='+0000', HST='-1000', MET='+0100', MST='-0700',
165        UCT='+0000', UTC='+0000', WET='+0000')
166
167
168def _oid_key(table):
169    """Build oid key from a table name."""
170    return 'oid(%s)' % table
171
172
173class _SimpleTypes(dict):
174    """Dictionary mapping pg_type names to simple type names."""
175
176    _types = {'bool': 'bool',
177        'bytea': 'bytea',
178        'date': 'date interval time timetz timestamp timestamptz'
179            ' abstime reltime',  # these are very old
180        'float': 'float4 float8',
181        'int': 'cid int2 int4 int8 oid xid',
182        'hstore': 'hstore', 'json': 'json jsonb',
183        'num': 'numeric',
184        'money': 'money',
185        'text': 'bpchar char name text varchar'}
186
187    def __init__(self):
188        for typ, keys in self._types.items():
189            for key in keys.split():
190                self[key] = typ
191                self['_%s' % key] = '%s[]' % typ
192
193    @staticmethod
194    def __missing__(key):
195        return 'text'
196
197_simpletypes = _SimpleTypes()
198
199
200def _quote_if_unqualified(param, name):
201    """Quote parameter representing a qualified name.
202
203    Puts a quote_ident() call around the give parameter unless
204    the name contains a dot, in which case the name is ambiguous
205    (could be a qualified name or just a name with a dot in it)
206    and must be quoted manually by the caller.
207    """
208    if isinstance(name, basestring) and '.' not in name:
209        return 'quote_ident(%s)' % (param,)
210    return param
211
212
213class _ParameterList(list):
214    """Helper class for building typed parameter lists."""
215
216    def add(self, value, typ=None):
217        """Typecast value with known database type and build parameter list.
218
219        If this is a literal value, it will be returned as is.  Otherwise, a
220        placeholder will be returned and the parameter list will be augmented.
221        """
222        value = self.adapt(value, typ)
223        if isinstance(value, Literal):
224            return value
225        self.append(value)
226        return '$%d' % len(self)
227
228
229class Bytea(bytes):
230    """Wrapper class for marking Bytea values."""
231
232
233class Hstore(dict):
234    """Wrapper class for marking hstore values."""
235
236    _re_quote = regex('^[Nn][Uu][Ll][Ll]$|[ ,=>]')
237
238    @classmethod
239    def _quote(cls, s):
240        if s is None:
241            return 'NULL'
242        if not s:
243            return '""'
244        s = s.replace('"', '\\"')
245        if cls._re_quote.search(s):
246            s = '"%s"' % s
247        return s
248
249    def __str__(self):
250        q = self._quote
251        return ','.join('%s=>%s' % (q(k), q(v)) for k, v in self.items())
252
253
254class Json:
255    """Wrapper class for marking Json values."""
256
257    def __init__(self, obj):
258        self.obj = obj
259
260
261class Literal(str):
262    """Wrapper class for marking literal SQL values."""
263
264
265class Adapter:
266    """Class providing methods for adapting parameters to the database."""
267
268    _bool_true_values = frozenset('t true 1 y yes on'.split())
269
270    _date_literals = frozenset('current_date current_time'
271        ' current_timestamp localtime localtimestamp'.split())
272
273    _re_array_quote = regex(r'[{},"\\\s]|^[Nn][Uu][Ll][Ll]$')
274    _re_record_quote = regex(r'[(,"\\]')
275    _re_array_escape = _re_record_escape = regex(r'(["\\])')
276
277    def __init__(self, db):
278        self.db = db
279        self.encode_json = db.encode_json
280        db = db.db
281        self.escape_bytea = db.escape_bytea
282        self.escape_string = db.escape_string
283
284    @classmethod
285    def _adapt_bool(cls, v):
286        """Adapt a boolean parameter."""
287        if isinstance(v, basestring):
288            if not v:
289                return None
290            v = v.lower() in cls._bool_true_values
291        return 't' if v else 'f'
292
293    @classmethod
294    def _adapt_date(cls, v):
295        """Adapt a date parameter."""
296        if not v:
297            return None
298        if isinstance(v, basestring) and v.lower() in cls._date_literals:
299            return Literal(v)
300        return v
301
302    @staticmethod
303    def _adapt_num(v):
304        """Adapt a numeric parameter."""
305        if not v and v != 0:
306            return None
307        return v
308
309    _adapt_int = _adapt_float = _adapt_money = _adapt_num
310
311    def _adapt_bytea(self, v):
312        """Adapt a bytea parameter."""
313        return self.escape_bytea(v)
314
315    def _adapt_json(self, v):
316        """Adapt a json parameter."""
317        if not v:
318            return None
319        if isinstance(v, basestring):
320            return v
321        return self.encode_json(v)
322
323    @classmethod
324    def _adapt_text_array(cls, v):
325        """Adapt a text type array parameter."""
326        if isinstance(v, list):
327            adapt = cls._adapt_text_array
328            return '{%s}' % ','.join(adapt(v) for v in v)
329        if v is None:
330            return 'null'
331        if not v:
332            return '""'
333        v = str(v)
334        if cls._re_array_quote.search(v):
335            v = '"%s"' % cls._re_array_escape.sub(r'\\\1', v)
336        return v
337
338    _adapt_date_array = _adapt_text_array
339
340    @classmethod
341    def _adapt_bool_array(cls, v):
342        """Adapt a boolean array parameter."""
343        if isinstance(v, list):
344            adapt = cls._adapt_bool_array
345            return '{%s}' % ','.join(adapt(v) for v in v)
346        if v is None:
347            return 'null'
348        if isinstance(v, basestring):
349            if not v:
350                return 'null'
351            v = v.lower() in cls._bool_true_values
352        return 't' if v else 'f'
353
354    @classmethod
355    def _adapt_num_array(cls, v):
356        """Adapt a numeric array parameter."""
357        if isinstance(v, list):
358            adapt = cls._adapt_num_array
359            return '{%s}' % ','.join(adapt(v) for v in v)
360        if not v and v != 0:
361            return 'null'
362        return str(v)
363
364    _adapt_int_array = _adapt_float_array = _adapt_money_array = \
365            _adapt_num_array
366
367    def _adapt_bytea_array(self, v):
368        """Adapt a bytea array parameter."""
369        if isinstance(v, list):
370            return b'{' + b','.join(
371                self._adapt_bytea_array(v) for v in v) + b'}'
372        if v is None:
373            return b'null'
374        return self.escape_bytea(v).replace(b'\\', b'\\\\')
375
376    def _adapt_json_array(self, v):
377        """Adapt a json array parameter."""
378        if isinstance(v, list):
379            adapt = self._adapt_json_array
380            return '{%s}' % ','.join(adapt(v) for v in v)
381        if not v:
382            return 'null'
383        if not isinstance(v, basestring):
384            v = self.encode_json(v)
385        if self._re_array_quote.search(v):
386            v = '"%s"' % self._re_array_escape.sub(r'\\\1', v)
387        return v
388
389    def _adapt_record(self, v, typ):
390        """Adapt a record parameter with given type."""
391        typ = self.get_attnames(typ).values()
392        if len(typ) != len(v):
393            raise TypeError('Record parameter %s has wrong size' % v)
394        adapt = self.adapt
395        value = []
396        for v, t in zip(v, typ):
397            v = adapt(v, t)
398            if v is None:
399                v = ''
400            elif not v:
401                v = '""'
402            else:
403                if isinstance(v, bytes):
404                    if str is not bytes:
405                        v = v.decode('ascii')
406                else:
407                    v = str(v)
408                if self._re_record_quote.search(v):
409                    v = '"%s"' % self._re_record_escape.sub(r'\\\1', v)
410            value.append(v)
411        return '(%s)' % ','.join(value)
412
413    def adapt(self, value, typ=None):
414        """Adapt a value with known database type."""
415        if value is not None and not isinstance(value, Literal):
416            if typ:
417                simple = self.get_simple_name(typ)
418            else:
419                typ = simple = self.guess_simple_type(value) or 'text'
420            try:
421                value = value.__pg_str__(typ)
422            except AttributeError:
423                pass
424            if simple == 'text':
425                pass
426            elif simple == 'record':
427                if isinstance(value, tuple):
428                    value = self._adapt_record(value, typ)
429            elif simple.endswith('[]'):
430                if isinstance(value, list):
431                    adapt = getattr(self, '_adapt_%s_array' % simple[:-2])
432                    value = adapt(value)
433            else:
434                adapt = getattr(self, '_adapt_%s' % simple)
435                value = adapt(value)
436        return value
437
438    @staticmethod
439    def simple_type(name):
440        """Create a simple database type with given attribute names."""
441        typ = DbType(name)
442        typ.simple = name
443        return typ
444
445    @staticmethod
446    def get_simple_name(typ):
447        """Get the simple name of a database type."""
448        if isinstance(typ, DbType):
449            return typ.simple
450        return _simpletypes[typ]
451
452    @staticmethod
453    def get_attnames(typ):
454        """Get the attribute names of a composite database type."""
455        if isinstance(typ, DbType):
456            return typ.attnames
457        return {}
458
459    @classmethod
460    def guess_simple_type(cls, value):
461        """Try to guess which database type the given value has."""
462        if isinstance(value, Bytea):
463            return 'bytea'
464        if isinstance(value, basestring):
465            return 'text'
466        if isinstance(value, bool):
467            return 'bool'
468        if isinstance(value, (int, long)):
469            return 'int'
470        if isinstance(value, float):
471            return 'float'
472        if isinstance(value, Decimal):
473            return 'num'
474        if isinstance(value, (date, time, datetime, timedelta)):
475            return 'date'
476        if isinstance(value, list):
477            return '%s[]' % cls.guess_simple_base_type(value)
478        if isinstance(value, tuple):
479            simple_type = cls.simple_type
480            typ = simple_type('record')
481            guess = cls.guess_simple_type
482            def get_attnames(self):
483                return AttrDict((str(n + 1), simple_type(guess(v)))
484                    for n, v in enumerate(value))
485            typ._get_attnames = get_attnames
486            return typ
487
488    @classmethod
489    def guess_simple_base_type(cls, value):
490        """Try to guess the base type of a given array."""
491        for v in value:
492            if isinstance(v, list):
493                typ = cls.guess_simple_base_type(v)
494            else:
495                typ = cls.guess_simple_type(v)
496            if typ:
497                return typ
498
499    def adapt_inline(self, value, nested=False):
500        """Adapt a value that is put into the SQL and needs to be quoted."""
501        if value is None:
502            return 'NULL'
503        if isinstance(value, Literal):
504            return value
505        if isinstance(value, Bytea):
506            value = self.escape_bytea(value)
507            if bytes is not str:  # Python >= 3.0
508                value = value.decode('ascii')
509        elif isinstance(value, Json):
510            if value.encode:
511                return value.encode()
512            value = self.encode_json(value)
513        elif isinstance(value, (datetime, date, time, timedelta)):
514            value = str(value)
515        if isinstance(value, basestring):
516            value = self.escape_string(value)
517            return "'%s'" % value
518        if isinstance(value, bool):
519            return 'true' if value else 'false'
520        if isinstance(value, float):
521            if isinf(value):
522                return "'-Infinity'" if value < 0 else "'Infinity'"
523            if isnan(value):
524                return "'NaN'"
525            return value
526        if isinstance(value, (int, long, Decimal)):
527            return value
528        if isinstance(value, list):
529            q = self.adapt_inline
530            s = '[%s]' if nested else 'ARRAY[%s]'
531            return s % ','.join(str(q(v, nested=True)) for v in value)
532        if isinstance(value, tuple):
533            q = self.adapt_inline
534            return '(%s)' % ','.join(str(q(v)) for v in value)
535        try:
536            value = value.__pg_repr__()
537        except AttributeError:
538            raise InterfaceError(
539                'Do not know how to adapt type %s' % type(value))
540        if isinstance(value, (tuple, list)):
541            value = self.adapt_inline(value)
542        return value
543
544    def parameter_list(self):
545        """Return a parameter list for parameters with known database types.
546
547        The list has an add(value, typ) method that will build up the
548        list and return either the literal value or a placeholder.
549        """
550        params = _ParameterList()
551        params.adapt = self.adapt
552        return params
553
554    def format_query(self, command, values, types=None, inline=False):
555        """Format a database query using the given values and types."""
556        if inline and types:
557            raise ValueError('Typed parameters must be sent separately')
558        params = self.parameter_list()
559        if isinstance(values, (list, tuple)):
560            if inline:
561                adapt = self.adapt_inline
562                literals = [adapt(value) for value in values]
563            else:
564                add = params.add
565                literals = []
566                append = literals.append
567                if types:
568                    if (not isinstance(types, (list, tuple)) or
569                            len(types) != len(values)):
570                        raise TypeError('The values and types do not match')
571                    for value, typ in zip(values, types):
572                        append(add(value, typ))
573                else:
574                    for value in values:
575                        append(add(value))
576            command = command % tuple(literals)
577        elif isinstance(values, dict):
578            if inline:
579                adapt = self.adapt_inline
580                literals = dict((key, adapt(value))
581                    for key, value in values.items())
582            else:
583                add = params.add
584                literals = {}
585                if types:
586                    if (not isinstance(types, dict) or
587                            len(types) < len(values)):
588                        raise TypeError('The values and types do not match')
589                    for key in sorted(values):
590                        literals[key] = add(values[key], types[key])
591                else:
592                    for key in sorted(values):
593                        literals[key] = add(values[key])
594            command = command % literals
595        else:
596            raise TypeError('The values must be passed as tuple, list or dict')
597        return command, params
598
599
600def cast_bool(value):
601    """Cast a boolean value."""
602    if not get_bool():
603        return value
604    return value[0] == 't'
605
606
607def cast_json(value):
608    """Cast a JSON value."""
609    cast = get_jsondecode()
610    if not cast:
611        return value
612    return cast(value)
613
614
615def cast_num(value):
616    """Cast a numeric value."""
617    return (get_decimal() or float)(value)
618
619
620def cast_money(value):
621    """Cast a money value."""
622    point = get_decimal_point()
623    if not point:
624        return value
625    if point != '.':
626        value = value.replace(point, '.')
627    value = value.replace('(', '-')
628    value = ''.join(c for c in value if c.isdigit() or c in '.-')
629    return (get_decimal() or float)(value)
630
631
632def cast_int2vector(value):
633    """Cast an int2vector value."""
634    return [int(v) for v in value.split()]
635
636
637def cast_date(value, connection):
638    """Cast a date value."""
639    # The output format depends on the server setting DateStyle.  The default
640    # setting ISO and the setting for German are actually unambiguous.  The
641    # order of days and months in the other two settings is however ambiguous,
642    # so at least here we need to consult the setting to properly parse values.
643    if value == '-infinity':
644        return date.min
645    if value == 'infinity':
646        return date.max
647    value = value.split()
648    if value[-1] == 'BC':
649        return date.min
650    value = value[0]
651    if len(value) > 10:
652        return date.max
653    fmt = connection.date_format()
654    return datetime.strptime(value, fmt).date()
655
656
657def cast_time(value):
658    """Cast a time value."""
659    fmt = '%H:%M:%S.%f' if len(value) > 8 else '%H:%M:%S'
660    return datetime.strptime(value, fmt).time()
661
662
663_re_timezone = regex('(.*)([+-].*)')
664
665
666def cast_timetz(value):
667    """Cast a timetz value."""
668    tz = _re_timezone.match(value)
669    if tz:
670        value, tz = tz.groups()
671    else:
672        tz = '+0000'
673    fmt = '%H:%M:%S.%f' if len(value) > 8 else '%H:%M:%S'
674    if timezones:
675        if tz.startswith(('+', '-')):
676            if len(tz) < 5:
677                tz += '00'
678            else:
679                tz = tz.replace(':', '')
680        elif tz in timezones:
681            tz = timezones[tz]
682        else:
683            tz = '+0000'
684        value += tz
685        fmt += '%z'
686    return datetime.strptime(value, fmt).timetz()
687
688
689def cast_timestamp(value, connection):
690    """Cast a timestamp value."""
691    if value == '-infinity':
692        return datetime.min
693    if value == 'infinity':
694        return datetime.max
695    value = value.split()
696    if value[-1] == 'BC':
697        return datetime.min
698    fmt = connection.date_format()
699    if fmt.endswith('-%Y') and len(value) > 2:
700        value = value[1:5]
701        if len(value[3]) > 4:
702            return datetime.max
703        fmt = ['%d %b' if fmt.startswith('%d') else '%b %d',
704            '%H:%M:%S.%f' if len(value[2]) > 8 else '%H:%M:%S', '%Y']
705    else:
706        if len(value[0]) > 10:
707            return datetime.max
708        fmt = [fmt, '%H:%M:%S.%f' if len(value[1]) > 8 else '%H:%M:%S']
709    return datetime.strptime(' '.join(value), ' '.join(fmt))
710
711
712def cast_timestamptz(value, connection):
713    """Cast a timestamptz value."""
714    if value == '-infinity':
715        return datetime.min
716    if value == 'infinity':
717        return datetime.max
718    value = value.split()
719    if value[-1] == 'BC':
720        return datetime.min
721    fmt = connection.date_format()
722    if fmt.endswith('-%Y') and len(value) > 2:
723        value = value[1:]
724        if len(value[3]) > 4:
725            return datetime.max
726        fmt = ['%d %b' if fmt.startswith('%d') else '%b %d',
727            '%H:%M:%S.%f' if len(value[2]) > 8 else '%H:%M:%S', '%Y']
728        value, tz = value[:-1], value[-1]
729    else:
730        if fmt.startswith('%Y-'):
731            tz = _re_timezone.match(value[1])
732            if tz:
733                value[1], tz = tz.groups()
734            else:
735                tz = '+0000'
736        else:
737            value, tz = value[:-1], value[-1]
738        if len(value[0]) > 10:
739            return datetime.max
740        fmt = [fmt, '%H:%M:%S.%f' if len(value[1]) > 8 else '%H:%M:%S']
741    if timezones:
742        if tz.startswith(('+', '-')):
743            if len(tz) < 5:
744                tz += '00'
745            else:
746                tz = tz.replace(':', '')
747        elif tz in timezones:
748            tz = timezones[tz]
749        else:
750            tz = '+0000'
751        value.append(tz)
752        fmt.append('%z')
753    return datetime.strptime(' '.join(value), ' '.join(fmt))
754
755_re_interval_sql_standard = regex(
756    '(?:([+-])?([0-9]+)-([0-9]+) ?)?'
757    '(?:([+-]?[0-9]+)(?!:) ?)?'
758    '(?:([+-])?([0-9]+):([0-9]+):([0-9]+)(?:\\.([0-9]+))?)?')
759
760_re_interval_postgres = regex(
761    '(?:([+-]?[0-9]+) ?years? ?)?'
762    '(?:([+-]?[0-9]+) ?mons? ?)?'
763    '(?:([+-]?[0-9]+) ?days? ?)?'
764    '(?:([+-])?([0-9]+):([0-9]+):([0-9]+)(?:\\.([0-9]+))?)?')
765
766_re_interval_postgres_verbose = regex(
767    '@ ?(?:([+-]?[0-9]+) ?years? ?)?'
768    '(?:([+-]?[0-9]+) ?mons? ?)?'
769    '(?:([+-]?[0-9]+) ?days? ?)?'
770    '(?:([+-]?[0-9]+) ?hours? ?)?'
771    '(?:([+-]?[0-9]+) ?mins? ?)?'
772    '(?:([+-])?([0-9]+)(?:\\.([0-9]+))? ?secs?)? ?(ago)?')
773
774_re_interval_iso_8601 = regex(
775    'P(?:([+-]?[0-9]+)Y)?'
776    '(?:([+-]?[0-9]+)M)?'
777    '(?:([+-]?[0-9]+)D)?'
778    '(?:T(?:([+-]?[0-9]+)H)?'
779    '(?:([+-]?[0-9]+)M)?'
780    '(?:([+-])?([0-9]+)(?:\\.([0-9]+))?S)?)?')
781
782
783def cast_interval(value):
784    """Cast an interval value."""
785    # The output format depends on the server setting IntervalStyle, but it's
786    # not necessary to consult this setting to parse it.  It's faster to just
787    # check all possible formats, and there is no ambiguity here.
788    m = _re_interval_iso_8601.match(value)
789    if m:
790        m = [d or '0' for d in m.groups()]
791        secs_ago = m.pop(5) == '-'
792        m = [int(d) for d in m]
793        years, mons, days, hours, mins, secs, usecs = m
794        if secs_ago:
795            secs = -secs
796            usecs = -usecs
797    else:
798        m = _re_interval_postgres_verbose.match(value)
799        if m:
800            m, ago = [d or '0' for d in m.groups()[:8]], m.group(9)
801            secs_ago = m.pop(5) == '-'
802            m = [-int(d) for d in m] if ago else [int(d) for d in m]
803            years, mons, days, hours, mins, secs, usecs = m
804            if secs_ago:
805                secs = - secs
806                usecs = -usecs
807        else:
808            m = _re_interval_postgres.match(value)
809            if m and any(m.groups()):
810                m = [d or '0' for d in m.groups()]
811                hours_ago = m.pop(3) == '-'
812                m = [int(d) for d in m]
813                years, mons, days, hours, mins, secs, usecs = m
814                if hours_ago:
815                    hours = -hours
816                    mins = -mins
817                    secs = -secs
818                    usecs = -usecs
819            else:
820                m = _re_interval_sql_standard.match(value)
821                if m and any(m.groups()):
822                    m = [d or '0' for d in m.groups()]
823                    years_ago = m.pop(0) == '-'
824                    hours_ago = m.pop(3) == '-'
825                    m = [int(d) for d in m]
826                    years, mons, days, hours, mins, secs, usecs = m
827                    if years_ago:
828                        years = -years
829                        mons = -mons
830                    if hours_ago:
831                        hours = -hours
832                        mins = -mins
833                        secs = -secs
834                        usecs = -usecs
835                else:
836                    raise ValueError('Cannot parse interval: %s' % value)
837    days += 365 * years + 30 * mons
838    return timedelta(days=days, hours=hours, minutes=mins,
839        seconds=secs, microseconds=usecs)
840
841
842class Typecasts(dict):
843    """Dictionary mapping database types to typecast functions.
844
845    The cast functions get passed the string representation of a value in
846    the database which they need to convert to a Python object.  The
847    passed string will never be None since NULL values are already be
848    handled before the cast function is called.
849
850    Note that the basic types are already handled by the C extension.
851    They only need to be handled here as record or array components.
852    """
853
854    # the default cast functions
855    # (str functions are ignored but have been added for faster access)
856    defaults = {'char': str, 'bpchar': str, 'name': str,
857        'text': str, 'varchar': str,
858        'bool': cast_bool, 'bytea': unescape_bytea,
859        'int2': int, 'int4': int, 'serial': int, 'int8': long, 'oid': int,
860        'hstore': cast_hstore, 'json': cast_json, 'jsonb': cast_json,
861        'float4': float, 'float8': float,
862        'numeric': cast_num, 'money': cast_money,
863        'date': cast_date, 'interval': cast_interval,
864        'time': cast_time, 'timetz': cast_timetz,
865        'timestamp': cast_timestamp, 'timestamptz': cast_timestamptz,
866        'int2vector': cast_int2vector,
867        'anyarray': cast_array, 'record': cast_record}
868
869    connection = None  # will be set in a connection specific instance
870
871    def __missing__(self, typ):
872        """Create a cast function if it is not cached.
873       
874        Note that this class never raises a KeyError,
875        but returns None when no special cast function exists.
876        """
877        if not isinstance(typ, str):
878            raise TypeError('Invalid type: %s' % typ)
879        cast = self.defaults.get(typ)
880        if cast:
881            # store default for faster access
882            cast = self._add_connection(cast)
883            self[typ] = cast
884        elif typ.startswith('_'):
885            base_cast = self[typ[1:]]
886            cast = self.create_array_cast(base_cast)
887            if base_cast:
888                self[typ] = cast
889        else:
890            attnames = self.get_attnames(typ)
891            if attnames:
892                casts = [self[v.pgtype] for v in attnames.values()]
893                cast = self.create_record_cast(typ, attnames, casts)
894                self[typ] = cast
895        return cast
896
897    @staticmethod
898    def _needs_connection(func):
899        """Check if a typecast function needs a connection argument."""
900        try:
901            args = get_args(func)
902        except (TypeError, ValueError):
903            return False
904        else:
905            return 'connection' in args[1:]
906
907    def _add_connection(self, cast):
908        """Add a connection argument to the typecast function if necessary."""
909        if not self.connection or not self._needs_connection(cast):
910            return cast
911        return partial(cast, connection=self.connection)
912
913    def get(self, typ, default=None):
914        """Get the typecast function for the given database type."""
915        return self[typ] or default
916
917    def set(self, typ, cast):
918        """Set a typecast function for the specified database type(s)."""
919        if isinstance(typ, basestring):
920            typ = [typ]
921        if cast is None:
922            for t in typ:
923                self.pop(t, None)
924                self.pop('_%s' % t, None)
925        else:
926            if not callable(cast):
927                raise TypeError("Cast parameter must be callable")
928            for t in typ:
929                self[t] = self._add_connection(cast)
930                self.pop('_%s' % t, None)
931
932    def reset(self, typ=None):
933        """Reset the typecasts for the specified type(s) to their defaults.
934
935        When no type is specified, all typecasts will be reset.
936        """
937        if typ is None:
938            self.clear()
939        else:
940            if isinstance(typ, basestring):
941                typ = [typ]
942            for t in typ:
943                self.pop(t, None)
944
945    @classmethod
946    def get_default(cls, typ):
947        """Get the default typecast function for the given database type."""
948        return cls.defaults.get(typ)
949
950    @classmethod
951    def set_default(cls, typ, cast):
952        """Set a default typecast function for the given database type(s)."""
953        if isinstance(typ, basestring):
954            typ = [typ]
955        defaults = cls.defaults
956        if cast is None:
957            for t in typ:
958                defaults.pop(t, None)
959                defaults.pop('_%s' % t, None)
960        else:
961            if not callable(cast):
962                raise TypeError("Cast parameter must be callable")
963            for t in typ:
964                defaults[t] = cast
965                defaults.pop('_%s' % t, None)
966
967    def get_attnames(self, typ):
968        """Return the fields for the given record type.
969
970        This method will be replaced with the get_attnames() method of DbTypes.
971        """
972        return {}
973
974    def dateformat(self):
975        """Return the current date format.
976
977        This method will be replaced with the dateformat() method of DbTypes.
978        """
979        return '%Y-%m-%d'
980
981    def create_array_cast(self, basecast):
982        """Create an array typecast for the given base cast."""
983        def cast(v):
984            return cast_array(v, basecast)
985        return cast
986
987    def create_record_cast(self, name, fields, casts):
988        """Create a named record typecast for the given fields and casts."""
989        record = namedtuple(name, fields)
990        def cast(v):
991            return record(*cast_record(v, casts))
992        return cast
993
994
995def get_typecast(typ):
996    """Get the global typecast function for the given database type(s)."""
997    return Typecasts.get_default(typ)
998
999
1000def set_typecast(typ, cast):
1001    """Set a global typecast function for the given database type(s).
1002
1003    Note that connections cache cast functions. To be sure a global change
1004    is picked up by a running connection, call db.db_types.reset_typecast().
1005    """
1006    Typecasts.set_default(typ, cast)
1007
1008
1009class DbType(str):
1010    """Class augmenting the simple type name with additional info.
1011
1012    The following additional information is provided:
1013
1014        oid: the PostgreSQL type OID
1015        pgtype: the PostgreSQL type name
1016        regtype: the regular type name
1017        simple: the simple PyGreSQL type name
1018        typtype: b = base type, c = composite type etc.
1019        category: A = Array, b =Boolean, C = Composite etc.
1020        delim: delimiter for array types
1021        relid: corresponding table for composite types
1022        attnames: attributes for composite types
1023    """
1024
1025    @property
1026    def attnames(self):
1027        """Get names and types of the fields of a composite type."""
1028        return self._get_attnames(self)
1029
1030
1031class DbTypes(dict):
1032    """Cache for PostgreSQL data types.
1033
1034    This cache maps type OIDs and names to DbType objects containing
1035    information on the associated database type.
1036    """
1037
1038    _num_types = frozenset('int float num money'
1039        ' int2 int4 int8 float4 float8 numeric money'.split())
1040
1041    def __init__(self, db):
1042        """Initialize type cache for connection."""
1043        super(DbTypes, self).__init__()
1044        self._regtypes = False
1045        self._get_attnames = db.get_attnames
1046        self._typecasts = Typecasts()
1047        self._typecasts.get_attnames = self.get_attnames
1048        self._typecasts.connection = db
1049        db = db.db
1050        self.query = db.query
1051        self.escape_string = db.escape_string
1052
1053    def add(self, oid, pgtype, regtype,
1054               typtype, category, delim, relid):
1055        """Create a PostgreSQL type name with additional info."""
1056        if oid in self:
1057            return self[oid]
1058        simple = 'record' if relid else _simpletypes[pgtype]
1059        typ = DbType(regtype if self._regtypes else simple)
1060        typ.oid = oid
1061        typ.simple = simple
1062        typ.pgtype = pgtype
1063        typ.regtype = regtype
1064        typ.typtype = typtype
1065        typ.category = category
1066        typ.delim = delim
1067        typ.relid = relid
1068        typ._get_attnames = self.get_attnames
1069        return typ
1070
1071    def __missing__(self, key):
1072        """Get the type info from the database if it is not cached."""
1073        try:
1074            res = self.query("SELECT oid, typname, typname::regtype,"
1075                " typtype, typcategory, typdelim, typrelid"
1076                " FROM pg_type WHERE oid=%s::regtype" %
1077                (_quote_if_unqualified('$1', key),), (key,)).getresult()
1078        except ProgrammingError:
1079            res = None
1080        if not res:
1081            raise KeyError('Type %s could not be found' % key)
1082        res = res[0]
1083        typ = self.add(*res)
1084        self[typ.oid] = self[typ.pgtype] = typ
1085        return typ
1086
1087    def get(self, key, default=None):
1088        """Get the type even if it is not cached."""
1089        try:
1090            return self[key]
1091        except KeyError:
1092            return default
1093
1094    def get_attnames(self, typ):
1095        """Get names and types of the fields of a composite type."""
1096        if not isinstance(typ, DbType):
1097            typ = self.get(typ)
1098            if not typ:
1099                return None
1100        if not typ.relid:
1101            return None
1102        return self._get_attnames(typ.relid, with_oid=False)
1103
1104    def get_typecast(self, typ):
1105        """Get the typecast function for the given database type."""
1106        return self._typecasts.get(typ)
1107
1108    def set_typecast(self, typ, cast):
1109        """Set a typecast function for the specified database type(s)."""
1110        self._typecasts.set(typ, cast)
1111
1112    def reset_typecast(self, typ=None):
1113        """Reset the typecast function for the specified database type(s)."""
1114        self._typecasts.reset(typ)
1115
1116    def typecast(self, value, typ):
1117        """Cast the given value according to the given database type."""
1118        if value is None:
1119            # for NULL values, no typecast is necessary
1120            return None
1121        if not isinstance(typ, DbType):
1122            typ = self.get(typ)
1123            if typ:
1124                typ = typ.pgtype
1125        cast = self.get_typecast(typ) if typ else None
1126        if not cast or cast is str:
1127            # no typecast is necessary
1128            return value
1129        return cast(value)
1130
1131
1132def _namedresult(q):
1133    """Get query result as named tuples."""
1134    row = namedtuple('Row', q.listfields())
1135    return [row(*r) for r in q.getresult()]
1136
1137
1138class _MemoryQuery:
1139    """Class that embodies a given query result."""
1140
1141    def __init__(self, result, fields):
1142        """Create query from given result rows and field names."""
1143        self.result = result
1144        self.fields = fields
1145
1146    def listfields(self):
1147        """Return the stored field names of this query."""
1148        return self.fields
1149
1150    def getresult(self):
1151        """Return the stored result of this query."""
1152        return self.result
1153
1154
1155def _db_error(msg, cls=DatabaseError):
1156    """Return DatabaseError with empty sqlstate attribute."""
1157    error = cls(msg)
1158    error.sqlstate = None
1159    return error
1160
1161
1162def _int_error(msg):
1163    """Return InternalError."""
1164    return _db_error(msg, InternalError)
1165
1166
1167def _prg_error(msg):
1168    """Return ProgrammingError."""
1169    return _db_error(msg, ProgrammingError)
1170
1171
1172# Initialize the C module
1173
1174set_namedresult(_namedresult)
1175set_decimal(Decimal)
1176set_jsondecode(jsondecode)
1177
1178
1179# The notification handler
1180
1181class NotificationHandler(object):
1182    """A PostgreSQL client-side asynchronous notification handler."""
1183
1184    def __init__(self, db, event, callback=None,
1185            arg_dict=None, timeout=None, stop_event=None):
1186        """Initialize the notification handler.
1187
1188        You must pass a PyGreSQL database connection, the name of an
1189        event (notification channel) to listen for and a callback function.
1190
1191        You can also specify a dictionary arg_dict that will be passed as
1192        the single argument to the callback function, and a timeout value
1193        in seconds (a floating point number denotes fractions of seconds).
1194        If it is absent or None, the callers will never time out.  If the
1195        timeout is reached, the callback function will be called with a
1196        single argument that is None.  If you set the timeout to zero,
1197        the handler will poll notifications synchronously and return.
1198
1199        You can specify the name of the event that will be used to signal
1200        the handler to stop listening as stop_event. By default, it will
1201        be the event name prefixed with 'stop_'.
1202        """
1203        self.db = db
1204        self.event = event
1205        self.stop_event = stop_event or 'stop_%s' % event
1206        self.listening = False
1207        self.callback = callback
1208        if arg_dict is None:
1209            arg_dict = {}
1210        self.arg_dict = arg_dict
1211        self.timeout = timeout
1212
1213    def __del__(self):
1214        self.unlisten()
1215
1216    def close(self):
1217        """Stop listening and close the connection."""
1218        if self.db:
1219            self.unlisten()
1220            self.db.close()
1221            self.db = None
1222
1223    def listen(self):
1224        """Start listening for the event and the stop event."""
1225        if not self.listening:
1226            self.db.query('listen "%s"' % self.event)
1227            self.db.query('listen "%s"' % self.stop_event)
1228            self.listening = True
1229
1230    def unlisten(self):
1231        """Stop listening for the event and the stop event."""
1232        if self.listening:
1233            self.db.query('unlisten "%s"' % self.event)
1234            self.db.query('unlisten "%s"' % self.stop_event)
1235            self.listening = False
1236
1237    def notify(self, db=None, stop=False, payload=None):
1238        """Generate a notification.
1239
1240        Optionally, you can pass a payload with the notification.
1241
1242        If you set the stop flag, a stop notification will be sent that
1243        will cause the handler to stop listening.
1244
1245        Note: If the notification handler is running in another thread, you
1246        must pass a different database connection since PyGreSQL database
1247        connections are not thread-safe.
1248        """
1249        if self.listening:
1250            if not db:
1251                db = self.db
1252            q = 'notify "%s"' % (self.stop_event if stop else self.event)
1253            if payload:
1254                q += ", '%s'" % payload
1255            return db.query(q)
1256
1257    def __call__(self):
1258        """Invoke the notification handler.
1259
1260        The handler is a loop that listens for notifications on the event
1261        and stop event channels.  When either of these notifications are
1262        received, its associated 'pid', 'event' and 'extra' (the payload
1263        passed with the notification) are inserted into its arg_dict
1264        dictionary and the callback is invoked with this dictionary as
1265        a single argument.  When the handler receives a stop event, it
1266        stops listening to both events and return.
1267
1268        In the special case that the timeout of the handler has been set
1269        to zero, the handler will poll all events synchronously and return.
1270        If will keep listening until it receives a stop event.
1271
1272        Note: If you run this loop in another thread, don't use the same
1273        database connection for database operations in the main thread.
1274        """
1275        self.listen()
1276        poll = self.timeout == 0
1277        if not poll:
1278            rlist = [self.db.fileno()]
1279        while self.listening:
1280            if poll or select.select(rlist, [], [], self.timeout)[0]:
1281                while self.listening:
1282                    notice = self.db.getnotify()
1283                    if not notice:  # no more messages
1284                        break
1285                    event, pid, extra = notice
1286                    if event not in (self.event, self.stop_event):
1287                        self.unlisten()
1288                        raise _db_error(
1289                            'Listening for "%s" and "%s", but notified of "%s"'
1290                            % (self.event, self.stop_event, event))
1291                    if event == self.stop_event:
1292                        self.unlisten()
1293                    self.arg_dict.update(pid=pid, event=event, extra=extra)
1294                    self.callback(self.arg_dict)
1295                if poll:
1296                    break
1297            else:   # we timed out
1298                self.unlisten()
1299                self.callback(None)
1300
1301
1302def pgnotify(*args, **kw):
1303    """Same as NotificationHandler, under the traditional name."""
1304    warnings.warn("pgnotify is deprecated, use NotificationHandler instead",
1305        DeprecationWarning, stacklevel=2)
1306    return NotificationHandler(*args, **kw)
1307
1308
1309# The actual PostGreSQL database connection interface:
1310
1311class DB:
1312    """Wrapper class for the _pg connection type."""
1313
1314    def __init__(self, *args, **kw):
1315        """Create a new connection
1316
1317        You can pass either the connection parameters or an existing
1318        _pg or pgdb connection. This allows you to use the methods
1319        of the classic pg interface with a DB-API 2 pgdb connection.
1320        """
1321        if not args and len(kw) == 1:
1322            db = kw.get('db')
1323        elif not kw and len(args) == 1:
1324            db = args[0]
1325        else:
1326            db = None
1327        if db:
1328            if isinstance(db, DB):
1329                db = db.db
1330            else:
1331                try:
1332                    db = db._cnx
1333                except AttributeError:
1334                    pass
1335        if not db or not hasattr(db, 'db') or not hasattr(db, 'query'):
1336            db = connect(*args, **kw)
1337            self._closeable = True
1338        else:
1339            self._closeable = False
1340        self.db = db
1341        self.dbname = db.db
1342        self._regtypes = False
1343        self._attnames = {}
1344        self._pkeys = {}
1345        self._privileges = {}
1346        self._args = args, kw
1347        self.adapter = Adapter(self)
1348        self.dbtypes = DbTypes(self)
1349        db.set_cast_hook(self.dbtypes.typecast)
1350        self.debug = None  # For debugging scripts, this can be set
1351            # * to a string format specification (e.g. in CGI set to "%s<BR>"),
1352            # * to a file object to write debug statements or
1353            # * to a callable object which takes a string argument
1354            # * to any other true value to just print debug statements
1355
1356    def __getattr__(self, name):
1357        # All undefined members are same as in underlying connection:
1358        if self.db:
1359            return getattr(self.db, name)
1360        else:
1361            raise _int_error('Connection is not valid')
1362
1363    def __dir__(self):
1364        # Custom dir function including the attributes of the connection:
1365        attrs = set(self.__class__.__dict__)
1366        attrs.update(self.__dict__)
1367        attrs.update(dir(self.db))
1368        return sorted(attrs)
1369
1370    # Context manager methods
1371
1372    def __enter__(self):
1373        """Enter the runtime context. This will start a transactio."""
1374        self.begin()
1375        return self
1376
1377    def __exit__(self, et, ev, tb):
1378        """Exit the runtime context. This will end the transaction."""
1379        if et is None and ev is None and tb is None:
1380            self.commit()
1381        else:
1382            self.rollback()
1383
1384    # Auxiliary methods
1385
1386    def _do_debug(self, *args):
1387        """Print a debug message"""
1388        if self.debug:
1389            s = '\n'.join(str(arg) for arg in args)
1390            if isinstance(self.debug, basestring):
1391                print(self.debug % s)
1392            elif hasattr(self.debug, 'write'):
1393                self.debug.write(s + '\n')
1394            elif callable(self.debug):
1395                self.debug(s)
1396            else:
1397                print(s)
1398
1399    def _escape_qualified_name(self, s):
1400        """Escape a qualified name.
1401
1402        Escapes the name for use as an SQL identifier, unless the
1403        name contains a dot, in which case the name is ambiguous
1404        (could be a qualified name or just a name with a dot in it)
1405        and must be quoted manually by the caller.
1406        """
1407        if '.' not in s:
1408            s = self.escape_identifier(s)
1409        return s
1410
1411    @staticmethod
1412    def _make_bool(d):
1413        """Get boolean value corresponding to d."""
1414        return bool(d) if get_bool() else ('t' if d else 'f')
1415
1416    def _list_params(self, params):
1417        """Create a human readable parameter list."""
1418        return ', '.join('$%d=%r' % (n, v) for n, v in enumerate(params, 1))
1419
1420    # Public methods
1421
1422    # escape_string and escape_bytea exist as methods,
1423    # so we define unescape_bytea as a method as well
1424    unescape_bytea = staticmethod(unescape_bytea)
1425
1426    def decode_json(self, s):
1427        """Decode a JSON string coming from the database."""
1428        return (get_jsondecode() or jsondecode)(s)
1429
1430    def encode_json(self, d):
1431        """Encode a JSON string for use within SQL."""
1432        return jsonencode(d)
1433
1434    def close(self):
1435        """Close the database connection."""
1436        # Wraps shared library function so we can track state.
1437        if self._closeable:
1438            if self.db:
1439                self.db.close()
1440                self.db = None
1441            else:
1442                raise _int_error('Connection already closed')
1443
1444    def reset(self):
1445        """Reset connection with current parameters.
1446
1447        All derived queries and large objects derived from this connection
1448        will not be usable after this call.
1449
1450        """
1451        if self.db:
1452            self.db.reset()
1453        else:
1454            raise _int_error('Connection already closed')
1455
1456    def reopen(self):
1457        """Reopen connection to the database.
1458
1459        Used in case we need another connection to the same database.
1460        Note that we can still reopen a database that we have closed.
1461
1462        """
1463        # There is no such shared library function.
1464        if self._closeable:
1465            db = connect(*self._args[0], **self._args[1])
1466            if self.db:
1467                self.db.close()
1468            self.db = db
1469
1470    def begin(self, mode=None):
1471        """Begin a transaction."""
1472        qstr = 'BEGIN'
1473        if mode:
1474            qstr += ' ' + mode
1475        return self.query(qstr)
1476
1477    start = begin
1478
1479    def commit(self):
1480        """Commit the current transaction."""
1481        return self.query('COMMIT')
1482
1483    end = commit
1484
1485    def rollback(self, name=None):
1486        """Roll back the current transaction."""
1487        qstr = 'ROLLBACK'
1488        if name:
1489            qstr += ' TO ' + name
1490        return self.query(qstr)
1491
1492    abort = rollback
1493
1494    def savepoint(self, name):
1495        """Define a new savepoint within the current transaction."""
1496        return self.query('SAVEPOINT ' + name)
1497
1498    def release(self, name):
1499        """Destroy a previously defined savepoint."""
1500        return self.query('RELEASE ' + name)
1501
1502    def get_parameter(self, parameter):
1503        """Get the value of a run-time parameter.
1504
1505        If the parameter is a string, the return value will also be a string
1506        that is the current setting of the run-time parameter with that name.
1507
1508        You can get several parameters at once by passing a list, set or dict.
1509        When passing a list of parameter names, the return value will be a
1510        corresponding list of parameter settings.  When passing a set of
1511        parameter names, a new dict will be returned, mapping these parameter
1512        names to their settings.  Finally, if you pass a dict as parameter,
1513        its values will be set to the current parameter settings corresponding
1514        to its keys.
1515
1516        By passing the special name 'all' as the parameter, you can get a dict
1517        of all existing configuration parameters.
1518        """
1519        if isinstance(parameter, basestring):
1520            parameter = [parameter]
1521            values = None
1522        elif isinstance(parameter, (list, tuple)):
1523            values = []
1524        elif isinstance(parameter, (set, frozenset)):
1525            values = {}
1526        elif isinstance(parameter, dict):
1527            values = parameter
1528        else:
1529            raise TypeError(
1530                'The parameter must be a string, list, set or dict')
1531        if not parameter:
1532            raise TypeError('No parameter has been specified')
1533        params = {} if isinstance(values, dict) else []
1534        for key in parameter:
1535            param = key.strip().lower() if isinstance(
1536                key, basestring) else None
1537            if not param:
1538                raise TypeError('Invalid parameter')
1539            if param == 'all':
1540                q = 'SHOW ALL'
1541                values = self.db.query(q).getresult()
1542                values = dict(value[:2] for value in values)
1543                break
1544            if isinstance(values, dict):
1545                params[param] = key
1546            else:
1547                params.append(param)
1548        else:
1549            for param in params:
1550                q = 'SHOW %s' % (param,)
1551                value = self.db.query(q).getresult()[0][0]
1552                if values is None:
1553                    values = value
1554                elif isinstance(values, list):
1555                    values.append(value)
1556                else:
1557                    values[params[param]] = value
1558        return values
1559
1560    def set_parameter(self, parameter, value=None, local=False):
1561        """Set the value of a run-time parameter.
1562
1563        If the parameter and the value are strings, the run-time parameter
1564        will be set to that value.  If no value or None is passed as a value,
1565        then the run-time parameter will be restored to its default value.
1566
1567        You can set several parameters at once by passing a list of parameter
1568        names, together with a single value that all parameters should be
1569        set to or with a corresponding list of values.  You can also pass
1570        the parameters as a set if you only provide a single value.
1571        Finally, you can pass a dict with parameter names as keys.  In this
1572        case, you should not pass a value, since the values for the parameters
1573        will be taken from the dict.
1574
1575        By passing the special name 'all' as the parameter, you can reset
1576        all existing settable run-time parameters to their default values.
1577
1578        If you set local to True, then the command takes effect for only the
1579        current transaction.  After commit() or rollback(), the session-level
1580        setting takes effect again.  Setting local to True will appear to
1581        have no effect if it is executed outside a transaction, since the
1582        transaction will end immediately.
1583        """
1584        if isinstance(parameter, basestring):
1585            parameter = {parameter: value}
1586        elif isinstance(parameter, (list, tuple)):
1587            if isinstance(value, (list, tuple)):
1588                parameter = dict(zip(parameter, value))
1589            else:
1590                parameter = dict.fromkeys(parameter, value)
1591        elif isinstance(parameter, (set, frozenset)):
1592            if isinstance(value, (list, tuple, set, frozenset)):
1593                value = set(value)
1594                if len(value) == 1:
1595                    value = value.pop()
1596            if not(value is None or isinstance(value, basestring)):
1597                raise ValueError('A single value must be specified'
1598                    ' when parameter is a set')
1599            parameter = dict.fromkeys(parameter, value)
1600        elif isinstance(parameter, dict):
1601            if value is not None:
1602                raise ValueError('A value must not be specified'
1603                    ' when parameter is a dictionary')
1604        else:
1605            raise TypeError(
1606                'The parameter must be a string, list, set or dict')
1607        if not parameter:
1608            raise TypeError('No parameter has been specified')
1609        params = {}
1610        for key, value in parameter.items():
1611            param = key.strip().lower() if isinstance(
1612                key, basestring) else None
1613            if not param:
1614                raise TypeError('Invalid parameter')
1615            if param == 'all':
1616                if value is not None:
1617                    raise ValueError('A value must ot be specified'
1618                        " when parameter is 'all'")
1619                params = {'all': None}
1620                break
1621            params[param] = value
1622        local = ' LOCAL' if local else ''
1623        for param, value in params.items():
1624            if value is None:
1625                q = 'RESET%s %s' % (local, param)
1626            else:
1627                q = 'SET%s %s TO %s' % (local, param, value)
1628            self._do_debug(q)
1629            self.db.query(q)
1630
1631    def query(self, command, *args):
1632        """Execute a SQL command string.
1633
1634        This method simply sends a SQL query to the database.  If the query is
1635        an insert statement that inserted exactly one row into a table that
1636        has OIDs, the return value is the OID of the newly inserted row.
1637        If the query is an update or delete statement, or an insert statement
1638        that did not insert exactly one row in a table with OIDs, then the
1639        number of rows affected is returned as a string.  If it is a statement
1640        that returns rows as a result (usually a select statement, but maybe
1641        also an "insert/update ... returning" statement), this method returns
1642        a Query object that can be accessed via getresult() or dictresult()
1643        or simply printed.  Otherwise, it returns `None`.
1644
1645        The query can contain numbered parameters of the form $1 in place
1646        of any data constant.  Arguments given after the query string will
1647        be substituted for the corresponding numbered parameter.  Parameter
1648        values can also be given as a single list or tuple argument.
1649        """
1650        # Wraps shared library function for debugging.
1651        if not self.db:
1652            raise _int_error('Connection is not valid')
1653        if args:
1654            self._do_debug(command, args)
1655            return self.db.query(command, args)
1656        self._do_debug(command)
1657        return self.db.query(command)
1658
1659    def query_formatted(self, command, parameters, types=None, inline=False):
1660        """Execute a formatted SQL command string.
1661
1662        Similar to query, but using Python format placeholders of the form
1663        %s or %(names)s instead of PostgreSQL placeholders of the form $1.
1664        The parameters must be passed as a tuple, list or dict.  You can
1665        also pass a corresponding tuple, list or dict of database types in
1666        order to format the parameters properly in case there is ambiguity.
1667
1668        If you set inline to True, the parameters will be sent to the database
1669        embedded in the SQL command, otherwise they will be sent separately.
1670        """
1671        return self.query(*self.adapter.format_query(
1672            command, parameters, types, inline))
1673
1674    def pkey(self, table, composite=False, flush=False):
1675        """Get or set the primary key of a table.
1676
1677        Single primary keys are returned as strings unless you
1678        set the composite flag.  Composite primary keys are always
1679        represented as tuples.  Note that this raises a KeyError
1680        if the table does not have a primary key.
1681
1682        If flush is set then the internal cache for primary keys will
1683        be flushed.  This may be necessary after the database schema or
1684        the search path has been changed.
1685        """
1686        pkeys = self._pkeys
1687        if flush:
1688            pkeys.clear()
1689            self._do_debug('The pkey cache has been flushed')
1690        try:  # cache lookup
1691            pkey = pkeys[table]
1692        except KeyError:  # cache miss, check the database
1693            q = ("SELECT a.attname, a.attnum, i.indkey FROM pg_index i"
1694                " JOIN pg_attribute a ON a.attrelid = i.indrelid"
1695                " AND a.attnum = ANY(i.indkey)"
1696                " AND NOT a.attisdropped"
1697                " WHERE i.indrelid=%s::regclass"
1698                " AND i.indisprimary ORDER BY a.attnum") % (
1699                    _quote_if_unqualified('$1', table),)
1700            pkey = self.db.query(q, (table,)).getresult()
1701            if not pkey:
1702                raise KeyError('Table %s has no primary key' % table)
1703            # we want to use the order defined in the primary key index here,
1704            # not the order as defined by the columns in the table
1705            if len(pkey) > 1:
1706                indkey = pkey[0][2]
1707                pkey = sorted(pkey, key=lambda row: indkey.index(row[1]))
1708                pkey = tuple(row[0] for row in pkey)
1709            else:
1710                pkey = pkey[0][0]
1711            pkeys[table] = pkey  # cache it
1712        if composite and not isinstance(pkey, tuple):
1713            pkey = (pkey,)
1714        return pkey
1715
1716    def get_databases(self):
1717        """Get list of databases in the system."""
1718        return [s[0] for s in
1719            self.db.query('SELECT datname FROM pg_database').getresult()]
1720
1721    def get_relations(self, kinds=None):
1722        """Get list of relations in connected database of specified kinds.
1723
1724        If kinds is None or empty, all kinds of relations are returned.
1725        Otherwise kinds can be a string or sequence of type letters
1726        specifying which kind of relations you want to list.
1727        """
1728        where = " AND r.relkind IN (%s)" % ','.join(
1729            ["'%s'" % k for k in kinds]) if kinds else ''
1730        q = ("SELECT quote_ident(s.nspname)||'.'||quote_ident(r.relname)"
1731            " FROM pg_class r"
1732            " JOIN pg_namespace s ON s.oid = r.relnamespace"
1733            " WHERE s.nspname NOT SIMILAR"
1734            " TO 'pg/_%%|information/_schema' ESCAPE '/' %s"
1735            " ORDER BY s.nspname, r.relname") % where
1736        return [r[0] for r in self.db.query(q).getresult()]
1737
1738    def get_tables(self):
1739        """Return list of tables in connected database."""
1740        return self.get_relations('r')
1741
1742    def get_attnames(self, table, with_oid=True, flush=False):
1743        """Given the name of a table, dig out the set of attribute names.
1744
1745        Returns a read-only dictionary of attribute names (the names are
1746        the keys, the values are the names of the attributes' types)
1747        with the column names in the proper order if you iterate over it.
1748
1749        If flush is set, then the internal cache for attribute names will
1750        be flushed. This may be necessary after the database schema or
1751        the search path has been changed.
1752
1753        By default, only a limited number of simple types will be returned.
1754        You can get the regular types after calling use_regtypes(True).
1755        """
1756        attnames = self._attnames
1757        if flush:
1758            attnames.clear()
1759            self._do_debug('The attnames cache has been flushed')
1760        try:  # cache lookup
1761            names = attnames[table]
1762        except KeyError:  # cache miss, check the database
1763            q = "a.attnum > 0"
1764            if with_oid:
1765                q = "(%s OR a.attname = 'oid')" % q
1766            q = ("SELECT a.attname, t.oid, t.typname, t.typname::regtype,"
1767                " t.typtype, t.typcategory, t.typdelim, t.typrelid"
1768                " FROM pg_attribute a"
1769                " JOIN pg_type t ON t.oid = a.atttypid"
1770                " WHERE a.attrelid = %s::regclass AND %s"
1771                " AND NOT a.attisdropped ORDER BY a.attnum") % (
1772                    _quote_if_unqualified('$1', table), q)
1773            names = self.db.query(q, (table,)).getresult()
1774            types = self.dbtypes
1775            names = ((name[0], types.add(*name[1:])) for name in names)
1776            names = AttrDict(names)
1777            attnames[table] = names  # cache it
1778        return names
1779
1780    def use_regtypes(self, regtypes=None):
1781        """Use regular type names instead of simplified type names."""
1782        if regtypes is None:
1783            return self.dbtypes._regtypes
1784        else:
1785            regtypes = bool(regtypes)
1786            if regtypes != self.dbtypes._regtypes:
1787                self.dbtypes._regtypes = regtypes
1788                self._attnames.clear()
1789                self.dbtypes.clear()
1790            return regtypes
1791
1792    def has_table_privilege(self, table, privilege='select'):
1793        """Check whether current user has specified table privilege."""
1794        privilege = privilege.lower()
1795        try:  # ask cache
1796            return self._privileges[(table, privilege)]
1797        except KeyError:  # cache miss, ask the database
1798            q = "SELECT has_table_privilege(%s, $2)" % (
1799                _quote_if_unqualified('$1', table),)
1800            q = self.db.query(q, (table, privilege))
1801            ret = q.getresult()[0][0] == self._make_bool(True)
1802            self._privileges[(table, privilege)] = ret  # cache it
1803            return ret
1804
1805    def get(self, table, row, keyname=None):
1806        """Get a row from a database table or view.
1807
1808        This method is the basic mechanism to get a single row.  It assumes
1809        that the keyname specifies a unique row.  It must be the name of a
1810        single column or a tuple of column names.  If the keyname is not
1811        specified, then the primary key for the table is used.
1812
1813        If row is a dictionary, then the value for the key is taken from it.
1814        Otherwise, the row must be a single value or a tuple of values
1815        corresponding to the passed keyname or primary key.  The fetched row
1816        from the table will be returned as a new dictionary or used to replace
1817        the existing values when row was passed as aa dictionary.
1818
1819        The OID is also put into the dictionary if the table has one, but
1820        in order to allow the caller to work with multiple tables, it is
1821        munged as "oid(table)" using the actual name of the table.
1822        """
1823        if table.endswith('*'):  # hint for descendant tables can be ignored
1824            table = table[:-1].rstrip()
1825        attnames = self.get_attnames(table)
1826        qoid = _oid_key(table) if 'oid' in attnames else None
1827        if keyname and isinstance(keyname, basestring):
1828            keyname = (keyname,)
1829        if qoid and isinstance(row, dict) and qoid in row and 'oid' not in row:
1830            row['oid'] = row[qoid]
1831        if not keyname:
1832            try:  # if keyname is not specified, try using the primary key
1833                keyname = self.pkey(table, True)
1834            except KeyError:  # the table has no primary key
1835                # try using the oid instead
1836                if qoid and isinstance(row, dict) and 'oid' in row:
1837                    keyname = ('oid',)
1838                else:
1839                    raise _prg_error('Table %s has no primary key' % table)
1840            else:  # the table has a primary key
1841                # check whether all key columns have values
1842                if isinstance(row, dict) and not set(keyname).issubset(row):
1843                    # try using the oid instead
1844                    if qoid and 'oid' in row:
1845                        keyname = ('oid',)
1846                    else:
1847                        raise KeyError(
1848                            'Missing value in row for specified keyname')
1849        if not isinstance(row, dict):
1850            if not isinstance(row, (tuple, list)):
1851                row = [row]
1852            if len(keyname) != len(row):
1853                raise KeyError(
1854                    'Differing number of items in keyname and row')
1855            row = dict(zip(keyname, row))
1856        params = self.adapter.parameter_list()
1857        adapt = params.add
1858        col = self.escape_identifier
1859        what = 'oid, *' if qoid else '*'
1860        where = ' AND '.join('%s = %s' % (
1861            col(k), adapt(row[k], attnames[k])) for k in keyname)
1862        if 'oid' in row:
1863            if qoid:
1864                row[qoid] = row['oid']
1865            del row['oid']
1866        q = 'SELECT %s FROM %s WHERE %s LIMIT 1' % (
1867            what, self._escape_qualified_name(table), where)
1868        self._do_debug(q, params)
1869        q = self.db.query(q, params)
1870        res = q.dictresult()
1871        if not res:
1872            raise _db_error('No such record in %s\nwhere %s\nwith %s' % (
1873                table, where, self._list_params(params)))
1874        for n, value in res[0].items():
1875            if qoid and n == 'oid':
1876                n = qoid
1877            row[n] = value
1878        return row
1879
1880    def insert(self, table, row=None, **kw):
1881        """Insert a row into a database table.
1882
1883        This method inserts a row into a table.  The name of the table must
1884        be passed as the first parameter.  The other parameters are used for
1885        providing the data of the row that shall be inserted into the table.
1886        If a dictionary is supplied as the second parameter, it starts with
1887        that.  Otherwise it uses a blank dictionary. Either way the dictionary
1888        is updated from the keywords.
1889
1890        The dictionary is then reloaded with the values actually inserted in
1891        order to pick up values modified by rules, triggers, etc.
1892        """
1893        if table.endswith('*'):  # hint for descendant tables can be ignored
1894            table = table[:-1].rstrip()
1895        if row is None:
1896            row = {}
1897        row.update(kw)
1898        if 'oid' in row:
1899            del row['oid']  # do not insert oid
1900        attnames = self.get_attnames(table)
1901        qoid = _oid_key(table) if 'oid' in attnames else None
1902        params = self.adapter.parameter_list()
1903        adapt = params.add
1904        col = self.escape_identifier
1905        names, values = [], []
1906        for n in attnames:
1907            if n in row:
1908                names.append(col(n))
1909                values.append(adapt(row[n], attnames[n]))
1910        if not names:
1911            raise _prg_error('No column found that can be inserted')
1912        names, values = ', '.join(names), ', '.join(values)
1913        ret = 'oid, *' if qoid else '*'
1914        q = 'INSERT INTO %s (%s) VALUES (%s) RETURNING %s' % (
1915            self._escape_qualified_name(table), names, values, ret)
1916        self._do_debug(q, params)
1917        q = self.db.query(q, params)
1918        res = q.dictresult()
1919        if res:  # this should always be true
1920            for n, value in res[0].items():
1921                if qoid and n == 'oid':
1922                    n = qoid
1923                row[n] = value
1924        return row
1925
1926    def update(self, table, row=None, **kw):
1927        """Update an existing row in a database table.
1928
1929        Similar to insert but updates an existing row.  The update is based
1930        on the primary key of the table or the OID value as munged by get
1931        or passed as keyword.
1932
1933        The dictionary is then modified to reflect any changes caused by the
1934        update due to triggers, rules, default values, etc.
1935        """
1936        if table.endswith('*'):
1937            table = table[:-1].rstrip()  # need parent table name
1938        attnames = self.get_attnames(table)
1939        qoid = _oid_key(table) if 'oid' in attnames else None
1940        if row is None:
1941            row = {}
1942        elif 'oid' in row:
1943            del row['oid']  # only accept oid key from named args for safety
1944        row.update(kw)
1945        if qoid and qoid in row and 'oid' not in row:
1946            row['oid'] = row[qoid]
1947        try:  # try using the primary key
1948            keyname = self.pkey(table, True)
1949        except KeyError:  # the table has no primary key
1950            # try using the oid instead
1951            if qoid and 'oid' in row:
1952                keyname = ('oid',)
1953            else:
1954                raise _prg_error('Table %s has no primary key' % table)
1955        else:  # the table has a primary key
1956            # check whether all key columns have values
1957            if not set(keyname).issubset(row):
1958                # try using the oid instead
1959                if qoid and 'oid' in row:
1960                    keyname = ('oid',)
1961                else:
1962                    raise KeyError('Missing primary key in row')
1963        params = self.adapter.parameter_list()
1964        adapt = params.add
1965        col = self.escape_identifier
1966        where = ' AND '.join('%s = %s' % (
1967            col(k), adapt(row[k], attnames[k])) for k in keyname)
1968        if 'oid' in row:
1969            if qoid:
1970                row[qoid] = row['oid']
1971            del row['oid']
1972        values = []
1973        keyname = set(keyname)
1974        for n in attnames:
1975            if n in row and n not in keyname:
1976                values.append('%s = %s' % (col(n), adapt(row[n], attnames[n])))
1977        if not values:
1978            return row
1979        values = ', '.join(values)
1980        ret = 'oid, *' if qoid else '*'
1981        q = 'UPDATE %s SET %s WHERE %s RETURNING %s' % (
1982            self._escape_qualified_name(table), values, where, ret)
1983        self._do_debug(q, params)
1984        q = self.db.query(q, params)
1985        res = q.dictresult()
1986        if res:  # may be empty when row does not exist
1987            for n, value in res[0].items():
1988                if qoid and n == 'oid':
1989                    n = qoid
1990                row[n] = value
1991        return row
1992
1993    def upsert(self, table, row=None, **kw):
1994        """Insert a row into a database table with conflict resolution
1995
1996        This method inserts a row into a table, but instead of raising a
1997        ProgrammingError exception in case a row with the same primary key
1998        already exists, an update will be executed instead.  This will be
1999        performed as a single atomic operation on the database, so race
2000        conditions can be avoided.
2001
2002        Like the insert method, the first parameter is the name of the
2003        table and the second parameter can be used to pass the values to
2004        be inserted as a dictionary.
2005
2006        Unlike the insert und update statement, keyword parameters are not
2007        used to modify the dictionary, but to specify which columns shall
2008        be updated in case of a conflict, and in which way:
2009
2010        A value of False or None means the column shall not be updated,
2011        a value of True means the column shall be updated with the value
2012        that has been proposed for insertion, i.e. has been passed as value
2013        in the dictionary.  Columns that are not specified by keywords but
2014        appear as keys in the dictionary are also updated like in the case
2015        keywords had been passed with the value True.
2016
2017        So if in the case of a conflict you want to update every column that
2018        has been passed in the dictionary row , you would call upsert(table, row).
2019        If you don't want to do anything in case of a conflict, i.e. leave
2020        the existing row as it is, call upsert(table, row, **dict.fromkeys(row)).
2021
2022        If you need more fine-grained control of what gets updated, you can
2023        also pass strings in the keyword parameters.  These strings will
2024        be used as SQL expressions for the update columns.  In these
2025        expressions you can refer to the value that already exists in
2026        the table by prefixing the column name with "included.", and to
2027        the value that has been proposed for insertion by prefixing the
2028        column name with the "excluded."
2029
2030        The dictionary is modified in any case to reflect the values in
2031        the database after the operation has completed.
2032
2033        Note: The method uses the PostgreSQL "upsert" feature which is
2034        only available since PostgreSQL 9.5.
2035        """
2036        if table.endswith('*'):  # hint for descendant tables can be ignored
2037            table = table[:-1].rstrip()
2038        if row is None:
2039            row = {}
2040        if 'oid' in row:
2041            del row['oid']  # do not insert oid
2042        if 'oid' in kw:
2043            del kw['oid']  # do not update oid
2044        attnames = self.get_attnames(table)
2045        qoid = _oid_key(table) if 'oid' in attnames else None
2046        params = self.adapter.parameter_list()
2047        adapt = params.add
2048        col = self.escape_identifier
2049        names, values, updates = [], [], []
2050        for n in attnames:
2051            if n in row:
2052                names.append(col(n))
2053                values.append(adapt(row[n], attnames[n]))
2054        names, values = ', '.join(names), ', '.join(values)
2055        try:
2056            keyname = self.pkey(table, True)
2057        except KeyError:
2058            raise _prg_error('Table %s has no primary key' % table)
2059        target = ', '.join(col(k) for k in keyname)
2060        update = []
2061        keyname = set(keyname)
2062        keyname.add('oid')
2063        for n in attnames:
2064            if n not in keyname:
2065                value = kw.get(n, True)
2066                if value:
2067                    if not isinstance(value, basestring):
2068                        value = 'excluded.%s' % col(n)
2069                    update.append('%s = %s' % (col(n), value))
2070        if not values:
2071            return row
2072        do = 'update set %s' % ', '.join(update) if update else 'nothing'
2073        ret = 'oid, *' if qoid else '*'
2074        q = ('INSERT INTO %s AS included (%s) VALUES (%s)'
2075            ' ON CONFLICT (%s) DO %s RETURNING %s') % (
2076                self._escape_qualified_name(table), names, values,
2077                target, do, ret)
2078        self._do_debug(q, params)
2079        try:
2080            q = self.db.query(q, params)
2081        except ProgrammingError:
2082            if self.server_version < 90500:
2083                raise _prg_error(
2084                    'Upsert operation is not supported by PostgreSQL version')
2085            raise  # re-raise original error
2086        res = q.dictresult()
2087        if res:  # may be empty with "do nothing"
2088            for n, value in res[0].items():
2089                if qoid and n == 'oid':
2090                    n = qoid
2091                row[n] = value
2092        else:
2093            self.get(table, row)
2094        return row
2095
2096    def clear(self, table, row=None):
2097        """Clear all the attributes to values determined by the types.
2098
2099        Numeric types are set to 0, Booleans are set to false, and everything
2100        else is set to the empty string.  If the row argument is present,
2101        it is used as the row dictionary and any entries matching attribute
2102        names are cleared with everything else left unchanged.
2103        """
2104        # At some point we will need a way to get defaults from a table.
2105        if row is None:
2106            row = {}  # empty if argument is not present
2107        attnames = self.get_attnames(table)
2108        for n, t in attnames.items():
2109            if n == 'oid':
2110                continue
2111            t = t.simple
2112            if t in DbTypes._num_types:
2113                row[n] = 0
2114            elif t == 'bool':
2115                row[n] = self._make_bool(False)
2116            else:
2117                row[n] = ''
2118        return row
2119
2120    def delete(self, table, row=None, **kw):
2121        """Delete an existing row in a database table.
2122
2123        This method deletes the row from a table.  It deletes based on the
2124        primary key of the table or the OID value as munged by get() or
2125        passed as keyword.
2126
2127        The return value is the number of deleted rows (i.e. 0 if the row
2128        did not exist and 1 if the row was deleted).
2129
2130        Note that if the row cannot be deleted because e.g. it is still
2131        referenced by another table, this method raises a ProgrammingError.
2132        """
2133        if table.endswith('*'):  # hint for descendant tables can be ignored
2134            table = table[:-1].rstrip()
2135        attnames = self.get_attnames(table)
2136        qoid = _oid_key(table) if 'oid' in attnames else None
2137        if row is None:
2138            row = {}
2139        elif 'oid' in row:
2140            del row['oid']  # only accept oid key from named args for safety
2141        row.update(kw)
2142        if qoid and qoid in row and 'oid' not in row:
2143            row['oid'] = row[qoid]
2144        try:  # try using the primary key
2145            keyname = self.pkey(table, True)
2146        except KeyError:  # the table has no primary key
2147            # try using the oid instead
2148            if qoid and 'oid' in row:
2149                keyname = ('oid',)
2150            else:
2151                raise _prg_error('Table %s has no primary key' % table)
2152        else:  # the table has a primary key
2153            # check whether all key columns have values
2154            if not set(keyname).issubset(row):
2155                # try using the oid instead
2156                if qoid and 'oid' in row:
2157                    keyname = ('oid',)
2158                else:
2159                    raise KeyError('Missing primary key in row')
2160        params = self.adapter.parameter_list()
2161        adapt = params.add
2162        col = self.escape_identifier
2163        where = ' AND '.join('%s = %s' % (
2164            col(k), adapt(row[k], attnames[k])) for k in keyname)
2165        if 'oid' in row:
2166            if qoid:
2167                row[qoid] = row['oid']
2168            del row['oid']
2169        q = 'DELETE FROM %s WHERE %s' % (
2170            self._escape_qualified_name(table), where)
2171        self._do_debug(q, params)
2172        res = self.db.query(q, params)
2173        return int(res)
2174
2175    def truncate(self, table, restart=False, cascade=False, only=False):
2176        """Empty a table or set of tables.
2177
2178        This method quickly removes all rows from the given table or set
2179        of tables.  It has the same effect as an unqualified DELETE on each
2180        table, but since it does not actually scan the tables it is faster.
2181        Furthermore, it reclaims disk space immediately, rather than requiring
2182        a subsequent VACUUM operation. This is most useful on large tables.
2183
2184        If restart is set to True, sequences owned by columns of the truncated
2185        table(s) are automatically restarted.  If cascade is set to True, it
2186        also truncates all tables that have foreign-key references to any of
2187        the named tables.  If the parameter only is not set to True, all the
2188        descendant tables (if any) will also be truncated. Optionally, a '*'
2189        can be specified after the table name to explicitly indicate that
2190        descendant tables are included.
2191        """
2192        if isinstance(table, basestring):
2193            only = {table: only}
2194            table = [table]
2195        elif isinstance(table, (list, tuple)):
2196            if isinstance(only, (list, tuple)):
2197                only = dict(zip(table, only))
2198            else:
2199                only = dict.fromkeys(table, only)
2200        elif isinstance(table, (set, frozenset)):
2201            only = dict.fromkeys(table, only)
2202        else:
2203            raise TypeError('The table must be a string, list or set')
2204        if not (restart is None or isinstance(restart, (bool, int))):
2205            raise TypeError('Invalid type for the restart option')
2206        if not (cascade is None or isinstance(cascade, (bool, int))):
2207            raise TypeError('Invalid type for the cascade option')
2208        tables = []
2209        for t in table:
2210            u = only.get(t)
2211            if not (u is None or isinstance(u, (bool, int))):
2212                raise TypeError('Invalid type for the only option')
2213            if t.endswith('*'):
2214                if u:
2215                    raise ValueError(
2216                        'Contradictory table name and only options')
2217                t = t[:-1].rstrip()
2218            t = self._escape_qualified_name(t)
2219            if u:
2220                t = 'ONLY %s' % t
2221            tables.append(t)
2222        q = ['TRUNCATE', ', '.join(tables)]
2223        if restart:
2224            q.append('RESTART IDENTITY')
2225        if cascade:
2226            q.append('CASCADE')
2227        q = ' '.join(q)
2228        self._do_debug(q)
2229        return self.db.query(q)
2230
2231    def get_as_list(self, table, what=None, where=None,
2232            order=None, limit=None, offset=None, scalar=False):
2233        """Get a table as a list.
2234
2235        This gets a convenient representation of the table as a list
2236        of named tuples in Python.  You only need to pass the name of
2237        the table (or any other SQL expression returning rows).  Note that
2238        by default this will return the full content of the table which
2239        can be huge and overflow your memory.  However, you can control
2240        the amount of data returned using the other optional parameters.
2241
2242        The parameter 'what' can restrict the query to only return a
2243        subset of the table columns.  It can be a string, list or a tuple.
2244        The parameter 'where' can restrict the query to only return a
2245        subset of the table rows.  It can be a string, list or a tuple
2246        of SQL expressions that all need to be fulfilled.  The parameter
2247        'order' specifies the ordering of the rows.  It can also be a
2248        other string, list or a tuple.  If no ordering is specified,
2249        the result will be ordered by the primary key(s) or all columns
2250        if no primary key exists.  You can set 'order' to False if you
2251        don't care about the ordering.  The parameters 'limit' and 'offset'
2252        can be integers specifying the maximum number of rows returned
2253        and a number of rows skipped over.
2254
2255        If you set the 'scalar' option to True, then instead of the
2256        named tuples you will get the first items of these tuples.
2257        This is useful if the result has only one column anyway.
2258        """
2259        if not table:
2260            raise TypeError('The table name is missing')
2261        if what:
2262            if isinstance(what, (list, tuple)):
2263                what = ', '.join(map(str, what))
2264            if order is None:
2265                order = what
2266        else:
2267            what = '*'
2268        q = ['SELECT', what, 'FROM', table]
2269        if where:
2270            if isinstance(where, (list, tuple)):
2271                where = ' AND '.join(map(str, where))
2272            q.extend(['WHERE', where])
2273        if order is None:
2274            try:
2275                order = self.pkey(table, True)
2276            except (KeyError, ProgrammingError):
2277                try:
2278                    order = list(self.get_attnames(table))
2279                except (KeyError, ProgrammingError):
2280                    pass
2281        if order:
2282            if isinstance(order, (list, tuple)):
2283                order = ', '.join(map(str, order))
2284            q.extend(['ORDER BY', order])
2285        if limit:
2286            q.append('LIMIT %d' % limit)
2287        if offset:
2288            q.append('OFFSET %d' % offset)
2289        q = ' '.join(q)
2290        self._do_debug(q)
2291        q = self.db.query(q)
2292        res = q.namedresult()
2293        if res and scalar:
2294            res = [row[0] for row in res]
2295        return res
2296
2297    def get_as_dict(self, table, keyname=None, what=None, where=None,
2298            order=None, limit=None, offset=None, scalar=False):
2299        """Get a table as a dictionary.
2300
2301        This method is similar to get_as_list(), but returns the table
2302        as a Python dict instead of a Python list, which can be even
2303        more convenient. The primary key column(s) of the table will
2304        be used as the keys of the dictionary, while the other column(s)
2305        will be the corresponding values.  The keys will be named tuples
2306        if the table has a composite primary key.  The rows will be also
2307        named tuples unless the 'scalar' option has been set to True.
2308        With the optional parameter 'keyname' you can specify an alternative
2309        set of columns to be used as the keys of the dictionary.  It must
2310        be set as a string, list or a tuple.
2311
2312        If the Python version supports it, the dictionary will be an
2313        OrderedDict using the order specified with the 'order' parameter
2314        or the key column(s) if not specified.  You can set 'order' to False
2315        if you don't care about the ordering.  In this case the returned
2316        dictionary will be an ordinary one.
2317        """
2318        if not table:
2319            raise TypeError('The table name is missing')
2320        if not keyname:
2321            try:
2322                keyname = self.pkey(table, True)
2323            except (KeyError, ProgrammingError):
2324                raise _prg_error('Table %s has no primary key' % table)
2325        if isinstance(keyname, basestring):
2326            keyname = [keyname]
2327        elif not isinstance(keyname, (list, tuple)):
2328            raise KeyError('The keyname must be a string, list or tuple')
2329        if what:
2330            if isinstance(what, (list, tuple)):
2331                what = ', '.join(map(str, what))
2332            if order is None:
2333                order = what
2334        else:
2335            what = '*'
2336        q = ['SELECT', what, 'FROM', table]
2337        if where:
2338            if isinstance(where, (list, tuple)):
2339                where = ' AND '.join(map(str, where))
2340            q.extend(['WHERE', where])
2341        if order is None:
2342            order = keyname
2343        if order:
2344            if isinstance(order, (list, tuple)):
2345                order = ', '.join(map(str, order))
2346            q.extend(['ORDER BY', order])
2347        if limit:
2348            q.append('LIMIT %d' % limit)
2349        if offset:
2350            q.append('OFFSET %d' % offset)
2351        q = ' '.join(q)
2352        self._do_debug(q)
2353        q = self.db.query(q)
2354        res = q.getresult()
2355        cls = OrderedDict if order else dict
2356        if not res:
2357            return cls()
2358        keyset = set(keyname)
2359        fields = q.listfields()
2360        if not keyset.issubset(fields):
2361            raise KeyError('Missing keyname in row')
2362        keyind, rowind = [], []
2363        for i, f in enumerate(fields):
2364            (keyind if f in keyset else rowind).append(i)
2365        keytuple = len(keyind) > 1
2366        getkey = itemgetter(*keyind)
2367        keys = map(getkey, res)
2368        if scalar:
2369            rowind = rowind[:1]
2370            rowtuple = False
2371        else:
2372            rowtuple = len(rowind) > 1
2373        if scalar or rowtuple:
2374            getrow = itemgetter(*rowind)
2375        else:
2376            rowind = rowind[0]
2377            getrow = lambda row: (row[rowind],)
2378            rowtuple = True
2379        rows = map(getrow, res)
2380        if keytuple or rowtuple:
2381            namedresult = get_namedresult()
2382            if namedresult:
2383                if keytuple:
2384                    keys = namedresult(_MemoryQuery(keys, keyname))
2385                if rowtuple:
2386                    fields = [f for f in fields if f not in keyset]
2387                    rows = namedresult(_MemoryQuery(rows, fields))
2388        return cls(zip(keys, rows))
2389
2390    def notification_handler(self,
2391            event, callback, arg_dict=None, timeout=None, stop_event=None):
2392        """Get notification handler that will run the given callback."""
2393        return NotificationHandler(self,
2394            event, callback, arg_dict, timeout, stop_event)
2395
2396
2397# if run as script, print some information
2398
2399if __name__ == '__main__':
2400    print('PyGreSQL version' + version)
2401    print('')
2402    print(__doc__)
Note: See TracBrowser for help on using the repository browser.