source: trunk/pg.py @ 763

Last change on this file since 763 was 763, checked in by cito, 4 years ago

Achieve 100% test coverage for pg module on the trunk

Note that some lines are only covered in certain Pg or Py versions,
so you need to run tests with different versions to be sure.

Also added another synonym for transaction methods,
you can now pick your favorite for all three of them.

  • Property svn:keywords set to Id
File size: 44.6 KB
Line 
1#! /usr/bin/python
2#
3# pg.py
4#
5# $Id: pg.py 763 2016-01-17 21:01:04Z cito $
6#
7
8"""PyGreSQL classic interface.
9
10This pg module implements some basic database management stuff.
11It includes the _pg module and builds on it, providing the higher
12level wrapper class named DB with additional functionality.
13This is known as the "classic" ("old style") PyGreSQL interface.
14For a DB-API 2 compliant interface use the newer pgdb module.
15"""
16
17# Copyright (c) 1997-2016 by D'Arcy J.M. Cain.
18#
19# Contributions made by Ch. Zwerschke and others.
20#
21# The notification handler is based on pgnotify which is
22# Copyright (c) 2001 Ng Pheng Siong. All rights reserved.
23#
24# Permission to use, copy, modify, and distribute this software and its
25# documentation for any purpose and without fee is hereby granted,
26# provided that the above copyright notice appear in all copies and that
27# both that copyright notice and this permission notice appear in
28# supporting documentation.
29
30from __future__ import print_function
31
32from _pg import *
33
34import select
35import warnings
36
37from decimal import Decimal
38from collections import namedtuple
39from functools import partial
40
41try:
42    from collections import OrderedDict
43except ImportError:  # Python 2.6 or 3.0
44    OrderedDict = dict
45
46try:
47    basestring
48except NameError:  # Python >= 3.0
49    basestring = (str, bytes)
50
51set_decimal(Decimal)
52
53
54# Auxiliary functions that are independent from a DB connection:
55
56def _oid_key(table):
57    """Build oid key from a table name."""
58    return 'oid(%s)' % table
59
60
61def _simpletype(typ):
62    """Determine a simplified name a pg_type name."""
63    if typ.startswith('bool'):
64        return 'bool'
65    if typ.startswith(('abstime', 'date', 'interval', 'timestamp')):
66        return 'date'
67    if typ.startswith(('cid', 'oid', 'int', 'xid')):
68        return 'int'
69    if typ.startswith('float'):
70        return 'float'
71    if typ.startswith('numeric'):
72        return 'num'
73    if typ.startswith('money'):
74        return 'money'
75    if typ.startswith('bytea'):
76        return 'bytea'
77    return 'text'
78
79
80def _namedresult(q):
81    """Get query result as named tuples."""
82    row = namedtuple('Row', q.listfields())
83    return [row(*r) for r in q.getresult()]
84
85set_namedresult(_namedresult)
86
87
88def _db_error(msg, cls=DatabaseError):
89    """Return DatabaseError with empty sqlstate attribute."""
90    error = cls(msg)
91    error.sqlstate = None
92    return error
93
94
95def _int_error(msg):
96    """Return InternalError."""
97    return _db_error(msg, InternalError)
98
99
100def _prg_error(msg):
101    """Return ProgrammingError."""
102    return _db_error(msg, ProgrammingError)
103
104
105class NotificationHandler(object):
106    """A PostgreSQL client-side asynchronous notification handler."""
107
108    def __init__(self, db, event, callback=None,
109            arg_dict=None, timeout=None, stop_event=None):
110        """Initialize the notification handler.
111
112        You must pass a PyGreSQL database connection, the name of an
113        event (notification channel) to listen for and a callback function.
114
115        You can also specify a dictionary arg_dict that will be passed as
116        the single argument to the callback function, and a timeout value
117        in seconds (a floating point number denotes fractions of seconds).
118        If it is absent or None, the callers will never time out.  If the
119        timeout is reached, the callback function will be called with a
120        single argument that is None.  If you set the timeout to zero,
121        the handler will poll notifications synchronously and return.
122
123        You can specify the name of the event that will be used to signal
124        the handler to stop listening as stop_event. By default, it will
125        be the event name prefixed with 'stop_'.
126        """
127        self.db = db
128        self.event = event
129        self.stop_event = stop_event or 'stop_%s' % event
130        self.listening = False
131        self.callback = callback
132        if arg_dict is None:
133            arg_dict = {}
134        self.arg_dict = arg_dict
135        self.timeout = timeout
136
137    def __del__(self):
138        self.unlisten()
139
140    def close(self):
141        """Stop listening and close the connection."""
142        if self.db:
143            self.unlisten()
144            self.db.close()
145            self.db = None
146
147    def listen(self):
148        """Start listening for the event and the stop event."""
149        if not self.listening:
150            self.db.query('listen "%s"' % self.event)
151            self.db.query('listen "%s"' % self.stop_event)
152            self.listening = True
153
154    def unlisten(self):
155        """Stop listening for the event and the stop event."""
156        if self.listening:
157            self.db.query('unlisten "%s"' % self.event)
158            self.db.query('unlisten "%s"' % self.stop_event)
159            self.listening = False
160
161    def notify(self, db=None, stop=False, payload=None):
162        """Generate a notification.
163
164        Optionally, you can pass a payload with the notification.
165
166        If you set the stop flag, a stop notification will be sent that
167        will cause the handler to stop listening.
168
169        Note: If the notification handler is running in another thread, you
170        must pass a different database connection since PyGreSQL database
171        connections are not thread-safe.
172        """
173        if self.listening:
174            if not db:
175                db = self.db
176            q = 'notify "%s"' % (self.stop_event if stop else self.event)
177            if payload:
178                q += ", '%s'" % payload
179            return db.query(q)
180
181    def __call__(self):
182        """Invoke the notification handler.
183
184        The handler is a loop that listens for notifications on the event
185        and stop event channels.  When either of these notifications are
186        received, its associated 'pid', 'event' and 'extra' (the payload
187        passed with the notification) are inserted into its arg_dict
188        dictionary and the callback is invoked with this dictionary as
189        a single argument.  When the handler receives a stop event, it
190        stops listening to both events and return.
191
192        In the special case that the timeout of the handler has been set
193        to zero, the handler will poll all events synchronously and return.
194        If will keep listening until it receives a stop event.
195
196        Note: If you run this loop in another thread, don't use the same
197        database connection for database operations in the main thread.
198        """
199        self.listen()
200        poll = self.timeout == 0
201        if not poll:
202            rlist = [self.db.fileno()]
203        while self.listening:
204            if poll or select.select(rlist, [], [], self.timeout)[0]:
205                while self.listening:
206                    notice = self.db.getnotify()
207                    if not notice:  # no more messages
208                        break
209                    event, pid, extra = notice
210                    if event not in (self.event, self.stop_event):
211                        self.unlisten()
212                        raise _db_error(
213                            'Listening for "%s" and "%s", but notified of "%s"'
214                            % (self.event, self.stop_event, event))
215                    if event == self.stop_event:
216                        self.unlisten()
217                    self.arg_dict.update(pid=pid, event=event, extra=extra)
218                    self.callback(self.arg_dict)
219                if poll:
220                    break
221            else:   # we timed out
222                self.unlisten()
223                self.callback(None)
224
225
226def pgnotify(*args, **kw):
227    """Same as NotificationHandler, under the traditional name."""
228    warnings.warn("pgnotify is deprecated, use NotificationHandler instead",
229        DeprecationWarning, stacklevel=2)
230    return NotificationHandler(*args, **kw)
231
232
233# The actual PostGreSQL database connection interface:
234
235class DB(object):
236    """Wrapper class for the _pg connection type."""
237
238    def __init__(self, *args, **kw):
239        """Create a new connection
240
241        You can pass either the connection parameters or an existing
242        _pg or pgdb connection. This allows you to use the methods
243        of the classic pg interface with a DB-API 2 pgdb connection.
244        """
245        if not args and len(kw) == 1:
246            db = kw.get('db')
247        elif not kw and len(args) == 1:
248            db = args[0]
249        else:
250            db = None
251        if db:
252            if isinstance(db, DB):
253                db = db.db
254            else:
255                try:
256                    db = db._cnx
257                except AttributeError:
258                    pass
259        if not db or not hasattr(db, 'db') or not hasattr(db, 'query'):
260            db = connect(*args, **kw)
261            self._closeable = True
262        else:
263            self._closeable = False
264        self.db = db
265        self.dbname = db.db
266        self._regtypes = False
267        self._attnames = {}
268        self._pkeys = {}
269        self._privileges = {}
270        self._args = args, kw
271        self.debug = None  # For debugging scripts, this can be set
272            # * to a string format specification (e.g. in CGI set to "%s<BR>"),
273            # * to a file object to write debug statements or
274            # * to a callable object which takes a string argument
275            # * to any other true value to just print debug statements
276
277    def __getattr__(self, name):
278        # All undefined members are same as in underlying connection:
279        if self.db:
280            return getattr(self.db, name)
281        else:
282            raise _int_error('Connection is not valid')
283
284    def __dir__(self):
285        # Custom dir function including the attributes of the connection:
286        attrs = set(self.__class__.__dict__)
287        attrs.update(self.__dict__)
288        attrs.update(dir(self.db))
289        return sorted(attrs)
290
291    # Context manager methods
292
293    def __enter__(self):
294        """Enter the runtime context. This will start a transactio."""
295        self.begin()
296        return self
297
298    def __exit__(self, et, ev, tb):
299        """Exit the runtime context. This will end the transaction."""
300        if et is None and ev is None and tb is None:
301            self.commit()
302        else:
303            self.rollback()
304
305    # Auxiliary methods
306
307    def _do_debug(self, *args):
308        """Print a debug message"""
309        if self.debug:
310            s = '\n'.join(args)
311            if isinstance(self.debug, basestring):
312                print(self.debug % s)
313            elif hasattr(self.debug, 'write'):
314                self.debug.write(s + '\n')
315            elif callable(self.debug):
316                self.debug(s)
317            else:
318                print(s)
319
320    def _escape_qualified_name(self, s):
321        """Escape a qualified name.
322
323        Escapes the name for use as an SQL identifier, unless the
324        name contains a dot, in which case the name is ambiguous
325        (could be a qualified name or just a name with a dot in it)
326        and must be quoted manually by the caller.
327        """
328        if '.' not in s:
329            s = self.escape_identifier(s)
330        return s
331
332    @staticmethod
333    def _make_bool(d):
334        """Get boolean value corresponding to d."""
335        return bool(d) if get_bool() else ('t' if d else 'f')
336
337    _bool_true_values = frozenset('t true 1 y yes on'.split())
338
339    def _prepare_bool(self, d):
340        """Prepare a boolean parameter."""
341        if isinstance(d, basestring):
342            if not d:
343                return None
344            d = d.lower() in self._bool_true_values
345        return 't' if d else 'f'
346
347    _date_literals = frozenset('current_date current_time'
348        ' current_timestamp localtime localtimestamp'.split())
349
350    def _prepare_date(self, d):
351        """Prepare a date parameter."""
352        if not d:
353            return None
354        if isinstance(d, basestring) and d.lower() in self._date_literals:
355            raise ValueError
356        return d
357
358    def _prepare_num(self, d):
359        """Prepare a numeric parameter."""
360        if not d and d != 0:
361            return None
362        return d
363
364    def _prepare_bytea(self, d):
365        """Prepare a bytea parameter."""
366        return self.escape_bytea(d)
367
368    _prepare_funcs = dict(  # quote methods for each type
369        bool=_prepare_bool, date=_prepare_date,
370        int=_prepare_num, num=_prepare_num, float=_prepare_num,
371        money=_prepare_num, bytea=_prepare_bytea)
372
373    def _prepare_param(self, value, typ, params):
374        """Prepare and add a parameter to the list."""
375        if value is not None and typ != 'text':
376            prepare = self._prepare_funcs[typ]
377            try:
378                value = prepare(self, value)
379            except ValueError:
380                return value
381        params.append(value)
382        return '$%d' % len(params)
383
384    def _list_params(self, params):
385        """Create a human readable parameter list."""
386        return ', '.join('$%d=%r' % (n, v) for n, v in enumerate(params, 1))
387
388    @staticmethod
389    def _prepare_qualified_param(name, param):
390        """Quote parameter representing a qualified name.
391
392        Escapes the name for use as an SQL parameter, unless the
393        name contains a dot, in which case the name is ambiguous
394        (could be a qualified name or just a name with a dot in it)
395        and must be quoted manually by the caller.
396
397        """
398        if isinstance(param, int):
399            param = "$%d" % param
400        if '.' not in name:
401            param = 'quote_ident(%s)' % (param,)
402        return param
403
404    # Public methods
405
406    # escape_string and escape_bytea exist as methods,
407    # so we define unescape_bytea as a method as well
408    unescape_bytea = staticmethod(unescape_bytea)
409
410    def close(self):
411        """Close the database connection."""
412        # Wraps shared library function so we can track state.
413        if self._closeable:
414            if self.db:
415                self.db.close()
416                self.db = None
417            else:
418                raise _int_error('Connection already closed')
419
420    def reset(self):
421        """Reset connection with current parameters.
422
423        All derived queries and large objects derived from this connection
424        will not be usable after this call.
425
426        """
427        if self.db:
428            self.db.reset()
429        else:
430            raise _int_error('Connection already closed')
431
432    def reopen(self):
433        """Reopen connection to the database.
434
435        Used in case we need another connection to the same database.
436        Note that we can still reopen a database that we have closed.
437
438        """
439        # There is no such shared library function.
440        if self._closeable:
441            db = connect(*self._args[0], **self._args[1])
442            if self.db:
443                self.db.close()
444            self.db = db
445
446    def begin(self, mode=None):
447        """Begin a transaction."""
448        qstr = 'BEGIN'
449        if mode:
450            qstr += ' ' + mode
451        return self.query(qstr)
452
453    start = begin
454
455    def commit(self):
456        """Commit the current transaction."""
457        return self.query('COMMIT')
458
459    end = commit
460
461    def rollback(self, name=None):
462        """Roll back the current transaction."""
463        qstr = 'ROLLBACK'
464        if name:
465            qstr += ' TO ' + name
466        return self.query(qstr)
467
468    abort = rollback
469
470    def savepoint(self, name):
471        """Define a new savepoint within the current transaction."""
472        return self.query('SAVEPOINT ' + name)
473
474    def release(self, name):
475        """Destroy a previously defined savepoint."""
476        return self.query('RELEASE ' + name)
477
478    def get_parameter(self, parameter):
479        """Get the value of a run-time parameter.
480
481        If the parameter is a string, the return value will also be a string
482        that is the current setting of the run-time parameter with that name.
483
484        You can get several parameters at once by passing a list, set or dict.
485        When passing a list of parameter names, the return value will be a
486        corresponding list of parameter settings.  When passing a set of
487        parameter names, a new dict will be returned, mapping these parameter
488        names to their settings.  Finally, if you pass a dict as parameter,
489        its values will be set to the current parameter settings corresponding
490        to its keys.
491
492        By passing the special name 'all' as the parameter, you can get a dict
493        of all existing configuration parameters.
494        """
495        if isinstance(parameter, basestring):
496            parameter = [parameter]
497            values = None
498        elif isinstance(parameter, (list, tuple)):
499            values = []
500        elif isinstance(parameter, (set, frozenset)):
501            values = {}
502        elif isinstance(parameter, dict):
503            values = parameter
504        else:
505            raise TypeError(
506                'The parameter must be a string, list, set or dict')
507        if not parameter:
508            raise TypeError('No parameter has been specified')
509        params = {} if isinstance(values, dict) else []
510        for key in parameter:
511            param = key.strip().lower() if isinstance(
512                key, basestring) else None
513            if not param:
514                raise TypeError('Invalid parameter')
515            if param == 'all':
516                q = 'SHOW ALL'
517                values = self.db.query(q).getresult()
518                values = dict(value[:2] for value in values)
519                break
520            if isinstance(values, dict):
521                params[param] = key
522            else:
523                params.append(param)
524        else:
525            for param in params:
526                q = 'SHOW %s' % (param,)
527                value = self.db.query(q).getresult()[0][0]
528                if values is None:
529                    values = value
530                elif isinstance(values, list):
531                    values.append(value)
532                else:
533                    values[params[param]] = value
534        return values
535
536    def set_parameter(self, parameter, value=None, local=False):
537        """Set the value of a run-time parameter.
538
539        If the parameter and the value are strings, the run-time parameter
540        will be set to that value.  If no value or None is passed as a value,
541        then the run-time parameter will be restored to its default value.
542
543        You can set several parameters at once by passing a list of parameter
544        names, together with a single value that all parameters should be
545        set to or with a corresponding list of values.  You can also pass
546        the parameters as a set if you only provide a single value.
547        Finally, you can pass a dict with parameter names as keys.  In this
548        case, you should not pass a value, since the values for the parameters
549        will be taken from the dict.
550
551        By passing the special name 'all' as the parameter, you can reset
552        all existing settable run-time parameters to their default values.
553
554        If you set local to True, then the command takes effect for only the
555        current transaction.  After commit() or rollback(), the session-level
556        setting takes effect again.  Setting local to True will appear to
557        have no effect if it is executed outside a transaction, since the
558        transaction will end immediately.
559        """
560        if isinstance(parameter, basestring):
561            parameter = {parameter: value}
562        elif isinstance(parameter, (list, tuple)):
563            if isinstance(value, (list, tuple)):
564                parameter = dict(zip(parameter, value))
565            else:
566                parameter = dict.fromkeys(parameter, value)
567        elif isinstance(parameter, (set, frozenset)):
568            if isinstance(value, (list, tuple, set, frozenset)):
569                value = set(value)
570                if len(value) == 1:
571                    value = value.pop()
572            if not(value is None or isinstance(value, basestring)):
573                raise ValueError('A single value must be specified'
574                    ' when parameter is a set')
575            parameter = dict.fromkeys(parameter, value)
576        elif isinstance(parameter, dict):
577            if value is not None:
578                raise ValueError('A value must not be specified'
579                    ' when parameter is a dictionary')
580        else:
581            raise TypeError(
582                'The parameter must be a string, list, set or dict')
583        if not parameter:
584            raise TypeError('No parameter has been specified')
585        params = {}
586        for key, value in parameter.items():
587            param = key.strip().lower() if isinstance(
588                key, basestring) else None
589            if not param:
590                raise TypeError('Invalid parameter')
591            if param == 'all':
592                if value is not None:
593                    raise ValueError('A value must ot be specified'
594                        " when parameter is 'all'")
595                params = {'all': None}
596                break
597            params[param] = value
598        local = ' LOCAL' if local else ''
599        for param, value in params.items():
600            if value is None:
601                q = 'RESET%s %s' % (local, param)
602            else:
603                q = 'SET%s %s TO %s' % (local, param, value)
604            self._do_debug(q)
605            self.db.query(q)
606
607    def query(self, qstr, *args):
608        """Execute a SQL command string.
609
610        This method simply sends a SQL query to the database. If the query is
611        an insert statement that inserted exactly one row into a table that
612        has OIDs, the return value is the OID of the newly inserted row.
613        If the query is an update or delete statement, or an insert statement
614        that did not insert exactly one row in a table with OIDs, then the
615        number of rows affected is returned as a string. If it is a statement
616        that returns rows as a result (usually a select statement, but maybe
617        also an "insert/update ... returning" statement), this method returns
618        a Query object that can be accessed via getresult() or dictresult()
619        or simply printed. Otherwise, it returns `None`.
620
621        The query can contain numbered parameters of the form $1 in place
622        of any data constant. Arguments given after the query string will
623        be substituted for the corresponding numbered parameter. Parameter
624        values can also be given as a single list or tuple argument.
625        """
626        # Wraps shared library function for debugging.
627        if not self.db:
628            raise _int_error('Connection is not valid')
629        self._do_debug(qstr)
630        return self.db.query(qstr, args)
631
632    def pkey(self, table, flush=False):
633        """Get or set the primary key of a table.
634
635        Composite primary keys are represented as frozensets. Note that
636        this raises a KeyError if the table does not have a primary key.
637
638        If flush is set then the internal cache for primary keys will
639        be flushed. This may be necessary after the database schema or
640        the search path has been changed.
641        """
642        pkeys = self._pkeys
643        if flush:
644            pkeys.clear()
645            self._do_debug('The pkey cache has been flushed')
646        try:  # cache lookup
647            pkey = pkeys[table]
648        except KeyError:  # cache miss, check the database
649            q = ("SELECT a.attname FROM pg_index i"
650                " JOIN pg_attribute a ON a.attrelid = i.indrelid"
651                " AND a.attnum = ANY(i.indkey)"
652                " AND NOT a.attisdropped"
653                " WHERE i.indrelid=%s::regclass"
654                " AND i.indisprimary") % (
655                    self._prepare_qualified_param(table, 1),)
656            pkey = self.db.query(q, (table,)).getresult()
657            if not pkey:
658                raise KeyError('Table %s has no primary key' % table)
659            if len(pkey) > 1:
660                pkey = frozenset(k[0] for k in pkey)
661            else:
662                pkey = pkey[0][0]
663            pkeys[table] = pkey  # cache it
664        return pkey
665
666    def get_databases(self):
667        """Get list of databases in the system."""
668        return [s[0] for s in
669            self.db.query('SELECT datname FROM pg_database').getresult()]
670
671    def get_relations(self, kinds=None):
672        """Get list of relations in connected database of specified kinds.
673
674        If kinds is None or empty, all kinds of relations are returned.
675        Otherwise kinds can be a string or sequence of type letters
676        specifying which kind of relations you want to list.
677        """
678        where = " AND r.relkind IN (%s)" % ','.join(
679            ["'%s'" % k for k in kinds]) if kinds else ''
680        q = ("SELECT quote_ident(s.nspname)||'.'||quote_ident(r.relname)"
681            " FROM pg_class r"
682            " JOIN pg_namespace s ON s.oid = r.relnamespace"
683            " WHERE s.nspname NOT SIMILAR"
684            " TO 'pg/_%%|information/_schema' ESCAPE '/' %s"
685            " ORDER BY s.nspname, r.relname") % where
686        return [r[0] for r in self.db.query(q).getresult()]
687
688    def get_tables(self):
689        """Return list of tables in connected database."""
690        return self.get_relations('r')
691
692    def get_attnames(self, table, flush=False):
693        """Given the name of a table, dig out the set of attribute names.
694
695        Returns a dictionary of attribute names (the names are the keys,
696        the values are the names of the attributes' types).
697
698        If your Python version supports this, the dictionary will be an
699        OrderedDictionary with the column names in the right order.
700
701        If flush is set then the internal cache for attribute names will
702        be flushed. This may be necessary after the database schema or
703        the search path has been changed.
704
705        By default, only a limited number of simple types will be returned.
706        You can get the regular types after calling use_regtypes(True).
707        """
708        attnames = self._attnames
709        if flush:
710            attnames.clear()
711            self._do_debug('The attnames cache has been flushed')
712        try:  # cache lookup
713            names = attnames[table]
714        except KeyError:  # cache miss, check the database
715            q = ("SELECT a.attname, t.typname%s"
716                " FROM pg_attribute a"
717                " JOIN pg_type t ON t.oid = a.atttypid"
718                " WHERE a.attrelid = %s::regclass"
719                " AND (a.attnum > 0 OR a.attname = 'oid')"
720                " AND NOT a.attisdropped ORDER BY a.attnum") % (
721                    '::regtype' if self._regtypes else '',
722                    self._prepare_qualified_param(table, 1))
723            names = self.db.query(q, (table,)).getresult()
724            if not self._regtypes:
725                names = ((name, _simpletype(typ)) for name, typ in names)
726            names = OrderedDict(names)
727            attnames[table] = names  # cache it
728        return names
729
730    def use_regtypes(self, regtypes=None):
731        """Use regular type names instead of simplified type names."""
732        if regtypes is None:
733            return self._regtypes
734        else:
735            regtypes = bool(regtypes)
736            if regtypes != self._regtypes:
737                self._regtypes = regtypes
738                self._attnames.clear()
739            return regtypes
740
741    def has_table_privilege(self, table, privilege='select'):
742        """Check whether current user has specified table privilege."""
743        privilege = privilege.lower()
744        try:  # ask cache
745            return self._privileges[(table, privilege)]
746        except KeyError:  # cache miss, ask the database
747            q = "SELECT has_table_privilege(%s, $2)" % (
748                self._prepare_qualified_param(table, 1),)
749            q = self.db.query(q, (table, privilege))
750            ret = q.getresult()[0][0] == self._make_bool(True)
751            self._privileges[(table, privilege)] = ret  # cache it
752            return ret
753
754    def get(self, table, row, keyname=None):
755        """Get a row from a database table or view.
756
757        This method is the basic mechanism to get a single row.  The keyname
758        that the key specifies a unique row.  If keyname is not specified
759        then the primary key for the table is used.  If row is a dictionary
760        then the value for the key is taken from it and it is modified to
761        include the new values, replacing existing values where necessary.
762        For a composite key, keyname can also be a sequence of key names.
763        The OID is also put into the dictionary if the table has one, but
764        in order to allow the caller to work with multiple tables, it is
765        munged as "oid(table)".
766        """
767        if table.endswith('*'):  # scan descendant tables?
768            table = table[:-1].rstrip()  # need parent table name
769        if not keyname:
770            # use the primary key by default
771            try:
772                keyname = self.pkey(table)
773            except KeyError:
774                raise _prg_error('Table %s has no primary key' % table)
775        attnames = self.get_attnames(table)
776        params = []
777        param = partial(self._prepare_param, params=params)
778        col = self.escape_identifier
779        # We want the oid for later updates if that isn't the key.
780        # To allow users to work with multiple tables, we munge
781        # the name of the "oid" key by adding the name of the table.
782        qoid = _oid_key(table)
783        if keyname == 'oid':
784            if isinstance(row, dict):
785                if qoid not in row:
786                    raise _prg_error('%s not in row' % qoid)
787            else:
788                row = {qoid: row}
789            what = '*'
790            where = 'oid = %s' % param(row[qoid], 'int')
791        else:
792            keyname = [keyname] if isinstance(
793                keyname, basestring) else sorted(keyname)
794            if not isinstance(row, dict):
795                if len(keyname) > 1:
796                    raise _prg_error('Composite key needs dict as row')
797                row = dict((k, row) for k in keyname)
798            what = ', '.join(col(k) for k in attnames)
799            where = ' AND '.join('%s = %s' % (
800                col(k), param(row[k], attnames[k])) for k in keyname)
801        q = 'SELECT %s FROM %s WHERE %s LIMIT 1' % (
802            what, self._escape_qualified_name(table), where)
803        self._do_debug(q, params)
804        q = self.db.query(q, params)
805        res = q.dictresult()
806        if not res:
807            raise _db_error('No such record in %s\nwhere %s\nwith %s' % (
808                table, where, self._list_params(params)))
809        for n, value in res[0].items():
810            if n == 'oid':
811                n = qoid
812            elif attnames.get(n) == 'bytea':
813                value = self.unescape_bytea(value)
814            row[n] = value
815        return row
816
817    def insert(self, table, row=None, **kw):
818        """Insert a row into a database table.
819
820        This method inserts a row into a table.  The name of the table must
821        be passed as the first parameter.  The other parameters are used for
822        providing the data of the row that shall be inserted into the table.
823        If a dictionary is supplied as the second parameter, it starts with
824        that.  Otherwise it uses a blank dictionary. Either way the dictionary
825        is updated from the keywords.
826
827        The dictionary is then reloaded with the values actually inserted in
828        order to pick up values modified by rules, triggers, etc.
829
830        Note: The method currently doesn't support insert into views
831        although PostgreSQL does.
832        """
833        if 'oid' in kw:
834            del kw['oid']
835        if row is None:
836            row = {}
837        row.update(kw)
838        attnames = self.get_attnames(table)
839        params = []
840        param = partial(self._prepare_param, params=params)
841        col = self.escape_identifier
842        names, values = [], []
843        for n in attnames:
844            if n in row:
845                names.append(col(n))
846                values.append(param(row[n], attnames[n]))
847        names, values = ', '.join(names), ', '.join(values)
848        ret = 'oid, *' if 'oid' in attnames else '*'
849        q = 'INSERT INTO %s (%s) VALUES (%s) RETURNING %s' % (
850            self._escape_qualified_name(table), names, values, ret)
851        self._do_debug(q, params)
852        q = self.db.query(q, params)
853        res = q.dictresult()  # this will always return a row
854        for n, value in res[0].items():
855            if n == 'oid':
856                n = _oid_key(table)
857            elif attnames.get(n) == 'bytea' and value is not None:
858                value = self.unescape_bytea(value)
859            row[n] = value
860        return row
861
862    def update(self, table, row=None, **kw):
863        """Update an existing row in a database table.
864
865        Similar to insert but updates an existing row.  The update is based
866        on the OID value as munged by get or passed as keyword, or on the
867        primary key of the table.  The dictionary is modified to reflect
868        any changes caused by the update due to triggers, rules, default
869        values, etc.
870        """
871        # Update always works on the oid which get() returns if available,
872        # otherwise use the primary key.  Fail if neither.
873        # Note that we only accept oid key from named args for safety.
874        qoid = _oid_key(table)
875        if 'oid' in kw:
876            kw[qoid] = kw.pop('oid')
877        if row is None:
878            row = {}
879        row.update(kw)
880        attnames = self.get_attnames(table)
881        params = []
882        param = partial(self._prepare_param, params=params)
883        col = self.escape_identifier
884        if qoid in row:
885            where = 'oid = %s' % param(row[qoid], 'int')
886            keyname = []
887        else:
888            try:
889                keyname = self.pkey(table)
890            except KeyError:
891                raise _prg_error('Table %s has no primary key' % table)
892            keyname = [keyname] if isinstance(
893                keyname, basestring) else sorted(keyname)
894            try:
895                where = ' AND '.join('%s = %s' % (
896                    col(k), param(row[k], attnames[k])) for k in keyname)
897            except KeyError:
898                raise _prg_error('Update operation needs primary key or oid')
899        keyname = set(keyname)
900        keyname.add('oid')
901        values = []
902        for n in attnames:
903            if n in row and n not in keyname:
904                values.append('%s = %s' % (col(n), param(row[n], attnames[n])))
905        if not values:
906            return row
907        values = ', '.join(values)
908        ret = 'oid, *' if 'oid' in attnames else '*'
909        q = 'UPDATE %s SET %s WHERE %s RETURNING %s' % (
910            self._escape_qualified_name(table), values, where, ret)
911        self._do_debug(q, params)
912        q = self.db.query(q, params)
913        res = q.dictresult()
914        if res:  # may be empty when row does not exist
915            for n, value in res[0].items():
916                if n == 'oid':
917                    n = qoid
918                elif attnames.get(n) == 'bytea' and value is not None:
919                    value = self.unescape_bytea(value)
920                row[n] = value
921        return row
922
923    def upsert(self, table, row=None, **kw):
924        """Insert a row into a database table with conflict resolution
925
926        This method inserts a row into a table, but instead of raising a
927        ProgrammingError exception in case a row with the same primary key
928        already exists, an update will be executed instead.  This will be
929        performed as a single atomic operation on the database, so race
930        conditions can be avoided.
931
932        Like the insert method, the first parameter is the name of the
933        table and the second parameter can be used to pass the values to
934        be inserted as a dictionary.
935
936        Unlike the insert und update statement, keyword parameters are not
937        used to modify the dictionary, but to specify which columns shall
938        be updated in case of a conflict, and in which way:
939
940        A value of False or None means the column shall not be updated,
941        a value of True means the column shall be updated with the value
942        that has been proposed for insertion, i.e. has been passed as value
943        in the dictionary.  Columns that are not specified by keywords but
944        appear as keys in the dictionary are also updated like in the case
945        keywords had been passed with the value True.
946
947        So if in the case of a conflict you want to update every column that
948        has been passed in the dictionary row , you would call upsert(table, row).
949        If you don't want to do anything in case of a conflict, i.e. leave
950        the existing row as it is, call upsert(table, row, **dict.fromkeys(row)).
951
952        If you need more fine-grained control of what gets updated, you can
953        also pass strings in the keyword parameters.  These strings will
954        be used as SQL expressions for the update columns.  In these
955        expressions you can refer to the value that already exists in
956        the table by prefixing the column name with "included.", and to
957        the value that has been proposed for insertion by prefixing the
958        column name with the "excluded."
959
960        The dictionary is modified in any case to reflect the values in
961        the database after the operation has completed.
962
963        Note: The method uses the PostgreSQL "upsert" feature which is
964        only available since PostgreSQL 9.5.
965        """
966        if 'oid' in kw:
967            del kw['oid']
968        if row is None:
969            row = {}
970        attnames = self.get_attnames(table)
971        params = []
972        param = partial(self._prepare_param,params=params)
973        col = self.escape_identifier
974        names, values, updates = [], [], []
975        for n in attnames:
976            if n in row:
977                names.append(col(n))
978                values.append(param(row[n], attnames[n]))
979        names, values = ', '.join(names), ', '.join(values)
980        try:
981            keyname = self.pkey(table)
982        except KeyError:
983            raise _prg_error('Table %s has no primary key' % table)
984        keyname = [keyname] if isinstance(
985            keyname, basestring) else sorted(keyname)
986        target = ', '.join(col(k) for k in keyname)
987        update = []
988        keyname = set(keyname)
989        keyname.add('oid')
990        for n in attnames:
991            if n not in keyname:
992                value = kw.get(n, True)
993                if value:
994                    if not isinstance(value, basestring):
995                        value = 'excluded.%s' % col(n)
996                    update.append('%s = %s' % (col(n), value))
997        if not values:
998            return row
999        do = 'update set %s' % ', '.join(update) if update else 'nothing'
1000        ret = 'oid, *' if 'oid' in attnames else '*'
1001        q = ('INSERT INTO %s AS included (%s) VALUES (%s)'
1002            ' ON CONFLICT (%s) DO %s RETURNING %s') % (
1003                self._escape_qualified_name(table), names, values,
1004                target, do, ret)
1005        self._do_debug(q, params)
1006        try:
1007            q = self.db.query(q, params)
1008        except ProgrammingError:
1009            if self.server_version < 90500:
1010                raise _prg_error(
1011                    'Upsert operation is not supported by PostgreSQL version')
1012            raise  # re-raise original error
1013        res = q.dictresult()
1014        if update:  # may be empty with "do nothing"
1015            for n, value in res[0].items():
1016                if n == 'oid':
1017                    n = _oid_key(table)
1018                elif attnames.get(n) == 'bytea' and value is not None:
1019                    value = self.unescape_bytea(value)
1020                row[n] = value
1021        else:
1022            self.get(table, row)
1023        return row
1024
1025    def clear(self, table, row=None):
1026        """Clear all the attributes to values determined by the types.
1027
1028        Numeric types are set to 0, Booleans are set to false, and everything
1029        else is set to the empty string.  If the row argument is present,
1030        it is used as the row dictionary and any entries matching attribute
1031        names are cleared with everything else left unchanged.
1032        """
1033        # At some point we will need a way to get defaults from a table.
1034        if row is None:
1035            row = {}  # empty if argument is not present
1036        attnames = self.get_attnames(table)
1037        for n, t in attnames.items():
1038            if n == 'oid':
1039                continue
1040            if t in ('int', 'integer', 'smallint', 'bigint',
1041                    'float', 'real', 'double precision',
1042                    'num', 'numeric', 'money'):
1043                row[n] = 0
1044            elif t in ('bool', 'boolean'):
1045                row[n] = self._make_bool(False)
1046            else:
1047                row[n] = ''
1048        return row
1049
1050    def delete(self, table, row=None, **kw):
1051        """Delete an existing row in a database table.
1052
1053        This method deletes the row from a table.  It deletes based on the
1054        OID value as munged by get or passed as keyword, or on the primary
1055        key of the table.  The return value is the number of deleted rows
1056        (i.e. 0 if the row did not exist and 1 if the row was deleted).
1057        """
1058        # Like update, delete works on the oid.
1059        # One day we will be testing that the record to be deleted
1060        # isn't referenced somewhere (or else PostgreSQL will).
1061        # Note that we only accept oid key from named args for safety.
1062        qoid = _oid_key(table)
1063        if 'oid' in kw:
1064            kw[qoid] = kw.pop('oid')
1065        if row is None:
1066            row = {}
1067        row.update(kw)
1068        params = []
1069        param = partial(self._prepare_param, params=params)
1070        if qoid in row:
1071            where = 'oid = %s' % param(row[qoid], 'int')
1072        else:
1073            try:
1074                keyname = self.pkey(table)
1075            except KeyError:
1076                raise _prg_error('Table %s has no primary key' % table)
1077            keyname = [keyname] if isinstance(
1078                keyname, basestring) else sorted(keyname)
1079            attnames = self.get_attnames(table)
1080            col = self.escape_identifier
1081            try:
1082                where = ' AND '.join('%s = %s' % (
1083                    col(k), param(row[k], attnames[k])) for k in keyname)
1084            except KeyError:
1085                raise _prg_error('Delete operation needs primary key or oid')
1086        q = 'DELETE FROM %s WHERE %s' % (
1087            self._escape_qualified_name(table), where)
1088        self._do_debug(q, params)
1089        res = self.db.query(q, params)
1090        return int(res)
1091
1092    def truncate(self, table, restart=False, cascade=False, only=False):
1093        """Empty a table or set of tables.
1094
1095        This method quickly removes all rows from the given table or set
1096        of tables.  It has the same effect as an unqualified DELETE on each
1097        table, but since it does not actually scan the tables it is faster.
1098        Furthermore, it reclaims disk space immediately, rather than requiring
1099        a subsequent VACUUM operation. This is most useful on large tables.
1100
1101        If restart is set to True, sequences owned by columns of the truncated
1102        table(s) are automatically restarted.  If cascade is set to True, it
1103        also truncates all tables that have foreign-key references to any of
1104        the named tables.  If the parameter only is not set to True, all the
1105        descendant tables (if any) will also be truncated. Optionally, a '*'
1106        can be specified after the table name to explicitly indicate that
1107        descendant tables are included.
1108        """
1109        if isinstance(table, basestring):
1110            only = {table: only}
1111            table = [table]
1112        elif isinstance(table, (list, tuple)):
1113            if isinstance(only, (list, tuple)):
1114                only = dict(zip(table, only))
1115            else:
1116                only = dict.fromkeys(table, only)
1117        elif isinstance(table, (set, frozenset)):
1118            only = dict.fromkeys(table, only)
1119        else:
1120            raise TypeError('The table must be a string, list or set')
1121        if not (restart is None or isinstance(restart, (bool, int))):
1122            raise TypeError('Invalid type for the restart option')
1123        if not (cascade is None or isinstance(cascade, (bool, int))):
1124            raise TypeError('Invalid type for the cascade option')
1125        tables = []
1126        for t in table:
1127            u = only.get(t)
1128            if not (u is None or isinstance(u, (bool, int))):
1129                raise TypeError('Invalid type for the only option')
1130            if t.endswith('*'):
1131                if u:
1132                    raise ValueError(
1133                        'Contradictory table name and only options')
1134                t = t[:-1].rstrip()
1135            t = self._escape_qualified_name(t)
1136            if u:
1137                t = 'ONLY %s' % t
1138            tables.append(t)
1139        q = ['TRUNCATE', ', '.join(tables)]
1140        if restart:
1141            q.append('RESTART IDENTITY')
1142        if cascade:
1143            q.append('CASCADE')
1144        q = ' '.join(q)
1145        self._do_debug(q)
1146        return self.query(q)
1147
1148    def notification_handler(self,
1149            event, callback, arg_dict=None, timeout=None, stop_event=None):
1150        """Get notification handler that will run the given callback."""
1151        return NotificationHandler(self,
1152            event, callback, arg_dict, timeout, stop_event)
1153
1154
1155# if run as script, print some information
1156
1157if __name__ == '__main__':
1158    print('PyGreSQL version' + version)
1159    print('')
1160    print(__doc__)
Note: See TracBrowser for help on using the repository browser.