source: trunk/pg.py @ 762

Last change on this file since 762 was 762, checked in by cito, 4 years ago

Docs and 100% test coverage for NotificationHandler?

  • Property svn:keywords set to Id
File size: 45.0 KB
Line 
1#! /usr/bin/python
2#
3# pg.py
4#
5# $Id: pg.py 762 2016-01-17 16:32:18Z cito $
6#
7
8"""PyGreSQL classic interface.
9
10This pg module implements some basic database management stuff.
11It includes the _pg module and builds on it, providing the higher
12level wrapper class named DB with additional functionality.
13This is known as the "classic" ("old style") PyGreSQL interface.
14For a DB-API 2 compliant interface use the newer pgdb module.
15"""
16
17# Copyright (c) 1997-2016 by D'Arcy J.M. Cain.
18#
19# Contributions made by Ch. Zwerschke and others.
20#
21# The notification handler is based on pgnotify which is
22# Copyright (c) 2001 Ng Pheng Siong. All rights reserved.
23#
24# Permission to use, copy, modify, and distribute this software and its
25# documentation for any purpose and without fee is hereby granted,
26# provided that the above copyright notice appear in all copies and that
27# both that copyright notice and this permission notice appear in
28# supporting documentation.
29
30from __future__ import print_function
31
32from _pg import *
33
34import select
35import warnings
36
37from decimal import Decimal
38from collections import namedtuple
39from functools import partial
40
41try:
42    from collections import OrderedDict
43except ImportError:  # Python 2.6 or 3.0
44    OrderedDict = dict
45
46try:
47    basestring
48except NameError:  # Python >= 3.0
49    basestring = (str, bytes)
50
51set_decimal(Decimal)
52
53
54# Auxiliary functions that are independent from a DB connection:
55
56def _oid_key(table):
57    """Build oid key from a table name."""
58    return 'oid(%s)' % table
59
60
61def _simpletype(typ):
62    """Determine a simplified name a pg_type name."""
63    if typ.startswith('bool'):
64        return 'bool'
65    if typ.startswith(('abstime', 'date', 'interval', 'timestamp')):
66        return 'date'
67    if typ.startswith(('cid', 'oid', 'int', 'xid')):
68        return 'int'
69    if typ.startswith('float'):
70        return 'float'
71    if typ.startswith('numeric'):
72        return 'num'
73    if typ.startswith('money'):
74        return 'money'
75    if typ.startswith('bytea'):
76        return 'bytea'
77    return 'text'
78
79
80def _namedresult(q):
81    """Get query result as named tuples."""
82    row = namedtuple('Row', q.listfields())
83    return [row(*r) for r in q.getresult()]
84
85set_namedresult(_namedresult)
86
87
88def _db_error(msg, cls=DatabaseError):
89    """Return DatabaseError with empty sqlstate attribute."""
90    error = cls(msg)
91    error.sqlstate = None
92    return error
93
94
95def _int_error(msg):
96    """Return InternalError."""
97    return _db_error(msg, InternalError)
98
99
100def _prg_error(msg):
101    """Return ProgrammingError."""
102    return _db_error(msg, ProgrammingError)
103
104
105class NotificationHandler(object):
106    """A PostgreSQL client-side asynchronous notification handler."""
107
108    def __init__(self, db, event, callback=None,
109            arg_dict=None, timeout=None, stop_event=None):
110        """Initialize the notification handler.
111
112        You must pass a PyGreSQL database connection, the name of an
113        event (notification channel) to listen for and a callback function.
114
115        You can also specify a dictionary arg_dict that will be passed as
116        the single argument to the callback function, and a timeout value
117        in seconds (a floating point number denotes fractions of seconds).
118        If it is absent or None, the callers will never time out.  If the
119        timeout is reached, the callback function will be called with a
120        single argument that is None.  If you set the timeout to zero,
121        the handler will poll notifications synchronously and return.
122
123        You can specify the name of the event that will be used to signal
124        the handler to stop listening as stop_event. By default, it will
125        be the event name prefixed with 'stop_'.
126        """
127        self.db = db
128        self.event = event
129        self.stop_event = stop_event or 'stop_%s' % event
130        self.listening = False
131        self.callback = callback
132        if arg_dict is None:
133            arg_dict = {}
134        self.arg_dict = arg_dict
135        self.timeout = timeout
136
137    def __del__(self):
138        self.unlisten()
139
140    def close(self):
141        """Stop listening and close the connection."""
142        if self.db:
143            self.unlisten()
144            self.db.close()
145            self.db = None
146
147    def listen(self):
148        """Start listening for the event and the stop event."""
149        if not self.listening:
150            self.db.query('listen "%s"' % self.event)
151            self.db.query('listen "%s"' % self.stop_event)
152            self.listening = True
153
154    def unlisten(self):
155        """Stop listening for the event and the stop event."""
156        if self.listening:
157            self.db.query('unlisten "%s"' % self.event)
158            self.db.query('unlisten "%s"' % self.stop_event)
159            self.listening = False
160
161    def notify(self, db=None, stop=False, payload=None):
162        """Generate a notification.
163
164        Optionally, you can pass a payload with the notification.
165
166        If you set the stop flag, a stop notification will be sent that
167        will cause the handler to stop listening.
168
169        Note: If the notification handler is running in another thread, you
170        must pass a different database connection since PyGreSQL database
171        connections are not thread-safe.
172        """
173        if self.listening:
174            if not db:
175                db = self.db
176            q = 'notify "%s"' % (self.stop_event if stop else self.event)
177            if payload:
178                q += ", '%s'" % payload
179            return db.query(q)
180
181    def __call__(self):
182        """Invoke the notification handler.
183
184        The handler is a loop that listens for notifications on the event
185        and stop event channels.  When either of these notifications are
186        received, its associated 'pid', 'event' and 'extra' (the payload
187        passed with the notification) are inserted into its arg_dict
188        dictionary and the callback is invoked with this dictionary as
189        a single argument.  When the handler receives a stop event, it
190        stops listening to both events and return.
191
192        In the special case that the timeout of the handler has been set
193        to zero, the handler will poll all events synchronously and return.
194        If will keep listening until it receives a stop event.
195
196        Note: If you run this loop in another thread, don't use the same
197        database connection for database operations in the main thread.
198        """
199        self.listen()
200        poll = self.timeout == 0
201        if not poll:
202            rlist = [self.db.fileno()]
203        while self.listening:
204            if poll or select.select(rlist, [], [], self.timeout)[0]:
205                while self.listening:
206                    notice = self.db.getnotify()
207                    if not notice:  # no more messages
208                        break
209                    event, pid, extra = notice
210                    if event not in (self.event, self.stop_event):
211                        self.unlisten()
212                        raise _db_error(
213                            'Listening for "%s" and "%s", but notified of "%s"'
214                            % (self.event, self.stop_event, event))
215                    if event == self.stop_event:
216                        self.unlisten()
217                    self.arg_dict.update(pid=pid, event=event, extra=extra)
218                    self.callback(self.arg_dict)
219                if poll:
220                    break
221            else:   # we timed out
222                self.unlisten()
223                self.callback(None)
224
225
226def pgnotify(*args, **kw):
227    """Same as NotificationHandler, under the traditional name."""
228    warnings.warn("pgnotify is deprecated, use NotificationHandler instead",
229        DeprecationWarning, stacklevel=2)
230    return NotificationHandler(*args, **kw)
231
232
233# The actual PostGreSQL database connection interface:
234
235class DB(object):
236    """Wrapper class for the _pg connection type."""
237
238    def __init__(self, *args, **kw):
239        """Create a new connection
240
241        You can pass either the connection parameters or an existing
242        _pg or pgdb connection. This allows you to use the methods
243        of the classic pg interface with a DB-API 2 pgdb connection.
244        """
245        if not args and len(kw) == 1:
246            db = kw.get('db')
247        elif not kw and len(args) == 1:
248            db = args[0]
249        else:
250            db = None
251        if db:
252            if isinstance(db, DB):
253                db = db.db
254            else:
255                try:
256                    db = db._cnx
257                except AttributeError:
258                    pass
259        if not db or not hasattr(db, 'db') or not hasattr(db, 'query'):
260            db = connect(*args, **kw)
261            self._closeable = True
262        else:
263            self._closeable = False
264        self.db = db
265        self.dbname = db.db
266        self._regtypes = False
267        self._attnames = {}
268        self._pkeys = {}
269        self._privileges = {}
270        self._args = args, kw
271        self.debug = None  # For debugging scripts, this can be set
272            # * to a string format specification (e.g. in CGI set to "%s<BR>"),
273            # * to a file object to write debug statements or
274            # * to a callable object which takes a string argument
275            # * to any other true value to just print debug statements
276
277    def __getattr__(self, name):
278        # All undefined members are same as in underlying connection:
279        if self.db:
280            return getattr(self.db, name)
281        else:
282            raise _int_error('Connection is not valid')
283
284    def __dir__(self):
285        # Custom dir function including the attributes of the connection:
286        attrs = set(self.__class__.__dict__)
287        attrs.update(self.__dict__)
288        attrs.update(dir(self.db))
289        return sorted(attrs)
290
291    # Context manager methods
292
293    def __enter__(self):
294        """Enter the runtime context. This will start a transactio."""
295        self.begin()
296        return self
297
298    def __exit__(self, et, ev, tb):
299        """Exit the runtime context. This will end the transaction."""
300        if et is None and ev is None and tb is None:
301            self.commit()
302        else:
303            self.rollback()
304
305    # Auxiliary methods
306
307    def _do_debug(self, *args):
308        """Print a debug message"""
309        if self.debug:
310            s = '\n'.join(args)
311            if isinstance(self.debug, basestring):
312                print(self.debug % s)
313            elif hasattr(self.debug, 'write'):
314                self.debug.write(s + '\n')
315            elif callable(self.debug):
316                self.debug(s)
317            else:
318                print(s)
319
320    def _escape_qualified_name(self, s):
321        """Escape a qualified name.
322
323        Escapes the name for use as an SQL identifier, unless the
324        name contains a dot, in which case the name is ambiguous
325        (could be a qualified name or just a name with a dot in it)
326        and must be quoted manually by the caller.
327        """
328        if '.' not in s:
329            s = self.escape_identifier(s)
330        return s
331
332    @staticmethod
333    def _make_bool(d):
334        """Get boolean value corresponding to d."""
335        return bool(d) if get_bool() else ('t' if d else 'f')
336
337    _bool_true_values = frozenset('t true 1 y yes on'.split())
338
339    def _prepare_bool(self, d):
340        """Prepare a boolean parameter."""
341        if isinstance(d, basestring):
342            if not d:
343                return None
344            d = d.lower() in self._bool_true_values
345        return 't' if d else 'f'
346
347    _date_literals = frozenset('current_date current_time'
348        ' current_timestamp localtime localtimestamp'.split())
349
350    def _prepare_date(self, d):
351        """Prepare a date parameter."""
352        if not d:
353            return None
354        if isinstance(d, basestring) and d.lower() in self._date_literals:
355            raise ValueError
356        return d
357
358    def _prepare_num(self, d):
359        """Prepare a numeric parameter."""
360        if not d and d != 0:
361            return None
362        return d
363
364    def _prepare_bytea(self, d):
365        """Prepare a bytea parameter."""
366        return self.escape_bytea(d)
367
368    _prepare_funcs = dict(  # quote methods for each type
369        bool=_prepare_bool, date=_prepare_date,
370        int=_prepare_num, num=_prepare_num, float=_prepare_num,
371        money=_prepare_num, bytea=_prepare_bytea)
372
373    def _prepare_param(self, value, typ, params):
374        """Prepare and add a parameter to the list."""
375        if value is not None and typ != 'text':
376            try:
377                prepare = self._prepare_funcs[typ]
378            except KeyError:
379                pass
380            else:
381                try:
382                    value = prepare(self, value)
383                except ValueError:
384                    return value
385        params.append(value)
386        return '$%d' % len(params)
387
388    def _list_params(self, params):
389        """Create a human readable parameter list."""
390        return ', '.join('$%d=%r' % (n, v) for n, v in enumerate(params, 1))
391
392    @staticmethod
393    def _prepare_qualified_param(name, param):
394        """Quote parameter representing a qualified name.
395
396        Escapes the name for use as an SQL parameter, unless the
397        name contains a dot, in which case the name is ambiguous
398        (could be a qualified name or just a name with a dot in it)
399        and must be quoted manually by the caller.
400
401        """
402        if isinstance(param, int):
403            param = "$%d" % param
404        if '.' not in name:
405            param = 'quote_ident(%s)' % (param,)
406        return param
407
408    # Public methods
409
410    # escape_string and escape_bytea exist as methods,
411    # so we define unescape_bytea as a method as well
412    unescape_bytea = staticmethod(unescape_bytea)
413
414    def close(self):
415        """Close the database connection."""
416        # Wraps shared library function so we can track state.
417        if self._closeable:
418            if self.db:
419                self.db.close()
420                self.db = None
421            else:
422                raise _int_error('Connection already closed')
423
424    def reset(self):
425        """Reset connection with current parameters.
426
427        All derived queries and large objects derived from this connection
428        will not be usable after this call.
429
430        """
431        if self.db:
432            self.db.reset()
433        else:
434            raise _int_error('Connection already closed')
435
436    def reopen(self):
437        """Reopen connection to the database.
438
439        Used in case we need another connection to the same database.
440        Note that we can still reopen a database that we have closed.
441
442        """
443        # There is no such shared library function.
444        if self._closeable:
445            db = connect(*self._args[0], **self._args[1])
446            if self.db:
447                self.db.close()
448            self.db = db
449
450    def begin(self, mode=None):
451        """Begin a transaction."""
452        qstr = 'BEGIN'
453        if mode:
454            qstr += ' ' + mode
455        return self.query(qstr)
456
457    start = begin
458
459    def commit(self):
460        """Commit the current transaction."""
461        return self.query('COMMIT')
462
463    end = commit
464
465    def rollback(self, name=None):
466        """Roll back the current transaction."""
467        qstr = 'ROLLBACK'
468        if name:
469            qstr += ' TO ' + name
470        return self.query(qstr)
471
472    def savepoint(self, name):
473        """Define a new savepoint within the current transaction."""
474        return self.query('SAVEPOINT ' + name)
475
476    def release(self, name):
477        """Destroy a previously defined savepoint."""
478        return self.query('RELEASE ' + name)
479
480    def get_parameter(self, parameter):
481        """Get the value of a run-time parameter.
482
483        If the parameter is a string, the return value will also be a string
484        that is the current setting of the run-time parameter with that name.
485
486        You can get several parameters at once by passing a list, set or dict.
487        When passing a list of parameter names, the return value will be a
488        corresponding list of parameter settings.  When passing a set of
489        parameter names, a new dict will be returned, mapping these parameter
490        names to their settings.  Finally, if you pass a dict as parameter,
491        its values will be set to the current parameter settings corresponding
492        to its keys.
493
494        By passing the special name 'all' as the parameter, you can get a dict
495        of all existing configuration parameters.
496        """
497        if isinstance(parameter, basestring):
498            parameter = [parameter]
499            values = None
500        elif isinstance(parameter, (list, tuple)):
501            values = []
502        elif isinstance(parameter, (set, frozenset)):
503            values = {}
504        elif isinstance(parameter, dict):
505            values = parameter
506        else:
507            raise TypeError(
508                'The parameter must be a string, list, set or dict')
509        if not parameter:
510            raise TypeError('No parameter has been specified')
511        params = {} if isinstance(values, dict) else []
512        for key in parameter:
513            param = key.strip().lower() if isinstance(
514                key, basestring) else None
515            if not param:
516                raise TypeError('Invalid parameter')
517            if param == 'all':
518                q = 'SHOW ALL'
519                values = self.db.query(q).getresult()
520                values = dict(value[:2] for value in values)
521                break
522            if isinstance(values, dict):
523                params[param] = key
524            else:
525                params.append(param)
526        else:
527            for param in params:
528                q = 'SHOW %s' % (param,)
529                value = self.db.query(q).getresult()[0][0]
530                if values is None:
531                    values = value
532                elif isinstance(values, list):
533                    values.append(value)
534                else:
535                    values[params[param]] = value
536        return values
537
538    def set_parameter(self, parameter, value=None, local=False):
539        """Set the value of a run-time parameter.
540
541        If the parameter and the value are strings, the run-time parameter
542        will be set to that value.  If no value or None is passed as a value,
543        then the run-time parameter will be restored to its default value.
544
545        You can set several parameters at once by passing a list of parameter
546        names, together with a single value that all parameters should be
547        set to or with a corresponding list of values.  You can also pass
548        the parameters as a set if you only provide a single value.
549        Finally, you can pass a dict with parameter names as keys.  In this
550        case, you should not pass a value, since the values for the parameters
551        will be taken from the dict.
552
553        By passing the special name 'all' as the parameter, you can reset
554        all existing settable run-time parameters to their default values.
555
556        If you set local to True, then the command takes effect for only the
557        current transaction.  After commit() or rollback(), the session-level
558        setting takes effect again.  Setting local to True will appear to
559        have no effect if it is executed outside a transaction, since the
560        transaction will end immediately.
561        """
562        if isinstance(parameter, basestring):
563            parameter = {parameter: value}
564        elif isinstance(parameter, (list, tuple)):
565            if isinstance(value, (list, tuple)):
566                parameter = dict(zip(parameter, value))
567            else:
568                parameter = dict.fromkeys(parameter, value)
569        elif isinstance(parameter, (set, frozenset)):
570            if isinstance(value, (list, tuple, set, frozenset)):
571                value = set(value)
572                if len(value) == 1:
573                    value = value.pop()
574            if not(value is None or isinstance(value, basestring)):
575                raise ValueError('A single value must be specified'
576                    ' when parameter is a set')
577            parameter = dict.fromkeys(parameter, value)
578        elif isinstance(parameter, dict):
579            if value is not None:
580                raise ValueError('A value must not be specified'
581                    ' when parameter is a dictionary')
582        else:
583            raise TypeError(
584                'The parameter must be a string, list, set or dict')
585        if not parameter:
586            raise TypeError('No parameter has been specified')
587        params = {}
588        for key, value in parameter.items():
589            param = key.strip().lower() if isinstance(
590                key, basestring) else None
591            if not param:
592                raise TypeError('Invalid parameter')
593            if param == 'all':
594                if value is not None:
595                    raise ValueError('A value must ot be specified'
596                        " when parameter is 'all'")
597                params = {'all': None}
598                break
599            params[param] = value
600        local = ' LOCAL' if local else ''
601        for param, value in params.items():
602            if value is None:
603                q = 'RESET%s %s' % (local, param)
604            else:
605                q = 'SET%s %s TO %s' % (local, param, value)
606            self._do_debug(q)
607            self.db.query(q)
608
609    def query(self, qstr, *args):
610        """Execute a SQL command string.
611
612        This method simply sends a SQL query to the database. If the query is
613        an insert statement that inserted exactly one row into a table that
614        has OIDs, the return value is the OID of the newly inserted row.
615        If the query is an update or delete statement, or an insert statement
616        that did not insert exactly one row in a table with OIDs, then the
617        number of rows affected is returned as a string. If it is a statement
618        that returns rows as a result (usually a select statement, but maybe
619        also an "insert/update ... returning" statement), this method returns
620        a Query object that can be accessed via getresult() or dictresult()
621        or simply printed. Otherwise, it returns `None`.
622
623        The query can contain numbered parameters of the form $1 in place
624        of any data constant. Arguments given after the query string will
625        be substituted for the corresponding numbered parameter. Parameter
626        values can also be given as a single list or tuple argument.
627        """
628        # Wraps shared library function for debugging.
629        if not self.db:
630            raise _int_error('Connection is not valid')
631        self._do_debug(qstr)
632        return self.db.query(qstr, args)
633
634    def pkey(self, table, flush=False):
635        """Get or set the primary key of a table.
636
637        Composite primary keys are represented as frozensets. Note that
638        this raises a KeyError if the table does not have a primary key.
639
640        If flush is set then the internal cache for primary keys will
641        be flushed. This may be necessary after the database schema or
642        the search path has been changed.
643        """
644        pkeys = self._pkeys
645        if flush:
646            pkeys.clear()
647            self._do_debug('The pkey cache has been flushed')
648        try:  # cache lookup
649            pkey = pkeys[table]
650        except KeyError:  # cache miss, check the database
651            q = ("SELECT a.attname FROM pg_index i"
652                " JOIN pg_attribute a ON a.attrelid = i.indrelid"
653                " AND a.attnum = ANY(i.indkey)"
654                " AND NOT a.attisdropped"
655                " WHERE i.indrelid=%s::regclass"
656                " AND i.indisprimary") % (
657                    self._prepare_qualified_param(table, 1),)
658            pkey = self.db.query(q, (table,)).getresult()
659            if not pkey:
660                raise KeyError('Table %s has no primary key' % table)
661            if len(pkey) > 1:
662                pkey = frozenset(k[0] for k in pkey)
663            else:
664                pkey = pkey[0][0]
665            pkeys[table] = pkey  # cache it
666        return pkey
667
668    def get_databases(self):
669        """Get list of databases in the system."""
670        return [s[0] for s in
671            self.db.query('SELECT datname FROM pg_database').getresult()]
672
673    def get_relations(self, kinds=None):
674        """Get list of relations in connected database of specified kinds.
675
676        If kinds is None or empty, all kinds of relations are returned.
677        Otherwise kinds can be a string or sequence of type letters
678        specifying which kind of relations you want to list.
679        """
680        where = " AND r.relkind IN (%s)" % ','.join(
681            ["'%s'" % k for k in kinds]) if kinds else ''
682        q = ("SELECT quote_ident(s.nspname)||'.'||quote_ident(r.relname)"
683            " FROM pg_class r"
684            " JOIN pg_namespace s ON s.oid = r.relnamespace"
685            " WHERE s.nspname NOT SIMILAR"
686            " TO 'pg/_%%|information/_schema' ESCAPE '/' %s"
687            " ORDER BY s.nspname, r.relname") % where
688        return [r[0] for r in self.db.query(q).getresult()]
689
690    def get_tables(self):
691        """Return list of tables in connected database."""
692        return self.get_relations('r')
693
694    def get_attnames(self, table, flush=False):
695        """Given the name of a table, dig out the set of attribute names.
696
697        Returns a dictionary of attribute names (the names are the keys,
698        the values are the names of the attributes' types).
699
700        If your Python version supports this, the dictionary will be an
701        OrderedDictionary with the column names in the right order.
702
703        If flush is set then the internal cache for attribute names will
704        be flushed. This may be necessary after the database schema or
705        the search path has been changed.
706
707        By default, only a limited number of simple types will be returned.
708        You can get the regular types after calling use_regtypes(True).
709        """
710        attnames = self._attnames
711        if flush:
712            attnames.clear()
713            self._do_debug('The attnames cache has been flushed')
714        try:  # cache lookup
715            names = attnames[table]
716        except KeyError:  # cache miss, check the database
717            q = ("SELECT a.attname, t.typname%s"
718                " FROM pg_attribute a"
719                " JOIN pg_type t ON t.oid = a.atttypid"
720                " WHERE a.attrelid = %s::regclass"
721                " AND (a.attnum > 0 OR a.attname = 'oid')"
722                " AND NOT a.attisdropped ORDER BY a.attnum") % (
723                    '::regtype' if self._regtypes else '',
724                    self._prepare_qualified_param(table, 1))
725            names = self.db.query(q, (table,)).getresult()
726            if not names:
727                raise KeyError('Table %s does not exist' % table)
728            if not self._regtypes:
729                names = ((name, _simpletype(typ)) for name, typ in names)
730            names = OrderedDict(names)
731            attnames[table] = names  # cache it
732        return names
733
734    def use_regtypes(self, regtypes=None):
735        """Use regular type names instead of simplified type names."""
736        if regtypes is None:
737            return self._regtypes
738        else:
739            regtypes = bool(regtypes)
740            if regtypes != self._regtypes:
741                self._regtypes = regtypes
742                self._attnames.clear()
743            return regtypes
744
745    def has_table_privilege(self, table, privilege='select'):
746        """Check whether current user has specified table privilege."""
747        privilege = privilege.lower()
748        try:  # ask cache
749            return self._privileges[(table, privilege)]
750        except KeyError:  # cache miss, ask the database
751            q = "SELECT has_table_privilege(%s, $2)" % (
752                self._prepare_qualified_param(table, 1),)
753            q = self.db.query(q, (table, privilege))
754            ret = q.getresult()[0][0] == self._make_bool(True)
755            self._privileges[(table, privilege)] = ret  # cache it
756            return ret
757
758    def get(self, table, row, keyname=None):
759        """Get a row from a database table or view.
760
761        This method is the basic mechanism to get a single row.  The keyname
762        that the key specifies a unique row.  If keyname is not specified
763        then the primary key for the table is used.  If row is a dictionary
764        then the value for the key is taken from it and it is modified to
765        include the new values, replacing existing values where necessary.
766        For a composite key, keyname can also be a sequence of key names.
767        The OID is also put into the dictionary if the table has one, but
768        in order to allow the caller to work with multiple tables, it is
769        munged as "oid(table)".
770        """
771        if table.endswith('*'):  # scan descendant tables?
772            table = table[:-1].rstrip()  # need parent table name
773        if not keyname:
774            # use the primary key by default
775            try:
776                keyname = self.pkey(table)
777            except KeyError:
778                raise _prg_error('Table %s has no primary key' % table)
779        attnames = self.get_attnames(table)
780        params = []
781        param = partial(self._prepare_param, params=params)
782        col = self.escape_identifier
783        # We want the oid for later updates if that isn't the key.
784        # To allow users to work with multiple tables, we munge
785        # the name of the "oid" key by adding the name of the table.
786        qoid = _oid_key(table)
787        if keyname == 'oid':
788            if isinstance(row, dict):
789                if qoid not in row:
790                    raise _db_error('%s not in row' % qoid)
791            else:
792                row = {qoid: row}
793            what = '*'
794            where = 'oid = %s' % param(row[qoid], 'int')
795        else:
796            keyname = [keyname] if isinstance(
797                keyname, basestring) else sorted(keyname)
798            if not isinstance(row, dict):
799                if len(keyname) > 1:
800                    raise _prg_error('Composite key needs dict as row')
801                row = dict((k, row) for k in keyname)
802            what = ', '.join(col(k) for k in attnames)
803            where = ' AND '.join('%s = %s' % (
804                col(k), param(row[k], attnames[k])) for k in keyname)
805        q = 'SELECT %s FROM %s WHERE %s LIMIT 1' % (
806            what, self._escape_qualified_name(table), where)
807        self._do_debug(q, params)
808        q = self.db.query(q, params)
809        res = q.dictresult()
810        if not res:
811            raise _db_error('No such record in %s\nwhere %s\nwith %s' % (
812                table, where, self._list_params(params)))
813        for n, value in res[0].items():
814            if n == 'oid':
815                n = qoid
816            elif attnames.get(n) == 'bytea':
817                value = self.unescape_bytea(value)
818            row[n] = value
819        return row
820
821    def insert(self, table, row=None, **kw):
822        """Insert a row into a database table.
823
824        This method inserts a row into a table.  The name of the table must
825        be passed as the first parameter.  The other parameters are used for
826        providing the data of the row that shall be inserted into the table.
827        If a dictionary is supplied as the second parameter, it starts with
828        that.  Otherwise it uses a blank dictionary. Either way the dictionary
829        is updated from the keywords.
830
831        The dictionary is then reloaded with the values actually inserted in
832        order to pick up values modified by rules, triggers, etc.
833
834        Note: The method currently doesn't support insert into views
835        although PostgreSQL does.
836        """
837        if 'oid' in kw:
838            del kw['oid']
839        if row is None:
840            row = {}
841        row.update(kw)
842        attnames = self.get_attnames(table)
843        params = []
844        param = partial(self._prepare_param, params=params)
845        col = self.escape_identifier
846        names, values = [], []
847        for n in attnames:
848            if n in row:
849                names.append(col(n))
850                values.append(param(row[n], attnames[n]))
851        names, values = ', '.join(names), ', '.join(values)
852        ret = 'oid, *' if 'oid' in attnames else '*'
853        q = 'INSERT INTO %s (%s) VALUES (%s) RETURNING %s' % (
854            self._escape_qualified_name(table), names, values, ret)
855        self._do_debug(q, params)
856        q = self.db.query(q, params)
857        res = q.dictresult()
858        if not res:
859            raise _int_error('Insert operation did not return new values')
860        for n, value in res[0].items():
861            if n == 'oid':
862                n = _oid_key(table)
863            elif attnames.get(n) == 'bytea' and value is not None:
864                value = self.unescape_bytea(value)
865            row[n] = value
866        return row
867
868    def update(self, table, row=None, **kw):
869        """Update an existing row in a database table.
870
871        Similar to insert but updates an existing row.  The update is based
872        on the OID value as munged by get or passed as keyword, or on the
873        primary key of the table.  The dictionary is modified to reflect
874        any changes caused by the update due to triggers, rules, default
875        values, etc.
876        """
877        # Update always works on the oid which get() returns if available,
878        # otherwise use the primary key.  Fail if neither.
879        # Note that we only accept oid key from named args for safety.
880        qoid = _oid_key(table)
881        if 'oid' in kw:
882            kw[qoid] = kw['oid']
883            del kw['oid']
884        if row is None:
885            row = {}
886        row.update(kw)
887        attnames = self.get_attnames(table)
888        params = []
889        param = partial(self._prepare_param, params=params)
890        col = self.escape_identifier
891        if qoid in row:
892            where = 'oid = %s' % param(row[qoid], 'int')
893            keyname = []
894        else:
895            try:
896                keyname = self.pkey(table)
897            except KeyError:
898                raise _prg_error('Table %s has no primary key' % table)
899            keyname = [keyname] if isinstance(
900                keyname, basestring) else sorted(keyname)
901            try:
902                where = ' AND '.join('%s = %s' % (
903                    col(k), param(row[k], attnames[k])) for k in keyname)
904            except KeyError:
905                raise _prg_error('Update operation needs primary key or oid')
906        keyname = set(keyname)
907        keyname.add('oid')
908        values = []
909        for n in attnames:
910            if n in row and n not in keyname:
911                values.append('%s = %s' % (col(n), param(row[n], attnames[n])))
912        if not values:
913            return row
914        values = ', '.join(values)
915        ret = 'oid, *' if 'oid' in attnames else '*'
916        q = 'UPDATE %s SET %s WHERE %s RETURNING %s' % (
917            self._escape_qualified_name(table), values, where, ret)
918        self._do_debug(q, params)
919        q = self.db.query(q, params)
920        res = q.dictresult()
921        if res:  # may be empty when row does not exist
922            for n, value in res[0].items():
923                if n == 'oid':
924                    n = qoid
925                elif attnames.get(n) == 'bytea' and value is not None:
926                    value = self.unescape_bytea(value)
927                row[n] = value
928        return row
929
930    def upsert(self, table, row=None, **kw):
931        """Insert a row into a database table with conflict resolution
932
933        This method inserts a row into a table, but instead of raising a
934        ProgrammingError exception in case a row with the same primary key
935        already exists, an update will be executed instead.  This will be
936        performed as a single atomic operation on the database, so race
937        conditions can be avoided.
938
939        Like the insert method, the first parameter is the name of the
940        table and the second parameter can be used to pass the values to
941        be inserted as a dictionary.
942
943        Unlike the insert und update statement, keyword parameters are not
944        used to modify the dictionary, but to specify which columns shall
945        be updated in case of a conflict, and in which way:
946
947        A value of False or None means the column shall not be updated,
948        a value of True means the column shall be updated with the value
949        that has been proposed for insertion, i.e. has been passed as value
950        in the dictionary.  Columns that are not specified by keywords but
951        appear as keys in the dictionary are also updated like in the case
952        keywords had been passed with the value True.
953
954        So if in the case of a conflict you want to update every column that
955        has been passed in the dictionary row , you would call upsert(table, row).
956        If you don't want to do anything in case of a conflict, i.e. leave
957        the existing row as it is, call upsert(table, row, **dict.fromkeys(row)).
958
959        If you need more fine-grained control of what gets updated, you can
960        also pass strings in the keyword parameters.  These strings will
961        be used as SQL expressions for the update columns.  In these
962        expressions you can refer to the value that already exists in
963        the table by prefixing the column name with "included.", and to
964        the value that has been proposed for insertion by prefixing the
965        column name with the "excluded."
966
967        The dictionary is modified in any case to reflect the values in
968        the database after the operation has completed.
969
970        Note: The method uses the PostgreSQL "upsert" feature which is
971        only available since PostgreSQL 9.5.
972        """
973        if 'oid' in kw:
974            del kw['oid']
975        if row is None:
976            row = {}
977        attnames = self.get_attnames(table)
978        params = []
979        param = partial(self._prepare_param,params=params)
980        col = self.escape_identifier
981        names, values, updates = [], [], []
982        for n in attnames:
983            if n in row:
984                names.append(col(n))
985                values.append(param(row[n], attnames[n]))
986        names, values = ', '.join(names), ', '.join(values)
987        try:
988            keyname = self.pkey(table)
989        except KeyError:
990            raise _prg_error('Table %s has no primary key' % table)
991        keyname = [keyname] if isinstance(
992            keyname, basestring) else sorted(keyname)
993        try:
994            target = ', '.join(col(k) for k in keyname)
995        except KeyError:
996            raise _prg_error('Upsert operation needs primary key or oid')
997        update = []
998        keyname = set(keyname)
999        keyname.add('oid')
1000        for n in attnames:
1001            if n not in keyname:
1002                value = kw.get(n, True)
1003                if value:
1004                    if not isinstance(value, basestring):
1005                        value = 'excluded.%s' % col(n)
1006                    update.append('%s = %s' % (col(n), value))
1007        if not values and not update:
1008            return row
1009        do = 'update set %s' % ', '.join(update) if update else 'nothing'
1010        ret = 'oid, *' if 'oid' in attnames else '*'
1011        q = ('INSERT INTO %s AS included (%s) VALUES (%s)'
1012            ' ON CONFLICT (%s) DO %s RETURNING %s') % (
1013                self._escape_qualified_name(table), names, values,
1014                target, do, ret)
1015        self._do_debug(q, params)
1016        try:
1017            q = self.db.query(q, params)
1018        except ProgrammingError:
1019            if self.server_version < 90500:
1020                raise _prg_error(
1021                    'Upsert operation is not supported by PostgreSQL version')
1022            raise  # re-raise original error
1023        res = q.dictresult()
1024        if res:  # may be empty with "do nothing"
1025            for n, value in res[0].items():
1026                if n == 'oid':
1027                    n = _oid_key(table)
1028                elif attnames.get(n) == 'bytea':
1029                    value = self.unescape_bytea(value)
1030                row[n] = value
1031        elif update:
1032            raise _int_error('Upsert operation did not return new values')
1033        else:
1034            self.get(table, row)
1035        return row
1036
1037    def clear(self, table, row=None):
1038        """Clear all the attributes to values determined by the types.
1039
1040        Numeric types are set to 0, Booleans are set to false, and everything
1041        else is set to the empty string.  If the row argument is present,
1042        it is used as the row dictionary and any entries matching attribute
1043        names are cleared with everything else left unchanged.
1044        """
1045        # At some point we will need a way to get defaults from a table.
1046        if row is None:
1047            row = {}  # empty if argument is not present
1048        attnames = self.get_attnames(table)
1049        for n, t in attnames.items():
1050            if n == 'oid':
1051                continue
1052            if t in ('int', 'integer', 'smallint', 'bigint',
1053                    'float', 'real', 'double precision',
1054                    'num', 'numeric', 'money'):
1055                row[n] = 0
1056            elif t in ('bool', 'boolean'):
1057                row[n] = self._make_bool(False)
1058            else:
1059                row[n] = ''
1060        return row
1061
1062    def delete(self, table, row=None, **kw):
1063        """Delete an existing row in a database table.
1064
1065        This method deletes the row from a table.  It deletes based on the
1066        OID value as munged by get or passed as keyword, or on the primary
1067        key of the table.  The return value is the number of deleted rows
1068        (i.e. 0 if the row did not exist and 1 if the row was deleted).
1069        """
1070        # Like update, delete works on the oid.
1071        # One day we will be testing that the record to be deleted
1072        # isn't referenced somewhere (or else PostgreSQL will).
1073        # Note that we only accept oid key from named args for safety.
1074        qoid = _oid_key(table)
1075        if 'oid' in kw:
1076            kw[qoid] = kw['oid']
1077            del kw['oid']
1078        if row is None:
1079            row = {}
1080        row.update(kw)
1081        params = []
1082        param = partial(self._prepare_param, params=params)
1083        if qoid in row:
1084            where = 'oid = %s' % param(row[qoid], 'int')
1085        else:
1086            try:
1087                keyname = self.pkey(table)
1088            except KeyError:
1089                raise _prg_error('Table %s has no primary key' % table)
1090            keyname = [keyname] if isinstance(
1091                keyname, basestring) else sorted(keyname)
1092            attnames = self.get_attnames(table)
1093            col = self.escape_identifier
1094            try:
1095                where = ' AND '.join('%s = %s' % (
1096                    col(k), param(row[k], attnames[k])) for k in keyname)
1097            except KeyError:
1098                raise _prg_error('Delete operation needs primary key or oid')
1099        q = 'DELETE FROM %s WHERE %s' % (
1100            self._escape_qualified_name(table), where)
1101        self._do_debug(q, params)
1102        res = self.db.query(q, params)
1103        return int(res)
1104
1105    def truncate(self, table, restart=False, cascade=False, only=False):
1106        """Empty a table or set of tables.
1107
1108        This method quickly removes all rows from the given table or set
1109        of tables.  It has the same effect as an unqualified DELETE on each
1110        table, but since it does not actually scan the tables it is faster.
1111        Furthermore, it reclaims disk space immediately, rather than requiring
1112        a subsequent VACUUM operation. This is most useful on large tables.
1113
1114        If restart is set to True, sequences owned by columns of the truncated
1115        table(s) are automatically restarted.  If cascade is set to True, it
1116        also truncates all tables that have foreign-key references to any of
1117        the named tables.  If the parameter only is not set to True, all the
1118        descendant tables (if any) will also be truncated. Optionally, a '*'
1119        can be specified after the table name to explicitly indicate that
1120        descendant tables are included.
1121        """
1122        if isinstance(table, basestring):
1123            only = {table: only}
1124            table = [table]
1125        elif isinstance(table, (list, tuple)):
1126            if isinstance(only, (list, tuple)):
1127                only = dict(zip(table, only))
1128            else:
1129                only = dict.fromkeys(table, only)
1130        elif isinstance(table, (set, frozenset)):
1131            only = dict.fromkeys(table, only)
1132        else:
1133            raise TypeError('The table must be a string, list or set')
1134        if not (restart is None or isinstance(restart, (bool, int))):
1135            raise TypeError('Invalid type for the restart option')
1136        if not (cascade is None or isinstance(cascade, (bool, int))):
1137            raise TypeError('Invalid type for the cascade option')
1138        tables = []
1139        for t in table:
1140            u = only.get(t)
1141            if not (u is None or isinstance(u, (bool, int))):
1142                raise TypeError('Invalid type for the only option')
1143            if t.endswith('*'):
1144                if u:
1145                    raise ValueError(
1146                        'Contradictory table name and only options')
1147                t = t[:-1].rstrip()
1148            t = self._escape_qualified_name(t)
1149            if u:
1150                t = 'ONLY %s' % t
1151            tables.append(t)
1152        q = ['TRUNCATE', ', '.join(tables)]
1153        if restart:
1154            q.append('RESTART IDENTITY')
1155        if cascade:
1156            q.append('CASCADE')
1157        q = ' '.join(q)
1158        self._do_debug(q)
1159        return self.query(q)
1160
1161    def notification_handler(self,
1162            event, callback, arg_dict=None, timeout=None, stop_event=None):
1163        """Get notification handler that will run the given callback."""
1164        return NotificationHandler(self,
1165            event, callback, arg_dict, timeout, stop_event)
1166
1167
1168# if run as script, print some information
1169
1170if __name__ == '__main__':
1171    print('PyGreSQL version' + version)
1172    print('')
1173    print(__doc__)
Note: See TracBrowser for help on using the repository browser.