source: trunk/pg.py @ 739

Last change on this file since 739 was 739, checked in by cito, 4 years ago

Return ordered dict for attributes is possible

Sometimes it's important to know the order of the columns in a table.
By returning an OrderedDict? instead of a dict in get_attnames, we can
deliver that information en passant, while staying backward compatible.

  • Property svn:keywords set to Id
File size: 35.7 KB
Line 
1#! /usr/bin/python
2#
3# pg.py
4#
5# $Id: pg.py 739 2016-01-14 00:30:50Z cito $
6#
7
8"""PyGreSQL classic interface.
9
10This pg module implements some basic database management stuff.
11It includes the _pg module and builds on it, providing the higher
12level wrapper class named DB with additional functionality.
13This is known as the "classic" ("old style") PyGreSQL interface.
14For a DB-API 2 compliant interface use the newer pgdb module.
15
16"""
17
18# Copyright (c) 1997-2016 by D'Arcy J.M. Cain.
19#
20# Contributions made by Ch. Zwerschke and others.
21#
22# The notification handler is based on pgnotify which is
23# Copyright (c) 2001 Ng Pheng Siong. All rights reserved.
24#
25# Permission to use, copy, modify, and distribute this software and its
26# documentation for any purpose and without fee is hereby granted,
27# provided that the above copyright notice appear in all copies and that
28# both that copyright notice and this permission notice appear in
29# supporting documentation.
30
31from __future__ import print_function
32
33from _pg import *
34
35import select
36import warnings
37
38from decimal import Decimal
39from collections import namedtuple
40from functools import partial
41
42try:
43    from collections import OrderedDict
44except ImportError:  # Python 2.6 or 3.0
45    OrderedDict = dict
46
47try:
48    basestring
49except NameError:  # Python >= 3.0
50    basestring = (str, bytes)
51
52set_decimal(Decimal)
53
54
55# Auxiliary functions that are independent from a DB connection:
56
57def _oid_key(table):
58    """Build oid key from a table name."""
59    return 'oid(%s)' % table
60
61
62def _simpletype(typ):
63    """Determine a simplified name a pg_type name."""
64    if typ.startswith('bool'):
65        return 'bool'
66    if typ.startswith(('abstime', 'date', 'interval', 'timestamp')):
67        return 'date'
68    if typ.startswith(('cid', 'oid', 'int', 'xid')):
69        return 'int'
70    if typ.startswith('float'):
71        return 'float'
72    if typ.startswith('numeric'):
73        return 'num'
74    if typ.startswith('money'):
75        return 'money'
76    if typ.startswith('bytea'):
77        return 'bytea'
78    return 'text'
79
80
81def _namedresult(q):
82    """Get query result as named tuples."""
83    row = namedtuple('Row', q.listfields())
84    return [row(*r) for r in q.getresult()]
85
86set_namedresult(_namedresult)
87
88
89def _db_error(msg, cls=DatabaseError):
90    """Returns DatabaseError with empty sqlstate attribute."""
91    error = cls(msg)
92    error.sqlstate = None
93    return error
94
95
96def _int_error(msg):
97    """Returns InternalError."""
98    return _db_error(msg, InternalError)
99
100
101def _prg_error(msg):
102    """Returns ProgrammingError."""
103    return _db_error(msg, ProgrammingError)
104
105
106class NotificationHandler(object):
107    """A PostgreSQL client-side asynchronous notification handler."""
108
109    def __init__(self, db, event, callback, arg_dict=None, timeout=None):
110        """Initialize the notification handler.
111
112        db       - PostgreSQL connection object.
113        event    - Event (notification channel) to LISTEN for.
114        callback - Event callback function.
115        arg_dict - A dictionary passed as the argument to the callback.
116        timeout  - Timeout in seconds; a floating point number denotes
117                   fractions of seconds. If it is absent or None, the
118                   callers will never time out.
119
120        """
121        if isinstance(db, DB):
122            db = db.db
123        self.db = db
124        self.event = event
125        self.stop_event = 'stop_%s' % event
126        self.listening = False
127        self.callback = callback
128        if arg_dict is None:
129            arg_dict = {}
130        self.arg_dict = arg_dict
131        self.timeout = timeout
132
133    def __del__(self):
134        self.close()
135
136    def close(self):
137        """Stop listening and close the connection."""
138        if self.db:
139            self.unlisten()
140            self.db.close()
141            self.db = None
142
143    def listen(self):
144        """Start listening for the event and the stop event."""
145        if not self.listening:
146            self.db.query('listen "%s"' % self.event)
147            self.db.query('listen "%s"' % self.stop_event)
148            self.listening = True
149
150    def unlisten(self):
151        """Stop listening for the event and the stop event."""
152        if self.listening:
153            self.db.query('unlisten "%s"' % self.event)
154            self.db.query('unlisten "%s"' % self.stop_event)
155            self.listening = False
156
157    def notify(self, db=None, stop=False, payload=None):
158        """Generate a notification.
159
160        Note: If the main loop is running in another thread, you must pass
161        a different database connection to avoid a collision.
162
163        """
164        if not db:
165            db = self.db
166        if self.listening:
167            q = 'notify "%s"' % (self.stop_event if stop else self.event)
168            if payload:
169                q += ", '%s'" % payload
170            return db.query(q)
171
172    def __call__(self, close=False):
173        """Invoke the notification handler.
174
175        The handler is a loop that actually LISTENs for two NOTIFY messages:
176
177        <event> and stop_<event>.
178
179        When either of these NOTIFY messages are received, its associated
180        'pid' and 'event' are inserted into <arg_dict>, and the callback is
181        invoked with <arg_dict>. If the NOTIFY message is stop_<event>, the
182        handler UNLISTENs both <event> and stop_<event> and exits.
183
184        Note: If you run this loop in another thread, don't use the same
185        database connection for database operations in the main thread.
186
187        """
188        self.listen()
189        _ilist = [self.db.fileno()]
190
191        while self.listening:
192            ilist, _olist, _elist = select.select(_ilist, [], [], self.timeout)
193            if ilist:
194                while self.listening:
195                    notice = self.db.getnotify()
196                    if not notice:  # no more messages
197                        break
198                    event, pid, extra = notice
199                    if event not in (self.event, self.stop_event):
200                        self.unlisten()
201                        raise _db_error(
202                            'listening for "%s" and "%s", but notified of "%s"'
203                            % (self.event, self.stop_event, event))
204                    if event == self.stop_event:
205                        self.unlisten()
206                    self.arg_dict['pid'] = pid
207                    self.arg_dict['event'] = event
208                    self.arg_dict['extra'] = extra
209                    self.callback(self.arg_dict)
210            else:   # we timed out
211                self.unlisten()
212                self.callback(None)
213
214
215def pgnotify(*args, **kw):
216    """Same as NotificationHandler, under the traditional name."""
217    warnings.warn("pgnotify is deprecated, use NotificationHandler instead.",
218        DeprecationWarning, stacklevel=2)
219    return NotificationHandler(*args, **kw)
220
221
222# The actual PostGreSQL database connection interface:
223
224class DB(object):
225    """Wrapper class for the _pg connection type."""
226
227    def __init__(self, *args, **kw):
228        """Create a new connection.
229
230        You can pass either the connection parameters or an existing
231        _pg or pgdb connection. This allows you to use the methods
232        of the classic pg interface with a DB-API 2 pgdb connection.
233
234        """
235        if not args and len(kw) == 1:
236            db = kw.get('db')
237        elif not kw and len(args) == 1:
238            db = args[0]
239        else:
240            db = None
241        if db:
242            if isinstance(db, DB):
243                db = db.db
244            else:
245                try:
246                    db = db._cnx
247                except AttributeError:
248                    pass
249        if not db or not hasattr(db, 'db') or not hasattr(db, 'query'):
250            db = connect(*args, **kw)
251            self._closeable = True
252        else:
253            self._closeable = False
254        self.db = db
255        self.dbname = db.db
256        self._regtypes = False
257        self._attnames = {}
258        self._pkeys = {}
259        self._privileges = {}
260        self._args = args, kw
261        self.debug = None  # For debugging scripts, this can be set
262            # * to a string format specification (e.g. in CGI set to "%s<BR>"),
263            # * to a file object to write debug statements or
264            # * to a callable object which takes a string argument
265            # * to any other true value to just print debug statements
266
267    def __getattr__(self, name):
268        # All undefined members are same as in underlying connection:
269        if self.db:
270            return getattr(self.db, name)
271        else:
272            raise _int_error('Connection is not valid')
273
274    def __dir__(self):
275        # Custom dir function including the attributes of the connection:
276        attrs = set(self.__class__.__dict__)
277        attrs.update(self.__dict__)
278        attrs.update(dir(self.db))
279        return sorted(attrs)
280
281    # Context manager methods
282
283    def __enter__(self):
284        """Enter the runtime context. This will start a transaction."""
285        self.begin()
286        return self
287
288    def __exit__(self, et, ev, tb):
289        """Exit the runtime context. This will end the transaction."""
290        if et is None and ev is None and tb is None:
291            self.commit()
292        else:
293            self.rollback()
294
295    # Auxiliary methods
296
297    def _do_debug(self, *args):
298        """Print a debug message."""
299        if self.debug:
300            s = '\n'.join(args)
301            if isinstance(self.debug, basestring):
302                print(self.debug % s)
303            elif hasattr(self.debug, 'write'):
304                self.debug.write(s + '\n')
305            elif callable(self.debug):
306                self.debug(s)
307            else:
308                print(s)
309
310    def _escape_qualified_name(self, s):
311        """Escape a qualified name.
312
313        Escapes the name for use as an SQL identifier, unless the
314        name contains a dot, in which case the name is ambiguous
315        (could be a qualified name or just a name with a dot in it)
316        and must be quoted manually by the caller.
317
318        """
319        if '.' not in s:
320            s = self.escape_identifier(s)
321        return s
322
323    @staticmethod
324    def _make_bool(d):
325        """Get boolean value corresponding to d."""
326        return bool(d) if get_bool() else ('t' if d else 'f')
327
328    _bool_true_values = frozenset('t true 1 y yes on'.split())
329
330    def _prepare_bool(self, d):
331        """Prepare a boolean parameter."""
332        if isinstance(d, basestring):
333            if not d:
334                return None
335            d = d.lower() in self._bool_true_values
336        return 't' if d else 'f'
337
338    _date_literals = frozenset('current_date current_time'
339        ' current_timestamp localtime localtimestamp'.split())
340
341    def _prepare_date(self, d):
342        """Prepare a date parameter."""
343        if not d:
344            return None
345        if isinstance(d, basestring) and d.lower() in self._date_literals:
346            raise ValueError
347        return d
348
349    def _prepare_num(self, d):
350        """Prepare a numeric parameter."""
351        if not d and d != 0:
352            return None
353        return d
354
355    def _prepare_bytea(self, d):
356        """Prepare a bytea parameter."""
357        return self.escape_bytea(d)
358
359    _prepare_funcs = dict(  # quote methods for each type
360        bool=_prepare_bool, date=_prepare_date,
361        int=_prepare_num, num=_prepare_num, float=_prepare_num,
362        money=_prepare_num, bytea=_prepare_bytea)
363
364    def _prepare_param(self, value, typ, params):
365        """Prepare and add a parameter to the list."""
366        if value is not None and typ != 'text':
367            try:
368                prepare = self._prepare_funcs[typ]
369            except KeyError:
370                pass
371            else:
372                try:
373                    value = prepare(self, value)
374                except ValueError:
375                    return value
376        params.append(value)
377        return '$%d' % len(params)
378
379    @staticmethod
380    def _prepare_qualified_param(name, param):
381        """Quote parameter representing a qualified name.
382
383        Escapes the name for use as an SQL parameter, unless the
384        name contains a dot, in which case the name is ambiguous
385        (could be a qualified name or just a name with a dot in it)
386        and must be quoted manually by the caller.
387
388        """
389        if isinstance(param, int):
390            param = "$%d" % param
391        if '.' not in name:
392            param = 'quote_ident(%s)' % (param,)
393        return param
394
395    # Public methods
396
397    # escape_string and escape_bytea exist as methods,
398    # so we define unescape_bytea as a method as well
399    unescape_bytea = staticmethod(unescape_bytea)
400
401    def close(self):
402        """Close the database connection."""
403        # Wraps shared library function so we can track state.
404        if self._closeable:
405            if self.db:
406                self.db.close()
407                self.db = None
408            else:
409                raise _int_error('Connection already closed')
410
411    def reset(self):
412        """Reset connection with current parameters.
413
414        All derived queries and large objects derived from this connection
415        will not be usable after this call.
416
417        """
418        if self.db:
419            self.db.reset()
420        else:
421            raise _int_error('Connection already closed')
422
423    def reopen(self):
424        """Reopen connection to the database.
425
426        Used in case we need another connection to the same database.
427        Note that we can still reopen a database that we have closed.
428
429        """
430        # There is no such shared library function.
431        if self._closeable:
432            db = connect(*self._args[0], **self._args[1])
433            if self.db:
434                self.db.close()
435            self.db = db
436
437    def begin(self, mode=None):
438        """Begin a transaction."""
439        qstr = 'BEGIN'
440        if mode:
441            qstr += ' ' + mode
442        return self.query(qstr)
443
444    start = begin
445
446    def commit(self):
447        """Commit the current transaction."""
448        return self.query('COMMIT')
449
450    end = commit
451
452    def rollback(self, name=None):
453        """Roll back the current transaction."""
454        qstr = 'ROLLBACK'
455        if name:
456            qstr += ' TO ' + name
457        return self.query(qstr)
458
459    def savepoint(self, name):
460        """Define a new savepoint within the current transaction."""
461        return self.query('SAVEPOINT ' + name)
462
463    def release(self, name):
464        """Destroy a previously defined savepoint."""
465        return self.query('RELEASE ' + name)
466
467    def query(self, qstr, *args):
468        """Executes a SQL command string.
469
470        This method simply sends a SQL query to the database. If the query is
471        an insert statement that inserted exactly one row into a table that
472        has OIDs, the return value is the OID of the newly inserted row.
473        If the query is an update or delete statement, or an insert statement
474        that did not insert exactly one row in a table with OIDs, then the
475        number of rows affected is returned as a string. If it is a statement
476        that returns rows as a result (usually a select statement, but maybe
477        also an "insert/update ... returning" statement), this method returns
478        a Query object that can be accessed via getresult() or dictresult()
479        or simply printed. Otherwise, it returns `None`.
480
481        The query can contain numbered parameters of the form $1 in place
482        of any data constant. Arguments given after the query string will
483        be substituted for the corresponding numbered parameter. Parameter
484        values can also be given as a single list or tuple argument.
485
486        """
487        # Wraps shared library function for debugging.
488        if not self.db:
489            raise _int_error('Connection is not valid')
490        self._do_debug(qstr)
491        return self.db.query(qstr, args)
492
493    def pkey(self, table, flush=False):
494        """This method gets or sets the primary key of a table.
495
496        Composite primary keys are represented as frozensets. Note that
497        this raises a KeyError if the table does not have a primary key.
498
499        If flush is set then the internal cache for primary keys will
500        be flushed. This may be necessary after the database schema or
501        the search path has been changed.
502
503        """
504        pkeys = self._pkeys
505        if flush:
506            pkeys.clear()
507            self._do_debug('pkey cache has been flushed')
508        try:  # cache lookup
509            pkey = pkeys[table]
510        except KeyError:  # cache miss, check the database
511            q = ("SELECT a.attname FROM pg_index i"
512                " JOIN pg_attribute a ON a.attrelid = i.indrelid"
513                " AND a.attnum = ANY(i.indkey)"
514                " AND NOT a.attisdropped"
515                " WHERE i.indrelid=%s::regclass"
516                " AND i.indisprimary") % (
517                    self._prepare_qualified_param(table, 1),)
518            pkey = self.db.query(q, (table,)).getresult()
519            if not pkey:
520                raise KeyError('Table %s has no primary key' % table)
521            if len(pkey) > 1:
522                pkey = frozenset(k[0] for k in pkey)
523            else:
524                pkey = pkey[0][0]
525            pkeys[table] = pkey  # cache it
526        return pkey
527
528    def get_databases(self):
529        """Get list of databases in the system."""
530        return [s[0] for s in
531            self.db.query('SELECT datname FROM pg_database').getresult()]
532
533    def get_relations(self, kinds=None):
534        """Get list of relations in connected database of specified kinds.
535
536        If kinds is None or empty, all kinds of relations are returned.
537        Otherwise kinds can be a string or sequence of type letters
538        specifying which kind of relations you want to list.
539
540        """
541        where = " AND r.relkind IN (%s)" % ','.join(
542            ["'%s'" % k for k in kinds]) if kinds else ''
543        q = ("SELECT quote_ident(s.nspname)||'.'||quote_ident(r.relname)"
544            " FROM pg_class r"
545            " JOIN pg_namespace s ON s.oid = r.relnamespace"
546            " WHERE s.nspname NOT SIMILAR"
547            " TO 'pg/_%%|information/_schema' ESCAPE '/' %s"
548            " ORDER BY s.nspname, r.relname") % where
549        return [r[0] for r in self.db.query(q).getresult()]
550
551    def get_tables(self):
552        """Return list of tables in connected database."""
553        return self.get_relations('r')
554
555    def get_attnames(self, table, flush=False):
556        """Given the name of a table, digs out the set of attribute names.
557
558        Returns a dictionary of attribute names (the names are the keys,
559        the values are the names of the attributes' types).
560
561        If your Python version supports this, the dictionary will be an
562        OrderedDictionary with the column names in the right order.
563
564        If flush is set then the internal cache for attribute names will
565        be flushed. This may be necessary after the database schema or
566        the search path has been changed.
567
568        By default, only a limited number of simple types will be returned.
569        You can get the regular types after calling use_regtypes(True).
570
571        """
572        attnames = self._attnames
573        if flush:
574            attnames.clear()
575            self._do_debug('pkey cache has been flushed')
576        try:  # cache lookup
577            names = attnames[table]
578        except KeyError:  # cache miss, check the database
579            q = ("SELECT a.attname, t.typname%s"
580                " FROM pg_attribute a"
581                " JOIN pg_type t ON t.oid = a.atttypid"
582                " WHERE a.attrelid = %s::regclass"
583                " AND (a.attnum > 0 OR a.attname = 'oid')"
584                " AND NOT a.attisdropped ORDER BY a.attnum") % (
585                    '::regtype' if self._regtypes else '',
586                    self._prepare_qualified_param(table, 1))
587            names = self.db.query(q, (table,)).getresult()
588            if not names:
589                raise KeyError('Table %s does not exist' % table)
590            if not self._regtypes:
591                names = ((name, _simpletype(typ)) for name, typ in names)
592            names = OrderedDict(names)
593            attnames[table] = names  # cache it
594        return names
595
596    def use_regtypes(self, regtypes=None):
597        """Use regular type names instead of simplified type names."""
598        if regtypes is None:
599            return self._regtypes
600        else:
601            regtypes = bool(regtypes)
602            if regtypes != self._regtypes:
603                self._regtypes = regtypes
604                self._attnames.clear()
605            return regtypes
606
607    def has_table_privilege(self, table, privilege='select'):
608        """Check whether current user has specified table privilege."""
609        privilege = privilege.lower()
610        try:  # ask cache
611            return self._privileges[(table, privilege)]
612        except KeyError:  # cache miss, ask the database
613            q = "SELECT has_table_privilege(%s, $2)" % (
614                self._prepare_qualified_param(table, 1),)
615            q = self.db.query(q, (table, privilege))
616            ret = q.getresult()[0][0] == self._make_bool(True)
617            self._privileges[(table, privilege)] = ret  # cache it
618            return ret
619
620    def get(self, table, row, keyname=None):
621        """Get a row from a database table or view.
622
623        This method is the basic mechanism to get a single row.  The keyname
624        that the key specifies a unique row.  If keyname is not specified
625        then the primary key for the table is used.  If row is a dictionary
626        then the value for the key is taken from it and it is modified to
627        include the new values, replacing existing values where necessary.
628        For a composite key, keyname can also be a sequence of key names.
629        The OID is also put into the dictionary if the table has one, but
630        in order to allow the caller to work with multiple tables, it is
631        munged as "oid(table)".
632
633        """
634        if table.endswith('*'):  # scan descendant tables?
635            table = table[:-1].rstrip()  # need parent table name
636        if not keyname:
637            # use the primary key by default
638            try:
639                keyname = self.pkey(table)
640            except KeyError:
641                raise _prg_error('Table %s has no primary key' % table)
642        attnames = self.get_attnames(table)
643        params = []
644        param = partial(self._prepare_param, params=params)
645        col = self.escape_identifier
646        # We want the oid for later updates if that isn't the key.
647        # To allow users to work with multiple tables, we munge
648        # the name of the "oid" key by adding the name of the table.
649        qoid = _oid_key(table)
650        if keyname == 'oid':
651            if isinstance(row, dict):
652                if qoid not in row:
653                    raise _db_error('%s not in row' % qoid)
654            else:
655                row = {qoid: row}
656            what = '*'
657            where = 'oid = %s' % param(row[qoid], 'int')
658        else:
659            keyname = [keyname] if isinstance(
660                keyname, basestring) else sorted(keyname)
661            if not isinstance(row, dict):
662                if len(keyname) > 1:
663                    raise _prg_error('Composite key needs dict as row')
664                row = dict((k, row) for k in keyname)
665            what = ', '.join(col(k) for k in attnames)
666            where = ' AND '.join('%s = %s' % (
667                col(k), param(row[k], attnames[k])) for k in keyname)
668        q = 'SELECT %s FROM %s WHERE %s LIMIT 1' % (
669            what, self._escape_qualified_name(table), where)
670        self._do_debug(q, params)
671        q = self.db.query(q, params)
672        res = q.dictresult()
673        if not res:
674            raise _db_error('No such record in %s where %s' % (table, where))
675        for n, value in res[0].items():
676            if n == 'oid':
677                n = qoid
678            elif attnames.get(n) == 'bytea':
679                value = self.unescape_bytea(value)
680            row[n] = value
681        return row
682
683    def insert(self, table, row=None, **kw):
684        """Insert a row into a database table.
685
686        This method inserts a row into a table.  The name of the table must
687        be passed as the first parameter.  The other parameters are used for
688        providing the data of the row that shall be inserted into the table.
689        If a dictionary is supplied as the second parameter, it starts with
690        that.  Otherwise it uses a blank dictionary. Either way the dictionary
691        is updated from the keywords.
692
693        The dictionary is then reloaded with the values actually inserted in
694        order to pick up values modified by rules, triggers, etc.
695
696        Note: The method currently doesn't support insert into views
697        although PostgreSQL does.
698
699        """
700        if 'oid' in kw:
701            del kw['oid']
702        if row is None:
703            row = {}
704        row.update(kw)
705        attnames = self.get_attnames(table)
706        params = []
707        param = partial(self._prepare_param, params=params)
708        col = self.escape_identifier
709        names, values = [], []
710        for n in attnames:
711            if n in row:
712                names.append(col(n))
713                values.append(param(row[n], attnames[n]))
714        names, values = ', '.join(names), ', '.join(values)
715        ret = 'oid, *' if 'oid' in attnames else '*'
716        q = 'INSERT INTO %s (%s) VALUES (%s) RETURNING %s' % (
717            self._escape_qualified_name(table), names, values, ret)
718        self._do_debug(q, params)
719        q = self.db.query(q, params)
720        res = q.dictresult()
721        if not res:
722            raise _int_error('insert did not return new values')
723        for n, value in res[0].items():
724            if n == 'oid':
725                n = _oid_key(table)
726            elif attnames.get(n) == 'bytea' and value is not None:
727                value = self.unescape_bytea(value)
728            row[n] = value
729        return row
730
731    def update(self, table, row=None, **kw):
732        """Update an existing row in a database table.
733
734        Similar to insert but updates an existing row.  The update is based
735        on the OID value as munged by get or passed as keyword, or on the
736        primary key of the table.  The dictionary is modified to reflect
737        any changes caused by the update due to triggers, rules, default
738        values, etc.
739
740        """
741        # Update always works on the oid which get() returns if available,
742        # otherwise use the primary key.  Fail if neither.
743        # Note that we only accept oid key from named args for safety.
744        qoid = _oid_key(table)
745        if 'oid' in kw:
746            kw[qoid] = kw['oid']
747            del kw['oid']
748        if row is None:
749            row = {}
750        row.update(kw)
751        attnames = self.get_attnames(table)
752        params = []
753        param = partial(self._prepare_param, params=params)
754        col = self.escape_identifier
755        if qoid in row:
756            where = 'oid = %s' % param(row[qoid], 'int')
757            keyname = []
758        else:
759            try:
760                keyname = self.pkey(table)
761            except KeyError:
762                raise _prg_error('Table %s has no primary key' % table)
763            keyname = [keyname] if isinstance(
764                keyname, basestring) else sorted(keyname)
765            try:
766                where = ' AND '.join('%s = %s' % (
767                    col(k), param(row[k], attnames[k])) for k in keyname)
768            except KeyError:
769                raise _prg_error('update needs primary key or oid')
770        keyname = set(keyname)
771        keyname.add('oid')
772        values = []
773        for n in attnames:
774            if n in row and n not in keyname:
775                values.append('%s = %s' % (col(n), param(row[n], attnames[n])))
776        if not values:
777            return row
778        values = ', '.join(values)
779        ret = 'oid, *' if 'oid' in attnames else '*'
780        q = 'UPDATE %s SET %s WHERE %s RETURNING %s' % (
781            self._escape_qualified_name(table), values, where, ret)
782        self._do_debug(q, params)
783        q = self.db.query(q, params)
784        res = q.dictresult()
785        if res:  # may be empty when row does not exist
786            for n, value in res[0].items():
787                if n == 'oid':
788                    n = qoid
789                elif attnames.get(n) == 'bytea' and value is not None:
790                    value = self.unescape_bytea(value)
791                row[n] = value
792        return row
793
794    def upsert(self, table, row=None, **kw):
795        """Insert a row into a database table with conflict resolution.
796
797        This method inserts a row into a table, but instead of raising a
798        ProgrammingError exception in case a row with the same primary key
799        already exists, an update will be executed instead.  This will be
800        performed as a single atomic operation on the database, so race
801        conditions can be avoided.
802
803        Like the insert method, the first parameter is the name of the
804        table and the second parameter can be used to pass the values to
805        be inserted as a dictionary.
806
807        Unlike the insert und update statement, keyword parameters are not
808        used to modify the dictionary, but to specify which columns shall
809        be updated in case of a conflict, and in which way:
810
811        A value of False or None means the column shall not be updated,
812        a value of True means the column shall be updated with the value
813        that has been proposed for insertion, i.e. has been passed as value
814        in the dictionary.  Columns that are not specified by keywords but
815        appear as keys in the dictionary are also updated like in the case
816        keywords had been passed with the value True.
817
818        So if in the case of a conflict you want to update every column that
819        has been passed in the dictionary row , you would call upsert(table, row).
820        If you don't want to do anything in case of a conflict, i.e. leave
821        the existing row as it is, call upsert(table, row, **dict.fromkeys(row)).
822
823        If you need more fine-grained control of what gets updated, you can
824        also pass strings in the keyword parameters.  These strings will
825        be used as SQL expressions for the update columns.  In these
826        expressions you can refer to the value that already exists in
827        the table by prefixing the column name with "included.", and to
828        the value that has been proposed for insertion by prefixing the
829        column name with the "excluded."
830
831        The dictionary is modified in any case to reflect the values in
832        the database after the operation has completed.
833
834        Note: The method uses the PostgreSQL "upsert" feature which is
835        only available since PostgreSQL 9.5.
836
837        """
838        if 'oid' in kw:
839            del kw['oid']
840        if row is None:
841            row = {}
842        attnames = self.get_attnames(table)
843        params = []
844        param = partial(self._prepare_param,params=params)
845        col = self.escape_identifier
846        names, values, updates = [], [], []
847        for n in attnames:
848            if n in row:
849                names.append(col(n))
850                values.append(param(row[n], attnames[n]))
851        names, values = ', '.join(names), ', '.join(values)
852        try:
853            keyname = self.pkey(table)
854        except KeyError:
855            raise _prg_error('Table %s has no primary key' % table)
856        keyname = [keyname] if isinstance(
857            keyname, basestring) else sorted(keyname)
858        try:
859            target = ', '.join(col(k) for k in keyname)
860        except KeyError:
861            raise _prg_error('upsert needs primary key or oid')
862        update = []
863        keyname = set(keyname)
864        keyname.add('oid')
865        for n in attnames:
866            if n not in keyname:
867                value = kw.get(n, True)
868                if value:
869                    if not isinstance(value, basestring):
870                        value = 'excluded.%s' % col(n)
871                    update.append('%s = %s' % (col(n), value))
872        if not values and not update:
873            return row
874        do = 'update set %s' % ', '.join(update) if update else 'nothing'
875        ret = 'oid, *' if 'oid' in attnames else '*'
876        q = ('INSERT INTO %s AS included (%s) VALUES (%s)'
877            ' ON CONFLICT (%s) DO %s RETURNING %s') % (
878                self._escape_qualified_name(table), names, values,
879                target, do, ret)
880        self._do_debug(q, params)
881        try:
882            q = self.db.query(q, params)
883        except ProgrammingError:
884            if self.server_version < 90500:
885                raise _prg_error('upsert not supported by PostgreSQL version')
886            raise  # re-raise original error
887        res = q.dictresult()
888        if res:  # may be empty with "do nothing"
889            for n, value in res[0].items():
890                if n == 'oid':
891                    n = _oid_key(table)
892                elif attnames.get(n) == 'bytea':
893                    value = self.unescape_bytea(value)
894                row[n] = value
895        elif update:
896            raise _int_error('upsert did not return new values')
897        else:
898            self.get(table, row)
899        return row
900
901    def clear(self, table, row=None):
902        """Clear all the attributes to values determined by the types.
903
904        Numeric types are set to 0, Booleans are set to false, and everything
905        else is set to the empty string.  If the row argument is present,
906        it is used as the row dictionary and any entries matching attribute
907        names are cleared with everything else left unchanged.
908
909        """
910        # At some point we will need a way to get defaults from a table.
911        if row is None:
912            row = {}  # empty if argument is not present
913        attnames = self.get_attnames(table)
914        for n, t in attnames.items():
915            if n == 'oid':
916                continue
917            if t in ('int', 'integer', 'smallint', 'bigint',
918                    'float', 'real', 'double precision',
919                    'num', 'numeric', 'money'):
920                row[n] = 0
921            elif t in ('bool', 'boolean'):
922                row[n] = self._make_bool(False)
923            else:
924                row[n] = ''
925        return row
926
927    def delete(self, table, row=None, **kw):
928        """Delete an existing row in a database table.
929
930        This method deletes the row from a table.  It deletes based on the
931        OID value as munged by get or passed as keyword, or on the primary
932        key of the table.  The return value is the number of deleted rows
933        (i.e. 0 if the row did not exist and 1 if the row was deleted).
934
935        """
936        # Like update, delete works on the oid.
937        # One day we will be testing that the record to be deleted
938        # isn't referenced somewhere (or else PostgreSQL will).
939        # Note that we only accept oid key from named args for safety.
940        qoid = _oid_key(table)
941        if 'oid' in kw:
942            kw[qoid] = kw['oid']
943            del kw['oid']
944        if row is None:
945            row = {}
946        row.update(kw)
947        params = []
948        param = partial(self._prepare_param, params=params)
949        if qoid in row:
950            where = 'oid = %s' % param(row[qoid], 'int')
951        else:
952            try:
953                keyname = self.pkey(table)
954            except KeyError:
955                raise _prg_error('Table %s has no primary key' % table)
956            keyname = [keyname] if isinstance(
957                keyname, basestring) else sorted(keyname)
958            attnames = self.get_attnames(table)
959            col = self.escape_identifier
960            try:
961                where = ' AND '.join('%s = %s' % (
962                    col(k), param(row[k], attnames[k])) for k in keyname)
963            except KeyError:
964                raise _prg_error('delete needs primary key or oid')
965        q = 'DELETE FROM %s WHERE %s' % (
966            self._escape_qualified_name(table), where)
967        self._do_debug(q, params)
968        res = self.db.query(q, params)
969        return int(res)
970
971    def notification_handler(self, event, callback, arg_dict={}, timeout=None):
972        """Get notification handler that will run the given callback."""
973        return NotificationHandler(self.db, event, callback, arg_dict, timeout)
974
975
976# if run as script, print some information
977
978if __name__ == '__main__':
979    print('PyGreSQL version' + version)
980    print('')
981    print(__doc__)
Note: See TracBrowser for help on using the repository browser.