source: trunk/pg.py @ 798

Last change on this file since 798 was 798, checked in by cito, 4 years ago

Port type cache and typecasting from pgdb to pg

So far, the typecasting in the classic module was been only done by
the C extension module and was not extensible through typecasting
functions in Python. This has now been made extensible by adding
a cast hook to the C extension module which has been hooked up to
a new type cache object that holds information on the types and the
associated typecast functions. All of this works very similar to the
pgdb module now, except that the basic types are still handled by
the C extension module and the Python typecast functions are only
called via the hook for types which are not supported internally.

Also added tests and a chapter on the type cache in the documentation,
and cleaned up the error messages in the C extension module.

  • Property svn:keywords set to Id
File size: 71.5 KB
Line 
1#! /usr/bin/python
2#
3# pg.py
4#
5# $Id: pg.py 798 2016-01-30 19:55:18Z cito $
6#
7
8"""PyGreSQL classic interface.
9
10This pg module implements some basic database management stuff.
11It includes the _pg module and builds on it, providing the higher
12level wrapper class named DB with additional functionality.
13This is known as the "classic" ("old style") PyGreSQL interface.
14For a DB-API 2 compliant interface use the newer pgdb module.
15"""
16
17# Copyright (c) 1997-2016 by D'Arcy J.M. Cain.
18#
19# Contributions made by Ch. Zwerschke and others.
20#
21# The notification handler is based on pgnotify which is
22# Copyright (c) 2001 Ng Pheng Siong. All rights reserved.
23#
24# Permission to use, copy, modify, and distribute this software and its
25# documentation for any purpose and without fee is hereby granted,
26# provided that the above copyright notice appear in all copies and that
27# both that copyright notice and this permission notice appear in
28# supporting documentation.
29
30from __future__ import print_function
31
32from _pg import *
33
34import select
35import warnings
36
37from decimal import Decimal
38from collections import namedtuple
39from functools import partial
40from operator import itemgetter
41from re import compile as regex
42from json import loads as jsondecode, dumps as jsonencode
43
44try:
45    long
46except NameError:  # Python >= 3.0
47    long = int
48
49try:
50    basestring
51except NameError:  # Python >= 3.0
52    basestring = (str, bytes)
53
54try:
55    from collections import OrderedDict
56except ImportError:  # Python 2.6 or 3.0
57    OrderedDict = dict
58
59
60    class AttrDict(dict):
61        """Simple read-only ordered dictionary for storing attribute names."""
62
63        def __init__(self, *args, **kw):
64            if len(args) > 1 or kw:
65                raise TypeError
66            items = args[0] if args else []
67            if isinstance(items, dict):
68                raise TypeError
69            items = list(items)
70            self._keys = [item[0] for item in items]
71            dict.__init__(self, items)
72            self._read_only = True
73            error = self._read_only_error
74            self.clear = self.update = error
75            self.pop = self.setdefault = self.popitem = error
76
77        def __setitem__(self, key, value):
78            if self._read_only:
79                self._read_only_error()
80            dict.__setitem__(self, key, value)
81
82        def __delitem__(self, key):
83            if self._read_only:
84                self._read_only_error()
85            dict.__delitem__(self, key)
86
87        def __iter__(self):
88            return iter(self._keys)
89
90        def keys(self):
91            return list(self._keys)
92
93        def values(self):
94            return [self[key] for key in self]
95
96        def items(self):
97            return [(key, self[key]) for key in self]
98
99        def iterkeys(self):
100            return self.__iter__()
101
102        def itervalues(self):
103            return iter(self.values())
104
105        def iteritems(self):
106            return iter(self.items())
107
108        @staticmethod
109        def _read_only_error(*args, **kw):
110            raise TypeError('This object is read-only')
111
112else:
113
114     class AttrDict(OrderedDict):
115        """Simple read-only ordered dictionary for storing attribute names."""
116
117        def __init__(self, *args, **kw):
118            self._read_only = False
119            OrderedDict.__init__(self, *args, **kw)
120            self._read_only = True
121            error = self._read_only_error
122            self.clear = self.update = error
123            self.pop = self.setdefault = self.popitem = error
124
125        def __setitem__(self, key, value):
126            if self._read_only:
127                self._read_only_error()
128            OrderedDict.__setitem__(self, key, value)
129
130        def __delitem__(self, key):
131            if self._read_only:
132                self._read_only_error()
133            OrderedDict.__delitem__(self, key)
134
135        @staticmethod
136        def _read_only_error(*args, **kw):
137            raise TypeError('This object is read-only')
138
139
140# Auxiliary classes and functions that are independent from a DB connection:
141
142def _oid_key(table):
143    """Build oid key from a table name."""
144    return 'oid(%s)' % table
145
146
147class _SimpleType(dict):
148    """Dictionary mapping pg_type names to simple type names."""
149
150    _types = {'bool': 'bool',
151        'bytea': 'bytea',
152        'date': 'date interval time timetz timestamp timestamptz'
153            ' abstime reltime',  # these are very old
154        'float': 'float4 float8',
155        'int': 'cid int2 int4 int8 oid xid',
156        'json': 'json jsonb',
157        'num': 'numeric',
158        'money': 'money',
159        'text': 'bpchar char name text varchar'}
160
161    def __init__(self):
162        for typ, keys in self._types.items():
163            for key in keys.split():
164                self[key] = typ
165                self['_%s' % key] = '%s[]' % typ
166
167    @staticmethod
168    def __missing__(key):
169        return 'text'
170
171_simpletype = _SimpleType()
172
173
174class _Adapt:
175    """Mixin providing methods for adapting records and record elements.
176
177    This is used when passing values from one of the higher level DB
178    methods as parameters for a query.
179
180    This class must be mixed in to a connection class, because it needs
181    connection specific methods such as escape_bytea().
182    """
183
184    _bool_true_values = frozenset('t true 1 y yes on'.split())
185
186    _date_literals = frozenset('current_date current_time'
187        ' current_timestamp localtime localtimestamp'.split())
188
189    _re_array_quote = regex(r'[{},"\\\s]|^[Nn][Uu][Ll][Ll]$')
190    _re_record_quote = regex(r'[(,"\\]')
191    _re_array_escape = _re_record_escape = regex(r'(["\\])')
192
193    @classmethod
194    def _adapt_bool(cls, v):
195        """Adapt a boolean parameter."""
196        if isinstance(v, basestring):
197            if not v:
198                return None
199            v = v.lower() in cls._bool_true_values
200        return 't' if v else 'f'
201
202    @classmethod
203    def _adapt_date(cls, v):
204        """Adapt a date parameter."""
205        if not v:
206            return None
207        if isinstance(v, basestring) and v.lower() in cls._date_literals:
208            return _Literal(v)
209        return v
210
211    @staticmethod
212    def _adapt_num(v):
213        """Adapt a numeric parameter."""
214        if not v and v != 0:
215            return None
216        return v
217
218    _adapt_int = _adapt_float = _adapt_money = _adapt_num
219
220    def _adapt_bytea(self, v):
221        """Adapt a bytea parameter."""
222        return self.escape_bytea(v)
223
224    def _adapt_json(self, v):
225        """Adapt a json parameter."""
226        if not v:
227            return None
228        if isinstance(v, basestring):
229            return v
230        return self.encode_json(v)
231
232    @classmethod
233    def _adapt_text_array(cls, v):
234        """Adapt a text type array parameter."""
235        if isinstance(v, list):
236            adapt = cls._adapt_text_array
237            return '{%s}' % ','.join(adapt(v) for v in v)
238        if v is None:
239            return 'null'
240        if not v:
241            return '""'
242        v = str(v)
243        if cls._re_array_quote.search(v):
244            v = '"%s"' % cls._re_array_escape.sub(r'\\\1', v)
245        return v
246
247    _adapt_date_array = _adapt_text_array
248
249    @classmethod
250    def _adapt_bool_array(cls, v):
251        """Adapt a boolean array parameter."""
252        if isinstance(v, list):
253            adapt = cls._adapt_bool_array
254            return '{%s}' % ','.join(adapt(v) for v in v)
255        if v is None:
256            return 'null'
257        if isinstance(v, basestring):
258            if not v:
259                return 'null'
260            v = v.lower() in cls._bool_true_values
261        return 't' if v else 'f'
262
263    @classmethod
264    def _adapt_num_array(cls, v):
265        """Adapt a numeric array parameter."""
266        if isinstance(v, list):
267            adapt = cls._adapt_num_array
268            return '{%s}' % ','.join(adapt(v) for v in v)
269        if not v and v != 0:
270            return 'null'
271        return str(v)
272
273    _adapt_int_array = _adapt_float_array = _adapt_money_array = \
274            _adapt_num_array
275
276    def _adapt_bytea_array(self, v):
277        """Adapt a bytea array parameter."""
278        if isinstance(v, list):
279            return b'{' + b','.join(
280                self._adapt_bytea_array(v) for v in v) + b'}'
281        if v is None:
282            return b'null'
283        return self.escape_bytea(v).replace(b'\\', b'\\\\')
284
285    def _adapt_json_array(self, v):
286        """Adapt a json array parameter."""
287        if isinstance(v, list):
288            adapt = self._adapt_json_array
289            return '{%s}' % ','.join(adapt(v) for v in v)
290        if not v:
291            return 'null'
292        if not isinstance(v, basestring):
293            v = self.encode_json(v)
294        if self._re_array_quote.search(v):
295            v = '"%s"' % self._re_array_escape.sub(r'\\\1', v)
296        return v
297
298    def _adapt_record(self, v, typ):
299        """Adapt a record parameter with given type."""
300        typ = typ.attnames.values()
301        if len(typ) != len(v):
302            raise TypeError('Record parameter %s has wrong size' % v)
303        return '(%s)' % ','.join(getattr(self,
304            '_adapt_record_%s' % t.simple)(v) for v, t in zip(v, typ))
305
306    @classmethod
307    def _adapt_record_text(cls, v):
308        """Adapt a text type record component."""
309        if v is None:
310            return ''
311        if not v:
312            return '""'
313        v = str(v)
314        if cls._re_record_quote.search(v):
315            v = '"%s"' % cls._re_record_escape.sub(r'\\\1', v)
316        return v
317
318    _adapt_record_date = _adapt_record_text
319
320    @classmethod
321    def _adapt_record_bool(cls, v):
322        """Adapt a boolean record component."""
323        if v is None:
324            return ''
325        if isinstance(v, basestring):
326            if not v:
327                return ''
328            v = v.lower() in cls._bool_true_values
329        return 't' if v else 'f'
330
331    @staticmethod
332    def _adapt_record_num(v):
333        """Adapt a numeric record component."""
334        if not v and v != 0:
335            return ''
336        return str(v)
337
338    _adapt_record_int = _adapt_record_float = _adapt_record_money = \
339        _adapt_record_num
340
341    def _adapt_record_bytea(self, v):
342        if v is None:
343            return ''
344        v = self.escape_bytea(v)
345        if bytes is not str and isinstance(v, bytes):
346            v = v.decode('ascii')
347        return v.replace('\\', '\\\\')
348
349    def _adapt_record_json(self, v):
350        """Adapt a bytea record component."""
351        if not v:
352            return ''
353        if not isinstance(v, basestring):
354            v = self.encode_json(v)
355        if self._re_array_quote.search(v):
356            v = '"%s"' % self._re_array_escape.sub(r'\\\1', v)
357        return v
358
359    def _adapt_param(self, value, typ, params):
360        """Adapt and add a parameter to the list."""
361        if isinstance(value, _Literal):
362            return value
363        if value is not None:
364            simple = typ.simple
365            if simple == 'text':
366                pass
367            elif simple == 'record':
368                if isinstance(value, tuple):
369                    value = self._adapt_record(value, typ)
370            elif simple.endswith('[]'):
371                if isinstance(value, list):
372                    adapt = getattr(self, '_adapt_%s_array' % simple[:-2])
373                    value = adapt(value)
374            else:
375                adapt = getattr(self, '_adapt_%s' % simple)
376                value = adapt(value)
377                if isinstance(value, _Literal):
378                    return value
379        params.append(value)
380        return '$%d' % len(params)
381
382
383def cast_bool(value):
384    """Cast a boolean value."""
385    if not get_bool():
386        return value
387    return value[0] == 't'
388
389
390def cast_json(value):
391    """Cast a JSON value."""
392    cast = get_jsondecode()
393    if not cast:
394        return value
395    return cast(value)
396
397
398def cast_num(value):
399    """Cast a numeric value."""
400    return (get_decimal() or float)(value)
401
402
403def cast_money(value):
404    """Cast a money value."""
405    point = get_decimal_point()
406    if not point:
407        return value
408    if point != '.':
409        value = value.replace(point, '.')
410    value = value.replace('(', '-')
411    value = ''.join(c for c in value if c.isdigit() or c in '.-')
412    return (get_decimal() or float)(value)
413
414
415def cast_int2vector(value):
416    """Cast an int2vector value."""
417    return [int(v) for v in value.split()]
418
419
420class Typecasts(dict):
421    """Dictionary mapping database types to typecast functions.
422
423    The cast functions get passed the string representation of a value in
424    the database which they need to convert to a Python object.  The
425    passed string will never be None since NULL values are already be
426    handled before the cast function is called.
427
428    Note that the basic types are already handled by the C extension.
429    They only need to be handled here as record or array components.
430    """
431
432    # the default cast functions
433    # (str functions are ignored but have been added for faster access)
434    defaults = {'char': str, 'bpchar': str, 'name': str,
435        'text': str, 'varchar': str,
436        'bool': cast_bool, 'bytea': unescape_bytea,
437        'int2': int, 'int4': int, 'serial': int,
438        'int8': long, 'json': cast_json, 'jsonb': cast_json,
439        'oid': long, 'oid8': long,
440        'float4': float, 'float8': float,
441        'numeric': cast_num, 'money': cast_money,
442        'int2vector': cast_int2vector,
443        'anyarray': cast_array, 'record': cast_record}
444
445    def __missing__(self, typ):
446        """Create a cast function if it is not cached.
447       
448        Note that this class never raises a KeyError,
449        but returns None when no special cast function exists.
450        """
451        if not isinstance(typ, str):
452            raise TypeError('Invalid type: %s' % typ)
453        cast = self.defaults.get(typ)
454        if cast:
455            # store default for faster access
456            self[typ] = cast
457        elif typ.startswith('_'):
458            base_cast = self[typ[1:]]
459            cast = self.create_array_cast(base_cast)
460            if base_cast:
461                self[typ] = cast
462        else:
463            attnames = self.get_attnames(typ)
464            if attnames:
465                casts = [self[v.pgtype] for v in attnames.values()]
466                cast = self.create_record_cast(typ, attnames, casts)
467                self[typ] = cast
468        return cast
469
470    def get(self, typ, default=None):
471        """Get the typecast function for the given database type."""
472        return self[typ] or default
473
474    def set(self, typ, cast):
475        """Set a typecast function for the specified database type(s)."""
476        if isinstance(typ, basestring):
477            typ = [typ]
478        if cast is None:
479            for t in typ:
480                self.pop(t, None)
481                self.pop('_%s' % t, None)
482        else:
483            if not callable(cast):
484                raise TypeError("Cast parameter must be callable")
485            for t in typ:
486                self[t] = cast
487                self.pop('_%s % t', None)
488
489    def reset(self, typ=None):
490        """Reset the typecasts for the specified type(s) to their defaults.
491
492        When no type is specified, all typecasts will be reset.
493        """
494        defaults = self.defaults
495        if typ is None:
496            self.clear()
497            self.update(defaults)
498        else:
499            if isinstance(typ, basestring):
500                typ = [typ]
501            for t in typ:
502                self.set(t, defaults.get(t))
503
504    @classmethod
505    def get_default(cls, typ):
506        """Get the default typecast function for the given database type."""
507        return cls.defaults.get(typ)
508
509    @classmethod
510    def set_default(cls, typ, cast):
511        """Set a default typecast function for the given database type(s)."""
512        if isinstance(typ, basestring):
513            typ = [typ]
514        defaults = cls.defaults
515        if cast is None:
516            for t in typ:
517                defaults.pop(t, None)
518                defaults.pop('_%s' % t, None)
519        else:
520            if not callable(cast):
521                raise TypeError("Cast parameter must be callable")
522            for t in typ:
523                defaults[t] = cast
524                defaults.pop('_%s % t', None)
525
526    def get_attnames(self, typ):
527        """Return the fields for the given record type.
528
529        This method will be replaced with the get_attnames() method of DbTypes.
530        """
531        return {}
532
533    def create_array_cast(self, cast):
534        """Create an array typecast for the given base cast."""
535        return lambda v: cast_array(v, cast)
536
537    def create_record_cast(self, name, fields, casts):
538        """Create a named record typecast for the given fields and casts."""
539        record = namedtuple(name, fields)
540        return lambda v: record(*cast_record(v, casts))
541
542
543def get_typecast(typ):
544    """Get the global typecast function for the given database type(s)."""
545    return Typecasts.get_default(typ)
546
547
548def set_typecast(typ, cast):
549    """Set a global typecast function for the given database type(s).
550
551    Note that connections cache cast functions. To be sure a global change
552    is picked up by a running connection, call db.db_types.reset_typecast().
553    """
554    Typecasts.set_default(typ, cast)
555
556
557class DbType(str):
558    """Class augmenting the simple type name with additional info.
559
560    The following additional information is provided:
561
562        oid: the PostgreSQL type OID
563        pgtype: the PostgreSQL type name
564        regtype: the regular type name
565        simple: the simple PyGreSQL type name
566        typtype: b = base type, c = composite type etc.
567        category: A = Array, b =Boolean, C = Composite etc.
568        delim: delimiter for array types
569        relid: corresponding table for composite types
570        attnames: attributes for composite types
571    """
572
573    @property
574    def attnames(self):
575        """Get names and types of the fields of a composite type."""
576        return self._get_attnames(self)
577
578
579class DbTypes(dict):
580    """Cache for PostgreSQL data types.
581
582    This cache maps type OIDs and names to DbType objects containing
583    information on the associated database type.
584    """
585
586    _num_types = frozenset('int float num money'
587        ' int2 int4 int8 float4 float8 numeric money'.split())
588
589    def __init__(self, db):
590        """Initialize type cache for connection."""
591        super(DbTypes, self).__init__()
592        self._get_attnames = db.get_attnames
593        db = db.db
594        self.query = db.query
595        self.escape_string = db.escape_string
596        self._typecasts = Typecasts()
597        self._typecasts.get_attnames = self.get_attnames
598        self._regtypes = False
599
600    def add(self, oid, pgtype, regtype,
601               typtype, category, delim, relid):
602        """Create a PostgreSQL type name with additional info."""
603        if oid in self:
604            return self[oid]
605        simple = 'record' if relid else _simpletype[pgtype]
606        typ = DbType(regtype if self._regtypes else simple)
607        typ.oid = oid
608        typ.simple = simple
609        typ.pgtype = pgtype
610        typ.regtype = regtype
611        typ.typtype = typtype
612        typ.category = category
613        typ.delim = delim
614        typ.relid = relid
615        typ._get_attnames = self.get_attnames
616        return typ
617
618    def __missing__(self, key):
619        """Get the type info from the database if it is not cached."""
620        try:
621            res = self.query("SELECT oid, typname, typname::regtype,"
622                " typtype, typcategory, typdelim, typrelid"
623                " FROM pg_type WHERE oid=%s::regtype" %
624                (DB._adapt_qualified_param(key, 1),), (key,)).getresult()
625        except ProgrammingError:
626            res = None
627        if not res:
628            raise KeyError('Type %s could not be found' % key)
629        res = res[0]
630        typ = self.add(*res)
631        self[typ.oid] = self[typ.pgtype] = typ
632        return typ
633
634    def get(self, key, default=None):
635        """Get the type even if it is not cached."""
636        try:
637            return self[key]
638        except KeyError:
639            return default
640
641    def get_attnames(self, typ):
642        """Get names and types of the fields of a composite type."""
643        if not isinstance(typ, DbType):
644            typ = self.get(typ)
645            if not typ:
646                return None
647        if not typ.relid:
648            return None
649        return self._get_attnames(typ.relid, with_oid=False)
650
651    def get_typecast(self, typ):
652        """Get the typecast function for the given database type."""
653        return self._typecasts.get(typ)
654
655    def set_typecast(self, typ, cast):
656        """Set a typecast function for the specified database type(s)."""
657        self._typecasts.set(typ, cast)
658
659    def reset_typecast(self, typ=None):
660        """Reset the typecast function for the specified database type(s)."""
661        self._typecasts.reset(typ)
662
663    def typecast(self, value, typ):
664        """Cast the given value according to the given database type."""
665        if value is None:
666            # for NULL values, no typecast is necessary
667            return None
668        if not isinstance(typ, DbType):
669            typ = self.get(typ)
670            if typ:
671                typ = typ.pgtype
672        cast = self.get_typecast(typ) if typ else None
673        if not cast or cast is str:
674            # no typecast is necessary
675            return value
676        return cast(value)
677
678
679class _Literal(str):
680    """Wrapper class for literal SQL."""
681
682
683def _namedresult(q):
684    """Get query result as named tuples."""
685    row = namedtuple('Row', q.listfields())
686    return [row(*r) for r in q.getresult()]
687
688
689class _MemoryQuery:
690    """Class that embodies a given query result."""
691
692    def __init__(self, result, fields):
693        """Create query from given result rows and field names."""
694        self.result = result
695        self.fields = fields
696
697    def listfields(self):
698        """Return the stored field names of this query."""
699        return self.fields
700
701    def getresult(self):
702        """Return the stored result of this query."""
703        return self.result
704
705
706def _db_error(msg, cls=DatabaseError):
707    """Return DatabaseError with empty sqlstate attribute."""
708    error = cls(msg)
709    error.sqlstate = None
710    return error
711
712
713def _int_error(msg):
714    """Return InternalError."""
715    return _db_error(msg, InternalError)
716
717
718def _prg_error(msg):
719    """Return ProgrammingError."""
720    return _db_error(msg, ProgrammingError)
721
722
723# Initialize the C module
724
725set_namedresult(_namedresult)
726set_decimal(Decimal)
727set_jsondecode(jsondecode)
728
729
730# The notification handler
731
732class NotificationHandler(object):
733    """A PostgreSQL client-side asynchronous notification handler."""
734
735    def __init__(self, db, event, callback=None,
736            arg_dict=None, timeout=None, stop_event=None):
737        """Initialize the notification handler.
738
739        You must pass a PyGreSQL database connection, the name of an
740        event (notification channel) to listen for and a callback function.
741
742        You can also specify a dictionary arg_dict that will be passed as
743        the single argument to the callback function, and a timeout value
744        in seconds (a floating point number denotes fractions of seconds).
745        If it is absent or None, the callers will never time out.  If the
746        timeout is reached, the callback function will be called with a
747        single argument that is None.  If you set the timeout to zero,
748        the handler will poll notifications synchronously and return.
749
750        You can specify the name of the event that will be used to signal
751        the handler to stop listening as stop_event. By default, it will
752        be the event name prefixed with 'stop_'.
753        """
754        self.db = db
755        self.event = event
756        self.stop_event = stop_event or 'stop_%s' % event
757        self.listening = False
758        self.callback = callback
759        if arg_dict is None:
760            arg_dict = {}
761        self.arg_dict = arg_dict
762        self.timeout = timeout
763
764    def __del__(self):
765        self.unlisten()
766
767    def close(self):
768        """Stop listening and close the connection."""
769        if self.db:
770            self.unlisten()
771            self.db.close()
772            self.db = None
773
774    def listen(self):
775        """Start listening for the event and the stop event."""
776        if not self.listening:
777            self.db.query('listen "%s"' % self.event)
778            self.db.query('listen "%s"' % self.stop_event)
779            self.listening = True
780
781    def unlisten(self):
782        """Stop listening for the event and the stop event."""
783        if self.listening:
784            self.db.query('unlisten "%s"' % self.event)
785            self.db.query('unlisten "%s"' % self.stop_event)
786            self.listening = False
787
788    def notify(self, db=None, stop=False, payload=None):
789        """Generate a notification.
790
791        Optionally, you can pass a payload with the notification.
792
793        If you set the stop flag, a stop notification will be sent that
794        will cause the handler to stop listening.
795
796        Note: If the notification handler is running in another thread, you
797        must pass a different database connection since PyGreSQL database
798        connections are not thread-safe.
799        """
800        if self.listening:
801            if not db:
802                db = self.db
803            q = 'notify "%s"' % (self.stop_event if stop else self.event)
804            if payload:
805                q += ", '%s'" % payload
806            return db.query(q)
807
808    def __call__(self):
809        """Invoke the notification handler.
810
811        The handler is a loop that listens for notifications on the event
812        and stop event channels.  When either of these notifications are
813        received, its associated 'pid', 'event' and 'extra' (the payload
814        passed with the notification) are inserted into its arg_dict
815        dictionary and the callback is invoked with this dictionary as
816        a single argument.  When the handler receives a stop event, it
817        stops listening to both events and return.
818
819        In the special case that the timeout of the handler has been set
820        to zero, the handler will poll all events synchronously and return.
821        If will keep listening until it receives a stop event.
822
823        Note: If you run this loop in another thread, don't use the same
824        database connection for database operations in the main thread.
825        """
826        self.listen()
827        poll = self.timeout == 0
828        if not poll:
829            rlist = [self.db.fileno()]
830        while self.listening:
831            if poll or select.select(rlist, [], [], self.timeout)[0]:
832                while self.listening:
833                    notice = self.db.getnotify()
834                    if not notice:  # no more messages
835                        break
836                    event, pid, extra = notice
837                    if event not in (self.event, self.stop_event):
838                        self.unlisten()
839                        raise _db_error(
840                            'Listening for "%s" and "%s", but notified of "%s"'
841                            % (self.event, self.stop_event, event))
842                    if event == self.stop_event:
843                        self.unlisten()
844                    self.arg_dict.update(pid=pid, event=event, extra=extra)
845                    self.callback(self.arg_dict)
846                if poll:
847                    break
848            else:   # we timed out
849                self.unlisten()
850                self.callback(None)
851
852
853def pgnotify(*args, **kw):
854    """Same as NotificationHandler, under the traditional name."""
855    warnings.warn("pgnotify is deprecated, use NotificationHandler instead",
856        DeprecationWarning, stacklevel=2)
857    return NotificationHandler(*args, **kw)
858
859
860# The actual PostGreSQL database connection interface:
861
862class DB(_Adapt):
863    """Wrapper class for the _pg connection type."""
864
865    def __init__(self, *args, **kw):
866        """Create a new connection
867
868        You can pass either the connection parameters or an existing
869        _pg or pgdb connection. This allows you to use the methods
870        of the classic pg interface with a DB-API 2 pgdb connection.
871        """
872        if not args and len(kw) == 1:
873            db = kw.get('db')
874        elif not kw and len(args) == 1:
875            db = args[0]
876        else:
877            db = None
878        if db:
879            if isinstance(db, DB):
880                db = db.db
881            else:
882                try:
883                    db = db._cnx
884                except AttributeError:
885                    pass
886        if not db or not hasattr(db, 'db') or not hasattr(db, 'query'):
887            db = connect(*args, **kw)
888            self._closeable = True
889        else:
890            self._closeable = False
891        self.db = db
892        self.dbname = db.db
893        self._regtypes = False
894        self._attnames = {}
895        self._pkeys = {}
896        self._privileges = {}
897        self._args = args, kw
898        self.dbtypes = DbTypes(self)
899        db.set_cast_hook(self.dbtypes.typecast)
900        self.debug = None  # For debugging scripts, this can be set
901            # * to a string format specification (e.g. in CGI set to "%s<BR>"),
902            # * to a file object to write debug statements or
903            # * to a callable object which takes a string argument
904            # * to any other true value to just print debug statements
905
906    def __getattr__(self, name):
907        # All undefined members are same as in underlying connection:
908        if self.db:
909            return getattr(self.db, name)
910        else:
911            raise _int_error('Connection is not valid')
912
913    def __dir__(self):
914        # Custom dir function including the attributes of the connection:
915        attrs = set(self.__class__.__dict__)
916        attrs.update(self.__dict__)
917        attrs.update(dir(self.db))
918        return sorted(attrs)
919
920    # Context manager methods
921
922    def __enter__(self):
923        """Enter the runtime context. This will start a transactio."""
924        self.begin()
925        return self
926
927    def __exit__(self, et, ev, tb):
928        """Exit the runtime context. This will end the transaction."""
929        if et is None and ev is None and tb is None:
930            self.commit()
931        else:
932            self.rollback()
933
934    # Auxiliary methods
935
936    def _do_debug(self, *args):
937        """Print a debug message"""
938        if self.debug:
939            s = '\n'.join(str(arg) for arg in args)
940            if isinstance(self.debug, basestring):
941                print(self.debug % s)
942            elif hasattr(self.debug, 'write'):
943                self.debug.write(s + '\n')
944            elif callable(self.debug):
945                self.debug(s)
946            else:
947                print(s)
948
949    def _escape_qualified_name(self, s):
950        """Escape a qualified name.
951
952        Escapes the name for use as an SQL identifier, unless the
953        name contains a dot, in which case the name is ambiguous
954        (could be a qualified name or just a name with a dot in it)
955        and must be quoted manually by the caller.
956        """
957        if '.' not in s:
958            s = self.escape_identifier(s)
959        return s
960
961    @staticmethod
962    def _make_bool(d):
963        """Get boolean value corresponding to d."""
964        return bool(d) if get_bool() else ('t' if d else 'f')
965
966    def _list_params(self, params):
967        """Create a human readable parameter list."""
968        return ', '.join('$%d=%r' % (n, v) for n, v in enumerate(params, 1))
969
970    @staticmethod
971    def _adapt_qualified_param(name, param):
972        """Quote parameter representing a qualified name.
973
974        Escapes the name for use as an SQL parameter, unless the
975        name contains a dot, in which case the name is ambiguous
976        (could be a qualified name or just a name with a dot in it)
977        and must be quoted manually by the caller.
978
979        """
980        if isinstance(param, int):
981            param = "$%d" % param
982        if isinstance(name, basestring) and '.' not in name:
983            param = 'quote_ident(%s)' % (param,)
984        return param
985
986    # Public methods
987
988    # escape_string and escape_bytea exist as methods,
989    # so we define unescape_bytea as a method as well
990    unescape_bytea = staticmethod(unescape_bytea)
991
992    def decode_json(self, s):
993        """Decode a JSON string coming from the database."""
994        return (get_jsondecode() or jsondecode)(s)
995
996    def encode_json(self, d):
997        """Encode a JSON string for use within SQL."""
998        return jsonencode(d)
999
1000    def close(self):
1001        """Close the database connection."""
1002        # Wraps shared library function so we can track state.
1003        if self._closeable:
1004            if self.db:
1005                self.db.close()
1006                self.db = None
1007            else:
1008                raise _int_error('Connection already closed')
1009
1010    def reset(self):
1011        """Reset connection with current parameters.
1012
1013        All derived queries and large objects derived from this connection
1014        will not be usable after this call.
1015
1016        """
1017        if self.db:
1018            self.db.reset()
1019        else:
1020            raise _int_error('Connection already closed')
1021
1022    def reopen(self):
1023        """Reopen connection to the database.
1024
1025        Used in case we need another connection to the same database.
1026        Note that we can still reopen a database that we have closed.
1027
1028        """
1029        # There is no such shared library function.
1030        if self._closeable:
1031            db = connect(*self._args[0], **self._args[1])
1032            if self.db:
1033                self.db.close()
1034            self.db = db
1035
1036    def begin(self, mode=None):
1037        """Begin a transaction."""
1038        qstr = 'BEGIN'
1039        if mode:
1040            qstr += ' ' + mode
1041        return self.query(qstr)
1042
1043    start = begin
1044
1045    def commit(self):
1046        """Commit the current transaction."""
1047        return self.query('COMMIT')
1048
1049    end = commit
1050
1051    def rollback(self, name=None):
1052        """Roll back the current transaction."""
1053        qstr = 'ROLLBACK'
1054        if name:
1055            qstr += ' TO ' + name
1056        return self.query(qstr)
1057
1058    abort = rollback
1059
1060    def savepoint(self, name):
1061        """Define a new savepoint within the current transaction."""
1062        return self.query('SAVEPOINT ' + name)
1063
1064    def release(self, name):
1065        """Destroy a previously defined savepoint."""
1066        return self.query('RELEASE ' + name)
1067
1068    def get_parameter(self, parameter):
1069        """Get the value of a run-time parameter.
1070
1071        If the parameter is a string, the return value will also be a string
1072        that is the current setting of the run-time parameter with that name.
1073
1074        You can get several parameters at once by passing a list, set or dict.
1075        When passing a list of parameter names, the return value will be a
1076        corresponding list of parameter settings.  When passing a set of
1077        parameter names, a new dict will be returned, mapping these parameter
1078        names to their settings.  Finally, if you pass a dict as parameter,
1079        its values will be set to the current parameter settings corresponding
1080        to its keys.
1081
1082        By passing the special name 'all' as the parameter, you can get a dict
1083        of all existing configuration parameters.
1084        """
1085        if isinstance(parameter, basestring):
1086            parameter = [parameter]
1087            values = None
1088        elif isinstance(parameter, (list, tuple)):
1089            values = []
1090        elif isinstance(parameter, (set, frozenset)):
1091            values = {}
1092        elif isinstance(parameter, dict):
1093            values = parameter
1094        else:
1095            raise TypeError(
1096                'The parameter must be a string, list, set or dict')
1097        if not parameter:
1098            raise TypeError('No parameter has been specified')
1099        params = {} if isinstance(values, dict) else []
1100        for key in parameter:
1101            param = key.strip().lower() if isinstance(
1102                key, basestring) else None
1103            if not param:
1104                raise TypeError('Invalid parameter')
1105            if param == 'all':
1106                q = 'SHOW ALL'
1107                values = self.db.query(q).getresult()
1108                values = dict(value[:2] for value in values)
1109                break
1110            if isinstance(values, dict):
1111                params[param] = key
1112            else:
1113                params.append(param)
1114        else:
1115            for param in params:
1116                q = 'SHOW %s' % (param,)
1117                value = self.db.query(q).getresult()[0][0]
1118                if values is None:
1119                    values = value
1120                elif isinstance(values, list):
1121                    values.append(value)
1122                else:
1123                    values[params[param]] = value
1124        return values
1125
1126    def set_parameter(self, parameter, value=None, local=False):
1127        """Set the value of a run-time parameter.
1128
1129        If the parameter and the value are strings, the run-time parameter
1130        will be set to that value.  If no value or None is passed as a value,
1131        then the run-time parameter will be restored to its default value.
1132
1133        You can set several parameters at once by passing a list of parameter
1134        names, together with a single value that all parameters should be
1135        set to or with a corresponding list of values.  You can also pass
1136        the parameters as a set if you only provide a single value.
1137        Finally, you can pass a dict with parameter names as keys.  In this
1138        case, you should not pass a value, since the values for the parameters
1139        will be taken from the dict.
1140
1141        By passing the special name 'all' as the parameter, you can reset
1142        all existing settable run-time parameters to their default values.
1143
1144        If you set local to True, then the command takes effect for only the
1145        current transaction.  After commit() or rollback(), the session-level
1146        setting takes effect again.  Setting local to True will appear to
1147        have no effect if it is executed outside a transaction, since the
1148        transaction will end immediately.
1149        """
1150        if isinstance(parameter, basestring):
1151            parameter = {parameter: value}
1152        elif isinstance(parameter, (list, tuple)):
1153            if isinstance(value, (list, tuple)):
1154                parameter = dict(zip(parameter, value))
1155            else:
1156                parameter = dict.fromkeys(parameter, value)
1157        elif isinstance(parameter, (set, frozenset)):
1158            if isinstance(value, (list, tuple, set, frozenset)):
1159                value = set(value)
1160                if len(value) == 1:
1161                    value = value.pop()
1162            if not(value is None or isinstance(value, basestring)):
1163                raise ValueError('A single value must be specified'
1164                    ' when parameter is a set')
1165            parameter = dict.fromkeys(parameter, value)
1166        elif isinstance(parameter, dict):
1167            if value is not None:
1168                raise ValueError('A value must not be specified'
1169                    ' when parameter is a dictionary')
1170        else:
1171            raise TypeError(
1172                'The parameter must be a string, list, set or dict')
1173        if not parameter:
1174            raise TypeError('No parameter has been specified')
1175        params = {}
1176        for key, value in parameter.items():
1177            param = key.strip().lower() if isinstance(
1178                key, basestring) else None
1179            if not param:
1180                raise TypeError('Invalid parameter')
1181            if param == 'all':
1182                if value is not None:
1183                    raise ValueError('A value must ot be specified'
1184                        " when parameter is 'all'")
1185                params = {'all': None}
1186                break
1187            params[param] = value
1188        local = ' LOCAL' if local else ''
1189        for param, value in params.items():
1190            if value is None:
1191                q = 'RESET%s %s' % (local, param)
1192            else:
1193                q = 'SET%s %s TO %s' % (local, param, value)
1194            self._do_debug(q)
1195            self.db.query(q)
1196
1197    def query(self, command, *args):
1198        """Execute a SQL command string.
1199
1200        This method simply sends a SQL query to the database.  If the query is
1201        an insert statement that inserted exactly one row into a table that
1202        has OIDs, the return value is the OID of the newly inserted row.
1203        If the query is an update or delete statement, or an insert statement
1204        that did not insert exactly one row in a table with OIDs, then the
1205        number of rows affected is returned as a string.  If it is a statement
1206        that returns rows as a result (usually a select statement, but maybe
1207        also an "insert/update ... returning" statement), this method returns
1208        a Query object that can be accessed via getresult() or dictresult()
1209        or simply printed.  Otherwise, it returns `None`.
1210
1211        The query can contain numbered parameters of the form $1 in place
1212        of any data constant.  Arguments given after the query string will
1213        be substituted for the corresponding numbered parameter.  Parameter
1214        values can also be given as a single list or tuple argument.
1215        """
1216        # Wraps shared library function for debugging.
1217        if not self.db:
1218            raise _int_error('Connection is not valid')
1219        if args:
1220            self._do_debug(command, args)
1221            return self.db.query(command, args)
1222        self._do_debug(command)
1223        return self.db.query(command)
1224
1225    def pkey(self, table, composite=False, flush=False):
1226        """Get or set the primary key of a table.
1227
1228        Single primary keys are returned as strings unless you
1229        set the composite flag.  Composite primary keys are always
1230        represented as tuples.  Note that this raises a KeyError
1231        if the table does not have a primary key.
1232
1233        If flush is set then the internal cache for primary keys will
1234        be flushed.  This may be necessary after the database schema or
1235        the search path has been changed.
1236        """
1237        pkeys = self._pkeys
1238        if flush:
1239            pkeys.clear()
1240            self._do_debug('The pkey cache has been flushed')
1241        try:  # cache lookup
1242            pkey = pkeys[table]
1243        except KeyError:  # cache miss, check the database
1244            q = ("SELECT a.attname, a.attnum, i.indkey FROM pg_index i"
1245                " JOIN pg_attribute a ON a.attrelid = i.indrelid"
1246                " AND a.attnum = ANY(i.indkey)"
1247                " AND NOT a.attisdropped"
1248                " WHERE i.indrelid=%s::regclass"
1249                " AND i.indisprimary ORDER BY a.attnum") % (
1250                    self._adapt_qualified_param(table, 1),)
1251            pkey = self.db.query(q, (table,)).getresult()
1252            if not pkey:
1253                raise KeyError('Table %s has no primary key' % table)
1254            # we want to use the order defined in the primary key index here,
1255            # not the order as defined by the columns in the table
1256            if len(pkey) > 1:
1257                indkey = pkey[0][2]
1258                pkey = sorted(pkey, key=lambda row: indkey.index(row[1]))
1259                pkey = tuple(row[0] for row in pkey)
1260            else:
1261                pkey = pkey[0][0]
1262            pkeys[table] = pkey  # cache it
1263        if composite and not isinstance(pkey, tuple):
1264            pkey = (pkey,)
1265        return pkey
1266
1267    def get_databases(self):
1268        """Get list of databases in the system."""
1269        return [s[0] for s in
1270            self.db.query('SELECT datname FROM pg_database').getresult()]
1271
1272    def get_relations(self, kinds=None):
1273        """Get list of relations in connected database of specified kinds.
1274
1275        If kinds is None or empty, all kinds of relations are returned.
1276        Otherwise kinds can be a string or sequence of type letters
1277        specifying which kind of relations you want to list.
1278        """
1279        where = " AND r.relkind IN (%s)" % ','.join(
1280            ["'%s'" % k for k in kinds]) if kinds else ''
1281        q = ("SELECT quote_ident(s.nspname)||'.'||quote_ident(r.relname)"
1282            " FROM pg_class r"
1283            " JOIN pg_namespace s ON s.oid = r.relnamespace"
1284            " WHERE s.nspname NOT SIMILAR"
1285            " TO 'pg/_%%|information/_schema' ESCAPE '/' %s"
1286            " ORDER BY s.nspname, r.relname") % where
1287        return [r[0] for r in self.db.query(q).getresult()]
1288
1289    def get_tables(self):
1290        """Return list of tables in connected database."""
1291        return self.get_relations('r')
1292
1293    def get_attnames(self, table, with_oid=True, flush=False):
1294        """Given the name of a table, dig out the set of attribute names.
1295
1296        Returns a read-only dictionary of attribute names (the names are
1297        the keys, the values are the names of the attributes' types)
1298        with the column names in the proper order if you iterate over it.
1299
1300        If flush is set, then the internal cache for attribute names will
1301        be flushed. This may be necessary after the database schema or
1302        the search path has been changed.
1303
1304        By default, only a limited number of simple types will be returned.
1305        You can get the regular types after calling use_regtypes(True).
1306        """
1307        attnames = self._attnames
1308        if flush:
1309            attnames.clear()
1310            self._do_debug('The attnames cache has been flushed')
1311        try:  # cache lookup
1312            names = attnames[table]
1313        except KeyError:  # cache miss, check the database
1314            q = "a.attnum > 0"
1315            if with_oid:
1316                q = "(%s OR a.attname = 'oid')" % q
1317            q = ("SELECT a.attname, t.oid, t.typname, t.typname::regtype,"
1318                " t.typtype, t.typcategory, t.typdelim, t.typrelid"
1319                " FROM pg_attribute a"
1320                " JOIN pg_type t ON t.oid = a.atttypid"
1321                " WHERE a.attrelid = %s::regclass AND %s"
1322                " AND NOT a.attisdropped ORDER BY a.attnum") % (
1323                    self._adapt_qualified_param(table, 1), q)
1324            names = self.db.query(q, (table,)).getresult()
1325            types = self.dbtypes
1326            names = ((name[0], types.add(*name[1:])) for name in names)
1327            names = AttrDict(names)
1328            attnames[table] = names  # cache it
1329        return names
1330
1331    def use_regtypes(self, regtypes=None):
1332        """Use regular type names instead of simplified type names."""
1333        if regtypes is None:
1334            return self.dbtypes._regtypes
1335        else:
1336            regtypes = bool(regtypes)
1337            if regtypes != self.dbtypes._regtypes:
1338                self.dbtypes._regtypes = regtypes
1339                self._attnames.clear()
1340                self.dbtypes.clear()
1341            return regtypes
1342
1343    def has_table_privilege(self, table, privilege='select'):
1344        """Check whether current user has specified table privilege."""
1345        privilege = privilege.lower()
1346        try:  # ask cache
1347            return self._privileges[(table, privilege)]
1348        except KeyError:  # cache miss, ask the database
1349            q = "SELECT has_table_privilege(%s, $2)" % (
1350                self._adapt_qualified_param(table, 1),)
1351            q = self.db.query(q, (table, privilege))
1352            ret = q.getresult()[0][0] == self._make_bool(True)
1353            self._privileges[(table, privilege)] = ret  # cache it
1354            return ret
1355
1356    def get(self, table, row, keyname=None):
1357        """Get a row from a database table or view.
1358
1359        This method is the basic mechanism to get a single row.  It assumes
1360        that the keyname specifies a unique row.  It must be the name of a
1361        single column or a tuple of column names.  If the keyname is not
1362        specified, then the primary key for the table is used.
1363
1364        If row is a dictionary, then the value for the key is taken from it.
1365        Otherwise, the row must be a single value or a tuple of values
1366        corresponding to the passed keyname or primary key.  The fetched row
1367        from the table will be returned as a new dictionary or used to replace
1368        the existing values when row was passed as aa dictionary.
1369
1370        The OID is also put into the dictionary if the table has one, but
1371        in order to allow the caller to work with multiple tables, it is
1372        munged as "oid(table)" using the actual name of the table.
1373        """
1374        if table.endswith('*'):  # hint for descendant tables can be ignored
1375            table = table[:-1].rstrip()
1376        attnames = self.get_attnames(table)
1377        qoid = _oid_key(table) if 'oid' in attnames else None
1378        if keyname and isinstance(keyname, basestring):
1379            keyname = (keyname,)
1380        if qoid and isinstance(row, dict) and qoid in row and 'oid' not in row:
1381            row['oid'] = row[qoid]
1382        if not keyname:
1383            try:  # if keyname is not specified, try using the primary key
1384                keyname = self.pkey(table, True)
1385            except KeyError:  # the table has no primary key
1386                # try using the oid instead
1387                if qoid and isinstance(row, dict) and 'oid' in row:
1388                    keyname = ('oid',)
1389                else:
1390                    raise _prg_error('Table %s has no primary key' % table)
1391            else:  # the table has a primary key
1392                # check whether all key columns have values
1393                if isinstance(row, dict) and not set(keyname).issubset(row):
1394                    # try using the oid instead
1395                    if qoid and 'oid' in row:
1396                        keyname = ('oid',)
1397                    else:
1398                        raise KeyError(
1399                            'Missing value in row for specified keyname')
1400        if not isinstance(row, dict):
1401            if not isinstance(row, (tuple, list)):
1402                row = [row]
1403            if len(keyname) != len(row):
1404                raise KeyError(
1405                    'Differing number of items in keyname and row')
1406            row = dict(zip(keyname, row))
1407        params = []
1408        param = partial(self._adapt_param, params=params)
1409        col = self.escape_identifier
1410        what = 'oid, *' if qoid else '*'
1411        where = ' AND '.join('%s = %s' % (
1412            col(k), param(row[k], attnames[k])) for k in keyname)
1413        if 'oid' in row:
1414            if qoid:
1415                row[qoid] = row['oid']
1416            del row['oid']
1417        q = 'SELECT %s FROM %s WHERE %s LIMIT 1' % (
1418            what, self._escape_qualified_name(table), where)
1419        self._do_debug(q, params)
1420        q = self.db.query(q, params)
1421        res = q.dictresult()
1422        if not res:
1423            raise _db_error('No such record in %s\nwhere %s\nwith %s' % (
1424                table, where, self._list_params(params)))
1425        for n, value in res[0].items():
1426            if qoid and n == 'oid':
1427                n = qoid
1428            row[n] = value
1429        return row
1430
1431    def insert(self, table, row=None, **kw):
1432        """Insert a row into a database table.
1433
1434        This method inserts a row into a table.  The name of the table must
1435        be passed as the first parameter.  The other parameters are used for
1436        providing the data of the row that shall be inserted into the table.
1437        If a dictionary is supplied as the second parameter, it starts with
1438        that.  Otherwise it uses a blank dictionary. Either way the dictionary
1439        is updated from the keywords.
1440
1441        The dictionary is then reloaded with the values actually inserted in
1442        order to pick up values modified by rules, triggers, etc.
1443        """
1444        if table.endswith('*'):  # hint for descendant tables can be ignored
1445            table = table[:-1].rstrip()
1446        if row is None:
1447            row = {}
1448        row.update(kw)
1449        if 'oid' in row:
1450            del row['oid']  # do not insert oid
1451        attnames = self.get_attnames(table)
1452        qoid = _oid_key(table) if 'oid' in attnames else None
1453        params = []
1454        param = partial(self._adapt_param, params=params)
1455        col = self.escape_identifier
1456        names, values = [], []
1457        for n in attnames:
1458            if n in row:
1459                names.append(col(n))
1460                values.append(param(row[n], attnames[n]))
1461        if not names:
1462            raise _prg_error('No column found that can be inserted')
1463        names, values = ', '.join(names), ', '.join(values)
1464        ret = 'oid, *' if qoid else '*'
1465        q = 'INSERT INTO %s (%s) VALUES (%s) RETURNING %s' % (
1466            self._escape_qualified_name(table), names, values, ret)
1467        self._do_debug(q, params)
1468        q = self.db.query(q, params)
1469        res = q.dictresult()
1470        if res:  # this should always be true
1471            for n, value in res[0].items():
1472                if qoid and n == 'oid':
1473                    n = qoid
1474                row[n] = value
1475        return row
1476
1477    def update(self, table, row=None, **kw):
1478        """Update an existing row in a database table.
1479
1480        Similar to insert but updates an existing row.  The update is based
1481        on the primary key of the table or the OID value as munged by get
1482        or passed as keyword.
1483
1484        The dictionary is then modified to reflect any changes caused by the
1485        update due to triggers, rules, default values, etc.
1486        """
1487        if table.endswith('*'):
1488            table = table[:-1].rstrip()  # need parent table name
1489        attnames = self.get_attnames(table)
1490        qoid = _oid_key(table) if 'oid' in attnames else None
1491        if row is None:
1492            row = {}
1493        elif 'oid' in row:
1494            del row['oid']  # only accept oid key from named args for safety
1495        row.update(kw)
1496        if qoid and qoid in row and 'oid' not in row:
1497            row['oid'] = row[qoid]
1498        try:  # try using the primary key
1499            keyname = self.pkey(table, True)
1500        except KeyError:  # the table has no primary key
1501            # try using the oid instead
1502            if qoid and 'oid' in row:
1503                keyname = ('oid',)
1504            else:
1505                raise _prg_error('Table %s has no primary key' % table)
1506        else:  # the table has a primary key
1507            # check whether all key columns have values
1508            if not set(keyname).issubset(row):
1509                # try using the oid instead
1510                if qoid and 'oid' in row:
1511                    keyname = ('oid',)
1512                else:
1513                    raise KeyError('Missing primary key in row')
1514        params = []
1515        param = partial(self._adapt_param, params=params)
1516        col = self.escape_identifier
1517        where = ' AND '.join('%s = %s' % (
1518            col(k), param(row[k], attnames[k])) for k in keyname)
1519        if 'oid' in row:
1520            if qoid:
1521                row[qoid] = row['oid']
1522            del row['oid']
1523        values = []
1524        keyname = set(keyname)
1525        for n in attnames:
1526            if n in row and n not in keyname:
1527                values.append('%s = %s' % (col(n), param(row[n], attnames[n])))
1528        if not values:
1529            return row
1530        values = ', '.join(values)
1531        ret = 'oid, *' if qoid else '*'
1532        q = 'UPDATE %s SET %s WHERE %s RETURNING %s' % (
1533            self._escape_qualified_name(table), values, where, ret)
1534        self._do_debug(q, params)
1535        q = self.db.query(q, params)
1536        res = q.dictresult()
1537        if res:  # may be empty when row does not exist
1538            for n, value in res[0].items():
1539                if qoid and n == 'oid':
1540                    n = qoid
1541                row[n] = value
1542        return row
1543
1544    def upsert(self, table, row=None, **kw):
1545        """Insert a row into a database table with conflict resolution
1546
1547        This method inserts a row into a table, but instead of raising a
1548        ProgrammingError exception in case a row with the same primary key
1549        already exists, an update will be executed instead.  This will be
1550        performed as a single atomic operation on the database, so race
1551        conditions can be avoided.
1552
1553        Like the insert method, the first parameter is the name of the
1554        table and the second parameter can be used to pass the values to
1555        be inserted as a dictionary.
1556
1557        Unlike the insert und update statement, keyword parameters are not
1558        used to modify the dictionary, but to specify which columns shall
1559        be updated in case of a conflict, and in which way:
1560
1561        A value of False or None means the column shall not be updated,
1562        a value of True means the column shall be updated with the value
1563        that has been proposed for insertion, i.e. has been passed as value
1564        in the dictionary.  Columns that are not specified by keywords but
1565        appear as keys in the dictionary are also updated like in the case
1566        keywords had been passed with the value True.
1567
1568        So if in the case of a conflict you want to update every column that
1569        has been passed in the dictionary row , you would call upsert(table, row).
1570        If you don't want to do anything in case of a conflict, i.e. leave
1571        the existing row as it is, call upsert(table, row, **dict.fromkeys(row)).
1572
1573        If you need more fine-grained control of what gets updated, you can
1574        also pass strings in the keyword parameters.  These strings will
1575        be used as SQL expressions for the update columns.  In these
1576        expressions you can refer to the value that already exists in
1577        the table by prefixing the column name with "included.", and to
1578        the value that has been proposed for insertion by prefixing the
1579        column name with the "excluded."
1580
1581        The dictionary is modified in any case to reflect the values in
1582        the database after the operation has completed.
1583
1584        Note: The method uses the PostgreSQL "upsert" feature which is
1585        only available since PostgreSQL 9.5.
1586        """
1587        if table.endswith('*'):  # hint for descendant tables can be ignored
1588            table = table[:-1].rstrip()
1589        if row is None:
1590            row = {}
1591        if 'oid' in row:
1592            del row['oid']  # do not insert oid
1593        if 'oid' in kw:
1594            del kw['oid']  # do not update oid
1595        attnames = self.get_attnames(table)
1596        qoid = _oid_key(table) if 'oid' in attnames else None
1597        params = []
1598        param = partial(self._adapt_param,params=params)
1599        col = self.escape_identifier
1600        names, values, updates = [], [], []
1601        for n in attnames:
1602            if n in row:
1603                names.append(col(n))
1604                values.append(param(row[n], attnames[n]))
1605        names, values = ', '.join(names), ', '.join(values)
1606        try:
1607            keyname = self.pkey(table, True)
1608        except KeyError:
1609            raise _prg_error('Table %s has no primary key' % table)
1610        target = ', '.join(col(k) for k in keyname)
1611        update = []
1612        keyname = set(keyname)
1613        keyname.add('oid')
1614        for n in attnames:
1615            if n not in keyname:
1616                value = kw.get(n, True)
1617                if value:
1618                    if not isinstance(value, basestring):
1619                        value = 'excluded.%s' % col(n)
1620                    update.append('%s = %s' % (col(n), value))
1621        if not values:
1622            return row
1623        do = 'update set %s' % ', '.join(update) if update else 'nothing'
1624        ret = 'oid, *' if qoid else '*'
1625        q = ('INSERT INTO %s AS included (%s) VALUES (%s)'
1626            ' ON CONFLICT (%s) DO %s RETURNING %s') % (
1627                self._escape_qualified_name(table), names, values,
1628                target, do, ret)
1629        self._do_debug(q, params)
1630        try:
1631            q = self.db.query(q, params)
1632        except ProgrammingError:
1633            if self.server_version < 90500:
1634                raise _prg_error(
1635                    'Upsert operation is not supported by PostgreSQL version')
1636            raise  # re-raise original error
1637        res = q.dictresult()
1638        if res:  # may be empty with "do nothing"
1639            for n, value in res[0].items():
1640                if qoid and n == 'oid':
1641                    n = qoid
1642                row[n] = value
1643        else:
1644            self.get(table, row)
1645        return row
1646
1647    def clear(self, table, row=None):
1648        """Clear all the attributes to values determined by the types.
1649
1650        Numeric types are set to 0, Booleans are set to false, and everything
1651        else is set to the empty string.  If the row argument is present,
1652        it is used as the row dictionary and any entries matching attribute
1653        names are cleared with everything else left unchanged.
1654        """
1655        # At some point we will need a way to get defaults from a table.
1656        if row is None:
1657            row = {}  # empty if argument is not present
1658        attnames = self.get_attnames(table)
1659        for n, t in attnames.items():
1660            if n == 'oid':
1661                continue
1662            t = t.simple
1663            if t in DbTypes._num_types:
1664                row[n] = 0
1665            elif t == 'bool':
1666                row[n] = self._make_bool(False)
1667            else:
1668                row[n] = ''
1669        return row
1670
1671    def delete(self, table, row=None, **kw):
1672        """Delete an existing row in a database table.
1673
1674        This method deletes the row from a table.  It deletes based on the
1675        primary key of the table or the OID value as munged by get() or
1676        passed as keyword.
1677
1678        The return value is the number of deleted rows (i.e. 0 if the row
1679        did not exist and 1 if the row was deleted).
1680
1681        Note that if the row cannot be deleted because e.g. it is still
1682        referenced by another table, this method raises a ProgrammingError.
1683        """
1684        if table.endswith('*'):  # hint for descendant tables can be ignored
1685            table = table[:-1].rstrip()
1686        attnames = self.get_attnames(table)
1687        qoid = _oid_key(table) if 'oid' in attnames else None
1688        if row is None:
1689            row = {}
1690        elif 'oid' in row:
1691            del row['oid']  # only accept oid key from named args for safety
1692        row.update(kw)
1693        if qoid and qoid in row and 'oid' not in row:
1694            row['oid'] = row[qoid]
1695        try:  # try using the primary key
1696            keyname = self.pkey(table, True)
1697        except KeyError:  # the table has no primary key
1698            # try using the oid instead
1699            if qoid and 'oid' in row:
1700                keyname = ('oid',)
1701            else:
1702                raise _prg_error('Table %s has no primary key' % table)
1703        else:  # the table has a primary key
1704            # check whether all key columns have values
1705            if not set(keyname).issubset(row):
1706                # try using the oid instead
1707                if qoid and 'oid' in row:
1708                    keyname = ('oid',)
1709                else:
1710                    raise KeyError('Missing primary key in row')
1711        params = []
1712        param = partial(self._adapt_param, params=params)
1713        col = self.escape_identifier
1714        where = ' AND '.join('%s = %s' % (
1715            col(k), param(row[k], attnames[k])) for k in keyname)
1716        if 'oid' in row:
1717            if qoid:
1718                row[qoid] = row['oid']
1719            del row['oid']
1720        q = 'DELETE FROM %s WHERE %s' % (
1721            self._escape_qualified_name(table), where)
1722        self._do_debug(q, params)
1723        res = self.db.query(q, params)
1724        return int(res)
1725
1726    def truncate(self, table, restart=False, cascade=False, only=False):
1727        """Empty a table or set of tables.
1728
1729        This method quickly removes all rows from the given table or set
1730        of tables.  It has the same effect as an unqualified DELETE on each
1731        table, but since it does not actually scan the tables it is faster.
1732        Furthermore, it reclaims disk space immediately, rather than requiring
1733        a subsequent VACUUM operation. This is most useful on large tables.
1734
1735        If restart is set to True, sequences owned by columns of the truncated
1736        table(s) are automatically restarted.  If cascade is set to True, it
1737        also truncates all tables that have foreign-key references to any of
1738        the named tables.  If the parameter only is not set to True, all the
1739        descendant tables (if any) will also be truncated. Optionally, a '*'
1740        can be specified after the table name to explicitly indicate that
1741        descendant tables are included.
1742        """
1743        if isinstance(table, basestring):
1744            only = {table: only}
1745            table = [table]
1746        elif isinstance(table, (list, tuple)):
1747            if isinstance(only, (list, tuple)):
1748                only = dict(zip(table, only))
1749            else:
1750                only = dict.fromkeys(table, only)
1751        elif isinstance(table, (set, frozenset)):
1752            only = dict.fromkeys(table, only)
1753        else:
1754            raise TypeError('The table must be a string, list or set')
1755        if not (restart is None or isinstance(restart, (bool, int))):
1756            raise TypeError('Invalid type for the restart option')
1757        if not (cascade is None or isinstance(cascade, (bool, int))):
1758            raise TypeError('Invalid type for the cascade option')
1759        tables = []
1760        for t in table:
1761            u = only.get(t)
1762            if not (u is None or isinstance(u, (bool, int))):
1763                raise TypeError('Invalid type for the only option')
1764            if t.endswith('*'):
1765                if u:
1766                    raise ValueError(
1767                        'Contradictory table name and only options')
1768                t = t[:-1].rstrip()
1769            t = self._escape_qualified_name(t)
1770            if u:
1771                t = 'ONLY %s' % t
1772            tables.append(t)
1773        q = ['TRUNCATE', ', '.join(tables)]
1774        if restart:
1775            q.append('RESTART IDENTITY')
1776        if cascade:
1777            q.append('CASCADE')
1778        q = ' '.join(q)
1779        self._do_debug(q)
1780        return self.db.query(q)
1781
1782    def get_as_list(self, table, what=None, where=None,
1783            order=None, limit=None, offset=None, scalar=False):
1784        """Get a table as a list.
1785
1786        This gets a convenient representation of the table as a list
1787        of named tuples in Python.  You only need to pass the name of
1788        the table (or any other SQL expression returning rows).  Note that
1789        by default this will return the full content of the table which
1790        can be huge and overflow your memory.  However, you can control
1791        the amount of data returned using the other optional parameters.
1792
1793        The parameter 'what' can restrict the query to only return a
1794        subset of the table columns.  It can be a string, list or a tuple.
1795        The parameter 'where' can restrict the query to only return a
1796        subset of the table rows.  It can be a string, list or a tuple
1797        of SQL expressions that all need to be fulfilled.  The parameter
1798        'order' specifies the ordering of the rows.  It can also be a
1799        other string, list or a tuple.  If no ordering is specified,
1800        the result will be ordered by the primary key(s) or all columns
1801        if no primary key exists.  You can set 'order' to False if you
1802        don't care about the ordering.  The parameters 'limit' and 'offset'
1803        can be integers specifying the maximum number of rows returned
1804        and a number of rows skipped over.
1805
1806        If you set the 'scalar' option to True, then instead of the
1807        named tuples you will get the first items of these tuples.
1808        This is useful if the result has only one column anyway.
1809        """
1810        if not table:
1811            raise TypeError('The table name is missing')
1812        if what:
1813            if isinstance(what, (list, tuple)):
1814                what = ', '.join(map(str, what))
1815            if order is None:
1816                order = what
1817        else:
1818            what = '*'
1819        q = ['SELECT', what, 'FROM', table]
1820        if where:
1821            if isinstance(where, (list, tuple)):
1822                where = ' AND '.join(map(str, where))
1823            q.extend(['WHERE', where])
1824        if order is None:
1825            try:
1826                order = self.pkey(table, True)
1827            except (KeyError, ProgrammingError):
1828                try:
1829                    order = list(self.get_attnames(table))
1830                except (KeyError, ProgrammingError):
1831                    pass
1832        if order:
1833            if isinstance(order, (list, tuple)):
1834                order = ', '.join(map(str, order))
1835            q.extend(['ORDER BY', order])
1836        if limit:
1837            q.append('LIMIT %d' % limit)
1838        if offset:
1839            q.append('OFFSET %d' % offset)
1840        q = ' '.join(q)
1841        self._do_debug(q)
1842        q = self.db.query(q)
1843        res = q.namedresult()
1844        if res and scalar:
1845            res = [row[0] for row in res]
1846        return res
1847
1848    def get_as_dict(self, table, keyname=None, what=None, where=None,
1849            order=None, limit=None, offset=None, scalar=False):
1850        """Get a table as a dictionary.
1851
1852        This method is similar to get_as_list(), but returns the table
1853        as a Python dict instead of a Python list, which can be even
1854        more convenient. The primary key column(s) of the table will
1855        be used as the keys of the dictionary, while the other column(s)
1856        will be the corresponding values.  The keys will be named tuples
1857        if the table has a composite primary key.  The rows will be also
1858        named tuples unless the 'scalar' option has been set to True.
1859        With the optional parameter 'keyname' you can specify an alternative
1860        set of columns to be used as the keys of the dictionary.  It must
1861        be set as a string, list or a tuple.
1862
1863        If the Python version supports it, the dictionary will be an
1864        OrderedDict using the order specified with the 'order' parameter
1865        or the key column(s) if not specified.  You can set 'order' to False
1866        if you don't care about the ordering.  In this case the returned
1867        dictionary will be an ordinary one.
1868        """
1869        if not table:
1870            raise TypeError('The table name is missing')
1871        if not keyname:
1872            try:
1873                keyname = self.pkey(table, True)
1874            except (KeyError, ProgrammingError):
1875                raise _prg_error('Table %s has no primary key' % table)
1876        if isinstance(keyname, basestring):
1877            keyname = [keyname]
1878        elif not isinstance(keyname, (list, tuple)):
1879            raise KeyError('The keyname must be a string, list or tuple')
1880        if what:
1881            if isinstance(what, (list, tuple)):
1882                what = ', '.join(map(str, what))
1883            if order is None:
1884                order = what
1885        else:
1886            what = '*'
1887        q = ['SELECT', what, 'FROM', table]
1888        if where:
1889            if isinstance(where, (list, tuple)):
1890                where = ' AND '.join(map(str, where))
1891            q.extend(['WHERE', where])
1892        if order is None:
1893            order = keyname
1894        if order:
1895            if isinstance(order, (list, tuple)):
1896                order = ', '.join(map(str, order))
1897            q.extend(['ORDER BY', order])
1898        if limit:
1899            q.append('LIMIT %d' % limit)
1900        if offset:
1901            q.append('OFFSET %d' % offset)
1902        q = ' '.join(q)
1903        self._do_debug(q)
1904        q = self.db.query(q)
1905        res = q.getresult()
1906        cls = OrderedDict if order else dict
1907        if not res:
1908            return cls()
1909        keyset = set(keyname)
1910        fields = q.listfields()
1911        if not keyset.issubset(fields):
1912            raise KeyError('Missing keyname in row')
1913        keyind, rowind = [], []
1914        for i, f in enumerate(fields):
1915            (keyind if f in keyset else rowind).append(i)
1916        keytuple = len(keyind) > 1
1917        getkey = itemgetter(*keyind)
1918        keys = map(getkey, res)
1919        if scalar:
1920            rowind = rowind[:1]
1921            rowtuple = False
1922        else:
1923            rowtuple = len(rowind) > 1
1924        if scalar or rowtuple:
1925            getrow = itemgetter(*rowind)
1926        else:
1927            rowind = rowind[0]
1928            getrow = lambda row: (row[rowind],)
1929            rowtuple = True
1930        rows = map(getrow, res)
1931        if keytuple or rowtuple:
1932            namedresult = get_namedresult()
1933            if namedresult:
1934                if keytuple:
1935                    keys = namedresult(_MemoryQuery(keys, keyname))
1936                if rowtuple:
1937                    fields = [f for f in fields if f not in keyset]
1938                    rows = namedresult(_MemoryQuery(rows, fields))
1939        return cls(zip(keys, rows))
1940
1941    def notification_handler(self,
1942            event, callback, arg_dict=None, timeout=None, stop_event=None):
1943        """Get notification handler that will run the given callback."""
1944        return NotificationHandler(self,
1945            event, callback, arg_dict, timeout, stop_event)
1946
1947
1948# if run as script, print some information
1949
1950if __name__ == '__main__':
1951    print('PyGreSQL version' + version)
1952    print('')
1953    print(__doc__)
Note: See TracBrowser for help on using the repository browser.