source: trunk/pg.py @ 815

Last change on this file since 815 was 815, checked in by cito, 4 years ago

PEP8 recommends not assigning lambda expressions

  • Property svn:keywords set to Id
File size: 86.3 KB
Line 
1#! /usr/bin/python
2#
3# pg.py
4#
5# $Id: pg.py 815 2016-02-03 21:11:13Z cito $
6#
7
8"""PyGreSQL classic interface.
9
10This pg module implements some basic database management stuff.
11It includes the _pg module and builds on it, providing the higher
12level wrapper class named DB with additional functionality.
13This is known as the "classic" ("old style") PyGreSQL interface.
14For a DB-API 2 compliant interface use the newer pgdb module.
15"""
16
17# Copyright (c) 1997-2016 by D'Arcy J.M. Cain.
18#
19# Contributions made by Ch. Zwerschke and others.
20#
21# The notification handler is based on pgnotify which is
22# Copyright (c) 2001 Ng Pheng Siong. All rights reserved.
23#
24# Permission to use, copy, modify, and distribute this software and its
25# documentation for any purpose and without fee is hereby granted,
26# provided that the above copyright notice appear in all copies and that
27# both that copyright notice and this permission notice appear in
28# supporting documentation.
29
30from __future__ import print_function
31
32from _pg import *
33
34import select
35import warnings
36
37from datetime import date, time, datetime, timedelta
38from decimal import Decimal
39from math import isnan, isinf
40from collections import namedtuple
41from operator import itemgetter
42from functools import partial
43from re import compile as regex
44from json import loads as jsondecode, dumps as jsonencode
45
46try:
47    long
48except NameError:  # Python >= 3.0
49    long = int
50
51try:
52    basestring
53except NameError:  # Python >= 3.0
54    basestring = (str, bytes)
55
56
57# Auxiliary classes and functions that are independent from a DB connection:
58
59try:
60    from collections import OrderedDict
61except ImportError:  # Python 2.6 or 3.0
62    OrderedDict = dict
63
64
65    class AttrDict(dict):
66        """Simple read-only ordered dictionary for storing attribute names."""
67
68        def __init__(self, *args, **kw):
69            if len(args) > 1 or kw:
70                raise TypeError
71            items = args[0] if args else []
72            if isinstance(items, dict):
73                raise TypeError
74            items = list(items)
75            self._keys = [item[0] for item in items]
76            dict.__init__(self, items)
77            self._read_only = True
78            error = self._read_only_error
79            self.clear = self.update = error
80            self.pop = self.setdefault = self.popitem = error
81
82        def __setitem__(self, key, value):
83            if self._read_only:
84                self._read_only_error()
85            dict.__setitem__(self, key, value)
86
87        def __delitem__(self, key):
88            if self._read_only:
89                self._read_only_error()
90            dict.__delitem__(self, key)
91
92        def __iter__(self):
93            return iter(self._keys)
94
95        def keys(self):
96            return list(self._keys)
97
98        def values(self):
99            return [self[key] for key in self]
100
101        def items(self):
102            return [(key, self[key]) for key in self]
103
104        def iterkeys(self):
105            return self.__iter__()
106
107        def itervalues(self):
108            return iter(self.values())
109
110        def iteritems(self):
111            return iter(self.items())
112
113        @staticmethod
114        def _read_only_error(*args, **kw):
115            raise TypeError('This object is read-only')
116
117else:
118
119     class AttrDict(OrderedDict):
120        """Simple read-only ordered dictionary for storing attribute names."""
121
122        def __init__(self, *args, **kw):
123            self._read_only = False
124            OrderedDict.__init__(self, *args, **kw)
125            self._read_only = True
126            error = self._read_only_error
127            self.clear = self.update = error
128            self.pop = self.setdefault = self.popitem = error
129
130        def __setitem__(self, key, value):
131            if self._read_only:
132                self._read_only_error()
133            OrderedDict.__setitem__(self, key, value)
134
135        def __delitem__(self, key):
136            if self._read_only:
137                self._read_only_error()
138            OrderedDict.__delitem__(self, key)
139
140        @staticmethod
141        def _read_only_error(*args, **kw):
142            raise TypeError('This object is read-only')
143
144try:
145    from inspect import signature
146except ImportError:  # Python < 3.3
147    from inspect import getargspec
148
149    def get_args(func):
150        return getargspec(func).args
151else:
152
153    def get_args(func):
154        return list(signature(func).parameters)
155
156try:
157    if datetime.strptime('+0100', '%z') is None:
158        raise ValueError
159except ValueError:  # Python < 3.2
160    timezones = None
161else:
162    # time zones used in Postgres timestamptz output
163    timezones = dict(CET='+0100', EET='+0200', EST='-0500',
164        GMT='+0000', HST='-1000', MET='+0100', MST='-0700',
165        UCT='+0000', UTC='+0000', WET='+0000')
166
167
168def _oid_key(table):
169    """Build oid key from a table name."""
170    return 'oid(%s)' % table
171
172
173class _SimpleTypes(dict):
174    """Dictionary mapping pg_type names to simple type names."""
175
176    _types = {'bool': 'bool',
177        'bytea': 'bytea',
178        'date': 'date interval time timetz timestamp timestamptz'
179            ' abstime reltime',  # these are very old
180        'float': 'float4 float8',
181        'int': 'cid int2 int4 int8 oid xid',
182        'json': 'json jsonb',
183        'num': 'numeric',
184        'money': 'money',
185        'text': 'bpchar char name text varchar'}
186
187    def __init__(self):
188        for typ, keys in self._types.items():
189            for key in keys.split():
190                self[key] = typ
191                self['_%s' % key] = '%s[]' % typ
192
193    @staticmethod
194    def __missing__(key):
195        return 'text'
196
197_simpletypes = _SimpleTypes()
198
199
200def _quote_if_unqualified(param, name):
201    """Quote parameter representing a qualified name.
202
203    Puts a quote_ident() call around the give parameter unless
204    the name contains a dot, in which case the name is ambiguous
205    (could be a qualified name or just a name with a dot in it)
206    and must be quoted manually by the caller.
207    """
208    if isinstance(name, basestring) and '.' not in name:
209        return 'quote_ident(%s)' % (param,)
210    return param
211
212
213class _ParameterList(list):
214    """Helper class for building typed parameter lists."""
215
216    def add(self, value, typ=None):
217        """Typecast value with known database type and build parameter list.
218
219        If this is a literal value, it will be returned as is.  Otherwise, a
220        placeholder will be returned and the parameter list will be augmented.
221        """
222        value = self.adapt(value, typ)
223        if isinstance(value, Literal):
224            return value
225        self.append(value)
226        return '$%d' % len(self)
227
228
229class Literal(str):
230    """Wrapper class for marking literal SQL values."""
231
232
233class Json:
234    """Wrapper class for marking Json values."""
235
236    def __init__(self, obj):
237        self.obj = obj
238
239
240class Bytea(bytes):
241    """Wrapper class for marking Bytea values."""
242
243
244class Adapter:
245    """Class providing methods for adapting parameters to the database."""
246
247    _bool_true_values = frozenset('t true 1 y yes on'.split())
248
249    _date_literals = frozenset('current_date current_time'
250        ' current_timestamp localtime localtimestamp'.split())
251
252    _re_array_quote = regex(r'[{},"\\\s]|^[Nn][Uu][Ll][Ll]$')
253    _re_record_quote = regex(r'[(,"\\]')
254    _re_array_escape = _re_record_escape = regex(r'(["\\])')
255
256    def __init__(self, db):
257        self.db = db
258        self.encode_json = db.encode_json
259        db = db.db
260        self.escape_bytea = db.escape_bytea
261        self.escape_string = db.escape_string
262
263    @classmethod
264    def _adapt_bool(cls, v):
265        """Adapt a boolean parameter."""
266        if isinstance(v, basestring):
267            if not v:
268                return None
269            v = v.lower() in cls._bool_true_values
270        return 't' if v else 'f'
271
272    @classmethod
273    def _adapt_date(cls, v):
274        """Adapt a date parameter."""
275        if not v:
276            return None
277        if isinstance(v, basestring) and v.lower() in cls._date_literals:
278            return Literal(v)
279        return v
280
281    @staticmethod
282    def _adapt_num(v):
283        """Adapt a numeric parameter."""
284        if not v and v != 0:
285            return None
286        return v
287
288    _adapt_int = _adapt_float = _adapt_money = _adapt_num
289
290    def _adapt_bytea(self, v):
291        """Adapt a bytea parameter."""
292        return self.escape_bytea(v)
293
294    def _adapt_json(self, v):
295        """Adapt a json parameter."""
296        if not v:
297            return None
298        if isinstance(v, basestring):
299            return v
300        return self.encode_json(v)
301
302    @classmethod
303    def _adapt_text_array(cls, v):
304        """Adapt a text type array parameter."""
305        if isinstance(v, list):
306            adapt = cls._adapt_text_array
307            return '{%s}' % ','.join(adapt(v) for v in v)
308        if v is None:
309            return 'null'
310        if not v:
311            return '""'
312        v = str(v)
313        if cls._re_array_quote.search(v):
314            v = '"%s"' % cls._re_array_escape.sub(r'\\\1', v)
315        return v
316
317    _adapt_date_array = _adapt_text_array
318
319    @classmethod
320    def _adapt_bool_array(cls, v):
321        """Adapt a boolean array parameter."""
322        if isinstance(v, list):
323            adapt = cls._adapt_bool_array
324            return '{%s}' % ','.join(adapt(v) for v in v)
325        if v is None:
326            return 'null'
327        if isinstance(v, basestring):
328            if not v:
329                return 'null'
330            v = v.lower() in cls._bool_true_values
331        return 't' if v else 'f'
332
333    @classmethod
334    def _adapt_num_array(cls, v):
335        """Adapt a numeric array parameter."""
336        if isinstance(v, list):
337            adapt = cls._adapt_num_array
338            return '{%s}' % ','.join(adapt(v) for v in v)
339        if not v and v != 0:
340            return 'null'
341        return str(v)
342
343    _adapt_int_array = _adapt_float_array = _adapt_money_array = \
344            _adapt_num_array
345
346    def _adapt_bytea_array(self, v):
347        """Adapt a bytea array parameter."""
348        if isinstance(v, list):
349            return b'{' + b','.join(
350                self._adapt_bytea_array(v) for v in v) + b'}'
351        if v is None:
352            return b'null'
353        return self.escape_bytea(v).replace(b'\\', b'\\\\')
354
355    def _adapt_json_array(self, v):
356        """Adapt a json array parameter."""
357        if isinstance(v, list):
358            adapt = self._adapt_json_array
359            return '{%s}' % ','.join(adapt(v) for v in v)
360        if not v:
361            return 'null'
362        if not isinstance(v, basestring):
363            v = self.encode_json(v)
364        if self._re_array_quote.search(v):
365            v = '"%s"' % self._re_array_escape.sub(r'\\\1', v)
366        return v
367
368    def _adapt_record(self, v, typ):
369        """Adapt a record parameter with given type."""
370        typ = self.get_attnames(typ).values()
371        if len(typ) != len(v):
372            raise TypeError('Record parameter %s has wrong size' % v)
373        adapt = self.adapt
374        value = []
375        for v, t in zip(v, typ):
376            v = adapt(v, t)
377            if v is None:
378                v = ''
379            elif not v:
380                v = '""'
381            else:
382                if isinstance(v, bytes):
383                    if str is not bytes:
384                        v = v.decode('ascii')
385                else:
386                    v = str(v)
387                if self._re_record_quote.search(v):
388                    v = '"%s"' % self._re_record_escape.sub(r'\\\1', v)
389            value.append(v)
390        return '(%s)' % ','.join(value)
391
392    def adapt(self, value, typ=None):
393        """Adapt a value with known database type."""
394        if value is not None and not isinstance(value, Literal):
395            if typ:
396                simple = self.get_simple_name(typ)
397            else:
398                typ = simple = self.guess_simple_type(value) or 'text'
399            try:
400                value = value.__pg_str__(typ)
401            except AttributeError:
402                pass
403            if simple == 'text':
404                pass
405            elif simple == 'record':
406                if isinstance(value, tuple):
407                    value = self._adapt_record(value, typ)
408            elif simple.endswith('[]'):
409                if isinstance(value, list):
410                    adapt = getattr(self, '_adapt_%s_array' % simple[:-2])
411                    value = adapt(value)
412            else:
413                adapt = getattr(self, '_adapt_%s' % simple)
414                value = adapt(value)
415        return value
416
417    @staticmethod
418    def simple_type(name):
419        """Create a simple database type with given attribute names."""
420        typ = DbType(name)
421        typ.simple = name
422        return typ
423
424    @staticmethod
425    def get_simple_name(typ):
426        """Get the simple name of a database type."""
427        if isinstance(typ, DbType):
428            return typ.simple
429        return _simpletypes[typ]
430
431    @staticmethod
432    def get_attnames(typ):
433        """Get the attribute names of a composite database type."""
434        if isinstance(typ, DbType):
435            return typ.attnames
436        return {}
437
438    @classmethod
439    def guess_simple_type(cls, value):
440        """Try to guess which database type the given value has."""
441        if isinstance(value, Bytea):
442            return 'bytea'
443        if isinstance(value, basestring):
444            return 'text'
445        if isinstance(value, bool):
446            return 'bool'
447        if isinstance(value, (int, long)):
448            return 'int'
449        if isinstance(value, float):
450            return 'float'
451        if isinstance(value, Decimal):
452            return 'num'
453        if isinstance(value, (date, time, datetime, timedelta)):
454            return 'date'
455        if isinstance(value, list):
456            return '%s[]' % cls.guess_simple_base_type(value)
457        if isinstance(value, tuple):
458            simple_type = cls.simple_type
459            typ = simple_type('record')
460            guess = cls.guess_simple_type
461            def get_attnames(self):
462                return AttrDict((str(n + 1), simple_type(guess(v)))
463                    for n, v in enumerate(value))
464            typ._get_attnames = get_attnames
465            return typ
466
467    @classmethod
468    def guess_simple_base_type(cls, value):
469        """Try to guess the base type of a given array."""
470        for v in value:
471            if isinstance(v, list):
472                typ = cls.guess_simple_base_type(v)
473            else:
474                typ = cls.guess_simple_type(v)
475            if typ:
476                return typ
477
478    def adapt_inline(self, value, nested=False):
479        """Adapt a value that is put into the SQL and needs to be quoted."""
480        if value is None:
481            return 'NULL'
482        if isinstance(value, Literal):
483            return value
484        if isinstance(value, Bytea):
485            value = self.escape_bytea(value)
486            if bytes is not str:  # Python >= 3.0
487                value = value.decode('ascii')
488        elif isinstance(value, Json):
489            if value.encode:
490                return value.encode()
491            value = self.encode_json(value)
492        elif isinstance(value, (datetime, date, time, timedelta)):
493            value = str(value)
494        if isinstance(value, basestring):
495            value = self.escape_string(value)
496            return "'%s'" % value
497        if isinstance(value, bool):
498            return 'true' if value else 'false'
499        if isinstance(value, float):
500            if isinf(value):
501                return "'-Infinity'" if value < 0 else "'Infinity'"
502            if isnan(value):
503                return "'NaN'"
504            return value
505        if isinstance(value, (int, long, Decimal)):
506            return value
507        if isinstance(value, list):
508            q = self.adapt_inline
509            s = '[%s]' if nested else 'ARRAY[%s]'
510            return s % ','.join(str(q(v, nested=True)) for v in value)
511        if isinstance(value, tuple):
512            q = self.adapt_inline
513            return '(%s)' % ','.join(str(q(v)) for v in value)
514        try:
515            value = value.__pg_repr__()
516        except AttributeError:
517            raise InterfaceError(
518                'Do not know how to adapt type %s' % type(value))
519        if isinstance(value, (tuple, list)):
520            value = self.adapt_inline(value)
521        return value
522
523    def parameter_list(self):
524        """Return a parameter list for parameters with known database types.
525
526        The list has an add(value, typ) method that will build up the
527        list and return either the literal value or a placeholder.
528        """
529        params = _ParameterList()
530        params.adapt = self.adapt
531        return params
532
533    def format_query(self, command, values, types=None, inline=False):
534        """Format a database query using the given values and types."""
535        if inline and types:
536            raise ValueError('Typed parameters must be sent separately')
537        params = self.parameter_list()
538        if isinstance(values, (list, tuple)):
539            if inline:
540                adapt = self.adapt_inline
541                literals = [adapt(value) for value in values]
542            else:
543                add = params.add
544                literals = []
545                append = literals.append
546                if types:
547                    if (not isinstance(types, (list, tuple)) or
548                            len(types) != len(values)):
549                        raise TypeError('The values and types do not match')
550                    for value, typ in zip(values, types):
551                        append(add(value, typ))
552                else:
553                    for value in values:
554                        append(add(value))
555            command = command % tuple(literals)
556        elif isinstance(values, dict):
557            if inline:
558                adapt = self.adapt_inline
559                literals = dict((key, adapt(value))
560                    for key, value in values.items())
561            else:
562                add = params.add
563                literals = {}
564                if types:
565                    if (not isinstance(types, dict) or
566                            len(types) < len(values)):
567                        raise TypeError('The values and types do not match')
568                    for key in sorted(values):
569                        literals[key] = add(values[key], types[key])
570                else:
571                    for key in sorted(values):
572                        literals[key] = add(values[key])
573            command = command % literals
574        else:
575            raise TypeError('The values must be passed as tuple, list or dict')
576        return command, params
577
578
579def cast_bool(value):
580    """Cast a boolean value."""
581    if not get_bool():
582        return value
583    return value[0] == 't'
584
585
586def cast_json(value):
587    """Cast a JSON value."""
588    cast = get_jsondecode()
589    if not cast:
590        return value
591    return cast(value)
592
593
594def cast_num(value):
595    """Cast a numeric value."""
596    return (get_decimal() or float)(value)
597
598
599def cast_money(value):
600    """Cast a money value."""
601    point = get_decimal_point()
602    if not point:
603        return value
604    if point != '.':
605        value = value.replace(point, '.')
606    value = value.replace('(', '-')
607    value = ''.join(c for c in value if c.isdigit() or c in '.-')
608    return (get_decimal() or float)(value)
609
610
611def cast_int2vector(value):
612    """Cast an int2vector value."""
613    return [int(v) for v in value.split()]
614
615
616def cast_date(value, connection):
617    """Cast a date value."""
618    # The output format depends on the server setting DateStyle.  The default
619    # setting ISO and the setting for German are actually unambiguous.  The
620    # order of days and months in the other two settings is however ambiguous,
621    # so at least here we need to consult the setting to properly parse values.
622    if value == '-infinity':
623        return date.min
624    if value == 'infinity':
625        return date.max
626    value = value.split()
627    if value[-1] == 'BC':
628        return date.min
629    value = value[0]
630    if len(value) > 10:
631        return date.max
632    fmt = connection.date_format()
633    return datetime.strptime(value, fmt).date()
634
635
636def cast_time(value):
637    """Cast a time value."""
638    fmt = '%H:%M:%S.%f' if len(value) > 8 else '%H:%M:%S'
639    return datetime.strptime(value, fmt).time()
640
641
642_re_timezone = regex('(.*)([+-].*)')
643
644
645def cast_timetz(value):
646    """Cast a timetz value."""
647    tz = _re_timezone.match(value)
648    if tz:
649        value, tz = tz.groups()
650    else:
651        tz = '+0000'
652    fmt = '%H:%M:%S.%f' if len(value) > 8 else '%H:%M:%S'
653    if timezones:
654        if tz.startswith(('+', '-')):
655            if len(tz) < 5:
656                tz += '00'
657            else:
658                tz = tz.replace(':', '')
659        elif tz in timezones:
660            tz = timezones[tz]
661        else:
662            tz = '+0000'
663        value += tz
664        fmt += '%z'
665    return datetime.strptime(value, fmt).timetz()
666
667
668def cast_timestamp(value, connection):
669    """Cast a timestamp value."""
670    if value == '-infinity':
671        return datetime.min
672    if value == 'infinity':
673        return datetime.max
674    value = value.split()
675    if value[-1] == 'BC':
676        return datetime.min
677    fmt = connection.date_format()
678    if fmt.endswith('-%Y') and len(value) > 2:
679        value = value[1:5]
680        if len(value[3]) > 4:
681            return datetime.max
682        fmt = ['%d %b' if fmt.startswith('%d') else '%b %d',
683            '%H:%M:%S.%f' if len(value[2]) > 8 else '%H:%M:%S', '%Y']
684    else:
685        if len(value[0]) > 10:
686            return datetime.max
687        fmt = [fmt, '%H:%M:%S.%f' if len(value[1]) > 8 else '%H:%M:%S']
688    return datetime.strptime(' '.join(value), ' '.join(fmt))
689
690
691def cast_timestamptz(value, connection):
692    """Cast a timestamptz value."""
693    if value == '-infinity':
694        return datetime.min
695    if value == 'infinity':
696        return datetime.max
697    value = value.split()
698    if value[-1] == 'BC':
699        return datetime.min
700    fmt = connection.date_format()
701    if fmt.endswith('-%Y') and len(value) > 2:
702        value = value[1:]
703        if len(value[3]) > 4:
704            return datetime.max
705        fmt = ['%d %b' if fmt.startswith('%d') else '%b %d',
706            '%H:%M:%S.%f' if len(value[2]) > 8 else '%H:%M:%S', '%Y']
707        value, tz = value[:-1], value[-1]
708    else:
709        if fmt.startswith('%Y-'):
710            tz = _re_timezone.match(value[1])
711            if tz:
712                value[1], tz = tz.groups()
713            else:
714                tz = '+0000'
715        else:
716            value, tz = value[:-1], value[-1]
717        if len(value[0]) > 10:
718            return datetime.max
719        fmt = [fmt, '%H:%M:%S.%f' if len(value[1]) > 8 else '%H:%M:%S']
720    if timezones:
721        if tz.startswith(('+', '-')):
722            if len(tz) < 5:
723                tz += '00'
724            else:
725                tz = tz.replace(':', '')
726        elif tz in timezones:
727            tz = timezones[tz]
728        else:
729            tz = '+0000'
730        value.append(tz)
731        fmt.append('%z')
732    return datetime.strptime(' '.join(value), ' '.join(fmt))
733
734_re_interval_sql_standard = regex(
735    '(?:([+-])?([0-9]+)-([0-9]+) ?)?'
736    '(?:([+-]?[0-9]+)(?!:) ?)?'
737    '(?:([+-])?([0-9]+):([0-9]+):([0-9]+)(?:\\.([0-9]+))?)?')
738
739_re_interval_postgres = regex(
740    '(?:([+-]?[0-9]+) ?years? ?)?'
741    '(?:([+-]?[0-9]+) ?mons? ?)?'
742    '(?:([+-]?[0-9]+) ?days? ?)?'
743    '(?:([+-])?([0-9]+):([0-9]+):([0-9]+)(?:\\.([0-9]+))?)?')
744
745_re_interval_postgres_verbose = regex(
746    '@ ?(?:([+-]?[0-9]+) ?years? ?)?'
747    '(?:([+-]?[0-9]+) ?mons? ?)?'
748    '(?:([+-]?[0-9]+) ?days? ?)?'
749    '(?:([+-]?[0-9]+) ?hours? ?)?'
750    '(?:([+-]?[0-9]+) ?mins? ?)?'
751    '(?:([+-])?([0-9]+)(?:\\.([0-9]+))? ?secs?)? ?(ago)?')
752
753_re_interval_iso_8601 = regex(
754    'P(?:([+-]?[0-9]+)Y)?'
755    '(?:([+-]?[0-9]+)M)?'
756    '(?:([+-]?[0-9]+)D)?'
757    '(?:T(?:([+-]?[0-9]+)H)?'
758    '(?:([+-]?[0-9]+)M)?'
759    '(?:([+-])?([0-9]+)(?:\\.([0-9]+))?S)?)?')
760
761
762def cast_interval(value):
763    """Cast an interval value."""
764    # The output format depends on the server setting IntervalStyle, but it's
765    # not necessary to consult this setting to parse it.  It's faster to just
766    # check all possible formats, and there is no ambiguity here.
767    m = _re_interval_iso_8601.match(value)
768    if m:
769        m = [d or '0' for d in m.groups()]
770        secs_ago = m.pop(5) == '-'
771        m = [int(d) for d in m]
772        years, mons, days, hours, mins, secs, usecs = m
773        if secs_ago:
774            secs = -secs
775            usecs = -usecs
776    else:
777        m = _re_interval_postgres_verbose.match(value)
778        if m:
779            m, ago = [d or '0' for d in m.groups()[:8]], m.group(9)
780            secs_ago = m.pop(5) == '-'
781            m = [-int(d) for d in m] if ago else [int(d) for d in m]
782            years, mons, days, hours, mins, secs, usecs = m
783            if secs_ago:
784                secs = - secs
785                usecs = -usecs
786        else:
787            m = _re_interval_postgres.match(value)
788            if m and any(m.groups()):
789                m = [d or '0' for d in m.groups()]
790                hours_ago = m.pop(3) == '-'
791                m = [int(d) for d in m]
792                years, mons, days, hours, mins, secs, usecs = m
793                if hours_ago:
794                    hours = -hours
795                    mins = -mins
796                    secs = -secs
797                    usecs = -usecs
798            else:
799                m = _re_interval_sql_standard.match(value)
800                if m and any(m.groups()):
801                    m = [d or '0' for d in m.groups()]
802                    years_ago = m.pop(0) == '-'
803                    hours_ago = m.pop(3) == '-'
804                    m = [int(d) for d in m]
805                    years, mons, days, hours, mins, secs, usecs = m
806                    if years_ago:
807                        years = -years
808                        mons = -mons
809                    if hours_ago:
810                        hours = -hours
811                        mins = -mins
812                        secs = -secs
813                        usecs = -usecs
814                else:
815                    raise ValueError('Cannot parse interval: %s' % value)
816    days += 365 * years + 30 * mons
817    return timedelta(days=days, hours=hours, minutes=mins,
818        seconds=secs, microseconds=usecs)
819
820
821class Typecasts(dict):
822    """Dictionary mapping database types to typecast functions.
823
824    The cast functions get passed the string representation of a value in
825    the database which they need to convert to a Python object.  The
826    passed string will never be None since NULL values are already be
827    handled before the cast function is called.
828
829    Note that the basic types are already handled by the C extension.
830    They only need to be handled here as record or array components.
831    """
832
833    # the default cast functions
834    # (str functions are ignored but have been added for faster access)
835    defaults = {'char': str, 'bpchar': str, 'name': str,
836        'text': str, 'varchar': str,
837        'bool': cast_bool, 'bytea': unescape_bytea,
838        'int2': int, 'int4': int, 'serial': int,
839        'int8': long, 'json': cast_json, 'jsonb': cast_json,
840        'oid': long, 'oid8': long,
841        'float4': float, 'float8': float,
842        'numeric': cast_num, 'money': cast_money,
843        'date': cast_date, 'interval': cast_interval,
844        'time': cast_time, 'timetz': cast_timetz,
845        'timestamp': cast_timestamp, 'timestamptz': cast_timestamptz,
846        'int2vector': cast_int2vector,
847        'anyarray': cast_array, 'record': cast_record}
848
849    connection = None  # will be set in a connection specific instance
850
851    def __missing__(self, typ):
852        """Create a cast function if it is not cached.
853       
854        Note that this class never raises a KeyError,
855        but returns None when no special cast function exists.
856        """
857        if not isinstance(typ, str):
858            raise TypeError('Invalid type: %s' % typ)
859        cast = self.defaults.get(typ)
860        if cast:
861            # store default for faster access
862            cast = self._add_connection(cast)
863            self[typ] = cast
864        elif typ.startswith('_'):
865            base_cast = self[typ[1:]]
866            cast = self.create_array_cast(base_cast)
867            if base_cast:
868                self[typ] = cast
869        else:
870            attnames = self.get_attnames(typ)
871            if attnames:
872                casts = [self[v.pgtype] for v in attnames.values()]
873                cast = self.create_record_cast(typ, attnames, casts)
874                self[typ] = cast
875        return cast
876
877    @staticmethod
878    def _needs_connection(func):
879        """Check if a typecast function needs a connection argument."""
880        try:
881            args = get_args(func)
882        except (TypeError, ValueError):
883            return False
884        else:
885            return 'connection' in args[1:]
886
887    def _add_connection(self, cast):
888        """Add a connection argument to the typecast function if necessary."""
889        if not self.connection or not self._needs_connection(cast):
890            return cast
891        return partial(cast, connection=self.connection)
892
893    def get(self, typ, default=None):
894        """Get the typecast function for the given database type."""
895        return self[typ] or default
896
897    def set(self, typ, cast):
898        """Set a typecast function for the specified database type(s)."""
899        if isinstance(typ, basestring):
900            typ = [typ]
901        if cast is None:
902            for t in typ:
903                self.pop(t, None)
904                self.pop('_%s' % t, None)
905        else:
906            if not callable(cast):
907                raise TypeError("Cast parameter must be callable")
908            for t in typ:
909                self[t] = self._add_connection(cast)
910                self.pop('_%s' % t, None)
911
912    def reset(self, typ=None):
913        """Reset the typecasts for the specified type(s) to their defaults.
914
915        When no type is specified, all typecasts will be reset.
916        """
917        if typ is None:
918            self.clear()
919        else:
920            if isinstance(typ, basestring):
921                typ = [typ]
922            for t in typ:
923                self.pop(t, None)
924
925    @classmethod
926    def get_default(cls, typ):
927        """Get the default typecast function for the given database type."""
928        return cls.defaults.get(typ)
929
930    @classmethod
931    def set_default(cls, typ, cast):
932        """Set a default typecast function for the given database type(s)."""
933        if isinstance(typ, basestring):
934            typ = [typ]
935        defaults = cls.defaults
936        if cast is None:
937            for t in typ:
938                defaults.pop(t, None)
939                defaults.pop('_%s' % t, None)
940        else:
941            if not callable(cast):
942                raise TypeError("Cast parameter must be callable")
943            for t in typ:
944                defaults[t] = cast
945                defaults.pop('_%s' % t, None)
946
947    def get_attnames(self, typ):
948        """Return the fields for the given record type.
949
950        This method will be replaced with the get_attnames() method of DbTypes.
951        """
952        return {}
953
954    def dateformat(self):
955        """Return the current date format.
956
957        This method will be replaced with the dateformat() method of DbTypes.
958        """
959        return '%Y-%m-%d'
960
961    def create_array_cast(self, basecast):
962        """Create an array typecast for the given base cast."""
963        def cast(v):
964            return cast_array(v, basecast)
965        return cast
966
967    def create_record_cast(self, name, fields, casts):
968        """Create a named record typecast for the given fields and casts."""
969        record = namedtuple(name, fields)
970        def cast(v):
971            return record(*cast_record(v, casts))
972        return cast
973
974
975def get_typecast(typ):
976    """Get the global typecast function for the given database type(s)."""
977    return Typecasts.get_default(typ)
978
979
980def set_typecast(typ, cast):
981    """Set a global typecast function for the given database type(s).
982
983    Note that connections cache cast functions. To be sure a global change
984    is picked up by a running connection, call db.db_types.reset_typecast().
985    """
986    Typecasts.set_default(typ, cast)
987
988
989class DbType(str):
990    """Class augmenting the simple type name with additional info.
991
992    The following additional information is provided:
993
994        oid: the PostgreSQL type OID
995        pgtype: the PostgreSQL type name
996        regtype: the regular type name
997        simple: the simple PyGreSQL type name
998        typtype: b = base type, c = composite type etc.
999        category: A = Array, b =Boolean, C = Composite etc.
1000        delim: delimiter for array types
1001        relid: corresponding table for composite types
1002        attnames: attributes for composite types
1003    """
1004
1005    @property
1006    def attnames(self):
1007        """Get names and types of the fields of a composite type."""
1008        return self._get_attnames(self)
1009
1010
1011class DbTypes(dict):
1012    """Cache for PostgreSQL data types.
1013
1014    This cache maps type OIDs and names to DbType objects containing
1015    information on the associated database type.
1016    """
1017
1018    _num_types = frozenset('int float num money'
1019        ' int2 int4 int8 float4 float8 numeric money'.split())
1020
1021    def __init__(self, db):
1022        """Initialize type cache for connection."""
1023        super(DbTypes, self).__init__()
1024        self._regtypes = False
1025        self._get_attnames = db.get_attnames
1026        self._typecasts = Typecasts()
1027        self._typecasts.get_attnames = self.get_attnames
1028        self._typecasts.connection = db
1029        db = db.db
1030        self.query = db.query
1031        self.escape_string = db.escape_string
1032
1033    def add(self, oid, pgtype, regtype,
1034               typtype, category, delim, relid):
1035        """Create a PostgreSQL type name with additional info."""
1036        if oid in self:
1037            return self[oid]
1038        simple = 'record' if relid else _simpletypes[pgtype]
1039        typ = DbType(regtype if self._regtypes else simple)
1040        typ.oid = oid
1041        typ.simple = simple
1042        typ.pgtype = pgtype
1043        typ.regtype = regtype
1044        typ.typtype = typtype
1045        typ.category = category
1046        typ.delim = delim
1047        typ.relid = relid
1048        typ._get_attnames = self.get_attnames
1049        return typ
1050
1051    def __missing__(self, key):
1052        """Get the type info from the database if it is not cached."""
1053        try:
1054            res = self.query("SELECT oid, typname, typname::regtype,"
1055                " typtype, typcategory, typdelim, typrelid"
1056                " FROM pg_type WHERE oid=%s::regtype" %
1057                (_quote_if_unqualified('$1', key),), (key,)).getresult()
1058        except ProgrammingError:
1059            res = None
1060        if not res:
1061            raise KeyError('Type %s could not be found' % key)
1062        res = res[0]
1063        typ = self.add(*res)
1064        self[typ.oid] = self[typ.pgtype] = typ
1065        return typ
1066
1067    def get(self, key, default=None):
1068        """Get the type even if it is not cached."""
1069        try:
1070            return self[key]
1071        except KeyError:
1072            return default
1073
1074    def get_attnames(self, typ):
1075        """Get names and types of the fields of a composite type."""
1076        if not isinstance(typ, DbType):
1077            typ = self.get(typ)
1078            if not typ:
1079                return None
1080        if not typ.relid:
1081            return None
1082        return self._get_attnames(typ.relid, with_oid=False)
1083
1084    def get_typecast(self, typ):
1085        """Get the typecast function for the given database type."""
1086        return self._typecasts.get(typ)
1087
1088    def set_typecast(self, typ, cast):
1089        """Set a typecast function for the specified database type(s)."""
1090        self._typecasts.set(typ, cast)
1091
1092    def reset_typecast(self, typ=None):
1093        """Reset the typecast function for the specified database type(s)."""
1094        self._typecasts.reset(typ)
1095
1096    def typecast(self, value, typ):
1097        """Cast the given value according to the given database type."""
1098        if value is None:
1099            # for NULL values, no typecast is necessary
1100            return None
1101        if not isinstance(typ, DbType):
1102            typ = self.get(typ)
1103            if typ:
1104                typ = typ.pgtype
1105        cast = self.get_typecast(typ) if typ else None
1106        if not cast or cast is str:
1107            # no typecast is necessary
1108            return value
1109        return cast(value)
1110
1111
1112def _namedresult(q):
1113    """Get query result as named tuples."""
1114    row = namedtuple('Row', q.listfields())
1115    return [row(*r) for r in q.getresult()]
1116
1117
1118class _MemoryQuery:
1119    """Class that embodies a given query result."""
1120
1121    def __init__(self, result, fields):
1122        """Create query from given result rows and field names."""
1123        self.result = result
1124        self.fields = fields
1125
1126    def listfields(self):
1127        """Return the stored field names of this query."""
1128        return self.fields
1129
1130    def getresult(self):
1131        """Return the stored result of this query."""
1132        return self.result
1133
1134
1135def _db_error(msg, cls=DatabaseError):
1136    """Return DatabaseError with empty sqlstate attribute."""
1137    error = cls(msg)
1138    error.sqlstate = None
1139    return error
1140
1141
1142def _int_error(msg):
1143    """Return InternalError."""
1144    return _db_error(msg, InternalError)
1145
1146
1147def _prg_error(msg):
1148    """Return ProgrammingError."""
1149    return _db_error(msg, ProgrammingError)
1150
1151
1152# Initialize the C module
1153
1154set_namedresult(_namedresult)
1155set_decimal(Decimal)
1156set_jsondecode(jsondecode)
1157
1158
1159# The notification handler
1160
1161class NotificationHandler(object):
1162    """A PostgreSQL client-side asynchronous notification handler."""
1163
1164    def __init__(self, db, event, callback=None,
1165            arg_dict=None, timeout=None, stop_event=None):
1166        """Initialize the notification handler.
1167
1168        You must pass a PyGreSQL database connection, the name of an
1169        event (notification channel) to listen for and a callback function.
1170
1171        You can also specify a dictionary arg_dict that will be passed as
1172        the single argument to the callback function, and a timeout value
1173        in seconds (a floating point number denotes fractions of seconds).
1174        If it is absent or None, the callers will never time out.  If the
1175        timeout is reached, the callback function will be called with a
1176        single argument that is None.  If you set the timeout to zero,
1177        the handler will poll notifications synchronously and return.
1178
1179        You can specify the name of the event that will be used to signal
1180        the handler to stop listening as stop_event. By default, it will
1181        be the event name prefixed with 'stop_'.
1182        """
1183        self.db = db
1184        self.event = event
1185        self.stop_event = stop_event or 'stop_%s' % event
1186        self.listening = False
1187        self.callback = callback
1188        if arg_dict is None:
1189            arg_dict = {}
1190        self.arg_dict = arg_dict
1191        self.timeout = timeout
1192
1193    def __del__(self):
1194        self.unlisten()
1195
1196    def close(self):
1197        """Stop listening and close the connection."""
1198        if self.db:
1199            self.unlisten()
1200            self.db.close()
1201            self.db = None
1202
1203    def listen(self):
1204        """Start listening for the event and the stop event."""
1205        if not self.listening:
1206            self.db.query('listen "%s"' % self.event)
1207            self.db.query('listen "%s"' % self.stop_event)
1208            self.listening = True
1209
1210    def unlisten(self):
1211        """Stop listening for the event and the stop event."""
1212        if self.listening:
1213            self.db.query('unlisten "%s"' % self.event)
1214            self.db.query('unlisten "%s"' % self.stop_event)
1215            self.listening = False
1216
1217    def notify(self, db=None, stop=False, payload=None):
1218        """Generate a notification.
1219
1220        Optionally, you can pass a payload with the notification.
1221
1222        If you set the stop flag, a stop notification will be sent that
1223        will cause the handler to stop listening.
1224
1225        Note: If the notification handler is running in another thread, you
1226        must pass a different database connection since PyGreSQL database
1227        connections are not thread-safe.
1228        """
1229        if self.listening:
1230            if not db:
1231                db = self.db
1232            q = 'notify "%s"' % (self.stop_event if stop else self.event)
1233            if payload:
1234                q += ", '%s'" % payload
1235            return db.query(q)
1236
1237    def __call__(self):
1238        """Invoke the notification handler.
1239
1240        The handler is a loop that listens for notifications on the event
1241        and stop event channels.  When either of these notifications are
1242        received, its associated 'pid', 'event' and 'extra' (the payload
1243        passed with the notification) are inserted into its arg_dict
1244        dictionary and the callback is invoked with this dictionary as
1245        a single argument.  When the handler receives a stop event, it
1246        stops listening to both events and return.
1247
1248        In the special case that the timeout of the handler has been set
1249        to zero, the handler will poll all events synchronously and return.
1250        If will keep listening until it receives a stop event.
1251
1252        Note: If you run this loop in another thread, don't use the same
1253        database connection for database operations in the main thread.
1254        """
1255        self.listen()
1256        poll = self.timeout == 0
1257        if not poll:
1258            rlist = [self.db.fileno()]
1259        while self.listening:
1260            if poll or select.select(rlist, [], [], self.timeout)[0]:
1261                while self.listening:
1262                    notice = self.db.getnotify()
1263                    if not notice:  # no more messages
1264                        break
1265                    event, pid, extra = notice
1266                    if event not in (self.event, self.stop_event):
1267                        self.unlisten()
1268                        raise _db_error(
1269                            'Listening for "%s" and "%s", but notified of "%s"'
1270                            % (self.event, self.stop_event, event))
1271                    if event == self.stop_event:
1272                        self.unlisten()
1273                    self.arg_dict.update(pid=pid, event=event, extra=extra)
1274                    self.callback(self.arg_dict)
1275                if poll:
1276                    break
1277            else:   # we timed out
1278                self.unlisten()
1279                self.callback(None)
1280
1281
1282def pgnotify(*args, **kw):
1283    """Same as NotificationHandler, under the traditional name."""
1284    warnings.warn("pgnotify is deprecated, use NotificationHandler instead",
1285        DeprecationWarning, stacklevel=2)
1286    return NotificationHandler(*args, **kw)
1287
1288
1289# The actual PostGreSQL database connection interface:
1290
1291class DB:
1292    """Wrapper class for the _pg connection type."""
1293
1294    def __init__(self, *args, **kw):
1295        """Create a new connection
1296
1297        You can pass either the connection parameters or an existing
1298        _pg or pgdb connection. This allows you to use the methods
1299        of the classic pg interface with a DB-API 2 pgdb connection.
1300        """
1301        if not args and len(kw) == 1:
1302            db = kw.get('db')
1303        elif not kw and len(args) == 1:
1304            db = args[0]
1305        else:
1306            db = None
1307        if db:
1308            if isinstance(db, DB):
1309                db = db.db
1310            else:
1311                try:
1312                    db = db._cnx
1313                except AttributeError:
1314                    pass
1315        if not db or not hasattr(db, 'db') or not hasattr(db, 'query'):
1316            db = connect(*args, **kw)
1317            self._closeable = True
1318        else:
1319            self._closeable = False
1320        self.db = db
1321        self.dbname = db.db
1322        self._regtypes = False
1323        self._attnames = {}
1324        self._pkeys = {}
1325        self._privileges = {}
1326        self._args = args, kw
1327        self.adapter = Adapter(self)
1328        self.dbtypes = DbTypes(self)
1329        db.set_cast_hook(self.dbtypes.typecast)
1330        self.debug = None  # For debugging scripts, this can be set
1331            # * to a string format specification (e.g. in CGI set to "%s<BR>"),
1332            # * to a file object to write debug statements or
1333            # * to a callable object which takes a string argument
1334            # * to any other true value to just print debug statements
1335
1336    def __getattr__(self, name):
1337        # All undefined members are same as in underlying connection:
1338        if self.db:
1339            return getattr(self.db, name)
1340        else:
1341            raise _int_error('Connection is not valid')
1342
1343    def __dir__(self):
1344        # Custom dir function including the attributes of the connection:
1345        attrs = set(self.__class__.__dict__)
1346        attrs.update(self.__dict__)
1347        attrs.update(dir(self.db))
1348        return sorted(attrs)
1349
1350    # Context manager methods
1351
1352    def __enter__(self):
1353        """Enter the runtime context. This will start a transactio."""
1354        self.begin()
1355        return self
1356
1357    def __exit__(self, et, ev, tb):
1358        """Exit the runtime context. This will end the transaction."""
1359        if et is None and ev is None and tb is None:
1360            self.commit()
1361        else:
1362            self.rollback()
1363
1364    # Auxiliary methods
1365
1366    def _do_debug(self, *args):
1367        """Print a debug message"""
1368        if self.debug:
1369            s = '\n'.join(str(arg) for arg in args)
1370            if isinstance(self.debug, basestring):
1371                print(self.debug % s)
1372            elif hasattr(self.debug, 'write'):
1373                self.debug.write(s + '\n')
1374            elif callable(self.debug):
1375                self.debug(s)
1376            else:
1377                print(s)
1378
1379    def _escape_qualified_name(self, s):
1380        """Escape a qualified name.
1381
1382        Escapes the name for use as an SQL identifier, unless the
1383        name contains a dot, in which case the name is ambiguous
1384        (could be a qualified name or just a name with a dot in it)
1385        and must be quoted manually by the caller.
1386        """
1387        if '.' not in s:
1388            s = self.escape_identifier(s)
1389        return s
1390
1391    @staticmethod
1392    def _make_bool(d):
1393        """Get boolean value corresponding to d."""
1394        return bool(d) if get_bool() else ('t' if d else 'f')
1395
1396    def _list_params(self, params):
1397        """Create a human readable parameter list."""
1398        return ', '.join('$%d=%r' % (n, v) for n, v in enumerate(params, 1))
1399
1400    # Public methods
1401
1402    # escape_string and escape_bytea exist as methods,
1403    # so we define unescape_bytea as a method as well
1404    unescape_bytea = staticmethod(unescape_bytea)
1405
1406    def decode_json(self, s):
1407        """Decode a JSON string coming from the database."""
1408        return (get_jsondecode() or jsondecode)(s)
1409
1410    def encode_json(self, d):
1411        """Encode a JSON string for use within SQL."""
1412        return jsonencode(d)
1413
1414    def close(self):
1415        """Close the database connection."""
1416        # Wraps shared library function so we can track state.
1417        if self._closeable:
1418            if self.db:
1419                self.db.close()
1420                self.db = None
1421            else:
1422                raise _int_error('Connection already closed')
1423
1424    def reset(self):
1425        """Reset connection with current parameters.
1426
1427        All derived queries and large objects derived from this connection
1428        will not be usable after this call.
1429
1430        """
1431        if self.db:
1432            self.db.reset()
1433        else:
1434            raise _int_error('Connection already closed')
1435
1436    def reopen(self):
1437        """Reopen connection to the database.
1438
1439        Used in case we need another connection to the same database.
1440        Note that we can still reopen a database that we have closed.
1441
1442        """
1443        # There is no such shared library function.
1444        if self._closeable:
1445            db = connect(*self._args[0], **self._args[1])
1446            if self.db:
1447                self.db.close()
1448            self.db = db
1449
1450    def begin(self, mode=None):
1451        """Begin a transaction."""
1452        qstr = 'BEGIN'
1453        if mode:
1454            qstr += ' ' + mode
1455        return self.query(qstr)
1456
1457    start = begin
1458
1459    def commit(self):
1460        """Commit the current transaction."""
1461        return self.query('COMMIT')
1462
1463    end = commit
1464
1465    def rollback(self, name=None):
1466        """Roll back the current transaction."""
1467        qstr = 'ROLLBACK'
1468        if name:
1469            qstr += ' TO ' + name
1470        return self.query(qstr)
1471
1472    abort = rollback
1473
1474    def savepoint(self, name):
1475        """Define a new savepoint within the current transaction."""
1476        return self.query('SAVEPOINT ' + name)
1477
1478    def release(self, name):
1479        """Destroy a previously defined savepoint."""
1480        return self.query('RELEASE ' + name)
1481
1482    def get_parameter(self, parameter):
1483        """Get the value of a run-time parameter.
1484
1485        If the parameter is a string, the return value will also be a string
1486        that is the current setting of the run-time parameter with that name.
1487
1488        You can get several parameters at once by passing a list, set or dict.
1489        When passing a list of parameter names, the return value will be a
1490        corresponding list of parameter settings.  When passing a set of
1491        parameter names, a new dict will be returned, mapping these parameter
1492        names to their settings.  Finally, if you pass a dict as parameter,
1493        its values will be set to the current parameter settings corresponding
1494        to its keys.
1495
1496        By passing the special name 'all' as the parameter, you can get a dict
1497        of all existing configuration parameters.
1498        """
1499        if isinstance(parameter, basestring):
1500            parameter = [parameter]
1501            values = None
1502        elif isinstance(parameter, (list, tuple)):
1503            values = []
1504        elif isinstance(parameter, (set, frozenset)):
1505            values = {}
1506        elif isinstance(parameter, dict):
1507            values = parameter
1508        else:
1509            raise TypeError(
1510                'The parameter must be a string, list, set or dict')
1511        if not parameter:
1512            raise TypeError('No parameter has been specified')
1513        params = {} if isinstance(values, dict) else []
1514        for key in parameter:
1515            param = key.strip().lower() if isinstance(
1516                key, basestring) else None
1517            if not param:
1518                raise TypeError('Invalid parameter')
1519            if param == 'all':
1520                q = 'SHOW ALL'
1521                values = self.db.query(q).getresult()
1522                values = dict(value[:2] for value in values)
1523                break
1524            if isinstance(values, dict):
1525                params[param] = key
1526            else:
1527                params.append(param)
1528        else:
1529            for param in params:
1530                q = 'SHOW %s' % (param,)
1531                value = self.db.query(q).getresult()[0][0]
1532                if values is None:
1533                    values = value
1534                elif isinstance(values, list):
1535                    values.append(value)
1536                else:
1537                    values[params[param]] = value
1538        return values
1539
1540    def set_parameter(self, parameter, value=None, local=False):
1541        """Set the value of a run-time parameter.
1542
1543        If the parameter and the value are strings, the run-time parameter
1544        will be set to that value.  If no value or None is passed as a value,
1545        then the run-time parameter will be restored to its default value.
1546
1547        You can set several parameters at once by passing a list of parameter
1548        names, together with a single value that all parameters should be
1549        set to or with a corresponding list of values.  You can also pass
1550        the parameters as a set if you only provide a single value.
1551        Finally, you can pass a dict with parameter names as keys.  In this
1552        case, you should not pass a value, since the values for the parameters
1553        will be taken from the dict.
1554
1555        By passing the special name 'all' as the parameter, you can reset
1556        all existing settable run-time parameters to their default values.
1557
1558        If you set local to True, then the command takes effect for only the
1559        current transaction.  After commit() or rollback(), the session-level
1560        setting takes effect again.  Setting local to True will appear to
1561        have no effect if it is executed outside a transaction, since the
1562        transaction will end immediately.
1563        """
1564        if isinstance(parameter, basestring):
1565            parameter = {parameter: value}
1566        elif isinstance(parameter, (list, tuple)):
1567            if isinstance(value, (list, tuple)):
1568                parameter = dict(zip(parameter, value))
1569            else:
1570                parameter = dict.fromkeys(parameter, value)
1571        elif isinstance(parameter, (set, frozenset)):
1572            if isinstance(value, (list, tuple, set, frozenset)):
1573                value = set(value)
1574                if len(value) == 1:
1575                    value = value.pop()
1576            if not(value is None or isinstance(value, basestring)):
1577                raise ValueError('A single value must be specified'
1578                    ' when parameter is a set')
1579            parameter = dict.fromkeys(parameter, value)
1580        elif isinstance(parameter, dict):
1581            if value is not None:
1582                raise ValueError('A value must not be specified'
1583                    ' when parameter is a dictionary')
1584        else:
1585            raise TypeError(
1586                'The parameter must be a string, list, set or dict')
1587        if not parameter:
1588            raise TypeError('No parameter has been specified')
1589        params = {}
1590        for key, value in parameter.items():
1591            param = key.strip().lower() if isinstance(
1592                key, basestring) else None
1593            if not param:
1594                raise TypeError('Invalid parameter')
1595            if param == 'all':
1596                if value is not None:
1597                    raise ValueError('A value must ot be specified'
1598                        " when parameter is 'all'")
1599                params = {'all': None}
1600                break
1601            params[param] = value
1602        local = ' LOCAL' if local else ''
1603        for param, value in params.items():
1604            if value is None:
1605                q = 'RESET%s %s' % (local, param)
1606            else:
1607                q = 'SET%s %s TO %s' % (local, param, value)
1608            self._do_debug(q)
1609            self.db.query(q)
1610
1611    def query(self, command, *args):
1612        """Execute a SQL command string.
1613
1614        This method simply sends a SQL query to the database.  If the query is
1615        an insert statement that inserted exactly one row into a table that
1616        has OIDs, the return value is the OID of the newly inserted row.
1617        If the query is an update or delete statement, or an insert statement
1618        that did not insert exactly one row in a table with OIDs, then the
1619        number of rows affected is returned as a string.  If it is a statement
1620        that returns rows as a result (usually a select statement, but maybe
1621        also an "insert/update ... returning" statement), this method returns
1622        a Query object that can be accessed via getresult() or dictresult()
1623        or simply printed.  Otherwise, it returns `None`.
1624
1625        The query can contain numbered parameters of the form $1 in place
1626        of any data constant.  Arguments given after the query string will
1627        be substituted for the corresponding numbered parameter.  Parameter
1628        values can also be given as a single list or tuple argument.
1629        """
1630        # Wraps shared library function for debugging.
1631        if not self.db:
1632            raise _int_error('Connection is not valid')
1633        if args:
1634            self._do_debug(command, args)
1635            return self.db.query(command, args)
1636        self._do_debug(command)
1637        return self.db.query(command)
1638
1639    def query_formatted(self, command, parameters, types=None, inline=False):
1640        """Execute a formatted SQL command string.
1641
1642        Similar to query, but using Python format placeholders of the form
1643        %s or %(names)s instead of PostgreSQL placeholders of the form $1.
1644        The parameters must be passed as a tuple, list or dict.  You can
1645        also pass a corresponding tuple, list or dict of database types in
1646        order to format the parameters properly in case there is ambiguity.
1647
1648        If you set inline to True, the parameters will be sent to the database
1649        embedded in the SQL command, otherwise they will be sent separately.
1650        """
1651        return self.query(*self.adapter.format_query(
1652            command, parameters, types, inline))
1653
1654    def pkey(self, table, composite=False, flush=False):
1655        """Get or set the primary key of a table.
1656
1657        Single primary keys are returned as strings unless you
1658        set the composite flag.  Composite primary keys are always
1659        represented as tuples.  Note that this raises a KeyError
1660        if the table does not have a primary key.
1661
1662        If flush is set then the internal cache for primary keys will
1663        be flushed.  This may be necessary after the database schema or
1664        the search path has been changed.
1665        """
1666        pkeys = self._pkeys
1667        if flush:
1668            pkeys.clear()
1669            self._do_debug('The pkey cache has been flushed')
1670        try:  # cache lookup
1671            pkey = pkeys[table]
1672        except KeyError:  # cache miss, check the database
1673            q = ("SELECT a.attname, a.attnum, i.indkey FROM pg_index i"
1674                " JOIN pg_attribute a ON a.attrelid = i.indrelid"
1675                " AND a.attnum = ANY(i.indkey)"
1676                " AND NOT a.attisdropped"
1677                " WHERE i.indrelid=%s::regclass"
1678                " AND i.indisprimary ORDER BY a.attnum") % (
1679                    _quote_if_unqualified('$1', table),)
1680            pkey = self.db.query(q, (table,)).getresult()
1681            if not pkey:
1682                raise KeyError('Table %s has no primary key' % table)
1683            # we want to use the order defined in the primary key index here,
1684            # not the order as defined by the columns in the table
1685            if len(pkey) > 1:
1686                indkey = pkey[0][2]
1687                pkey = sorted(pkey, key=lambda row: indkey.index(row[1]))
1688                pkey = tuple(row[0] for row in pkey)
1689            else:
1690                pkey = pkey[0][0]
1691            pkeys[table] = pkey  # cache it
1692        if composite and not isinstance(pkey, tuple):
1693            pkey = (pkey,)
1694        return pkey
1695
1696    def get_databases(self):
1697        """Get list of databases in the system."""
1698        return [s[0] for s in
1699            self.db.query('SELECT datname FROM pg_database').getresult()]
1700
1701    def get_relations(self, kinds=None):
1702        """Get list of relations in connected database of specified kinds.
1703
1704        If kinds is None or empty, all kinds of relations are returned.
1705        Otherwise kinds can be a string or sequence of type letters
1706        specifying which kind of relations you want to list.
1707        """
1708        where = " AND r.relkind IN (%s)" % ','.join(
1709            ["'%s'" % k for k in kinds]) if kinds else ''
1710        q = ("SELECT quote_ident(s.nspname)||'.'||quote_ident(r.relname)"
1711            " FROM pg_class r"
1712            " JOIN pg_namespace s ON s.oid = r.relnamespace"
1713            " WHERE s.nspname NOT SIMILAR"
1714            " TO 'pg/_%%|information/_schema' ESCAPE '/' %s"
1715            " ORDER BY s.nspname, r.relname") % where
1716        return [r[0] for r in self.db.query(q).getresult()]
1717
1718    def get_tables(self):
1719        """Return list of tables in connected database."""
1720        return self.get_relations('r')
1721
1722    def get_attnames(self, table, with_oid=True, flush=False):
1723        """Given the name of a table, dig out the set of attribute names.
1724
1725        Returns a read-only dictionary of attribute names (the names are
1726        the keys, the values are the names of the attributes' types)
1727        with the column names in the proper order if you iterate over it.
1728
1729        If flush is set, then the internal cache for attribute names will
1730        be flushed. This may be necessary after the database schema or
1731        the search path has been changed.
1732
1733        By default, only a limited number of simple types will be returned.
1734        You can get the regular types after calling use_regtypes(True).
1735        """
1736        attnames = self._attnames
1737        if flush:
1738            attnames.clear()
1739            self._do_debug('The attnames cache has been flushed')
1740        try:  # cache lookup
1741            names = attnames[table]
1742        except KeyError:  # cache miss, check the database
1743            q = "a.attnum > 0"
1744            if with_oid:
1745                q = "(%s OR a.attname = 'oid')" % q
1746            q = ("SELECT a.attname, t.oid, t.typname, t.typname::regtype,"
1747                " t.typtype, t.typcategory, t.typdelim, t.typrelid"
1748                " FROM pg_attribute a"
1749                " JOIN pg_type t ON t.oid = a.atttypid"
1750                " WHERE a.attrelid = %s::regclass AND %s"
1751                " AND NOT a.attisdropped ORDER BY a.attnum") % (
1752                    _quote_if_unqualified('$1', table), q)
1753            names = self.db.query(q, (table,)).getresult()
1754            types = self.dbtypes
1755            names = ((name[0], types.add(*name[1:])) for name in names)
1756            names = AttrDict(names)
1757            attnames[table] = names  # cache it
1758        return names
1759
1760    def use_regtypes(self, regtypes=None):
1761        """Use regular type names instead of simplified type names."""
1762        if regtypes is None:
1763            return self.dbtypes._regtypes
1764        else:
1765            regtypes = bool(regtypes)
1766            if regtypes != self.dbtypes._regtypes:
1767                self.dbtypes._regtypes = regtypes
1768                self._attnames.clear()
1769                self.dbtypes.clear()
1770            return regtypes
1771
1772    def has_table_privilege(self, table, privilege='select'):
1773        """Check whether current user has specified table privilege."""
1774        privilege = privilege.lower()
1775        try:  # ask cache
1776            return self._privileges[(table, privilege)]
1777        except KeyError:  # cache miss, ask the database
1778            q = "SELECT has_table_privilege(%s, $2)" % (
1779                _quote_if_unqualified('$1', table),)
1780            q = self.db.query(q, (table, privilege))
1781            ret = q.getresult()[0][0] == self._make_bool(True)
1782            self._privileges[(table, privilege)] = ret  # cache it
1783            return ret
1784
1785    def get(self, table, row, keyname=None):
1786        """Get a row from a database table or view.
1787
1788        This method is the basic mechanism to get a single row.  It assumes
1789        that the keyname specifies a unique row.  It must be the name of a
1790        single column or a tuple of column names.  If the keyname is not
1791        specified, then the primary key for the table is used.
1792
1793        If row is a dictionary, then the value for the key is taken from it.
1794        Otherwise, the row must be a single value or a tuple of values
1795        corresponding to the passed keyname or primary key.  The fetched row
1796        from the table will be returned as a new dictionary or used to replace
1797        the existing values when row was passed as aa dictionary.
1798
1799        The OID is also put into the dictionary if the table has one, but
1800        in order to allow the caller to work with multiple tables, it is
1801        munged as "oid(table)" using the actual name of the table.
1802        """
1803        if table.endswith('*'):  # hint for descendant tables can be ignored
1804            table = table[:-1].rstrip()
1805        attnames = self.get_attnames(table)
1806        qoid = _oid_key(table) if 'oid' in attnames else None
1807        if keyname and isinstance(keyname, basestring):
1808            keyname = (keyname,)
1809        if qoid and isinstance(row, dict) and qoid in row and 'oid' not in row:
1810            row['oid'] = row[qoid]
1811        if not keyname:
1812            try:  # if keyname is not specified, try using the primary key
1813                keyname = self.pkey(table, True)
1814            except KeyError:  # the table has no primary key
1815                # try using the oid instead
1816                if qoid and isinstance(row, dict) and 'oid' in row:
1817                    keyname = ('oid',)
1818                else:
1819                    raise _prg_error('Table %s has no primary key' % table)
1820            else:  # the table has a primary key
1821                # check whether all key columns have values
1822                if isinstance(row, dict) and not set(keyname).issubset(row):
1823                    # try using the oid instead
1824                    if qoid and 'oid' in row:
1825                        keyname = ('oid',)
1826                    else:
1827                        raise KeyError(
1828                            'Missing value in row for specified keyname')
1829        if not isinstance(row, dict):
1830            if not isinstance(row, (tuple, list)):
1831                row = [row]
1832            if len(keyname) != len(row):
1833                raise KeyError(
1834                    'Differing number of items in keyname and row')
1835            row = dict(zip(keyname, row))
1836        params = self.adapter.parameter_list()
1837        adapt = params.add
1838        col = self.escape_identifier
1839        what = 'oid, *' if qoid else '*'
1840        where = ' AND '.join('%s = %s' % (
1841            col(k), adapt(row[k], attnames[k])) for k in keyname)
1842        if 'oid' in row:
1843            if qoid:
1844                row[qoid] = row['oid']
1845            del row['oid']
1846        q = 'SELECT %s FROM %s WHERE %s LIMIT 1' % (
1847            what, self._escape_qualified_name(table), where)
1848        self._do_debug(q, params)
1849        q = self.db.query(q, params)
1850        res = q.dictresult()
1851        if not res:
1852            raise _db_error('No such record in %s\nwhere %s\nwith %s' % (
1853                table, where, self._list_params(params)))
1854        for n, value in res[0].items():
1855            if qoid and n == 'oid':
1856                n = qoid
1857            row[n] = value
1858        return row
1859
1860    def insert(self, table, row=None, **kw):
1861        """Insert a row into a database table.
1862
1863        This method inserts a row into a table.  The name of the table must
1864        be passed as the first parameter.  The other parameters are used for
1865        providing the data of the row that shall be inserted into the table.
1866        If a dictionary is supplied as the second parameter, it starts with
1867        that.  Otherwise it uses a blank dictionary. Either way the dictionary
1868        is updated from the keywords.
1869
1870        The dictionary is then reloaded with the values actually inserted in
1871        order to pick up values modified by rules, triggers, etc.
1872        """
1873        if table.endswith('*'):  # hint for descendant tables can be ignored
1874            table = table[:-1].rstrip()
1875        if row is None:
1876            row = {}
1877        row.update(kw)
1878        if 'oid' in row:
1879            del row['oid']  # do not insert oid
1880        attnames = self.get_attnames(table)
1881        qoid = _oid_key(table) if 'oid' in attnames else None
1882        params = self.adapter.parameter_list()
1883        adapt = params.add
1884        col = self.escape_identifier
1885        names, values = [], []
1886        for n in attnames:
1887            if n in row:
1888                names.append(col(n))
1889                values.append(adapt(row[n], attnames[n]))
1890        if not names:
1891            raise _prg_error('No column found that can be inserted')
1892        names, values = ', '.join(names), ', '.join(values)
1893        ret = 'oid, *' if qoid else '*'
1894        q = 'INSERT INTO %s (%s) VALUES (%s) RETURNING %s' % (
1895            self._escape_qualified_name(table), names, values, ret)
1896        self._do_debug(q, params)
1897        q = self.db.query(q, params)
1898        res = q.dictresult()
1899        if res:  # this should always be true
1900            for n, value in res[0].items():
1901                if qoid and n == 'oid':
1902                    n = qoid
1903                row[n] = value
1904        return row
1905
1906    def update(self, table, row=None, **kw):
1907        """Update an existing row in a database table.
1908
1909        Similar to insert but updates an existing row.  The update is based
1910        on the primary key of the table or the OID value as munged by get
1911        or passed as keyword.
1912
1913        The dictionary is then modified to reflect any changes caused by the
1914        update due to triggers, rules, default values, etc.
1915        """
1916        if table.endswith('*'):
1917            table = table[:-1].rstrip()  # need parent table name
1918        attnames = self.get_attnames(table)
1919        qoid = _oid_key(table) if 'oid' in attnames else None
1920        if row is None:
1921            row = {}
1922        elif 'oid' in row:
1923            del row['oid']  # only accept oid key from named args for safety
1924        row.update(kw)
1925        if qoid and qoid in row and 'oid' not in row:
1926            row['oid'] = row[qoid]
1927        try:  # try using the primary key
1928            keyname = self.pkey(table, True)
1929        except KeyError:  # the table has no primary key
1930            # try using the oid instead
1931            if qoid and 'oid' in row:
1932                keyname = ('oid',)
1933            else:
1934                raise _prg_error('Table %s has no primary key' % table)
1935        else:  # the table has a primary key
1936            # check whether all key columns have values
1937            if not set(keyname).issubset(row):
1938                # try using the oid instead
1939                if qoid and 'oid' in row:
1940                    keyname = ('oid',)
1941                else:
1942                    raise KeyError('Missing primary key in row')
1943        params = self.adapter.parameter_list()
1944        adapt = params.add
1945        col = self.escape_identifier
1946        where = ' AND '.join('%s = %s' % (
1947            col(k), adapt(row[k], attnames[k])) for k in keyname)
1948        if 'oid' in row:
1949            if qoid:
1950                row[qoid] = row['oid']
1951            del row['oid']
1952        values = []
1953        keyname = set(keyname)
1954        for n in attnames:
1955            if n in row and n not in keyname:
1956                values.append('%s = %s' % (col(n), adapt(row[n], attnames[n])))
1957        if not values:
1958            return row
1959        values = ', '.join(values)
1960        ret = 'oid, *' if qoid else '*'
1961        q = 'UPDATE %s SET %s WHERE %s RETURNING %s' % (
1962            self._escape_qualified_name(table), values, where, ret)
1963        self._do_debug(q, params)
1964        q = self.db.query(q, params)
1965        res = q.dictresult()
1966        if res:  # may be empty when row does not exist
1967            for n, value in res[0].items():
1968                if qoid and n == 'oid':
1969                    n = qoid
1970                row[n] = value
1971        return row
1972
1973    def upsert(self, table, row=None, **kw):
1974        """Insert a row into a database table with conflict resolution
1975
1976        This method inserts a row into a table, but instead of raising a
1977        ProgrammingError exception in case a row with the same primary key
1978        already exists, an update will be executed instead.  This will be
1979        performed as a single atomic operation on the database, so race
1980        conditions can be avoided.
1981
1982        Like the insert method, the first parameter is the name of the
1983        table and the second parameter can be used to pass the values to
1984        be inserted as a dictionary.
1985
1986        Unlike the insert und update statement, keyword parameters are not
1987        used to modify the dictionary, but to specify which columns shall
1988        be updated in case of a conflict, and in which way:
1989
1990        A value of False or None means the column shall not be updated,
1991        a value of True means the column shall be updated with the value
1992        that has been proposed for insertion, i.e. has been passed as value
1993        in the dictionary.  Columns that are not specified by keywords but
1994        appear as keys in the dictionary are also updated like in the case
1995        keywords had been passed with the value True.
1996
1997        So if in the case of a conflict you want to update every column that
1998        has been passed in the dictionary row , you would call upsert(table, row).
1999        If you don't want to do anything in case of a conflict, i.e. leave
2000        the existing row as it is, call upsert(table, row, **dict.fromkeys(row)).
2001
2002        If you need more fine-grained control of what gets updated, you can
2003        also pass strings in the keyword parameters.  These strings will
2004        be used as SQL expressions for the update columns.  In these
2005        expressions you can refer to the value that already exists in
2006        the table by prefixing the column name with "included.", and to
2007        the value that has been proposed for insertion by prefixing the
2008        column name with the "excluded."
2009
2010        The dictionary is modified in any case to reflect the values in
2011        the database after the operation has completed.
2012
2013        Note: The method uses the PostgreSQL "upsert" feature which is
2014        only available since PostgreSQL 9.5.
2015        """
2016        if table.endswith('*'):  # hint for descendant tables can be ignored
2017            table = table[:-1].rstrip()
2018        if row is None:
2019            row = {}
2020        if 'oid' in row:
2021            del row['oid']  # do not insert oid
2022        if 'oid' in kw:
2023            del kw['oid']  # do not update oid
2024        attnames = self.get_attnames(table)
2025        qoid = _oid_key(table) if 'oid' in attnames else None
2026        params = self.adapter.parameter_list()
2027        adapt = params.add
2028        col = self.escape_identifier
2029        names, values, updates = [], [], []
2030        for n in attnames:
2031            if n in row:
2032                names.append(col(n))
2033                values.append(adapt(row[n], attnames[n]))
2034        names, values = ', '.join(names), ', '.join(values)
2035        try:
2036            keyname = self.pkey(table, True)
2037        except KeyError:
2038            raise _prg_error('Table %s has no primary key' % table)
2039        target = ', '.join(col(k) for k in keyname)
2040        update = []
2041        keyname = set(keyname)
2042        keyname.add('oid')
2043        for n in attnames:
2044            if n not in keyname:
2045                value = kw.get(n, True)
2046                if value:
2047                    if not isinstance(value, basestring):
2048                        value = 'excluded.%s' % col(n)
2049                    update.append('%s = %s' % (col(n), value))
2050        if not values:
2051            return row
2052        do = 'update set %s' % ', '.join(update) if update else 'nothing'
2053        ret = 'oid, *' if qoid else '*'
2054        q = ('INSERT INTO %s AS included (%s) VALUES (%s)'
2055            ' ON CONFLICT (%s) DO %s RETURNING %s') % (
2056                self._escape_qualified_name(table), names, values,
2057                target, do, ret)
2058        self._do_debug(q, params)
2059        try:
2060            q = self.db.query(q, params)
2061        except ProgrammingError:
2062            if self.server_version < 90500:
2063                raise _prg_error(
2064                    'Upsert operation is not supported by PostgreSQL version')
2065            raise  # re-raise original error
2066        res = q.dictresult()
2067        if res:  # may be empty with "do nothing"
2068            for n, value in res[0].items():
2069                if qoid and n == 'oid':
2070                    n = qoid
2071                row[n] = value
2072        else:
2073            self.get(table, row)
2074        return row
2075
2076    def clear(self, table, row=None):
2077        """Clear all the attributes to values determined by the types.
2078
2079        Numeric types are set to 0, Booleans are set to false, and everything
2080        else is set to the empty string.  If the row argument is present,
2081        it is used as the row dictionary and any entries matching attribute
2082        names are cleared with everything else left unchanged.
2083        """
2084        # At some point we will need a way to get defaults from a table.
2085        if row is None:
2086            row = {}  # empty if argument is not present
2087        attnames = self.get_attnames(table)
2088        for n, t in attnames.items():
2089            if n == 'oid':
2090                continue
2091            t = t.simple
2092            if t in DbTypes._num_types:
2093                row[n] = 0
2094            elif t == 'bool':
2095                row[n] = self._make_bool(False)
2096            else:
2097                row[n] = ''
2098        return row
2099
2100    def delete(self, table, row=None, **kw):
2101        """Delete an existing row in a database table.
2102
2103        This method deletes the row from a table.  It deletes based on the
2104        primary key of the table or the OID value as munged by get() or
2105        passed as keyword.
2106
2107        The return value is the number of deleted rows (i.e. 0 if the row
2108        did not exist and 1 if the row was deleted).
2109
2110        Note that if the row cannot be deleted because e.g. it is still
2111        referenced by another table, this method raises a ProgrammingError.
2112        """
2113        if table.endswith('*'):  # hint for descendant tables can be ignored
2114            table = table[:-1].rstrip()
2115        attnames = self.get_attnames(table)
2116        qoid = _oid_key(table) if 'oid' in attnames else None
2117        if row is None:
2118            row = {}
2119        elif 'oid' in row:
2120            del row['oid']  # only accept oid key from named args for safety
2121        row.update(kw)
2122        if qoid and qoid in row and 'oid' not in row:
2123            row['oid'] = row[qoid]
2124        try:  # try using the primary key
2125            keyname = self.pkey(table, True)
2126        except KeyError:  # the table has no primary key
2127            # try using the oid instead
2128            if qoid and 'oid' in row:
2129                keyname = ('oid',)
2130            else:
2131                raise _prg_error('Table %s has no primary key' % table)
2132        else:  # the table has a primary key
2133            # check whether all key columns have values
2134            if not set(keyname).issubset(row):
2135                # try using the oid instead
2136                if qoid and 'oid' in row:
2137                    keyname = ('oid',)
2138                else:
2139                    raise KeyError('Missing primary key in row')
2140        params = self.adapter.parameter_list()
2141        adapt = params.add
2142        col = self.escape_identifier
2143        where = ' AND '.join('%s = %s' % (
2144            col(k), adapt(row[k], attnames[k])) for k in keyname)
2145        if 'oid' in row:
2146            if qoid:
2147                row[qoid] = row['oid']
2148            del row['oid']
2149        q = 'DELETE FROM %s WHERE %s' % (
2150            self._escape_qualified_name(table), where)
2151        self._do_debug(q, params)
2152        res = self.db.query(q, params)
2153        return int(res)
2154
2155    def truncate(self, table, restart=False, cascade=False, only=False):
2156        """Empty a table or set of tables.
2157
2158        This method quickly removes all rows from the given table or set
2159        of tables.  It has the same effect as an unqualified DELETE on each
2160        table, but since it does not actually scan the tables it is faster.
2161        Furthermore, it reclaims disk space immediately, rather than requiring
2162        a subsequent VACUUM operation. This is most useful on large tables.
2163
2164        If restart is set to True, sequences owned by columns of the truncated
2165        table(s) are automatically restarted.  If cascade is set to True, it
2166        also truncates all tables that have foreign-key references to any of
2167        the named tables.  If the parameter only is not set to True, all the
2168        descendant tables (if any) will also be truncated. Optionally, a '*'
2169        can be specified after the table name to explicitly indicate that
2170        descendant tables are included.
2171        """
2172        if isinstance(table, basestring):
2173            only = {table: only}
2174            table = [table]
2175        elif isinstance(table, (list, tuple)):
2176            if isinstance(only, (list, tuple)):
2177                only = dict(zip(table, only))
2178            else:
2179                only = dict.fromkeys(table, only)
2180        elif isinstance(table, (set, frozenset)):
2181            only = dict.fromkeys(table, only)
2182        else:
2183            raise TypeError('The table must be a string, list or set')
2184        if not (restart is None or isinstance(restart, (bool, int))):
2185            raise TypeError('Invalid type for the restart option')
2186        if not (cascade is None or isinstance(cascade, (bool, int))):
2187            raise TypeError('Invalid type for the cascade option')
2188        tables = []
2189        for t in table:
2190            u = only.get(t)
2191            if not (u is None or isinstance(u, (bool, int))):
2192                raise TypeError('Invalid type for the only option')
2193            if t.endswith('*'):
2194                if u:
2195                    raise ValueError(
2196                        'Contradictory table name and only options')
2197                t = t[:-1].rstrip()
2198            t = self._escape_qualified_name(t)
2199            if u:
2200                t = 'ONLY %s' % t
2201            tables.append(t)
2202        q = ['TRUNCATE', ', '.join(tables)]
2203        if restart:
2204            q.append('RESTART IDENTITY')
2205        if cascade:
2206            q.append('CASCADE')
2207        q = ' '.join(q)
2208        self._do_debug(q)
2209        return self.db.query(q)
2210
2211    def get_as_list(self, table, what=None, where=None,
2212            order=None, limit=None, offset=None, scalar=False):
2213        """Get a table as a list.
2214
2215        This gets a convenient representation of the table as a list
2216        of named tuples in Python.  You only need to pass the name of
2217        the table (or any other SQL expression returning rows).  Note that
2218        by default this will return the full content of the table which
2219        can be huge and overflow your memory.  However, you can control
2220        the amount of data returned using the other optional parameters.
2221
2222        The parameter 'what' can restrict the query to only return a
2223        subset of the table columns.  It can be a string, list or a tuple.
2224        The parameter 'where' can restrict the query to only return a
2225        subset of the table rows.  It can be a string, list or a tuple
2226        of SQL expressions that all need to be fulfilled.  The parameter
2227        'order' specifies the ordering of the rows.  It can also be a
2228        other string, list or a tuple.  If no ordering is specified,
2229        the result will be ordered by the primary key(s) or all columns
2230        if no primary key exists.  You can set 'order' to False if you
2231        don't care about the ordering.  The parameters 'limit' and 'offset'
2232        can be integers specifying the maximum number of rows returned
2233        and a number of rows skipped over.
2234
2235        If you set the 'scalar' option to True, then instead of the
2236        named tuples you will get the first items of these tuples.
2237        This is useful if the result has only one column anyway.
2238        """
2239        if not table:
2240            raise TypeError('The table name is missing')
2241        if what:
2242            if isinstance(what, (list, tuple)):
2243                what = ', '.join(map(str, what))
2244            if order is None:
2245                order = what
2246        else:
2247            what = '*'
2248        q = ['SELECT', what, 'FROM', table]
2249        if where:
2250            if isinstance(where, (list, tuple)):
2251                where = ' AND '.join(map(str, where))
2252            q.extend(['WHERE', where])
2253        if order is None:
2254            try:
2255                order = self.pkey(table, True)
2256            except (KeyError, ProgrammingError):
2257                try:
2258                    order = list(self.get_attnames(table))
2259                except (KeyError, ProgrammingError):
2260                    pass
2261        if order:
2262            if isinstance(order, (list, tuple)):
2263                order = ', '.join(map(str, order))
2264            q.extend(['ORDER BY', order])
2265        if limit:
2266            q.append('LIMIT %d' % limit)
2267        if offset:
2268            q.append('OFFSET %d' % offset)
2269        q = ' '.join(q)
2270        self._do_debug(q)
2271        q = self.db.query(q)
2272        res = q.namedresult()
2273        if res and scalar:
2274            res = [row[0] for row in res]
2275        return res
2276
2277    def get_as_dict(self, table, keyname=None, what=None, where=None,
2278            order=None, limit=None, offset=None, scalar=False):
2279        """Get a table as a dictionary.
2280
2281        This method is similar to get_as_list(), but returns the table
2282        as a Python dict instead of a Python list, which can be even
2283        more convenient. The primary key column(s) of the table will
2284        be used as the keys of the dictionary, while the other column(s)
2285        will be the corresponding values.  The keys will be named tuples
2286        if the table has a composite primary key.  The rows will be also
2287        named tuples unless the 'scalar' option has been set to True.
2288        With the optional parameter 'keyname' you can specify an alternative
2289        set of columns to be used as the keys of the dictionary.  It must
2290        be set as a string, list or a tuple.
2291
2292        If the Python version supports it, the dictionary will be an
2293        OrderedDict using the order specified with the 'order' parameter
2294        or the key column(s) if not specified.  You can set 'order' to False
2295        if you don't care about the ordering.  In this case the returned
2296        dictionary will be an ordinary one.
2297        """
2298        if not table:
2299            raise TypeError('The table name is missing')
2300        if not keyname:
2301            try:
2302                keyname = self.pkey(table, True)
2303            except (KeyError, ProgrammingError):
2304                raise _prg_error('Table %s has no primary key' % table)
2305        if isinstance(keyname, basestring):
2306            keyname = [keyname]
2307        elif not isinstance(keyname, (list, tuple)):
2308            raise KeyError('The keyname must be a string, list or tuple')
2309        if what:
2310            if isinstance(what, (list, tuple)):
2311                what = ', '.join(map(str, what))
2312            if order is None:
2313                order = what
2314        else:
2315            what = '*'
2316        q = ['SELECT', what, 'FROM', table]
2317        if where:
2318            if isinstance(where, (list, tuple)):
2319                where = ' AND '.join(map(str, where))
2320            q.extend(['WHERE', where])
2321        if order is None:
2322            order = keyname
2323        if order:
2324            if isinstance(order, (list, tuple)):
2325                order = ', '.join(map(str, order))
2326            q.extend(['ORDER BY', order])
2327        if limit:
2328            q.append('LIMIT %d' % limit)
2329        if offset:
2330            q.append('OFFSET %d' % offset)
2331        q = ' '.join(q)
2332        self._do_debug(q)
2333        q = self.db.query(q)
2334        res = q.getresult()
2335        cls = OrderedDict if order else dict
2336        if not res:
2337            return cls()
2338        keyset = set(keyname)
2339        fields = q.listfields()
2340        if not keyset.issubset(fields):
2341            raise KeyError('Missing keyname in row')
2342        keyind, rowind = [], []
2343        for i, f in enumerate(fields):
2344            (keyind if f in keyset else rowind).append(i)
2345        keytuple = len(keyind) > 1
2346        getkey = itemgetter(*keyind)
2347        keys = map(getkey, res)
2348        if scalar:
2349            rowind = rowind[:1]
2350            rowtuple = False
2351        else:
2352            rowtuple = len(rowind) > 1
2353        if scalar or rowtuple:
2354            getrow = itemgetter(*rowind)
2355        else:
2356            rowind = rowind[0]
2357            getrow = lambda row: (row[rowind],)
2358            rowtuple = True
2359        rows = map(getrow, res)
2360        if keytuple or rowtuple:
2361            namedresult = get_namedresult()
2362            if namedresult:
2363                if keytuple:
2364                    keys = namedresult(_MemoryQuery(keys, keyname))
2365                if rowtuple:
2366                    fields = [f for f in fields if f not in keyset]
2367                    rows = namedresult(_MemoryQuery(rows, fields))
2368        return cls(zip(keys, rows))
2369
2370    def notification_handler(self,
2371            event, callback, arg_dict=None, timeout=None, stop_event=None):
2372        """Get notification handler that will run the given callback."""
2373        return NotificationHandler(self,
2374            event, callback, arg_dict, timeout, stop_event)
2375
2376
2377# if run as script, print some information
2378
2379if __name__ == '__main__':
2380    print('PyGreSQL version' + version)
2381    print('')
2382    print(__doc__)
Note: See TracBrowser for help on using the repository browser.