source: trunk/pg.py @ 817

Last change on this file since 817 was 817, checked in by cito, 3 years ago

Support the hstore data type

Added adaptation and typecasting of the hstore type as Python dictionaries.
For the typecasting, a fast parser has been added to the C extension.

  • Property svn:keywords set to Id
File size: 86.8 KB
Line 
1#! /usr/bin/python
2#
3# pg.py
4#
5# $Id: pg.py 817 2016-02-04 20:18:08Z cito $
6#
7
8"""PyGreSQL classic interface.
9
10This pg module implements some basic database management stuff.
11It includes the _pg module and builds on it, providing the higher
12level wrapper class named DB with additional functionality.
13This is known as the "classic" ("old style") PyGreSQL interface.
14For a DB-API 2 compliant interface use the newer pgdb module.
15"""
16
17# Copyright (c) 1997-2016 by D'Arcy J.M. Cain.
18#
19# Contributions made by Ch. Zwerschke and others.
20#
21# The notification handler is based on pgnotify which is
22# Copyright (c) 2001 Ng Pheng Siong. All rights reserved.
23#
24# Permission to use, copy, modify, and distribute this software and its
25# documentation for any purpose and without fee is hereby granted,
26# provided that the above copyright notice appear in all copies and that
27# both that copyright notice and this permission notice appear in
28# supporting documentation.
29
30from __future__ import print_function
31
32from _pg import *
33
34import select
35import warnings
36
37from datetime import date, time, datetime, timedelta
38from decimal import Decimal
39from math import isnan, isinf
40from collections import namedtuple
41from operator import itemgetter
42from functools import partial
43from re import compile as regex
44from json import loads as jsondecode, dumps as jsonencode
45
46try:
47    long
48except NameError:  # Python >= 3.0
49    long = int
50
51try:
52    basestring
53except NameError:  # Python >= 3.0
54    basestring = (str, bytes)
55
56
57# Auxiliary classes and functions that are independent from a DB connection:
58
59try:
60    from collections import OrderedDict
61except ImportError:  # Python 2.6 or 3.0
62    OrderedDict = dict
63
64
65    class AttrDict(dict):
66        """Simple read-only ordered dictionary for storing attribute names."""
67
68        def __init__(self, *args, **kw):
69            if len(args) > 1 or kw:
70                raise TypeError
71            items = args[0] if args else []
72            if isinstance(items, dict):
73                raise TypeError
74            items = list(items)
75            self._keys = [item[0] for item in items]
76            dict.__init__(self, items)
77            self._read_only = True
78            error = self._read_only_error
79            self.clear = self.update = error
80            self.pop = self.setdefault = self.popitem = error
81
82        def __setitem__(self, key, value):
83            if self._read_only:
84                self._read_only_error()
85            dict.__setitem__(self, key, value)
86
87        def __delitem__(self, key):
88            if self._read_only:
89                self._read_only_error()
90            dict.__delitem__(self, key)
91
92        def __iter__(self):
93            return iter(self._keys)
94
95        def keys(self):
96            return list(self._keys)
97
98        def values(self):
99            return [self[key] for key in self]
100
101        def items(self):
102            return [(key, self[key]) for key in self]
103
104        def iterkeys(self):
105            return self.__iter__()
106
107        def itervalues(self):
108            return iter(self.values())
109
110        def iteritems(self):
111            return iter(self.items())
112
113        @staticmethod
114        def _read_only_error(*args, **kw):
115            raise TypeError('This object is read-only')
116
117else:
118
119     class AttrDict(OrderedDict):
120        """Simple read-only ordered dictionary for storing attribute names."""
121
122        def __init__(self, *args, **kw):
123            self._read_only = False
124            OrderedDict.__init__(self, *args, **kw)
125            self._read_only = True
126            error = self._read_only_error
127            self.clear = self.update = error
128            self.pop = self.setdefault = self.popitem = error
129
130        def __setitem__(self, key, value):
131            if self._read_only:
132                self._read_only_error()
133            OrderedDict.__setitem__(self, key, value)
134
135        def __delitem__(self, key):
136            if self._read_only:
137                self._read_only_error()
138            OrderedDict.__delitem__(self, key)
139
140        @staticmethod
141        def _read_only_error(*args, **kw):
142            raise TypeError('This object is read-only')
143
144try:
145    from inspect import signature
146except ImportError:  # Python < 3.3
147    from inspect import getargspec
148
149    def get_args(func):
150        return getargspec(func).args
151else:
152
153    def get_args(func):
154        return list(signature(func).parameters)
155
156try:
157    if datetime.strptime('+0100', '%z') is None:
158        raise ValueError
159except ValueError:  # Python < 3.2
160    timezones = None
161else:
162    # time zones used in Postgres timestamptz output
163    timezones = dict(CET='+0100', EET='+0200', EST='-0500',
164        GMT='+0000', HST='-1000', MET='+0100', MST='-0700',
165        UCT='+0000', UTC='+0000', WET='+0000')
166
167
168def _oid_key(table):
169    """Build oid key from a table name."""
170    return 'oid(%s)' % table
171
172
173class _SimpleTypes(dict):
174    """Dictionary mapping pg_type names to simple type names."""
175
176    _types = {'bool': 'bool',
177        'bytea': 'bytea',
178        'date': 'date interval time timetz timestamp timestamptz'
179            ' abstime reltime',  # these are very old
180        'float': 'float4 float8',
181        'int': 'cid int2 int4 int8 oid xid',
182        'hstore': 'hstore', 'json': 'json jsonb',
183        'num': 'numeric',
184        'money': 'money',
185        'text': 'bpchar char name text varchar'}
186
187    def __init__(self):
188        for typ, keys in self._types.items():
189            for key in keys.split():
190                self[key] = typ
191                self['_%s' % key] = '%s[]' % typ
192
193    @staticmethod
194    def __missing__(key):
195        return 'text'
196
197_simpletypes = _SimpleTypes()
198
199
200def _quote_if_unqualified(param, name):
201    """Quote parameter representing a qualified name.
202
203    Puts a quote_ident() call around the give parameter unless
204    the name contains a dot, in which case the name is ambiguous
205    (could be a qualified name or just a name with a dot in it)
206    and must be quoted manually by the caller.
207    """
208    if isinstance(name, basestring) and '.' not in name:
209        return 'quote_ident(%s)' % (param,)
210    return param
211
212
213class _ParameterList(list):
214    """Helper class for building typed parameter lists."""
215
216    def add(self, value, typ=None):
217        """Typecast value with known database type and build parameter list.
218
219        If this is a literal value, it will be returned as is.  Otherwise, a
220        placeholder will be returned and the parameter list will be augmented.
221        """
222        value = self.adapt(value, typ)
223        if isinstance(value, Literal):
224            return value
225        self.append(value)
226        return '$%d' % len(self)
227
228
229class Bytea(bytes):
230    """Wrapper class for marking Bytea values."""
231
232
233class Hstore(dict):
234    """Wrapper class for marking hstore values."""
235
236    _re_quote = regex('^[Nn][Uu][Ll][Ll]$|[ ,=>]')
237
238    @classmethod
239    def _quote(cls, s):
240        if s is None:
241            return 'NULL'
242        if not s:
243            return '""'
244        s = s.replace('"', '\\"')
245        if cls._re_quote.search(s):
246            s = '"%s"' % s
247        return s
248
249    def __str__(self):
250        q = self._quote
251        return ','.join('%s=>%s' % (q(k), q(v)) for k, v in self.items())
252
253
254class Json:
255    """Wrapper class for marking Json values."""
256
257    def __init__(self, obj):
258        self.obj = obj
259
260
261class Literal(str):
262    """Wrapper class for marking literal SQL values."""
263
264
265class Adapter:
266    """Class providing methods for adapting parameters to the database."""
267
268    _bool_true_values = frozenset('t true 1 y yes on'.split())
269
270    _date_literals = frozenset('current_date current_time'
271        ' current_timestamp localtime localtimestamp'.split())
272
273    _re_array_quote = regex(r'[{},"\\\s]|^[Nn][Uu][Ll][Ll]$')
274    _re_record_quote = regex(r'[(,"\\]')
275    _re_array_escape = _re_record_escape = regex(r'(["\\])')
276
277    def __init__(self, db):
278        self.db = db
279        self.encode_json = db.encode_json
280        db = db.db
281        self.escape_bytea = db.escape_bytea
282        self.escape_string = db.escape_string
283
284    @classmethod
285    def _adapt_bool(cls, v):
286        """Adapt a boolean parameter."""
287        if isinstance(v, basestring):
288            if not v:
289                return None
290            v = v.lower() in cls._bool_true_values
291        return 't' if v else 'f'
292
293    @classmethod
294    def _adapt_date(cls, v):
295        """Adapt a date parameter."""
296        if not v:
297            return None
298        if isinstance(v, basestring) and v.lower() in cls._date_literals:
299            return Literal(v)
300        return v
301
302    @staticmethod
303    def _adapt_num(v):
304        """Adapt a numeric parameter."""
305        if not v and v != 0:
306            return None
307        return v
308
309    _adapt_int = _adapt_float = _adapt_money = _adapt_num
310
311    def _adapt_bytea(self, v):
312        """Adapt a bytea parameter."""
313        return self.escape_bytea(v)
314
315    def _adapt_json(self, v):
316        """Adapt a json parameter."""
317        if not v:
318            return None
319        if isinstance(v, basestring):
320            return v
321        return self.encode_json(v)
322
323    @classmethod
324    def _adapt_text_array(cls, v):
325        """Adapt a text type array parameter."""
326        if isinstance(v, list):
327            adapt = cls._adapt_text_array
328            return '{%s}' % ','.join(adapt(v) for v in v)
329        if v is None:
330            return 'null'
331        if not v:
332            return '""'
333        v = str(v)
334        if cls._re_array_quote.search(v):
335            v = '"%s"' % cls._re_array_escape.sub(r'\\\1', v)
336        return v
337
338    _adapt_date_array = _adapt_text_array
339
340    @classmethod
341    def _adapt_bool_array(cls, v):
342        """Adapt a boolean array parameter."""
343        if isinstance(v, list):
344            adapt = cls._adapt_bool_array
345            return '{%s}' % ','.join(adapt(v) for v in v)
346        if v is None:
347            return 'null'
348        if isinstance(v, basestring):
349            if not v:
350                return 'null'
351            v = v.lower() in cls._bool_true_values
352        return 't' if v else 'f'
353
354    @classmethod
355    def _adapt_num_array(cls, v):
356        """Adapt a numeric array parameter."""
357        if isinstance(v, list):
358            adapt = cls._adapt_num_array
359            return '{%s}' % ','.join(adapt(v) for v in v)
360        if not v and v != 0:
361            return 'null'
362        return str(v)
363
364    _adapt_int_array = _adapt_float_array = _adapt_money_array = \
365            _adapt_num_array
366
367    def _adapt_bytea_array(self, v):
368        """Adapt a bytea array parameter."""
369        if isinstance(v, list):
370            return b'{' + b','.join(
371                self._adapt_bytea_array(v) for v in v) + b'}'
372        if v is None:
373            return b'null'
374        return self.escape_bytea(v).replace(b'\\', b'\\\\')
375
376    def _adapt_json_array(self, v):
377        """Adapt a json array parameter."""
378        if isinstance(v, list):
379            adapt = self._adapt_json_array
380            return '{%s}' % ','.join(adapt(v) for v in v)
381        if not v:
382            return 'null'
383        if not isinstance(v, basestring):
384            v = self.encode_json(v)
385        if self._re_array_quote.search(v):
386            v = '"%s"' % self._re_array_escape.sub(r'\\\1', v)
387        return v
388
389    def _adapt_record(self, v, typ):
390        """Adapt a record parameter with given type."""
391        typ = self.get_attnames(typ).values()
392        if len(typ) != len(v):
393            raise TypeError('Record parameter %s has wrong size' % v)
394        adapt = self.adapt
395        value = []
396        for v, t in zip(v, typ):
397            v = adapt(v, t)
398            if v is None:
399                v = ''
400            elif not v:
401                v = '""'
402            else:
403                if isinstance(v, bytes):
404                    if str is not bytes:
405                        v = v.decode('ascii')
406                else:
407                    v = str(v)
408                if self._re_record_quote.search(v):
409                    v = '"%s"' % self._re_record_escape.sub(r'\\\1', v)
410            value.append(v)
411        return '(%s)' % ','.join(value)
412
413    def adapt(self, value, typ=None):
414        """Adapt a value with known database type."""
415        if value is not None and not isinstance(value, Literal):
416            if typ:
417                simple = self.get_simple_name(typ)
418            else:
419                typ = simple = self.guess_simple_type(value) or 'text'
420            try:
421                value = value.__pg_str__(typ)
422            except AttributeError:
423                pass
424            if simple == 'text':
425                pass
426            elif simple == 'record':
427                if isinstance(value, tuple):
428                    value = self._adapt_record(value, typ)
429            elif simple.endswith('[]'):
430                if isinstance(value, list):
431                    adapt = getattr(self, '_adapt_%s_array' % simple[:-2])
432                    value = adapt(value)
433            else:
434                adapt = getattr(self, '_adapt_%s' % simple)
435                value = adapt(value)
436        return value
437
438    @staticmethod
439    def simple_type(name):
440        """Create a simple database type with given attribute names."""
441        typ = DbType(name)
442        typ.simple = name
443        return typ
444
445    @staticmethod
446    def get_simple_name(typ):
447        """Get the simple name of a database type."""
448        if isinstance(typ, DbType):
449            return typ.simple
450        return _simpletypes[typ]
451
452    @staticmethod
453    def get_attnames(typ):
454        """Get the attribute names of a composite database type."""
455        if isinstance(typ, DbType):
456            return typ.attnames
457        return {}
458
459    @classmethod
460    def guess_simple_type(cls, value):
461        """Try to guess which database type the given value has."""
462        if isinstance(value, Bytea):
463            return 'bytea'
464        if isinstance(value, basestring):
465            return 'text'
466        if isinstance(value, bool):
467            return 'bool'
468        if isinstance(value, (int, long)):
469            return 'int'
470        if isinstance(value, float):
471            return 'float'
472        if isinstance(value, Decimal):
473            return 'num'
474        if isinstance(value, (date, time, datetime, timedelta)):
475            return 'date'
476        if isinstance(value, list):
477            return '%s[]' % cls.guess_simple_base_type(value)
478        if isinstance(value, tuple):
479            simple_type = cls.simple_type
480            typ = simple_type('record')
481            guess = cls.guess_simple_type
482            def get_attnames(self):
483                return AttrDict((str(n + 1), simple_type(guess(v)))
484                    for n, v in enumerate(value))
485            typ._get_attnames = get_attnames
486            return typ
487
488    @classmethod
489    def guess_simple_base_type(cls, value):
490        """Try to guess the base type of a given array."""
491        for v in value:
492            if isinstance(v, list):
493                typ = cls.guess_simple_base_type(v)
494            else:
495                typ = cls.guess_simple_type(v)
496            if typ:
497                return typ
498
499    def adapt_inline(self, value, nested=False):
500        """Adapt a value that is put into the SQL and needs to be quoted."""
501        if value is None:
502            return 'NULL'
503        if isinstance(value, Literal):
504            return value
505        if isinstance(value, Bytea):
506            value = self.escape_bytea(value)
507            if bytes is not str:  # Python >= 3.0
508                value = value.decode('ascii')
509        elif isinstance(value, Json):
510            if value.encode:
511                return value.encode()
512            value = self.encode_json(value)
513        elif isinstance(value, (datetime, date, time, timedelta)):
514            value = str(value)
515        if isinstance(value, basestring):
516            value = self.escape_string(value)
517            return "'%s'" % value
518        if isinstance(value, bool):
519            return 'true' if value else 'false'
520        if isinstance(value, float):
521            if isinf(value):
522                return "'-Infinity'" if value < 0 else "'Infinity'"
523            if isnan(value):
524                return "'NaN'"
525            return value
526        if isinstance(value, (int, long, Decimal)):
527            return value
528        if isinstance(value, list):
529            q = self.adapt_inline
530            s = '[%s]' if nested else 'ARRAY[%s]'
531            return s % ','.join(str(q(v, nested=True)) for v in value)
532        if isinstance(value, tuple):
533            q = self.adapt_inline
534            return '(%s)' % ','.join(str(q(v)) for v in value)
535        try:
536            value = value.__pg_repr__()
537        except AttributeError:
538            raise InterfaceError(
539                'Do not know how to adapt type %s' % type(value))
540        if isinstance(value, (tuple, list)):
541            value = self.adapt_inline(value)
542        return value
543
544    def parameter_list(self):
545        """Return a parameter list for parameters with known database types.
546
547        The list has an add(value, typ) method that will build up the
548        list and return either the literal value or a placeholder.
549        """
550        params = _ParameterList()
551        params.adapt = self.adapt
552        return params
553
554    def format_query(self, command, values, types=None, inline=False):
555        """Format a database query using the given values and types."""
556        if inline and types:
557            raise ValueError('Typed parameters must be sent separately')
558        params = self.parameter_list()
559        if isinstance(values, (list, tuple)):
560            if inline:
561                adapt = self.adapt_inline
562                literals = [adapt(value) for value in values]
563            else:
564                add = params.add
565                literals = []
566                append = literals.append
567                if types:
568                    if (not isinstance(types, (list, tuple)) or
569                            len(types) != len(values)):
570                        raise TypeError('The values and types do not match')
571                    for value, typ in zip(values, types):
572                        append(add(value, typ))
573                else:
574                    for value in values:
575                        append(add(value))
576            command = command % tuple(literals)
577        elif isinstance(values, dict):
578            if inline:
579                adapt = self.adapt_inline
580                literals = dict((key, adapt(value))
581                    for key, value in values.items())
582            else:
583                add = params.add
584                literals = {}
585                if types:
586                    if (not isinstance(types, dict) or
587                            len(types) < len(values)):
588                        raise TypeError('The values and types do not match')
589                    for key in sorted(values):
590                        literals[key] = add(values[key], types[key])
591                else:
592                    for key in sorted(values):
593                        literals[key] = add(values[key])
594            command = command % literals
595        else:
596            raise TypeError('The values must be passed as tuple, list or dict')
597        return command, params
598
599
600def cast_bool(value):
601    """Cast a boolean value."""
602    if not get_bool():
603        return value
604    return value[0] == 't'
605
606
607def cast_json(value):
608    """Cast a JSON value."""
609    cast = get_jsondecode()
610    if not cast:
611        return value
612    return cast(value)
613
614
615def cast_num(value):
616    """Cast a numeric value."""
617    return (get_decimal() or float)(value)
618
619
620def cast_money(value):
621    """Cast a money value."""
622    point = get_decimal_point()
623    if not point:
624        return value
625    if point != '.':
626        value = value.replace(point, '.')
627    value = value.replace('(', '-')
628    value = ''.join(c for c in value if c.isdigit() or c in '.-')
629    return (get_decimal() or float)(value)
630
631
632def cast_int2vector(value):
633    """Cast an int2vector value."""
634    return [int(v) for v in value.split()]
635
636
637def cast_date(value, connection):
638    """Cast a date value."""
639    # The output format depends on the server setting DateStyle.  The default
640    # setting ISO and the setting for German are actually unambiguous.  The
641    # order of days and months in the other two settings is however ambiguous,
642    # so at least here we need to consult the setting to properly parse values.
643    if value == '-infinity':
644        return date.min
645    if value == 'infinity':
646        return date.max
647    value = value.split()
648    if value[-1] == 'BC':
649        return date.min
650    value = value[0]
651    if len(value) > 10:
652        return date.max
653    fmt = connection.date_format()
654    return datetime.strptime(value, fmt).date()
655
656
657def cast_time(value):
658    """Cast a time value."""
659    fmt = '%H:%M:%S.%f' if len(value) > 8 else '%H:%M:%S'
660    return datetime.strptime(value, fmt).time()
661
662
663_re_timezone = regex('(.*)([+-].*)')
664
665
666def cast_timetz(value):
667    """Cast a timetz value."""
668    tz = _re_timezone.match(value)
669    if tz:
670        value, tz = tz.groups()
671    else:
672        tz = '+0000'
673    fmt = '%H:%M:%S.%f' if len(value) > 8 else '%H:%M:%S'
674    if timezones:
675        if tz.startswith(('+', '-')):
676            if len(tz) < 5:
677                tz += '00'
678            else:
679                tz = tz.replace(':', '')
680        elif tz in timezones:
681            tz = timezones[tz]
682        else:
683            tz = '+0000'
684        value += tz
685        fmt += '%z'
686    return datetime.strptime(value, fmt).timetz()
687
688
689def cast_timestamp(value, connection):
690    """Cast a timestamp value."""
691    if value == '-infinity':
692        return datetime.min
693    if value == 'infinity':
694        return datetime.max
695    value = value.split()
696    if value[-1] == 'BC':
697        return datetime.min
698    fmt = connection.date_format()
699    if fmt.endswith('-%Y') and len(value) > 2:
700        value = value[1:5]
701        if len(value[3]) > 4:
702            return datetime.max
703        fmt = ['%d %b' if fmt.startswith('%d') else '%b %d',
704            '%H:%M:%S.%f' if len(value[2]) > 8 else '%H:%M:%S', '%Y']
705    else:
706        if len(value[0]) > 10:
707            return datetime.max
708        fmt = [fmt, '%H:%M:%S.%f' if len(value[1]) > 8 else '%H:%M:%S']
709    return datetime.strptime(' '.join(value), ' '.join(fmt))
710
711
712def cast_timestamptz(value, connection):
713    """Cast a timestamptz value."""
714    if value == '-infinity':
715        return datetime.min
716    if value == 'infinity':
717        return datetime.max
718    value = value.split()
719    if value[-1] == 'BC':
720        return datetime.min
721    fmt = connection.date_format()
722    if fmt.endswith('-%Y') and len(value) > 2:
723        value = value[1:]
724        if len(value[3]) > 4:
725            return datetime.max
726        fmt = ['%d %b' if fmt.startswith('%d') else '%b %d',
727            '%H:%M:%S.%f' if len(value[2]) > 8 else '%H:%M:%S', '%Y']
728        value, tz = value[:-1], value[-1]
729    else:
730        if fmt.startswith('%Y-'):
731            tz = _re_timezone.match(value[1])
732            if tz:
733                value[1], tz = tz.groups()
734            else:
735                tz = '+0000'
736        else:
737            value, tz = value[:-1], value[-1]
738        if len(value[0]) > 10:
739            return datetime.max
740        fmt = [fmt, '%H:%M:%S.%f' if len(value[1]) > 8 else '%H:%M:%S']
741    if timezones:
742        if tz.startswith(('+', '-')):
743            if len(tz) < 5:
744                tz += '00'
745            else:
746                tz = tz.replace(':', '')
747        elif tz in timezones:
748            tz = timezones[tz]
749        else:
750            tz = '+0000'
751        value.append(tz)
752        fmt.append('%z')
753    return datetime.strptime(' '.join(value), ' '.join(fmt))
754
755_re_interval_sql_standard = regex(
756    '(?:([+-])?([0-9]+)-([0-9]+) ?)?'
757    '(?:([+-]?[0-9]+)(?!:) ?)?'
758    '(?:([+-])?([0-9]+):([0-9]+):([0-9]+)(?:\\.([0-9]+))?)?')
759
760_re_interval_postgres = regex(
761    '(?:([+-]?[0-9]+) ?years? ?)?'
762    '(?:([+-]?[0-9]+) ?mons? ?)?'
763    '(?:([+-]?[0-9]+) ?days? ?)?'
764    '(?:([+-])?([0-9]+):([0-9]+):([0-9]+)(?:\\.([0-9]+))?)?')
765
766_re_interval_postgres_verbose = regex(
767    '@ ?(?:([+-]?[0-9]+) ?years? ?)?'
768    '(?:([+-]?[0-9]+) ?mons? ?)?'
769    '(?:([+-]?[0-9]+) ?days? ?)?'
770    '(?:([+-]?[0-9]+) ?hours? ?)?'
771    '(?:([+-]?[0-9]+) ?mins? ?)?'
772    '(?:([+-])?([0-9]+)(?:\\.([0-9]+))? ?secs?)? ?(ago)?')
773
774_re_interval_iso_8601 = regex(
775    'P(?:([+-]?[0-9]+)Y)?'
776    '(?:([+-]?[0-9]+)M)?'
777    '(?:([+-]?[0-9]+)D)?'
778    '(?:T(?:([+-]?[0-9]+)H)?'
779    '(?:([+-]?[0-9]+)M)?'
780    '(?:([+-])?([0-9]+)(?:\\.([0-9]+))?S)?)?')
781
782
783def cast_interval(value):
784    """Cast an interval value."""
785    # The output format depends on the server setting IntervalStyle, but it's
786    # not necessary to consult this setting to parse it.  It's faster to just
787    # check all possible formats, and there is no ambiguity here.
788    m = _re_interval_iso_8601.match(value)
789    if m:
790        m = [d or '0' for d in m.groups()]
791        secs_ago = m.pop(5) == '-'
792        m = [int(d) for d in m]
793        years, mons, days, hours, mins, secs, usecs = m
794        if secs_ago:
795            secs = -secs
796            usecs = -usecs
797    else:
798        m = _re_interval_postgres_verbose.match(value)
799        if m:
800            m, ago = [d or '0' for d in m.groups()[:8]], m.group(9)
801            secs_ago = m.pop(5) == '-'
802            m = [-int(d) for d in m] if ago else [int(d) for d in m]
803            years, mons, days, hours, mins, secs, usecs = m
804            if secs_ago:
805                secs = - secs
806                usecs = -usecs
807        else:
808            m = _re_interval_postgres.match(value)
809            if m and any(m.groups()):
810                m = [d or '0' for d in m.groups()]
811                hours_ago = m.pop(3) == '-'
812                m = [int(d) for d in m]
813                years, mons, days, hours, mins, secs, usecs = m
814                if hours_ago:
815                    hours = -hours
816                    mins = -mins
817                    secs = -secs
818                    usecs = -usecs
819            else:
820                m = _re_interval_sql_standard.match(value)
821                if m and any(m.groups()):
822                    m = [d or '0' for d in m.groups()]
823                    years_ago = m.pop(0) == '-'
824                    hours_ago = m.pop(3) == '-'
825                    m = [int(d) for d in m]
826                    years, mons, days, hours, mins, secs, usecs = m
827                    if years_ago:
828                        years = -years
829                        mons = -mons
830                    if hours_ago:
831                        hours = -hours
832                        mins = -mins
833                        secs = -secs
834                        usecs = -usecs
835                else:
836                    raise ValueError('Cannot parse interval: %s' % value)
837    days += 365 * years + 30 * mons
838    return timedelta(days=days, hours=hours, minutes=mins,
839        seconds=secs, microseconds=usecs)
840
841
842class Typecasts(dict):
843    """Dictionary mapping database types to typecast functions.
844
845    The cast functions get passed the string representation of a value in
846    the database which they need to convert to a Python object.  The
847    passed string will never be None since NULL values are already be
848    handled before the cast function is called.
849
850    Note that the basic types are already handled by the C extension.
851    They only need to be handled here as record or array components.
852    """
853
854    # the default cast functions
855    # (str functions are ignored but have been added for faster access)
856    defaults = {'char': str, 'bpchar': str, 'name': str,
857        'text': str, 'varchar': str,
858        'bool': cast_bool, 'bytea': unescape_bytea,
859        'int2': int, 'int4': int, 'serial': int, 'int8': long,
860        'hstore': cast_hstore, 'json': cast_json, 'jsonb': cast_json,
861        'oid': long, 'oid8': long,
862        'float4': float, 'float8': float,
863        'numeric': cast_num, 'money': cast_money,
864        'date': cast_date, 'interval': cast_interval,
865        'time': cast_time, 'timetz': cast_timetz,
866        'timestamp': cast_timestamp, 'timestamptz': cast_timestamptz,
867        'int2vector': cast_int2vector,
868        'anyarray': cast_array, 'record': cast_record}
869
870    connection = None  # will be set in a connection specific instance
871
872    def __missing__(self, typ):
873        """Create a cast function if it is not cached.
874       
875        Note that this class never raises a KeyError,
876        but returns None when no special cast function exists.
877        """
878        if not isinstance(typ, str):
879            raise TypeError('Invalid type: %s' % typ)
880        cast = self.defaults.get(typ)
881        if cast:
882            # store default for faster access
883            cast = self._add_connection(cast)
884            self[typ] = cast
885        elif typ.startswith('_'):
886            base_cast = self[typ[1:]]
887            cast = self.create_array_cast(base_cast)
888            if base_cast:
889                self[typ] = cast
890        else:
891            attnames = self.get_attnames(typ)
892            if attnames:
893                casts = [self[v.pgtype] for v in attnames.values()]
894                cast = self.create_record_cast(typ, attnames, casts)
895                self[typ] = cast
896        return cast
897
898    @staticmethod
899    def _needs_connection(func):
900        """Check if a typecast function needs a connection argument."""
901        try:
902            args = get_args(func)
903        except (TypeError, ValueError):
904            return False
905        else:
906            return 'connection' in args[1:]
907
908    def _add_connection(self, cast):
909        """Add a connection argument to the typecast function if necessary."""
910        if not self.connection or not self._needs_connection(cast):
911            return cast
912        return partial(cast, connection=self.connection)
913
914    def get(self, typ, default=None):
915        """Get the typecast function for the given database type."""
916        return self[typ] or default
917
918    def set(self, typ, cast):
919        """Set a typecast function for the specified database type(s)."""
920        if isinstance(typ, basestring):
921            typ = [typ]
922        if cast is None:
923            for t in typ:
924                self.pop(t, None)
925                self.pop('_%s' % t, None)
926        else:
927            if not callable(cast):
928                raise TypeError("Cast parameter must be callable")
929            for t in typ:
930                self[t] = self._add_connection(cast)
931                self.pop('_%s' % t, None)
932
933    def reset(self, typ=None):
934        """Reset the typecasts for the specified type(s) to their defaults.
935
936        When no type is specified, all typecasts will be reset.
937        """
938        if typ is None:
939            self.clear()
940        else:
941            if isinstance(typ, basestring):
942                typ = [typ]
943            for t in typ:
944                self.pop(t, None)
945
946    @classmethod
947    def get_default(cls, typ):
948        """Get the default typecast function for the given database type."""
949        return cls.defaults.get(typ)
950
951    @classmethod
952    def set_default(cls, typ, cast):
953        """Set a default typecast function for the given database type(s)."""
954        if isinstance(typ, basestring):
955            typ = [typ]
956        defaults = cls.defaults
957        if cast is None:
958            for t in typ:
959                defaults.pop(t, None)
960                defaults.pop('_%s' % t, None)
961        else:
962            if not callable(cast):
963                raise TypeError("Cast parameter must be callable")
964            for t in typ:
965                defaults[t] = cast
966                defaults.pop('_%s' % t, None)
967
968    def get_attnames(self, typ):
969        """Return the fields for the given record type.
970
971        This method will be replaced with the get_attnames() method of DbTypes.
972        """
973        return {}
974
975    def dateformat(self):
976        """Return the current date format.
977
978        This method will be replaced with the dateformat() method of DbTypes.
979        """
980        return '%Y-%m-%d'
981
982    def create_array_cast(self, basecast):
983        """Create an array typecast for the given base cast."""
984        def cast(v):
985            return cast_array(v, basecast)
986        return cast
987
988    def create_record_cast(self, name, fields, casts):
989        """Create a named record typecast for the given fields and casts."""
990        record = namedtuple(name, fields)
991        def cast(v):
992            return record(*cast_record(v, casts))
993        return cast
994
995
996def get_typecast(typ):
997    """Get the global typecast function for the given database type(s)."""
998    return Typecasts.get_default(typ)
999
1000
1001def set_typecast(typ, cast):
1002    """Set a global typecast function for the given database type(s).
1003
1004    Note that connections cache cast functions. To be sure a global change
1005    is picked up by a running connection, call db.db_types.reset_typecast().
1006    """
1007    Typecasts.set_default(typ, cast)
1008
1009
1010class DbType(str):
1011    """Class augmenting the simple type name with additional info.
1012
1013    The following additional information is provided:
1014
1015        oid: the PostgreSQL type OID
1016        pgtype: the PostgreSQL type name
1017        regtype: the regular type name
1018        simple: the simple PyGreSQL type name
1019        typtype: b = base type, c = composite type etc.
1020        category: A = Array, b =Boolean, C = Composite etc.
1021        delim: delimiter for array types
1022        relid: corresponding table for composite types
1023        attnames: attributes for composite types
1024    """
1025
1026    @property
1027    def attnames(self):
1028        """Get names and types of the fields of a composite type."""
1029        return self._get_attnames(self)
1030
1031
1032class DbTypes(dict):
1033    """Cache for PostgreSQL data types.
1034
1035    This cache maps type OIDs and names to DbType objects containing
1036    information on the associated database type.
1037    """
1038
1039    _num_types = frozenset('int float num money'
1040        ' int2 int4 int8 float4 float8 numeric money'.split())
1041
1042    def __init__(self, db):
1043        """Initialize type cache for connection."""
1044        super(DbTypes, self).__init__()
1045        self._regtypes = False
1046        self._get_attnames = db.get_attnames
1047        self._typecasts = Typecasts()
1048        self._typecasts.get_attnames = self.get_attnames
1049        self._typecasts.connection = db
1050        db = db.db
1051        self.query = db.query
1052        self.escape_string = db.escape_string
1053
1054    def add(self, oid, pgtype, regtype,
1055               typtype, category, delim, relid):
1056        """Create a PostgreSQL type name with additional info."""
1057        if oid in self:
1058            return self[oid]
1059        simple = 'record' if relid else _simpletypes[pgtype]
1060        typ = DbType(regtype if self._regtypes else simple)
1061        typ.oid = oid
1062        typ.simple = simple
1063        typ.pgtype = pgtype
1064        typ.regtype = regtype
1065        typ.typtype = typtype
1066        typ.category = category
1067        typ.delim = delim
1068        typ.relid = relid
1069        typ._get_attnames = self.get_attnames
1070        return typ
1071
1072    def __missing__(self, key):
1073        """Get the type info from the database if it is not cached."""
1074        try:
1075            res = self.query("SELECT oid, typname, typname::regtype,"
1076                " typtype, typcategory, typdelim, typrelid"
1077                " FROM pg_type WHERE oid=%s::regtype" %
1078                (_quote_if_unqualified('$1', key),), (key,)).getresult()
1079        except ProgrammingError:
1080            res = None
1081        if not res:
1082            raise KeyError('Type %s could not be found' % key)
1083        res = res[0]
1084        typ = self.add(*res)
1085        self[typ.oid] = self[typ.pgtype] = typ
1086        return typ
1087
1088    def get(self, key, default=None):
1089        """Get the type even if it is not cached."""
1090        try:
1091            return self[key]
1092        except KeyError:
1093            return default
1094
1095    def get_attnames(self, typ):
1096        """Get names and types of the fields of a composite type."""
1097        if not isinstance(typ, DbType):
1098            typ = self.get(typ)
1099            if not typ:
1100                return None
1101        if not typ.relid:
1102            return None
1103        return self._get_attnames(typ.relid, with_oid=False)
1104
1105    def get_typecast(self, typ):
1106        """Get the typecast function for the given database type."""
1107        return self._typecasts.get(typ)
1108
1109    def set_typecast(self, typ, cast):
1110        """Set a typecast function for the specified database type(s)."""
1111        self._typecasts.set(typ, cast)
1112
1113    def reset_typecast(self, typ=None):
1114        """Reset the typecast function for the specified database type(s)."""
1115        self._typecasts.reset(typ)
1116
1117    def typecast(self, value, typ):
1118        """Cast the given value according to the given database type."""
1119        if value is None:
1120            # for NULL values, no typecast is necessary
1121            return None
1122        if not isinstance(typ, DbType):
1123            typ = self.get(typ)
1124            if typ:
1125                typ = typ.pgtype
1126        cast = self.get_typecast(typ) if typ else None
1127        if not cast or cast is str:
1128            # no typecast is necessary
1129            return value
1130        return cast(value)
1131
1132
1133def _namedresult(q):
1134    """Get query result as named tuples."""
1135    row = namedtuple('Row', q.listfields())
1136    return [row(*r) for r in q.getresult()]
1137
1138
1139class _MemoryQuery:
1140    """Class that embodies a given query result."""
1141
1142    def __init__(self, result, fields):
1143        """Create query from given result rows and field names."""
1144        self.result = result
1145        self.fields = fields
1146
1147    def listfields(self):
1148        """Return the stored field names of this query."""
1149        return self.fields
1150
1151    def getresult(self):
1152        """Return the stored result of this query."""
1153        return self.result
1154
1155
1156def _db_error(msg, cls=DatabaseError):
1157    """Return DatabaseError with empty sqlstate attribute."""
1158    error = cls(msg)
1159    error.sqlstate = None
1160    return error
1161
1162
1163def _int_error(msg):
1164    """Return InternalError."""
1165    return _db_error(msg, InternalError)
1166
1167
1168def _prg_error(msg):
1169    """Return ProgrammingError."""
1170    return _db_error(msg, ProgrammingError)
1171
1172
1173# Initialize the C module
1174
1175set_namedresult(_namedresult)
1176set_decimal(Decimal)
1177set_jsondecode(jsondecode)
1178
1179
1180# The notification handler
1181
1182class NotificationHandler(object):
1183    """A PostgreSQL client-side asynchronous notification handler."""
1184
1185    def __init__(self, db, event, callback=None,
1186            arg_dict=None, timeout=None, stop_event=None):
1187        """Initialize the notification handler.
1188
1189        You must pass a PyGreSQL database connection, the name of an
1190        event (notification channel) to listen for and a callback function.
1191
1192        You can also specify a dictionary arg_dict that will be passed as
1193        the single argument to the callback function, and a timeout value
1194        in seconds (a floating point number denotes fractions of seconds).
1195        If it is absent or None, the callers will never time out.  If the
1196        timeout is reached, the callback function will be called with a
1197        single argument that is None.  If you set the timeout to zero,
1198        the handler will poll notifications synchronously and return.
1199
1200        You can specify the name of the event that will be used to signal
1201        the handler to stop listening as stop_event. By default, it will
1202        be the event name prefixed with 'stop_'.
1203        """
1204        self.db = db
1205        self.event = event
1206        self.stop_event = stop_event or 'stop_%s' % event
1207        self.listening = False
1208        self.callback = callback
1209        if arg_dict is None:
1210            arg_dict = {}
1211        self.arg_dict = arg_dict
1212        self.timeout = timeout
1213
1214    def __del__(self):
1215        self.unlisten()
1216
1217    def close(self):
1218        """Stop listening and close the connection."""
1219        if self.db:
1220            self.unlisten()
1221            self.db.close()
1222            self.db = None
1223
1224    def listen(self):
1225        """Start listening for the event and the stop event."""
1226        if not self.listening:
1227            self.db.query('listen "%s"' % self.event)
1228            self.db.query('listen "%s"' % self.stop_event)
1229            self.listening = True
1230
1231    def unlisten(self):
1232        """Stop listening for the event and the stop event."""
1233        if self.listening:
1234            self.db.query('unlisten "%s"' % self.event)
1235            self.db.query('unlisten "%s"' % self.stop_event)
1236            self.listening = False
1237
1238    def notify(self, db=None, stop=False, payload=None):
1239        """Generate a notification.
1240
1241        Optionally, you can pass a payload with the notification.
1242
1243        If you set the stop flag, a stop notification will be sent that
1244        will cause the handler to stop listening.
1245
1246        Note: If the notification handler is running in another thread, you
1247        must pass a different database connection since PyGreSQL database
1248        connections are not thread-safe.
1249        """
1250        if self.listening:
1251            if not db:
1252                db = self.db
1253            q = 'notify "%s"' % (self.stop_event if stop else self.event)
1254            if payload:
1255                q += ", '%s'" % payload
1256            return db.query(q)
1257
1258    def __call__(self):
1259        """Invoke the notification handler.
1260
1261        The handler is a loop that listens for notifications on the event
1262        and stop event channels.  When either of these notifications are
1263        received, its associated 'pid', 'event' and 'extra' (the payload
1264        passed with the notification) are inserted into its arg_dict
1265        dictionary and the callback is invoked with this dictionary as
1266        a single argument.  When the handler receives a stop event, it
1267        stops listening to both events and return.
1268
1269        In the special case that the timeout of the handler has been set
1270        to zero, the handler will poll all events synchronously and return.
1271        If will keep listening until it receives a stop event.
1272
1273        Note: If you run this loop in another thread, don't use the same
1274        database connection for database operations in the main thread.
1275        """
1276        self.listen()
1277        poll = self.timeout == 0
1278        if not poll:
1279            rlist = [self.db.fileno()]
1280        while self.listening:
1281            if poll or select.select(rlist, [], [], self.timeout)[0]:
1282                while self.listening:
1283                    notice = self.db.getnotify()
1284                    if not notice:  # no more messages
1285                        break
1286                    event, pid, extra = notice
1287                    if event not in (self.event, self.stop_event):
1288                        self.unlisten()
1289                        raise _db_error(
1290                            'Listening for "%s" and "%s", but notified of "%s"'
1291                            % (self.event, self.stop_event, event))
1292                    if event == self.stop_event:
1293                        self.unlisten()
1294                    self.arg_dict.update(pid=pid, event=event, extra=extra)
1295                    self.callback(self.arg_dict)
1296                if poll:
1297                    break
1298            else:   # we timed out
1299                self.unlisten()
1300                self.callback(None)
1301
1302
1303def pgnotify(*args, **kw):
1304    """Same as NotificationHandler, under the traditional name."""
1305    warnings.warn("pgnotify is deprecated, use NotificationHandler instead",
1306        DeprecationWarning, stacklevel=2)
1307    return NotificationHandler(*args, **kw)
1308
1309
1310# The actual PostGreSQL database connection interface:
1311
1312class DB:
1313    """Wrapper class for the _pg connection type."""
1314
1315    def __init__(self, *args, **kw):
1316        """Create a new connection
1317
1318        You can pass either the connection parameters or an existing
1319        _pg or pgdb connection. This allows you to use the methods
1320        of the classic pg interface with a DB-API 2 pgdb connection.
1321        """
1322        if not args and len(kw) == 1:
1323            db = kw.get('db')
1324        elif not kw and len(args) == 1:
1325            db = args[0]
1326        else:
1327            db = None
1328        if db:
1329            if isinstance(db, DB):
1330                db = db.db
1331            else:
1332                try:
1333                    db = db._cnx
1334                except AttributeError:
1335                    pass
1336        if not db or not hasattr(db, 'db') or not hasattr(db, 'query'):
1337            db = connect(*args, **kw)
1338            self._closeable = True
1339        else:
1340            self._closeable = False
1341        self.db = db
1342        self.dbname = db.db
1343        self._regtypes = False
1344        self._attnames = {}
1345        self._pkeys = {}
1346        self._privileges = {}
1347        self._args = args, kw
1348        self.adapter = Adapter(self)
1349        self.dbtypes = DbTypes(self)
1350        db.set_cast_hook(self.dbtypes.typecast)
1351        self.debug = None  # For debugging scripts, this can be set
1352            # * to a string format specification (e.g. in CGI set to "%s<BR>"),
1353            # * to a file object to write debug statements or
1354            # * to a callable object which takes a string argument
1355            # * to any other true value to just print debug statements
1356
1357    def __getattr__(self, name):
1358        # All undefined members are same as in underlying connection:
1359        if self.db:
1360            return getattr(self.db, name)
1361        else:
1362            raise _int_error('Connection is not valid')
1363
1364    def __dir__(self):
1365        # Custom dir function including the attributes of the connection:
1366        attrs = set(self.__class__.__dict__)
1367        attrs.update(self.__dict__)
1368        attrs.update(dir(self.db))
1369        return sorted(attrs)
1370
1371    # Context manager methods
1372
1373    def __enter__(self):
1374        """Enter the runtime context. This will start a transactio."""
1375        self.begin()
1376        return self
1377
1378    def __exit__(self, et, ev, tb):
1379        """Exit the runtime context. This will end the transaction."""
1380        if et is None and ev is None and tb is None:
1381            self.commit()
1382        else:
1383            self.rollback()
1384
1385    # Auxiliary methods
1386
1387    def _do_debug(self, *args):
1388        """Print a debug message"""
1389        if self.debug:
1390            s = '\n'.join(str(arg) for arg in args)
1391            if isinstance(self.debug, basestring):
1392                print(self.debug % s)
1393            elif hasattr(self.debug, 'write'):
1394                self.debug.write(s + '\n')
1395            elif callable(self.debug):
1396                self.debug(s)
1397            else:
1398                print(s)
1399
1400    def _escape_qualified_name(self, s):
1401        """Escape a qualified name.
1402
1403        Escapes the name for use as an SQL identifier, unless the
1404        name contains a dot, in which case the name is ambiguous
1405        (could be a qualified name or just a name with a dot in it)
1406        and must be quoted manually by the caller.
1407        """
1408        if '.' not in s:
1409            s = self.escape_identifier(s)
1410        return s
1411
1412    @staticmethod
1413    def _make_bool(d):
1414        """Get boolean value corresponding to d."""
1415        return bool(d) if get_bool() else ('t' if d else 'f')
1416
1417    def _list_params(self, params):
1418        """Create a human readable parameter list."""
1419        return ', '.join('$%d=%r' % (n, v) for n, v in enumerate(params, 1))
1420
1421    # Public methods
1422
1423    # escape_string and escape_bytea exist as methods,
1424    # so we define unescape_bytea as a method as well
1425    unescape_bytea = staticmethod(unescape_bytea)
1426
1427    def decode_json(self, s):
1428        """Decode a JSON string coming from the database."""
1429        return (get_jsondecode() or jsondecode)(s)
1430
1431    def encode_json(self, d):
1432        """Encode a JSON string for use within SQL."""
1433        return jsonencode(d)
1434
1435    def close(self):
1436        """Close the database connection."""
1437        # Wraps shared library function so we can track state.
1438        if self._closeable:
1439            if self.db:
1440                self.db.close()
1441                self.db = None
1442            else:
1443                raise _int_error('Connection already closed')
1444
1445    def reset(self):
1446        """Reset connection with current parameters.
1447
1448        All derived queries and large objects derived from this connection
1449        will not be usable after this call.
1450
1451        """
1452        if self.db:
1453            self.db.reset()
1454        else:
1455            raise _int_error('Connection already closed')
1456
1457    def reopen(self):
1458        """Reopen connection to the database.
1459
1460        Used in case we need another connection to the same database.
1461        Note that we can still reopen a database that we have closed.
1462
1463        """
1464        # There is no such shared library function.
1465        if self._closeable:
1466            db = connect(*self._args[0], **self._args[1])
1467            if self.db:
1468                self.db.close()
1469            self.db = db
1470
1471    def begin(self, mode=None):
1472        """Begin a transaction."""
1473        qstr = 'BEGIN'
1474        if mode:
1475            qstr += ' ' + mode
1476        return self.query(qstr)
1477
1478    start = begin
1479
1480    def commit(self):
1481        """Commit the current transaction."""
1482        return self.query('COMMIT')
1483
1484    end = commit
1485
1486    def rollback(self, name=None):
1487        """Roll back the current transaction."""
1488        qstr = 'ROLLBACK'
1489        if name:
1490            qstr += ' TO ' + name
1491        return self.query(qstr)
1492
1493    abort = rollback
1494
1495    def savepoint(self, name):
1496        """Define a new savepoint within the current transaction."""
1497        return self.query('SAVEPOINT ' + name)
1498
1499    def release(self, name):
1500        """Destroy a previously defined savepoint."""
1501        return self.query('RELEASE ' + name)
1502
1503    def get_parameter(self, parameter):
1504        """Get the value of a run-time parameter.
1505
1506        If the parameter is a string, the return value will also be a string
1507        that is the current setting of the run-time parameter with that name.
1508
1509        You can get several parameters at once by passing a list, set or dict.
1510        When passing a list of parameter names, the return value will be a
1511        corresponding list of parameter settings.  When passing a set of
1512        parameter names, a new dict will be returned, mapping these parameter
1513        names to their settings.  Finally, if you pass a dict as parameter,
1514        its values will be set to the current parameter settings corresponding
1515        to its keys.
1516
1517        By passing the special name 'all' as the parameter, you can get a dict
1518        of all existing configuration parameters.
1519        """
1520        if isinstance(parameter, basestring):
1521            parameter = [parameter]
1522            values = None
1523        elif isinstance(parameter, (list, tuple)):
1524            values = []
1525        elif isinstance(parameter, (set, frozenset)):
1526            values = {}
1527        elif isinstance(parameter, dict):
1528            values = parameter
1529        else:
1530            raise TypeError(
1531                'The parameter must be a string, list, set or dict')
1532        if not parameter:
1533            raise TypeError('No parameter has been specified')
1534        params = {} if isinstance(values, dict) else []
1535        for key in parameter:
1536            param = key.strip().lower() if isinstance(
1537                key, basestring) else None
1538            if not param:
1539                raise TypeError('Invalid parameter')
1540            if param == 'all':
1541                q = 'SHOW ALL'
1542                values = self.db.query(q).getresult()
1543                values = dict(value[:2] for value in values)
1544                break
1545            if isinstance(values, dict):
1546                params[param] = key
1547            else:
1548                params.append(param)
1549        else:
1550            for param in params:
1551                q = 'SHOW %s' % (param,)
1552                value = self.db.query(q).getresult()[0][0]
1553                if values is None:
1554                    values = value
1555                elif isinstance(values, list):
1556                    values.append(value)
1557                else:
1558                    values[params[param]] = value
1559        return values
1560
1561    def set_parameter(self, parameter, value=None, local=False):
1562        """Set the value of a run-time parameter.
1563
1564        If the parameter and the value are strings, the run-time parameter
1565        will be set to that value.  If no value or None is passed as a value,
1566        then the run-time parameter will be restored to its default value.
1567
1568        You can set several parameters at once by passing a list of parameter
1569        names, together with a single value that all parameters should be
1570        set to or with a corresponding list of values.  You can also pass
1571        the parameters as a set if you only provide a single value.
1572        Finally, you can pass a dict with parameter names as keys.  In this
1573        case, you should not pass a value, since the values for the parameters
1574        will be taken from the dict.
1575
1576        By passing the special name 'all' as the parameter, you can reset
1577        all existing settable run-time parameters to their default values.
1578
1579        If you set local to True, then the command takes effect for only the
1580        current transaction.  After commit() or rollback(), the session-level
1581        setting takes effect again.  Setting local to True will appear to
1582        have no effect if it is executed outside a transaction, since the
1583        transaction will end immediately.
1584        """
1585        if isinstance(parameter, basestring):
1586            parameter = {parameter: value}
1587        elif isinstance(parameter, (list, tuple)):
1588            if isinstance(value, (list, tuple)):
1589                parameter = dict(zip(parameter, value))
1590            else:
1591                parameter = dict.fromkeys(parameter, value)
1592        elif isinstance(parameter, (set, frozenset)):
1593            if isinstance(value, (list, tuple, set, frozenset)):
1594                value = set(value)
1595                if len(value) == 1:
1596                    value = value.pop()
1597            if not(value is None or isinstance(value, basestring)):
1598                raise ValueError('A single value must be specified'
1599                    ' when parameter is a set')
1600            parameter = dict.fromkeys(parameter, value)
1601        elif isinstance(parameter, dict):
1602            if value is not None:
1603                raise ValueError('A value must not be specified'
1604                    ' when parameter is a dictionary')
1605        else:
1606            raise TypeError(
1607                'The parameter must be a string, list, set or dict')
1608        if not parameter:
1609            raise TypeError('No parameter has been specified')
1610        params = {}
1611        for key, value in parameter.items():
1612            param = key.strip().lower() if isinstance(
1613                key, basestring) else None
1614            if not param:
1615                raise TypeError('Invalid parameter')
1616            if param == 'all':
1617                if value is not None:
1618                    raise ValueError('A value must ot be specified'
1619                        " when parameter is 'all'")
1620                params = {'all': None}
1621                break
1622            params[param] = value
1623        local = ' LOCAL' if local else ''
1624        for param, value in params.items():
1625            if value is None:
1626                q = 'RESET%s %s' % (local, param)
1627            else:
1628                q = 'SET%s %s TO %s' % (local, param, value)
1629            self._do_debug(q)
1630            self.db.query(q)
1631
1632    def query(self, command, *args):
1633        """Execute a SQL command string.
1634
1635        This method simply sends a SQL query to the database.  If the query is
1636        an insert statement that inserted exactly one row into a table that
1637        has OIDs, the return value is the OID of the newly inserted row.
1638        If the query is an update or delete statement, or an insert statement
1639        that did not insert exactly one row in a table with OIDs, then the
1640        number of rows affected is returned as a string.  If it is a statement
1641        that returns rows as a result (usually a select statement, but maybe
1642        also an "insert/update ... returning" statement), this method returns
1643        a Query object that can be accessed via getresult() or dictresult()
1644        or simply printed.  Otherwise, it returns `None`.
1645
1646        The query can contain numbered parameters of the form $1 in place
1647        of any data constant.  Arguments given after the query string will
1648        be substituted for the corresponding numbered parameter.  Parameter
1649        values can also be given as a single list or tuple argument.
1650        """
1651        # Wraps shared library function for debugging.
1652        if not self.db:
1653            raise _int_error('Connection is not valid')
1654        if args:
1655            self._do_debug(command, args)
1656            return self.db.query(command, args)
1657        self._do_debug(command)
1658        return self.db.query(command)
1659
1660    def query_formatted(self, command, parameters, types=None, inline=False):
1661        """Execute a formatted SQL command string.
1662
1663        Similar to query, but using Python format placeholders of the form
1664        %s or %(names)s instead of PostgreSQL placeholders of the form $1.
1665        The parameters must be passed as a tuple, list or dict.  You can
1666        also pass a corresponding tuple, list or dict of database types in
1667        order to format the parameters properly in case there is ambiguity.
1668
1669        If you set inline to True, the parameters will be sent to the database
1670        embedded in the SQL command, otherwise they will be sent separately.
1671        """
1672        return self.query(*self.adapter.format_query(
1673            command, parameters, types, inline))
1674
1675    def pkey(self, table, composite=False, flush=False):
1676        """Get or set the primary key of a table.
1677
1678        Single primary keys are returned as strings unless you
1679        set the composite flag.  Composite primary keys are always
1680        represented as tuples.  Note that this raises a KeyError
1681        if the table does not have a primary key.
1682
1683        If flush is set then the internal cache for primary keys will
1684        be flushed.  This may be necessary after the database schema or
1685        the search path has been changed.
1686        """
1687        pkeys = self._pkeys
1688        if flush:
1689            pkeys.clear()
1690            self._do_debug('The pkey cache has been flushed')
1691        try:  # cache lookup
1692            pkey = pkeys[table]
1693        except KeyError:  # cache miss, check the database
1694            q = ("SELECT a.attname, a.attnum, i.indkey FROM pg_index i"
1695                " JOIN pg_attribute a ON a.attrelid = i.indrelid"
1696                " AND a.attnum = ANY(i.indkey)"
1697                " AND NOT a.attisdropped"
1698                " WHERE i.indrelid=%s::regclass"
1699                " AND i.indisprimary ORDER BY a.attnum") % (
1700                    _quote_if_unqualified('$1', table),)
1701            pkey = self.db.query(q, (table,)).getresult()
1702            if not pkey:
1703                raise KeyError('Table %s has no primary key' % table)
1704            # we want to use the order defined in the primary key index here,
1705            # not the order as defined by the columns in the table
1706            if len(pkey) > 1:
1707                indkey = pkey[0][2]
1708                pkey = sorted(pkey, key=lambda row: indkey.index(row[1]))
1709                pkey = tuple(row[0] for row in pkey)
1710            else:
1711                pkey = pkey[0][0]
1712            pkeys[table] = pkey  # cache it
1713        if composite and not isinstance(pkey, tuple):
1714            pkey = (pkey,)
1715        return pkey
1716
1717    def get_databases(self):
1718        """Get list of databases in the system."""
1719        return [s[0] for s in
1720            self.db.query('SELECT datname FROM pg_database').getresult()]
1721
1722    def get_relations(self, kinds=None):
1723        """Get list of relations in connected database of specified kinds.
1724
1725        If kinds is None or empty, all kinds of relations are returned.
1726        Otherwise kinds can be a string or sequence of type letters
1727        specifying which kind of relations you want to list.
1728        """
1729        where = " AND r.relkind IN (%s)" % ','.join(
1730            ["'%s'" % k for k in kinds]) if kinds else ''
1731        q = ("SELECT quote_ident(s.nspname)||'.'||quote_ident(r.relname)"
1732            " FROM pg_class r"
1733            " JOIN pg_namespace s ON s.oid = r.relnamespace"
1734            " WHERE s.nspname NOT SIMILAR"
1735            " TO 'pg/_%%|information/_schema' ESCAPE '/' %s"
1736            " ORDER BY s.nspname, r.relname") % where
1737        return [r[0] for r in self.db.query(q).getresult()]
1738
1739    def get_tables(self):
1740        """Return list of tables in connected database."""
1741        return self.get_relations('r')
1742
1743    def get_attnames(self, table, with_oid=True, flush=False):
1744        """Given the name of a table, dig out the set of attribute names.
1745
1746        Returns a read-only dictionary of attribute names (the names are
1747        the keys, the values are the names of the attributes' types)
1748        with the column names in the proper order if you iterate over it.
1749
1750        If flush is set, then the internal cache for attribute names will
1751        be flushed. This may be necessary after the database schema or
1752        the search path has been changed.
1753
1754        By default, only a limited number of simple types will be returned.
1755        You can get the regular types after calling use_regtypes(True).
1756        """
1757        attnames = self._attnames
1758        if flush:
1759            attnames.clear()
1760            self._do_debug('The attnames cache has been flushed')
1761        try:  # cache lookup
1762            names = attnames[table]
1763        except KeyError:  # cache miss, check the database
1764            q = "a.attnum > 0"
1765            if with_oid:
1766                q = "(%s OR a.attname = 'oid')" % q
1767            q = ("SELECT a.attname, t.oid, t.typname, t.typname::regtype,"
1768                " t.typtype, t.typcategory, t.typdelim, t.typrelid"
1769                " FROM pg_attribute a"
1770                " JOIN pg_type t ON t.oid = a.atttypid"
1771                " WHERE a.attrelid = %s::regclass AND %s"
1772                " AND NOT a.attisdropped ORDER BY a.attnum") % (
1773                    _quote_if_unqualified('$1', table), q)
1774            names = self.db.query(q, (table,)).getresult()
1775            types = self.dbtypes
1776            names = ((name[0], types.add(*name[1:])) for name in names)
1777            names = AttrDict(names)
1778            attnames[table] = names  # cache it
1779        return names
1780
1781    def use_regtypes(self, regtypes=None):
1782        """Use regular type names instead of simplified type names."""
1783        if regtypes is None:
1784            return self.dbtypes._regtypes
1785        else:
1786            regtypes = bool(regtypes)
1787            if regtypes != self.dbtypes._regtypes:
1788                self.dbtypes._regtypes = regtypes
1789                self._attnames.clear()
1790                self.dbtypes.clear()
1791            return regtypes
1792
1793    def has_table_privilege(self, table, privilege='select'):
1794        """Check whether current user has specified table privilege."""
1795        privilege = privilege.lower()
1796        try:  # ask cache
1797            return self._privileges[(table, privilege)]
1798        except KeyError:  # cache miss, ask the database
1799            q = "SELECT has_table_privilege(%s, $2)" % (
1800                _quote_if_unqualified('$1', table),)
1801            q = self.db.query(q, (table, privilege))
1802            ret = q.getresult()[0][0] == self._make_bool(True)
1803            self._privileges[(table, privilege)] = ret  # cache it
1804            return ret
1805
1806    def get(self, table, row, keyname=None):
1807        """Get a row from a database table or view.
1808
1809        This method is the basic mechanism to get a single row.  It assumes
1810        that the keyname specifies a unique row.  It must be the name of a
1811        single column or a tuple of column names.  If the keyname is not
1812        specified, then the primary key for the table is used.
1813
1814        If row is a dictionary, then the value for the key is taken from it.
1815        Otherwise, the row must be a single value or a tuple of values
1816        corresponding to the passed keyname or primary key.  The fetched row
1817        from the table will be returned as a new dictionary or used to replace
1818        the existing values when row was passed as aa dictionary.
1819
1820        The OID is also put into the dictionary if the table has one, but
1821        in order to allow the caller to work with multiple tables, it is
1822        munged as "oid(table)" using the actual name of the table.
1823        """
1824        if table.endswith('*'):  # hint for descendant tables can be ignored
1825            table = table[:-1].rstrip()
1826        attnames = self.get_attnames(table)
1827        qoid = _oid_key(table) if 'oid' in attnames else None
1828        if keyname and isinstance(keyname, basestring):
1829            keyname = (keyname,)
1830        if qoid and isinstance(row, dict) and qoid in row and 'oid' not in row:
1831            row['oid'] = row[qoid]
1832        if not keyname:
1833            try:  # if keyname is not specified, try using the primary key
1834                keyname = self.pkey(table, True)
1835            except KeyError:  # the table has no primary key
1836                # try using the oid instead
1837                if qoid and isinstance(row, dict) and 'oid' in row:
1838                    keyname = ('oid',)
1839                else:
1840                    raise _prg_error('Table %s has no primary key' % table)
1841            else:  # the table has a primary key
1842                # check whether all key columns have values
1843                if isinstance(row, dict) and not set(keyname).issubset(row):
1844                    # try using the oid instead
1845                    if qoid and 'oid' in row:
1846                        keyname = ('oid',)
1847                    else:
1848                        raise KeyError(
1849                            'Missing value in row for specified keyname')
1850        if not isinstance(row, dict):
1851            if not isinstance(row, (tuple, list)):
1852                row = [row]
1853            if len(keyname) != len(row):
1854                raise KeyError(
1855                    'Differing number of items in keyname and row')
1856            row = dict(zip(keyname, row))
1857        params = self.adapter.parameter_list()
1858        adapt = params.add
1859        col = self.escape_identifier
1860        what = 'oid, *' if qoid else '*'
1861        where = ' AND '.join('%s = %s' % (
1862            col(k), adapt(row[k], attnames[k])) for k in keyname)
1863        if 'oid' in row:
1864            if qoid:
1865                row[qoid] = row['oid']
1866            del row['oid']
1867        q = 'SELECT %s FROM %s WHERE %s LIMIT 1' % (
1868            what, self._escape_qualified_name(table), where)
1869        self._do_debug(q, params)
1870        q = self.db.query(q, params)
1871        res = q.dictresult()
1872        if not res:
1873            raise _db_error('No such record in %s\nwhere %s\nwith %s' % (
1874                table, where, self._list_params(params)))
1875        for n, value in res[0].items():
1876            if qoid and n == 'oid':
1877                n = qoid
1878            row[n] = value
1879        return row
1880
1881    def insert(self, table, row=None, **kw):
1882        """Insert a row into a database table.
1883
1884        This method inserts a row into a table.  The name of the table must
1885        be passed as the first parameter.  The other parameters are used for
1886        providing the data of the row that shall be inserted into the table.
1887        If a dictionary is supplied as the second parameter, it starts with
1888        that.  Otherwise it uses a blank dictionary. Either way the dictionary
1889        is updated from the keywords.
1890
1891        The dictionary is then reloaded with the values actually inserted in
1892        order to pick up values modified by rules, triggers, etc.
1893        """
1894        if table.endswith('*'):  # hint for descendant tables can be ignored
1895            table = table[:-1].rstrip()
1896        if row is None:
1897            row = {}
1898        row.update(kw)
1899        if 'oid' in row:
1900            del row['oid']  # do not insert oid
1901        attnames = self.get_attnames(table)
1902        qoid = _oid_key(table) if 'oid' in attnames else None
1903        params = self.adapter.parameter_list()
1904        adapt = params.add
1905        col = self.escape_identifier
1906        names, values = [], []
1907        for n in attnames:
1908            if n in row:
1909                names.append(col(n))
1910                values.append(adapt(row[n], attnames[n]))
1911        if not names:
1912            raise _prg_error('No column found that can be inserted')
1913        names, values = ', '.join(names), ', '.join(values)
1914        ret = 'oid, *' if qoid else '*'
1915        q = 'INSERT INTO %s (%s) VALUES (%s) RETURNING %s' % (
1916            self._escape_qualified_name(table), names, values, ret)
1917        self._do_debug(q, params)
1918        q = self.db.query(q, params)
1919        res = q.dictresult()
1920        if res:  # this should always be true
1921            for n, value in res[0].items():
1922                if qoid and n == 'oid':
1923                    n = qoid
1924                row[n] = value
1925        return row
1926
1927    def update(self, table, row=None, **kw):
1928        """Update an existing row in a database table.
1929
1930        Similar to insert but updates an existing row.  The update is based
1931        on the primary key of the table or the OID value as munged by get
1932        or passed as keyword.
1933
1934        The dictionary is then modified to reflect any changes caused by the
1935        update due to triggers, rules, default values, etc.
1936        """
1937        if table.endswith('*'):
1938            table = table[:-1].rstrip()  # need parent table name
1939        attnames = self.get_attnames(table)
1940        qoid = _oid_key(table) if 'oid' in attnames else None
1941        if row is None:
1942            row = {}
1943        elif 'oid' in row:
1944            del row['oid']  # only accept oid key from named args for safety
1945        row.update(kw)
1946        if qoid and qoid in row and 'oid' not in row:
1947            row['oid'] = row[qoid]
1948        try:  # try using the primary key
1949            keyname = self.pkey(table, True)
1950        except KeyError:  # the table has no primary key
1951            # try using the oid instead
1952            if qoid and 'oid' in row:
1953                keyname = ('oid',)
1954            else:
1955                raise _prg_error('Table %s has no primary key' % table)
1956        else:  # the table has a primary key
1957            # check whether all key columns have values
1958            if not set(keyname).issubset(row):
1959                # try using the oid instead
1960                if qoid and 'oid' in row:
1961                    keyname = ('oid',)
1962                else:
1963                    raise KeyError('Missing primary key in row')
1964        params = self.adapter.parameter_list()
1965        adapt = params.add
1966        col = self.escape_identifier
1967        where = ' AND '.join('%s = %s' % (
1968            col(k), adapt(row[k], attnames[k])) for k in keyname)
1969        if 'oid' in row:
1970            if qoid:
1971                row[qoid] = row['oid']
1972            del row['oid']
1973        values = []
1974        keyname = set(keyname)
1975        for n in attnames:
1976            if n in row and n not in keyname:
1977                values.append('%s = %s' % (col(n), adapt(row[n], attnames[n])))
1978        if not values:
1979            return row
1980        values = ', '.join(values)
1981        ret = 'oid, *' if qoid else '*'
1982        q = 'UPDATE %s SET %s WHERE %s RETURNING %s' % (
1983            self._escape_qualified_name(table), values, where, ret)
1984        self._do_debug(q, params)
1985        q = self.db.query(q, params)
1986        res = q.dictresult()
1987        if res:  # may be empty when row does not exist
1988            for n, value in res[0].items():
1989                if qoid and n == 'oid':
1990                    n = qoid
1991                row[n] = value
1992        return row
1993
1994    def upsert(self, table, row=None, **kw):
1995        """Insert a row into a database table with conflict resolution
1996
1997        This method inserts a row into a table, but instead of raising a
1998        ProgrammingError exception in case a row with the same primary key
1999        already exists, an update will be executed instead.  This will be
2000        performed as a single atomic operation on the database, so race
2001        conditions can be avoided.
2002
2003        Like the insert method, the first parameter is the name of the
2004        table and the second parameter can be used to pass the values to
2005        be inserted as a dictionary.
2006
2007        Unlike the insert und update statement, keyword parameters are not
2008        used to modify the dictionary, but to specify which columns shall
2009        be updated in case of a conflict, and in which way:
2010
2011        A value of False or None means the column shall not be updated,
2012        a value of True means the column shall be updated with the value
2013        that has been proposed for insertion, i.e. has been passed as value
2014        in the dictionary.  Columns that are not specified by keywords but
2015        appear as keys in the dictionary are also updated like in the case
2016        keywords had been passed with the value True.
2017
2018        So if in the case of a conflict you want to update every column that
2019        has been passed in the dictionary row , you would call upsert(table, row).
2020        If you don't want to do anything in case of a conflict, i.e. leave
2021        the existing row as it is, call upsert(table, row, **dict.fromkeys(row)).
2022
2023        If you need more fine-grained control of what gets updated, you can
2024        also pass strings in the keyword parameters.  These strings will
2025        be used as SQL expressions for the update columns.  In these
2026        expressions you can refer to the value that already exists in
2027        the table by prefixing the column name with "included.", and to
2028        the value that has been proposed for insertion by prefixing the
2029        column name with the "excluded."
2030
2031        The dictionary is modified in any case to reflect the values in
2032        the database after the operation has completed.
2033
2034        Note: The method uses the PostgreSQL "upsert" feature which is
2035        only available since PostgreSQL 9.5.
2036        """
2037        if table.endswith('*'):  # hint for descendant tables can be ignored
2038            table = table[:-1].rstrip()
2039        if row is None:
2040            row = {}
2041        if 'oid' in row:
2042            del row['oid']  # do not insert oid
2043        if 'oid' in kw:
2044            del kw['oid']  # do not update oid
2045        attnames = self.get_attnames(table)
2046        qoid = _oid_key(table) if 'oid' in attnames else None
2047        params = self.adapter.parameter_list()
2048        adapt = params.add
2049        col = self.escape_identifier
2050        names, values, updates = [], [], []
2051        for n in attnames:
2052            if n in row:
2053                names.append(col(n))
2054                values.append(adapt(row[n], attnames[n]))
2055        names, values = ', '.join(names), ', '.join(values)
2056        try:
2057            keyname = self.pkey(table, True)
2058        except KeyError:
2059            raise _prg_error('Table %s has no primary key' % table)
2060        target = ', '.join(col(k) for k in keyname)
2061        update = []
2062        keyname = set(keyname)
2063        keyname.add('oid')
2064        for n in attnames:
2065            if n not in keyname:
2066                value = kw.get(n, True)
2067                if value:
2068                    if not isinstance(value, basestring):
2069                        value = 'excluded.%s' % col(n)
2070                    update.append('%s = %s' % (col(n), value))
2071        if not values:
2072            return row
2073        do = 'update set %s' % ', '.join(update) if update else 'nothing'
2074        ret = 'oid, *' if qoid else '*'
2075        q = ('INSERT INTO %s AS included (%s) VALUES (%s)'
2076            ' ON CONFLICT (%s) DO %s RETURNING %s') % (
2077                self._escape_qualified_name(table), names, values,
2078                target, do, ret)
2079        self._do_debug(q, params)
2080        try:
2081            q = self.db.query(q, params)
2082        except ProgrammingError:
2083            if self.server_version < 90500:
2084                raise _prg_error(
2085                    'Upsert operation is not supported by PostgreSQL version')
2086            raise  # re-raise original error
2087        res = q.dictresult()
2088        if res:  # may be empty with "do nothing"
2089            for n, value in res[0].items():
2090                if qoid and n == 'oid':
2091                    n = qoid
2092                row[n] = value
2093        else:
2094            self.get(table, row)
2095        return row
2096
2097    def clear(self, table, row=None):
2098        """Clear all the attributes to values determined by the types.
2099
2100        Numeric types are set to 0, Booleans are set to false, and everything
2101        else is set to the empty string.  If the row argument is present,
2102        it is used as the row dictionary and any entries matching attribute
2103        names are cleared with everything else left unchanged.
2104        """
2105        # At some point we will need a way to get defaults from a table.
2106        if row is None:
2107            row = {}  # empty if argument is not present
2108        attnames = self.get_attnames(table)
2109        for n, t in attnames.items():
2110            if n == 'oid':
2111                continue
2112            t = t.simple
2113            if t in DbTypes._num_types:
2114                row[n] = 0
2115            elif t == 'bool':
2116                row[n] = self._make_bool(False)
2117            else:
2118                row[n] = ''
2119        return row
2120
2121    def delete(self, table, row=None, **kw):
2122        """Delete an existing row in a database table.
2123
2124        This method deletes the row from a table.  It deletes based on the
2125        primary key of the table or the OID value as munged by get() or
2126        passed as keyword.
2127
2128        The return value is the number of deleted rows (i.e. 0 if the row
2129        did not exist and 1 if the row was deleted).
2130
2131        Note that if the row cannot be deleted because e.g. it is still
2132        referenced by another table, this method raises a ProgrammingError.
2133        """
2134        if table.endswith('*'):  # hint for descendant tables can be ignored
2135            table = table[:-1].rstrip()
2136        attnames = self.get_attnames(table)
2137        qoid = _oid_key(table) if 'oid' in attnames else None
2138        if row is None:
2139            row = {}
2140        elif 'oid' in row:
2141            del row['oid']  # only accept oid key from named args for safety
2142        row.update(kw)
2143        if qoid and qoid in row and 'oid' not in row:
2144            row['oid'] = row[qoid]
2145        try:  # try using the primary key
2146            keyname = self.pkey(table, True)
2147        except KeyError:  # the table has no primary key
2148            # try using the oid instead
2149            if qoid and 'oid' in row:
2150                keyname = ('oid',)
2151            else:
2152                raise _prg_error('Table %s has no primary key' % table)
2153        else:  # the table has a primary key
2154            # check whether all key columns have values
2155            if not set(keyname).issubset(row):
2156                # try using the oid instead
2157                if qoid and 'oid' in row:
2158                    keyname = ('oid',)
2159                else:
2160                    raise KeyError('Missing primary key in row')
2161        params = self.adapter.parameter_list()
2162        adapt = params.add
2163        col = self.escape_identifier
2164        where = ' AND '.join('%s = %s' % (
2165            col(k), adapt(row[k], attnames[k])) for k in keyname)
2166        if 'oid' in row:
2167            if qoid:
2168                row[qoid] = row['oid']
2169            del row['oid']
2170        q = 'DELETE FROM %s WHERE %s' % (
2171            self._escape_qualified_name(table), where)
2172        self._do_debug(q, params)
2173        res = self.db.query(q, params)
2174        return int(res)
2175
2176    def truncate(self, table, restart=False, cascade=False, only=False):
2177        """Empty a table or set of tables.
2178
2179        This method quickly removes all rows from the given table or set
2180        of tables.  It has the same effect as an unqualified DELETE on each
2181        table, but since it does not actually scan the tables it is faster.
2182        Furthermore, it reclaims disk space immediately, rather than requiring
2183        a subsequent VACUUM operation. This is most useful on large tables.
2184
2185        If restart is set to True, sequences owned by columns of the truncated
2186        table(s) are automatically restarted.  If cascade is set to True, it
2187        also truncates all tables that have foreign-key references to any of
2188        the named tables.  If the parameter only is not set to True, all the
2189        descendant tables (if any) will also be truncated. Optionally, a '*'
2190        can be specified after the table name to explicitly indicate that
2191        descendant tables are included.
2192        """
2193        if isinstance(table, basestring):
2194            only = {table: only}
2195            table = [table]
2196        elif isinstance(table, (list, tuple)):
2197            if isinstance(only, (list, tuple)):
2198                only = dict(zip(table, only))
2199            else:
2200                only = dict.fromkeys(table, only)
2201        elif isinstance(table, (set, frozenset)):
2202            only = dict.fromkeys(table, only)
2203        else:
2204            raise TypeError('The table must be a string, list or set')
2205        if not (restart is None or isinstance(restart, (bool, int))):
2206            raise TypeError('Invalid type for the restart option')
2207        if not (cascade is None or isinstance(cascade, (bool, int))):
2208            raise TypeError('Invalid type for the cascade option')
2209        tables = []
2210        for t in table:
2211            u = only.get(t)
2212            if not (u is None or isinstance(u, (bool, int))):
2213                raise TypeError('Invalid type for the only option')
2214            if t.endswith('*'):
2215                if u:
2216                    raise ValueError(
2217                        'Contradictory table name and only options')
2218                t = t[:-1].rstrip()
2219            t = self._escape_qualified_name(t)
2220            if u:
2221                t = 'ONLY %s' % t
2222            tables.append(t)
2223        q = ['TRUNCATE', ', '.join(tables)]
2224        if restart:
2225            q.append('RESTART IDENTITY')
2226        if cascade:
2227            q.append('CASCADE')
2228        q = ' '.join(q)
2229        self._do_debug(q)
2230        return self.db.query(q)
2231
2232    def get_as_list(self, table, what=None, where=None,
2233            order=None, limit=None, offset=None, scalar=False):
2234        """Get a table as a list.
2235
2236        This gets a convenient representation of the table as a list
2237        of named tuples in Python.  You only need to pass the name of
2238        the table (or any other SQL expression returning rows).  Note that
2239        by default this will return the full content of the table which
2240        can be huge and overflow your memory.  However, you can control
2241        the amount of data returned using the other optional parameters.
2242
2243        The parameter 'what' can restrict the query to only return a
2244        subset of the table columns.  It can be a string, list or a tuple.
2245        The parameter 'where' can restrict the query to only return a
2246        subset of the table rows.  It can be a string, list or a tuple
2247        of SQL expressions that all need to be fulfilled.  The parameter
2248        'order' specifies the ordering of the rows.  It can also be a
2249        other string, list or a tuple.  If no ordering is specified,
2250        the result will be ordered by the primary key(s) or all columns
2251        if no primary key exists.  You can set 'order' to False if you
2252        don't care about the ordering.  The parameters 'limit' and 'offset'
2253        can be integers specifying the maximum number of rows returned
2254        and a number of rows skipped over.
2255
2256        If you set the 'scalar' option to True, then instead of the
2257        named tuples you will get the first items of these tuples.
2258        This is useful if the result has only one column anyway.
2259        """
2260        if not table:
2261            raise TypeError('The table name is missing')
2262        if what:
2263            if isinstance(what, (list, tuple)):
2264                what = ', '.join(map(str, what))
2265            if order is None:
2266                order = what
2267        else:
2268            what = '*'
2269        q = ['SELECT', what, 'FROM', table]
2270        if where:
2271            if isinstance(where, (list, tuple)):
2272                where = ' AND '.join(map(str, where))
2273            q.extend(['WHERE', where])
2274        if order is None:
2275            try:
2276                order = self.pkey(table, True)
2277            except (KeyError, ProgrammingError):
2278                try:
2279                    order = list(self.get_attnames(table))
2280                except (KeyError, ProgrammingError):
2281                    pass
2282        if order:
2283            if isinstance(order, (list, tuple)):
2284                order = ', '.join(map(str, order))
2285            q.extend(['ORDER BY', order])
2286        if limit:
2287            q.append('LIMIT %d' % limit)
2288        if offset:
2289            q.append('OFFSET %d' % offset)
2290        q = ' '.join(q)
2291        self._do_debug(q)
2292        q = self.db.query(q)
2293        res = q.namedresult()
2294        if res and scalar:
2295            res = [row[0] for row in res]
2296        return res
2297
2298    def get_as_dict(self, table, keyname=None, what=None, where=None,
2299            order=None, limit=None, offset=None, scalar=False):
2300        """Get a table as a dictionary.
2301
2302        This method is similar to get_as_list(), but returns the table
2303        as a Python dict instead of a Python list, which can be even
2304        more convenient. The primary key column(s) of the table will
2305        be used as the keys of the dictionary, while the other column(s)
2306        will be the corresponding values.  The keys will be named tuples
2307        if the table has a composite primary key.  The rows will be also
2308        named tuples unless the 'scalar' option has been set to True.
2309        With the optional parameter 'keyname' you can specify an alternative
2310        set of columns to be used as the keys of the dictionary.  It must
2311        be set as a string, list or a tuple.
2312
2313        If the Python version supports it, the dictionary will be an
2314        OrderedDict using the order specified with the 'order' parameter
2315        or the key column(s) if not specified.  You can set 'order' to False
2316        if you don't care about the ordering.  In this case the returned
2317        dictionary will be an ordinary one.
2318        """
2319        if not table:
2320            raise TypeError('The table name is missing')
2321        if not keyname:
2322            try:
2323                keyname = self.pkey(table, True)
2324            except (KeyError, ProgrammingError):
2325                raise _prg_error('Table %s has no primary key' % table)
2326        if isinstance(keyname, basestring):
2327            keyname = [keyname]
2328        elif not isinstance(keyname, (list, tuple)):
2329            raise KeyError('The keyname must be a string, list or tuple')
2330        if what:
2331            if isinstance(what, (list, tuple)):
2332                what = ', '.join(map(str, what))
2333            if order is None:
2334                order = what
2335        else:
2336            what = '*'
2337        q = ['SELECT', what, 'FROM', table]
2338        if where:
2339            if isinstance(where, (list, tuple)):
2340                where = ' AND '.join(map(str, where))
2341            q.extend(['WHERE', where])
2342        if order is None:
2343            order = keyname
2344        if order:
2345            if isinstance(order, (list, tuple)):
2346                order = ', '.join(map(str, order))
2347            q.extend(['ORDER BY', order])
2348        if limit:
2349            q.append('LIMIT %d' % limit)
2350        if offset:
2351            q.append('OFFSET %d' % offset)
2352        q = ' '.join(q)
2353        self._do_debug(q)
2354        q = self.db.query(q)
2355        res = q.getresult()
2356        cls = OrderedDict if order else dict
2357        if not res:
2358            return cls()
2359        keyset = set(keyname)
2360        fields = q.listfields()
2361        if not keyset.issubset(fields):
2362            raise KeyError('Missing keyname in row')
2363        keyind, rowind = [], []
2364        for i, f in enumerate(fields):
2365            (keyind if f in keyset else rowind).append(i)
2366        keytuple = len(keyind) > 1
2367        getkey = itemgetter(*keyind)
2368        keys = map(getkey, res)
2369        if scalar:
2370            rowind = rowind[:1]
2371            rowtuple = False
2372        else:
2373            rowtuple = len(rowind) > 1
2374        if scalar or rowtuple:
2375            getrow = itemgetter(*rowind)
2376        else:
2377            rowind = rowind[0]
2378            getrow = lambda row: (row[rowind],)
2379            rowtuple = True
2380        rows = map(getrow, res)
2381        if keytuple or rowtuple:
2382            namedresult = get_namedresult()
2383            if namedresult:
2384                if keytuple:
2385                    keys = namedresult(_MemoryQuery(keys, keyname))
2386                if rowtuple:
2387                    fields = [f for f in fields if f not in keyset]
2388                    rows = namedresult(_MemoryQuery(rows, fields))
2389        return cls(zip(keys, rows))
2390
2391    def notification_handler(self,
2392            event, callback, arg_dict=None, timeout=None, stop_event=None):
2393        """Get notification handler that will run the given callback."""
2394        return NotificationHandler(self,
2395            event, callback, arg_dict, timeout, stop_event)
2396
2397
2398# if run as script, print some information
2399
2400if __name__ == '__main__':
2401    print('PyGreSQL version' + version)
2402    print('')
2403    print(__doc__)
Note: See TracBrowser for help on using the repository browser.