source: trunk/pg.py @ 738

Last change on this file since 738 was 738, checked in by cito, 4 years ago

Renamed some parameters to clarify their meaning

This is uncritical, because these are not keyword parameters. Also, we don't
change much in the 4.x branch in order to stay compatible.

Also, avoided the old terminology "class" for a table or table-like object in PostgreSQL,
because it may confuse people. The word "table" is so much clearer.

  • Property svn:keywords set to Id
File size: 35.4 KB
Line 
1#! /usr/bin/python
2#
3# pg.py
4#
5# $Id: pg.py 738 2016-01-13 22:57:29Z cito $
6#
7
8"""PyGreSQL classic interface.
9
10This pg module implements some basic database management stuff.
11It includes the _pg module and builds on it, providing the higher
12level wrapper class named DB with additional functionality.
13This is known as the "classic" ("old style") PyGreSQL interface.
14For a DB-API 2 compliant interface use the newer pgdb module.
15
16"""
17
18# Copyright (c) 1997-2016 by D'Arcy J.M. Cain.
19#
20# Contributions made by Ch. Zwerschke and others.
21#
22# The notification handler is based on pgnotify which is
23# Copyright (c) 2001 Ng Pheng Siong. All rights reserved.
24#
25# Permission to use, copy, modify, and distribute this software and its
26# documentation for any purpose and without fee is hereby granted,
27# provided that the above copyright notice appear in all copies and that
28# both that copyright notice and this permission notice appear in
29# supporting documentation.
30
31from __future__ import print_function
32
33from _pg import *
34
35import select
36import warnings
37
38from decimal import Decimal
39from collections import namedtuple
40from functools import partial
41
42try:
43    basestring
44except NameError:  # Python >= 3.0
45    basestring = (str, bytes)
46
47set_decimal(Decimal)
48
49
50# Auxiliary functions that are independent from a DB connection:
51
52def _oid_key(table):
53    """Build oid key from a table name."""
54    return 'oid(%s)' % table
55
56
57def _simpletype(typ):
58    """Determine a simplified name a pg_type name."""
59    if typ.startswith('bool'):
60        return 'bool'
61    if typ.startswith(('abstime', 'date', 'interval', 'timestamp')):
62        return 'date'
63    if typ.startswith(('cid', 'oid', 'int', 'xid')):
64        return 'int'
65    if typ.startswith('float'):
66        return 'float'
67    if typ.startswith('numeric'):
68        return 'num'
69    if typ.startswith('money'):
70        return 'money'
71    if typ.startswith('bytea'):
72        return 'bytea'
73    return 'text'
74
75
76def _namedresult(q):
77    """Get query result as named tuples."""
78    row = namedtuple('Row', q.listfields())
79    return [row(*r) for r in q.getresult()]
80
81set_namedresult(_namedresult)
82
83
84def _db_error(msg, cls=DatabaseError):
85    """Returns DatabaseError with empty sqlstate attribute."""
86    error = cls(msg)
87    error.sqlstate = None
88    return error
89
90
91def _int_error(msg):
92    """Returns InternalError."""
93    return _db_error(msg, InternalError)
94
95
96def _prg_error(msg):
97    """Returns ProgrammingError."""
98    return _db_error(msg, ProgrammingError)
99
100
101class NotificationHandler(object):
102    """A PostgreSQL client-side asynchronous notification handler."""
103
104    def __init__(self, db, event, callback, arg_dict=None, timeout=None):
105        """Initialize the notification handler.
106
107        db       - PostgreSQL connection object.
108        event    - Event (notification channel) to LISTEN for.
109        callback - Event callback function.
110        arg_dict - A dictionary passed as the argument to the callback.
111        timeout  - Timeout in seconds; a floating point number denotes
112                   fractions of seconds. If it is absent or None, the
113                   callers will never time out.
114
115        """
116        if isinstance(db, DB):
117            db = db.db
118        self.db = db
119        self.event = event
120        self.stop_event = 'stop_%s' % event
121        self.listening = False
122        self.callback = callback
123        if arg_dict is None:
124            arg_dict = {}
125        self.arg_dict = arg_dict
126        self.timeout = timeout
127
128    def __del__(self):
129        self.close()
130
131    def close(self):
132        """Stop listening and close the connection."""
133        if self.db:
134            self.unlisten()
135            self.db.close()
136            self.db = None
137
138    def listen(self):
139        """Start listening for the event and the stop event."""
140        if not self.listening:
141            self.db.query('listen "%s"' % self.event)
142            self.db.query('listen "%s"' % self.stop_event)
143            self.listening = True
144
145    def unlisten(self):
146        """Stop listening for the event and the stop event."""
147        if self.listening:
148            self.db.query('unlisten "%s"' % self.event)
149            self.db.query('unlisten "%s"' % self.stop_event)
150            self.listening = False
151
152    def notify(self, db=None, stop=False, payload=None):
153        """Generate a notification.
154
155        Note: If the main loop is running in another thread, you must pass
156        a different database connection to avoid a collision.
157
158        """
159        if not db:
160            db = self.db
161        if self.listening:
162            q = 'notify "%s"' % (self.stop_event if stop else self.event)
163            if payload:
164                q += ", '%s'" % payload
165            return db.query(q)
166
167    def __call__(self, close=False):
168        """Invoke the notification handler.
169
170        The handler is a loop that actually LISTENs for two NOTIFY messages:
171
172        <event> and stop_<event>.
173
174        When either of these NOTIFY messages are received, its associated
175        'pid' and 'event' are inserted into <arg_dict>, and the callback is
176        invoked with <arg_dict>. If the NOTIFY message is stop_<event>, the
177        handler UNLISTENs both <event> and stop_<event> and exits.
178
179        Note: If you run this loop in another thread, don't use the same
180        database connection for database operations in the main thread.
181
182        """
183        self.listen()
184        _ilist = [self.db.fileno()]
185
186        while self.listening:
187            ilist, _olist, _elist = select.select(_ilist, [], [], self.timeout)
188            if ilist:
189                while self.listening:
190                    notice = self.db.getnotify()
191                    if not notice:  # no more messages
192                        break
193                    event, pid, extra = notice
194                    if event not in (self.event, self.stop_event):
195                        self.unlisten()
196                        raise _db_error(
197                            'listening for "%s" and "%s", but notified of "%s"'
198                            % (self.event, self.stop_event, event))
199                    if event == self.stop_event:
200                        self.unlisten()
201                    self.arg_dict['pid'] = pid
202                    self.arg_dict['event'] = event
203                    self.arg_dict['extra'] = extra
204                    self.callback(self.arg_dict)
205            else:   # we timed out
206                self.unlisten()
207                self.callback(None)
208
209
210def pgnotify(*args, **kw):
211    """Same as NotificationHandler, under the traditional name."""
212    warnings.warn("pgnotify is deprecated, use NotificationHandler instead.",
213        DeprecationWarning, stacklevel=2)
214    return NotificationHandler(*args, **kw)
215
216
217# The actual PostGreSQL database connection interface:
218
219class DB(object):
220    """Wrapper class for the _pg connection type."""
221
222    def __init__(self, *args, **kw):
223        """Create a new connection.
224
225        You can pass either the connection parameters or an existing
226        _pg or pgdb connection. This allows you to use the methods
227        of the classic pg interface with a DB-API 2 pgdb connection.
228
229        """
230        if not args and len(kw) == 1:
231            db = kw.get('db')
232        elif not kw and len(args) == 1:
233            db = args[0]
234        else:
235            db = None
236        if db:
237            if isinstance(db, DB):
238                db = db.db
239            else:
240                try:
241                    db = db._cnx
242                except AttributeError:
243                    pass
244        if not db or not hasattr(db, 'db') or not hasattr(db, 'query'):
245            db = connect(*args, **kw)
246            self._closeable = True
247        else:
248            self._closeable = False
249        self.db = db
250        self.dbname = db.db
251        self._regtypes = False
252        self._attnames = {}
253        self._pkeys = {}
254        self._privileges = {}
255        self._args = args, kw
256        self.debug = None  # For debugging scripts, this can be set
257            # * to a string format specification (e.g. in CGI set to "%s<BR>"),
258            # * to a file object to write debug statements or
259            # * to a callable object which takes a string argument
260            # * to any other true value to just print debug statements
261
262    def __getattr__(self, name):
263        # All undefined members are same as in underlying connection:
264        if self.db:
265            return getattr(self.db, name)
266        else:
267            raise _int_error('Connection is not valid')
268
269    def __dir__(self):
270        # Custom dir function including the attributes of the connection:
271        attrs = set(self.__class__.__dict__)
272        attrs.update(self.__dict__)
273        attrs.update(dir(self.db))
274        return sorted(attrs)
275
276    # Context manager methods
277
278    def __enter__(self):
279        """Enter the runtime context. This will start a transaction."""
280        self.begin()
281        return self
282
283    def __exit__(self, et, ev, tb):
284        """Exit the runtime context. This will end the transaction."""
285        if et is None and ev is None and tb is None:
286            self.commit()
287        else:
288            self.rollback()
289
290    # Auxiliary methods
291
292    def _do_debug(self, *args):
293        """Print a debug message."""
294        if self.debug:
295            s = '\n'.join(args)
296            if isinstance(self.debug, basestring):
297                print(self.debug % s)
298            elif hasattr(self.debug, 'write'):
299                self.debug.write(s + '\n')
300            elif callable(self.debug):
301                self.debug(s)
302            else:
303                print(s)
304
305    def _escape_qualified_name(self, s):
306        """Escape a qualified name.
307
308        Escapes the name for use as an SQL identifier, unless the
309        name contains a dot, in which case the name is ambiguous
310        (could be a qualified name or just a name with a dot in it)
311        and must be quoted manually by the caller.
312
313        """
314        if '.' not in s:
315            s = self.escape_identifier(s)
316        return s
317
318    @staticmethod
319    def _make_bool(d):
320        """Get boolean value corresponding to d."""
321        return bool(d) if get_bool() else ('t' if d else 'f')
322
323    _bool_true_values = frozenset('t true 1 y yes on'.split())
324
325    def _prepare_bool(self, d):
326        """Prepare a boolean parameter."""
327        if isinstance(d, basestring):
328            if not d:
329                return None
330            d = d.lower() in self._bool_true_values
331        return 't' if d else 'f'
332
333    _date_literals = frozenset('current_date current_time'
334        ' current_timestamp localtime localtimestamp'.split())
335
336    def _prepare_date(self, d):
337        """Prepare a date parameter."""
338        if not d:
339            return None
340        if isinstance(d, basestring) and d.lower() in self._date_literals:
341            raise ValueError
342        return d
343
344    def _prepare_num(self, d):
345        """Prepare a numeric parameter."""
346        if not d and d != 0:
347            return None
348        return d
349
350    def _prepare_bytea(self, d):
351        """Prepare a bytea parameter."""
352        return self.escape_bytea(d)
353
354    _prepare_funcs = dict(  # quote methods for each type
355        bool=_prepare_bool, date=_prepare_date,
356        int=_prepare_num, num=_prepare_num, float=_prepare_num,
357        money=_prepare_num, bytea=_prepare_bytea)
358
359    def _prepare_param(self, value, typ, params):
360        """Prepare and add a parameter to the list."""
361        if value is not None and typ != 'text':
362            try:
363                prepare = self._prepare_funcs[typ]
364            except KeyError:
365                pass
366            else:
367                try:
368                    value = prepare(self, value)
369                except ValueError:
370                    return value
371        params.append(value)
372        return '$%d' % len(params)
373
374    @staticmethod
375    def _prepare_qualified_param(name, param):
376        """Quote parameter representing a qualified name.
377
378        Escapes the name for use as an SQL parameter, unless the
379        name contains a dot, in which case the name is ambiguous
380        (could be a qualified name or just a name with a dot in it)
381        and must be quoted manually by the caller.
382
383        """
384        if isinstance(param, int):
385            param = "$%d" % param
386        if '.' not in name:
387            param = 'quote_ident(%s)' % (param,)
388        return param
389
390    # Public methods
391
392    # escape_string and escape_bytea exist as methods,
393    # so we define unescape_bytea as a method as well
394    unescape_bytea = staticmethod(unescape_bytea)
395
396    def close(self):
397        """Close the database connection."""
398        # Wraps shared library function so we can track state.
399        if self._closeable:
400            if self.db:
401                self.db.close()
402                self.db = None
403            else:
404                raise _int_error('Connection already closed')
405
406    def reset(self):
407        """Reset connection with current parameters.
408
409        All derived queries and large objects derived from this connection
410        will not be usable after this call.
411
412        """
413        if self.db:
414            self.db.reset()
415        else:
416            raise _int_error('Connection already closed')
417
418    def reopen(self):
419        """Reopen connection to the database.
420
421        Used in case we need another connection to the same database.
422        Note that we can still reopen a database that we have closed.
423
424        """
425        # There is no such shared library function.
426        if self._closeable:
427            db = connect(*self._args[0], **self._args[1])
428            if self.db:
429                self.db.close()
430            self.db = db
431
432    def begin(self, mode=None):
433        """Begin a transaction."""
434        qstr = 'BEGIN'
435        if mode:
436            qstr += ' ' + mode
437        return self.query(qstr)
438
439    start = begin
440
441    def commit(self):
442        """Commit the current transaction."""
443        return self.query('COMMIT')
444
445    end = commit
446
447    def rollback(self, name=None):
448        """Roll back the current transaction."""
449        qstr = 'ROLLBACK'
450        if name:
451            qstr += ' TO ' + name
452        return self.query(qstr)
453
454    def savepoint(self, name):
455        """Define a new savepoint within the current transaction."""
456        return self.query('SAVEPOINT ' + name)
457
458    def release(self, name):
459        """Destroy a previously defined savepoint."""
460        return self.query('RELEASE ' + name)
461
462    def query(self, qstr, *args):
463        """Executes a SQL command string.
464
465        This method simply sends a SQL query to the database. If the query is
466        an insert statement that inserted exactly one row into a table that
467        has OIDs, the return value is the OID of the newly inserted row.
468        If the query is an update or delete statement, or an insert statement
469        that did not insert exactly one row in a table with OIDs, then the
470        number of rows affected is returned as a string. If it is a statement
471        that returns rows as a result (usually a select statement, but maybe
472        also an "insert/update ... returning" statement), this method returns
473        a Query object that can be accessed via getresult() or dictresult()
474        or simply printed. Otherwise, it returns `None`.
475
476        The query can contain numbered parameters of the form $1 in place
477        of any data constant. Arguments given after the query string will
478        be substituted for the corresponding numbered parameter. Parameter
479        values can also be given as a single list or tuple argument.
480
481        """
482        # Wraps shared library function for debugging.
483        if not self.db:
484            raise _int_error('Connection is not valid')
485        self._do_debug(qstr)
486        return self.db.query(qstr, args)
487
488    def pkey(self, table, flush=False):
489        """This method gets or sets the primary key of a table.
490
491        Composite primary keys are represented as frozensets. Note that
492        this raises a KeyError if the table does not have a primary key.
493
494        If flush is set then the internal cache for primary keys will
495        be flushed. This may be necessary after the database schema or
496        the search path has been changed.
497
498        """
499        pkeys = self._pkeys
500        if flush:
501            pkeys.clear()
502            self._do_debug('pkey cache has been flushed')
503        try:  # cache lookup
504            pkey = pkeys[table]
505        except KeyError:  # cache miss, check the database
506            q = ("SELECT a.attname FROM pg_index i"
507                " JOIN pg_attribute a ON a.attrelid = i.indrelid"
508                " AND a.attnum = ANY(i.indkey)"
509                " AND NOT a.attisdropped"
510                " WHERE i.indrelid=%s::regclass"
511                " AND i.indisprimary") % (
512                    self._prepare_qualified_param(table, 1),)
513            pkey = self.db.query(q, (table,)).getresult()
514            if not pkey:
515                raise KeyError('Table %s has no primary key' % table)
516            if len(pkey) > 1:
517                pkey = frozenset(k[0] for k in pkey)
518            else:
519                pkey = pkey[0][0]
520            pkeys[table] = pkey  # cache it
521        return pkey
522
523    def get_databases(self):
524        """Get list of databases in the system."""
525        return [s[0] for s in
526            self.db.query('SELECT datname FROM pg_database').getresult()]
527
528    def get_relations(self, kinds=None):
529        """Get list of relations in connected database of specified kinds.
530
531        If kinds is None or empty, all kinds of relations are returned.
532        Otherwise kinds can be a string or sequence of type letters
533        specifying which kind of relations you want to list.
534
535        """
536        where = " AND r.relkind IN (%s)" % ','.join(
537            ["'%s'" % k for k in kinds]) if kinds else ''
538        q = ("SELECT quote_ident(s.nspname)||'.'||quote_ident(r.relname)"
539            " FROM pg_class r"
540            " JOIN pg_namespace s ON s.oid = r.relnamespace"
541            " WHERE s.nspname NOT SIMILAR"
542            " TO 'pg/_%%|information/_schema' ESCAPE '/' %s"
543            " ORDER BY s.nspname, r.relname") % where
544        return [r[0] for r in self.db.query(q).getresult()]
545
546    def get_tables(self):
547        """Return list of tables in connected database."""
548        return self.get_relations('r')
549
550    def get_attnames(self, table, flush=False):
551        """Given the name of a table, digs out the set of attribute names.
552
553        Returns a dictionary of attribute names (the names are the keys,
554        the values are the names of the attributes' types).
555
556        If the optional newattnames exists, it must be a dictionary and
557        will become the new attribute names dictionary.
558
559        By default, only a limited number of simple types will be returned.
560        You can get the regular types after calling use_regtypes(True).
561
562        """
563        attnames = self._attnames
564        if flush:
565            attnames.clear()
566            self._do_debug('pkey cache has been flushed')
567        try:  # cache lookup
568            names = attnames[table]
569        except KeyError:  # cache miss, check the database
570            q = ("SELECT a.attname, t.typname%s"
571                " FROM pg_attribute a"
572                " JOIN pg_type t ON t.oid = a.atttypid"
573                " WHERE a.attrelid = %s::regclass"
574                " AND (a.attnum > 0 OR a.attname = 'oid')"
575                " AND NOT a.attisdropped") % (
576                    '::regtype' if self._regtypes else '',
577                    self._prepare_qualified_param(table, 1))
578            names = self.db.query(q, (table,)).getresult()
579            if not names:
580                raise KeyError('Table %s does not exist' % table)
581            if self._regtypes:
582                names = dict(names)
583            else:
584                names = dict((name, _simpletype(typ)) for name, typ in names)
585            attnames[table] = names  # cache it
586        return names
587
588    def use_regtypes(self, regtypes=None):
589        """Use regular type names instead of simplified type names."""
590        if regtypes is None:
591            return self._regtypes
592        else:
593            regtypes = bool(regtypes)
594            if regtypes != self._regtypes:
595                self._regtypes = regtypes
596                self._attnames.clear()
597            return regtypes
598
599    def has_table_privilege(self, table, privilege='select'):
600        """Check whether current user has specified table privilege."""
601        privilege = privilege.lower()
602        try:  # ask cache
603            return self._privileges[(table, privilege)]
604        except KeyError:  # cache miss, ask the database
605            q = "SELECT has_table_privilege(%s, $2)" % (
606                self._prepare_qualified_param(table, 1),)
607            q = self.db.query(q, (table, privilege))
608            ret = q.getresult()[0][0] == self._make_bool(True)
609            self._privileges[(table, privilege)] = ret  # cache it
610            return ret
611
612    def get(self, table, row, keyname=None):
613        """Get a row from a database table or view.
614
615        This method is the basic mechanism to get a single row.  The keyname
616        that the key specifies a unique row.  If keyname is not specified
617        then the primary key for the table is used.  If row is a dictionary
618        then the value for the key is taken from it and it is modified to
619        include the new values, replacing existing values where necessary.
620        For a composite key, keyname can also be a sequence of key names.
621        The OID is also put into the dictionary if the table has one, but
622        in order to allow the caller to work with multiple tables, it is
623        munged as "oid(table)".
624
625        """
626        if table.endswith('*'):  # scan descendant tables?
627            table = table[:-1].rstrip()  # need parent table name
628        if not keyname:
629            # use the primary key by default
630            try:
631                keyname = self.pkey(table)
632            except KeyError:
633                raise _prg_error('Table %s has no primary key' % table)
634        attnames = self.get_attnames(table)
635        params = []
636        param = partial(self._prepare_param, params=params)
637        col = self.escape_identifier
638        # We want the oid for later updates if that isn't the key.
639        # To allow users to work with multiple tables, we munge
640        # the name of the "oid" key by adding the name of the table.
641        qoid = _oid_key(table)
642        if keyname == 'oid':
643            if isinstance(row, dict):
644                if qoid not in row:
645                    raise _db_error('%s not in row' % qoid)
646            else:
647                row = {qoid: row}
648            what = '*'
649            where = 'oid = %s' % param(row[qoid], 'int')
650        else:
651            keyname = [keyname] if isinstance(
652                keyname, basestring) else sorted(keyname)
653            if not isinstance(row, dict):
654                if len(keyname) > 1:
655                    raise _prg_error('Composite key needs dict as row')
656                row = dict((k, row) for k in keyname)
657            what = ', '.join(col(k) for k in attnames)
658            where = ' AND '.join('%s = %s' % (
659                col(k), param(row[k], attnames[k])) for k in keyname)
660        q = 'SELECT %s FROM %s WHERE %s LIMIT 1' % (
661            what, self._escape_qualified_name(table), where)
662        self._do_debug(q, params)
663        q = self.db.query(q, params)
664        res = q.dictresult()
665        if not res:
666            raise _db_error('No such record in %s where %s' % (table, where))
667        for n, value in res[0].items():
668            if n == 'oid':
669                n = qoid
670            elif attnames.get(n) == 'bytea':
671                value = self.unescape_bytea(value)
672            row[n] = value
673        return row
674
675    def insert(self, table, row=None, **kw):
676        """Insert a row into a database table.
677
678        This method inserts a row into a table.  The name of the table must
679        be passed as the first parameter.  The other parameters are used for
680        providing the data of the row that shall be inserted into the table.
681        If a dictionary is supplied as the second parameter, it starts with
682        that.  Otherwise it uses a blank dictionary. Either way the dictionary
683        is updated from the keywords.
684
685        The dictionary is then reloaded with the values actually inserted in
686        order to pick up values modified by rules, triggers, etc.
687
688        Note: The method currently doesn't support insert into views
689        although PostgreSQL does.
690
691        """
692        if 'oid' in kw:
693            del kw['oid']
694        if row is None:
695            row = {}
696        row.update(kw)
697        attnames = self.get_attnames(table)
698        params = []
699        param = partial(self._prepare_param, params=params)
700        col = self.escape_identifier
701        names, values = [], []
702        for n in attnames:
703            if n in row:
704                names.append(col(n))
705                values.append(param(row[n], attnames[n]))
706        names, values = ', '.join(names), ', '.join(values)
707        ret = 'oid, *' if 'oid' in attnames else '*'
708        q = 'INSERT INTO %s (%s) VALUES (%s) RETURNING %s' % (
709            self._escape_qualified_name(table), names, values, ret)
710        self._do_debug(q, params)
711        q = self.db.query(q, params)
712        res = q.dictresult()
713        if not res:
714            raise _int_error('insert did not return new values')
715        for n, value in res[0].items():
716            if n == 'oid':
717                n = _oid_key(table)
718            elif attnames.get(n) == 'bytea' and value is not None:
719                value = self.unescape_bytea(value)
720            row[n] = value
721        return row
722
723    def update(self, table, row=None, **kw):
724        """Update an existing row in a database table.
725
726        Similar to insert but updates an existing row.  The update is based
727        on the OID value as munged by get or passed as keyword, or on the
728        primary key of the table.  The dictionary is modified to reflect
729        any changes caused by the update due to triggers, rules, default
730        values, etc.
731
732        """
733        # Update always works on the oid which get() returns if available,
734        # otherwise use the primary key.  Fail if neither.
735        # Note that we only accept oid key from named args for safety.
736        qoid = _oid_key(table)
737        if 'oid' in kw:
738            kw[qoid] = kw['oid']
739            del kw['oid']
740        if row is None:
741            row = {}
742        row.update(kw)
743        attnames = self.get_attnames(table)
744        params = []
745        param = partial(self._prepare_param, params=params)
746        col = self.escape_identifier
747        if qoid in row:
748            where = 'oid = %s' % param(row[qoid], 'int')
749            keyname = []
750        else:
751            try:
752                keyname = self.pkey(table)
753            except KeyError:
754                raise _prg_error('Table %s has no primary key' % table)
755            keyname = [keyname] if isinstance(
756                keyname, basestring) else sorted(keyname)
757            try:
758                where = ' AND '.join('%s = %s' % (
759                    col(k), param(row[k], attnames[k])) for k in keyname)
760            except KeyError:
761                raise _prg_error('update needs primary key or oid')
762        keyname = set(keyname)
763        keyname.add('oid')
764        values = []
765        for n in attnames:
766            if n in row and n not in keyname:
767                values.append('%s = %s' % (col(n), param(row[n], attnames[n])))
768        if not values:
769            return row
770        values = ', '.join(values)
771        ret = 'oid, *' if 'oid' in attnames else '*'
772        q = 'UPDATE %s SET %s WHERE %s RETURNING %s' % (
773            self._escape_qualified_name(table), values, where, ret)
774        self._do_debug(q, params)
775        q = self.db.query(q, params)
776        res = q.dictresult()
777        if res:  # may be empty when row does not exist
778            for n, value in res[0].items():
779                if n == 'oid':
780                    n = qoid
781                elif attnames.get(n) == 'bytea' and value is not None:
782                    value = self.unescape_bytea(value)
783                row[n] = value
784        return row
785
786    def upsert(self, table, row=None, **kw):
787        """Insert a row into a database table with conflict resolution.
788
789        This method inserts a row into a table, but instead of raising a
790        ProgrammingError exception in case a row with the same primary key
791        already exists, an update will be executed instead.  This will be
792        performed as a single atomic operation on the database, so race
793        conditions can be avoided.
794
795        Like the insert method, the first parameter is the name of the
796        table and the second parameter can be used to pass the values to
797        be inserted as a dictionary.
798
799        Unlike the insert und update statement, keyword parameters are not
800        used to modify the dictionary, but to specify which columns shall
801        be updated in case of a conflict, and in which way:
802
803        A value of False or None means the column shall not be updated,
804        a value of True means the column shall be updated with the value
805        that has been proposed for insertion, i.e. has been passed as value
806        in the dictionary.  Columns that are not specified by keywords but
807        appear as keys in the dictionary are also updated like in the case
808        keywords had been passed with the value True.
809
810        So if in the case of a conflict you want to update every column that
811        has been passed in the dictionary row , you would call upsert(table, row).
812        If you don't want to do anything in case of a conflict, i.e. leave
813        the existing row as it is, call upsert(table, row, **dict.fromkeys(row)).
814
815        If you need more fine-grained control of what gets updated, you can
816        also pass strings in the keyword parameters.  These strings will
817        be used as SQL expressions for the update columns.  In these
818        expressions you can refer to the value that already exists in
819        the table by prefixing the column name with "included.", and to
820        the value that has been proposed for insertion by prefixing the
821        column name with the "excluded."
822
823        The dictionary is modified in any case to reflect the values in
824        the database after the operation has completed.
825
826        Note: The method uses the PostgreSQL "upsert" feature which is
827        only available since PostgreSQL 9.5.
828
829        """
830        if 'oid' in kw:
831            del kw['oid']
832        if row is None:
833            row = {}
834        attnames = self.get_attnames(table)
835        params = []
836        param = partial(self._prepare_param,params=params)
837        col = self.escape_identifier
838        names, values, updates = [], [], []
839        for n in attnames:
840            if n in row:
841                names.append(col(n))
842                values.append(param(row[n], attnames[n]))
843        names, values = ', '.join(names), ', '.join(values)
844        try:
845            keyname = self.pkey(table)
846        except KeyError:
847            raise _prg_error('Table %s has no primary key' % table)
848        keyname = [keyname] if isinstance(
849            keyname, basestring) else sorted(keyname)
850        try:
851            target = ', '.join(col(k) for k in keyname)
852        except KeyError:
853            raise _prg_error('upsert needs primary key or oid')
854        update = []
855        keyname = set(keyname)
856        keyname.add('oid')
857        for n in attnames:
858            if n not in keyname:
859                value = kw.get(n, True)
860                if value:
861                    if not isinstance(value, basestring):
862                        value = 'excluded.%s' % col(n)
863                    update.append('%s = %s' % (col(n), value))
864        if not values and not update:
865            return row
866        do = 'update set %s' % ', '.join(update) if update else 'nothing'
867        ret = 'oid, *' if 'oid' in attnames else '*'
868        q = ('INSERT INTO %s AS included (%s) VALUES (%s)'
869            ' ON CONFLICT (%s) DO %s RETURNING %s') % (
870                self._escape_qualified_name(table), names, values,
871                target, do, ret)
872        self._do_debug(q, params)
873        try:
874            q = self.db.query(q, params)
875        except ProgrammingError:
876            if self.server_version < 90500:
877                raise _prg_error('upsert not supported by PostgreSQL version')
878            raise  # re-raise original error
879        res = q.dictresult()
880        if res:  # may be empty with "do nothing"
881            for n, value in res[0].items():
882                if n == 'oid':
883                    n = _oid_key(table)
884                elif attnames.get(n) == 'bytea':
885                    value = self.unescape_bytea(value)
886                row[n] = value
887        elif update:
888            raise _int_error('upsert did not return new values')
889        else:
890            self.get(table, row)
891        return row
892
893    def clear(self, table, row=None):
894        """Clear all the attributes to values determined by the types.
895
896        Numeric types are set to 0, Booleans are set to false, and everything
897        else is set to the empty string.  If the row argument is present,
898        it is used as the row dictionary and any entries matching attribute
899        names are cleared with everything else left unchanged.
900
901        """
902        # At some point we will need a way to get defaults from a table.
903        if row is None:
904            row = {}  # empty if argument is not present
905        attnames = self.get_attnames(table)
906        for n, t in attnames.items():
907            if n == 'oid':
908                continue
909            if t in ('int', 'integer', 'smallint', 'bigint',
910                    'float', 'real', 'double precision',
911                    'num', 'numeric', 'money'):
912                row[n] = 0
913            elif t in ('bool', 'boolean'):
914                row[n] = self._make_bool(False)
915            else:
916                row[n] = ''
917        return row
918
919    def delete(self, table, row=None, **kw):
920        """Delete an existing row in a database table.
921
922        This method deletes the row from a table.  It deletes based on the
923        OID value as munged by get or passed as keyword, or on the primary
924        key of the table.  The return value is the number of deleted rows
925        (i.e. 0 if the row did not exist and 1 if the row was deleted).
926
927        """
928        # Like update, delete works on the oid.
929        # One day we will be testing that the record to be deleted
930        # isn't referenced somewhere (or else PostgreSQL will).
931        # Note that we only accept oid key from named args for safety.
932        qoid = _oid_key(table)
933        if 'oid' in kw:
934            kw[qoid] = kw['oid']
935            del kw['oid']
936        if row is None:
937            row = {}
938        row.update(kw)
939        params = []
940        param = partial(self._prepare_param, params=params)
941        if qoid in row:
942            where = 'oid = %s' % param(row[qoid], 'int')
943        else:
944            try:
945                keyname = self.pkey(table)
946            except KeyError:
947                raise _prg_error('Table %s has no primary key' % table)
948            keyname = [keyname] if isinstance(
949                keyname, basestring) else sorted(keyname)
950            attnames = self.get_attnames(table)
951            col = self.escape_identifier
952            try:
953                where = ' AND '.join('%s = %s' % (
954                    col(k), param(row[k], attnames[k])) for k in keyname)
955            except KeyError:
956                raise _prg_error('delete needs primary key or oid')
957        q = 'DELETE FROM %s WHERE %s' % (
958            self._escape_qualified_name(table), where)
959        self._do_debug(q, params)
960        res = self.db.query(q, params)
961        return int(res)
962
963    def notification_handler(self, event, callback, arg_dict={}, timeout=None):
964        """Get notification handler that will run the given callback."""
965        return NotificationHandler(self.db, event, callback, arg_dict, timeout)
966
967
968# if run as script, print some information
969
970if __name__ == '__main__':
971    print('PyGreSQL version' + version)
972    print('')
973    print(__doc__)
Note: See TracBrowser for help on using the repository browser.