source: trunk/pg.py @ 799

Last change on this file since 799 was 799, checked in by cito, 4 years ago

Improve adaptation and add query_formatted() method

Also added more tests and documentation.

  • Property svn:keywords set to Id
File size: 77.8 KB
Line 
1#! /usr/bin/python
2#
3# pg.py
4#
5# $Id: pg.py 799 2016-01-31 18:45:58Z cito $
6#
7
8"""PyGreSQL classic interface.
9
10This pg module implements some basic database management stuff.
11It includes the _pg module and builds on it, providing the higher
12level wrapper class named DB with additional functionality.
13This is known as the "classic" ("old style") PyGreSQL interface.
14For a DB-API 2 compliant interface use the newer pgdb module.
15"""
16
17# Copyright (c) 1997-2016 by D'Arcy J.M. Cain.
18#
19# Contributions made by Ch. Zwerschke and others.
20#
21# The notification handler is based on pgnotify which is
22# Copyright (c) 2001 Ng Pheng Siong. All rights reserved.
23#
24# Permission to use, copy, modify, and distribute this software and its
25# documentation for any purpose and without fee is hereby granted,
26# provided that the above copyright notice appear in all copies and that
27# both that copyright notice and this permission notice appear in
28# supporting documentation.
29
30from __future__ import print_function
31
32from _pg import *
33
34import select
35import warnings
36
37from datetime import date, time, datetime, timedelta
38from decimal import Decimal
39from math import isnan, isinf
40from collections import namedtuple
41from operator import itemgetter
42from re import compile as regex
43from json import loads as jsondecode, dumps as jsonencode
44
45try:
46    long
47except NameError:  # Python >= 3.0
48    long = int
49
50try:
51    basestring
52except NameError:  # Python >= 3.0
53    basestring = (str, bytes)
54
55try:
56    from collections import OrderedDict
57except ImportError:  # Python 2.6 or 3.0
58    OrderedDict = dict
59
60
61    class AttrDict(dict):
62        """Simple read-only ordered dictionary for storing attribute names."""
63
64        def __init__(self, *args, **kw):
65            if len(args) > 1 or kw:
66                raise TypeError
67            items = args[0] if args else []
68            if isinstance(items, dict):
69                raise TypeError
70            items = list(items)
71            self._keys = [item[0] for item in items]
72            dict.__init__(self, items)
73            self._read_only = True
74            error = self._read_only_error
75            self.clear = self.update = error
76            self.pop = self.setdefault = self.popitem = error
77
78        def __setitem__(self, key, value):
79            if self._read_only:
80                self._read_only_error()
81            dict.__setitem__(self, key, value)
82
83        def __delitem__(self, key):
84            if self._read_only:
85                self._read_only_error()
86            dict.__delitem__(self, key)
87
88        def __iter__(self):
89            return iter(self._keys)
90
91        def keys(self):
92            return list(self._keys)
93
94        def values(self):
95            return [self[key] for key in self]
96
97        def items(self):
98            return [(key, self[key]) for key in self]
99
100        def iterkeys(self):
101            return self.__iter__()
102
103        def itervalues(self):
104            return iter(self.values())
105
106        def iteritems(self):
107            return iter(self.items())
108
109        @staticmethod
110        def _read_only_error(*args, **kw):
111            raise TypeError('This object is read-only')
112
113else:
114
115     class AttrDict(OrderedDict):
116        """Simple read-only ordered dictionary for storing attribute names."""
117
118        def __init__(self, *args, **kw):
119            self._read_only = False
120            OrderedDict.__init__(self, *args, **kw)
121            self._read_only = True
122            error = self._read_only_error
123            self.clear = self.update = error
124            self.pop = self.setdefault = self.popitem = error
125
126        def __setitem__(self, key, value):
127            if self._read_only:
128                self._read_only_error()
129            OrderedDict.__setitem__(self, key, value)
130
131        def __delitem__(self, key):
132            if self._read_only:
133                self._read_only_error()
134            OrderedDict.__delitem__(self, key)
135
136        @staticmethod
137        def _read_only_error(*args, **kw):
138            raise TypeError('This object is read-only')
139
140
141# Auxiliary classes and functions that are independent from a DB connection:
142
143def _oid_key(table):
144    """Build oid key from a table name."""
145    return 'oid(%s)' % table
146
147
148class _SimpleTypes(dict):
149    """Dictionary mapping pg_type names to simple type names."""
150
151    _types = {'bool': 'bool',
152        'bytea': 'bytea',
153        'date': 'date interval time timetz timestamp timestamptz'
154            ' abstime reltime',  # these are very old
155        'float': 'float4 float8',
156        'int': 'cid int2 int4 int8 oid xid',
157        'json': 'json jsonb',
158        'num': 'numeric',
159        'money': 'money',
160        'text': 'bpchar char name text varchar'}
161
162    def __init__(self):
163        for typ, keys in self._types.items():
164            for key in keys.split():
165                self[key] = typ
166                self['_%s' % key] = '%s[]' % typ
167
168    @staticmethod
169    def __missing__(key):
170        return 'text'
171
172_simpletypes = _SimpleTypes()
173
174
175def _quote_if_unqualified(param, name):
176    """Quote parameter representing a qualified name.
177
178    Puts a quote_ident() call around the give parameter unless
179    the name contains a dot, in which case the name is ambiguous
180    (could be a qualified name or just a name with a dot in it)
181    and must be quoted manually by the caller.
182    """
183    if isinstance(name, basestring) and '.' not in name:
184        return 'quote_ident(%s)' % (param,)
185    return param
186
187
188class _ParameterList(list):
189    """Helper class for building typed parameter lists."""
190
191    def add(self, value, typ=None):
192        """Typecast value with known database type and build parameter list.
193
194        If this is a literal value, it will be returned as is.  Otherwise, a
195        placeholder will be returned and the parameter list will be augmented.
196        """
197        value = self.adapt(value, typ)
198        if isinstance(value, Literal):
199            return value
200        self.append(value)
201        return '$%d' % len(self)
202
203
204class Literal(str):
205    """Wrapper class for marking literal SQL values."""
206
207
208class Json:
209    """Wrapper class for marking Json values."""
210
211    def __init__(self, obj):
212        self.obj = obj
213
214
215class Bytea(bytes):
216    """Wrapper class for marking Bytea values."""
217
218
219class Adapter:
220    """Class providing methods for adapting parameters to the database."""
221
222    _bool_true_values = frozenset('t true 1 y yes on'.split())
223
224    _date_literals = frozenset('current_date current_time'
225        ' current_timestamp localtime localtimestamp'.split())
226
227    _re_array_quote = regex(r'[{},"\\\s]|^[Nn][Uu][Ll][Ll]$')
228    _re_record_quote = regex(r'[(,"\\]')
229    _re_array_escape = _re_record_escape = regex(r'(["\\])')
230
231    def __init__(self, db):
232        self.db = db
233        self.encode_json = db.encode_json
234        db = db.db
235        self.escape_bytea = db.escape_bytea
236        self.escape_string = db.escape_string
237
238    @classmethod
239    def _adapt_bool(cls, v):
240        """Adapt a boolean parameter."""
241        if isinstance(v, basestring):
242            if not v:
243                return None
244            v = v.lower() in cls._bool_true_values
245        return 't' if v else 'f'
246
247    @classmethod
248    def _adapt_date(cls, v):
249        """Adapt a date parameter."""
250        if not v:
251            return None
252        if isinstance(v, basestring) and v.lower() in cls._date_literals:
253            return Literal(v)
254        return v
255
256    @staticmethod
257    def _adapt_num(v):
258        """Adapt a numeric parameter."""
259        if not v and v != 0:
260            return None
261        return v
262
263    _adapt_int = _adapt_float = _adapt_money = _adapt_num
264
265    def _adapt_bytea(self, v):
266        """Adapt a bytea parameter."""
267        return self.escape_bytea(v)
268
269    def _adapt_json(self, v):
270        """Adapt a json parameter."""
271        if not v:
272            return None
273        if isinstance(v, basestring):
274            return v
275        return self.encode_json(v)
276
277    @classmethod
278    def _adapt_text_array(cls, v):
279        """Adapt a text type array parameter."""
280        if isinstance(v, list):
281            adapt = cls._adapt_text_array
282            return '{%s}' % ','.join(adapt(v) for v in v)
283        if v is None:
284            return 'null'
285        if not v:
286            return '""'
287        v = str(v)
288        if cls._re_array_quote.search(v):
289            v = '"%s"' % cls._re_array_escape.sub(r'\\\1', v)
290        return v
291
292    _adapt_date_array = _adapt_text_array
293
294    @classmethod
295    def _adapt_bool_array(cls, v):
296        """Adapt a boolean array parameter."""
297        if isinstance(v, list):
298            adapt = cls._adapt_bool_array
299            return '{%s}' % ','.join(adapt(v) for v in v)
300        if v is None:
301            return 'null'
302        if isinstance(v, basestring):
303            if not v:
304                return 'null'
305            v = v.lower() in cls._bool_true_values
306        return 't' if v else 'f'
307
308    @classmethod
309    def _adapt_num_array(cls, v):
310        """Adapt a numeric array parameter."""
311        if isinstance(v, list):
312            adapt = cls._adapt_num_array
313            return '{%s}' % ','.join(adapt(v) for v in v)
314        if not v and v != 0:
315            return 'null'
316        return str(v)
317
318    _adapt_int_array = _adapt_float_array = _adapt_money_array = \
319            _adapt_num_array
320
321    def _adapt_bytea_array(self, v):
322        """Adapt a bytea array parameter."""
323        if isinstance(v, list):
324            return b'{' + b','.join(
325                self._adapt_bytea_array(v) for v in v) + b'}'
326        if v is None:
327            return b'null'
328        return self.escape_bytea(v).replace(b'\\', b'\\\\')
329
330    def _adapt_json_array(self, v):
331        """Adapt a json array parameter."""
332        if isinstance(v, list):
333            adapt = self._adapt_json_array
334            return '{%s}' % ','.join(adapt(v) for v in v)
335        if not v:
336            return 'null'
337        if not isinstance(v, basestring):
338            v = self.encode_json(v)
339        if self._re_array_quote.search(v):
340            v = '"%s"' % self._re_array_escape.sub(r'\\\1', v)
341        return v
342
343    def _adapt_record(self, v, typ):
344        """Adapt a record parameter with given type."""
345        typ = self.get_attnames(typ).values()
346        if len(typ) != len(v):
347            raise TypeError('Record parameter %s has wrong size' % v)
348        adapt = self.adapt
349        value = []
350        for v, t in zip(v, typ):
351            v = adapt(v, t)
352            if v is None:
353                v = ''
354            elif not v:
355                v = '""'
356            else:
357                if isinstance(v, bytes):
358                    if str is not bytes:
359                        v = v.decode('ascii')
360                else:
361                    v = str(v)
362                if self._re_record_quote.search(v):
363                    v = '"%s"' % self._re_record_escape.sub(r'\\\1', v)
364            value.append(v)
365        return '(%s)' % ','.join(value)
366
367    def adapt(self, value, typ=None):
368        """Adapt a value with known database type."""
369        if value is not None and not isinstance(value, Literal):
370            if typ:
371                simple = self.get_simple_name(typ)
372            else:
373                typ = simple = self.guess_simple_type(value) or 'text'
374            try:
375                value = value.__pg_str__(typ)
376            except AttributeError:
377                pass
378            if simple == 'text':
379                pass
380            elif simple == 'record':
381                if isinstance(value, tuple):
382                    value = self._adapt_record(value, typ)
383            elif simple.endswith('[]'):
384                if isinstance(value, list):
385                    adapt = getattr(self, '_adapt_%s_array' % simple[:-2])
386                    value = adapt(value)
387            else:
388                adapt = getattr(self, '_adapt_%s' % simple)
389                value = adapt(value)
390        return value
391
392    @staticmethod
393    def simple_type(name):
394        """Create a simple database type with given attribute names."""
395        typ = DbType(name)
396        typ.simple = name
397        return typ
398
399    @staticmethod
400    def get_simple_name(typ):
401        """Get the simple name of a database type."""
402        if isinstance(typ, DbType):
403            return typ.simple
404        return _simpletypes[typ]
405
406    @staticmethod
407    def get_attnames(typ):
408        """Get the attribute names of a composite database type."""
409        if isinstance(typ, DbType):
410            return typ.attnames
411        return {}
412
413    @classmethod
414    def guess_simple_type(cls, value):
415        """Try to guess which database type the given value has."""
416        if isinstance(value, Bytea):
417            return 'bytea'
418        if isinstance(value, basestring):
419            return 'text'
420        if isinstance(value, bool):
421            return 'bool'
422        if isinstance(value, (int, long)):
423            return 'int'
424        if isinstance(value, float):
425            return 'float'
426        if isinstance(value, Decimal):
427            return 'num'
428        if isinstance(value, (date, time, datetime, timedelta)):
429            return 'date'
430        if isinstance(value, list):
431            return '%s[]' % cls.guess_simple_base_type(value)
432        if isinstance(value, tuple):
433            simple_type = cls.simple_type
434            typ = simple_type('record')
435            guess = cls.guess_simple_type
436            typ._get_attnames = lambda _self: AttrDict(
437                (str(n + 1), simple_type(guess(v)))
438                for n, v in enumerate(value))
439            return typ
440
441    @classmethod
442    def guess_simple_base_type(cls, value):
443        """Try to guess the base type of a given array."""
444        for v in value:
445            if isinstance(v, list):
446                typ = cls.guess_simple_base_type(v)
447            else:
448                typ = cls.guess_simple_type(v)
449            if typ:
450                return typ
451
452    def adapt_inline(self, value, nested=False):
453        """Adapt a value that is put into the SQL and needs to be quoted."""
454        if value is None:
455            return 'NULL'
456        if isinstance(value, Literal):
457            return value
458        if isinstance(value, Bytea):
459            value = self.escape_bytea(value)
460            if bytes is not str:  # Python >= 3.0
461                value = value.decode('ascii')
462        elif isinstance(value, Json):
463            if value.encode:
464                return value.encode()
465            value = self.encode_json(value)
466        elif isinstance(value, (datetime, date, time, timedelta)):
467            value = str(value)
468        if isinstance(value, basestring):
469            value = self.escape_string(value)
470            return "'%s'" % value
471        if isinstance(value, bool):
472            return 'true' if value else 'false'
473        if isinstance(value, float):
474            if isinf(value):
475                return "'-Infinity'" if value < 0 else "'Infinity'"
476            if isnan(value):
477                return "'NaN'"
478            return value
479        if isinstance(value, (int, long, Decimal)):
480            return value
481        if isinstance(value, list):
482            q = self.adapt_inline
483            s = '[%s]' if nested else 'ARRAY[%s]'
484            return s % ','.join(str(q(v, nested=True)) for v in value)
485        if isinstance(value, tuple):
486            q = self.adapt_inline
487            return '(%s)' % ','.join(str(q(v)) for v in value)
488        try:
489            value = value.__pg_repr__()
490        except AttributeError:
491            raise InterfaceError(
492                'Do not know how to adapt type %s' % type(value))
493        if isinstance(value, (tuple, list)):
494            value = self.adapt_inline(value)
495        return value
496
497    def parameter_list(self):
498        """Return a parameter list for parameters with known database types.
499
500        The list has an add(value, typ) method that will build up the
501        list and return either the literal value or a placeholder.
502        """
503        params = _ParameterList()
504        params.adapt = self.adapt
505        return params
506
507    def format_query(self, command, values, types=None, inline=False):
508        """Format a database query using the given values and types."""
509        if inline and types:
510            raise ValueError('Typed parameters must be sent separately')
511        params = self.parameter_list()
512        if isinstance(values, (list, tuple)):
513            if inline:
514                adapt = self.adapt_inline
515                literals = [adapt(value) for value in values]
516            else:
517                add = params.add
518                literals = []
519                append = literals.append
520                if types:
521                    if (not isinstance(types, (list, tuple)) or
522                            len(types) != len(values)):
523                        raise TypeError('The values and types do not match')
524                    for value, typ in zip(values, types):
525                        append(add(value, typ))
526                else:
527                    for value in values:
528                        append(add(value))
529            command = command % tuple(literals)
530        elif isinstance(values, dict):
531            if inline:
532                adapt = self.adapt_inline
533                literals = dict((key, adapt(value))
534                    for key, value in values.items())
535            else:
536                add = params.add
537                literals = {}
538                if types:
539                    if (not isinstance(types, dict) or
540                            len(types) < len(values)):
541                        raise TypeError('The values and types do not match')
542                    for key in sorted(values):
543                        literals[key] = add(values[key], types[key])
544                else:
545                    for key in sorted(values):
546                        literals[key] = add(values[key])
547            command = command % literals
548        else:
549            raise TypeError('The values must be passed as tuple, list or dict')
550        return command, params
551
552
553def cast_bool(value):
554    """Cast a boolean value."""
555    if not get_bool():
556        return value
557    return value[0] == 't'
558
559
560def cast_json(value):
561    """Cast a JSON value."""
562    cast = get_jsondecode()
563    if not cast:
564        return value
565    return cast(value)
566
567
568def cast_num(value):
569    """Cast a numeric value."""
570    return (get_decimal() or float)(value)
571
572
573def cast_money(value):
574    """Cast a money value."""
575    point = get_decimal_point()
576    if not point:
577        return value
578    if point != '.':
579        value = value.replace(point, '.')
580    value = value.replace('(', '-')
581    value = ''.join(c for c in value if c.isdigit() or c in '.-')
582    return (get_decimal() or float)(value)
583
584
585def cast_int2vector(value):
586    """Cast an int2vector value."""
587    return [int(v) for v in value.split()]
588
589
590class Typecasts(dict):
591    """Dictionary mapping database types to typecast functions.
592
593    The cast functions get passed the string representation of a value in
594    the database which they need to convert to a Python object.  The
595    passed string will never be None since NULL values are already be
596    handled before the cast function is called.
597
598    Note that the basic types are already handled by the C extension.
599    They only need to be handled here as record or array components.
600    """
601
602    # the default cast functions
603    # (str functions are ignored but have been added for faster access)
604    defaults = {'char': str, 'bpchar': str, 'name': str,
605        'text': str, 'varchar': str,
606        'bool': cast_bool, 'bytea': unescape_bytea,
607        'int2': int, 'int4': int, 'serial': int,
608        'int8': long, 'json': cast_json, 'jsonb': cast_json,
609        'oid': long, 'oid8': long,
610        'float4': float, 'float8': float,
611        'numeric': cast_num, 'money': cast_money,
612        'int2vector': cast_int2vector,
613        'anyarray': cast_array, 'record': cast_record}
614
615    def __missing__(self, typ):
616        """Create a cast function if it is not cached.
617       
618        Note that this class never raises a KeyError,
619        but returns None when no special cast function exists.
620        """
621        if not isinstance(typ, str):
622            raise TypeError('Invalid type: %s' % typ)
623        cast = self.defaults.get(typ)
624        if cast:
625            # store default for faster access
626            self[typ] = cast
627        elif typ.startswith('_'):
628            base_cast = self[typ[1:]]
629            cast = self.create_array_cast(base_cast)
630            if base_cast:
631                self[typ] = cast
632        else:
633            attnames = self.get_attnames(typ)
634            if attnames:
635                casts = [self[v.pgtype] for v in attnames.values()]
636                cast = self.create_record_cast(typ, attnames, casts)
637                self[typ] = cast
638        return cast
639
640    def get(self, typ, default=None):
641        """Get the typecast function for the given database type."""
642        return self[typ] or default
643
644    def set(self, typ, cast):
645        """Set a typecast function for the specified database type(s)."""
646        if isinstance(typ, basestring):
647            typ = [typ]
648        if cast is None:
649            for t in typ:
650                self.pop(t, None)
651                self.pop('_%s' % t, None)
652        else:
653            if not callable(cast):
654                raise TypeError("Cast parameter must be callable")
655            for t in typ:
656                self[t] = cast
657                self.pop('_%s' % t, None)
658
659    def reset(self, typ=None):
660        """Reset the typecasts for the specified type(s) to their defaults.
661
662        When no type is specified, all typecasts will be reset.
663        """
664        if typ is None:
665            self.clear()
666        else:
667            if isinstance(typ, basestring):
668                typ = [typ]
669            for t in typ:
670                self.pop(t, None)
671
672    @classmethod
673    def get_default(cls, typ):
674        """Get the default typecast function for the given database type."""
675        return cls.defaults.get(typ)
676
677    @classmethod
678    def set_default(cls, typ, cast):
679        """Set a default typecast function for the given database type(s)."""
680        if isinstance(typ, basestring):
681            typ = [typ]
682        defaults = cls.defaults
683        if cast is None:
684            for t in typ:
685                defaults.pop(t, None)
686                defaults.pop('_%s' % t, None)
687        else:
688            if not callable(cast):
689                raise TypeError("Cast parameter must be callable")
690            for t in typ:
691                defaults[t] = cast
692                defaults.pop('_%s' % t, None)
693
694    def get_attnames(self, typ):
695        """Return the fields for the given record type.
696
697        This method will be replaced with the get_attnames() method of DbTypes.
698        """
699        return {}
700
701    def create_array_cast(self, cast):
702        """Create an array typecast for the given base cast."""
703        return lambda v: cast_array(v, cast)
704
705    def create_record_cast(self, name, fields, casts):
706        """Create a named record typecast for the given fields and casts."""
707        record = namedtuple(name, fields)
708        return lambda v: record(*cast_record(v, casts))
709
710
711def get_typecast(typ):
712    """Get the global typecast function for the given database type(s)."""
713    return Typecasts.get_default(typ)
714
715
716def set_typecast(typ, cast):
717    """Set a global typecast function for the given database type(s).
718
719    Note that connections cache cast functions. To be sure a global change
720    is picked up by a running connection, call db.db_types.reset_typecast().
721    """
722    Typecasts.set_default(typ, cast)
723
724
725class DbType(str):
726    """Class augmenting the simple type name with additional info.
727
728    The following additional information is provided:
729
730        oid: the PostgreSQL type OID
731        pgtype: the PostgreSQL type name
732        regtype: the regular type name
733        simple: the simple PyGreSQL type name
734        typtype: b = base type, c = composite type etc.
735        category: A = Array, b =Boolean, C = Composite etc.
736        delim: delimiter for array types
737        relid: corresponding table for composite types
738        attnames: attributes for composite types
739    """
740
741    @property
742    def attnames(self):
743        """Get names and types of the fields of a composite type."""
744        return self._get_attnames(self)
745
746
747class DbTypes(dict):
748    """Cache for PostgreSQL data types.
749
750    This cache maps type OIDs and names to DbType objects containing
751    information on the associated database type.
752    """
753
754    _num_types = frozenset('int float num money'
755        ' int2 int4 int8 float4 float8 numeric money'.split())
756
757    def __init__(self, db):
758        """Initialize type cache for connection."""
759        super(DbTypes, self).__init__()
760        self._get_attnames = db.get_attnames
761        db = db.db
762        self.query = db.query
763        self.escape_string = db.escape_string
764        self._typecasts = Typecasts()
765        self._typecasts.get_attnames = self.get_attnames
766        self._regtypes = False
767
768    def add(self, oid, pgtype, regtype,
769               typtype, category, delim, relid):
770        """Create a PostgreSQL type name with additional info."""
771        if oid in self:
772            return self[oid]
773        simple = 'record' if relid else _simpletypes[pgtype]
774        typ = DbType(regtype if self._regtypes else simple)
775        typ.oid = oid
776        typ.simple = simple
777        typ.pgtype = pgtype
778        typ.regtype = regtype
779        typ.typtype = typtype
780        typ.category = category
781        typ.delim = delim
782        typ.relid = relid
783        typ._get_attnames = self.get_attnames
784        return typ
785
786    def __missing__(self, key):
787        """Get the type info from the database if it is not cached."""
788        try:
789            res = self.query("SELECT oid, typname, typname::regtype,"
790                " typtype, typcategory, typdelim, typrelid"
791                " FROM pg_type WHERE oid=%s::regtype" %
792                (_quote_if_unqualified('$1', key),), (key,)).getresult()
793        except ProgrammingError:
794            res = None
795        if not res:
796            raise KeyError('Type %s could not be found' % key)
797        res = res[0]
798        typ = self.add(*res)
799        self[typ.oid] = self[typ.pgtype] = typ
800        return typ
801
802    def get(self, key, default=None):
803        """Get the type even if it is not cached."""
804        try:
805            return self[key]
806        except KeyError:
807            return default
808
809    def get_attnames(self, typ):
810        """Get names and types of the fields of a composite type."""
811        if not isinstance(typ, DbType):
812            typ = self.get(typ)
813            if not typ:
814                return None
815        if not typ.relid:
816            return None
817        return self._get_attnames(typ.relid, with_oid=False)
818
819    def get_typecast(self, typ):
820        """Get the typecast function for the given database type."""
821        return self._typecasts.get(typ)
822
823    def set_typecast(self, typ, cast):
824        """Set a typecast function for the specified database type(s)."""
825        self._typecasts.set(typ, cast)
826
827    def reset_typecast(self, typ=None):
828        """Reset the typecast function for the specified database type(s)."""
829        self._typecasts.reset(typ)
830
831    def typecast(self, value, typ):
832        """Cast the given value according to the given database type."""
833        if value is None:
834            # for NULL values, no typecast is necessary
835            return None
836        if not isinstance(typ, DbType):
837            typ = self.get(typ)
838            if typ:
839                typ = typ.pgtype
840        cast = self.get_typecast(typ) if typ else None
841        if not cast or cast is str:
842            # no typecast is necessary
843            return value
844        return cast(value)
845
846
847def _namedresult(q):
848    """Get query result as named tuples."""
849    row = namedtuple('Row', q.listfields())
850    return [row(*r) for r in q.getresult()]
851
852
853class _MemoryQuery:
854    """Class that embodies a given query result."""
855
856    def __init__(self, result, fields):
857        """Create query from given result rows and field names."""
858        self.result = result
859        self.fields = fields
860
861    def listfields(self):
862        """Return the stored field names of this query."""
863        return self.fields
864
865    def getresult(self):
866        """Return the stored result of this query."""
867        return self.result
868
869
870def _db_error(msg, cls=DatabaseError):
871    """Return DatabaseError with empty sqlstate attribute."""
872    error = cls(msg)
873    error.sqlstate = None
874    return error
875
876
877def _int_error(msg):
878    """Return InternalError."""
879    return _db_error(msg, InternalError)
880
881
882def _prg_error(msg):
883    """Return ProgrammingError."""
884    return _db_error(msg, ProgrammingError)
885
886
887# Initialize the C module
888
889set_namedresult(_namedresult)
890set_decimal(Decimal)
891set_jsondecode(jsondecode)
892
893
894# The notification handler
895
896class NotificationHandler(object):
897    """A PostgreSQL client-side asynchronous notification handler."""
898
899    def __init__(self, db, event, callback=None,
900            arg_dict=None, timeout=None, stop_event=None):
901        """Initialize the notification handler.
902
903        You must pass a PyGreSQL database connection, the name of an
904        event (notification channel) to listen for and a callback function.
905
906        You can also specify a dictionary arg_dict that will be passed as
907        the single argument to the callback function, and a timeout value
908        in seconds (a floating point number denotes fractions of seconds).
909        If it is absent or None, the callers will never time out.  If the
910        timeout is reached, the callback function will be called with a
911        single argument that is None.  If you set the timeout to zero,
912        the handler will poll notifications synchronously and return.
913
914        You can specify the name of the event that will be used to signal
915        the handler to stop listening as stop_event. By default, it will
916        be the event name prefixed with 'stop_'.
917        """
918        self.db = db
919        self.event = event
920        self.stop_event = stop_event or 'stop_%s' % event
921        self.listening = False
922        self.callback = callback
923        if arg_dict is None:
924            arg_dict = {}
925        self.arg_dict = arg_dict
926        self.timeout = timeout
927
928    def __del__(self):
929        self.unlisten()
930
931    def close(self):
932        """Stop listening and close the connection."""
933        if self.db:
934            self.unlisten()
935            self.db.close()
936            self.db = None
937
938    def listen(self):
939        """Start listening for the event and the stop event."""
940        if not self.listening:
941            self.db.query('listen "%s"' % self.event)
942            self.db.query('listen "%s"' % self.stop_event)
943            self.listening = True
944
945    def unlisten(self):
946        """Stop listening for the event and the stop event."""
947        if self.listening:
948            self.db.query('unlisten "%s"' % self.event)
949            self.db.query('unlisten "%s"' % self.stop_event)
950            self.listening = False
951
952    def notify(self, db=None, stop=False, payload=None):
953        """Generate a notification.
954
955        Optionally, you can pass a payload with the notification.
956
957        If you set the stop flag, a stop notification will be sent that
958        will cause the handler to stop listening.
959
960        Note: If the notification handler is running in another thread, you
961        must pass a different database connection since PyGreSQL database
962        connections are not thread-safe.
963        """
964        if self.listening:
965            if not db:
966                db = self.db
967            q = 'notify "%s"' % (self.stop_event if stop else self.event)
968            if payload:
969                q += ", '%s'" % payload
970            return db.query(q)
971
972    def __call__(self):
973        """Invoke the notification handler.
974
975        The handler is a loop that listens for notifications on the event
976        and stop event channels.  When either of these notifications are
977        received, its associated 'pid', 'event' and 'extra' (the payload
978        passed with the notification) are inserted into its arg_dict
979        dictionary and the callback is invoked with this dictionary as
980        a single argument.  When the handler receives a stop event, it
981        stops listening to both events and return.
982
983        In the special case that the timeout of the handler has been set
984        to zero, the handler will poll all events synchronously and return.
985        If will keep listening until it receives a stop event.
986
987        Note: If you run this loop in another thread, don't use the same
988        database connection for database operations in the main thread.
989        """
990        self.listen()
991        poll = self.timeout == 0
992        if not poll:
993            rlist = [self.db.fileno()]
994        while self.listening:
995            if poll or select.select(rlist, [], [], self.timeout)[0]:
996                while self.listening:
997                    notice = self.db.getnotify()
998                    if not notice:  # no more messages
999                        break
1000                    event, pid, extra = notice
1001                    if event not in (self.event, self.stop_event):
1002                        self.unlisten()
1003                        raise _db_error(
1004                            'Listening for "%s" and "%s", but notified of "%s"'
1005                            % (self.event, self.stop_event, event))
1006                    if event == self.stop_event:
1007                        self.unlisten()
1008                    self.arg_dict.update(pid=pid, event=event, extra=extra)
1009                    self.callback(self.arg_dict)
1010                if poll:
1011                    break
1012            else:   # we timed out
1013                self.unlisten()
1014                self.callback(None)
1015
1016
1017def pgnotify(*args, **kw):
1018    """Same as NotificationHandler, under the traditional name."""
1019    warnings.warn("pgnotify is deprecated, use NotificationHandler instead",
1020        DeprecationWarning, stacklevel=2)
1021    return NotificationHandler(*args, **kw)
1022
1023
1024# The actual PostGreSQL database connection interface:
1025
1026class DB:
1027    """Wrapper class for the _pg connection type."""
1028
1029    def __init__(self, *args, **kw):
1030        """Create a new connection
1031
1032        You can pass either the connection parameters or an existing
1033        _pg or pgdb connection. This allows you to use the methods
1034        of the classic pg interface with a DB-API 2 pgdb connection.
1035        """
1036        if not args and len(kw) == 1:
1037            db = kw.get('db')
1038        elif not kw and len(args) == 1:
1039            db = args[0]
1040        else:
1041            db = None
1042        if db:
1043            if isinstance(db, DB):
1044                db = db.db
1045            else:
1046                try:
1047                    db = db._cnx
1048                except AttributeError:
1049                    pass
1050        if not db or not hasattr(db, 'db') or not hasattr(db, 'query'):
1051            db = connect(*args, **kw)
1052            self._closeable = True
1053        else:
1054            self._closeable = False
1055        self.db = db
1056        self.dbname = db.db
1057        self._regtypes = False
1058        self._attnames = {}
1059        self._pkeys = {}
1060        self._privileges = {}
1061        self._args = args, kw
1062        self.adapter = Adapter(self)
1063        self.dbtypes = DbTypes(self)
1064        db.set_cast_hook(self.dbtypes.typecast)
1065        self.debug = None  # For debugging scripts, this can be set
1066            # * to a string format specification (e.g. in CGI set to "%s<BR>"),
1067            # * to a file object to write debug statements or
1068            # * to a callable object which takes a string argument
1069            # * to any other true value to just print debug statements
1070
1071    def __getattr__(self, name):
1072        # All undefined members are same as in underlying connection:
1073        if self.db:
1074            return getattr(self.db, name)
1075        else:
1076            raise _int_error('Connection is not valid')
1077
1078    def __dir__(self):
1079        # Custom dir function including the attributes of the connection:
1080        attrs = set(self.__class__.__dict__)
1081        attrs.update(self.__dict__)
1082        attrs.update(dir(self.db))
1083        return sorted(attrs)
1084
1085    # Context manager methods
1086
1087    def __enter__(self):
1088        """Enter the runtime context. This will start a transactio."""
1089        self.begin()
1090        return self
1091
1092    def __exit__(self, et, ev, tb):
1093        """Exit the runtime context. This will end the transaction."""
1094        if et is None and ev is None and tb is None:
1095            self.commit()
1096        else:
1097            self.rollback()
1098
1099    # Auxiliary methods
1100
1101    def _do_debug(self, *args):
1102        """Print a debug message"""
1103        if self.debug:
1104            s = '\n'.join(str(arg) for arg in args)
1105            if isinstance(self.debug, basestring):
1106                print(self.debug % s)
1107            elif hasattr(self.debug, 'write'):
1108                self.debug.write(s + '\n')
1109            elif callable(self.debug):
1110                self.debug(s)
1111            else:
1112                print(s)
1113
1114    def _escape_qualified_name(self, s):
1115        """Escape a qualified name.
1116
1117        Escapes the name for use as an SQL identifier, unless the
1118        name contains a dot, in which case the name is ambiguous
1119        (could be a qualified name or just a name with a dot in it)
1120        and must be quoted manually by the caller.
1121        """
1122        if '.' not in s:
1123            s = self.escape_identifier(s)
1124        return s
1125
1126    @staticmethod
1127    def _make_bool(d):
1128        """Get boolean value corresponding to d."""
1129        return bool(d) if get_bool() else ('t' if d else 'f')
1130
1131    def _list_params(self, params):
1132        """Create a human readable parameter list."""
1133        return ', '.join('$%d=%r' % (n, v) for n, v in enumerate(params, 1))
1134
1135    # Public methods
1136
1137    # escape_string and escape_bytea exist as methods,
1138    # so we define unescape_bytea as a method as well
1139    unescape_bytea = staticmethod(unescape_bytea)
1140
1141    def decode_json(self, s):
1142        """Decode a JSON string coming from the database."""
1143        return (get_jsondecode() or jsondecode)(s)
1144
1145    def encode_json(self, d):
1146        """Encode a JSON string for use within SQL."""
1147        return jsonencode(d)
1148
1149    def close(self):
1150        """Close the database connection."""
1151        # Wraps shared library function so we can track state.
1152        if self._closeable:
1153            if self.db:
1154                self.db.close()
1155                self.db = None
1156            else:
1157                raise _int_error('Connection already closed')
1158
1159    def reset(self):
1160        """Reset connection with current parameters.
1161
1162        All derived queries and large objects derived from this connection
1163        will not be usable after this call.
1164
1165        """
1166        if self.db:
1167            self.db.reset()
1168        else:
1169            raise _int_error('Connection already closed')
1170
1171    def reopen(self):
1172        """Reopen connection to the database.
1173
1174        Used in case we need another connection to the same database.
1175        Note that we can still reopen a database that we have closed.
1176
1177        """
1178        # There is no such shared library function.
1179        if self._closeable:
1180            db = connect(*self._args[0], **self._args[1])
1181            if self.db:
1182                self.db.close()
1183            self.db = db
1184
1185    def begin(self, mode=None):
1186        """Begin a transaction."""
1187        qstr = 'BEGIN'
1188        if mode:
1189            qstr += ' ' + mode
1190        return self.query(qstr)
1191
1192    start = begin
1193
1194    def commit(self):
1195        """Commit the current transaction."""
1196        return self.query('COMMIT')
1197
1198    end = commit
1199
1200    def rollback(self, name=None):
1201        """Roll back the current transaction."""
1202        qstr = 'ROLLBACK'
1203        if name:
1204            qstr += ' TO ' + name
1205        return self.query(qstr)
1206
1207    abort = rollback
1208
1209    def savepoint(self, name):
1210        """Define a new savepoint within the current transaction."""
1211        return self.query('SAVEPOINT ' + name)
1212
1213    def release(self, name):
1214        """Destroy a previously defined savepoint."""
1215        return self.query('RELEASE ' + name)
1216
1217    def get_parameter(self, parameter):
1218        """Get the value of a run-time parameter.
1219
1220        If the parameter is a string, the return value will also be a string
1221        that is the current setting of the run-time parameter with that name.
1222
1223        You can get several parameters at once by passing a list, set or dict.
1224        When passing a list of parameter names, the return value will be a
1225        corresponding list of parameter settings.  When passing a set of
1226        parameter names, a new dict will be returned, mapping these parameter
1227        names to their settings.  Finally, if you pass a dict as parameter,
1228        its values will be set to the current parameter settings corresponding
1229        to its keys.
1230
1231        By passing the special name 'all' as the parameter, you can get a dict
1232        of all existing configuration parameters.
1233        """
1234        if isinstance(parameter, basestring):
1235            parameter = [parameter]
1236            values = None
1237        elif isinstance(parameter, (list, tuple)):
1238            values = []
1239        elif isinstance(parameter, (set, frozenset)):
1240            values = {}
1241        elif isinstance(parameter, dict):
1242            values = parameter
1243        else:
1244            raise TypeError(
1245                'The parameter must be a string, list, set or dict')
1246        if not parameter:
1247            raise TypeError('No parameter has been specified')
1248        params = {} if isinstance(values, dict) else []
1249        for key in parameter:
1250            param = key.strip().lower() if isinstance(
1251                key, basestring) else None
1252            if not param:
1253                raise TypeError('Invalid parameter')
1254            if param == 'all':
1255                q = 'SHOW ALL'
1256                values = self.db.query(q).getresult()
1257                values = dict(value[:2] for value in values)
1258                break
1259            if isinstance(values, dict):
1260                params[param] = key
1261            else:
1262                params.append(param)
1263        else:
1264            for param in params:
1265                q = 'SHOW %s' % (param,)
1266                value = self.db.query(q).getresult()[0][0]
1267                if values is None:
1268                    values = value
1269                elif isinstance(values, list):
1270                    values.append(value)
1271                else:
1272                    values[params[param]] = value
1273        return values
1274
1275    def set_parameter(self, parameter, value=None, local=False):
1276        """Set the value of a run-time parameter.
1277
1278        If the parameter and the value are strings, the run-time parameter
1279        will be set to that value.  If no value or None is passed as a value,
1280        then the run-time parameter will be restored to its default value.
1281
1282        You can set several parameters at once by passing a list of parameter
1283        names, together with a single value that all parameters should be
1284        set to or with a corresponding list of values.  You can also pass
1285        the parameters as a set if you only provide a single value.
1286        Finally, you can pass a dict with parameter names as keys.  In this
1287        case, you should not pass a value, since the values for the parameters
1288        will be taken from the dict.
1289
1290        By passing the special name 'all' as the parameter, you can reset
1291        all existing settable run-time parameters to their default values.
1292
1293        If you set local to True, then the command takes effect for only the
1294        current transaction.  After commit() or rollback(), the session-level
1295        setting takes effect again.  Setting local to True will appear to
1296        have no effect if it is executed outside a transaction, since the
1297        transaction will end immediately.
1298        """
1299        if isinstance(parameter, basestring):
1300            parameter = {parameter: value}
1301        elif isinstance(parameter, (list, tuple)):
1302            if isinstance(value, (list, tuple)):
1303                parameter = dict(zip(parameter, value))
1304            else:
1305                parameter = dict.fromkeys(parameter, value)
1306        elif isinstance(parameter, (set, frozenset)):
1307            if isinstance(value, (list, tuple, set, frozenset)):
1308                value = set(value)
1309                if len(value) == 1:
1310                    value = value.pop()
1311            if not(value is None or isinstance(value, basestring)):
1312                raise ValueError('A single value must be specified'
1313                    ' when parameter is a set')
1314            parameter = dict.fromkeys(parameter, value)
1315        elif isinstance(parameter, dict):
1316            if value is not None:
1317                raise ValueError('A value must not be specified'
1318                    ' when parameter is a dictionary')
1319        else:
1320            raise TypeError(
1321                'The parameter must be a string, list, set or dict')
1322        if not parameter:
1323            raise TypeError('No parameter has been specified')
1324        params = {}
1325        for key, value in parameter.items():
1326            param = key.strip().lower() if isinstance(
1327                key, basestring) else None
1328            if not param:
1329                raise TypeError('Invalid parameter')
1330            if param == 'all':
1331                if value is not None:
1332                    raise ValueError('A value must ot be specified'
1333                        " when parameter is 'all'")
1334                params = {'all': None}
1335                break
1336            params[param] = value
1337        local = ' LOCAL' if local else ''
1338        for param, value in params.items():
1339            if value is None:
1340                q = 'RESET%s %s' % (local, param)
1341            else:
1342                q = 'SET%s %s TO %s' % (local, param, value)
1343            self._do_debug(q)
1344            self.db.query(q)
1345
1346    def query(self, command, *args):
1347        """Execute a SQL command string.
1348
1349        This method simply sends a SQL query to the database.  If the query is
1350        an insert statement that inserted exactly one row into a table that
1351        has OIDs, the return value is the OID of the newly inserted row.
1352        If the query is an update or delete statement, or an insert statement
1353        that did not insert exactly one row in a table with OIDs, then the
1354        number of rows affected is returned as a string.  If it is a statement
1355        that returns rows as a result (usually a select statement, but maybe
1356        also an "insert/update ... returning" statement), this method returns
1357        a Query object that can be accessed via getresult() or dictresult()
1358        or simply printed.  Otherwise, it returns `None`.
1359
1360        The query can contain numbered parameters of the form $1 in place
1361        of any data constant.  Arguments given after the query string will
1362        be substituted for the corresponding numbered parameter.  Parameter
1363        values can also be given as a single list or tuple argument.
1364        """
1365        # Wraps shared library function for debugging.
1366        if not self.db:
1367            raise _int_error('Connection is not valid')
1368        if args:
1369            self._do_debug(command, args)
1370            return self.db.query(command, args)
1371        self._do_debug(command)
1372        return self.db.query(command)
1373
1374    def query_formatted(self, command, parameters, types=None, inline=False):
1375        """Execute a formatted SQL command string.
1376
1377        Similar to query, but using Python format placeholders of the form
1378        %s or %(names)s instead of PostgreSQL placeholders of the form $1.
1379        The parameters must be passed as a tuple, list or dict.  You can
1380        also pass a corresponding tuple, list or dict of database types in
1381        order to format the parameters properly in case there is ambiguity.
1382
1383        If you set inline to True, the parameters will be sent to the database
1384        embedded in the SQL command, otherwise they will be sent separately.
1385        """
1386        return self.query(*self.adapter.format_query(
1387            command, parameters, types, inline))
1388
1389    def pkey(self, table, composite=False, flush=False):
1390        """Get or set the primary key of a table.
1391
1392        Single primary keys are returned as strings unless you
1393        set the composite flag.  Composite primary keys are always
1394        represented as tuples.  Note that this raises a KeyError
1395        if the table does not have a primary key.
1396
1397        If flush is set then the internal cache for primary keys will
1398        be flushed.  This may be necessary after the database schema or
1399        the search path has been changed.
1400        """
1401        pkeys = self._pkeys
1402        if flush:
1403            pkeys.clear()
1404            self._do_debug('The pkey cache has been flushed')
1405        try:  # cache lookup
1406            pkey = pkeys[table]
1407        except KeyError:  # cache miss, check the database
1408            q = ("SELECT a.attname, a.attnum, i.indkey FROM pg_index i"
1409                " JOIN pg_attribute a ON a.attrelid = i.indrelid"
1410                " AND a.attnum = ANY(i.indkey)"
1411                " AND NOT a.attisdropped"
1412                " WHERE i.indrelid=%s::regclass"
1413                " AND i.indisprimary ORDER BY a.attnum") % (
1414                    _quote_if_unqualified('$1', table),)
1415            pkey = self.db.query(q, (table,)).getresult()
1416            if not pkey:
1417                raise KeyError('Table %s has no primary key' % table)
1418            # we want to use the order defined in the primary key index here,
1419            # not the order as defined by the columns in the table
1420            if len(pkey) > 1:
1421                indkey = pkey[0][2]
1422                pkey = sorted(pkey, key=lambda row: indkey.index(row[1]))
1423                pkey = tuple(row[0] for row in pkey)
1424            else:
1425                pkey = pkey[0][0]
1426            pkeys[table] = pkey  # cache it
1427        if composite and not isinstance(pkey, tuple):
1428            pkey = (pkey,)
1429        return pkey
1430
1431    def get_databases(self):
1432        """Get list of databases in the system."""
1433        return [s[0] for s in
1434            self.db.query('SELECT datname FROM pg_database').getresult()]
1435
1436    def get_relations(self, kinds=None):
1437        """Get list of relations in connected database of specified kinds.
1438
1439        If kinds is None or empty, all kinds of relations are returned.
1440        Otherwise kinds can be a string or sequence of type letters
1441        specifying which kind of relations you want to list.
1442        """
1443        where = " AND r.relkind IN (%s)" % ','.join(
1444            ["'%s'" % k for k in kinds]) if kinds else ''
1445        q = ("SELECT quote_ident(s.nspname)||'.'||quote_ident(r.relname)"
1446            " FROM pg_class r"
1447            " JOIN pg_namespace s ON s.oid = r.relnamespace"
1448            " WHERE s.nspname NOT SIMILAR"
1449            " TO 'pg/_%%|information/_schema' ESCAPE '/' %s"
1450            " ORDER BY s.nspname, r.relname") % where
1451        return [r[0] for r in self.db.query(q).getresult()]
1452
1453    def get_tables(self):
1454        """Return list of tables in connected database."""
1455        return self.get_relations('r')
1456
1457    def get_attnames(self, table, with_oid=True, flush=False):
1458        """Given the name of a table, dig out the set of attribute names.
1459
1460        Returns a read-only dictionary of attribute names (the names are
1461        the keys, the values are the names of the attributes' types)
1462        with the column names in the proper order if you iterate over it.
1463
1464        If flush is set, then the internal cache for attribute names will
1465        be flushed. This may be necessary after the database schema or
1466        the search path has been changed.
1467
1468        By default, only a limited number of simple types will be returned.
1469        You can get the regular types after calling use_regtypes(True).
1470        """
1471        attnames = self._attnames
1472        if flush:
1473            attnames.clear()
1474            self._do_debug('The attnames cache has been flushed')
1475        try:  # cache lookup
1476            names = attnames[table]
1477        except KeyError:  # cache miss, check the database
1478            q = "a.attnum > 0"
1479            if with_oid:
1480                q = "(%s OR a.attname = 'oid')" % q
1481            q = ("SELECT a.attname, t.oid, t.typname, t.typname::regtype,"
1482                " t.typtype, t.typcategory, t.typdelim, t.typrelid"
1483                " FROM pg_attribute a"
1484                " JOIN pg_type t ON t.oid = a.atttypid"
1485                " WHERE a.attrelid = %s::regclass AND %s"
1486                " AND NOT a.attisdropped ORDER BY a.attnum") % (
1487                    _quote_if_unqualified('$1', table), q)
1488            names = self.db.query(q, (table,)).getresult()
1489            types = self.dbtypes
1490            names = ((name[0], types.add(*name[1:])) for name in names)
1491            names = AttrDict(names)
1492            attnames[table] = names  # cache it
1493        return names
1494
1495    def use_regtypes(self, regtypes=None):
1496        """Use regular type names instead of simplified type names."""
1497        if regtypes is None:
1498            return self.dbtypes._regtypes
1499        else:
1500            regtypes = bool(regtypes)
1501            if regtypes != self.dbtypes._regtypes:
1502                self.dbtypes._regtypes = regtypes
1503                self._attnames.clear()
1504                self.dbtypes.clear()
1505            return regtypes
1506
1507    def has_table_privilege(self, table, privilege='select'):
1508        """Check whether current user has specified table privilege."""
1509        privilege = privilege.lower()
1510        try:  # ask cache
1511            return self._privileges[(table, privilege)]
1512        except KeyError:  # cache miss, ask the database
1513            q = "SELECT has_table_privilege(%s, $2)" % (
1514                _quote_if_unqualified('$1', table),)
1515            q = self.db.query(q, (table, privilege))
1516            ret = q.getresult()[0][0] == self._make_bool(True)
1517            self._privileges[(table, privilege)] = ret  # cache it
1518            return ret
1519
1520    def get(self, table, row, keyname=None):
1521        """Get a row from a database table or view.
1522
1523        This method is the basic mechanism to get a single row.  It assumes
1524        that the keyname specifies a unique row.  It must be the name of a
1525        single column or a tuple of column names.  If the keyname is not
1526        specified, then the primary key for the table is used.
1527
1528        If row is a dictionary, then the value for the key is taken from it.
1529        Otherwise, the row must be a single value or a tuple of values
1530        corresponding to the passed keyname or primary key.  The fetched row
1531        from the table will be returned as a new dictionary or used to replace
1532        the existing values when row was passed as aa dictionary.
1533
1534        The OID is also put into the dictionary if the table has one, but
1535        in order to allow the caller to work with multiple tables, it is
1536        munged as "oid(table)" using the actual name of the table.
1537        """
1538        if table.endswith('*'):  # hint for descendant tables can be ignored
1539            table = table[:-1].rstrip()
1540        attnames = self.get_attnames(table)
1541        qoid = _oid_key(table) if 'oid' in attnames else None
1542        if keyname and isinstance(keyname, basestring):
1543            keyname = (keyname,)
1544        if qoid and isinstance(row, dict) and qoid in row and 'oid' not in row:
1545            row['oid'] = row[qoid]
1546        if not keyname:
1547            try:  # if keyname is not specified, try using the primary key
1548                keyname = self.pkey(table, True)
1549            except KeyError:  # the table has no primary key
1550                # try using the oid instead
1551                if qoid and isinstance(row, dict) and 'oid' in row:
1552                    keyname = ('oid',)
1553                else:
1554                    raise _prg_error('Table %s has no primary key' % table)
1555            else:  # the table has a primary key
1556                # check whether all key columns have values
1557                if isinstance(row, dict) and not set(keyname).issubset(row):
1558                    # try using the oid instead
1559                    if qoid and 'oid' in row:
1560                        keyname = ('oid',)
1561                    else:
1562                        raise KeyError(
1563                            'Missing value in row for specified keyname')
1564        if not isinstance(row, dict):
1565            if not isinstance(row, (tuple, list)):
1566                row = [row]
1567            if len(keyname) != len(row):
1568                raise KeyError(
1569                    'Differing number of items in keyname and row')
1570            row = dict(zip(keyname, row))
1571        params = self.adapter.parameter_list()
1572        adapt = params.add
1573        col = self.escape_identifier
1574        what = 'oid, *' if qoid else '*'
1575        where = ' AND '.join('%s = %s' % (
1576            col(k), adapt(row[k], attnames[k])) for k in keyname)
1577        if 'oid' in row:
1578            if qoid:
1579                row[qoid] = row['oid']
1580            del row['oid']
1581        q = 'SELECT %s FROM %s WHERE %s LIMIT 1' % (
1582            what, self._escape_qualified_name(table), where)
1583        self._do_debug(q, params)
1584        q = self.db.query(q, params)
1585        res = q.dictresult()
1586        if not res:
1587            raise _db_error('No such record in %s\nwhere %s\nwith %s' % (
1588                table, where, self._list_params(params)))
1589        for n, value in res[0].items():
1590            if qoid and n == 'oid':
1591                n = qoid
1592            row[n] = value
1593        return row
1594
1595    def insert(self, table, row=None, **kw):
1596        """Insert a row into a database table.
1597
1598        This method inserts a row into a table.  The name of the table must
1599        be passed as the first parameter.  The other parameters are used for
1600        providing the data of the row that shall be inserted into the table.
1601        If a dictionary is supplied as the second parameter, it starts with
1602        that.  Otherwise it uses a blank dictionary. Either way the dictionary
1603        is updated from the keywords.
1604
1605        The dictionary is then reloaded with the values actually inserted in
1606        order to pick up values modified by rules, triggers, etc.
1607        """
1608        if table.endswith('*'):  # hint for descendant tables can be ignored
1609            table = table[:-1].rstrip()
1610        if row is None:
1611            row = {}
1612        row.update(kw)
1613        if 'oid' in row:
1614            del row['oid']  # do not insert oid
1615        attnames = self.get_attnames(table)
1616        qoid = _oid_key(table) if 'oid' in attnames else None
1617        params = self.adapter.parameter_list()
1618        adapt = params.add
1619        col = self.escape_identifier
1620        names, values = [], []
1621        for n in attnames:
1622            if n in row:
1623                names.append(col(n))
1624                values.append(adapt(row[n], attnames[n]))
1625        if not names:
1626            raise _prg_error('No column found that can be inserted')
1627        names, values = ', '.join(names), ', '.join(values)
1628        ret = 'oid, *' if qoid else '*'
1629        q = 'INSERT INTO %s (%s) VALUES (%s) RETURNING %s' % (
1630            self._escape_qualified_name(table), names, values, ret)
1631        self._do_debug(q, params)
1632        q = self.db.query(q, params)
1633        res = q.dictresult()
1634        if res:  # this should always be true
1635            for n, value in res[0].items():
1636                if qoid and n == 'oid':
1637                    n = qoid
1638                row[n] = value
1639        return row
1640
1641    def update(self, table, row=None, **kw):
1642        """Update an existing row in a database table.
1643
1644        Similar to insert but updates an existing row.  The update is based
1645        on the primary key of the table or the OID value as munged by get
1646        or passed as keyword.
1647
1648        The dictionary is then modified to reflect any changes caused by the
1649        update due to triggers, rules, default values, etc.
1650        """
1651        if table.endswith('*'):
1652            table = table[:-1].rstrip()  # need parent table name
1653        attnames = self.get_attnames(table)
1654        qoid = _oid_key(table) if 'oid' in attnames else None
1655        if row is None:
1656            row = {}
1657        elif 'oid' in row:
1658            del row['oid']  # only accept oid key from named args for safety
1659        row.update(kw)
1660        if qoid and qoid in row and 'oid' not in row:
1661            row['oid'] = row[qoid]
1662        try:  # try using the primary key
1663            keyname = self.pkey(table, True)
1664        except KeyError:  # the table has no primary key
1665            # try using the oid instead
1666            if qoid and 'oid' in row:
1667                keyname = ('oid',)
1668            else:
1669                raise _prg_error('Table %s has no primary key' % table)
1670        else:  # the table has a primary key
1671            # check whether all key columns have values
1672            if not set(keyname).issubset(row):
1673                # try using the oid instead
1674                if qoid and 'oid' in row:
1675                    keyname = ('oid',)
1676                else:
1677                    raise KeyError('Missing primary key in row')
1678        params = self.adapter.parameter_list()
1679        adapt = params.add
1680        col = self.escape_identifier
1681        where = ' AND '.join('%s = %s' % (
1682            col(k), adapt(row[k], attnames[k])) for k in keyname)
1683        if 'oid' in row:
1684            if qoid:
1685                row[qoid] = row['oid']
1686            del row['oid']
1687        values = []
1688        keyname = set(keyname)
1689        for n in attnames:
1690            if n in row and n not in keyname:
1691                values.append('%s = %s' % (col(n), adapt(row[n], attnames[n])))
1692        if not values:
1693            return row
1694        values = ', '.join(values)
1695        ret = 'oid, *' if qoid else '*'
1696        q = 'UPDATE %s SET %s WHERE %s RETURNING %s' % (
1697            self._escape_qualified_name(table), values, where, ret)
1698        self._do_debug(q, params)
1699        q = self.db.query(q, params)
1700        res = q.dictresult()
1701        if res:  # may be empty when row does not exist
1702            for n, value in res[0].items():
1703                if qoid and n == 'oid':
1704                    n = qoid
1705                row[n] = value
1706        return row
1707
1708    def upsert(self, table, row=None, **kw):
1709        """Insert a row into a database table with conflict resolution
1710
1711        This method inserts a row into a table, but instead of raising a
1712        ProgrammingError exception in case a row with the same primary key
1713        already exists, an update will be executed instead.  This will be
1714        performed as a single atomic operation on the database, so race
1715        conditions can be avoided.
1716
1717        Like the insert method, the first parameter is the name of the
1718        table and the second parameter can be used to pass the values to
1719        be inserted as a dictionary.
1720
1721        Unlike the insert und update statement, keyword parameters are not
1722        used to modify the dictionary, but to specify which columns shall
1723        be updated in case of a conflict, and in which way:
1724
1725        A value of False or None means the column shall not be updated,
1726        a value of True means the column shall be updated with the value
1727        that has been proposed for insertion, i.e. has been passed as value
1728        in the dictionary.  Columns that are not specified by keywords but
1729        appear as keys in the dictionary are also updated like in the case
1730        keywords had been passed with the value True.
1731
1732        So if in the case of a conflict you want to update every column that
1733        has been passed in the dictionary row , you would call upsert(table, row).
1734        If you don't want to do anything in case of a conflict, i.e. leave
1735        the existing row as it is, call upsert(table, row, **dict.fromkeys(row)).
1736
1737        If you need more fine-grained control of what gets updated, you can
1738        also pass strings in the keyword parameters.  These strings will
1739        be used as SQL expressions for the update columns.  In these
1740        expressions you can refer to the value that already exists in
1741        the table by prefixing the column name with "included.", and to
1742        the value that has been proposed for insertion by prefixing the
1743        column name with the "excluded."
1744
1745        The dictionary is modified in any case to reflect the values in
1746        the database after the operation has completed.
1747
1748        Note: The method uses the PostgreSQL "upsert" feature which is
1749        only available since PostgreSQL 9.5.
1750        """
1751        if table.endswith('*'):  # hint for descendant tables can be ignored
1752            table = table[:-1].rstrip()
1753        if row is None:
1754            row = {}
1755        if 'oid' in row:
1756            del row['oid']  # do not insert oid
1757        if 'oid' in kw:
1758            del kw['oid']  # do not update oid
1759        attnames = self.get_attnames(table)
1760        qoid = _oid_key(table) if 'oid' in attnames else None
1761        params = self.adapter.parameter_list()
1762        adapt = params.add
1763        col = self.escape_identifier
1764        names, values, updates = [], [], []
1765        for n in attnames:
1766            if n in row:
1767                names.append(col(n))
1768                values.append(adapt(row[n], attnames[n]))
1769        names, values = ', '.join(names), ', '.join(values)
1770        try:
1771            keyname = self.pkey(table, True)
1772        except KeyError:
1773            raise _prg_error('Table %s has no primary key' % table)
1774        target = ', '.join(col(k) for k in keyname)
1775        update = []
1776        keyname = set(keyname)
1777        keyname.add('oid')
1778        for n in attnames:
1779            if n not in keyname:
1780                value = kw.get(n, True)
1781                if value:
1782                    if not isinstance(value, basestring):
1783                        value = 'excluded.%s' % col(n)
1784                    update.append('%s = %s' % (col(n), value))
1785        if not values:
1786            return row
1787        do = 'update set %s' % ', '.join(update) if update else 'nothing'
1788        ret = 'oid, *' if qoid else '*'
1789        q = ('INSERT INTO %s AS included (%s) VALUES (%s)'
1790            ' ON CONFLICT (%s) DO %s RETURNING %s') % (
1791                self._escape_qualified_name(table), names, values,
1792                target, do, ret)
1793        self._do_debug(q, params)
1794        try:
1795            q = self.db.query(q, params)
1796        except ProgrammingError:
1797            if self.server_version < 90500:
1798                raise _prg_error(
1799                    'Upsert operation is not supported by PostgreSQL version')
1800            raise  # re-raise original error
1801        res = q.dictresult()
1802        if res:  # may be empty with "do nothing"
1803            for n, value in res[0].items():
1804                if qoid and n == 'oid':
1805                    n = qoid
1806                row[n] = value
1807        else:
1808            self.get(table, row)
1809        return row
1810
1811    def clear(self, table, row=None):
1812        """Clear all the attributes to values determined by the types.
1813
1814        Numeric types are set to 0, Booleans are set to false, and everything
1815        else is set to the empty string.  If the row argument is present,
1816        it is used as the row dictionary and any entries matching attribute
1817        names are cleared with everything else left unchanged.
1818        """
1819        # At some point we will need a way to get defaults from a table.
1820        if row is None:
1821            row = {}  # empty if argument is not present
1822        attnames = self.get_attnames(table)
1823        for n, t in attnames.items():
1824            if n == 'oid':
1825                continue
1826            t = t.simple
1827            if t in DbTypes._num_types:
1828                row[n] = 0
1829            elif t == 'bool':
1830                row[n] = self._make_bool(False)
1831            else:
1832                row[n] = ''
1833        return row
1834
1835    def delete(self, table, row=None, **kw):
1836        """Delete an existing row in a database table.
1837
1838        This method deletes the row from a table.  It deletes based on the
1839        primary key of the table or the OID value as munged by get() or
1840        passed as keyword.
1841
1842        The return value is the number of deleted rows (i.e. 0 if the row
1843        did not exist and 1 if the row was deleted).
1844
1845        Note that if the row cannot be deleted because e.g. it is still
1846        referenced by another table, this method raises a ProgrammingError.
1847        """
1848        if table.endswith('*'):  # hint for descendant tables can be ignored
1849            table = table[:-1].rstrip()
1850        attnames = self.get_attnames(table)
1851        qoid = _oid_key(table) if 'oid' in attnames else None
1852        if row is None:
1853            row = {}
1854        elif 'oid' in row:
1855            del row['oid']  # only accept oid key from named args for safety
1856        row.update(kw)
1857        if qoid and qoid in row and 'oid' not in row:
1858            row['oid'] = row[qoid]
1859        try:  # try using the primary key
1860            keyname = self.pkey(table, True)
1861        except KeyError:  # the table has no primary key
1862            # try using the oid instead
1863            if qoid and 'oid' in row:
1864                keyname = ('oid',)
1865            else:
1866                raise _prg_error('Table %s has no primary key' % table)
1867        else:  # the table has a primary key
1868            # check whether all key columns have values
1869            if not set(keyname).issubset(row):
1870                # try using the oid instead
1871                if qoid and 'oid' in row:
1872                    keyname = ('oid',)
1873                else:
1874                    raise KeyError('Missing primary key in row')
1875        params = self.adapter.parameter_list()
1876        adapt = params.add
1877        col = self.escape_identifier
1878        where = ' AND '.join('%s = %s' % (
1879            col(k), adapt(row[k], attnames[k])) for k in keyname)
1880        if 'oid' in row:
1881            if qoid:
1882                row[qoid] = row['oid']
1883            del row['oid']
1884        q = 'DELETE FROM %s WHERE %s' % (
1885            self._escape_qualified_name(table), where)
1886        self._do_debug(q, params)
1887        res = self.db.query(q, params)
1888        return int(res)
1889
1890    def truncate(self, table, restart=False, cascade=False, only=False):
1891        """Empty a table or set of tables.
1892
1893        This method quickly removes all rows from the given table or set
1894        of tables.  It has the same effect as an unqualified DELETE on each
1895        table, but since it does not actually scan the tables it is faster.
1896        Furthermore, it reclaims disk space immediately, rather than requiring
1897        a subsequent VACUUM operation. This is most useful on large tables.
1898
1899        If restart is set to True, sequences owned by columns of the truncated
1900        table(s) are automatically restarted.  If cascade is set to True, it
1901        also truncates all tables that have foreign-key references to any of
1902        the named tables.  If the parameter only is not set to True, all the
1903        descendant tables (if any) will also be truncated. Optionally, a '*'
1904        can be specified after the table name to explicitly indicate that
1905        descendant tables are included.
1906        """
1907        if isinstance(table, basestring):
1908            only = {table: only}
1909            table = [table]
1910        elif isinstance(table, (list, tuple)):
1911            if isinstance(only, (list, tuple)):
1912                only = dict(zip(table, only))
1913            else:
1914                only = dict.fromkeys(table, only)
1915        elif isinstance(table, (set, frozenset)):
1916            only = dict.fromkeys(table, only)
1917        else:
1918            raise TypeError('The table must be a string, list or set')
1919        if not (restart is None or isinstance(restart, (bool, int))):
1920            raise TypeError('Invalid type for the restart option')
1921        if not (cascade is None or isinstance(cascade, (bool, int))):
1922            raise TypeError('Invalid type for the cascade option')
1923        tables = []
1924        for t in table:
1925            u = only.get(t)
1926            if not (u is None or isinstance(u, (bool, int))):
1927                raise TypeError('Invalid type for the only option')
1928            if t.endswith('*'):
1929                if u:
1930                    raise ValueError(
1931                        'Contradictory table name and only options')
1932                t = t[:-1].rstrip()
1933            t = self._escape_qualified_name(t)
1934            if u:
1935                t = 'ONLY %s' % t
1936            tables.append(t)
1937        q = ['TRUNCATE', ', '.join(tables)]
1938        if restart:
1939            q.append('RESTART IDENTITY')
1940        if cascade:
1941            q.append('CASCADE')
1942        q = ' '.join(q)
1943        self._do_debug(q)
1944        return self.db.query(q)
1945
1946    def get_as_list(self, table, what=None, where=None,
1947            order=None, limit=None, offset=None, scalar=False):
1948        """Get a table as a list.
1949
1950        This gets a convenient representation of the table as a list
1951        of named tuples in Python.  You only need to pass the name of
1952        the table (or any other SQL expression returning rows).  Note that
1953        by default this will return the full content of the table which
1954        can be huge and overflow your memory.  However, you can control
1955        the amount of data returned using the other optional parameters.
1956
1957        The parameter 'what' can restrict the query to only return a
1958        subset of the table columns.  It can be a string, list or a tuple.
1959        The parameter 'where' can restrict the query to only return a
1960        subset of the table rows.  It can be a string, list or a tuple
1961        of SQL expressions that all need to be fulfilled.  The parameter
1962        'order' specifies the ordering of the rows.  It can also be a
1963        other string, list or a tuple.  If no ordering is specified,
1964        the result will be ordered by the primary key(s) or all columns
1965        if no primary key exists.  You can set 'order' to False if you
1966        don't care about the ordering.  The parameters 'limit' and 'offset'
1967        can be integers specifying the maximum number of rows returned
1968        and a number of rows skipped over.
1969
1970        If you set the 'scalar' option to True, then instead of the
1971        named tuples you will get the first items of these tuples.
1972        This is useful if the result has only one column anyway.
1973        """
1974        if not table:
1975            raise TypeError('The table name is missing')
1976        if what:
1977            if isinstance(what, (list, tuple)):
1978                what = ', '.join(map(str, what))
1979            if order is None:
1980                order = what
1981        else:
1982            what = '*'
1983        q = ['SELECT', what, 'FROM', table]
1984        if where:
1985            if isinstance(where, (list, tuple)):
1986                where = ' AND '.join(map(str, where))
1987            q.extend(['WHERE', where])
1988        if order is None:
1989            try:
1990                order = self.pkey(table, True)
1991            except (KeyError, ProgrammingError):
1992                try:
1993                    order = list(self.get_attnames(table))
1994                except (KeyError, ProgrammingError):
1995                    pass
1996        if order:
1997            if isinstance(order, (list, tuple)):
1998                order = ', '.join(map(str, order))
1999            q.extend(['ORDER BY', order])
2000        if limit:
2001            q.append('LIMIT %d' % limit)
2002        if offset:
2003            q.append('OFFSET %d' % offset)
2004        q = ' '.join(q)
2005        self._do_debug(q)
2006        q = self.db.query(q)
2007        res = q.namedresult()
2008        if res and scalar:
2009            res = [row[0] for row in res]
2010        return res
2011
2012    def get_as_dict(self, table, keyname=None, what=None, where=None,
2013            order=None, limit=None, offset=None, scalar=False):
2014        """Get a table as a dictionary.
2015
2016        This method is similar to get_as_list(), but returns the table
2017        as a Python dict instead of a Python list, which can be even
2018        more convenient. The primary key column(s) of the table will
2019        be used as the keys of the dictionary, while the other column(s)
2020        will be the corresponding values.  The keys will be named tuples
2021        if the table has a composite primary key.  The rows will be also
2022        named tuples unless the 'scalar' option has been set to True.
2023        With the optional parameter 'keyname' you can specify an alternative
2024        set of columns to be used as the keys of the dictionary.  It must
2025        be set as a string, list or a tuple.
2026
2027        If the Python version supports it, the dictionary will be an
2028        OrderedDict using the order specified with the 'order' parameter
2029        or the key column(s) if not specified.  You can set 'order' to False
2030        if you don't care about the ordering.  In this case the returned
2031        dictionary will be an ordinary one.
2032        """
2033        if not table:
2034            raise TypeError('The table name is missing')
2035        if not keyname:
2036            try:
2037                keyname = self.pkey(table, True)
2038            except (KeyError, ProgrammingError):
2039                raise _prg_error('Table %s has no primary key' % table)
2040        if isinstance(keyname, basestring):
2041            keyname = [keyname]
2042        elif not isinstance(keyname, (list, tuple)):
2043            raise KeyError('The keyname must be a string, list or tuple')
2044        if what:
2045            if isinstance(what, (list, tuple)):
2046                what = ', '.join(map(str, what))
2047            if order is None:
2048                order = what
2049        else:
2050            what = '*'
2051        q = ['SELECT', what, 'FROM', table]
2052        if where:
2053            if isinstance(where, (list, tuple)):
2054                where = ' AND '.join(map(str, where))
2055            q.extend(['WHERE', where])
2056        if order is None:
2057            order = keyname
2058        if order:
2059            if isinstance(order, (list, tuple)):
2060                order = ', '.join(map(str, order))
2061            q.extend(['ORDER BY', order])
2062        if limit:
2063            q.append('LIMIT %d' % limit)
2064        if offset:
2065            q.append('OFFSET %d' % offset)
2066        q = ' '.join(q)
2067        self._do_debug(q)
2068        q = self.db.query(q)
2069        res = q.getresult()
2070        cls = OrderedDict if order else dict
2071        if not res:
2072            return cls()
2073        keyset = set(keyname)
2074        fields = q.listfields()
2075        if not keyset.issubset(fields):
2076            raise KeyError('Missing keyname in row')
2077        keyind, rowind = [], []
2078        for i, f in enumerate(fields):
2079            (keyind if f in keyset else rowind).append(i)
2080        keytuple = len(keyind) > 1
2081        getkey = itemgetter(*keyind)
2082        keys = map(getkey, res)
2083        if scalar:
2084            rowind = rowind[:1]
2085            rowtuple = False
2086        else:
2087            rowtuple = len(rowind) > 1
2088        if scalar or rowtuple:
2089            getrow = itemgetter(*rowind)
2090        else:
2091            rowind = rowind[0]
2092            getrow = lambda row: (row[rowind],)
2093            rowtuple = True
2094        rows = map(getrow, res)
2095        if keytuple or rowtuple:
2096            namedresult = get_namedresult()
2097            if namedresult:
2098                if keytuple:
2099                    keys = namedresult(_MemoryQuery(keys, keyname))
2100                if rowtuple:
2101                    fields = [f for f in fields if f not in keyset]
2102                    rows = namedresult(_MemoryQuery(rows, fields))
2103        return cls(zip(keys, rows))
2104
2105    def notification_handler(self,
2106            event, callback, arg_dict=None, timeout=None, stop_event=None):
2107        """Get notification handler that will run the given callback."""
2108        return NotificationHandler(self,
2109            event, callback, arg_dict, timeout, stop_event)
2110
2111
2112# if run as script, print some information
2113
2114if __name__ == '__main__':
2115    print('PyGreSQL version' + version)
2116    print('')
2117    print(__doc__)
Note: See TracBrowser for help on using the repository browser.