source: trunk/pg.py @ 822

Last change on this file since 822 was 822, checked in by cito, 3 years ago

Make version available as version in both modules

  • Property svn:keywords set to Id
File size: 87.2 KB
Line 
1#! /usr/bin/python
2#
3# pg.py
4#
5# $Id: pg.py 822 2016-02-05 17:05:45Z cito $
6#
7
8"""PyGreSQL classic interface.
9
10This pg module implements some basic database management stuff.
11It includes the _pg module and builds on it, providing the higher
12level wrapper class named DB with additional functionality.
13This is known as the "classic" ("old style") PyGreSQL interface.
14For a DB-API 2 compliant interface use the newer pgdb module.
15"""
16
17# Copyright (c) 1997-2016 by D'Arcy J.M. Cain.
18#
19# Contributions made by Ch. Zwerschke and others.
20#
21# The notification handler is based on pgnotify which is
22# Copyright (c) 2001 Ng Pheng Siong. All rights reserved.
23#
24# Permission to use, copy, modify, and distribute this software and its
25# documentation for any purpose and without fee is hereby granted,
26# provided that the above copyright notice appear in all copies and that
27# both that copyright notice and this permission notice appear in
28# supporting documentation.
29
30from __future__ import print_function
31
32from _pg import *
33
34__version__ = version
35
36import select
37import warnings
38
39from datetime import date, time, datetime, timedelta
40from decimal import Decimal
41from math import isnan, isinf
42from collections import namedtuple
43from operator import itemgetter
44from functools import partial
45from re import compile as regex
46from json import loads as jsondecode, dumps as jsonencode
47from uuid import UUID
48
49try:
50    long
51except NameError:  # Python >= 3.0
52    long = int
53
54try:
55    basestring
56except NameError:  # Python >= 3.0
57    basestring = (str, bytes)
58
59
60# Auxiliary classes and functions that are independent from a DB connection:
61
62try:
63    from collections import OrderedDict
64except ImportError:  # Python 2.6 or 3.0
65    OrderedDict = dict
66
67
68    class AttrDict(dict):
69        """Simple read-only ordered dictionary for storing attribute names."""
70
71        def __init__(self, *args, **kw):
72            if len(args) > 1 or kw:
73                raise TypeError
74            items = args[0] if args else []
75            if isinstance(items, dict):
76                raise TypeError
77            items = list(items)
78            self._keys = [item[0] for item in items]
79            dict.__init__(self, items)
80            self._read_only = True
81            error = self._read_only_error
82            self.clear = self.update = error
83            self.pop = self.setdefault = self.popitem = error
84
85        def __setitem__(self, key, value):
86            if self._read_only:
87                self._read_only_error()
88            dict.__setitem__(self, key, value)
89
90        def __delitem__(self, key):
91            if self._read_only:
92                self._read_only_error()
93            dict.__delitem__(self, key)
94
95        def __iter__(self):
96            return iter(self._keys)
97
98        def keys(self):
99            return list(self._keys)
100
101        def values(self):
102            return [self[key] for key in self]
103
104        def items(self):
105            return [(key, self[key]) for key in self]
106
107        def iterkeys(self):
108            return self.__iter__()
109
110        def itervalues(self):
111            return iter(self.values())
112
113        def iteritems(self):
114            return iter(self.items())
115
116        @staticmethod
117        def _read_only_error(*args, **kw):
118            raise TypeError('This object is read-only')
119
120else:
121
122     class AttrDict(OrderedDict):
123        """Simple read-only ordered dictionary for storing attribute names."""
124
125        def __init__(self, *args, **kw):
126            self._read_only = False
127            OrderedDict.__init__(self, *args, **kw)
128            self._read_only = True
129            error = self._read_only_error
130            self.clear = self.update = error
131            self.pop = self.setdefault = self.popitem = error
132
133        def __setitem__(self, key, value):
134            if self._read_only:
135                self._read_only_error()
136            OrderedDict.__setitem__(self, key, value)
137
138        def __delitem__(self, key):
139            if self._read_only:
140                self._read_only_error()
141            OrderedDict.__delitem__(self, key)
142
143        @staticmethod
144        def _read_only_error(*args, **kw):
145            raise TypeError('This object is read-only')
146
147try:
148    from inspect import signature
149except ImportError:  # Python < 3.3
150    from inspect import getargspec
151
152    def get_args(func):
153        return getargspec(func).args
154else:
155
156    def get_args(func):
157        return list(signature(func).parameters)
158
159try:
160    if datetime.strptime('+0100', '%z') is None:
161        raise ValueError
162except ValueError:  # Python < 3.2
163    timezones = None
164else:
165    # time zones used in Postgres timestamptz output
166    timezones = dict(CET='+0100', EET='+0200', EST='-0500',
167        GMT='+0000', HST='-1000', MET='+0100', MST='-0700',
168        UCT='+0000', UTC='+0000', WET='+0000')
169
170
171def _oid_key(table):
172    """Build oid key from a table name."""
173    return 'oid(%s)' % table
174
175
176class _SimpleTypes(dict):
177    """Dictionary mapping pg_type names to simple type names."""
178
179    _types = {'bool': 'bool',
180        'bytea': 'bytea',
181        'date': 'date interval time timetz timestamp timestamptz'
182            ' abstime reltime',  # these are very old
183        'float': 'float4 float8',
184        'int': 'cid int2 int4 int8 oid xid',
185        'hstore': 'hstore', 'json': 'json jsonb', 'uuid': 'uuid',
186        'num': 'numeric', 'money': 'money',
187        'text': 'bpchar char name text varchar'}
188
189    def __init__(self):
190        for typ, keys in self._types.items():
191            for key in keys.split():
192                self[key] = typ
193                self['_%s' % key] = '%s[]' % typ
194
195    @staticmethod
196    def __missing__(key):
197        return 'text'
198
199_simpletypes = _SimpleTypes()
200
201
202def _quote_if_unqualified(param, name):
203    """Quote parameter representing a qualified name.
204
205    Puts a quote_ident() call around the give parameter unless
206    the name contains a dot, in which case the name is ambiguous
207    (could be a qualified name or just a name with a dot in it)
208    and must be quoted manually by the caller.
209    """
210    if isinstance(name, basestring) and '.' not in name:
211        return 'quote_ident(%s)' % (param,)
212    return param
213
214
215class _ParameterList(list):
216    """Helper class for building typed parameter lists."""
217
218    def add(self, value, typ=None):
219        """Typecast value with known database type and build parameter list.
220
221        If this is a literal value, it will be returned as is.  Otherwise, a
222        placeholder will be returned and the parameter list will be augmented.
223        """
224        value = self.adapt(value, typ)
225        if isinstance(value, Literal):
226            return value
227        self.append(value)
228        return '$%d' % len(self)
229
230
231class Bytea(bytes):
232    """Wrapper class for marking Bytea values."""
233
234
235class Hstore(dict):
236    """Wrapper class for marking hstore values."""
237
238    _re_quote = regex('^[Nn][Uu][Ll][Ll]$|[ ,=>]')
239
240    @classmethod
241    def _quote(cls, s):
242        if s is None:
243            return 'NULL'
244        if not s:
245            return '""'
246        s = s.replace('"', '\\"')
247        if cls._re_quote.search(s):
248            s = '"%s"' % s
249        return s
250
251    def __str__(self):
252        q = self._quote
253        return ','.join('%s=>%s' % (q(k), q(v)) for k, v in self.items())
254
255
256class Json:
257    """Wrapper class for marking Json values."""
258
259    def __init__(self, obj):
260        self.obj = obj
261
262
263class Literal(str):
264    """Wrapper class for marking literal SQL values."""
265
266
267class Adapter:
268    """Class providing methods for adapting parameters to the database."""
269
270    _bool_true_values = frozenset('t true 1 y yes on'.split())
271
272    _date_literals = frozenset('current_date current_time'
273        ' current_timestamp localtime localtimestamp'.split())
274
275    _re_array_quote = regex(r'[{},"\\\s]|^[Nn][Uu][Ll][Ll]$')
276    _re_record_quote = regex(r'[(,"\\]')
277    _re_array_escape = _re_record_escape = regex(r'(["\\])')
278
279    def __init__(self, db):
280        self.db = db
281        self.encode_json = db.encode_json
282        db = db.db
283        self.escape_bytea = db.escape_bytea
284        self.escape_string = db.escape_string
285
286    @classmethod
287    def _adapt_bool(cls, v):
288        """Adapt a boolean parameter."""
289        if isinstance(v, basestring):
290            if not v:
291                return None
292            v = v.lower() in cls._bool_true_values
293        return 't' if v else 'f'
294
295    @classmethod
296    def _adapt_date(cls, v):
297        """Adapt a date parameter."""
298        if not v:
299            return None
300        if isinstance(v, basestring) and v.lower() in cls._date_literals:
301            return Literal(v)
302        return v
303
304    @staticmethod
305    def _adapt_num(v):
306        """Adapt a numeric parameter."""
307        if not v and v != 0:
308            return None
309        return v
310
311    _adapt_int = _adapt_float = _adapt_money = _adapt_num
312
313    def _adapt_bytea(self, v):
314        """Adapt a bytea parameter."""
315        return self.escape_bytea(v)
316
317    def _adapt_json(self, v):
318        """Adapt a json parameter."""
319        if not v:
320            return None
321        if isinstance(v, basestring):
322            return v
323        return self.encode_json(v)
324
325    @classmethod
326    def _adapt_text_array(cls, v):
327        """Adapt a text type array parameter."""
328        if isinstance(v, list):
329            adapt = cls._adapt_text_array
330            return '{%s}' % ','.join(adapt(v) for v in v)
331        if v is None:
332            return 'null'
333        if not v:
334            return '""'
335        v = str(v)
336        if cls._re_array_quote.search(v):
337            v = '"%s"' % cls._re_array_escape.sub(r'\\\1', v)
338        return v
339
340    _adapt_date_array = _adapt_text_array
341
342    @classmethod
343    def _adapt_bool_array(cls, v):
344        """Adapt a boolean array parameter."""
345        if isinstance(v, list):
346            adapt = cls._adapt_bool_array
347            return '{%s}' % ','.join(adapt(v) for v in v)
348        if v is None:
349            return 'null'
350        if isinstance(v, basestring):
351            if not v:
352                return 'null'
353            v = v.lower() in cls._bool_true_values
354        return 't' if v else 'f'
355
356    @classmethod
357    def _adapt_num_array(cls, v):
358        """Adapt a numeric array parameter."""
359        if isinstance(v, list):
360            adapt = cls._adapt_num_array
361            return '{%s}' % ','.join(adapt(v) for v in v)
362        if not v and v != 0:
363            return 'null'
364        return str(v)
365
366    _adapt_int_array = _adapt_float_array = _adapt_money_array = \
367            _adapt_num_array
368
369    def _adapt_bytea_array(self, v):
370        """Adapt a bytea array parameter."""
371        if isinstance(v, list):
372            return b'{' + b','.join(
373                self._adapt_bytea_array(v) for v in v) + b'}'
374        if v is None:
375            return b'null'
376        return self.escape_bytea(v).replace(b'\\', b'\\\\')
377
378    def _adapt_json_array(self, v):
379        """Adapt a json array parameter."""
380        if isinstance(v, list):
381            adapt = self._adapt_json_array
382            return '{%s}' % ','.join(adapt(v) for v in v)
383        if not v:
384            return 'null'
385        if not isinstance(v, basestring):
386            v = self.encode_json(v)
387        if self._re_array_quote.search(v):
388            v = '"%s"' % self._re_array_escape.sub(r'\\\1', v)
389        return v
390
391    def _adapt_record(self, v, typ):
392        """Adapt a record parameter with given type."""
393        typ = self.get_attnames(typ).values()
394        if len(typ) != len(v):
395            raise TypeError('Record parameter %s has wrong size' % v)
396        adapt = self.adapt
397        value = []
398        for v, t in zip(v, typ):
399            v = adapt(v, t)
400            if v is None:
401                v = ''
402            elif not v:
403                v = '""'
404            else:
405                if isinstance(v, bytes):
406                    if str is not bytes:
407                        v = v.decode('ascii')
408                else:
409                    v = str(v)
410                if self._re_record_quote.search(v):
411                    v = '"%s"' % self._re_record_escape.sub(r'\\\1', v)
412            value.append(v)
413        return '(%s)' % ','.join(value)
414
415    def adapt(self, value, typ=None):
416        """Adapt a value with known database type."""
417        if value is not None and not isinstance(value, Literal):
418            if typ:
419                simple = self.get_simple_name(typ)
420            else:
421                typ = simple = self.guess_simple_type(value) or 'text'
422            try:
423                value = value.__pg_str__(typ)
424            except AttributeError:
425                pass
426            if simple == 'text':
427                pass
428            elif simple == 'record':
429                if isinstance(value, tuple):
430                    value = self._adapt_record(value, typ)
431            elif simple.endswith('[]'):
432                if isinstance(value, list):
433                    adapt = getattr(self, '_adapt_%s_array' % simple[:-2])
434                    value = adapt(value)
435            else:
436                adapt = getattr(self, '_adapt_%s' % simple)
437                value = adapt(value)
438        return value
439
440    @staticmethod
441    def simple_type(name):
442        """Create a simple database type with given attribute names."""
443        typ = DbType(name)
444        typ.simple = name
445        return typ
446
447    @staticmethod
448    def get_simple_name(typ):
449        """Get the simple name of a database type."""
450        if isinstance(typ, DbType):
451            return typ.simple
452        return _simpletypes[typ]
453
454    @staticmethod
455    def get_attnames(typ):
456        """Get the attribute names of a composite database type."""
457        if isinstance(typ, DbType):
458            return typ.attnames
459        return {}
460
461    @classmethod
462    def guess_simple_type(cls, value):
463        """Try to guess which database type the given value has."""
464        if isinstance(value, Bytea):
465            return 'bytea'
466        if isinstance(value, basestring):
467            return 'text'
468        if isinstance(value, bool):
469            return 'bool'
470        if isinstance(value, (int, long)):
471            return 'int'
472        if isinstance(value, float):
473            return 'float'
474        if isinstance(value, Decimal):
475            return 'num'
476        if isinstance(value, (date, time, datetime, timedelta)):
477            return 'date'
478        if isinstance(value, list):
479            return '%s[]' % cls.guess_simple_base_type(value)
480        if isinstance(value, tuple):
481            simple_type = cls.simple_type
482            typ = simple_type('record')
483            guess = cls.guess_simple_type
484            def get_attnames(self):
485                return AttrDict((str(n + 1), simple_type(guess(v)))
486                    for n, v in enumerate(value))
487            typ._get_attnames = get_attnames
488            return typ
489
490    @classmethod
491    def guess_simple_base_type(cls, value):
492        """Try to guess the base type of a given array."""
493        for v in value:
494            if isinstance(v, list):
495                typ = cls.guess_simple_base_type(v)
496            else:
497                typ = cls.guess_simple_type(v)
498            if typ:
499                return typ
500
501    def adapt_inline(self, value, nested=False):
502        """Adapt a value that is put into the SQL and needs to be quoted."""
503        if value is None:
504            return 'NULL'
505        if isinstance(value, Literal):
506            return value
507        if isinstance(value, Bytea):
508            value = self.escape_bytea(value)
509            if bytes is not str:  # Python >= 3.0
510                value = value.decode('ascii')
511        elif isinstance(value, Json):
512            if value.encode:
513                return value.encode()
514            value = self.encode_json(value)
515        elif isinstance(value, (datetime, date, time, timedelta)):
516            value = str(value)
517        if isinstance(value, basestring):
518            value = self.escape_string(value)
519            return "'%s'" % value
520        if isinstance(value, bool):
521            return 'true' if value else 'false'
522        if isinstance(value, float):
523            if isinf(value):
524                return "'-Infinity'" if value < 0 else "'Infinity'"
525            if isnan(value):
526                return "'NaN'"
527            return value
528        if isinstance(value, (int, long, Decimal)):
529            return value
530        if isinstance(value, list):
531            q = self.adapt_inline
532            s = '[%s]' if nested else 'ARRAY[%s]'
533            return s % ','.join(str(q(v, nested=True)) for v in value)
534        if isinstance(value, tuple):
535            q = self.adapt_inline
536            return '(%s)' % ','.join(str(q(v)) for v in value)
537        try:
538            value = value.__pg_repr__()
539        except AttributeError:
540            raise InterfaceError(
541                'Do not know how to adapt type %s' % type(value))
542        if isinstance(value, (tuple, list)):
543            value = self.adapt_inline(value)
544        return value
545
546    def parameter_list(self):
547        """Return a parameter list for parameters with known database types.
548
549        The list has an add(value, typ) method that will build up the
550        list and return either the literal value or a placeholder.
551        """
552        params = _ParameterList()
553        params.adapt = self.adapt
554        return params
555
556    def format_query(self, command, values, types=None, inline=False):
557        """Format a database query using the given values and types."""
558        if inline and types:
559            raise ValueError('Typed parameters must be sent separately')
560        params = self.parameter_list()
561        if isinstance(values, (list, tuple)):
562            if inline:
563                adapt = self.adapt_inline
564                literals = [adapt(value) for value in values]
565            else:
566                add = params.add
567                literals = []
568                append = literals.append
569                if types:
570                    if (not isinstance(types, (list, tuple)) or
571                            len(types) != len(values)):
572                        raise TypeError('The values and types do not match')
573                    for value, typ in zip(values, types):
574                        append(add(value, typ))
575                else:
576                    for value in values:
577                        append(add(value))
578            command = command % tuple(literals)
579        elif isinstance(values, dict):
580            if inline:
581                adapt = self.adapt_inline
582                literals = dict((key, adapt(value))
583                    for key, value in values.items())
584            else:
585                add = params.add
586                literals = {}
587                if types:
588                    if (not isinstance(types, dict) or
589                            len(types) < len(values)):
590                        raise TypeError('The values and types do not match')
591                    for key in sorted(values):
592                        literals[key] = add(values[key], types[key])
593                else:
594                    for key in sorted(values):
595                        literals[key] = add(values[key])
596            command = command % literals
597        else:
598            raise TypeError('The values must be passed as tuple, list or dict')
599        return command, params
600
601
602def cast_bool(value):
603    """Cast a boolean value."""
604    if not get_bool():
605        return value
606    return value[0] == 't'
607
608
609def cast_json(value):
610    """Cast a JSON value."""
611    cast = get_jsondecode()
612    if not cast:
613        return value
614    return cast(value)
615
616
617def cast_num(value):
618    """Cast a numeric value."""
619    return (get_decimal() or float)(value)
620
621
622def cast_money(value):
623    """Cast a money value."""
624    point = get_decimal_point()
625    if not point:
626        return value
627    if point != '.':
628        value = value.replace(point, '.')
629    value = value.replace('(', '-')
630    value = ''.join(c for c in value if c.isdigit() or c in '.-')
631    return (get_decimal() or float)(value)
632
633
634def cast_int2vector(value):
635    """Cast an int2vector value."""
636    return [int(v) for v in value.split()]
637
638
639def cast_date(value, connection):
640    """Cast a date value."""
641    # The output format depends on the server setting DateStyle.  The default
642    # setting ISO and the setting for German are actually unambiguous.  The
643    # order of days and months in the other two settings is however ambiguous,
644    # so at least here we need to consult the setting to properly parse values.
645    if value == '-infinity':
646        return date.min
647    if value == 'infinity':
648        return date.max
649    value = value.split()
650    if value[-1] == 'BC':
651        return date.min
652    value = value[0]
653    if len(value) > 10:
654        return date.max
655    fmt = connection.date_format()
656    return datetime.strptime(value, fmt).date()
657
658
659def cast_time(value):
660    """Cast a time value."""
661    fmt = '%H:%M:%S.%f' if len(value) > 8 else '%H:%M:%S'
662    return datetime.strptime(value, fmt).time()
663
664
665_re_timezone = regex('(.*)([+-].*)')
666
667
668def cast_timetz(value):
669    """Cast a timetz value."""
670    tz = _re_timezone.match(value)
671    if tz:
672        value, tz = tz.groups()
673    else:
674        tz = '+0000'
675    fmt = '%H:%M:%S.%f' if len(value) > 8 else '%H:%M:%S'
676    if timezones:
677        if tz.startswith(('+', '-')):
678            if len(tz) < 5:
679                tz += '00'
680            else:
681                tz = tz.replace(':', '')
682        elif tz in timezones:
683            tz = timezones[tz]
684        else:
685            tz = '+0000'
686        value += tz
687        fmt += '%z'
688    return datetime.strptime(value, fmt).timetz()
689
690
691def cast_timestamp(value, connection):
692    """Cast a timestamp value."""
693    if value == '-infinity':
694        return datetime.min
695    if value == 'infinity':
696        return datetime.max
697    value = value.split()
698    if value[-1] == 'BC':
699        return datetime.min
700    fmt = connection.date_format()
701    if fmt.endswith('-%Y') and len(value) > 2:
702        value = value[1:5]
703        if len(value[3]) > 4:
704            return datetime.max
705        fmt = ['%d %b' if fmt.startswith('%d') else '%b %d',
706            '%H:%M:%S.%f' if len(value[2]) > 8 else '%H:%M:%S', '%Y']
707    else:
708        if len(value[0]) > 10:
709            return datetime.max
710        fmt = [fmt, '%H:%M:%S.%f' if len(value[1]) > 8 else '%H:%M:%S']
711    return datetime.strptime(' '.join(value), ' '.join(fmt))
712
713
714def cast_timestamptz(value, connection):
715    """Cast a timestamptz value."""
716    if value == '-infinity':
717        return datetime.min
718    if value == 'infinity':
719        return datetime.max
720    value = value.split()
721    if value[-1] == 'BC':
722        return datetime.min
723    fmt = connection.date_format()
724    if fmt.endswith('-%Y') and len(value) > 2:
725        value = value[1:]
726        if len(value[3]) > 4:
727            return datetime.max
728        fmt = ['%d %b' if fmt.startswith('%d') else '%b %d',
729            '%H:%M:%S.%f' if len(value[2]) > 8 else '%H:%M:%S', '%Y']
730        value, tz = value[:-1], value[-1]
731    else:
732        if fmt.startswith('%Y-'):
733            tz = _re_timezone.match(value[1])
734            if tz:
735                value[1], tz = tz.groups()
736            else:
737                tz = '+0000'
738        else:
739            value, tz = value[:-1], value[-1]
740        if len(value[0]) > 10:
741            return datetime.max
742        fmt = [fmt, '%H:%M:%S.%f' if len(value[1]) > 8 else '%H:%M:%S']
743    if timezones:
744        if tz.startswith(('+', '-')):
745            if len(tz) < 5:
746                tz += '00'
747            else:
748                tz = tz.replace(':', '')
749        elif tz in timezones:
750            tz = timezones[tz]
751        else:
752            tz = '+0000'
753        value.append(tz)
754        fmt.append('%z')
755    return datetime.strptime(' '.join(value), ' '.join(fmt))
756
757_re_interval_sql_standard = regex(
758    '(?:([+-])?([0-9]+)-([0-9]+) ?)?'
759    '(?:([+-]?[0-9]+)(?!:) ?)?'
760    '(?:([+-])?([0-9]+):([0-9]+):([0-9]+)(?:\\.([0-9]+))?)?')
761
762_re_interval_postgres = regex(
763    '(?:([+-]?[0-9]+) ?years? ?)?'
764    '(?:([+-]?[0-9]+) ?mons? ?)?'
765    '(?:([+-]?[0-9]+) ?days? ?)?'
766    '(?:([+-])?([0-9]+):([0-9]+):([0-9]+)(?:\\.([0-9]+))?)?')
767
768_re_interval_postgres_verbose = regex(
769    '@ ?(?:([+-]?[0-9]+) ?years? ?)?'
770    '(?:([+-]?[0-9]+) ?mons? ?)?'
771    '(?:([+-]?[0-9]+) ?days? ?)?'
772    '(?:([+-]?[0-9]+) ?hours? ?)?'
773    '(?:([+-]?[0-9]+) ?mins? ?)?'
774    '(?:([+-])?([0-9]+)(?:\\.([0-9]+))? ?secs?)? ?(ago)?')
775
776_re_interval_iso_8601 = regex(
777    'P(?:([+-]?[0-9]+)Y)?'
778    '(?:([+-]?[0-9]+)M)?'
779    '(?:([+-]?[0-9]+)D)?'
780    '(?:T(?:([+-]?[0-9]+)H)?'
781    '(?:([+-]?[0-9]+)M)?'
782    '(?:([+-])?([0-9]+)(?:\\.([0-9]+))?S)?)?')
783
784
785def cast_interval(value):
786    """Cast an interval value."""
787    # The output format depends on the server setting IntervalStyle, but it's
788    # not necessary to consult this setting to parse it.  It's faster to just
789    # check all possible formats, and there is no ambiguity here.
790    m = _re_interval_iso_8601.match(value)
791    if m:
792        m = [d or '0' for d in m.groups()]
793        secs_ago = m.pop(5) == '-'
794        m = [int(d) for d in m]
795        years, mons, days, hours, mins, secs, usecs = m
796        if secs_ago:
797            secs = -secs
798            usecs = -usecs
799    else:
800        m = _re_interval_postgres_verbose.match(value)
801        if m:
802            m, ago = [d or '0' for d in m.groups()[:8]], m.group(9)
803            secs_ago = m.pop(5) == '-'
804            m = [-int(d) for d in m] if ago else [int(d) for d in m]
805            years, mons, days, hours, mins, secs, usecs = m
806            if secs_ago:
807                secs = - secs
808                usecs = -usecs
809        else:
810            m = _re_interval_postgres.match(value)
811            if m and any(m.groups()):
812                m = [d or '0' for d in m.groups()]
813                hours_ago = m.pop(3) == '-'
814                m = [int(d) for d in m]
815                years, mons, days, hours, mins, secs, usecs = m
816                if hours_ago:
817                    hours = -hours
818                    mins = -mins
819                    secs = -secs
820                    usecs = -usecs
821            else:
822                m = _re_interval_sql_standard.match(value)
823                if m and any(m.groups()):
824                    m = [d or '0' for d in m.groups()]
825                    years_ago = m.pop(0) == '-'
826                    hours_ago = m.pop(3) == '-'
827                    m = [int(d) for d in m]
828                    years, mons, days, hours, mins, secs, usecs = m
829                    if years_ago:
830                        years = -years
831                        mons = -mons
832                    if hours_ago:
833                        hours = -hours
834                        mins = -mins
835                        secs = -secs
836                        usecs = -usecs
837                else:
838                    raise ValueError('Cannot parse interval: %s' % value)
839    days += 365 * years + 30 * mons
840    return timedelta(days=days, hours=hours, minutes=mins,
841        seconds=secs, microseconds=usecs)
842
843
844class Typecasts(dict):
845    """Dictionary mapping database types to typecast functions.
846
847    The cast functions get passed the string representation of a value in
848    the database which they need to convert to a Python object.  The
849    passed string will never be None since NULL values are already be
850    handled before the cast function is called.
851
852    Note that the basic types are already handled by the C extension.
853    They only need to be handled here as record or array components.
854    """
855
856    # the default cast functions
857    # (str functions are ignored but have been added for faster access)
858    defaults = {'char': str, 'bpchar': str, 'name': str,
859        'text': str, 'varchar': str,
860        'bool': cast_bool, 'bytea': unescape_bytea,
861        'int2': int, 'int4': int, 'serial': int, 'int8': long, 'oid': int,
862        'hstore': cast_hstore, 'json': cast_json, 'jsonb': cast_json,
863        'float4': float, 'float8': float,
864        'numeric': cast_num, 'money': cast_money,
865        'date': cast_date, 'interval': cast_interval,
866        'time': cast_time, 'timetz': cast_timetz,
867        'timestamp': cast_timestamp, 'timestamptz': cast_timestamptz,
868        'int2vector': cast_int2vector, 'uuid': UUID,
869        'anyarray': cast_array, 'record': cast_record}
870
871    connection = None  # will be set in a connection specific instance
872
873    def __missing__(self, typ):
874        """Create a cast function if it is not cached.
875       
876        Note that this class never raises a KeyError,
877        but returns None when no special cast function exists.
878        """
879        if not isinstance(typ, str):
880            raise TypeError('Invalid type: %s' % typ)
881        cast = self.defaults.get(typ)
882        if cast:
883            # store default for faster access
884            cast = self._add_connection(cast)
885            self[typ] = cast
886        elif typ.startswith('_'):
887            base_cast = self[typ[1:]]
888            cast = self.create_array_cast(base_cast)
889            if base_cast:
890                self[typ] = cast
891        else:
892            attnames = self.get_attnames(typ)
893            if attnames:
894                casts = [self[v.pgtype] for v in attnames.values()]
895                cast = self.create_record_cast(typ, attnames, casts)
896                self[typ] = cast
897        return cast
898
899    @staticmethod
900    def _needs_connection(func):
901        """Check if a typecast function needs a connection argument."""
902        try:
903            args = get_args(func)
904        except (TypeError, ValueError):
905            return False
906        else:
907            return 'connection' in args[1:]
908
909    def _add_connection(self, cast):
910        """Add a connection argument to the typecast function if necessary."""
911        if not self.connection or not self._needs_connection(cast):
912            return cast
913        return partial(cast, connection=self.connection)
914
915    def get(self, typ, default=None):
916        """Get the typecast function for the given database type."""
917        return self[typ] or default
918
919    def set(self, typ, cast):
920        """Set a typecast function for the specified database type(s)."""
921        if isinstance(typ, basestring):
922            typ = [typ]
923        if cast is None:
924            for t in typ:
925                self.pop(t, None)
926                self.pop('_%s' % t, None)
927        else:
928            if not callable(cast):
929                raise TypeError("Cast parameter must be callable")
930            for t in typ:
931                self[t] = self._add_connection(cast)
932                self.pop('_%s' % t, None)
933
934    def reset(self, typ=None):
935        """Reset the typecasts for the specified type(s) to their defaults.
936
937        When no type is specified, all typecasts will be reset.
938        """
939        if typ is None:
940            self.clear()
941        else:
942            if isinstance(typ, basestring):
943                typ = [typ]
944            for t in typ:
945                self.pop(t, None)
946
947    @classmethod
948    def get_default(cls, typ):
949        """Get the default typecast function for the given database type."""
950        return cls.defaults.get(typ)
951
952    @classmethod
953    def set_default(cls, typ, cast):
954        """Set a default typecast function for the given database type(s)."""
955        if isinstance(typ, basestring):
956            typ = [typ]
957        defaults = cls.defaults
958        if cast is None:
959            for t in typ:
960                defaults.pop(t, None)
961                defaults.pop('_%s' % t, None)
962        else:
963            if not callable(cast):
964                raise TypeError("Cast parameter must be callable")
965            for t in typ:
966                defaults[t] = cast
967                defaults.pop('_%s' % t, None)
968
969    def get_attnames(self, typ):
970        """Return the fields for the given record type.
971
972        This method will be replaced with the get_attnames() method of DbTypes.
973        """
974        return {}
975
976    def dateformat(self):
977        """Return the current date format.
978
979        This method will be replaced with the dateformat() method of DbTypes.
980        """
981        return '%Y-%m-%d'
982
983    def create_array_cast(self, basecast):
984        """Create an array typecast for the given base cast."""
985        def cast(v):
986            return cast_array(v, basecast)
987        return cast
988
989    def create_record_cast(self, name, fields, casts):
990        """Create a named record typecast for the given fields and casts."""
991        record = namedtuple(name, fields)
992        def cast(v):
993            return record(*cast_record(v, casts))
994        return cast
995
996
997def get_typecast(typ):
998    """Get the global typecast function for the given database type(s)."""
999    return Typecasts.get_default(typ)
1000
1001
1002def set_typecast(typ, cast):
1003    """Set a global typecast function for the given database type(s).
1004
1005    Note that connections cache cast functions. To be sure a global change
1006    is picked up by a running connection, call db.db_types.reset_typecast().
1007    """
1008    Typecasts.set_default(typ, cast)
1009
1010
1011class DbType(str):
1012    """Class augmenting the simple type name with additional info.
1013
1014    The following additional information is provided:
1015
1016        oid: the PostgreSQL type OID
1017        pgtype: the PostgreSQL type name
1018        regtype: the regular type name
1019        simple: the simple PyGreSQL type name
1020        typtype: b = base type, c = composite type etc.
1021        category: A = Array, b =Boolean, C = Composite etc.
1022        delim: delimiter for array types
1023        relid: corresponding table for composite types
1024        attnames: attributes for composite types
1025    """
1026
1027    @property
1028    def attnames(self):
1029        """Get names and types of the fields of a composite type."""
1030        return self._get_attnames(self)
1031
1032
1033class DbTypes(dict):
1034    """Cache for PostgreSQL data types.
1035
1036    This cache maps type OIDs and names to DbType objects containing
1037    information on the associated database type.
1038    """
1039
1040    _num_types = frozenset('int float num money'
1041        ' int2 int4 int8 float4 float8 numeric money'.split())
1042
1043    def __init__(self, db):
1044        """Initialize type cache for connection."""
1045        super(DbTypes, self).__init__()
1046        self._regtypes = False
1047        self._get_attnames = db.get_attnames
1048        self._typecasts = Typecasts()
1049        self._typecasts.get_attnames = self.get_attnames
1050        self._typecasts.connection = db
1051        db = db.db
1052        self.query = db.query
1053        self.escape_string = db.escape_string
1054
1055    def add(self, oid, pgtype, regtype,
1056               typtype, category, delim, relid):
1057        """Create a PostgreSQL type name with additional info."""
1058        if oid in self:
1059            return self[oid]
1060        simple = 'record' if relid else _simpletypes[pgtype]
1061        typ = DbType(regtype if self._regtypes else simple)
1062        typ.oid = oid
1063        typ.simple = simple
1064        typ.pgtype = pgtype
1065        typ.regtype = regtype
1066        typ.typtype = typtype
1067        typ.category = category
1068        typ.delim = delim
1069        typ.relid = relid
1070        typ._get_attnames = self.get_attnames
1071        return typ
1072
1073    def __missing__(self, key):
1074        """Get the type info from the database if it is not cached."""
1075        try:
1076            res = self.query("SELECT oid, typname, typname::regtype,"
1077                " typtype, typcategory, typdelim, typrelid"
1078                " FROM pg_type WHERE oid=%s::regtype" %
1079                (_quote_if_unqualified('$1', key),), (key,)).getresult()
1080        except ProgrammingError:
1081            res = None
1082        if not res:
1083            raise KeyError('Type %s could not be found' % key)
1084        res = res[0]
1085        typ = self.add(*res)
1086        self[typ.oid] = self[typ.pgtype] = typ
1087        return typ
1088
1089    def get(self, key, default=None):
1090        """Get the type even if it is not cached."""
1091        try:
1092            return self[key]
1093        except KeyError:
1094            return default
1095
1096    def get_attnames(self, typ):
1097        """Get names and types of the fields of a composite type."""
1098        if not isinstance(typ, DbType):
1099            typ = self.get(typ)
1100            if not typ:
1101                return None
1102        if not typ.relid:
1103            return None
1104        return self._get_attnames(typ.relid, with_oid=False)
1105
1106    def get_typecast(self, typ):
1107        """Get the typecast function for the given database type."""
1108        return self._typecasts.get(typ)
1109
1110    def set_typecast(self, typ, cast):
1111        """Set a typecast function for the specified database type(s)."""
1112        self._typecasts.set(typ, cast)
1113
1114    def reset_typecast(self, typ=None):
1115        """Reset the typecast function for the specified database type(s)."""
1116        self._typecasts.reset(typ)
1117
1118    def typecast(self, value, typ):
1119        """Cast the given value according to the given database type."""
1120        if value is None:
1121            # for NULL values, no typecast is necessary
1122            return None
1123        if not isinstance(typ, DbType):
1124            typ = self.get(typ)
1125            if typ:
1126                typ = typ.pgtype
1127        cast = self.get_typecast(typ) if typ else None
1128        if not cast or cast is str:
1129            # no typecast is necessary
1130            return value
1131        return cast(value)
1132
1133
1134def _namedresult(q):
1135    """Get query result as named tuples."""
1136    row = namedtuple('Row', q.listfields())
1137    return [row(*r) for r in q.getresult()]
1138
1139
1140class _MemoryQuery:
1141    """Class that embodies a given query result."""
1142
1143    def __init__(self, result, fields):
1144        """Create query from given result rows and field names."""
1145        self.result = result
1146        self.fields = fields
1147
1148    def listfields(self):
1149        """Return the stored field names of this query."""
1150        return self.fields
1151
1152    def getresult(self):
1153        """Return the stored result of this query."""
1154        return self.result
1155
1156
1157def _db_error(msg, cls=DatabaseError):
1158    """Return DatabaseError with empty sqlstate attribute."""
1159    error = cls(msg)
1160    error.sqlstate = None
1161    return error
1162
1163
1164def _int_error(msg):
1165    """Return InternalError."""
1166    return _db_error(msg, InternalError)
1167
1168
1169def _prg_error(msg):
1170    """Return ProgrammingError."""
1171    return _db_error(msg, ProgrammingError)
1172
1173
1174# Initialize the C module
1175
1176set_namedresult(_namedresult)
1177set_decimal(Decimal)
1178set_jsondecode(jsondecode)
1179
1180
1181# The notification handler
1182
1183class NotificationHandler(object):
1184    """A PostgreSQL client-side asynchronous notification handler."""
1185
1186    def __init__(self, db, event, callback=None,
1187            arg_dict=None, timeout=None, stop_event=None):
1188        """Initialize the notification handler.
1189
1190        You must pass a PyGreSQL database connection, the name of an
1191        event (notification channel) to listen for and a callback function.
1192
1193        You can also specify a dictionary arg_dict that will be passed as
1194        the single argument to the callback function, and a timeout value
1195        in seconds (a floating point number denotes fractions of seconds).
1196        If it is absent or None, the callers will never time out.  If the
1197        timeout is reached, the callback function will be called with a
1198        single argument that is None.  If you set the timeout to zero,
1199        the handler will poll notifications synchronously and return.
1200
1201        You can specify the name of the event that will be used to signal
1202        the handler to stop listening as stop_event. By default, it will
1203        be the event name prefixed with 'stop_'.
1204        """
1205        self.db = db
1206        self.event = event
1207        self.stop_event = stop_event or 'stop_%s' % event
1208        self.listening = False
1209        self.callback = callback
1210        if arg_dict is None:
1211            arg_dict = {}
1212        self.arg_dict = arg_dict
1213        self.timeout = timeout
1214
1215    def __del__(self):
1216        self.unlisten()
1217
1218    def close(self):
1219        """Stop listening and close the connection."""
1220        if self.db:
1221            self.unlisten()
1222            self.db.close()
1223            self.db = None
1224
1225    def listen(self):
1226        """Start listening for the event and the stop event."""
1227        if not self.listening:
1228            self.db.query('listen "%s"' % self.event)
1229            self.db.query('listen "%s"' % self.stop_event)
1230            self.listening = True
1231
1232    def unlisten(self):
1233        """Stop listening for the event and the stop event."""
1234        if self.listening:
1235            self.db.query('unlisten "%s"' % self.event)
1236            self.db.query('unlisten "%s"' % self.stop_event)
1237            self.listening = False
1238
1239    def notify(self, db=None, stop=False, payload=None):
1240        """Generate a notification.
1241
1242        Optionally, you can pass a payload with the notification.
1243
1244        If you set the stop flag, a stop notification will be sent that
1245        will cause the handler to stop listening.
1246
1247        Note: If the notification handler is running in another thread, you
1248        must pass a different database connection since PyGreSQL database
1249        connections are not thread-safe.
1250        """
1251        if self.listening:
1252            if not db:
1253                db = self.db
1254            q = 'notify "%s"' % (self.stop_event if stop else self.event)
1255            if payload:
1256                q += ", '%s'" % payload
1257            return db.query(q)
1258
1259    def __call__(self):
1260        """Invoke the notification handler.
1261
1262        The handler is a loop that listens for notifications on the event
1263        and stop event channels.  When either of these notifications are
1264        received, its associated 'pid', 'event' and 'extra' (the payload
1265        passed with the notification) are inserted into its arg_dict
1266        dictionary and the callback is invoked with this dictionary as
1267        a single argument.  When the handler receives a stop event, it
1268        stops listening to both events and return.
1269
1270        In the special case that the timeout of the handler has been set
1271        to zero, the handler will poll all events synchronously and return.
1272        If will keep listening until it receives a stop event.
1273
1274        Note: If you run this loop in another thread, don't use the same
1275        database connection for database operations in the main thread.
1276        """
1277        self.listen()
1278        poll = self.timeout == 0
1279        if not poll:
1280            rlist = [self.db.fileno()]
1281        while self.listening:
1282            if poll or select.select(rlist, [], [], self.timeout)[0]:
1283                while self.listening:
1284                    notice = self.db.getnotify()
1285                    if not notice:  # no more messages
1286                        break
1287                    event, pid, extra = notice
1288                    if event not in (self.event, self.stop_event):
1289                        self.unlisten()
1290                        raise _db_error(
1291                            'Listening for "%s" and "%s", but notified of "%s"'
1292                            % (self.event, self.stop_event, event))
1293                    if event == self.stop_event:
1294                        self.unlisten()
1295                    self.arg_dict.update(pid=pid, event=event, extra=extra)
1296                    self.callback(self.arg_dict)
1297                if poll:
1298                    break
1299            else:   # we timed out
1300                self.unlisten()
1301                self.callback(None)
1302
1303
1304def pgnotify(*args, **kw):
1305    """Same as NotificationHandler, under the traditional name."""
1306    warnings.warn("pgnotify is deprecated, use NotificationHandler instead",
1307        DeprecationWarning, stacklevel=2)
1308    return NotificationHandler(*args, **kw)
1309
1310
1311# The actual PostGreSQL database connection interface:
1312
1313class DB:
1314    """Wrapper class for the _pg connection type."""
1315
1316    def __init__(self, *args, **kw):
1317        """Create a new connection
1318
1319        You can pass either the connection parameters or an existing
1320        _pg or pgdb connection. This allows you to use the methods
1321        of the classic pg interface with a DB-API 2 pgdb connection.
1322        """
1323        if not args and len(kw) == 1:
1324            db = kw.get('db')
1325        elif not kw and len(args) == 1:
1326            db = args[0]
1327        else:
1328            db = None
1329        if db:
1330            if isinstance(db, DB):
1331                db = db.db
1332            else:
1333                try:
1334                    db = db._cnx
1335                except AttributeError:
1336                    pass
1337        if not db or not hasattr(db, 'db') or not hasattr(db, 'query'):
1338            db = connect(*args, **kw)
1339            self._closeable = True
1340        else:
1341            self._closeable = False
1342        self.db = db
1343        self.dbname = db.db
1344        self._regtypes = False
1345        self._attnames = {}
1346        self._pkeys = {}
1347        self._privileges = {}
1348        self._args = args, kw
1349        self.adapter = Adapter(self)
1350        self.dbtypes = DbTypes(self)
1351        db.set_cast_hook(self.dbtypes.typecast)
1352        self.debug = None  # For debugging scripts, this can be set
1353            # * to a string format specification (e.g. in CGI set to "%s<BR>"),
1354            # * to a file object to write debug statements or
1355            # * to a callable object which takes a string argument
1356            # * to any other true value to just print debug statements
1357
1358    def __getattr__(self, name):
1359        # All undefined members are same as in underlying connection:
1360        if self.db:
1361            return getattr(self.db, name)
1362        else:
1363            raise _int_error('Connection is not valid')
1364
1365    def __dir__(self):
1366        # Custom dir function including the attributes of the connection:
1367        attrs = set(self.__class__.__dict__)
1368        attrs.update(self.__dict__)
1369        attrs.update(dir(self.db))
1370        return sorted(attrs)
1371
1372    # Context manager methods
1373
1374    def __enter__(self):
1375        """Enter the runtime context. This will start a transactio."""
1376        self.begin()
1377        return self
1378
1379    def __exit__(self, et, ev, tb):
1380        """Exit the runtime context. This will end the transaction."""
1381        if et is None and ev is None and tb is None:
1382            self.commit()
1383        else:
1384            self.rollback()
1385
1386    # Auxiliary methods
1387
1388    def _do_debug(self, *args):
1389        """Print a debug message"""
1390        if self.debug:
1391            s = '\n'.join(str(arg) for arg in args)
1392            if isinstance(self.debug, basestring):
1393                print(self.debug % s)
1394            elif hasattr(self.debug, 'write'):
1395                self.debug.write(s + '\n')
1396            elif callable(self.debug):
1397                self.debug(s)
1398            else:
1399                print(s)
1400
1401    def _escape_qualified_name(self, s):
1402        """Escape a qualified name.
1403
1404        Escapes the name for use as an SQL identifier, unless the
1405        name contains a dot, in which case the name is ambiguous
1406        (could be a qualified name or just a name with a dot in it)
1407        and must be quoted manually by the caller.
1408        """
1409        if '.' not in s:
1410            s = self.escape_identifier(s)
1411        return s
1412
1413    @staticmethod
1414    def _make_bool(d):
1415        """Get boolean value corresponding to d."""
1416        return bool(d) if get_bool() else ('t' if d else 'f')
1417
1418    def _list_params(self, params):
1419        """Create a human readable parameter list."""
1420        return ', '.join('$%d=%r' % (n, v) for n, v in enumerate(params, 1))
1421
1422    # Public methods
1423
1424    # escape_string and escape_bytea exist as methods,
1425    # so we define unescape_bytea as a method as well
1426    unescape_bytea = staticmethod(unescape_bytea)
1427
1428    def decode_json(self, s):
1429        """Decode a JSON string coming from the database."""
1430        return (get_jsondecode() or jsondecode)(s)
1431
1432    def encode_json(self, d):
1433        """Encode a JSON string for use within SQL."""
1434        return jsonencode(d)
1435
1436    def close(self):
1437        """Close the database connection."""
1438        # Wraps shared library function so we can track state.
1439        if self._closeable:
1440            if self.db:
1441                self.db.close()
1442                self.db = None
1443            else:
1444                raise _int_error('Connection already closed')
1445
1446    def reset(self):
1447        """Reset connection with current parameters.
1448
1449        All derived queries and large objects derived from this connection
1450        will not be usable after this call.
1451
1452        """
1453        if self.db:
1454            self.db.reset()
1455        else:
1456            raise _int_error('Connection already closed')
1457
1458    def reopen(self):
1459        """Reopen connection to the database.
1460
1461        Used in case we need another connection to the same database.
1462        Note that we can still reopen a database that we have closed.
1463
1464        """
1465        # There is no such shared library function.
1466        if self._closeable:
1467            db = connect(*self._args[0], **self._args[1])
1468            if self.db:
1469                self.db.close()
1470            self.db = db
1471
1472    def begin(self, mode=None):
1473        """Begin a transaction."""
1474        qstr = 'BEGIN'
1475        if mode:
1476            qstr += ' ' + mode
1477        return self.query(qstr)
1478
1479    start = begin
1480
1481    def commit(self):
1482        """Commit the current transaction."""
1483        return self.query('COMMIT')
1484
1485    end = commit
1486
1487    def rollback(self, name=None):
1488        """Roll back the current transaction."""
1489        qstr = 'ROLLBACK'
1490        if name:
1491            qstr += ' TO ' + name
1492        return self.query(qstr)
1493
1494    abort = rollback
1495
1496    def savepoint(self, name):
1497        """Define a new savepoint within the current transaction."""
1498        return self.query('SAVEPOINT ' + name)
1499
1500    def release(self, name):
1501        """Destroy a previously defined savepoint."""
1502        return self.query('RELEASE ' + name)
1503
1504    def get_parameter(self, parameter):
1505        """Get the value of a run-time parameter.
1506
1507        If the parameter is a string, the return value will also be a string
1508        that is the current setting of the run-time parameter with that name.
1509
1510        You can get several parameters at once by passing a list, set or dict.
1511        When passing a list of parameter names, the return value will be a
1512        corresponding list of parameter settings.  When passing a set of
1513        parameter names, a new dict will be returned, mapping these parameter
1514        names to their settings.  Finally, if you pass a dict as parameter,
1515        its values will be set to the current parameter settings corresponding
1516        to its keys.
1517
1518        By passing the special name 'all' as the parameter, you can get a dict
1519        of all existing configuration parameters.
1520        """
1521        if isinstance(parameter, basestring):
1522            parameter = [parameter]
1523            values = None
1524        elif isinstance(parameter, (list, tuple)):
1525            values = []
1526        elif isinstance(parameter, (set, frozenset)):
1527            values = {}
1528        elif isinstance(parameter, dict):
1529            values = parameter
1530        else:
1531            raise TypeError(
1532                'The parameter must be a string, list, set or dict')
1533        if not parameter:
1534            raise TypeError('No parameter has been specified')
1535        params = {} if isinstance(values, dict) else []
1536        for key in parameter:
1537            param = key.strip().lower() if isinstance(
1538                key, basestring) else None
1539            if not param:
1540                raise TypeError('Invalid parameter')
1541            if param == 'all':
1542                q = 'SHOW ALL'
1543                values = self.db.query(q).getresult()
1544                values = dict(value[:2] for value in values)
1545                break
1546            if isinstance(values, dict):
1547                params[param] = key
1548            else:
1549                params.append(param)
1550        else:
1551            for param in params:
1552                q = 'SHOW %s' % (param,)
1553                value = self.db.query(q).getresult()[0][0]
1554                if values is None:
1555                    values = value
1556                elif isinstance(values, list):
1557                    values.append(value)
1558                else:
1559                    values[params[param]] = value
1560        return values
1561
1562    def set_parameter(self, parameter, value=None, local=False):
1563        """Set the value of a run-time parameter.
1564
1565        If the parameter and the value are strings, the run-time parameter
1566        will be set to that value.  If no value or None is passed as a value,
1567        then the run-time parameter will be restored to its default value.
1568
1569        You can set several parameters at once by passing a list of parameter
1570        names, together with a single value that all parameters should be
1571        set to or with a corresponding list of values.  You can also pass
1572        the parameters as a set if you only provide a single value.
1573        Finally, you can pass a dict with parameter names as keys.  In this
1574        case, you should not pass a value, since the values for the parameters
1575        will be taken from the dict.
1576
1577        By passing the special name 'all' as the parameter, you can reset
1578        all existing settable run-time parameters to their default values.
1579
1580        If you set local to True, then the command takes effect for only the
1581        current transaction.  After commit() or rollback(), the session-level
1582        setting takes effect again.  Setting local to True will appear to
1583        have no effect if it is executed outside a transaction, since the
1584        transaction will end immediately.
1585        """
1586        if isinstance(parameter, basestring):
1587            parameter = {parameter: value}
1588        elif isinstance(parameter, (list, tuple)):
1589            if isinstance(value, (list, tuple)):
1590                parameter = dict(zip(parameter, value))
1591            else:
1592                parameter = dict.fromkeys(parameter, value)
1593        elif isinstance(parameter, (set, frozenset)):
1594            if isinstance(value, (list, tuple, set, frozenset)):
1595                value = set(value)
1596                if len(value) == 1:
1597                    value = value.pop()
1598            if not(value is None or isinstance(value, basestring)):
1599                raise ValueError('A single value must be specified'
1600                    ' when parameter is a set')
1601            parameter = dict.fromkeys(parameter, value)
1602        elif isinstance(parameter, dict):
1603            if value is not None:
1604                raise ValueError('A value must not be specified'
1605                    ' when parameter is a dictionary')
1606        else:
1607            raise TypeError(
1608                'The parameter must be a string, list, set or dict')
1609        if not parameter:
1610            raise TypeError('No parameter has been specified')
1611        params = {}
1612        for key, value in parameter.items():
1613            param = key.strip().lower() if isinstance(
1614                key, basestring) else None
1615            if not param:
1616                raise TypeError('Invalid parameter')
1617            if param == 'all':
1618                if value is not None:
1619                    raise ValueError('A value must ot be specified'
1620                        " when parameter is 'all'")
1621                params = {'all': None}
1622                break
1623            params[param] = value
1624        local = ' LOCAL' if local else ''
1625        for param, value in params.items():
1626            if value is None:
1627                q = 'RESET%s %s' % (local, param)
1628            else:
1629                q = 'SET%s %s TO %s' % (local, param, value)
1630            self._do_debug(q)
1631            self.db.query(q)
1632
1633    def query(self, command, *args):
1634        """Execute a SQL command string.
1635
1636        This method simply sends a SQL query to the database.  If the query is
1637        an insert statement that inserted exactly one row into a table that
1638        has OIDs, the return value is the OID of the newly inserted row.
1639        If the query is an update or delete statement, or an insert statement
1640        that did not insert exactly one row in a table with OIDs, then the
1641        number of rows affected is returned as a string.  If it is a statement
1642        that returns rows as a result (usually a select statement, but maybe
1643        also an "insert/update ... returning" statement), this method returns
1644        a Query object that can be accessed via getresult() or dictresult()
1645        or simply printed.  Otherwise, it returns `None`.
1646
1647        The query can contain numbered parameters of the form $1 in place
1648        of any data constant.  Arguments given after the query string will
1649        be substituted for the corresponding numbered parameter.  Parameter
1650        values can also be given as a single list or tuple argument.
1651        """
1652        # Wraps shared library function for debugging.
1653        if not self.db:
1654            raise _int_error('Connection is not valid')
1655        if args:
1656            self._do_debug(command, args)
1657            return self.db.query(command, args)
1658        self._do_debug(command)
1659        return self.db.query(command)
1660
1661    def query_formatted(self, command, parameters, types=None, inline=False):
1662        """Execute a formatted SQL command string.
1663
1664        Similar to query, but using Python format placeholders of the form
1665        %s or %(names)s instead of PostgreSQL placeholders of the form $1.
1666        The parameters must be passed as a tuple, list or dict.  You can
1667        also pass a corresponding tuple, list or dict of database types in
1668        order to format the parameters properly in case there is ambiguity.
1669
1670        If you set inline to True, the parameters will be sent to the database
1671        embedded in the SQL command, otherwise they will be sent separately.
1672        """
1673        return self.query(*self.adapter.format_query(
1674            command, parameters, types, inline))
1675
1676    def pkey(self, table, composite=False, flush=False):
1677        """Get or set the primary key of a table.
1678
1679        Single primary keys are returned as strings unless you
1680        set the composite flag.  Composite primary keys are always
1681        represented as tuples.  Note that this raises a KeyError
1682        if the table does not have a primary key.
1683
1684        If flush is set then the internal cache for primary keys will
1685        be flushed.  This may be necessary after the database schema or
1686        the search path has been changed.
1687        """
1688        pkeys = self._pkeys
1689        if flush:
1690            pkeys.clear()
1691            self._do_debug('The pkey cache has been flushed')
1692        try:  # cache lookup
1693            pkey = pkeys[table]
1694        except KeyError:  # cache miss, check the database
1695            q = ("SELECT a.attname, a.attnum, i.indkey FROM pg_index i"
1696                " JOIN pg_attribute a ON a.attrelid = i.indrelid"
1697                " AND a.attnum = ANY(i.indkey)"
1698                " AND NOT a.attisdropped"
1699                " WHERE i.indrelid=%s::regclass"
1700                " AND i.indisprimary ORDER BY a.attnum") % (
1701                    _quote_if_unqualified('$1', table),)
1702            pkey = self.db.query(q, (table,)).getresult()
1703            if not pkey:
1704                raise KeyError('Table %s has no primary key' % table)
1705            # we want to use the order defined in the primary key index here,
1706            # not the order as defined by the columns in the table
1707            if len(pkey) > 1:
1708                indkey = pkey[0][2]
1709                pkey = sorted(pkey, key=lambda row: indkey.index(row[1]))
1710                pkey = tuple(row[0] for row in pkey)
1711            else:
1712                pkey = pkey[0][0]
1713            pkeys[table] = pkey  # cache it
1714        if composite and not isinstance(pkey, tuple):
1715            pkey = (pkey,)
1716        return pkey
1717
1718    def get_databases(self):
1719        """Get list of databases in the system."""
1720        return [s[0] for s in
1721            self.db.query('SELECT datname FROM pg_database').getresult()]
1722
1723    def get_relations(self, kinds=None):
1724        """Get list of relations in connected database of specified kinds.
1725
1726        If kinds is None or empty, all kinds of relations are returned.
1727        Otherwise kinds can be a string or sequence of type letters
1728        specifying which kind of relations you want to list.
1729        """
1730        where = " AND r.relkind IN (%s)" % ','.join(
1731            ["'%s'" % k for k in kinds]) if kinds else ''
1732        q = ("SELECT quote_ident(s.nspname)||'.'||quote_ident(r.relname)"
1733            " FROM pg_class r"
1734            " JOIN pg_namespace s ON s.oid = r.relnamespace"
1735            " WHERE s.nspname NOT SIMILAR"
1736            " TO 'pg/_%%|information/_schema' ESCAPE '/' %s"
1737            " ORDER BY s.nspname, r.relname") % where
1738        return [r[0] for r in self.db.query(q).getresult()]
1739
1740    def get_tables(self):
1741        """Return list of tables in connected database."""
1742        return self.get_relations('r')
1743
1744    def get_attnames(self, table, with_oid=True, flush=False):
1745        """Given the name of a table, dig out the set of attribute names.
1746
1747        Returns a read-only dictionary of attribute names (the names are
1748        the keys, the values are the names of the attributes' types)
1749        with the column names in the proper order if you iterate over it.
1750
1751        If flush is set, then the internal cache for attribute names will
1752        be flushed. This may be necessary after the database schema or
1753        the search path has been changed.
1754
1755        By default, only a limited number of simple types will be returned.
1756        You can get the regular types after calling use_regtypes(True).
1757        """
1758        attnames = self._attnames
1759        if flush:
1760            attnames.clear()
1761            self._do_debug('The attnames cache has been flushed')
1762        try:  # cache lookup
1763            names = attnames[table]
1764        except KeyError:  # cache miss, check the database
1765            q = "a.attnum > 0"
1766            if with_oid:
1767                q = "(%s OR a.attname = 'oid')" % q
1768            q = ("SELECT a.attname, t.oid, t.typname, t.typname::regtype,"
1769                " t.typtype, t.typcategory, t.typdelim, t.typrelid"
1770                " FROM pg_attribute a"
1771                " JOIN pg_type t ON t.oid = a.atttypid"
1772                " WHERE a.attrelid = %s::regclass AND %s"
1773                " AND NOT a.attisdropped ORDER BY a.attnum") % (
1774                    _quote_if_unqualified('$1', table), q)
1775            names = self.db.query(q, (table,)).getresult()
1776            types = self.dbtypes
1777            names = ((name[0], types.add(*name[1:])) for name in names)
1778            names = AttrDict(names)
1779            attnames[table] = names  # cache it
1780        return names
1781
1782    def use_regtypes(self, regtypes=None):
1783        """Use regular type names instead of simplified type names."""
1784        if regtypes is None:
1785            return self.dbtypes._regtypes
1786        else:
1787            regtypes = bool(regtypes)
1788            if regtypes != self.dbtypes._regtypes:
1789                self.dbtypes._regtypes = regtypes
1790                self._attnames.clear()
1791                self.dbtypes.clear()
1792            return regtypes
1793
1794    def has_table_privilege(self, table, privilege='select', flush=False):
1795        """Check whether current user has specified table privilege.
1796
1797        If flush is set, then the internal cache for table privileges will
1798        be flushed. This may be necessary after privileges have been changed.
1799        """
1800        privileges = self._privileges
1801        if flush:
1802            privileges.clear()
1803            self._do_debug('The privileges cache has been flushed')
1804        privilege = privilege.lower()
1805        try:  # ask cache
1806            ret = privileges[table, privilege]
1807        except KeyError:  # cache miss, ask the database
1808            q = "SELECT has_table_privilege(%s, $2)" % (
1809                _quote_if_unqualified('$1', table),)
1810            q = self.db.query(q, (table, privilege))
1811            ret = q.getresult()[0][0] == self._make_bool(True)
1812            privileges[table, privilege] = ret  # cache it
1813        return ret
1814
1815    def get(self, table, row, keyname=None):
1816        """Get a row from a database table or view.
1817
1818        This method is the basic mechanism to get a single row.  It assumes
1819        that the keyname specifies a unique row.  It must be the name of a
1820        single column or a tuple of column names.  If the keyname is not
1821        specified, then the primary key for the table is used.
1822
1823        If row is a dictionary, then the value for the key is taken from it.
1824        Otherwise, the row must be a single value or a tuple of values
1825        corresponding to the passed keyname or primary key.  The fetched row
1826        from the table will be returned as a new dictionary or used to replace
1827        the existing values when row was passed as aa dictionary.
1828
1829        The OID is also put into the dictionary if the table has one, but
1830        in order to allow the caller to work with multiple tables, it is
1831        munged as "oid(table)" using the actual name of the table.
1832        """
1833        if table.endswith('*'):  # hint for descendant tables can be ignored
1834            table = table[:-1].rstrip()
1835        attnames = self.get_attnames(table)
1836        qoid = _oid_key(table) if 'oid' in attnames else None
1837        if keyname and isinstance(keyname, basestring):
1838            keyname = (keyname,)
1839        if qoid and isinstance(row, dict) and qoid in row and 'oid' not in row:
1840            row['oid'] = row[qoid]
1841        if not keyname:
1842            try:  # if keyname is not specified, try using the primary key
1843                keyname = self.pkey(table, True)
1844            except KeyError:  # the table has no primary key
1845                # try using the oid instead
1846                if qoid and isinstance(row, dict) and 'oid' in row:
1847                    keyname = ('oid',)
1848                else:
1849                    raise _prg_error('Table %s has no primary key' % table)
1850            else:  # the table has a primary key
1851                # check whether all key columns have values
1852                if isinstance(row, dict) and not set(keyname).issubset(row):
1853                    # try using the oid instead
1854                    if qoid and 'oid' in row:
1855                        keyname = ('oid',)
1856                    else:
1857                        raise KeyError(
1858                            'Missing value in row for specified keyname')
1859        if not isinstance(row, dict):
1860            if not isinstance(row, (tuple, list)):
1861                row = [row]
1862            if len(keyname) != len(row):
1863                raise KeyError(
1864                    'Differing number of items in keyname and row')
1865            row = dict(zip(keyname, row))
1866        params = self.adapter.parameter_list()
1867        adapt = params.add
1868        col = self.escape_identifier
1869        what = 'oid, *' if qoid else '*'
1870        where = ' AND '.join('%s = %s' % (
1871            col(k), adapt(row[k], attnames[k])) for k in keyname)
1872        if 'oid' in row:
1873            if qoid:
1874                row[qoid] = row['oid']
1875            del row['oid']
1876        q = 'SELECT %s FROM %s WHERE %s LIMIT 1' % (
1877            what, self._escape_qualified_name(table), where)
1878        self._do_debug(q, params)
1879        q = self.db.query(q, params)
1880        res = q.dictresult()
1881        if not res:
1882            raise _db_error('No such record in %s\nwhere %s\nwith %s' % (
1883                table, where, self._list_params(params)))
1884        for n, value in res[0].items():
1885            if qoid and n == 'oid':
1886                n = qoid
1887            row[n] = value
1888        return row
1889
1890    def insert(self, table, row=None, **kw):
1891        """Insert a row into a database table.
1892
1893        This method inserts a row into a table.  The name of the table must
1894        be passed as the first parameter.  The other parameters are used for
1895        providing the data of the row that shall be inserted into the table.
1896        If a dictionary is supplied as the second parameter, it starts with
1897        that.  Otherwise it uses a blank dictionary. Either way the dictionary
1898        is updated from the keywords.
1899
1900        The dictionary is then reloaded with the values actually inserted in
1901        order to pick up values modified by rules, triggers, etc.
1902        """
1903        if table.endswith('*'):  # hint for descendant tables can be ignored
1904            table = table[:-1].rstrip()
1905        if row is None:
1906            row = {}
1907        row.update(kw)
1908        if 'oid' in row:
1909            del row['oid']  # do not insert oid
1910        attnames = self.get_attnames(table)
1911        qoid = _oid_key(table) if 'oid' in attnames else None
1912        params = self.adapter.parameter_list()
1913        adapt = params.add
1914        col = self.escape_identifier
1915        names, values = [], []
1916        for n in attnames:
1917            if n in row:
1918                names.append(col(n))
1919                values.append(adapt(row[n], attnames[n]))
1920        if not names:
1921            raise _prg_error('No column found that can be inserted')
1922        names, values = ', '.join(names), ', '.join(values)
1923        ret = 'oid, *' if qoid else '*'
1924        q = 'INSERT INTO %s (%s) VALUES (%s) RETURNING %s' % (
1925            self._escape_qualified_name(table), names, values, ret)
1926        self._do_debug(q, params)
1927        q = self.db.query(q, params)
1928        res = q.dictresult()
1929        if res:  # this should always be true
1930            for n, value in res[0].items():
1931                if qoid and n == 'oid':
1932                    n = qoid
1933                row[n] = value
1934        return row
1935
1936    def update(self, table, row=None, **kw):
1937        """Update an existing row in a database table.
1938
1939        Similar to insert but updates an existing row.  The update is based
1940        on the primary key of the table or the OID value as munged by get
1941        or passed as keyword.
1942
1943        The dictionary is then modified to reflect any changes caused by the
1944        update due to triggers, rules, default values, etc.
1945        """
1946        if table.endswith('*'):
1947            table = table[:-1].rstrip()  # need parent table name
1948        attnames = self.get_attnames(table)
1949        qoid = _oid_key(table) if 'oid' in attnames else None
1950        if row is None:
1951            row = {}
1952        elif 'oid' in row:
1953            del row['oid']  # only accept oid key from named args for safety
1954        row.update(kw)
1955        if qoid and qoid in row and 'oid' not in row:
1956            row['oid'] = row[qoid]
1957        try:  # try using the primary key
1958            keyname = self.pkey(table, True)
1959        except KeyError:  # the table has no primary key
1960            # try using the oid instead
1961            if qoid and 'oid' in row:
1962                keyname = ('oid',)
1963            else:
1964                raise _prg_error('Table %s has no primary key' % table)
1965        else:  # the table has a primary key
1966            # check whether all key columns have values
1967            if not set(keyname).issubset(row):
1968                # try using the oid instead
1969                if qoid and 'oid' in row:
1970                    keyname = ('oid',)
1971                else:
1972                    raise KeyError('Missing primary key in row')
1973        params = self.adapter.parameter_list()
1974        adapt = params.add
1975        col = self.escape_identifier
1976        where = ' AND '.join('%s = %s' % (
1977            col(k), adapt(row[k], attnames[k])) for k in keyname)
1978        if 'oid' in row:
1979            if qoid:
1980                row[qoid] = row['oid']
1981            del row['oid']
1982        values = []
1983        keyname = set(keyname)
1984        for n in attnames:
1985            if n in row and n not in keyname:
1986                values.append('%s = %s' % (col(n), adapt(row[n], attnames[n])))
1987        if not values:
1988            return row
1989        values = ', '.join(values)
1990        ret = 'oid, *' if qoid else '*'
1991        q = 'UPDATE %s SET %s WHERE %s RETURNING %s' % (
1992            self._escape_qualified_name(table), values, where, ret)
1993        self._do_debug(q, params)
1994        q = self.db.query(q, params)
1995        res = q.dictresult()
1996        if res:  # may be empty when row does not exist
1997            for n, value in res[0].items():
1998                if qoid and n == 'oid':
1999                    n = qoid
2000                row[n] = value
2001        return row
2002
2003    def upsert(self, table, row=None, **kw):
2004        """Insert a row into a database table with conflict resolution
2005
2006        This method inserts a row into a table, but instead of raising a
2007        ProgrammingError exception in case a row with the same primary key
2008        already exists, an update will be executed instead.  This will be
2009        performed as a single atomic operation on the database, so race
2010        conditions can be avoided.
2011
2012        Like the insert method, the first parameter is the name of the
2013        table and the second parameter can be used to pass the values to
2014        be inserted as a dictionary.
2015
2016        Unlike the insert und update statement, keyword parameters are not
2017        used to modify the dictionary, but to specify which columns shall
2018        be updated in case of a conflict, and in which way:
2019
2020        A value of False or None means the column shall not be updated,
2021        a value of True means the column shall be updated with the value
2022        that has been proposed for insertion, i.e. has been passed as value
2023        in the dictionary.  Columns that are not specified by keywords but
2024        appear as keys in the dictionary are also updated like in the case
2025        keywords had been passed with the value True.
2026
2027        So if in the case of a conflict you want to update every column that
2028        has been passed in the dictionary row , you would call upsert(table, row).
2029        If you don't want to do anything in case of a conflict, i.e. leave
2030        the existing row as it is, call upsert(table, row, **dict.fromkeys(row)).
2031
2032        If you need more fine-grained control of what gets updated, you can
2033        also pass strings in the keyword parameters.  These strings will
2034        be used as SQL expressions for the update columns.  In these
2035        expressions you can refer to the value that already exists in
2036        the table by prefixing the column name with "included.", and to
2037        the value that has been proposed for insertion by prefixing the
2038        column name with the "excluded."
2039
2040        The dictionary is modified in any case to reflect the values in
2041        the database after the operation has completed.
2042
2043        Note: The method uses the PostgreSQL "upsert" feature which is
2044        only available since PostgreSQL 9.5.
2045        """
2046        if table.endswith('*'):  # hint for descendant tables can be ignored
2047            table = table[:-1].rstrip()
2048        if row is None:
2049            row = {}
2050        if 'oid' in row:
2051            del row['oid']  # do not insert oid
2052        if 'oid' in kw:
2053            del kw['oid']  # do not update oid
2054        attnames = self.get_attnames(table)
2055        qoid = _oid_key(table) if 'oid' in attnames else None
2056        params = self.adapter.parameter_list()
2057        adapt = params.add
2058        col = self.escape_identifier
2059        names, values, updates = [], [], []
2060        for n in attnames:
2061            if n in row:
2062                names.append(col(n))
2063                values.append(adapt(row[n], attnames[n]))
2064        names, values = ', '.join(names), ', '.join(values)
2065        try:
2066            keyname = self.pkey(table, True)
2067        except KeyError:
2068            raise _prg_error('Table %s has no primary key' % table)
2069        target = ', '.join(col(k) for k in keyname)
2070        update = []
2071        keyname = set(keyname)
2072        keyname.add('oid')
2073        for n in attnames:
2074            if n not in keyname:
2075                value = kw.get(n, True)
2076                if value:
2077                    if not isinstance(value, basestring):
2078                        value = 'excluded.%s' % col(n)
2079                    update.append('%s = %s' % (col(n), value))
2080        if not values:
2081            return row
2082        do = 'update set %s' % ', '.join(update) if update else 'nothing'
2083        ret = 'oid, *' if qoid else '*'
2084        q = ('INSERT INTO %s AS included (%s) VALUES (%s)'
2085            ' ON CONFLICT (%s) DO %s RETURNING %s') % (
2086                self._escape_qualified_name(table), names, values,
2087                target, do, ret)
2088        self._do_debug(q, params)
2089        try:
2090            q = self.db.query(q, params)
2091        except ProgrammingError:
2092            if self.server_version < 90500:
2093                raise _prg_error(
2094                    'Upsert operation is not supported by PostgreSQL version')
2095            raise  # re-raise original error
2096        res = q.dictresult()
2097        if res:  # may be empty with "do nothing"
2098            for n, value in res[0].items():
2099                if qoid and n == 'oid':
2100                    n = qoid
2101                row[n] = value
2102        else:
2103            self.get(table, row)
2104        return row
2105
2106    def clear(self, table, row=None):
2107        """Clear all the attributes to values determined by the types.
2108
2109        Numeric types are set to 0, Booleans are set to false, and everything
2110        else is set to the empty string.  If the row argument is present,
2111        it is used as the row dictionary and any entries matching attribute
2112        names are cleared with everything else left unchanged.
2113        """
2114        # At some point we will need a way to get defaults from a table.
2115        if row is None:
2116            row = {}  # empty if argument is not present
2117        attnames = self.get_attnames(table)
2118        for n, t in attnames.items():
2119            if n == 'oid':
2120                continue
2121            t = t.simple
2122            if t in DbTypes._num_types:
2123                row[n] = 0
2124            elif t == 'bool':
2125                row[n] = self._make_bool(False)
2126            else:
2127                row[n] = ''
2128        return row
2129
2130    def delete(self, table, row=None, **kw):
2131        """Delete an existing row in a database table.
2132
2133        This method deletes the row from a table.  It deletes based on the
2134        primary key of the table or the OID value as munged by get() or
2135        passed as keyword.
2136
2137        The return value is the number of deleted rows (i.e. 0 if the row
2138        did not exist and 1 if the row was deleted).
2139
2140        Note that if the row cannot be deleted because e.g. it is still
2141        referenced by another table, this method raises a ProgrammingError.
2142        """
2143        if table.endswith('*'):  # hint for descendant tables can be ignored
2144            table = table[:-1].rstrip()
2145        attnames = self.get_attnames(table)
2146        qoid = _oid_key(table) if 'oid' in attnames else None
2147        if row is None:
2148            row = {}
2149        elif 'oid' in row:
2150            del row['oid']  # only accept oid key from named args for safety
2151        row.update(kw)
2152        if qoid and qoid in row and 'oid' not in row:
2153            row['oid'] = row[qoid]
2154        try:  # try using the primary key
2155            keyname = self.pkey(table, True)
2156        except KeyError:  # the table has no primary key
2157            # try using the oid instead
2158            if qoid and 'oid' in row:
2159                keyname = ('oid',)
2160            else:
2161                raise _prg_error('Table %s has no primary key' % table)
2162        else:  # the table has a primary key
2163            # check whether all key columns have values
2164            if not set(keyname).issubset(row):
2165                # try using the oid instead
2166                if qoid and 'oid' in row:
2167                    keyname = ('oid',)
2168                else:
2169                    raise KeyError('Missing primary key in row')
2170        params = self.adapter.parameter_list()
2171        adapt = params.add
2172        col = self.escape_identifier
2173        where = ' AND '.join('%s = %s' % (
2174            col(k), adapt(row[k], attnames[k])) for k in keyname)
2175        if 'oid' in row:
2176            if qoid:
2177                row[qoid] = row['oid']
2178            del row['oid']
2179        q = 'DELETE FROM %s WHERE %s' % (
2180            self._escape_qualified_name(table), where)
2181        self._do_debug(q, params)
2182        res = self.db.query(q, params)
2183        return int(res)
2184
2185    def truncate(self, table, restart=False, cascade=False, only=False):
2186        """Empty a table or set of tables.
2187
2188        This method quickly removes all rows from the given table or set
2189        of tables.  It has the same effect as an unqualified DELETE on each
2190        table, but since it does not actually scan the tables it is faster.
2191        Furthermore, it reclaims disk space immediately, rather than requiring
2192        a subsequent VACUUM operation. This is most useful on large tables.
2193
2194        If restart is set to True, sequences owned by columns of the truncated
2195        table(s) are automatically restarted.  If cascade is set to True, it
2196        also truncates all tables that have foreign-key references to any of
2197        the named tables.  If the parameter only is not set to True, all the
2198        descendant tables (if any) will also be truncated. Optionally, a '*'
2199        can be specified after the table name to explicitly indicate that
2200        descendant tables are included.
2201        """
2202        if isinstance(table, basestring):
2203            only = {table: only}
2204            table = [table]
2205        elif isinstance(table, (list, tuple)):
2206            if isinstance(only, (list, tuple)):
2207                only = dict(zip(table, only))
2208            else:
2209                only = dict.fromkeys(table, only)
2210        elif isinstance(table, (set, frozenset)):
2211            only = dict.fromkeys(table, only)
2212        else:
2213            raise TypeError('The table must be a string, list or set')
2214        if not (restart is None or isinstance(restart, (bool, int))):
2215            raise TypeError('Invalid type for the restart option')
2216        if not (cascade is None or isinstance(cascade, (bool, int))):
2217            raise TypeError('Invalid type for the cascade option')
2218        tables = []
2219        for t in table:
2220            u = only.get(t)
2221            if not (u is None or isinstance(u, (bool, int))):
2222                raise TypeError('Invalid type for the only option')
2223            if t.endswith('*'):
2224                if u:
2225                    raise ValueError(
2226                        'Contradictory table name and only options')
2227                t = t[:-1].rstrip()
2228            t = self._escape_qualified_name(t)
2229            if u:
2230                t = 'ONLY %s' % t
2231            tables.append(t)
2232        q = ['TRUNCATE', ', '.join(tables)]
2233        if restart:
2234            q.append('RESTART IDENTITY')
2235        if cascade:
2236            q.append('CASCADE')
2237        q = ' '.join(q)
2238        self._do_debug(q)
2239        return self.db.query(q)
2240
2241    def get_as_list(self, table, what=None, where=None,
2242            order=None, limit=None, offset=None, scalar=False):
2243        """Get a table as a list.
2244
2245        This gets a convenient representation of the table as a list
2246        of named tuples in Python.  You only need to pass the name of
2247        the table (or any other SQL expression returning rows).  Note that
2248        by default this will return the full content of the table which
2249        can be huge and overflow your memory.  However, you can control
2250        the amount of data returned using the other optional parameters.
2251
2252        The parameter 'what' can restrict the query to only return a
2253        subset of the table columns.  It can be a string, list or a tuple.
2254        The parameter 'where' can restrict the query to only return a
2255        subset of the table rows.  It can be a string, list or a tuple
2256        of SQL expressions that all need to be fulfilled.  The parameter
2257        'order' specifies the ordering of the rows.  It can also be a
2258        other string, list or a tuple.  If no ordering is specified,
2259        the result will be ordered by the primary key(s) or all columns
2260        if no primary key exists.  You can set 'order' to False if you
2261        don't care about the ordering.  The parameters 'limit' and 'offset'
2262        can be integers specifying the maximum number of rows returned
2263        and a number of rows skipped over.
2264
2265        If you set the 'scalar' option to True, then instead of the
2266        named tuples you will get the first items of these tuples.
2267        This is useful if the result has only one column anyway.
2268        """
2269        if not table:
2270            raise TypeError('The table name is missing')
2271        if what:
2272            if isinstance(what, (list, tuple)):
2273                what = ', '.join(map(str, what))
2274            if order is None:
2275                order = what
2276        else:
2277            what = '*'
2278        q = ['SELECT', what, 'FROM', table]
2279        if where:
2280            if isinstance(where, (list, tuple)):
2281                where = ' AND '.join(map(str, where))
2282            q.extend(['WHERE', where])
2283        if order is None:
2284            try:
2285                order = self.pkey(table, True)
2286            except (KeyError, ProgrammingError):
2287                try:
2288                    order = list(self.get_attnames(table))
2289                except (KeyError, ProgrammingError):
2290                    pass
2291        if order:
2292            if isinstance(order, (list, tuple)):
2293                order = ', '.join(map(str, order))
2294            q.extend(['ORDER BY', order])
2295        if limit:
2296            q.append('LIMIT %d' % limit)
2297        if offset:
2298            q.append('OFFSET %d' % offset)
2299        q = ' '.join(q)
2300        self._do_debug(q)
2301        q = self.db.query(q)
2302        res = q.namedresult()
2303        if res and scalar:
2304            res = [row[0] for row in res]
2305        return res
2306
2307    def get_as_dict(self, table, keyname=None, what=None, where=None,
2308            order=None, limit=None, offset=None, scalar=False):
2309        """Get a table as a dictionary.
2310
2311        This method is similar to get_as_list(), but returns the table
2312        as a Python dict instead of a Python list, which can be even
2313        more convenient. The primary key column(s) of the table will
2314        be used as the keys of the dictionary, while the other column(s)
2315        will be the corresponding values.  The keys will be named tuples
2316        if the table has a composite primary key.  The rows will be also
2317        named tuples unless the 'scalar' option has been set to True.
2318        With the optional parameter 'keyname' you can specify an alternative
2319        set of columns to be used as the keys of the dictionary.  It must
2320        be set as a string, list or a tuple.
2321
2322        If the Python version supports it, the dictionary will be an
2323        OrderedDict using the order specified with the 'order' parameter
2324        or the key column(s) if not specified.  You can set 'order' to False
2325        if you don't care about the ordering.  In this case the returned
2326        dictionary will be an ordinary one.
2327        """
2328        if not table:
2329            raise TypeError('The table name is missing')
2330        if not keyname:
2331            try:
2332                keyname = self.pkey(table, True)
2333            except (KeyError, ProgrammingError):
2334                raise _prg_error('Table %s has no primary key' % table)
2335        if isinstance(keyname, basestring):
2336            keyname = [keyname]
2337        elif not isinstance(keyname, (list, tuple)):
2338            raise KeyError('The keyname must be a string, list or tuple')
2339        if what:
2340            if isinstance(what, (list, tuple)):
2341                what = ', '.join(map(str, what))
2342            if order is None:
2343                order = what
2344        else:
2345            what = '*'
2346        q = ['SELECT', what, 'FROM', table]
2347        if where:
2348            if isinstance(where, (list, tuple)):
2349                where = ' AND '.join(map(str, where))
2350            q.extend(['WHERE', where])
2351        if order is None:
2352            order = keyname
2353        if order:
2354            if isinstance(order, (list, tuple)):
2355                order = ', '.join(map(str, order))
2356            q.extend(['ORDER BY', order])
2357        if limit:
2358            q.append('LIMIT %d' % limit)
2359        if offset:
2360            q.append('OFFSET %d' % offset)
2361        q = ' '.join(q)
2362        self._do_debug(q)
2363        q = self.db.query(q)
2364        res = q.getresult()
2365        cls = OrderedDict if order else dict
2366        if not res:
2367            return cls()
2368        keyset = set(keyname)
2369        fields = q.listfields()
2370        if not keyset.issubset(fields):
2371            raise KeyError('Missing keyname in row')
2372        keyind, rowind = [], []
2373        for i, f in enumerate(fields):
2374            (keyind if f in keyset else rowind).append(i)
2375        keytuple = len(keyind) > 1
2376        getkey = itemgetter(*keyind)
2377        keys = map(getkey, res)
2378        if scalar:
2379            rowind = rowind[:1]
2380            rowtuple = False
2381        else:
2382            rowtuple = len(rowind) > 1
2383        if scalar or rowtuple:
2384            getrow = itemgetter(*rowind)
2385        else:
2386            rowind = rowind[0]
2387            getrow = lambda row: (row[rowind],)
2388            rowtuple = True
2389        rows = map(getrow, res)
2390        if keytuple or rowtuple:
2391            namedresult = get_namedresult()
2392            if namedresult:
2393                if keytuple:
2394                    keys = namedresult(_MemoryQuery(keys, keyname))
2395                if rowtuple:
2396                    fields = [f for f in fields if f not in keyset]
2397                    rows = namedresult(_MemoryQuery(rows, fields))
2398        return cls(zip(keys, rows))
2399
2400    def notification_handler(self,
2401            event, callback, arg_dict=None, timeout=None, stop_event=None):
2402        """Get notification handler that will run the given callback."""
2403        return NotificationHandler(self,
2404            event, callback, arg_dict, timeout, stop_event)
2405
2406
2407# if run as script, print some information
2408
2409if __name__ == '__main__':
2410    print('PyGreSQL version' + version)
2411    print('')
2412    print(__doc__)
Note: See TracBrowser for help on using the repository browser.