source: trunk/pg.py @ 737

Last change on this file since 737 was 737, checked in by cito, 4 years ago

Update year in the copyright in two more files

  • Property svn:keywords set to Id
File size: 35.1 KB
Line 
1#! /usr/bin/python
2#
3# pg.py
4#
5# $Id: pg.py 737 2016-01-13 22:16:53Z cito $
6#
7
8"""PyGreSQL classic interface.
9
10This pg module implements some basic database management stuff.
11It includes the _pg module and builds on it, providing the higher
12level wrapper class named DB with additional functionality.
13This is known as the "classic" ("old style") PyGreSQL interface.
14For a DB-API 2 compliant interface use the newer pgdb module.
15
16"""
17
18# Copyright (c) 1997-2016 by D'Arcy J.M. Cain.
19#
20# Contributions made by Ch. Zwerschke and others.
21#
22# The notification handler is based on pgnotify which is
23# Copyright (c) 2001 Ng Pheng Siong. All rights reserved.
24#
25# Permission to use, copy, modify, and distribute this software and its
26# documentation for any purpose and without fee is hereby granted,
27# provided that the above copyright notice appear in all copies and that
28# both that copyright notice and this permission notice appear in
29# supporting documentation.
30
31from __future__ import print_function
32
33from _pg import *
34
35import select
36import warnings
37
38from decimal import Decimal
39from collections import namedtuple
40from functools import partial
41
42try:
43    basestring
44except NameError:  # Python >= 3.0
45    basestring = (str, bytes)
46
47set_decimal(Decimal)
48
49
50# Auxiliary functions that are independent from a DB connection:
51
52def _oid_key(cl):
53    """Build oid key from a class name."""
54    return 'oid(%s)' % cl
55
56
57def _simpletype(typ):
58    """Determine a simplified name a pg_type name."""
59    if typ.startswith('bool'):
60        return 'bool'
61    if typ.startswith(('abstime', 'date', 'interval', 'timestamp')):
62        return 'date'
63    if typ.startswith(('cid', 'oid', 'int', 'xid')):
64        return 'int'
65    if typ.startswith('float'):
66        return 'float'
67    if typ.startswith('numeric'):
68        return 'num'
69    if typ.startswith('money'):
70        return 'money'
71    if typ.startswith('bytea'):
72        return 'bytea'
73    return 'text'
74
75
76def _namedresult(q):
77    """Get query result as named tuples."""
78    row = namedtuple('Row', q.listfields())
79    return [row(*r) for r in q.getresult()]
80
81set_namedresult(_namedresult)
82
83
84def _db_error(msg, cls=DatabaseError):
85    """Returns DatabaseError with empty sqlstate attribute."""
86    error = cls(msg)
87    error.sqlstate = None
88    return error
89
90
91def _int_error(msg):
92    """Returns InternalError."""
93    return _db_error(msg, InternalError)
94
95
96def _prg_error(msg):
97    """Returns ProgrammingError."""
98    return _db_error(msg, ProgrammingError)
99
100
101class NotificationHandler(object):
102    """A PostgreSQL client-side asynchronous notification handler."""
103
104    def __init__(self, db, event, callback, arg_dict=None, timeout=None):
105        """Initialize the notification handler.
106
107        db       - PostgreSQL connection object.
108        event    - Event (notification channel) to LISTEN for.
109        callback - Event callback function.
110        arg_dict - A dictionary passed as the argument to the callback.
111        timeout  - Timeout in seconds; a floating point number denotes
112                   fractions of seconds. If it is absent or None, the
113                   callers will never time out.
114
115        """
116        if isinstance(db, DB):
117            db = db.db
118        self.db = db
119        self.event = event
120        self.stop_event = 'stop_%s' % event
121        self.listening = False
122        self.callback = callback
123        if arg_dict is None:
124            arg_dict = {}
125        self.arg_dict = arg_dict
126        self.timeout = timeout
127
128    def __del__(self):
129        self.close()
130
131    def close(self):
132        """Stop listening and close the connection."""
133        if self.db:
134            self.unlisten()
135            self.db.close()
136            self.db = None
137
138    def listen(self):
139        """Start listening for the event and the stop event."""
140        if not self.listening:
141            self.db.query('listen "%s"' % self.event)
142            self.db.query('listen "%s"' % self.stop_event)
143            self.listening = True
144
145    def unlisten(self):
146        """Stop listening for the event and the stop event."""
147        if self.listening:
148            self.db.query('unlisten "%s"' % self.event)
149            self.db.query('unlisten "%s"' % self.stop_event)
150            self.listening = False
151
152    def notify(self, db=None, stop=False, payload=None):
153        """Generate a notification.
154
155        Note: If the main loop is running in another thread, you must pass
156        a different database connection to avoid a collision.
157
158        """
159        if not db:
160            db = self.db
161        if self.listening:
162            q = 'notify "%s"' % (self.stop_event if stop else self.event)
163            if payload:
164                q += ", '%s'" % payload
165            return db.query(q)
166
167    def __call__(self, close=False):
168        """Invoke the notification handler.
169
170        The handler is a loop that actually LISTENs for two NOTIFY messages:
171
172        <event> and stop_<event>.
173
174        When either of these NOTIFY messages are received, its associated
175        'pid' and 'event' are inserted into <arg_dict>, and the callback is
176        invoked with <arg_dict>. If the NOTIFY message is stop_<event>, the
177        handler UNLISTENs both <event> and stop_<event> and exits.
178
179        Note: If you run this loop in another thread, don't use the same
180        database connection for database operations in the main thread.
181
182        """
183        self.listen()
184        _ilist = [self.db.fileno()]
185
186        while self.listening:
187            ilist, _olist, _elist = select.select(_ilist, [], [], self.timeout)
188            if ilist:
189                while self.listening:
190                    notice = self.db.getnotify()
191                    if not notice:  # no more messages
192                        break
193                    event, pid, extra = notice
194                    if event not in (self.event, self.stop_event):
195                        self.unlisten()
196                        raise _db_error(
197                            'listening for "%s" and "%s", but notified of "%s"'
198                            % (self.event, self.stop_event, event))
199                    if event == self.stop_event:
200                        self.unlisten()
201                    self.arg_dict['pid'] = pid
202                    self.arg_dict['event'] = event
203                    self.arg_dict['extra'] = extra
204                    self.callback(self.arg_dict)
205            else:   # we timed out
206                self.unlisten()
207                self.callback(None)
208
209
210def pgnotify(*args, **kw):
211    """Same as NotificationHandler, under the traditional name."""
212    warnings.warn("pgnotify is deprecated, use NotificationHandler instead.",
213        DeprecationWarning, stacklevel=2)
214    return NotificationHandler(*args, **kw)
215
216
217# The actual PostGreSQL database connection interface:
218
219class DB(object):
220    """Wrapper class for the _pg connection type."""
221
222    def __init__(self, *args, **kw):
223        """Create a new connection.
224
225        You can pass either the connection parameters or an existing
226        _pg or pgdb connection. This allows you to use the methods
227        of the classic pg interface with a DB-API 2 pgdb connection.
228
229        """
230        if not args and len(kw) == 1:
231            db = kw.get('db')
232        elif not kw and len(args) == 1:
233            db = args[0]
234        else:
235            db = None
236        if db:
237            if isinstance(db, DB):
238                db = db.db
239            else:
240                try:
241                    db = db._cnx
242                except AttributeError:
243                    pass
244        if not db or not hasattr(db, 'db') or not hasattr(db, 'query'):
245            db = connect(*args, **kw)
246            self._closeable = True
247        else:
248            self._closeable = False
249        self.db = db
250        self.dbname = db.db
251        self._regtypes = False
252        self._attnames = {}
253        self._pkeys = {}
254        self._privileges = {}
255        self._args = args, kw
256        self.debug = None  # For debugging scripts, this can be set
257            # * to a string format specification (e.g. in CGI set to "%s<BR>"),
258            # * to a file object to write debug statements or
259            # * to a callable object which takes a string argument
260            # * to any other true value to just print debug statements
261
262    def __getattr__(self, name):
263        # All undefined members are same as in underlying connection:
264        if self.db:
265            return getattr(self.db, name)
266        else:
267            raise _int_error('Connection is not valid')
268
269    def __dir__(self):
270        # Custom dir function including the attributes of the connection:
271        attrs = set(self.__class__.__dict__)
272        attrs.update(self.__dict__)
273        attrs.update(dir(self.db))
274        return sorted(attrs)
275
276    # Context manager methods
277
278    def __enter__(self):
279        """Enter the runtime context. This will start a transaction."""
280        self.begin()
281        return self
282
283    def __exit__(self, et, ev, tb):
284        """Exit the runtime context. This will end the transaction."""
285        if et is None and ev is None and tb is None:
286            self.commit()
287        else:
288            self.rollback()
289
290    # Auxiliary methods
291
292    def _do_debug(self, *args):
293        """Print a debug message."""
294        if self.debug:
295            s = '\n'.join(args)
296            if isinstance(self.debug, basestring):
297                print(self.debug % s)
298            elif hasattr(self.debug, 'write'):
299                self.debug.write(s + '\n')
300            elif callable(self.debug):
301                self.debug(s)
302            else:
303                print(s)
304
305    def _escape_qualified_name(self, s):
306        """Escape a qualified name.
307
308        Escapes the name for use as an SQL identifier, unless the
309        name contains a dot, in which case the name is ambiguous
310        (could be a qualified name or just a name with a dot in it)
311        and must be quoted manually by the caller.
312
313        """
314        if '.' not in s:
315            s = self.escape_identifier(s)
316        return s
317
318    @staticmethod
319    def _make_bool(d):
320        """Get boolean value corresponding to d."""
321        return bool(d) if get_bool() else ('t' if d else 'f')
322
323    _bool_true_values = frozenset('t true 1 y yes on'.split())
324
325    def _prepare_bool(self, d):
326        """Prepare a boolean parameter."""
327        if isinstance(d, basestring):
328            if not d:
329                return None
330            d = d.lower() in self._bool_true_values
331        return 't' if d else 'f'
332
333    _date_literals = frozenset('current_date current_time'
334        ' current_timestamp localtime localtimestamp'.split())
335
336    def _prepare_date(self, d):
337        """Prepare a date parameter."""
338        if not d:
339            return None
340        if isinstance(d, basestring) and d.lower() in self._date_literals:
341            raise ValueError
342        return d
343
344    def _prepare_num(self, d):
345        """Prepare a numeric parameter."""
346        if not d and d != 0:
347            return None
348        return d
349
350    def _prepare_bytea(self, d):
351        """Prepare a bytea parameter."""
352        return self.escape_bytea(d)
353
354    _prepare_funcs = dict(  # quote methods for each type
355        bool=_prepare_bool, date=_prepare_date,
356        int=_prepare_num, num=_prepare_num, float=_prepare_num,
357        money=_prepare_num, bytea=_prepare_bytea)
358
359    def _prepare_param(self, value, typ, params):
360        """Prepare and add a parameter to the list."""
361        if value is not None and typ != 'text':
362            try:
363                prepare = self._prepare_funcs[typ]
364            except KeyError:
365                pass
366            else:
367                try:
368                    value = prepare(self, value)
369                except ValueError:
370                    return value
371        params.append(value)
372        return '$%d' % len(params)
373
374    @staticmethod
375    def _prepare_qualified_param(cl, param):
376        """Quote parameter representing a qualified name.
377
378        Escapes the name for use as an SQL parameter, unless the
379        name contains a dot, in which case the name is ambiguous
380        (could be a qualified name or just a name with a dot in it)
381        and must be quoted manually by the caller.
382
383        """
384        if isinstance(param, int):
385            param = "$%d" % param
386        if '.' not in cl:
387            param = 'quote_ident(%s)' % (param,)
388        return param
389
390    # Public methods
391
392    # escape_string and escape_bytea exist as methods,
393    # so we define unescape_bytea as a method as well
394    unescape_bytea = staticmethod(unescape_bytea)
395
396    def close(self):
397        """Close the database connection."""
398        # Wraps shared library function so we can track state.
399        if self._closeable:
400            if self.db:
401                self.db.close()
402                self.db = None
403            else:
404                raise _int_error('Connection already closed')
405
406    def reset(self):
407        """Reset connection with current parameters.
408
409        All derived queries and large objects derived from this connection
410        will not be usable after this call.
411
412        """
413        if self.db:
414            self.db.reset()
415        else:
416            raise _int_error('Connection already closed')
417
418    def reopen(self):
419        """Reopen connection to the database.
420
421        Used in case we need another connection to the same database.
422        Note that we can still reopen a database that we have closed.
423
424        """
425        # There is no such shared library function.
426        if self._closeable:
427            db = connect(*self._args[0], **self._args[1])
428            if self.db:
429                self.db.close()
430            self.db = db
431
432    def begin(self, mode=None):
433        """Begin a transaction."""
434        qstr = 'BEGIN'
435        if mode:
436            qstr += ' ' + mode
437        return self.query(qstr)
438
439    start = begin
440
441    def commit(self):
442        """Commit the current transaction."""
443        return self.query('COMMIT')
444
445    end = commit
446
447    def rollback(self, name=None):
448        """Roll back the current transaction."""
449        qstr = 'ROLLBACK'
450        if name:
451            qstr += ' TO ' + name
452        return self.query(qstr)
453
454    def savepoint(self, name):
455        """Define a new savepoint within the current transaction."""
456        return self.query('SAVEPOINT ' + name)
457
458    def release(self, name):
459        """Destroy a previously defined savepoint."""
460        return self.query('RELEASE ' + name)
461
462    def query(self, qstr, *args):
463        """Executes a SQL command string.
464
465        This method simply sends a SQL query to the database. If the query is
466        an insert statement that inserted exactly one row into a table that
467        has OIDs, the return value is the OID of the newly inserted row.
468        If the query is an update or delete statement, or an insert statement
469        that did not insert exactly one row in a table with OIDs, then the
470        number of rows affected is returned as a string. If it is a statement
471        that returns rows as a result (usually a select statement, but maybe
472        also an "insert/update ... returning" statement), this method returns
473        a Query object that can be accessed via getresult() or dictresult()
474        or simply printed. Otherwise, it returns `None`.
475
476        The query can contain numbered parameters of the form $1 in place
477        of any data constant. Arguments given after the query string will
478        be substituted for the corresponding numbered parameter. Parameter
479        values can also be given as a single list or tuple argument.
480
481        """
482        # Wraps shared library function for debugging.
483        if not self.db:
484            raise _int_error('Connection is not valid')
485        self._do_debug(qstr)
486        return self.db.query(qstr, args)
487
488    def pkey(self, cl, flush=False):
489        """This method gets or sets the primary key of a class.
490
491        Composite primary keys are represented as frozensets. Note that
492        this raises a KeyError if the table does not have a primary key.
493
494        If flush is set then the internal cache for primary keys will
495        be flushed. This may be necessary after the database schema or
496        the search path has been changed.
497
498        """
499        pkeys = self._pkeys
500        if flush:
501            pkeys.clear()
502            self._do_debug('pkey cache has been flushed')
503        try:  # cache lookup
504            pkey = pkeys[cl]
505        except KeyError:  # cache miss, check the database
506            q = ("SELECT a.attname FROM pg_index i"
507                " JOIN pg_attribute a ON a.attrelid = i.indrelid"
508                " AND a.attnum = ANY(i.indkey)"
509                " AND NOT a.attisdropped"
510                " WHERE i.indrelid=%s::regclass"
511                " AND i.indisprimary" % self._prepare_qualified_param(cl, 1))
512            pkey = self.db.query(q, (cl,)).getresult()
513            if not pkey:
514                raise KeyError('Class %s has no primary key' % cl)
515            if len(pkey) > 1:
516                pkey = frozenset(k[0] for k in pkey)
517            else:
518                pkey = pkey[0][0]
519            pkeys[cl] = pkey  # cache it
520        return pkey
521
522    def get_databases(self):
523        """Get list of databases in the system."""
524        return [s[0] for s in
525            self.db.query('SELECT datname FROM pg_database').getresult()]
526
527    def get_relations(self, kinds=None):
528        """Get list of relations in connected database of specified kinds.
529
530        If kinds is None or empty, all kinds of relations are returned.
531        Otherwise kinds can be a string or sequence of type letters
532        specifying which kind of relations you want to list.
533
534        """
535        where = " AND r.relkind IN (%s)" % ','.join(
536            ["'%s'" % k for k in kinds]) if kinds else ''
537        q = ("SELECT quote_ident(s.nspname)||'.'||quote_ident(r.relname)"
538            " FROM pg_class r"
539            " JOIN pg_namespace s ON s.oid = r.relnamespace"
540            " WHERE s.nspname NOT SIMILAR"
541            " TO 'pg/_%%|information/_schema' ESCAPE '/' %s"
542            " ORDER BY s.nspname, r.relname") % where
543        return [r[0] for r in self.db.query(q).getresult()]
544
545    def get_tables(self):
546        """Return list of tables in connected database."""
547        return self.get_relations('r')
548
549    def get_attnames(self, cl, flush=False):
550        """Given the name of a table, digs out the set of attribute names.
551
552        Returns a dictionary of attribute names (the names are the keys,
553        the values are the names of the attributes' types).
554
555        If the optional newattnames exists, it must be a dictionary and
556        will become the new attribute names dictionary.
557
558        By default, only a limited number of simple types will be returned.
559        You can get the regular types after calling use_regtypes(True).
560
561        """
562        attnames = self._attnames
563        if flush:
564            attnames.clear()
565            self._do_debug('pkey cache has been flushed')
566        try:  # cache lookup
567            names = attnames[cl]
568        except KeyError:  # cache miss, check the database
569            q = ("SELECT a.attname, t.typname%s"
570                " FROM pg_attribute a"
571                " JOIN pg_type t ON t.oid = a.atttypid"
572                " WHERE a.attrelid = %s::regclass"
573                " AND (a.attnum > 0 OR a.attname = 'oid')"
574                " AND NOT a.attisdropped") % (
575                    '::regtype' if self._regtypes else '',
576                    self._prepare_qualified_param(cl, 1))
577            names = self.db.query(q, (cl,)).getresult()
578            if not names:
579                raise KeyError('Class %s does not exist' % cl)
580            if self._regtypes:
581                names = dict(names)
582            else:
583                names = dict((name, _simpletype(typ)) for name, typ in names)
584            attnames[cl] = names  # cache it
585        return names
586
587    def use_regtypes(self, regtypes=None):
588        """Use regular type names instead of simplified type names."""
589        if regtypes is None:
590            return self._regtypes
591        else:
592            regtypes = bool(regtypes)
593            if regtypes != self._regtypes:
594                self._regtypes = regtypes
595                self._attnames.clear()
596            return regtypes
597
598    def has_table_privilege(self, cl, privilege='select'):
599        """Check whether current user has specified table privilege."""
600        privilege = privilege.lower()
601        try:  # ask cache
602            return self._privileges[(cl, privilege)]
603        except KeyError:  # cache miss, ask the database
604            q = "SELECT has_table_privilege(%s, $2)" % (
605                self._prepare_qualified_param(cl, 1),)
606            q = self.db.query(q, (cl, privilege))
607            ret = q.getresult()[0][0] == self._make_bool(True)
608            self._privileges[(cl, privilege)] = ret  # cache it
609            return ret
610
611    def get(self, cl, arg, keyname=None):
612        """Get a row from a database table or view.
613
614        This method is the basic mechanism to get a single row.  The keyname
615        that the key specifies a unique row.  If keyname is not specified
616        then the primary key for the table is used.  If arg is a dictionary
617        then the value for the key is taken from it and it is modified to
618        include the new values, replacing existing values where necessary.
619        For a composite key, keyname can also be a sequence of key names.
620        The OID is also put into the dictionary if the table has one, but
621        in order to allow the caller to work with multiple tables, it is
622        munged as "oid(cl)".
623
624        """
625        if cl.endswith('*'):  # scan descendant tables?
626            cl = cl[:-1].rstrip()  # need parent table name
627        if not keyname:
628            # use the primary key by default
629            try:
630                keyname = self.pkey(cl)
631            except KeyError:
632                raise _prg_error('Class %s has no primary key' % cl)
633        attnames = self.get_attnames(cl)
634        params = []
635        param = partial(self._prepare_param, params=params)
636        col = self.escape_identifier
637        # We want the oid for later updates if that isn't the key.
638        # To allow users to work with multiple tables, we munge
639        # the name of the "oid" key by adding the name of the class.
640        qoid = _oid_key(cl)
641        if keyname == 'oid':
642            if isinstance(arg, dict):
643                if qoid not in arg:
644                    raise _db_error('%s not in arg' % qoid)
645            else:
646                arg = {qoid: arg}
647            what = '*'
648            where = 'oid = %s' % param(arg[qoid], 'int')
649        else:
650            keyname = [keyname] if isinstance(
651                keyname, basestring) else sorted(keyname)
652            if not isinstance(arg, dict):
653                if len(keyname) > 1:
654                    raise _prg_error('Composite key needs dict as arg')
655                arg = dict((k, arg) for k in keyname)
656            what = ', '.join(col(k) for k in attnames)
657            where = ' AND '.join('%s = %s' % (
658                col(k), param(arg[k], attnames[k])) for k in keyname)
659        q = 'SELECT %s FROM %s WHERE %s LIMIT 1' % (
660            what, self._escape_qualified_name(cl), where)
661        self._do_debug(q, params)
662        q = self.db.query(q, params)
663        res = q.dictresult()
664        if not res:
665            raise _db_error('No such record in %s where %s' % (cl, where))
666        for n, value in res[0].items():
667            if n == 'oid':
668                n = qoid
669            elif attnames.get(n) == 'bytea':
670                value = self.unescape_bytea(value)
671            arg[n] = value
672        return arg
673
674    def insert(self, cl, d=None, **kw):
675        """Insert a row into a database table.
676
677        This method inserts a row into a table.  The name of the table must
678        be passed as the first parameter.  The other parameters are used for
679        providing the data of the row that shall be inserted into the table.
680        If a dictionary is supplied as the second parameter, it starts with
681        that.  Otherwise it uses a blank dictionary. Either way the dictionary
682        is updated from the keywords.
683
684        The dictionary is then reloaded with the values actually inserted in
685        order to pick up values modified by rules, triggers, etc.
686
687        Note: The method currently doesn't support insert into views
688        although PostgreSQL does.
689
690        """
691        if 'oid' in kw:
692            del kw['oid']
693        if d is None:
694            d = {}
695        d.update(kw)
696        attnames = self.get_attnames(cl)
697        params = []
698        param = partial(self._prepare_param, params=params)
699        col = self.escape_identifier
700        names, values = [], []
701        for n in attnames:
702            if n in d:
703                names.append(col(n))
704                values.append(param(d[n], attnames[n]))
705        names, values = ', '.join(names), ', '.join(values)
706        ret = 'oid, *' if 'oid' in attnames else '*'
707        q = 'INSERT INTO %s (%s) VALUES (%s) RETURNING %s' % (
708            self._escape_qualified_name(cl), names, values, ret)
709        self._do_debug(q, params)
710        q = self.db.query(q, params)
711        res = q.dictresult()
712        if not res:
713            raise _int_error('insert did not return new values')
714        for n, value in res[0].items():
715            if n == 'oid':
716                n = _oid_key(cl)
717            elif attnames.get(n) == 'bytea' and value is not None:
718                value = self.unescape_bytea(value)
719            d[n] = value
720        return d
721
722    def update(self, cl, d=None, **kw):
723        """Update an existing row in a database table.
724
725        Similar to insert but updates an existing row.  The update is based
726        on the OID value as munged by get or passed as keyword, or on the
727        primary key of the table.  The dictionary is modified to reflect
728        any changes caused by the update due to triggers, rules, default
729        values, etc.
730
731        """
732        # Update always works on the oid which get() returns if available,
733        # otherwise use the primary key.  Fail if neither.
734        # Note that we only accept oid key from named args for safety.
735        qoid = _oid_key(cl)
736        if 'oid' in kw:
737            kw[qoid] = kw['oid']
738            del kw['oid']
739        if d is None:
740            d = {}
741        d.update(kw)
742        attnames = self.get_attnames(cl)
743        params = []
744        param = partial(self._prepare_param, params=params)
745        col = self.escape_identifier
746        if qoid in d:
747            where = 'oid = %s' % param(d[qoid], 'int')
748            keyname = []
749        else:
750            try:
751                keyname = self.pkey(cl)
752            except KeyError:
753                raise _prg_error('Class %s has no primary key' % cl)
754            keyname = [keyname] if isinstance(
755                keyname, basestring) else sorted(keyname)
756            try:
757                where = ' AND '.join('%s = %s' % (
758                    col(k), param(d[k], attnames[k])) for k in keyname)
759            except KeyError:
760                raise _prg_error('update needs primary key or oid')
761        keyname = set(keyname)
762        keyname.add('oid')
763        values = []
764        for n in attnames:
765            if n in d and n not in keyname:
766                values.append('%s = %s' % (col(n), param(d[n], attnames[n])))
767        if not values:
768            return d
769        values = ', '.join(values)
770        ret = 'oid, *' if 'oid' in attnames else '*'
771        q = 'UPDATE %s SET %s WHERE %s RETURNING %s' % (
772            self._escape_qualified_name(cl), values, where, ret)
773        self._do_debug(q, params)
774        q = self.db.query(q, params)
775        res = q.dictresult()
776        if res:  # may be empty when row does not exist
777            for n, value in res[0].items():
778                if n == 'oid':
779                    n = qoid
780                elif attnames.get(n) == 'bytea' and value is not None:
781                    value = self.unescape_bytea(value)
782                d[n] = value
783        return d
784
785    def upsert(self, cl, d=None, **kw):
786        """Insert a row into a database table with conflict resolution.
787
788        This method inserts a row into a table, but instead of raising a
789        ProgrammingError exception in case a row with the same primary key
790        already exists, an update will be executed instead.  This will be
791        performed as a single atomic operation on the database, so race
792        conditions can be avoided.
793
794        Like the insert method, the first parameter is the name of the
795        table and the second parameter can be used to pass the values to
796        be inserted as a dictionary.
797
798        Unlike the insert und update statement, keyword parameters are not
799        used to modify the dictionary, but to specify which columns shall
800        be updated in case of a conflict, and in which way:
801
802        A value of False or None means the column shall not be updated,
803        a value of True means the column shall be updated with the value
804        that has been proposed for insertion, i.e. has been passed as value
805        in the dictionary.  Columns that are not specified by keywords but
806        appear as keys in the dictionary are also updated like in the case
807        keywords had been passed with the value True.
808
809        So if in the case of a conflict you want to update every column that
810        has been passed in the dictionary d , you would call upsert(cl, d).
811        If you don't want to do anything in case of a conflict, i.e. leave
812        the existing row as it is, call upsert(cl, d, **dict.fromkeys(d)).
813
814        If you need more fine-grained control of what gets updated, you can
815        also pass strings in the keyword parameters.  These strings will
816        be used as SQL expressions for the update columns.  In these
817        expressions you can refer to the value that already exists in
818        the table by prefixing the column name with "included.", and to
819        the value that has been proposed for insertion by prefixing the
820        column name with the "excluded."
821
822        The dictionary is modified in any case to reflect the values in
823        the database after the operation has completed.
824
825        Note: The method uses the PostgreSQL "upsert" feature which is
826        only available since PostgreSQL 9.5.
827
828        """
829        if 'oid' in kw:
830            del kw['oid']
831        if d is None:
832            d = {}
833        attnames = self.get_attnames(cl)
834        params = []
835        param = partial(self._prepare_param,params=params)
836        col = self.escape_identifier
837        names, values, updates = [], [], []
838        for n in attnames:
839            if n in d:
840                names.append(col(n))
841                values.append(param(d[n], attnames[n]))
842        names, values = ', '.join(names), ', '.join(values)
843        try:
844            keyname = self.pkey(cl)
845        except KeyError:
846            raise _prg_error('Class %s has no primary key' % cl)
847        keyname = [keyname] if isinstance(
848            keyname, basestring) else sorted(keyname)
849        try:
850            target = ', '.join(col(k) for k in keyname)
851        except KeyError:
852            raise _prg_error('upsert needs primary key or oid')
853        update = []
854        keyname = set(keyname)
855        keyname.add('oid')
856        for n in attnames:
857            if n not in keyname:
858                value = kw.get(n, True)
859                if value:
860                    if not isinstance(value, basestring):
861                        value = 'excluded.%s' % col(n)
862                    update.append('%s = %s' % (col(n), value))
863        if not values and not update:
864            return d
865        do = 'update set %s' % ', '.join(update) if update else 'nothing'
866        ret = 'oid, *' if 'oid' in attnames else '*'
867        q = ('INSERT INTO %s AS included (%s) VALUES (%s)'
868            ' ON CONFLICT (%s) DO %s RETURNING %s') % (
869                self._escape_qualified_name(cl), names, values,
870                target, do, ret)
871        self._do_debug(q, params)
872        try:
873            q = self.db.query(q, params)
874        except ProgrammingError:
875            if self.server_version < 90500:
876                raise _prg_error('upsert not supported by PostgreSQL version')
877            raise  # re-raise original error
878        res = q.dictresult()
879        if res:  # may be empty with "do nothing"
880            for n, value in res[0].items():
881                if n == 'oid':
882                    n = _oid_key(cl)
883                elif attnames.get(n) == 'bytea':
884                    value = self.unescape_bytea(value)
885                d[n] = value
886        elif update:
887            raise _int_error('upsert did not return new values')
888        else:
889            self.get(cl, d)
890        return d
891
892    def clear(self, cl, a=None):
893        """Clear all the attributes to values determined by the types.
894
895        Numeric types are set to 0, Booleans are set to false, and everything
896        else is set to the empty string.  If the array argument is present,
897        it is used as the array and any entries matching attribute names are
898        cleared with everything else left unchanged.
899
900        """
901        # At some point we will need a way to get defaults from a table.
902        if a is None:
903            a = {}  # empty if argument is not present
904        attnames = self.get_attnames(cl)
905        for n, t in attnames.items():
906            if n == 'oid':
907                continue
908            if t in ('int', 'integer', 'smallint', 'bigint',
909                    'float', 'real', 'double precision',
910                    'num', 'numeric', 'money'):
911                a[n] = 0
912            elif t in ('bool', 'boolean'):
913                a[n] = self._make_bool(False)
914            else:
915                a[n] = ''
916        return a
917
918    def delete(self, cl, d=None, **kw):
919        """Delete an existing row in a database table.
920
921        This method deletes the row from a table.  It deletes based on the
922        OID value as munged by get or passed as keyword, or on the primary
923        key of the table.  The return value is the number of deleted rows
924        (i.e. 0 if the row did not exist and 1 if the row was deleted).
925
926        """
927        # Like update, delete works on the oid.
928        # One day we will be testing that the record to be deleted
929        # isn't referenced somewhere (or else PostgreSQL will).
930        # Note that we only accept oid key from named args for safety.
931        qoid = _oid_key(cl)
932        if 'oid' in kw:
933            kw[qoid] = kw['oid']
934            del kw['oid']
935        if d is None:
936            d = {}
937        d.update(kw)
938        params = []
939        param = partial(self._prepare_param, params=params)
940        if qoid in d:
941            where = 'oid = %s' % param(d[qoid], 'int')
942        else:
943            try:
944                keyname = self.pkey(cl)
945            except KeyError:
946                raise _prg_error('Class %s has no primary key' % cl)
947            keyname = [keyname] if isinstance(
948                keyname, basestring) else sorted(keyname)
949            attnames = self.get_attnames(cl)
950            col = self.escape_identifier
951            try:
952                where = ' AND '.join('%s = %s'
953                    % (col(k), param(d[k], attnames[k])) for k in keyname)
954            except KeyError:
955                raise _prg_error('delete needs primary key or oid')
956        q = 'DELETE FROM %s WHERE %s' % (
957            self._escape_qualified_name(cl), where)
958        self._do_debug(q, params)
959        res = self.db.query(q, params)
960        return int(res)
961
962    def notification_handler(self, event, callback, arg_dict={}, timeout=None):
963        """Get notification handler that will run the given callback."""
964        return NotificationHandler(self.db, event, callback, arg_dict, timeout)
965
966
967# if run as script, print some information
968
969if __name__ == '__main__':
970    print('PyGreSQL version' + version)
971    print('')
972    print(__doc__)
Note: See TracBrowser for help on using the repository browser.