source: trunk/pg.py @ 820

Last change on this file since 820 was 820, checked in by cito, 3 years ago

Skip test for privileges when running as superuser

It's not possible to test for missing privileges as a superuser
because superusers always have all privileges.

  • Property svn:keywords set to Id
File size: 87.1 KB
Line 
1#! /usr/bin/python
2#
3# pg.py
4#
5# $Id: pg.py 820 2016-02-05 14:04:31Z cito $
6#
7
8"""PyGreSQL classic interface.
9
10This pg module implements some basic database management stuff.
11It includes the _pg module and builds on it, providing the higher
12level wrapper class named DB with additional functionality.
13This is known as the "classic" ("old style") PyGreSQL interface.
14For a DB-API 2 compliant interface use the newer pgdb module.
15"""
16
17# Copyright (c) 1997-2016 by D'Arcy J.M. Cain.
18#
19# Contributions made by Ch. Zwerschke and others.
20#
21# The notification handler is based on pgnotify which is
22# Copyright (c) 2001 Ng Pheng Siong. All rights reserved.
23#
24# Permission to use, copy, modify, and distribute this software and its
25# documentation for any purpose and without fee is hereby granted,
26# provided that the above copyright notice appear in all copies and that
27# both that copyright notice and this permission notice appear in
28# supporting documentation.
29
30from __future__ import print_function
31
32from _pg import *
33
34import select
35import warnings
36
37from datetime import date, time, datetime, timedelta
38from decimal import Decimal
39from math import isnan, isinf
40from collections import namedtuple
41from operator import itemgetter
42from functools import partial
43from re import compile as regex
44from json import loads as jsondecode, dumps as jsonencode
45
46try:
47    long
48except NameError:  # Python >= 3.0
49    long = int
50
51try:
52    basestring
53except NameError:  # Python >= 3.0
54    basestring = (str, bytes)
55
56
57# Auxiliary classes and functions that are independent from a DB connection:
58
59try:
60    from collections import OrderedDict
61except ImportError:  # Python 2.6 or 3.0
62    OrderedDict = dict
63
64
65    class AttrDict(dict):
66        """Simple read-only ordered dictionary for storing attribute names."""
67
68        def __init__(self, *args, **kw):
69            if len(args) > 1 or kw:
70                raise TypeError
71            items = args[0] if args else []
72            if isinstance(items, dict):
73                raise TypeError
74            items = list(items)
75            self._keys = [item[0] for item in items]
76            dict.__init__(self, items)
77            self._read_only = True
78            error = self._read_only_error
79            self.clear = self.update = error
80            self.pop = self.setdefault = self.popitem = error
81
82        def __setitem__(self, key, value):
83            if self._read_only:
84                self._read_only_error()
85            dict.__setitem__(self, key, value)
86
87        def __delitem__(self, key):
88            if self._read_only:
89                self._read_only_error()
90            dict.__delitem__(self, key)
91
92        def __iter__(self):
93            return iter(self._keys)
94
95        def keys(self):
96            return list(self._keys)
97
98        def values(self):
99            return [self[key] for key in self]
100
101        def items(self):
102            return [(key, self[key]) for key in self]
103
104        def iterkeys(self):
105            return self.__iter__()
106
107        def itervalues(self):
108            return iter(self.values())
109
110        def iteritems(self):
111            return iter(self.items())
112
113        @staticmethod
114        def _read_only_error(*args, **kw):
115            raise TypeError('This object is read-only')
116
117else:
118
119     class AttrDict(OrderedDict):
120        """Simple read-only ordered dictionary for storing attribute names."""
121
122        def __init__(self, *args, **kw):
123            self._read_only = False
124            OrderedDict.__init__(self, *args, **kw)
125            self._read_only = True
126            error = self._read_only_error
127            self.clear = self.update = error
128            self.pop = self.setdefault = self.popitem = error
129
130        def __setitem__(self, key, value):
131            if self._read_only:
132                self._read_only_error()
133            OrderedDict.__setitem__(self, key, value)
134
135        def __delitem__(self, key):
136            if self._read_only:
137                self._read_only_error()
138            OrderedDict.__delitem__(self, key)
139
140        @staticmethod
141        def _read_only_error(*args, **kw):
142            raise TypeError('This object is read-only')
143
144try:
145    from inspect import signature
146except ImportError:  # Python < 3.3
147    from inspect import getargspec
148
149    def get_args(func):
150        return getargspec(func).args
151else:
152
153    def get_args(func):
154        return list(signature(func).parameters)
155
156try:
157    if datetime.strptime('+0100', '%z') is None:
158        raise ValueError
159except ValueError:  # Python < 3.2
160    timezones = None
161else:
162    # time zones used in Postgres timestamptz output
163    timezones = dict(CET='+0100', EET='+0200', EST='-0500',
164        GMT='+0000', HST='-1000', MET='+0100', MST='-0700',
165        UCT='+0000', UTC='+0000', WET='+0000')
166
167
168def _oid_key(table):
169    """Build oid key from a table name."""
170    return 'oid(%s)' % table
171
172
173class _SimpleTypes(dict):
174    """Dictionary mapping pg_type names to simple type names."""
175
176    _types = {'bool': 'bool',
177        'bytea': 'bytea',
178        'date': 'date interval time timetz timestamp timestamptz'
179            ' abstime reltime',  # these are very old
180        'float': 'float4 float8',
181        'int': 'cid int2 int4 int8 oid xid',
182        'hstore': 'hstore', 'json': 'json jsonb',
183        'num': 'numeric',
184        'money': 'money',
185        'text': 'bpchar char name text varchar'}
186
187    def __init__(self):
188        for typ, keys in self._types.items():
189            for key in keys.split():
190                self[key] = typ
191                self['_%s' % key] = '%s[]' % typ
192
193    @staticmethod
194    def __missing__(key):
195        return 'text'
196
197_simpletypes = _SimpleTypes()
198
199
200def _quote_if_unqualified(param, name):
201    """Quote parameter representing a qualified name.
202
203    Puts a quote_ident() call around the give parameter unless
204    the name contains a dot, in which case the name is ambiguous
205    (could be a qualified name or just a name with a dot in it)
206    and must be quoted manually by the caller.
207    """
208    if isinstance(name, basestring) and '.' not in name:
209        return 'quote_ident(%s)' % (param,)
210    return param
211
212
213class _ParameterList(list):
214    """Helper class for building typed parameter lists."""
215
216    def add(self, value, typ=None):
217        """Typecast value with known database type and build parameter list.
218
219        If this is a literal value, it will be returned as is.  Otherwise, a
220        placeholder will be returned and the parameter list will be augmented.
221        """
222        value = self.adapt(value, typ)
223        if isinstance(value, Literal):
224            return value
225        self.append(value)
226        return '$%d' % len(self)
227
228
229class Bytea(bytes):
230    """Wrapper class for marking Bytea values."""
231
232
233class Hstore(dict):
234    """Wrapper class for marking hstore values."""
235
236    _re_quote = regex('^[Nn][Uu][Ll][Ll]$|[ ,=>]')
237
238    @classmethod
239    def _quote(cls, s):
240        if s is None:
241            return 'NULL'
242        if not s:
243            return '""'
244        s = s.replace('"', '\\"')
245        if cls._re_quote.search(s):
246            s = '"%s"' % s
247        return s
248
249    def __str__(self):
250        q = self._quote
251        return ','.join('%s=>%s' % (q(k), q(v)) for k, v in self.items())
252
253
254class Json:
255    """Wrapper class for marking Json values."""
256
257    def __init__(self, obj):
258        self.obj = obj
259
260
261class Literal(str):
262    """Wrapper class for marking literal SQL values."""
263
264
265class Adapter:
266    """Class providing methods for adapting parameters to the database."""
267
268    _bool_true_values = frozenset('t true 1 y yes on'.split())
269
270    _date_literals = frozenset('current_date current_time'
271        ' current_timestamp localtime localtimestamp'.split())
272
273    _re_array_quote = regex(r'[{},"\\\s]|^[Nn][Uu][Ll][Ll]$')
274    _re_record_quote = regex(r'[(,"\\]')
275    _re_array_escape = _re_record_escape = regex(r'(["\\])')
276
277    def __init__(self, db):
278        self.db = db
279        self.encode_json = db.encode_json
280        db = db.db
281        self.escape_bytea = db.escape_bytea
282        self.escape_string = db.escape_string
283
284    @classmethod
285    def _adapt_bool(cls, v):
286        """Adapt a boolean parameter."""
287        if isinstance(v, basestring):
288            if not v:
289                return None
290            v = v.lower() in cls._bool_true_values
291        return 't' if v else 'f'
292
293    @classmethod
294    def _adapt_date(cls, v):
295        """Adapt a date parameter."""
296        if not v:
297            return None
298        if isinstance(v, basestring) and v.lower() in cls._date_literals:
299            return Literal(v)
300        return v
301
302    @staticmethod
303    def _adapt_num(v):
304        """Adapt a numeric parameter."""
305        if not v and v != 0:
306            return None
307        return v
308
309    _adapt_int = _adapt_float = _adapt_money = _adapt_num
310
311    def _adapt_bytea(self, v):
312        """Adapt a bytea parameter."""
313        return self.escape_bytea(v)
314
315    def _adapt_json(self, v):
316        """Adapt a json parameter."""
317        if not v:
318            return None
319        if isinstance(v, basestring):
320            return v
321        return self.encode_json(v)
322
323    @classmethod
324    def _adapt_text_array(cls, v):
325        """Adapt a text type array parameter."""
326        if isinstance(v, list):
327            adapt = cls._adapt_text_array
328            return '{%s}' % ','.join(adapt(v) for v in v)
329        if v is None:
330            return 'null'
331        if not v:
332            return '""'
333        v = str(v)
334        if cls._re_array_quote.search(v):
335            v = '"%s"' % cls._re_array_escape.sub(r'\\\1', v)
336        return v
337
338    _adapt_date_array = _adapt_text_array
339
340    @classmethod
341    def _adapt_bool_array(cls, v):
342        """Adapt a boolean array parameter."""
343        if isinstance(v, list):
344            adapt = cls._adapt_bool_array
345            return '{%s}' % ','.join(adapt(v) for v in v)
346        if v is None:
347            return 'null'
348        if isinstance(v, basestring):
349            if not v:
350                return 'null'
351            v = v.lower() in cls._bool_true_values
352        return 't' if v else 'f'
353
354    @classmethod
355    def _adapt_num_array(cls, v):
356        """Adapt a numeric array parameter."""
357        if isinstance(v, list):
358            adapt = cls._adapt_num_array
359            return '{%s}' % ','.join(adapt(v) for v in v)
360        if not v and v != 0:
361            return 'null'
362        return str(v)
363
364    _adapt_int_array = _adapt_float_array = _adapt_money_array = \
365            _adapt_num_array
366
367    def _adapt_bytea_array(self, v):
368        """Adapt a bytea array parameter."""
369        if isinstance(v, list):
370            return b'{' + b','.join(
371                self._adapt_bytea_array(v) for v in v) + b'}'
372        if v is None:
373            return b'null'
374        return self.escape_bytea(v).replace(b'\\', b'\\\\')
375
376    def _adapt_json_array(self, v):
377        """Adapt a json array parameter."""
378        if isinstance(v, list):
379            adapt = self._adapt_json_array
380            return '{%s}' % ','.join(adapt(v) for v in v)
381        if not v:
382            return 'null'
383        if not isinstance(v, basestring):
384            v = self.encode_json(v)
385        if self._re_array_quote.search(v):
386            v = '"%s"' % self._re_array_escape.sub(r'\\\1', v)
387        return v
388
389    def _adapt_record(self, v, typ):
390        """Adapt a record parameter with given type."""
391        typ = self.get_attnames(typ).values()
392        if len(typ) != len(v):
393            raise TypeError('Record parameter %s has wrong size' % v)
394        adapt = self.adapt
395        value = []
396        for v, t in zip(v, typ):
397            v = adapt(v, t)
398            if v is None:
399                v = ''
400            elif not v:
401                v = '""'
402            else:
403                if isinstance(v, bytes):
404                    if str is not bytes:
405                        v = v.decode('ascii')
406                else:
407                    v = str(v)
408                if self._re_record_quote.search(v):
409                    v = '"%s"' % self._re_record_escape.sub(r'\\\1', v)
410            value.append(v)
411        return '(%s)' % ','.join(value)
412
413    def adapt(self, value, typ=None):
414        """Adapt a value with known database type."""
415        if value is not None and not isinstance(value, Literal):
416            if typ:
417                simple = self.get_simple_name(typ)
418            else:
419                typ = simple = self.guess_simple_type(value) or 'text'
420            try:
421                value = value.__pg_str__(typ)
422            except AttributeError:
423                pass
424            if simple == 'text':
425                pass
426            elif simple == 'record':
427                if isinstance(value, tuple):
428                    value = self._adapt_record(value, typ)
429            elif simple.endswith('[]'):
430                if isinstance(value, list):
431                    adapt = getattr(self, '_adapt_%s_array' % simple[:-2])
432                    value = adapt(value)
433            else:
434                adapt = getattr(self, '_adapt_%s' % simple)
435                value = adapt(value)
436        return value
437
438    @staticmethod
439    def simple_type(name):
440        """Create a simple database type with given attribute names."""
441        typ = DbType(name)
442        typ.simple = name
443        return typ
444
445    @staticmethod
446    def get_simple_name(typ):
447        """Get the simple name of a database type."""
448        if isinstance(typ, DbType):
449            return typ.simple
450        return _simpletypes[typ]
451
452    @staticmethod
453    def get_attnames(typ):
454        """Get the attribute names of a composite database type."""
455        if isinstance(typ, DbType):
456            return typ.attnames
457        return {}
458
459    @classmethod
460    def guess_simple_type(cls, value):
461        """Try to guess which database type the given value has."""
462        if isinstance(value, Bytea):
463            return 'bytea'
464        if isinstance(value, basestring):
465            return 'text'
466        if isinstance(value, bool):
467            return 'bool'
468        if isinstance(value, (int, long)):
469            return 'int'
470        if isinstance(value, float):
471            return 'float'
472        if isinstance(value, Decimal):
473            return 'num'
474        if isinstance(value, (date, time, datetime, timedelta)):
475            return 'date'
476        if isinstance(value, list):
477            return '%s[]' % cls.guess_simple_base_type(value)
478        if isinstance(value, tuple):
479            simple_type = cls.simple_type
480            typ = simple_type('record')
481            guess = cls.guess_simple_type
482            def get_attnames(self):
483                return AttrDict((str(n + 1), simple_type(guess(v)))
484                    for n, v in enumerate(value))
485            typ._get_attnames = get_attnames
486            return typ
487
488    @classmethod
489    def guess_simple_base_type(cls, value):
490        """Try to guess the base type of a given array."""
491        for v in value:
492            if isinstance(v, list):
493                typ = cls.guess_simple_base_type(v)
494            else:
495                typ = cls.guess_simple_type(v)
496            if typ:
497                return typ
498
499    def adapt_inline(self, value, nested=False):
500        """Adapt a value that is put into the SQL and needs to be quoted."""
501        if value is None:
502            return 'NULL'
503        if isinstance(value, Literal):
504            return value
505        if isinstance(value, Bytea):
506            value = self.escape_bytea(value)
507            if bytes is not str:  # Python >= 3.0
508                value = value.decode('ascii')
509        elif isinstance(value, Json):
510            if value.encode:
511                return value.encode()
512            value = self.encode_json(value)
513        elif isinstance(value, (datetime, date, time, timedelta)):
514            value = str(value)
515        if isinstance(value, basestring):
516            value = self.escape_string(value)
517            return "'%s'" % value
518        if isinstance(value, bool):
519            return 'true' if value else 'false'
520        if isinstance(value, float):
521            if isinf(value):
522                return "'-Infinity'" if value < 0 else "'Infinity'"
523            if isnan(value):
524                return "'NaN'"
525            return value
526        if isinstance(value, (int, long, Decimal)):
527            return value
528        if isinstance(value, list):
529            q = self.adapt_inline
530            s = '[%s]' if nested else 'ARRAY[%s]'
531            return s % ','.join(str(q(v, nested=True)) for v in value)
532        if isinstance(value, tuple):
533            q = self.adapt_inline
534            return '(%s)' % ','.join(str(q(v)) for v in value)
535        try:
536            value = value.__pg_repr__()
537        except AttributeError:
538            raise InterfaceError(
539                'Do not know how to adapt type %s' % type(value))
540        if isinstance(value, (tuple, list)):
541            value = self.adapt_inline(value)
542        return value
543
544    def parameter_list(self):
545        """Return a parameter list for parameters with known database types.
546
547        The list has an add(value, typ) method that will build up the
548        list and return either the literal value or a placeholder.
549        """
550        params = _ParameterList()
551        params.adapt = self.adapt
552        return params
553
554    def format_query(self, command, values, types=None, inline=False):
555        """Format a database query using the given values and types."""
556        if inline and types:
557            raise ValueError('Typed parameters must be sent separately')
558        params = self.parameter_list()
559        if isinstance(values, (list, tuple)):
560            if inline:
561                adapt = self.adapt_inline
562                literals = [adapt(value) for value in values]
563            else:
564                add = params.add
565                literals = []
566                append = literals.append
567                if types:
568                    if (not isinstance(types, (list, tuple)) or
569                            len(types) != len(values)):
570                        raise TypeError('The values and types do not match')
571                    for value, typ in zip(values, types):
572                        append(add(value, typ))
573                else:
574                    for value in values:
575                        append(add(value))
576            command = command % tuple(literals)
577        elif isinstance(values, dict):
578            if inline:
579                adapt = self.adapt_inline
580                literals = dict((key, adapt(value))
581                    for key, value in values.items())
582            else:
583                add = params.add
584                literals = {}
585                if types:
586                    if (not isinstance(types, dict) or
587                            len(types) < len(values)):
588                        raise TypeError('The values and types do not match')
589                    for key in sorted(values):
590                        literals[key] = add(values[key], types[key])
591                else:
592                    for key in sorted(values):
593                        literals[key] = add(values[key])
594            command = command % literals
595        else:
596            raise TypeError('The values must be passed as tuple, list or dict')
597        return command, params
598
599
600def cast_bool(value):
601    """Cast a boolean value."""
602    if not get_bool():
603        return value
604    return value[0] == 't'
605
606
607def cast_json(value):
608    """Cast a JSON value."""
609    cast = get_jsondecode()
610    if not cast:
611        return value
612    return cast(value)
613
614
615def cast_num(value):
616    """Cast a numeric value."""
617    return (get_decimal() or float)(value)
618
619
620def cast_money(value):
621    """Cast a money value."""
622    point = get_decimal_point()
623    if not point:
624        return value
625    if point != '.':
626        value = value.replace(point, '.')
627    value = value.replace('(', '-')
628    value = ''.join(c for c in value if c.isdigit() or c in '.-')
629    return (get_decimal() or float)(value)
630
631
632def cast_int2vector(value):
633    """Cast an int2vector value."""
634    return [int(v) for v in value.split()]
635
636
637def cast_date(value, connection):
638    """Cast a date value."""
639    # The output format depends on the server setting DateStyle.  The default
640    # setting ISO and the setting for German are actually unambiguous.  The
641    # order of days and months in the other two settings is however ambiguous,
642    # so at least here we need to consult the setting to properly parse values.
643    if value == '-infinity':
644        return date.min
645    if value == 'infinity':
646        return date.max
647    value = value.split()
648    if value[-1] == 'BC':
649        return date.min
650    value = value[0]
651    if len(value) > 10:
652        return date.max
653    fmt = connection.date_format()
654    return datetime.strptime(value, fmt).date()
655
656
657def cast_time(value):
658    """Cast a time value."""
659    fmt = '%H:%M:%S.%f' if len(value) > 8 else '%H:%M:%S'
660    return datetime.strptime(value, fmt).time()
661
662
663_re_timezone = regex('(.*)([+-].*)')
664
665
666def cast_timetz(value):
667    """Cast a timetz value."""
668    tz = _re_timezone.match(value)
669    if tz:
670        value, tz = tz.groups()
671    else:
672        tz = '+0000'
673    fmt = '%H:%M:%S.%f' if len(value) > 8 else '%H:%M:%S'
674    if timezones:
675        if tz.startswith(('+', '-')):
676            if len(tz) < 5:
677                tz += '00'
678            else:
679                tz = tz.replace(':', '')
680        elif tz in timezones:
681            tz = timezones[tz]
682        else:
683            tz = '+0000'
684        value += tz
685        fmt += '%z'
686    return datetime.strptime(value, fmt).timetz()
687
688
689def cast_timestamp(value, connection):
690    """Cast a timestamp value."""
691    if value == '-infinity':
692        return datetime.min
693    if value == 'infinity':
694        return datetime.max
695    value = value.split()
696    if value[-1] == 'BC':
697        return datetime.min
698    fmt = connection.date_format()
699    if fmt.endswith('-%Y') and len(value) > 2:
700        value = value[1:5]
701        if len(value[3]) > 4:
702            return datetime.max
703        fmt = ['%d %b' if fmt.startswith('%d') else '%b %d',
704            '%H:%M:%S.%f' if len(value[2]) > 8 else '%H:%M:%S', '%Y']
705    else:
706        if len(value[0]) > 10:
707            return datetime.max
708        fmt = [fmt, '%H:%M:%S.%f' if len(value[1]) > 8 else '%H:%M:%S']
709    return datetime.strptime(' '.join(value), ' '.join(fmt))
710
711
712def cast_timestamptz(value, connection):
713    """Cast a timestamptz value."""
714    if value == '-infinity':
715        return datetime.min
716    if value == 'infinity':
717        return datetime.max
718    value = value.split()
719    if value[-1] == 'BC':
720        return datetime.min
721    fmt = connection.date_format()
722    if fmt.endswith('-%Y') and len(value) > 2:
723        value = value[1:]
724        if len(value[3]) > 4:
725            return datetime.max
726        fmt = ['%d %b' if fmt.startswith('%d') else '%b %d',
727            '%H:%M:%S.%f' if len(value[2]) > 8 else '%H:%M:%S', '%Y']
728        value, tz = value[:-1], value[-1]
729    else:
730        if fmt.startswith('%Y-'):
731            tz = _re_timezone.match(value[1])
732            if tz:
733                value[1], tz = tz.groups()
734            else:
735                tz = '+0000'
736        else:
737            value, tz = value[:-1], value[-1]
738        if len(value[0]) > 10:
739            return datetime.max
740        fmt = [fmt, '%H:%M:%S.%f' if len(value[1]) > 8 else '%H:%M:%S']
741    if timezones:
742        if tz.startswith(('+', '-')):
743            if len(tz) < 5:
744                tz += '00'
745            else:
746                tz = tz.replace(':', '')
747        elif tz in timezones:
748            tz = timezones[tz]
749        else:
750            tz = '+0000'
751        value.append(tz)
752        fmt.append('%z')
753    return datetime.strptime(' '.join(value), ' '.join(fmt))
754
755_re_interval_sql_standard = regex(
756    '(?:([+-])?([0-9]+)-([0-9]+) ?)?'
757    '(?:([+-]?[0-9]+)(?!:) ?)?'
758    '(?:([+-])?([0-9]+):([0-9]+):([0-9]+)(?:\\.([0-9]+))?)?')
759
760_re_interval_postgres = regex(
761    '(?:([+-]?[0-9]+) ?years? ?)?'
762    '(?:([+-]?[0-9]+) ?mons? ?)?'
763    '(?:([+-]?[0-9]+) ?days? ?)?'
764    '(?:([+-])?([0-9]+):([0-9]+):([0-9]+)(?:\\.([0-9]+))?)?')
765
766_re_interval_postgres_verbose = regex(
767    '@ ?(?:([+-]?[0-9]+) ?years? ?)?'
768    '(?:([+-]?[0-9]+) ?mons? ?)?'
769    '(?:([+-]?[0-9]+) ?days? ?)?'
770    '(?:([+-]?[0-9]+) ?hours? ?)?'
771    '(?:([+-]?[0-9]+) ?mins? ?)?'
772    '(?:([+-])?([0-9]+)(?:\\.([0-9]+))? ?secs?)? ?(ago)?')
773
774_re_interval_iso_8601 = regex(
775    'P(?:([+-]?[0-9]+)Y)?'
776    '(?:([+-]?[0-9]+)M)?'
777    '(?:([+-]?[0-9]+)D)?'
778    '(?:T(?:([+-]?[0-9]+)H)?'
779    '(?:([+-]?[0-9]+)M)?'
780    '(?:([+-])?([0-9]+)(?:\\.([0-9]+))?S)?)?')
781
782
783def cast_interval(value):
784    """Cast an interval value."""
785    # The output format depends on the server setting IntervalStyle, but it's
786    # not necessary to consult this setting to parse it.  It's faster to just
787    # check all possible formats, and there is no ambiguity here.
788    m = _re_interval_iso_8601.match(value)
789    if m:
790        m = [d or '0' for d in m.groups()]
791        secs_ago = m.pop(5) == '-'
792        m = [int(d) for d in m]
793        years, mons, days, hours, mins, secs, usecs = m
794        if secs_ago:
795            secs = -secs
796            usecs = -usecs
797    else:
798        m = _re_interval_postgres_verbose.match(value)
799        if m:
800            m, ago = [d or '0' for d in m.groups()[:8]], m.group(9)
801            secs_ago = m.pop(5) == '-'
802            m = [-int(d) for d in m] if ago else [int(d) for d in m]
803            years, mons, days, hours, mins, secs, usecs = m
804            if secs_ago:
805                secs = - secs
806                usecs = -usecs
807        else:
808            m = _re_interval_postgres.match(value)
809            if m and any(m.groups()):
810                m = [d or '0' for d in m.groups()]
811                hours_ago = m.pop(3) == '-'
812                m = [int(d) for d in m]
813                years, mons, days, hours, mins, secs, usecs = m
814                if hours_ago:
815                    hours = -hours
816                    mins = -mins
817                    secs = -secs
818                    usecs = -usecs
819            else:
820                m = _re_interval_sql_standard.match(value)
821                if m and any(m.groups()):
822                    m = [d or '0' for d in m.groups()]
823                    years_ago = m.pop(0) == '-'
824                    hours_ago = m.pop(3) == '-'
825                    m = [int(d) for d in m]
826                    years, mons, days, hours, mins, secs, usecs = m
827                    if years_ago:
828                        years = -years
829                        mons = -mons
830                    if hours_ago:
831                        hours = -hours
832                        mins = -mins
833                        secs = -secs
834                        usecs = -usecs
835                else:
836                    raise ValueError('Cannot parse interval: %s' % value)
837    days += 365 * years + 30 * mons
838    return timedelta(days=days, hours=hours, minutes=mins,
839        seconds=secs, microseconds=usecs)
840
841
842class Typecasts(dict):
843    """Dictionary mapping database types to typecast functions.
844
845    The cast functions get passed the string representation of a value in
846    the database which they need to convert to a Python object.  The
847    passed string will never be None since NULL values are already be
848    handled before the cast function is called.
849
850    Note that the basic types are already handled by the C extension.
851    They only need to be handled here as record or array components.
852    """
853
854    # the default cast functions
855    # (str functions are ignored but have been added for faster access)
856    defaults = {'char': str, 'bpchar': str, 'name': str,
857        'text': str, 'varchar': str,
858        'bool': cast_bool, 'bytea': unescape_bytea,
859        'int2': int, 'int4': int, 'serial': int, 'int8': long, 'oid': int,
860        'hstore': cast_hstore, 'json': cast_json, 'jsonb': cast_json,
861        'float4': float, 'float8': float,
862        'numeric': cast_num, 'money': cast_money,
863        'date': cast_date, 'interval': cast_interval,
864        'time': cast_time, 'timetz': cast_timetz,
865        'timestamp': cast_timestamp, 'timestamptz': cast_timestamptz,
866        'int2vector': cast_int2vector,
867        'anyarray': cast_array, 'record': cast_record}
868
869    connection = None  # will be set in a connection specific instance
870
871    def __missing__(self, typ):
872        """Create a cast function if it is not cached.
873       
874        Note that this class never raises a KeyError,
875        but returns None when no special cast function exists.
876        """
877        if not isinstance(typ, str):
878            raise TypeError('Invalid type: %s' % typ)
879        cast = self.defaults.get(typ)
880        if cast:
881            # store default for faster access
882            cast = self._add_connection(cast)
883            self[typ] = cast
884        elif typ.startswith('_'):
885            base_cast = self[typ[1:]]
886            cast = self.create_array_cast(base_cast)
887            if base_cast:
888                self[typ] = cast
889        else:
890            attnames = self.get_attnames(typ)
891            if attnames:
892                casts = [self[v.pgtype] for v in attnames.values()]
893                cast = self.create_record_cast(typ, attnames, casts)
894                self[typ] = cast
895        return cast
896
897    @staticmethod
898    def _needs_connection(func):
899        """Check if a typecast function needs a connection argument."""
900        try:
901            args = get_args(func)
902        except (TypeError, ValueError):
903            return False
904        else:
905            return 'connection' in args[1:]
906
907    def _add_connection(self, cast):
908        """Add a connection argument to the typecast function if necessary."""
909        if not self.connection or not self._needs_connection(cast):
910            return cast
911        return partial(cast, connection=self.connection)
912
913    def get(self, typ, default=None):
914        """Get the typecast function for the given database type."""
915        return self[typ] or default
916
917    def set(self, typ, cast):
918        """Set a typecast function for the specified database type(s)."""
919        if isinstance(typ, basestring):
920            typ = [typ]
921        if cast is None:
922            for t in typ:
923                self.pop(t, None)
924                self.pop('_%s' % t, None)
925        else:
926            if not callable(cast):
927                raise TypeError("Cast parameter must be callable")
928            for t in typ:
929                self[t] = self._add_connection(cast)
930                self.pop('_%s' % t, None)
931
932    def reset(self, typ=None):
933        """Reset the typecasts for the specified type(s) to their defaults.
934
935        When no type is specified, all typecasts will be reset.
936        """
937        if typ is None:
938            self.clear()
939        else:
940            if isinstance(typ, basestring):
941                typ = [typ]
942            for t in typ:
943                self.pop(t, None)
944
945    @classmethod
946    def get_default(cls, typ):
947        """Get the default typecast function for the given database type."""
948        return cls.defaults.get(typ)
949
950    @classmethod
951    def set_default(cls, typ, cast):
952        """Set a default typecast function for the given database type(s)."""
953        if isinstance(typ, basestring):
954            typ = [typ]
955        defaults = cls.defaults
956        if cast is None:
957            for t in typ:
958                defaults.pop(t, None)
959                defaults.pop('_%s' % t, None)
960        else:
961            if not callable(cast):
962                raise TypeError("Cast parameter must be callable")
963            for t in typ:
964                defaults[t] = cast
965                defaults.pop('_%s' % t, None)
966
967    def get_attnames(self, typ):
968        """Return the fields for the given record type.
969
970        This method will be replaced with the get_attnames() method of DbTypes.
971        """
972        return {}
973
974    def dateformat(self):
975        """Return the current date format.
976
977        This method will be replaced with the dateformat() method of DbTypes.
978        """
979        return '%Y-%m-%d'
980
981    def create_array_cast(self, basecast):
982        """Create an array typecast for the given base cast."""
983        def cast(v):
984            return cast_array(v, basecast)
985        return cast
986
987    def create_record_cast(self, name, fields, casts):
988        """Create a named record typecast for the given fields and casts."""
989        record = namedtuple(name, fields)
990        def cast(v):
991            return record(*cast_record(v, casts))
992        return cast
993
994
995def get_typecast(typ):
996    """Get the global typecast function for the given database type(s)."""
997    return Typecasts.get_default(typ)
998
999
1000def set_typecast(typ, cast):
1001    """Set a global typecast function for the given database type(s).
1002
1003    Note that connections cache cast functions. To be sure a global change
1004    is picked up by a running connection, call db.db_types.reset_typecast().
1005    """
1006    Typecasts.set_default(typ, cast)
1007
1008
1009class DbType(str):
1010    """Class augmenting the simple type name with additional info.
1011
1012    The following additional information is provided:
1013
1014        oid: the PostgreSQL type OID
1015        pgtype: the PostgreSQL type name
1016        regtype: the regular type name
1017        simple: the simple PyGreSQL type name
1018        typtype: b = base type, c = composite type etc.
1019        category: A = Array, b =Boolean, C = Composite etc.
1020        delim: delimiter for array types
1021        relid: corresponding table for composite types
1022        attnames: attributes for composite types
1023    """
1024
1025    @property
1026    def attnames(self):
1027        """Get names and types of the fields of a composite type."""
1028        return self._get_attnames(self)
1029
1030
1031class DbTypes(dict):
1032    """Cache for PostgreSQL data types.
1033
1034    This cache maps type OIDs and names to DbType objects containing
1035    information on the associated database type.
1036    """
1037
1038    _num_types = frozenset('int float num money'
1039        ' int2 int4 int8 float4 float8 numeric money'.split())
1040
1041    def __init__(self, db):
1042        """Initialize type cache for connection."""
1043        super(DbTypes, self).__init__()
1044        self._regtypes = False
1045        self._get_attnames = db.get_attnames
1046        self._typecasts = Typecasts()
1047        self._typecasts.get_attnames = self.get_attnames
1048        self._typecasts.connection = db
1049        db = db.db
1050        self.query = db.query
1051        self.escape_string = db.escape_string
1052
1053    def add(self, oid, pgtype, regtype,
1054               typtype, category, delim, relid):
1055        """Create a PostgreSQL type name with additional info."""
1056        if oid in self:
1057            return self[oid]
1058        simple = 'record' if relid else _simpletypes[pgtype]
1059        typ = DbType(regtype if self._regtypes else simple)
1060        typ.oid = oid
1061        typ.simple = simple
1062        typ.pgtype = pgtype
1063        typ.regtype = regtype
1064        typ.typtype = typtype
1065        typ.category = category
1066        typ.delim = delim
1067        typ.relid = relid
1068        typ._get_attnames = self.get_attnames
1069        return typ
1070
1071    def __missing__(self, key):
1072        """Get the type info from the database if it is not cached."""
1073        try:
1074            res = self.query("SELECT oid, typname, typname::regtype,"
1075                " typtype, typcategory, typdelim, typrelid"
1076                " FROM pg_type WHERE oid=%s::regtype" %
1077                (_quote_if_unqualified('$1', key),), (key,)).getresult()
1078        except ProgrammingError:
1079            res = None
1080        if not res:
1081            raise KeyError('Type %s could not be found' % key)
1082        res = res[0]
1083        typ = self.add(*res)
1084        self[typ.oid] = self[typ.pgtype] = typ
1085        return typ
1086
1087    def get(self, key, default=None):
1088        """Get the type even if it is not cached."""
1089        try:
1090            return self[key]
1091        except KeyError:
1092            return default
1093
1094    def get_attnames(self, typ):
1095        """Get names and types of the fields of a composite type."""
1096        if not isinstance(typ, DbType):
1097            typ = self.get(typ)
1098            if not typ:
1099                return None
1100        if not typ.relid:
1101            return None
1102        return self._get_attnames(typ.relid, with_oid=False)
1103
1104    def get_typecast(self, typ):
1105        """Get the typecast function for the given database type."""
1106        return self._typecasts.get(typ)
1107
1108    def set_typecast(self, typ, cast):
1109        """Set a typecast function for the specified database type(s)."""
1110        self._typecasts.set(typ, cast)
1111
1112    def reset_typecast(self, typ=None):
1113        """Reset the typecast function for the specified database type(s)."""
1114        self._typecasts.reset(typ)
1115
1116    def typecast(self, value, typ):
1117        """Cast the given value according to the given database type."""
1118        if value is None:
1119            # for NULL values, no typecast is necessary
1120            return None
1121        if not isinstance(typ, DbType):
1122            typ = self.get(typ)
1123            if typ:
1124                typ = typ.pgtype
1125        cast = self.get_typecast(typ) if typ else None
1126        if not cast or cast is str:
1127            # no typecast is necessary
1128            return value
1129        return cast(value)
1130
1131
1132def _namedresult(q):
1133    """Get query result as named tuples."""
1134    row = namedtuple('Row', q.listfields())
1135    return [row(*r) for r in q.getresult()]
1136
1137
1138class _MemoryQuery:
1139    """Class that embodies a given query result."""
1140
1141    def __init__(self, result, fields):
1142        """Create query from given result rows and field names."""
1143        self.result = result
1144        self.fields = fields
1145
1146    def listfields(self):
1147        """Return the stored field names of this query."""
1148        return self.fields
1149
1150    def getresult(self):
1151        """Return the stored result of this query."""
1152        return self.result
1153
1154
1155def _db_error(msg, cls=DatabaseError):
1156    """Return DatabaseError with empty sqlstate attribute."""
1157    error = cls(msg)
1158    error.sqlstate = None
1159    return error
1160
1161
1162def _int_error(msg):
1163    """Return InternalError."""
1164    return _db_error(msg, InternalError)
1165
1166
1167def _prg_error(msg):
1168    """Return ProgrammingError."""
1169    return _db_error(msg, ProgrammingError)
1170
1171
1172# Initialize the C module
1173
1174set_namedresult(_namedresult)
1175set_decimal(Decimal)
1176set_jsondecode(jsondecode)
1177
1178
1179# The notification handler
1180
1181class NotificationHandler(object):
1182    """A PostgreSQL client-side asynchronous notification handler."""
1183
1184    def __init__(self, db, event, callback=None,
1185            arg_dict=None, timeout=None, stop_event=None):
1186        """Initialize the notification handler.
1187
1188        You must pass a PyGreSQL database connection, the name of an
1189        event (notification channel) to listen for and a callback function.
1190
1191        You can also specify a dictionary arg_dict that will be passed as
1192        the single argument to the callback function, and a timeout value
1193        in seconds (a floating point number denotes fractions of seconds).
1194        If it is absent or None, the callers will never time out.  If the
1195        timeout is reached, the callback function will be called with a
1196        single argument that is None.  If you set the timeout to zero,
1197        the handler will poll notifications synchronously and return.
1198
1199        You can specify the name of the event that will be used to signal
1200        the handler to stop listening as stop_event. By default, it will
1201        be the event name prefixed with 'stop_'.
1202        """
1203        self.db = db
1204        self.event = event
1205        self.stop_event = stop_event or 'stop_%s' % event
1206        self.listening = False
1207        self.callback = callback
1208        if arg_dict is None:
1209            arg_dict = {}
1210        self.arg_dict = arg_dict
1211        self.timeout = timeout
1212
1213    def __del__(self):
1214        self.unlisten()
1215
1216    def close(self):
1217        """Stop listening and close the connection."""
1218        if self.db:
1219            self.unlisten()
1220            self.db.close()
1221            self.db = None
1222
1223    def listen(self):
1224        """Start listening for the event and the stop event."""
1225        if not self.listening:
1226            self.db.query('listen "%s"' % self.event)
1227            self.db.query('listen "%s"' % self.stop_event)
1228            self.listening = True
1229
1230    def unlisten(self):
1231        """Stop listening for the event and the stop event."""
1232        if self.listening:
1233            self.db.query('unlisten "%s"' % self.event)
1234            self.db.query('unlisten "%s"' % self.stop_event)
1235            self.listening = False
1236
1237    def notify(self, db=None, stop=False, payload=None):
1238        """Generate a notification.
1239
1240        Optionally, you can pass a payload with the notification.
1241
1242        If you set the stop flag, a stop notification will be sent that
1243        will cause the handler to stop listening.
1244
1245        Note: If the notification handler is running in another thread, you
1246        must pass a different database connection since PyGreSQL database
1247        connections are not thread-safe.
1248        """
1249        if self.listening:
1250            if not db:
1251                db = self.db
1252            q = 'notify "%s"' % (self.stop_event if stop else self.event)
1253            if payload:
1254                q += ", '%s'" % payload
1255            return db.query(q)
1256
1257    def __call__(self):
1258        """Invoke the notification handler.
1259
1260        The handler is a loop that listens for notifications on the event
1261        and stop event channels.  When either of these notifications are
1262        received, its associated 'pid', 'event' and 'extra' (the payload
1263        passed with the notification) are inserted into its arg_dict
1264        dictionary and the callback is invoked with this dictionary as
1265        a single argument.  When the handler receives a stop event, it
1266        stops listening to both events and return.
1267
1268        In the special case that the timeout of the handler has been set
1269        to zero, the handler will poll all events synchronously and return.
1270        If will keep listening until it receives a stop event.
1271
1272        Note: If you run this loop in another thread, don't use the same
1273        database connection for database operations in the main thread.
1274        """
1275        self.listen()
1276        poll = self.timeout == 0
1277        if not poll:
1278            rlist = [self.db.fileno()]
1279        while self.listening:
1280            if poll or select.select(rlist, [], [], self.timeout)[0]:
1281                while self.listening:
1282                    notice = self.db.getnotify()
1283                    if not notice:  # no more messages
1284                        break
1285                    event, pid, extra = notice
1286                    if event not in (self.event, self.stop_event):
1287                        self.unlisten()
1288                        raise _db_error(
1289                            'Listening for "%s" and "%s", but notified of "%s"'
1290                            % (self.event, self.stop_event, event))
1291                    if event == self.stop_event:
1292                        self.unlisten()
1293                    self.arg_dict.update(pid=pid, event=event, extra=extra)
1294                    self.callback(self.arg_dict)
1295                if poll:
1296                    break
1297            else:   # we timed out
1298                self.unlisten()
1299                self.callback(None)
1300
1301
1302def pgnotify(*args, **kw):
1303    """Same as NotificationHandler, under the traditional name."""
1304    warnings.warn("pgnotify is deprecated, use NotificationHandler instead",
1305        DeprecationWarning, stacklevel=2)
1306    return NotificationHandler(*args, **kw)
1307
1308
1309# The actual PostGreSQL database connection interface:
1310
1311class DB:
1312    """Wrapper class for the _pg connection type."""
1313
1314    def __init__(self, *args, **kw):
1315        """Create a new connection
1316
1317        You can pass either the connection parameters or an existing
1318        _pg or pgdb connection. This allows you to use the methods
1319        of the classic pg interface with a DB-API 2 pgdb connection.
1320        """
1321        if not args and len(kw) == 1:
1322            db = kw.get('db')
1323        elif not kw and len(args) == 1:
1324            db = args[0]
1325        else:
1326            db = None
1327        if db:
1328            if isinstance(db, DB):
1329                db = db.db
1330            else:
1331                try:
1332                    db = db._cnx
1333                except AttributeError:
1334                    pass
1335        if not db or not hasattr(db, 'db') or not hasattr(db, 'query'):
1336            db = connect(*args, **kw)
1337            self._closeable = True
1338        else:
1339            self._closeable = False
1340        self.db = db
1341        self.dbname = db.db
1342        self._regtypes = False
1343        self._attnames = {}
1344        self._pkeys = {}
1345        self._privileges = {}
1346        self._args = args, kw
1347        self.adapter = Adapter(self)
1348        self.dbtypes = DbTypes(self)
1349        db.set_cast_hook(self.dbtypes.typecast)
1350        self.debug = None  # For debugging scripts, this can be set
1351            # * to a string format specification (e.g. in CGI set to "%s<BR>"),
1352            # * to a file object to write debug statements or
1353            # * to a callable object which takes a string argument
1354            # * to any other true value to just print debug statements
1355
1356    def __getattr__(self, name):
1357        # All undefined members are same as in underlying connection:
1358        if self.db:
1359            return getattr(self.db, name)
1360        else:
1361            raise _int_error('Connection is not valid')
1362
1363    def __dir__(self):
1364        # Custom dir function including the attributes of the connection:
1365        attrs = set(self.__class__.__dict__)
1366        attrs.update(self.__dict__)
1367        attrs.update(dir(self.db))
1368        return sorted(attrs)
1369
1370    # Context manager methods
1371
1372    def __enter__(self):
1373        """Enter the runtime context. This will start a transactio."""
1374        self.begin()
1375        return self
1376
1377    def __exit__(self, et, ev, tb):
1378        """Exit the runtime context. This will end the transaction."""
1379        if et is None and ev is None and tb is None:
1380            self.commit()
1381        else:
1382            self.rollback()
1383
1384    # Auxiliary methods
1385
1386    def _do_debug(self, *args):
1387        """Print a debug message"""
1388        if self.debug:
1389            s = '\n'.join(str(arg) for arg in args)
1390            if isinstance(self.debug, basestring):
1391                print(self.debug % s)
1392            elif hasattr(self.debug, 'write'):
1393                self.debug.write(s + '\n')
1394            elif callable(self.debug):
1395                self.debug(s)
1396            else:
1397                print(s)
1398
1399    def _escape_qualified_name(self, s):
1400        """Escape a qualified name.
1401
1402        Escapes the name for use as an SQL identifier, unless the
1403        name contains a dot, in which case the name is ambiguous
1404        (could be a qualified name or just a name with a dot in it)
1405        and must be quoted manually by the caller.
1406        """
1407        if '.' not in s:
1408            s = self.escape_identifier(s)
1409        return s
1410
1411    @staticmethod
1412    def _make_bool(d):
1413        """Get boolean value corresponding to d."""
1414        return bool(d) if get_bool() else ('t' if d else 'f')
1415
1416    def _list_params(self, params):
1417        """Create a human readable parameter list."""
1418        return ', '.join('$%d=%r' % (n, v) for n, v in enumerate(params, 1))
1419
1420    # Public methods
1421
1422    # escape_string and escape_bytea exist as methods,
1423    # so we define unescape_bytea as a method as well
1424    unescape_bytea = staticmethod(unescape_bytea)
1425
1426    def decode_json(self, s):
1427        """Decode a JSON string coming from the database."""
1428        return (get_jsondecode() or jsondecode)(s)
1429
1430    def encode_json(self, d):
1431        """Encode a JSON string for use within SQL."""
1432        return jsonencode(d)
1433
1434    def close(self):
1435        """Close the database connection."""
1436        # Wraps shared library function so we can track state.
1437        if self._closeable:
1438            if self.db:
1439                self.db.close()
1440                self.db = None
1441            else:
1442                raise _int_error('Connection already closed')
1443
1444    def reset(self):
1445        """Reset connection with current parameters.
1446
1447        All derived queries and large objects derived from this connection
1448        will not be usable after this call.
1449
1450        """
1451        if self.db:
1452            self.db.reset()
1453        else:
1454            raise _int_error('Connection already closed')
1455
1456    def reopen(self):
1457        """Reopen connection to the database.
1458
1459        Used in case we need another connection to the same database.
1460        Note that we can still reopen a database that we have closed.
1461
1462        """
1463        # There is no such shared library function.
1464        if self._closeable:
1465            db = connect(*self._args[0], **self._args[1])
1466            if self.db:
1467                self.db.close()
1468            self.db = db
1469
1470    def begin(self, mode=None):
1471        """Begin a transaction."""
1472        qstr = 'BEGIN'
1473        if mode:
1474            qstr += ' ' + mode
1475        return self.query(qstr)
1476
1477    start = begin
1478
1479    def commit(self):
1480        """Commit the current transaction."""
1481        return self.query('COMMIT')
1482
1483    end = commit
1484
1485    def rollback(self, name=None):
1486        """Roll back the current transaction."""
1487        qstr = 'ROLLBACK'
1488        if name:
1489            qstr += ' TO ' + name
1490        return self.query(qstr)
1491
1492    abort = rollback
1493
1494    def savepoint(self, name):
1495        """Define a new savepoint within the current transaction."""
1496        return self.query('SAVEPOINT ' + name)
1497
1498    def release(self, name):
1499        """Destroy a previously defined savepoint."""
1500        return self.query('RELEASE ' + name)
1501
1502    def get_parameter(self, parameter):
1503        """Get the value of a run-time parameter.
1504
1505        If the parameter is a string, the return value will also be a string
1506        that is the current setting of the run-time parameter with that name.
1507
1508        You can get several parameters at once by passing a list, set or dict.
1509        When passing a list of parameter names, the return value will be a
1510        corresponding list of parameter settings.  When passing a set of
1511        parameter names, a new dict will be returned, mapping these parameter
1512        names to their settings.  Finally, if you pass a dict as parameter,
1513        its values will be set to the current parameter settings corresponding
1514        to its keys.
1515
1516        By passing the special name 'all' as the parameter, you can get a dict
1517        of all existing configuration parameters.
1518        """
1519        if isinstance(parameter, basestring):
1520            parameter = [parameter]
1521            values = None
1522        elif isinstance(parameter, (list, tuple)):
1523            values = []
1524        elif isinstance(parameter, (set, frozenset)):
1525            values = {}
1526        elif isinstance(parameter, dict):
1527            values = parameter
1528        else:
1529            raise TypeError(
1530                'The parameter must be a string, list, set or dict')
1531        if not parameter:
1532            raise TypeError('No parameter has been specified')
1533        params = {} if isinstance(values, dict) else []
1534        for key in parameter:
1535            param = key.strip().lower() if isinstance(
1536                key, basestring) else None
1537            if not param:
1538                raise TypeError('Invalid parameter')
1539            if param == 'all':
1540                q = 'SHOW ALL'
1541                values = self.db.query(q).getresult()
1542                values = dict(value[:2] for value in values)
1543                break
1544            if isinstance(values, dict):
1545                params[param] = key
1546            else:
1547                params.append(param)
1548        else:
1549            for param in params:
1550                q = 'SHOW %s' % (param,)
1551                value = self.db.query(q).getresult()[0][0]
1552                if values is None:
1553                    values = value
1554                elif isinstance(values, list):
1555                    values.append(value)
1556                else:
1557                    values[params[param]] = value
1558        return values
1559
1560    def set_parameter(self, parameter, value=None, local=False):
1561        """Set the value of a run-time parameter.
1562
1563        If the parameter and the value are strings, the run-time parameter
1564        will be set to that value.  If no value or None is passed as a value,
1565        then the run-time parameter will be restored to its default value.
1566
1567        You can set several parameters at once by passing a list of parameter
1568        names, together with a single value that all parameters should be
1569        set to or with a corresponding list of values.  You can also pass
1570        the parameters as a set if you only provide a single value.
1571        Finally, you can pass a dict with parameter names as keys.  In this
1572        case, you should not pass a value, since the values for the parameters
1573        will be taken from the dict.
1574
1575        By passing the special name 'all' as the parameter, you can reset
1576        all existing settable run-time parameters to their default values.
1577
1578        If you set local to True, then the command takes effect for only the
1579        current transaction.  After commit() or rollback(), the session-level
1580        setting takes effect again.  Setting local to True will appear to
1581        have no effect if it is executed outside a transaction, since the
1582        transaction will end immediately.
1583        """
1584        if isinstance(parameter, basestring):
1585            parameter = {parameter: value}
1586        elif isinstance(parameter, (list, tuple)):
1587            if isinstance(value, (list, tuple)):
1588                parameter = dict(zip(parameter, value))
1589            else:
1590                parameter = dict.fromkeys(parameter, value)
1591        elif isinstance(parameter, (set, frozenset)):
1592            if isinstance(value, (list, tuple, set, frozenset)):
1593                value = set(value)
1594                if len(value) == 1:
1595                    value = value.pop()
1596            if not(value is None or isinstance(value, basestring)):
1597                raise ValueError('A single value must be specified'
1598                    ' when parameter is a set')
1599            parameter = dict.fromkeys(parameter, value)
1600        elif isinstance(parameter, dict):
1601            if value is not None:
1602                raise ValueError('A value must not be specified'
1603                    ' when parameter is a dictionary')
1604        else:
1605            raise TypeError(
1606                'The parameter must be a string, list, set or dict')
1607        if not parameter:
1608            raise TypeError('No parameter has been specified')
1609        params = {}
1610        for key, value in parameter.items():
1611            param = key.strip().lower() if isinstance(
1612                key, basestring) else None
1613            if not param:
1614                raise TypeError('Invalid parameter')
1615            if param == 'all':
1616                if value is not None:
1617                    raise ValueError('A value must ot be specified'
1618                        " when parameter is 'all'")
1619                params = {'all': None}
1620                break
1621            params[param] = value
1622        local = ' LOCAL' if local else ''
1623        for param, value in params.items():
1624            if value is None:
1625                q = 'RESET%s %s' % (local, param)
1626            else:
1627                q = 'SET%s %s TO %s' % (local, param, value)
1628            self._do_debug(q)
1629            self.db.query(q)
1630
1631    def query(self, command, *args):
1632        """Execute a SQL command string.
1633
1634        This method simply sends a SQL query to the database.  If the query is
1635        an insert statement that inserted exactly one row into a table that
1636        has OIDs, the return value is the OID of the newly inserted row.
1637        If the query is an update or delete statement, or an insert statement
1638        that did not insert exactly one row in a table with OIDs, then the
1639        number of rows affected is returned as a string.  If it is a statement
1640        that returns rows as a result (usually a select statement, but maybe
1641        also an "insert/update ... returning" statement), this method returns
1642        a Query object that can be accessed via getresult() or dictresult()
1643        or simply printed.  Otherwise, it returns `None`.
1644
1645        The query can contain numbered parameters of the form $1 in place
1646        of any data constant.  Arguments given after the query string will
1647        be substituted for the corresponding numbered parameter.  Parameter
1648        values can also be given as a single list or tuple argument.
1649        """
1650        # Wraps shared library function for debugging.
1651        if not self.db:
1652            raise _int_error('Connection is not valid')
1653        if args:
1654            self._do_debug(command, args)
1655            return self.db.query(command, args)
1656        self._do_debug(command)
1657        return self.db.query(command)
1658
1659    def query_formatted(self, command, parameters, types=None, inline=False):
1660        """Execute a formatted SQL command string.
1661
1662        Similar to query, but using Python format placeholders of the form
1663        %s or %(names)s instead of PostgreSQL placeholders of the form $1.
1664        The parameters must be passed as a tuple, list or dict.  You can
1665        also pass a corresponding tuple, list or dict of database types in
1666        order to format the parameters properly in case there is ambiguity.
1667
1668        If you set inline to True, the parameters will be sent to the database
1669        embedded in the SQL command, otherwise they will be sent separately.
1670        """
1671        return self.query(*self.adapter.format_query(
1672            command, parameters, types, inline))
1673
1674    def pkey(self, table, composite=False, flush=False):
1675        """Get or set the primary key of a table.
1676
1677        Single primary keys are returned as strings unless you
1678        set the composite flag.  Composite primary keys are always
1679        represented as tuples.  Note that this raises a KeyError
1680        if the table does not have a primary key.
1681
1682        If flush is set then the internal cache for primary keys will
1683        be flushed.  This may be necessary after the database schema or
1684        the search path has been changed.
1685        """
1686        pkeys = self._pkeys
1687        if flush:
1688            pkeys.clear()
1689            self._do_debug('The pkey cache has been flushed')
1690        try:  # cache lookup
1691            pkey = pkeys[table]
1692        except KeyError:  # cache miss, check the database
1693            q = ("SELECT a.attname, a.attnum, i.indkey FROM pg_index i"
1694                " JOIN pg_attribute a ON a.attrelid = i.indrelid"
1695                " AND a.attnum = ANY(i.indkey)"
1696                " AND NOT a.attisdropped"
1697                " WHERE i.indrelid=%s::regclass"
1698                " AND i.indisprimary ORDER BY a.attnum") % (
1699                    _quote_if_unqualified('$1', table),)
1700            pkey = self.db.query(q, (table,)).getresult()
1701            if not pkey:
1702                raise KeyError('Table %s has no primary key' % table)
1703            # we want to use the order defined in the primary key index here,
1704            # not the order as defined by the columns in the table
1705            if len(pkey) > 1:
1706                indkey = pkey[0][2]
1707                pkey = sorted(pkey, key=lambda row: indkey.index(row[1]))
1708                pkey = tuple(row[0] for row in pkey)
1709            else:
1710                pkey = pkey[0][0]
1711            pkeys[table] = pkey  # cache it
1712        if composite and not isinstance(pkey, tuple):
1713            pkey = (pkey,)
1714        return pkey
1715
1716    def get_databases(self):
1717        """Get list of databases in the system."""
1718        return [s[0] for s in
1719            self.db.query('SELECT datname FROM pg_database').getresult()]
1720
1721    def get_relations(self, kinds=None):
1722        """Get list of relations in connected database of specified kinds.
1723
1724        If kinds is None or empty, all kinds of relations are returned.
1725        Otherwise kinds can be a string or sequence of type letters
1726        specifying which kind of relations you want to list.
1727        """
1728        where = " AND r.relkind IN (%s)" % ','.join(
1729            ["'%s'" % k for k in kinds]) if kinds else ''
1730        q = ("SELECT quote_ident(s.nspname)||'.'||quote_ident(r.relname)"
1731            " FROM pg_class r"
1732            " JOIN pg_namespace s ON s.oid = r.relnamespace"
1733            " WHERE s.nspname NOT SIMILAR"
1734            " TO 'pg/_%%|information/_schema' ESCAPE '/' %s"
1735            " ORDER BY s.nspname, r.relname") % where
1736        return [r[0] for r in self.db.query(q).getresult()]
1737
1738    def get_tables(self):
1739        """Return list of tables in connected database."""
1740        return self.get_relations('r')
1741
1742    def get_attnames(self, table, with_oid=True, flush=False):
1743        """Given the name of a table, dig out the set of attribute names.
1744
1745        Returns a read-only dictionary of attribute names (the names are
1746        the keys, the values are the names of the attributes' types)
1747        with the column names in the proper order if you iterate over it.
1748
1749        If flush is set, then the internal cache for attribute names will
1750        be flushed. This may be necessary after the database schema or
1751        the search path has been changed.
1752
1753        By default, only a limited number of simple types will be returned.
1754        You can get the regular types after calling use_regtypes(True).
1755        """
1756        attnames = self._attnames
1757        if flush:
1758            attnames.clear()
1759            self._do_debug('The attnames cache has been flushed')
1760        try:  # cache lookup
1761            names = attnames[table]
1762        except KeyError:  # cache miss, check the database
1763            q = "a.attnum > 0"
1764            if with_oid:
1765                q = "(%s OR a.attname = 'oid')" % q
1766            q = ("SELECT a.attname, t.oid, t.typname, t.typname::regtype,"
1767                " t.typtype, t.typcategory, t.typdelim, t.typrelid"
1768                " FROM pg_attribute a"
1769                " JOIN pg_type t ON t.oid = a.atttypid"
1770                " WHERE a.attrelid = %s::regclass AND %s"
1771                " AND NOT a.attisdropped ORDER BY a.attnum") % (
1772                    _quote_if_unqualified('$1', table), q)
1773            names = self.db.query(q, (table,)).getresult()
1774            types = self.dbtypes
1775            names = ((name[0], types.add(*name[1:])) for name in names)
1776            names = AttrDict(names)
1777            attnames[table] = names  # cache it
1778        return names
1779
1780    def use_regtypes(self, regtypes=None):
1781        """Use regular type names instead of simplified type names."""
1782        if regtypes is None:
1783            return self.dbtypes._regtypes
1784        else:
1785            regtypes = bool(regtypes)
1786            if regtypes != self.dbtypes._regtypes:
1787                self.dbtypes._regtypes = regtypes
1788                self._attnames.clear()
1789                self.dbtypes.clear()
1790            return regtypes
1791
1792    def has_table_privilege(self, table, privilege='select', flush=False):
1793        """Check whether current user has specified table privilege.
1794
1795        If flush is set, then the internal cache for table privileges will
1796        be flushed. This may be necessary after privileges have been changed.
1797        """
1798        privileges = self._privileges
1799        if flush:
1800            privileges.clear()
1801            self._do_debug('The privileges cache has been flushed')
1802        privilege = privilege.lower()
1803        try:  # ask cache
1804            ret = privileges[table, privilege]
1805        except KeyError:  # cache miss, ask the database
1806            q = "SELECT has_table_privilege(%s, $2)" % (
1807                _quote_if_unqualified('$1', table),)
1808            q = self.db.query(q, (table, privilege))
1809            ret = q.getresult()[0][0] == self._make_bool(True)
1810            privileges[table, privilege] = ret  # cache it
1811        return ret
1812
1813    def get(self, table, row, keyname=None):
1814        """Get a row from a database table or view.
1815
1816        This method is the basic mechanism to get a single row.  It assumes
1817        that the keyname specifies a unique row.  It must be the name of a
1818        single column or a tuple of column names.  If the keyname is not
1819        specified, then the primary key for the table is used.
1820
1821        If row is a dictionary, then the value for the key is taken from it.
1822        Otherwise, the row must be a single value or a tuple of values
1823        corresponding to the passed keyname or primary key.  The fetched row
1824        from the table will be returned as a new dictionary or used to replace
1825        the existing values when row was passed as aa dictionary.
1826
1827        The OID is also put into the dictionary if the table has one, but
1828        in order to allow the caller to work with multiple tables, it is
1829        munged as "oid(table)" using the actual name of the table.
1830        """
1831        if table.endswith('*'):  # hint for descendant tables can be ignored
1832            table = table[:-1].rstrip()
1833        attnames = self.get_attnames(table)
1834        qoid = _oid_key(table) if 'oid' in attnames else None
1835        if keyname and isinstance(keyname, basestring):
1836            keyname = (keyname,)
1837        if qoid and isinstance(row, dict) and qoid in row and 'oid' not in row:
1838            row['oid'] = row[qoid]
1839        if not keyname:
1840            try:  # if keyname is not specified, try using the primary key
1841                keyname = self.pkey(table, True)
1842            except KeyError:  # the table has no primary key
1843                # try using the oid instead
1844                if qoid and isinstance(row, dict) and 'oid' in row:
1845                    keyname = ('oid',)
1846                else:
1847                    raise _prg_error('Table %s has no primary key' % table)
1848            else:  # the table has a primary key
1849                # check whether all key columns have values
1850                if isinstance(row, dict) and not set(keyname).issubset(row):
1851                    # try using the oid instead
1852                    if qoid and 'oid' in row:
1853                        keyname = ('oid',)
1854                    else:
1855                        raise KeyError(
1856                            'Missing value in row for specified keyname')
1857        if not isinstance(row, dict):
1858            if not isinstance(row, (tuple, list)):
1859                row = [row]
1860            if len(keyname) != len(row):
1861                raise KeyError(
1862                    'Differing number of items in keyname and row')
1863            row = dict(zip(keyname, row))
1864        params = self.adapter.parameter_list()
1865        adapt = params.add
1866        col = self.escape_identifier
1867        what = 'oid, *' if qoid else '*'
1868        where = ' AND '.join('%s = %s' % (
1869            col(k), adapt(row[k], attnames[k])) for k in keyname)
1870        if 'oid' in row:
1871            if qoid:
1872                row[qoid] = row['oid']
1873            del row['oid']
1874        q = 'SELECT %s FROM %s WHERE %s LIMIT 1' % (
1875            what, self._escape_qualified_name(table), where)
1876        self._do_debug(q, params)
1877        q = self.db.query(q, params)
1878        res = q.dictresult()
1879        if not res:
1880            raise _db_error('No such record in %s\nwhere %s\nwith %s' % (
1881                table, where, self._list_params(params)))
1882        for n, value in res[0].items():
1883            if qoid and n == 'oid':
1884                n = qoid
1885            row[n] = value
1886        return row
1887
1888    def insert(self, table, row=None, **kw):
1889        """Insert a row into a database table.
1890
1891        This method inserts a row into a table.  The name of the table must
1892        be passed as the first parameter.  The other parameters are used for
1893        providing the data of the row that shall be inserted into the table.
1894        If a dictionary is supplied as the second parameter, it starts with
1895        that.  Otherwise it uses a blank dictionary. Either way the dictionary
1896        is updated from the keywords.
1897
1898        The dictionary is then reloaded with the values actually inserted in
1899        order to pick up values modified by rules, triggers, etc.
1900        """
1901        if table.endswith('*'):  # hint for descendant tables can be ignored
1902            table = table[:-1].rstrip()
1903        if row is None:
1904            row = {}
1905        row.update(kw)
1906        if 'oid' in row:
1907            del row['oid']  # do not insert oid
1908        attnames = self.get_attnames(table)
1909        qoid = _oid_key(table) if 'oid' in attnames else None
1910        params = self.adapter.parameter_list()
1911        adapt = params.add
1912        col = self.escape_identifier
1913        names, values = [], []
1914        for n in attnames:
1915            if n in row:
1916                names.append(col(n))
1917                values.append(adapt(row[n], attnames[n]))
1918        if not names:
1919            raise _prg_error('No column found that can be inserted')
1920        names, values = ', '.join(names), ', '.join(values)
1921        ret = 'oid, *' if qoid else '*'
1922        q = 'INSERT INTO %s (%s) VALUES (%s) RETURNING %s' % (
1923            self._escape_qualified_name(table), names, values, ret)
1924        self._do_debug(q, params)
1925        q = self.db.query(q, params)
1926        res = q.dictresult()
1927        if res:  # this should always be true
1928            for n, value in res[0].items():
1929                if qoid and n == 'oid':
1930                    n = qoid
1931                row[n] = value
1932        return row
1933
1934    def update(self, table, row=None, **kw):
1935        """Update an existing row in a database table.
1936
1937        Similar to insert but updates an existing row.  The update is based
1938        on the primary key of the table or the OID value as munged by get
1939        or passed as keyword.
1940
1941        The dictionary is then modified to reflect any changes caused by the
1942        update due to triggers, rules, default values, etc.
1943        """
1944        if table.endswith('*'):
1945            table = table[:-1].rstrip()  # need parent table name
1946        attnames = self.get_attnames(table)
1947        qoid = _oid_key(table) if 'oid' in attnames else None
1948        if row is None:
1949            row = {}
1950        elif 'oid' in row:
1951            del row['oid']  # only accept oid key from named args for safety
1952        row.update(kw)
1953        if qoid and qoid in row and 'oid' not in row:
1954            row['oid'] = row[qoid]
1955        try:  # try using the primary key
1956            keyname = self.pkey(table, True)
1957        except KeyError:  # the table has no primary key
1958            # try using the oid instead
1959            if qoid and 'oid' in row:
1960                keyname = ('oid',)
1961            else:
1962                raise _prg_error('Table %s has no primary key' % table)
1963        else:  # the table has a primary key
1964            # check whether all key columns have values
1965            if not set(keyname).issubset(row):
1966                # try using the oid instead
1967                if qoid and 'oid' in row:
1968                    keyname = ('oid',)
1969                else:
1970                    raise KeyError('Missing primary key in row')
1971        params = self.adapter.parameter_list()
1972        adapt = params.add
1973        col = self.escape_identifier
1974        where = ' AND '.join('%s = %s' % (
1975            col(k), adapt(row[k], attnames[k])) for k in keyname)
1976        if 'oid' in row:
1977            if qoid:
1978                row[qoid] = row['oid']
1979            del row['oid']
1980        values = []
1981        keyname = set(keyname)
1982        for n in attnames:
1983            if n in row and n not in keyname:
1984                values.append('%s = %s' % (col(n), adapt(row[n], attnames[n])))
1985        if not values:
1986            return row
1987        values = ', '.join(values)
1988        ret = 'oid, *' if qoid else '*'
1989        q = 'UPDATE %s SET %s WHERE %s RETURNING %s' % (
1990            self._escape_qualified_name(table), values, where, ret)
1991        self._do_debug(q, params)
1992        q = self.db.query(q, params)
1993        res = q.dictresult()
1994        if res:  # may be empty when row does not exist
1995            for n, value in res[0].items():
1996                if qoid and n == 'oid':
1997                    n = qoid
1998                row[n] = value
1999        return row
2000
2001    def upsert(self, table, row=None, **kw):
2002        """Insert a row into a database table with conflict resolution
2003
2004        This method inserts a row into a table, but instead of raising a
2005        ProgrammingError exception in case a row with the same primary key
2006        already exists, an update will be executed instead.  This will be
2007        performed as a single atomic operation on the database, so race
2008        conditions can be avoided.
2009
2010        Like the insert method, the first parameter is the name of the
2011        table and the second parameter can be used to pass the values to
2012        be inserted as a dictionary.
2013
2014        Unlike the insert und update statement, keyword parameters are not
2015        used to modify the dictionary, but to specify which columns shall
2016        be updated in case of a conflict, and in which way:
2017
2018        A value of False or None means the column shall not be updated,
2019        a value of True means the column shall be updated with the value
2020        that has been proposed for insertion, i.e. has been passed as value
2021        in the dictionary.  Columns that are not specified by keywords but
2022        appear as keys in the dictionary are also updated like in the case
2023        keywords had been passed with the value True.
2024
2025        So if in the case of a conflict you want to update every column that
2026        has been passed in the dictionary row , you would call upsert(table, row).
2027        If you don't want to do anything in case of a conflict, i.e. leave
2028        the existing row as it is, call upsert(table, row, **dict.fromkeys(row)).
2029
2030        If you need more fine-grained control of what gets updated, you can
2031        also pass strings in the keyword parameters.  These strings will
2032        be used as SQL expressions for the update columns.  In these
2033        expressions you can refer to the value that already exists in
2034        the table by prefixing the column name with "included.", and to
2035        the value that has been proposed for insertion by prefixing the
2036        column name with the "excluded."
2037
2038        The dictionary is modified in any case to reflect the values in
2039        the database after the operation has completed.
2040
2041        Note: The method uses the PostgreSQL "upsert" feature which is
2042        only available since PostgreSQL 9.5.
2043        """
2044        if table.endswith('*'):  # hint for descendant tables can be ignored
2045            table = table[:-1].rstrip()
2046        if row is None:
2047            row = {}
2048        if 'oid' in row:
2049            del row['oid']  # do not insert oid
2050        if 'oid' in kw:
2051            del kw['oid']  # do not update oid
2052        attnames = self.get_attnames(table)
2053        qoid = _oid_key(table) if 'oid' in attnames else None
2054        params = self.adapter.parameter_list()
2055        adapt = params.add
2056        col = self.escape_identifier
2057        names, values, updates = [], [], []
2058        for n in attnames:
2059            if n in row:
2060                names.append(col(n))
2061                values.append(adapt(row[n], attnames[n]))
2062        names, values = ', '.join(names), ', '.join(values)
2063        try:
2064            keyname = self.pkey(table, True)
2065        except KeyError:
2066            raise _prg_error('Table %s has no primary key' % table)
2067        target = ', '.join(col(k) for k in keyname)
2068        update = []
2069        keyname = set(keyname)
2070        keyname.add('oid')
2071        for n in attnames:
2072            if n not in keyname:
2073                value = kw.get(n, True)
2074                if value:
2075                    if not isinstance(value, basestring):
2076                        value = 'excluded.%s' % col(n)
2077                    update.append('%s = %s' % (col(n), value))
2078        if not values:
2079            return row
2080        do = 'update set %s' % ', '.join(update) if update else 'nothing'
2081        ret = 'oid, *' if qoid else '*'
2082        q = ('INSERT INTO %s AS included (%s) VALUES (%s)'
2083            ' ON CONFLICT (%s) DO %s RETURNING %s') % (
2084                self._escape_qualified_name(table), names, values,
2085                target, do, ret)
2086        self._do_debug(q, params)
2087        try:
2088            q = self.db.query(q, params)
2089        except ProgrammingError:
2090            if self.server_version < 90500:
2091                raise _prg_error(
2092                    'Upsert operation is not supported by PostgreSQL version')
2093            raise  # re-raise original error
2094        res = q.dictresult()
2095        if res:  # may be empty with "do nothing"
2096            for n, value in res[0].items():
2097                if qoid and n == 'oid':
2098                    n = qoid
2099                row[n] = value
2100        else:
2101            self.get(table, row)
2102        return row
2103
2104    def clear(self, table, row=None):
2105        """Clear all the attributes to values determined by the types.
2106
2107        Numeric types are set to 0, Booleans are set to false, and everything
2108        else is set to the empty string.  If the row argument is present,
2109        it is used as the row dictionary and any entries matching attribute
2110        names are cleared with everything else left unchanged.
2111        """
2112        # At some point we will need a way to get defaults from a table.
2113        if row is None:
2114            row = {}  # empty if argument is not present
2115        attnames = self.get_attnames(table)
2116        for n, t in attnames.items():
2117            if n == 'oid':
2118                continue
2119            t = t.simple
2120            if t in DbTypes._num_types:
2121                row[n] = 0
2122            elif t == 'bool':
2123                row[n] = self._make_bool(False)
2124            else:
2125                row[n] = ''
2126        return row
2127
2128    def delete(self, table, row=None, **kw):
2129        """Delete an existing row in a database table.
2130
2131        This method deletes the row from a table.  It deletes based on the
2132        primary key of the table or the OID value as munged by get() or
2133        passed as keyword.
2134
2135        The return value is the number of deleted rows (i.e. 0 if the row
2136        did not exist and 1 if the row was deleted).
2137
2138        Note that if the row cannot be deleted because e.g. it is still
2139        referenced by another table, this method raises a ProgrammingError.
2140        """
2141        if table.endswith('*'):  # hint for descendant tables can be ignored
2142            table = table[:-1].rstrip()
2143        attnames = self.get_attnames(table)
2144        qoid = _oid_key(table) if 'oid' in attnames else None
2145        if row is None:
2146            row = {}
2147        elif 'oid' in row:
2148            del row['oid']  # only accept oid key from named args for safety
2149        row.update(kw)
2150        if qoid and qoid in row and 'oid' not in row:
2151            row['oid'] = row[qoid]
2152        try:  # try using the primary key
2153            keyname = self.pkey(table, True)
2154        except KeyError:  # the table has no primary key
2155            # try using the oid instead
2156            if qoid and 'oid' in row:
2157                keyname = ('oid',)
2158            else:
2159                raise _prg_error('Table %s has no primary key' % table)
2160        else:  # the table has a primary key
2161            # check whether all key columns have values
2162            if not set(keyname).issubset(row):
2163                # try using the oid instead
2164                if qoid and 'oid' in row:
2165                    keyname = ('oid',)
2166                else:
2167                    raise KeyError('Missing primary key in row')
2168        params = self.adapter.parameter_list()
2169        adapt = params.add
2170        col = self.escape_identifier
2171        where = ' AND '.join('%s = %s' % (
2172            col(k), adapt(row[k], attnames[k])) for k in keyname)
2173        if 'oid' in row:
2174            if qoid:
2175                row[qoid] = row['oid']
2176            del row['oid']
2177        q = 'DELETE FROM %s WHERE %s' % (
2178            self._escape_qualified_name(table), where)
2179        self._do_debug(q, params)
2180        res = self.db.query(q, params)
2181        return int(res)
2182
2183    def truncate(self, table, restart=False, cascade=False, only=False):
2184        """Empty a table or set of tables.
2185
2186        This method quickly removes all rows from the given table or set
2187        of tables.  It has the same effect as an unqualified DELETE on each
2188        table, but since it does not actually scan the tables it is faster.
2189        Furthermore, it reclaims disk space immediately, rather than requiring
2190        a subsequent VACUUM operation. This is most useful on large tables.
2191
2192        If restart is set to True, sequences owned by columns of the truncated
2193        table(s) are automatically restarted.  If cascade is set to True, it
2194        also truncates all tables that have foreign-key references to any of
2195        the named tables.  If the parameter only is not set to True, all the
2196        descendant tables (if any) will also be truncated. Optionally, a '*'
2197        can be specified after the table name to explicitly indicate that
2198        descendant tables are included.
2199        """
2200        if isinstance(table, basestring):
2201            only = {table: only}
2202            table = [table]
2203        elif isinstance(table, (list, tuple)):
2204            if isinstance(only, (list, tuple)):
2205                only = dict(zip(table, only))
2206            else:
2207                only = dict.fromkeys(table, only)
2208        elif isinstance(table, (set, frozenset)):
2209            only = dict.fromkeys(table, only)
2210        else:
2211            raise TypeError('The table must be a string, list or set')
2212        if not (restart is None or isinstance(restart, (bool, int))):
2213            raise TypeError('Invalid type for the restart option')
2214        if not (cascade is None or isinstance(cascade, (bool, int))):
2215            raise TypeError('Invalid type for the cascade option')
2216        tables = []
2217        for t in table:
2218            u = only.get(t)
2219            if not (u is None or isinstance(u, (bool, int))):
2220                raise TypeError('Invalid type for the only option')
2221            if t.endswith('*'):
2222                if u:
2223                    raise ValueError(
2224                        'Contradictory table name and only options')
2225                t = t[:-1].rstrip()
2226            t = self._escape_qualified_name(t)
2227            if u:
2228                t = 'ONLY %s' % t
2229            tables.append(t)
2230        q = ['TRUNCATE', ', '.join(tables)]
2231        if restart:
2232            q.append('RESTART IDENTITY')
2233        if cascade:
2234            q.append('CASCADE')
2235        q = ' '.join(q)
2236        self._do_debug(q)
2237        return self.db.query(q)
2238
2239    def get_as_list(self, table, what=None, where=None,
2240            order=None, limit=None, offset=None, scalar=False):
2241        """Get a table as a list.
2242
2243        This gets a convenient representation of the table as a list
2244        of named tuples in Python.  You only need to pass the name of
2245        the table (or any other SQL expression returning rows).  Note that
2246        by default this will return the full content of the table which
2247        can be huge and overflow your memory.  However, you can control
2248        the amount of data returned using the other optional parameters.
2249
2250        The parameter 'what' can restrict the query to only return a
2251        subset of the table columns.  It can be a string, list or a tuple.
2252        The parameter 'where' can restrict the query to only return a
2253        subset of the table rows.  It can be a string, list or a tuple
2254        of SQL expressions that all need to be fulfilled.  The parameter
2255        'order' specifies the ordering of the rows.  It can also be a
2256        other string, list or a tuple.  If no ordering is specified,
2257        the result will be ordered by the primary key(s) or all columns
2258        if no primary key exists.  You can set 'order' to False if you
2259        don't care about the ordering.  The parameters 'limit' and 'offset'
2260        can be integers specifying the maximum number of rows returned
2261        and a number of rows skipped over.
2262
2263        If you set the 'scalar' option to True, then instead of the
2264        named tuples you will get the first items of these tuples.
2265        This is useful if the result has only one column anyway.
2266        """
2267        if not table:
2268            raise TypeError('The table name is missing')
2269        if what:
2270            if isinstance(what, (list, tuple)):
2271                what = ', '.join(map(str, what))
2272            if order is None:
2273                order = what
2274        else:
2275            what = '*'
2276        q = ['SELECT', what, 'FROM', table]
2277        if where:
2278            if isinstance(where, (list, tuple)):
2279                where = ' AND '.join(map(str, where))
2280            q.extend(['WHERE', where])
2281        if order is None:
2282            try:
2283                order = self.pkey(table, True)
2284            except (KeyError, ProgrammingError):
2285                try:
2286                    order = list(self.get_attnames(table))
2287                except (KeyError, ProgrammingError):
2288                    pass
2289        if order:
2290            if isinstance(order, (list, tuple)):
2291                order = ', '.join(map(str, order))
2292            q.extend(['ORDER BY', order])
2293        if limit:
2294            q.append('LIMIT %d' % limit)
2295        if offset:
2296            q.append('OFFSET %d' % offset)
2297        q = ' '.join(q)
2298        self._do_debug(q)
2299        q = self.db.query(q)
2300        res = q.namedresult()
2301        if res and scalar:
2302            res = [row[0] for row in res]
2303        return res
2304
2305    def get_as_dict(self, table, keyname=None, what=None, where=None,
2306            order=None, limit=None, offset=None, scalar=False):
2307        """Get a table as a dictionary.
2308
2309        This method is similar to get_as_list(), but returns the table
2310        as a Python dict instead of a Python list, which can be even
2311        more convenient. The primary key column(s) of the table will
2312        be used as the keys of the dictionary, while the other column(s)
2313        will be the corresponding values.  The keys will be named tuples
2314        if the table has a composite primary key.  The rows will be also
2315        named tuples unless the 'scalar' option has been set to True.
2316        With the optional parameter 'keyname' you can specify an alternative
2317        set of columns to be used as the keys of the dictionary.  It must
2318        be set as a string, list or a tuple.
2319
2320        If the Python version supports it, the dictionary will be an
2321        OrderedDict using the order specified with the 'order' parameter
2322        or the key column(s) if not specified.  You can set 'order' to False
2323        if you don't care about the ordering.  In this case the returned
2324        dictionary will be an ordinary one.
2325        """
2326        if not table:
2327            raise TypeError('The table name is missing')
2328        if not keyname:
2329            try:
2330                keyname = self.pkey(table, True)
2331            except (KeyError, ProgrammingError):
2332                raise _prg_error('Table %s has no primary key' % table)
2333        if isinstance(keyname, basestring):
2334            keyname = [keyname]
2335        elif not isinstance(keyname, (list, tuple)):
2336            raise KeyError('The keyname must be a string, list or tuple')
2337        if what:
2338            if isinstance(what, (list, tuple)):
2339                what = ', '.join(map(str, what))
2340            if order is None:
2341                order = what
2342        else:
2343            what = '*'
2344        q = ['SELECT', what, 'FROM', table]
2345        if where:
2346            if isinstance(where, (list, tuple)):
2347                where = ' AND '.join(map(str, where))
2348            q.extend(['WHERE', where])
2349        if order is None:
2350            order = keyname
2351        if order:
2352            if isinstance(order, (list, tuple)):
2353                order = ', '.join(map(str, order))
2354            q.extend(['ORDER BY', order])
2355        if limit:
2356            q.append('LIMIT %d' % limit)
2357        if offset:
2358            q.append('OFFSET %d' % offset)
2359        q = ' '.join(q)
2360        self._do_debug(q)
2361        q = self.db.query(q)
2362        res = q.getresult()
2363        cls = OrderedDict if order else dict
2364        if not res:
2365            return cls()
2366        keyset = set(keyname)
2367        fields = q.listfields()
2368        if not keyset.issubset(fields):
2369            raise KeyError('Missing keyname in row')
2370        keyind, rowind = [], []
2371        for i, f in enumerate(fields):
2372            (keyind if f in keyset else rowind).append(i)
2373        keytuple = len(keyind) > 1
2374        getkey = itemgetter(*keyind)
2375        keys = map(getkey, res)
2376        if scalar:
2377            rowind = rowind[:1]
2378            rowtuple = False
2379        else:
2380            rowtuple = len(rowind) > 1
2381        if scalar or rowtuple:
2382            getrow = itemgetter(*rowind)
2383        else:
2384            rowind = rowind[0]
2385            getrow = lambda row: (row[rowind],)
2386            rowtuple = True
2387        rows = map(getrow, res)
2388        if keytuple or rowtuple:
2389            namedresult = get_namedresult()
2390            if namedresult:
2391                if keytuple:
2392                    keys = namedresult(_MemoryQuery(keys, keyname))
2393                if rowtuple:
2394                    fields = [f for f in fields if f not in keyset]
2395                    rows = namedresult(_MemoryQuery(rows, fields))
2396        return cls(zip(keys, rows))
2397
2398    def notification_handler(self,
2399            event, callback, arg_dict=None, timeout=None, stop_event=None):
2400        """Get notification handler that will run the given callback."""
2401        return NotificationHandler(self,
2402            event, callback, arg_dict, timeout, stop_event)
2403
2404
2405# if run as script, print some information
2406
2407if __name__ == '__main__':
2408    print('PyGreSQL version' + version)
2409    print('')
2410    print(__doc__)
Note: See TracBrowser for help on using the repository browser.