source: trunk/pg.py @ 740

Last change on this file since 740 was 740, checked in by cito, 4 years ago

Reformat some error messages and docstrings

Try to achieve a somewhat consistent style of docstrings and error
messages in the trunk. The docstrings use PEP 257, with a slight
variation between C code and Python code. The error messages are
capitalized and do not end with a period. (I prefer the periods,
but most Python code I have seen doesn't use them.)

(I know, a foolish consistency is the hobgoblin of little minds.)

  • Property svn:keywords set to Id
File size: 35.7 KB
Line 
1#! /usr/bin/python
2#
3# pg.py
4#
5# $Id: pg.py 740 2016-01-14 02:35:55Z cito $
6#
7
8"""PyGreSQL classic interface.
9
10This pg module implements some basic database management stuff.
11It includes the _pg module and builds on it, providing the higher
12level wrapper class named DB with additional functionality.
13This is known as the "classic" ("old style") PyGreSQL interface.
14For a DB-API 2 compliant interface use the newer pgdb module.
15"""
16
17# Copyright (c) 1997-2016 by D'Arcy J.M. Cain.
18#
19# Contributions made by Ch. Zwerschke and others.
20#
21# The notification handler is based on pgnotify which is
22# Copyright (c) 2001 Ng Pheng Siong. All rights reserved.
23#
24# Permission to use, copy, modify, and distribute this software and its
25# documentation for any purpose and without fee is hereby granted,
26# provided that the above copyright notice appear in all copies and that
27# both that copyright notice and this permission notice appear in
28# supporting documentation.
29
30from __future__ import print_function
31
32from _pg import *
33
34import select
35import warnings
36
37from decimal import Decimal
38from collections import namedtuple
39from functools import partial
40
41try:
42    from collections import OrderedDict
43except ImportError:  # Python 2.6 or 3.0
44    OrderedDict = dict
45
46try:
47    basestring
48except NameError:  # Python >= 3.0
49    basestring = (str, bytes)
50
51set_decimal(Decimal)
52
53
54# Auxiliary functions that are independent from a DB connection:
55
56def _oid_key(table):
57    """Build oid key from a table name."""
58    return 'oid(%s)' % table
59
60
61def _simpletype(typ):
62    """Determine a simplified name a pg_type name."""
63    if typ.startswith('bool'):
64        return 'bool'
65    if typ.startswith(('abstime', 'date', 'interval', 'timestamp')):
66        return 'date'
67    if typ.startswith(('cid', 'oid', 'int', 'xid')):
68        return 'int'
69    if typ.startswith('float'):
70        return 'float'
71    if typ.startswith('numeric'):
72        return 'num'
73    if typ.startswith('money'):
74        return 'money'
75    if typ.startswith('bytea'):
76        return 'bytea'
77    return 'text'
78
79
80def _namedresult(q):
81    """Get query result as named tuples."""
82    row = namedtuple('Row', q.listfields())
83    return [row(*r) for r in q.getresult()]
84
85set_namedresult(_namedresult)
86
87
88def _db_error(msg, cls=DatabaseError):
89    """Return DatabaseError with empty sqlstate attribute."""
90    error = cls(msg)
91    error.sqlstate = None
92    return error
93
94
95def _int_error(msg):
96    """Return InternalError."""
97    return _db_error(msg, InternalError)
98
99
100def _prg_error(msg):
101    """Return ProgrammingError."""
102    return _db_error(msg, ProgrammingError)
103
104
105class NotificationHandler(object):
106    """A PostgreSQL client-side asynchronous notification handler."""
107
108    def __init__(self, db, event, callback, arg_dict=None, timeout=None):
109        """Initialize the notification handler.
110
111        db       - PostgreSQL connection object.
112        event    - Event (notification channel) to LISTEN for.
113        callback - Event callback function.
114        arg_dict - A dictionary passed as the argument to the callback.
115        timeout  - Timeout in seconds; a floating point number denotes
116                   fractions of seconds. If it is absent or None, the
117                   callers will never time out.
118        """
119        if isinstance(db, DB):
120            db = db.db
121        self.db = db
122        self.event = event
123        self.stop_event = 'stop_%s' % event
124        self.listening = False
125        self.callback = callback
126        if arg_dict is None:
127            arg_dict = {}
128        self.arg_dict = arg_dict
129        self.timeout = timeout
130
131    def __del__(self):
132        self.close()
133
134    def close(self):
135        """Stop listening and close the connection."""
136        if self.db:
137            self.unlisten()
138            self.db.close()
139            self.db = None
140
141    def listen(self):
142        """Start listening for the event and the stop event."""
143        if not self.listening:
144            self.db.query('listen "%s"' % self.event)
145            self.db.query('listen "%s"' % self.stop_event)
146            self.listening = True
147
148    def unlisten(self):
149        """Stop listening for the event and the stop event."""
150        if self.listening:
151            self.db.query('unlisten "%s"' % self.event)
152            self.db.query('unlisten "%s"' % self.stop_event)
153            self.listening = False
154
155    def notify(self, db=None, stop=False, payload=None):
156        """Generate a notification.
157
158        Note: If the main loop is running in another thread, you must pass
159        a different database connection to avoid a collision.
160        """
161        if not db:
162            db = self.db
163        if self.listening:
164            q = 'notify "%s"' % (self.stop_event if stop else self.event)
165            if payload:
166                q += ", '%s'" % payload
167            return db.query(q)
168
169    def __call__(self, close=False):
170        """Invoke the notification handler.
171
172        The handler is a loop that actually LISTENs for two NOTIFY messages:
173
174        <event> and stop_<event>.
175
176        When either of these NOTIFY messages are received, its associated
177        'pid' and 'event' are inserted into <arg_dict>, and the callback is
178        invoked with <arg_dict>. If the NOTIFY message is stop_<event>, the
179        handler UNLISTENs both <event> and stop_<event> and exits.
180
181        Note: If you run this loop in another thread, don't use the same
182        database connection for database operations in the main thread.
183        """
184        self.listen()
185        _ilist = [self.db.fileno()]
186
187        while self.listening:
188            ilist, _olist, _elist = select.select(_ilist, [], [], self.timeout)
189            if ilist:
190                while self.listening:
191                    notice = self.db.getnotify()
192                    if not notice:  # no more messages
193                        break
194                    event, pid, extra = notice
195                    if event not in (self.event, self.stop_event):
196                        self.unlisten()
197                        raise _db_error(
198                            'Listening for "%s" and "%s", but notified of "%s"'
199                            % (self.event, self.stop_event, event))
200                    if event == self.stop_event:
201                        self.unlisten()
202                    self.arg_dict['pid'] = pid
203                    self.arg_dict['event'] = event
204                    self.arg_dict['extra'] = extra
205                    self.callback(self.arg_dict)
206            else:   # we timed out
207                self.unlisten()
208                self.callback(None)
209
210
211def pgnotify(*args, **kw):
212    """Same as NotificationHandler, under the traditional name."""
213    warnings.warn("pgnotify is deprecated, use NotificationHandler instead",
214        DeprecationWarning, stacklevel=2)
215    return NotificationHandler(*args, **kw)
216
217
218# The actual PostGreSQL database connection interface:
219
220class DB(object):
221    """Wrapper class for the _pg connection type."""
222
223    def __init__(self, *args, **kw):
224        """Create a new connection
225
226        You can pass either the connection parameters or an existing
227        _pg or pgdb connection. This allows you to use the methods
228        of the classic pg interface with a DB-API 2 pgdb connection.
229        """
230        if not args and len(kw) == 1:
231            db = kw.get('db')
232        elif not kw and len(args) == 1:
233            db = args[0]
234        else:
235            db = None
236        if db:
237            if isinstance(db, DB):
238                db = db.db
239            else:
240                try:
241                    db = db._cnx
242                except AttributeError:
243                    pass
244        if not db or not hasattr(db, 'db') or not hasattr(db, 'query'):
245            db = connect(*args, **kw)
246            self._closeable = True
247        else:
248            self._closeable = False
249        self.db = db
250        self.dbname = db.db
251        self._regtypes = False
252        self._attnames = {}
253        self._pkeys = {}
254        self._privileges = {}
255        self._args = args, kw
256        self.debug = None  # For debugging scripts, this can be set
257            # * to a string format specification (e.g. in CGI set to "%s<BR>"),
258            # * to a file object to write debug statements or
259            # * to a callable object which takes a string argument
260            # * to any other true value to just print debug statements
261
262    def __getattr__(self, name):
263        # All undefined members are same as in underlying connection:
264        if self.db:
265            return getattr(self.db, name)
266        else:
267            raise _int_error('Connection is not valid')
268
269    def __dir__(self):
270        # Custom dir function including the attributes of the connection:
271        attrs = set(self.__class__.__dict__)
272        attrs.update(self.__dict__)
273        attrs.update(dir(self.db))
274        return sorted(attrs)
275
276    # Context manager methods
277
278    def __enter__(self):
279        """Enter the runtime context. This will start a transactio."""
280        self.begin()
281        return self
282
283    def __exit__(self, et, ev, tb):
284        """Exit the runtime context. This will end the transaction."""
285        if et is None and ev is None and tb is None:
286            self.commit()
287        else:
288            self.rollback()
289
290    # Auxiliary methods
291
292    def _do_debug(self, *args):
293        """Print a debug message"""
294        if self.debug:
295            s = '\n'.join(args)
296            if isinstance(self.debug, basestring):
297                print(self.debug % s)
298            elif hasattr(self.debug, 'write'):
299                self.debug.write(s + '\n')
300            elif callable(self.debug):
301                self.debug(s)
302            else:
303                print(s)
304
305    def _escape_qualified_name(self, s):
306        """Escape a qualified name.
307
308        Escapes the name for use as an SQL identifier, unless the
309        name contains a dot, in which case the name is ambiguous
310        (could be a qualified name or just a name with a dot in it)
311        and must be quoted manually by the caller.
312        """
313        if '.' not in s:
314            s = self.escape_identifier(s)
315        return s
316
317    @staticmethod
318    def _make_bool(d):
319        """Get boolean value corresponding to d."""
320        return bool(d) if get_bool() else ('t' if d else 'f')
321
322    _bool_true_values = frozenset('t true 1 y yes on'.split())
323
324    def _prepare_bool(self, d):
325        """Prepare a boolean parameter."""
326        if isinstance(d, basestring):
327            if not d:
328                return None
329            d = d.lower() in self._bool_true_values
330        return 't' if d else 'f'
331
332    _date_literals = frozenset('current_date current_time'
333        ' current_timestamp localtime localtimestamp'.split())
334
335    def _prepare_date(self, d):
336        """Prepare a date parameter."""
337        if not d:
338            return None
339        if isinstance(d, basestring) and d.lower() in self._date_literals:
340            raise ValueError
341        return d
342
343    def _prepare_num(self, d):
344        """Prepare a numeric parameter."""
345        if not d and d != 0:
346            return None
347        return d
348
349    def _prepare_bytea(self, d):
350        """Prepare a bytea parameter."""
351        return self.escape_bytea(d)
352
353    _prepare_funcs = dict(  # quote methods for each type
354        bool=_prepare_bool, date=_prepare_date,
355        int=_prepare_num, num=_prepare_num, float=_prepare_num,
356        money=_prepare_num, bytea=_prepare_bytea)
357
358    def _prepare_param(self, value, typ, params):
359        """Prepare and add a parameter to the list."""
360        if value is not None and typ != 'text':
361            try:
362                prepare = self._prepare_funcs[typ]
363            except KeyError:
364                pass
365            else:
366                try:
367                    value = prepare(self, value)
368                except ValueError:
369                    return value
370        params.append(value)
371        return '$%d' % len(params)
372
373    @staticmethod
374    def _prepare_qualified_param(name, param):
375        """Quote parameter representing a qualified name.
376
377        Escapes the name for use as an SQL parameter, unless the
378        name contains a dot, in which case the name is ambiguous
379        (could be a qualified name or just a name with a dot in it)
380        and must be quoted manually by the caller.
381
382        """
383        if isinstance(param, int):
384            param = "$%d" % param
385        if '.' not in name:
386            param = 'quote_ident(%s)' % (param,)
387        return param
388
389    # Public methods
390
391    # escape_string and escape_bytea exist as methods,
392    # so we define unescape_bytea as a method as well
393    unescape_bytea = staticmethod(unescape_bytea)
394
395    def close(self):
396        """Close the database connection."""
397        # Wraps shared library function so we can track state.
398        if self._closeable:
399            if self.db:
400                self.db.close()
401                self.db = None
402            else:
403                raise _int_error('Connection already closed')
404
405    def reset(self):
406        """Reset connection with current parameters.
407
408        All derived queries and large objects derived from this connection
409        will not be usable after this call.
410
411        """
412        if self.db:
413            self.db.reset()
414        else:
415            raise _int_error('Connection already closed')
416
417    def reopen(self):
418        """Reopen connection to the database.
419
420        Used in case we need another connection to the same database.
421        Note that we can still reopen a database that we have closed.
422
423        """
424        # There is no such shared library function.
425        if self._closeable:
426            db = connect(*self._args[0], **self._args[1])
427            if self.db:
428                self.db.close()
429            self.db = db
430
431    def begin(self, mode=None):
432        """Begin a transaction."""
433        qstr = 'BEGIN'
434        if mode:
435            qstr += ' ' + mode
436        return self.query(qstr)
437
438    start = begin
439
440    def commit(self):
441        """Commit the current transaction."""
442        return self.query('COMMIT')
443
444    end = commit
445
446    def rollback(self, name=None):
447        """Roll back the current transaction."""
448        qstr = 'ROLLBACK'
449        if name:
450            qstr += ' TO ' + name
451        return self.query(qstr)
452
453    def savepoint(self, name):
454        """Define a new savepoint within the current transaction."""
455        return self.query('SAVEPOINT ' + name)
456
457    def release(self, name):
458        """Destroy a previously defined savepoint."""
459        return self.query('RELEASE ' + name)
460
461    def query(self, qstr, *args):
462        """Execute a SQL command string.
463
464        This method simply sends a SQL query to the database. If the query is
465        an insert statement that inserted exactly one row into a table that
466        has OIDs, the return value is the OID of the newly inserted row.
467        If the query is an update or delete statement, or an insert statement
468        that did not insert exactly one row in a table with OIDs, then the
469        number of rows affected is returned as a string. If it is a statement
470        that returns rows as a result (usually a select statement, but maybe
471        also an "insert/update ... returning" statement), this method returns
472        a Query object that can be accessed via getresult() or dictresult()
473        or simply printed. Otherwise, it returns `None`.
474
475        The query can contain numbered parameters of the form $1 in place
476        of any data constant. Arguments given after the query string will
477        be substituted for the corresponding numbered parameter. Parameter
478        values can also be given as a single list or tuple argument.
479        """
480        # Wraps shared library function for debugging.
481        if not self.db:
482            raise _int_error('Connection is not valid')
483        self._do_debug(qstr)
484        return self.db.query(qstr, args)
485
486    def pkey(self, table, flush=False):
487        """Get or set the primary key of a table.
488
489        Composite primary keys are represented as frozensets. Note that
490        this raises a KeyError if the table does not have a primary key.
491
492        If flush is set then the internal cache for primary keys will
493        be flushed. This may be necessary after the database schema or
494        the search path has been changed.
495        """
496        pkeys = self._pkeys
497        if flush:
498            pkeys.clear()
499            self._do_debug('The pkey cache has been flushed')
500        try:  # cache lookup
501            pkey = pkeys[table]
502        except KeyError:  # cache miss, check the database
503            q = ("SELECT a.attname FROM pg_index i"
504                " JOIN pg_attribute a ON a.attrelid = i.indrelid"
505                " AND a.attnum = ANY(i.indkey)"
506                " AND NOT a.attisdropped"
507                " WHERE i.indrelid=%s::regclass"
508                " AND i.indisprimary") % (
509                    self._prepare_qualified_param(table, 1),)
510            pkey = self.db.query(q, (table,)).getresult()
511            if not pkey:
512                raise KeyError('Table %s has no primary key' % table)
513            if len(pkey) > 1:
514                pkey = frozenset(k[0] for k in pkey)
515            else:
516                pkey = pkey[0][0]
517            pkeys[table] = pkey  # cache it
518        return pkey
519
520    def get_databases(self):
521        """Get list of databases in the system."""
522        return [s[0] for s in
523            self.db.query('SELECT datname FROM pg_database').getresult()]
524
525    def get_relations(self, kinds=None):
526        """Get list of relations in connected database of specified kinds.
527
528        If kinds is None or empty, all kinds of relations are returned.
529        Otherwise kinds can be a string or sequence of type letters
530        specifying which kind of relations you want to list.
531        """
532        where = " AND r.relkind IN (%s)" % ','.join(
533            ["'%s'" % k for k in kinds]) if kinds else ''
534        q = ("SELECT quote_ident(s.nspname)||'.'||quote_ident(r.relname)"
535            " FROM pg_class r"
536            " JOIN pg_namespace s ON s.oid = r.relnamespace"
537            " WHERE s.nspname NOT SIMILAR"
538            " TO 'pg/_%%|information/_schema' ESCAPE '/' %s"
539            " ORDER BY s.nspname, r.relname") % where
540        return [r[0] for r in self.db.query(q).getresult()]
541
542    def get_tables(self):
543        """Return list of tables in connected database."""
544        return self.get_relations('r')
545
546    def get_attnames(self, table, flush=False):
547        """Given the name of a table, dig out the set of attribute names.
548
549        Returns a dictionary of attribute names (the names are the keys,
550        the values are the names of the attributes' types).
551
552        If your Python version supports this, the dictionary will be an
553        OrderedDictionary with the column names in the right order.
554
555        If flush is set then the internal cache for attribute names will
556        be flushed. This may be necessary after the database schema or
557        the search path has been changed.
558
559        By default, only a limited number of simple types will be returned.
560        You can get the regular types after calling use_regtypes(True).
561        """
562        attnames = self._attnames
563        if flush:
564            attnames.clear()
565            self._do_debug('The attnames cache has been flushed')
566        try:  # cache lookup
567            names = attnames[table]
568        except KeyError:  # cache miss, check the database
569            q = ("SELECT a.attname, t.typname%s"
570                " FROM pg_attribute a"
571                " JOIN pg_type t ON t.oid = a.atttypid"
572                " WHERE a.attrelid = %s::regclass"
573                " AND (a.attnum > 0 OR a.attname = 'oid')"
574                " AND NOT a.attisdropped ORDER BY a.attnum") % (
575                    '::regtype' if self._regtypes else '',
576                    self._prepare_qualified_param(table, 1))
577            names = self.db.query(q, (table,)).getresult()
578            if not names:
579                raise KeyError('Table %s does not exist' % table)
580            if not self._regtypes:
581                names = ((name, _simpletype(typ)) for name, typ in names)
582            names = OrderedDict(names)
583            attnames[table] = names  # cache it
584        return names
585
586    def use_regtypes(self, regtypes=None):
587        """Use regular type names instead of simplified type names."""
588        if regtypes is None:
589            return self._regtypes
590        else:
591            regtypes = bool(regtypes)
592            if regtypes != self._regtypes:
593                self._regtypes = regtypes
594                self._attnames.clear()
595            return regtypes
596
597    def has_table_privilege(self, table, privilege='select'):
598        """Check whether current user has specified table privilege."""
599        privilege = privilege.lower()
600        try:  # ask cache
601            return self._privileges[(table, privilege)]
602        except KeyError:  # cache miss, ask the database
603            q = "SELECT has_table_privilege(%s, $2)" % (
604                self._prepare_qualified_param(table, 1),)
605            q = self.db.query(q, (table, privilege))
606            ret = q.getresult()[0][0] == self._make_bool(True)
607            self._privileges[(table, privilege)] = ret  # cache it
608            return ret
609
610    def get(self, table, row, keyname=None):
611        """Get a row from a database table or view.
612
613        This method is the basic mechanism to get a single row.  The keyname
614        that the key specifies a unique row.  If keyname is not specified
615        then the primary key for the table is used.  If row is a dictionary
616        then the value for the key is taken from it and it is modified to
617        include the new values, replacing existing values where necessary.
618        For a composite key, keyname can also be a sequence of key names.
619        The OID is also put into the dictionary if the table has one, but
620        in order to allow the caller to work with multiple tables, it is
621        munged as "oid(table)".
622        """
623        if table.endswith('*'):  # scan descendant tables?
624            table = table[:-1].rstrip()  # need parent table name
625        if not keyname:
626            # use the primary key by default
627            try:
628                keyname = self.pkey(table)
629            except KeyError:
630                raise _prg_error('Table %s has no primary key' % table)
631        attnames = self.get_attnames(table)
632        params = []
633        param = partial(self._prepare_param, params=params)
634        col = self.escape_identifier
635        # We want the oid for later updates if that isn't the key.
636        # To allow users to work with multiple tables, we munge
637        # the name of the "oid" key by adding the name of the table.
638        qoid = _oid_key(table)
639        if keyname == 'oid':
640            if isinstance(row, dict):
641                if qoid not in row:
642                    raise _db_error('%s not in row' % qoid)
643            else:
644                row = {qoid: row}
645            what = '*'
646            where = 'oid = %s' % param(row[qoid], 'int')
647        else:
648            keyname = [keyname] if isinstance(
649                keyname, basestring) else sorted(keyname)
650            if not isinstance(row, dict):
651                if len(keyname) > 1:
652                    raise _prg_error('Composite key needs dict as row')
653                row = dict((k, row) for k in keyname)
654            what = ', '.join(col(k) for k in attnames)
655            where = ' AND '.join('%s = %s' % (
656                col(k), param(row[k], attnames[k])) for k in keyname)
657        q = 'SELECT %s FROM %s WHERE %s LIMIT 1' % (
658            what, self._escape_qualified_name(table), where)
659        self._do_debug(q, params)
660        q = self.db.query(q, params)
661        res = q.dictresult()
662        if not res:
663            raise _db_error('No such record in %s where %s' % (table, where))
664        for n, value in res[0].items():
665            if n == 'oid':
666                n = qoid
667            elif attnames.get(n) == 'bytea':
668                value = self.unescape_bytea(value)
669            row[n] = value
670        return row
671
672    def insert(self, table, row=None, **kw):
673        """Insert a row into a database table.
674
675        This method inserts a row into a table.  The name of the table must
676        be passed as the first parameter.  The other parameters are used for
677        providing the data of the row that shall be inserted into the table.
678        If a dictionary is supplied as the second parameter, it starts with
679        that.  Otherwise it uses a blank dictionary. Either way the dictionary
680        is updated from the keywords.
681
682        The dictionary is then reloaded with the values actually inserted in
683        order to pick up values modified by rules, triggers, etc.
684
685        Note: The method currently doesn't support insert into views
686        although PostgreSQL does.
687        """
688        if 'oid' in kw:
689            del kw['oid']
690        if row is None:
691            row = {}
692        row.update(kw)
693        attnames = self.get_attnames(table)
694        params = []
695        param = partial(self._prepare_param, params=params)
696        col = self.escape_identifier
697        names, values = [], []
698        for n in attnames:
699            if n in row:
700                names.append(col(n))
701                values.append(param(row[n], attnames[n]))
702        names, values = ', '.join(names), ', '.join(values)
703        ret = 'oid, *' if 'oid' in attnames else '*'
704        q = 'INSERT INTO %s (%s) VALUES (%s) RETURNING %s' % (
705            self._escape_qualified_name(table), names, values, ret)
706        self._do_debug(q, params)
707        q = self.db.query(q, params)
708        res = q.dictresult()
709        if not res:
710            raise _int_error('Insert operation did not return new values')
711        for n, value in res[0].items():
712            if n == 'oid':
713                n = _oid_key(table)
714            elif attnames.get(n) == 'bytea' and value is not None:
715                value = self.unescape_bytea(value)
716            row[n] = value
717        return row
718
719    def update(self, table, row=None, **kw):
720        """Update an existing row in a database table.
721
722        Similar to insert but updates an existing row.  The update is based
723        on the OID value as munged by get or passed as keyword, or on the
724        primary key of the table.  The dictionary is modified to reflect
725        any changes caused by the update due to triggers, rules, default
726        values, etc.
727        """
728        # Update always works on the oid which get() returns if available,
729        # otherwise use the primary key.  Fail if neither.
730        # Note that we only accept oid key from named args for safety.
731        qoid = _oid_key(table)
732        if 'oid' in kw:
733            kw[qoid] = kw['oid']
734            del kw['oid']
735        if row is None:
736            row = {}
737        row.update(kw)
738        attnames = self.get_attnames(table)
739        params = []
740        param = partial(self._prepare_param, params=params)
741        col = self.escape_identifier
742        if qoid in row:
743            where = 'oid = %s' % param(row[qoid], 'int')
744            keyname = []
745        else:
746            try:
747                keyname = self.pkey(table)
748            except KeyError:
749                raise _prg_error('Table %s has no primary key' % table)
750            keyname = [keyname] if isinstance(
751                keyname, basestring) else sorted(keyname)
752            try:
753                where = ' AND '.join('%s = %s' % (
754                    col(k), param(row[k], attnames[k])) for k in keyname)
755            except KeyError:
756                raise _prg_error('Update operation needs primary key or oid')
757        keyname = set(keyname)
758        keyname.add('oid')
759        values = []
760        for n in attnames:
761            if n in row and n not in keyname:
762                values.append('%s = %s' % (col(n), param(row[n], attnames[n])))
763        if not values:
764            return row
765        values = ', '.join(values)
766        ret = 'oid, *' if 'oid' in attnames else '*'
767        q = 'UPDATE %s SET %s WHERE %s RETURNING %s' % (
768            self._escape_qualified_name(table), values, where, ret)
769        self._do_debug(q, params)
770        q = self.db.query(q, params)
771        res = q.dictresult()
772        if res:  # may be empty when row does not exist
773            for n, value in res[0].items():
774                if n == 'oid':
775                    n = qoid
776                elif attnames.get(n) == 'bytea' and value is not None:
777                    value = self.unescape_bytea(value)
778                row[n] = value
779        return row
780
781    def upsert(self, table, row=None, **kw):
782        """Insert a row into a database table with conflict resolution
783
784        This method inserts a row into a table, but instead of raising a
785        ProgrammingError exception in case a row with the same primary key
786        already exists, an update will be executed instead.  This will be
787        performed as a single atomic operation on the database, so race
788        conditions can be avoided.
789
790        Like the insert method, the first parameter is the name of the
791        table and the second parameter can be used to pass the values to
792        be inserted as a dictionary.
793
794        Unlike the insert und update statement, keyword parameters are not
795        used to modify the dictionary, but to specify which columns shall
796        be updated in case of a conflict, and in which way:
797
798        A value of False or None means the column shall not be updated,
799        a value of True means the column shall be updated with the value
800        that has been proposed for insertion, i.e. has been passed as value
801        in the dictionary.  Columns that are not specified by keywords but
802        appear as keys in the dictionary are also updated like in the case
803        keywords had been passed with the value True.
804
805        So if in the case of a conflict you want to update every column that
806        has been passed in the dictionary row , you would call upsert(table, row).
807        If you don't want to do anything in case of a conflict, i.e. leave
808        the existing row as it is, call upsert(table, row, **dict.fromkeys(row)).
809
810        If you need more fine-grained control of what gets updated, you can
811        also pass strings in the keyword parameters.  These strings will
812        be used as SQL expressions for the update columns.  In these
813        expressions you can refer to the value that already exists in
814        the table by prefixing the column name with "included.", and to
815        the value that has been proposed for insertion by prefixing the
816        column name with the "excluded."
817
818        The dictionary is modified in any case to reflect the values in
819        the database after the operation has completed.
820
821        Note: The method uses the PostgreSQL "upsert" feature which is
822        only available since PostgreSQL 9.5.
823        """
824        if 'oid' in kw:
825            del kw['oid']
826        if row is None:
827            row = {}
828        attnames = self.get_attnames(table)
829        params = []
830        param = partial(self._prepare_param,params=params)
831        col = self.escape_identifier
832        names, values, updates = [], [], []
833        for n in attnames:
834            if n in row:
835                names.append(col(n))
836                values.append(param(row[n], attnames[n]))
837        names, values = ', '.join(names), ', '.join(values)
838        try:
839            keyname = self.pkey(table)
840        except KeyError:
841            raise _prg_error('Table %s has no primary key' % table)
842        keyname = [keyname] if isinstance(
843            keyname, basestring) else sorted(keyname)
844        try:
845            target = ', '.join(col(k) for k in keyname)
846        except KeyError:
847            raise _prg_error('Upsert operation needs primary key or oid')
848        update = []
849        keyname = set(keyname)
850        keyname.add('oid')
851        for n in attnames:
852            if n not in keyname:
853                value = kw.get(n, True)
854                if value:
855                    if not isinstance(value, basestring):
856                        value = 'excluded.%s' % col(n)
857                    update.append('%s = %s' % (col(n), value))
858        if not values and not update:
859            return row
860        do = 'update set %s' % ', '.join(update) if update else 'nothing'
861        ret = 'oid, *' if 'oid' in attnames else '*'
862        q = ('INSERT INTO %s AS included (%s) VALUES (%s)'
863            ' ON CONFLICT (%s) DO %s RETURNING %s') % (
864                self._escape_qualified_name(table), names, values,
865                target, do, ret)
866        self._do_debug(q, params)
867        try:
868            q = self.db.query(q, params)
869        except ProgrammingError:
870            if self.server_version < 90500:
871                raise _prg_error(
872                    'Upsert operation is not supported by PostgreSQL version')
873            raise  # re-raise original error
874        res = q.dictresult()
875        if res:  # may be empty with "do nothing"
876            for n, value in res[0].items():
877                if n == 'oid':
878                    n = _oid_key(table)
879                elif attnames.get(n) == 'bytea':
880                    value = self.unescape_bytea(value)
881                row[n] = value
882        elif update:
883            raise _int_error('Upsert operation did not return new values')
884        else:
885            self.get(table, row)
886        return row
887
888    def clear(self, table, row=None):
889        """Clear all the attributes to values determined by the types.
890
891        Numeric types are set to 0, Booleans are set to false, and everything
892        else is set to the empty string.  If the row argument is present,
893        it is used as the row dictionary and any entries matching attribute
894        names are cleared with everything else left unchanged.
895        """
896        # At some point we will need a way to get defaults from a table.
897        if row is None:
898            row = {}  # empty if argument is not present
899        attnames = self.get_attnames(table)
900        for n, t in attnames.items():
901            if n == 'oid':
902                continue
903            if t in ('int', 'integer', 'smallint', 'bigint',
904                    'float', 'real', 'double precision',
905                    'num', 'numeric', 'money'):
906                row[n] = 0
907            elif t in ('bool', 'boolean'):
908                row[n] = self._make_bool(False)
909            else:
910                row[n] = ''
911        return row
912
913    def delete(self, table, row=None, **kw):
914        """Delete an existing row in a database table.
915
916        This method deletes the row from a table.  It deletes based on the
917        OID value as munged by get or passed as keyword, or on the primary
918        key of the table.  The return value is the number of deleted rows
919        (i.e. 0 if the row did not exist and 1 if the row was deleted).
920        """
921        # Like update, delete works on the oid.
922        # One day we will be testing that the record to be deleted
923        # isn't referenced somewhere (or else PostgreSQL will).
924        # Note that we only accept oid key from named args for safety.
925        qoid = _oid_key(table)
926        if 'oid' in kw:
927            kw[qoid] = kw['oid']
928            del kw['oid']
929        if row is None:
930            row = {}
931        row.update(kw)
932        params = []
933        param = partial(self._prepare_param, params=params)
934        if qoid in row:
935            where = 'oid = %s' % param(row[qoid], 'int')
936        else:
937            try:
938                keyname = self.pkey(table)
939            except KeyError:
940                raise _prg_error('Table %s has no primary key' % table)
941            keyname = [keyname] if isinstance(
942                keyname, basestring) else sorted(keyname)
943            attnames = self.get_attnames(table)
944            col = self.escape_identifier
945            try:
946                where = ' AND '.join('%s = %s' % (
947                    col(k), param(row[k], attnames[k])) for k in keyname)
948            except KeyError:
949                raise _prg_error('Delete operation needs primary key or oid')
950        q = 'DELETE FROM %s WHERE %s' % (
951            self._escape_qualified_name(table), where)
952        self._do_debug(q, params)
953        res = self.db.query(q, params)
954        return int(res)
955
956    def notification_handler(self, event, callback, arg_dict={}, timeout=None):
957        """Get notification handler that will run the given callback."""
958        return NotificationHandler(self.db, event, callback, arg_dict, timeout)
959
960
961# if run as script, print some information
962
963if __name__ == '__main__':
964    print('PyGreSQL version' + version)
965    print('')
966    print(__doc__)
Note: See TracBrowser for help on using the repository browser.