source: trunk/pg.py @ 743

Last change on this file since 743 was 743, checked in by cito, 4 years ago

Test error messages and security of the get() method

The get() method should be immune against SQL hacking with apostrophes in
values, and give a proper and helpful error message if a row is not found.

  • Property svn:keywords set to Id
File size: 35.9 KB
Line 
1#! /usr/bin/python
2#
3# pg.py
4#
5# $Id: pg.py 743 2016-01-14 17:32:01Z cito $
6#
7
8"""PyGreSQL classic interface.
9
10This pg module implements some basic database management stuff.
11It includes the _pg module and builds on it, providing the higher
12level wrapper class named DB with additional functionality.
13This is known as the "classic" ("old style") PyGreSQL interface.
14For a DB-API 2 compliant interface use the newer pgdb module.
15"""
16
17# Copyright (c) 1997-2016 by D'Arcy J.M. Cain.
18#
19# Contributions made by Ch. Zwerschke and others.
20#
21# The notification handler is based on pgnotify which is
22# Copyright (c) 2001 Ng Pheng Siong. All rights reserved.
23#
24# Permission to use, copy, modify, and distribute this software and its
25# documentation for any purpose and without fee is hereby granted,
26# provided that the above copyright notice appear in all copies and that
27# both that copyright notice and this permission notice appear in
28# supporting documentation.
29
30from __future__ import print_function
31
32from _pg import *
33
34import select
35import warnings
36
37from decimal import Decimal
38from collections import namedtuple
39from functools import partial
40
41try:
42    from collections import OrderedDict
43except ImportError:  # Python 2.6 or 3.0
44    OrderedDict = dict
45
46try:
47    basestring
48except NameError:  # Python >= 3.0
49    basestring = (str, bytes)
50
51set_decimal(Decimal)
52
53
54# Auxiliary functions that are independent from a DB connection:
55
56def _oid_key(table):
57    """Build oid key from a table name."""
58    return 'oid(%s)' % table
59
60
61def _simpletype(typ):
62    """Determine a simplified name a pg_type name."""
63    if typ.startswith('bool'):
64        return 'bool'
65    if typ.startswith(('abstime', 'date', 'interval', 'timestamp')):
66        return 'date'
67    if typ.startswith(('cid', 'oid', 'int', 'xid')):
68        return 'int'
69    if typ.startswith('float'):
70        return 'float'
71    if typ.startswith('numeric'):
72        return 'num'
73    if typ.startswith('money'):
74        return 'money'
75    if typ.startswith('bytea'):
76        return 'bytea'
77    return 'text'
78
79
80def _namedresult(q):
81    """Get query result as named tuples."""
82    row = namedtuple('Row', q.listfields())
83    return [row(*r) for r in q.getresult()]
84
85set_namedresult(_namedresult)
86
87
88def _db_error(msg, cls=DatabaseError):
89    """Return DatabaseError with empty sqlstate attribute."""
90    error = cls(msg)
91    error.sqlstate = None
92    return error
93
94
95def _int_error(msg):
96    """Return InternalError."""
97    return _db_error(msg, InternalError)
98
99
100def _prg_error(msg):
101    """Return ProgrammingError."""
102    return _db_error(msg, ProgrammingError)
103
104
105class NotificationHandler(object):
106    """A PostgreSQL client-side asynchronous notification handler."""
107
108    def __init__(self, db, event, callback, arg_dict=None, timeout=None):
109        """Initialize the notification handler.
110
111        db       - PostgreSQL connection object.
112        event    - Event (notification channel) to LISTEN for.
113        callback - Event callback function.
114        arg_dict - A dictionary passed as the argument to the callback.
115        timeout  - Timeout in seconds; a floating point number denotes
116                   fractions of seconds. If it is absent or None, the
117                   callers will never time out.
118        """
119        if isinstance(db, DB):
120            db = db.db
121        self.db = db
122        self.event = event
123        self.stop_event = 'stop_%s' % event
124        self.listening = False
125        self.callback = callback
126        if arg_dict is None:
127            arg_dict = {}
128        self.arg_dict = arg_dict
129        self.timeout = timeout
130
131    def __del__(self):
132        self.close()
133
134    def close(self):
135        """Stop listening and close the connection."""
136        if self.db:
137            self.unlisten()
138            self.db.close()
139            self.db = None
140
141    def listen(self):
142        """Start listening for the event and the stop event."""
143        if not self.listening:
144            self.db.query('listen "%s"' % self.event)
145            self.db.query('listen "%s"' % self.stop_event)
146            self.listening = True
147
148    def unlisten(self):
149        """Stop listening for the event and the stop event."""
150        if self.listening:
151            self.db.query('unlisten "%s"' % self.event)
152            self.db.query('unlisten "%s"' % self.stop_event)
153            self.listening = False
154
155    def notify(self, db=None, stop=False, payload=None):
156        """Generate a notification.
157
158        Note: If the main loop is running in another thread, you must pass
159        a different database connection to avoid a collision.
160        """
161        if not db:
162            db = self.db
163        if self.listening:
164            q = 'notify "%s"' % (self.stop_event if stop else self.event)
165            if payload:
166                q += ", '%s'" % payload
167            return db.query(q)
168
169    def __call__(self, close=False):
170        """Invoke the notification handler.
171
172        The handler is a loop that actually LISTENs for two NOTIFY messages:
173
174        <event> and stop_<event>.
175
176        When either of these NOTIFY messages are received, its associated
177        'pid' and 'event' are inserted into <arg_dict>, and the callback is
178        invoked with <arg_dict>. If the NOTIFY message is stop_<event>, the
179        handler UNLISTENs both <event> and stop_<event> and exits.
180
181        Note: If you run this loop in another thread, don't use the same
182        database connection for database operations in the main thread.
183        """
184        self.listen()
185        _ilist = [self.db.fileno()]
186
187        while self.listening:
188            ilist, _olist, _elist = select.select(_ilist, [], [], self.timeout)
189            if ilist:
190                while self.listening:
191                    notice = self.db.getnotify()
192                    if not notice:  # no more messages
193                        break
194                    event, pid, extra = notice
195                    if event not in (self.event, self.stop_event):
196                        self.unlisten()
197                        raise _db_error(
198                            'Listening for "%s" and "%s", but notified of "%s"'
199                            % (self.event, self.stop_event, event))
200                    if event == self.stop_event:
201                        self.unlisten()
202                    self.arg_dict['pid'] = pid
203                    self.arg_dict['event'] = event
204                    self.arg_dict['extra'] = extra
205                    self.callback(self.arg_dict)
206            else:   # we timed out
207                self.unlisten()
208                self.callback(None)
209
210
211def pgnotify(*args, **kw):
212    """Same as NotificationHandler, under the traditional name."""
213    warnings.warn("pgnotify is deprecated, use NotificationHandler instead",
214        DeprecationWarning, stacklevel=2)
215    return NotificationHandler(*args, **kw)
216
217
218# The actual PostGreSQL database connection interface:
219
220class DB(object):
221    """Wrapper class for the _pg connection type."""
222
223    def __init__(self, *args, **kw):
224        """Create a new connection
225
226        You can pass either the connection parameters or an existing
227        _pg or pgdb connection. This allows you to use the methods
228        of the classic pg interface with a DB-API 2 pgdb connection.
229        """
230        if not args and len(kw) == 1:
231            db = kw.get('db')
232        elif not kw and len(args) == 1:
233            db = args[0]
234        else:
235            db = None
236        if db:
237            if isinstance(db, DB):
238                db = db.db
239            else:
240                try:
241                    db = db._cnx
242                except AttributeError:
243                    pass
244        if not db or not hasattr(db, 'db') or not hasattr(db, 'query'):
245            db = connect(*args, **kw)
246            self._closeable = True
247        else:
248            self._closeable = False
249        self.db = db
250        self.dbname = db.db
251        self._regtypes = False
252        self._attnames = {}
253        self._pkeys = {}
254        self._privileges = {}
255        self._args = args, kw
256        self.debug = None  # For debugging scripts, this can be set
257            # * to a string format specification (e.g. in CGI set to "%s<BR>"),
258            # * to a file object to write debug statements or
259            # * to a callable object which takes a string argument
260            # * to any other true value to just print debug statements
261
262    def __getattr__(self, name):
263        # All undefined members are same as in underlying connection:
264        if self.db:
265            return getattr(self.db, name)
266        else:
267            raise _int_error('Connection is not valid')
268
269    def __dir__(self):
270        # Custom dir function including the attributes of the connection:
271        attrs = set(self.__class__.__dict__)
272        attrs.update(self.__dict__)
273        attrs.update(dir(self.db))
274        return sorted(attrs)
275
276    # Context manager methods
277
278    def __enter__(self):
279        """Enter the runtime context. This will start a transactio."""
280        self.begin()
281        return self
282
283    def __exit__(self, et, ev, tb):
284        """Exit the runtime context. This will end the transaction."""
285        if et is None and ev is None and tb is None:
286            self.commit()
287        else:
288            self.rollback()
289
290    # Auxiliary methods
291
292    def _do_debug(self, *args):
293        """Print a debug message"""
294        if self.debug:
295            s = '\n'.join(args)
296            if isinstance(self.debug, basestring):
297                print(self.debug % s)
298            elif hasattr(self.debug, 'write'):
299                self.debug.write(s + '\n')
300            elif callable(self.debug):
301                self.debug(s)
302            else:
303                print(s)
304
305    def _escape_qualified_name(self, s):
306        """Escape a qualified name.
307
308        Escapes the name for use as an SQL identifier, unless the
309        name contains a dot, in which case the name is ambiguous
310        (could be a qualified name or just a name with a dot in it)
311        and must be quoted manually by the caller.
312        """
313        if '.' not in s:
314            s = self.escape_identifier(s)
315        return s
316
317    @staticmethod
318    def _make_bool(d):
319        """Get boolean value corresponding to d."""
320        return bool(d) if get_bool() else ('t' if d else 'f')
321
322    _bool_true_values = frozenset('t true 1 y yes on'.split())
323
324    def _prepare_bool(self, d):
325        """Prepare a boolean parameter."""
326        if isinstance(d, basestring):
327            if not d:
328                return None
329            d = d.lower() in self._bool_true_values
330        return 't' if d else 'f'
331
332    _date_literals = frozenset('current_date current_time'
333        ' current_timestamp localtime localtimestamp'.split())
334
335    def _prepare_date(self, d):
336        """Prepare a date parameter."""
337        if not d:
338            return None
339        if isinstance(d, basestring) and d.lower() in self._date_literals:
340            raise ValueError
341        return d
342
343    def _prepare_num(self, d):
344        """Prepare a numeric parameter."""
345        if not d and d != 0:
346            return None
347        return d
348
349    def _prepare_bytea(self, d):
350        """Prepare a bytea parameter."""
351        return self.escape_bytea(d)
352
353    _prepare_funcs = dict(  # quote methods for each type
354        bool=_prepare_bool, date=_prepare_date,
355        int=_prepare_num, num=_prepare_num, float=_prepare_num,
356        money=_prepare_num, bytea=_prepare_bytea)
357
358    def _prepare_param(self, value, typ, params):
359        """Prepare and add a parameter to the list."""
360        if value is not None and typ != 'text':
361            try:
362                prepare = self._prepare_funcs[typ]
363            except KeyError:
364                pass
365            else:
366                try:
367                    value = prepare(self, value)
368                except ValueError:
369                    return value
370        params.append(value)
371        return '$%d' % len(params)
372
373    def _list_params(self, params):
374        """Create a human readable parameter list."""
375        return ', '.join('$%d=%r' % (n, v) for n, v in enumerate(params, 1))
376
377    @staticmethod
378    def _prepare_qualified_param(name, param):
379        """Quote parameter representing a qualified name.
380
381        Escapes the name for use as an SQL parameter, unless the
382        name contains a dot, in which case the name is ambiguous
383        (could be a qualified name or just a name with a dot in it)
384        and must be quoted manually by the caller.
385
386        """
387        if isinstance(param, int):
388            param = "$%d" % param
389        if '.' not in name:
390            param = 'quote_ident(%s)' % (param,)
391        return param
392
393    # Public methods
394
395    # escape_string and escape_bytea exist as methods,
396    # so we define unescape_bytea as a method as well
397    unescape_bytea = staticmethod(unescape_bytea)
398
399    def close(self):
400        """Close the database connection."""
401        # Wraps shared library function so we can track state.
402        if self._closeable:
403            if self.db:
404                self.db.close()
405                self.db = None
406            else:
407                raise _int_error('Connection already closed')
408
409    def reset(self):
410        """Reset connection with current parameters.
411
412        All derived queries and large objects derived from this connection
413        will not be usable after this call.
414
415        """
416        if self.db:
417            self.db.reset()
418        else:
419            raise _int_error('Connection already closed')
420
421    def reopen(self):
422        """Reopen connection to the database.
423
424        Used in case we need another connection to the same database.
425        Note that we can still reopen a database that we have closed.
426
427        """
428        # There is no such shared library function.
429        if self._closeable:
430            db = connect(*self._args[0], **self._args[1])
431            if self.db:
432                self.db.close()
433            self.db = db
434
435    def begin(self, mode=None):
436        """Begin a transaction."""
437        qstr = 'BEGIN'
438        if mode:
439            qstr += ' ' + mode
440        return self.query(qstr)
441
442    start = begin
443
444    def commit(self):
445        """Commit the current transaction."""
446        return self.query('COMMIT')
447
448    end = commit
449
450    def rollback(self, name=None):
451        """Roll back the current transaction."""
452        qstr = 'ROLLBACK'
453        if name:
454            qstr += ' TO ' + name
455        return self.query(qstr)
456
457    def savepoint(self, name):
458        """Define a new savepoint within the current transaction."""
459        return self.query('SAVEPOINT ' + name)
460
461    def release(self, name):
462        """Destroy a previously defined savepoint."""
463        return self.query('RELEASE ' + name)
464
465    def query(self, qstr, *args):
466        """Execute a SQL command string.
467
468        This method simply sends a SQL query to the database. If the query is
469        an insert statement that inserted exactly one row into a table that
470        has OIDs, the return value is the OID of the newly inserted row.
471        If the query is an update or delete statement, or an insert statement
472        that did not insert exactly one row in a table with OIDs, then the
473        number of rows affected is returned as a string. If it is a statement
474        that returns rows as a result (usually a select statement, but maybe
475        also an "insert/update ... returning" statement), this method returns
476        a Query object that can be accessed via getresult() or dictresult()
477        or simply printed. Otherwise, it returns `None`.
478
479        The query can contain numbered parameters of the form $1 in place
480        of any data constant. Arguments given after the query string will
481        be substituted for the corresponding numbered parameter. Parameter
482        values can also be given as a single list or tuple argument.
483        """
484        # Wraps shared library function for debugging.
485        if not self.db:
486            raise _int_error('Connection is not valid')
487        self._do_debug(qstr)
488        return self.db.query(qstr, args)
489
490    def pkey(self, table, flush=False):
491        """Get or set the primary key of a table.
492
493        Composite primary keys are represented as frozensets. Note that
494        this raises a KeyError if the table does not have a primary key.
495
496        If flush is set then the internal cache for primary keys will
497        be flushed. This may be necessary after the database schema or
498        the search path has been changed.
499        """
500        pkeys = self._pkeys
501        if flush:
502            pkeys.clear()
503            self._do_debug('The pkey cache has been flushed')
504        try:  # cache lookup
505            pkey = pkeys[table]
506        except KeyError:  # cache miss, check the database
507            q = ("SELECT a.attname FROM pg_index i"
508                " JOIN pg_attribute a ON a.attrelid = i.indrelid"
509                " AND a.attnum = ANY(i.indkey)"
510                " AND NOT a.attisdropped"
511                " WHERE i.indrelid=%s::regclass"
512                " AND i.indisprimary") % (
513                    self._prepare_qualified_param(table, 1),)
514            pkey = self.db.query(q, (table,)).getresult()
515            if not pkey:
516                raise KeyError('Table %s has no primary key' % table)
517            if len(pkey) > 1:
518                pkey = frozenset(k[0] for k in pkey)
519            else:
520                pkey = pkey[0][0]
521            pkeys[table] = pkey  # cache it
522        return pkey
523
524    def get_databases(self):
525        """Get list of databases in the system."""
526        return [s[0] for s in
527            self.db.query('SELECT datname FROM pg_database').getresult()]
528
529    def get_relations(self, kinds=None):
530        """Get list of relations in connected database of specified kinds.
531
532        If kinds is None or empty, all kinds of relations are returned.
533        Otherwise kinds can be a string or sequence of type letters
534        specifying which kind of relations you want to list.
535        """
536        where = " AND r.relkind IN (%s)" % ','.join(
537            ["'%s'" % k for k in kinds]) if kinds else ''
538        q = ("SELECT quote_ident(s.nspname)||'.'||quote_ident(r.relname)"
539            " FROM pg_class r"
540            " JOIN pg_namespace s ON s.oid = r.relnamespace"
541            " WHERE s.nspname NOT SIMILAR"
542            " TO 'pg/_%%|information/_schema' ESCAPE '/' %s"
543            " ORDER BY s.nspname, r.relname") % where
544        return [r[0] for r in self.db.query(q).getresult()]
545
546    def get_tables(self):
547        """Return list of tables in connected database."""
548        return self.get_relations('r')
549
550    def get_attnames(self, table, flush=False):
551        """Given the name of a table, dig out the set of attribute names.
552
553        Returns a dictionary of attribute names (the names are the keys,
554        the values are the names of the attributes' types).
555
556        If your Python version supports this, the dictionary will be an
557        OrderedDictionary with the column names in the right order.
558
559        If flush is set then the internal cache for attribute names will
560        be flushed. This may be necessary after the database schema or
561        the search path has been changed.
562
563        By default, only a limited number of simple types will be returned.
564        You can get the regular types after calling use_regtypes(True).
565        """
566        attnames = self._attnames
567        if flush:
568            attnames.clear()
569            self._do_debug('The attnames cache has been flushed')
570        try:  # cache lookup
571            names = attnames[table]
572        except KeyError:  # cache miss, check the database
573            q = ("SELECT a.attname, t.typname%s"
574                " FROM pg_attribute a"
575                " JOIN pg_type t ON t.oid = a.atttypid"
576                " WHERE a.attrelid = %s::regclass"
577                " AND (a.attnum > 0 OR a.attname = 'oid')"
578                " AND NOT a.attisdropped ORDER BY a.attnum") % (
579                    '::regtype' if self._regtypes else '',
580                    self._prepare_qualified_param(table, 1))
581            names = self.db.query(q, (table,)).getresult()
582            if not names:
583                raise KeyError('Table %s does not exist' % table)
584            if not self._regtypes:
585                names = ((name, _simpletype(typ)) for name, typ in names)
586            names = OrderedDict(names)
587            attnames[table] = names  # cache it
588        return names
589
590    def use_regtypes(self, regtypes=None):
591        """Use regular type names instead of simplified type names."""
592        if regtypes is None:
593            return self._regtypes
594        else:
595            regtypes = bool(regtypes)
596            if regtypes != self._regtypes:
597                self._regtypes = regtypes
598                self._attnames.clear()
599            return regtypes
600
601    def has_table_privilege(self, table, privilege='select'):
602        """Check whether current user has specified table privilege."""
603        privilege = privilege.lower()
604        try:  # ask cache
605            return self._privileges[(table, privilege)]
606        except KeyError:  # cache miss, ask the database
607            q = "SELECT has_table_privilege(%s, $2)" % (
608                self._prepare_qualified_param(table, 1),)
609            q = self.db.query(q, (table, privilege))
610            ret = q.getresult()[0][0] == self._make_bool(True)
611            self._privileges[(table, privilege)] = ret  # cache it
612            return ret
613
614    def get(self, table, row, keyname=None):
615        """Get a row from a database table or view.
616
617        This method is the basic mechanism to get a single row.  The keyname
618        that the key specifies a unique row.  If keyname is not specified
619        then the primary key for the table is used.  If row is a dictionary
620        then the value for the key is taken from it and it is modified to
621        include the new values, replacing existing values where necessary.
622        For a composite key, keyname can also be a sequence of key names.
623        The OID is also put into the dictionary if the table has one, but
624        in order to allow the caller to work with multiple tables, it is
625        munged as "oid(table)".
626        """
627        if table.endswith('*'):  # scan descendant tables?
628            table = table[:-1].rstrip()  # need parent table name
629        if not keyname:
630            # use the primary key by default
631            try:
632                keyname = self.pkey(table)
633            except KeyError:
634                raise _prg_error('Table %s has no primary key' % table)
635        attnames = self.get_attnames(table)
636        params = []
637        param = partial(self._prepare_param, params=params)
638        col = self.escape_identifier
639        # We want the oid for later updates if that isn't the key.
640        # To allow users to work with multiple tables, we munge
641        # the name of the "oid" key by adding the name of the table.
642        qoid = _oid_key(table)
643        if keyname == 'oid':
644            if isinstance(row, dict):
645                if qoid not in row:
646                    raise _db_error('%s not in row' % qoid)
647            else:
648                row = {qoid: row}
649            what = '*'
650            where = 'oid = %s' % param(row[qoid], 'int')
651        else:
652            keyname = [keyname] if isinstance(
653                keyname, basestring) else sorted(keyname)
654            if not isinstance(row, dict):
655                if len(keyname) > 1:
656                    raise _prg_error('Composite key needs dict as row')
657                row = dict((k, row) for k in keyname)
658            what = ', '.join(col(k) for k in attnames)
659            where = ' AND '.join('%s = %s' % (
660                col(k), param(row[k], attnames[k])) for k in keyname)
661        q = 'SELECT %s FROM %s WHERE %s LIMIT 1' % (
662            what, self._escape_qualified_name(table), where)
663        self._do_debug(q, params)
664        q = self.db.query(q, params)
665        res = q.dictresult()
666        if not res:
667            raise _db_error('No such record in %s\nwhere %s\nwith %s' % (
668                table, where, self._list_params(params)))
669        for n, value in res[0].items():
670            if n == 'oid':
671                n = qoid
672            elif attnames.get(n) == 'bytea':
673                value = self.unescape_bytea(value)
674            row[n] = value
675        return row
676
677    def insert(self, table, row=None, **kw):
678        """Insert a row into a database table.
679
680        This method inserts a row into a table.  The name of the table must
681        be passed as the first parameter.  The other parameters are used for
682        providing the data of the row that shall be inserted into the table.
683        If a dictionary is supplied as the second parameter, it starts with
684        that.  Otherwise it uses a blank dictionary. Either way the dictionary
685        is updated from the keywords.
686
687        The dictionary is then reloaded with the values actually inserted in
688        order to pick up values modified by rules, triggers, etc.
689
690        Note: The method currently doesn't support insert into views
691        although PostgreSQL does.
692        """
693        if 'oid' in kw:
694            del kw['oid']
695        if row is None:
696            row = {}
697        row.update(kw)
698        attnames = self.get_attnames(table)
699        params = []
700        param = partial(self._prepare_param, params=params)
701        col = self.escape_identifier
702        names, values = [], []
703        for n in attnames:
704            if n in row:
705                names.append(col(n))
706                values.append(param(row[n], attnames[n]))
707        names, values = ', '.join(names), ', '.join(values)
708        ret = 'oid, *' if 'oid' in attnames else '*'
709        q = 'INSERT INTO %s (%s) VALUES (%s) RETURNING %s' % (
710            self._escape_qualified_name(table), names, values, ret)
711        self._do_debug(q, params)
712        q = self.db.query(q, params)
713        res = q.dictresult()
714        if not res:
715            raise _int_error('Insert operation did not return new values')
716        for n, value in res[0].items():
717            if n == 'oid':
718                n = _oid_key(table)
719            elif attnames.get(n) == 'bytea' and value is not None:
720                value = self.unescape_bytea(value)
721            row[n] = value
722        return row
723
724    def update(self, table, row=None, **kw):
725        """Update an existing row in a database table.
726
727        Similar to insert but updates an existing row.  The update is based
728        on the OID value as munged by get or passed as keyword, or on the
729        primary key of the table.  The dictionary is modified to reflect
730        any changes caused by the update due to triggers, rules, default
731        values, etc.
732        """
733        # Update always works on the oid which get() returns if available,
734        # otherwise use the primary key.  Fail if neither.
735        # Note that we only accept oid key from named args for safety.
736        qoid = _oid_key(table)
737        if 'oid' in kw:
738            kw[qoid] = kw['oid']
739            del kw['oid']
740        if row is None:
741            row = {}
742        row.update(kw)
743        attnames = self.get_attnames(table)
744        params = []
745        param = partial(self._prepare_param, params=params)
746        col = self.escape_identifier
747        if qoid in row:
748            where = 'oid = %s' % param(row[qoid], 'int')
749            keyname = []
750        else:
751            try:
752                keyname = self.pkey(table)
753            except KeyError:
754                raise _prg_error('Table %s has no primary key' % table)
755            keyname = [keyname] if isinstance(
756                keyname, basestring) else sorted(keyname)
757            try:
758                where = ' AND '.join('%s = %s' % (
759                    col(k), param(row[k], attnames[k])) for k in keyname)
760            except KeyError:
761                raise _prg_error('Update operation needs primary key or oid')
762        keyname = set(keyname)
763        keyname.add('oid')
764        values = []
765        for n in attnames:
766            if n in row and n not in keyname:
767                values.append('%s = %s' % (col(n), param(row[n], attnames[n])))
768        if not values:
769            return row
770        values = ', '.join(values)
771        ret = 'oid, *' if 'oid' in attnames else '*'
772        q = 'UPDATE %s SET %s WHERE %s RETURNING %s' % (
773            self._escape_qualified_name(table), values, where, ret)
774        self._do_debug(q, params)
775        q = self.db.query(q, params)
776        res = q.dictresult()
777        if res:  # may be empty when row does not exist
778            for n, value in res[0].items():
779                if n == 'oid':
780                    n = qoid
781                elif attnames.get(n) == 'bytea' and value is not None:
782                    value = self.unescape_bytea(value)
783                row[n] = value
784        return row
785
786    def upsert(self, table, row=None, **kw):
787        """Insert a row into a database table with conflict resolution
788
789        This method inserts a row into a table, but instead of raising a
790        ProgrammingError exception in case a row with the same primary key
791        already exists, an update will be executed instead.  This will be
792        performed as a single atomic operation on the database, so race
793        conditions can be avoided.
794
795        Like the insert method, the first parameter is the name of the
796        table and the second parameter can be used to pass the values to
797        be inserted as a dictionary.
798
799        Unlike the insert und update statement, keyword parameters are not
800        used to modify the dictionary, but to specify which columns shall
801        be updated in case of a conflict, and in which way:
802
803        A value of False or None means the column shall not be updated,
804        a value of True means the column shall be updated with the value
805        that has been proposed for insertion, i.e. has been passed as value
806        in the dictionary.  Columns that are not specified by keywords but
807        appear as keys in the dictionary are also updated like in the case
808        keywords had been passed with the value True.
809
810        So if in the case of a conflict you want to update every column that
811        has been passed in the dictionary row , you would call upsert(table, row).
812        If you don't want to do anything in case of a conflict, i.e. leave
813        the existing row as it is, call upsert(table, row, **dict.fromkeys(row)).
814
815        If you need more fine-grained control of what gets updated, you can
816        also pass strings in the keyword parameters.  These strings will
817        be used as SQL expressions for the update columns.  In these
818        expressions you can refer to the value that already exists in
819        the table by prefixing the column name with "included.", and to
820        the value that has been proposed for insertion by prefixing the
821        column name with the "excluded."
822
823        The dictionary is modified in any case to reflect the values in
824        the database after the operation has completed.
825
826        Note: The method uses the PostgreSQL "upsert" feature which is
827        only available since PostgreSQL 9.5.
828        """
829        if 'oid' in kw:
830            del kw['oid']
831        if row is None:
832            row = {}
833        attnames = self.get_attnames(table)
834        params = []
835        param = partial(self._prepare_param,params=params)
836        col = self.escape_identifier
837        names, values, updates = [], [], []
838        for n in attnames:
839            if n in row:
840                names.append(col(n))
841                values.append(param(row[n], attnames[n]))
842        names, values = ', '.join(names), ', '.join(values)
843        try:
844            keyname = self.pkey(table)
845        except KeyError:
846            raise _prg_error('Table %s has no primary key' % table)
847        keyname = [keyname] if isinstance(
848            keyname, basestring) else sorted(keyname)
849        try:
850            target = ', '.join(col(k) for k in keyname)
851        except KeyError:
852            raise _prg_error('Upsert operation needs primary key or oid')
853        update = []
854        keyname = set(keyname)
855        keyname.add('oid')
856        for n in attnames:
857            if n not in keyname:
858                value = kw.get(n, True)
859                if value:
860                    if not isinstance(value, basestring):
861                        value = 'excluded.%s' % col(n)
862                    update.append('%s = %s' % (col(n), value))
863        if not values and not update:
864            return row
865        do = 'update set %s' % ', '.join(update) if update else 'nothing'
866        ret = 'oid, *' if 'oid' in attnames else '*'
867        q = ('INSERT INTO %s AS included (%s) VALUES (%s)'
868            ' ON CONFLICT (%s) DO %s RETURNING %s') % (
869                self._escape_qualified_name(table), names, values,
870                target, do, ret)
871        self._do_debug(q, params)
872        try:
873            q = self.db.query(q, params)
874        except ProgrammingError:
875            if self.server_version < 90500:
876                raise _prg_error(
877                    'Upsert operation is not supported by PostgreSQL version')
878            raise  # re-raise original error
879        res = q.dictresult()
880        if res:  # may be empty with "do nothing"
881            for n, value in res[0].items():
882                if n == 'oid':
883                    n = _oid_key(table)
884                elif attnames.get(n) == 'bytea':
885                    value = self.unescape_bytea(value)
886                row[n] = value
887        elif update:
888            raise _int_error('Upsert operation did not return new values')
889        else:
890            self.get(table, row)
891        return row
892
893    def clear(self, table, row=None):
894        """Clear all the attributes to values determined by the types.
895
896        Numeric types are set to 0, Booleans are set to false, and everything
897        else is set to the empty string.  If the row argument is present,
898        it is used as the row dictionary and any entries matching attribute
899        names are cleared with everything else left unchanged.
900        """
901        # At some point we will need a way to get defaults from a table.
902        if row is None:
903            row = {}  # empty if argument is not present
904        attnames = self.get_attnames(table)
905        for n, t in attnames.items():
906            if n == 'oid':
907                continue
908            if t in ('int', 'integer', 'smallint', 'bigint',
909                    'float', 'real', 'double precision',
910                    'num', 'numeric', 'money'):
911                row[n] = 0
912            elif t in ('bool', 'boolean'):
913                row[n] = self._make_bool(False)
914            else:
915                row[n] = ''
916        return row
917
918    def delete(self, table, row=None, **kw):
919        """Delete an existing row in a database table.
920
921        This method deletes the row from a table.  It deletes based on the
922        OID value as munged by get or passed as keyword, or on the primary
923        key of the table.  The return value is the number of deleted rows
924        (i.e. 0 if the row did not exist and 1 if the row was deleted).
925        """
926        # Like update, delete works on the oid.
927        # One day we will be testing that the record to be deleted
928        # isn't referenced somewhere (or else PostgreSQL will).
929        # Note that we only accept oid key from named args for safety.
930        qoid = _oid_key(table)
931        if 'oid' in kw:
932            kw[qoid] = kw['oid']
933            del kw['oid']
934        if row is None:
935            row = {}
936        row.update(kw)
937        params = []
938        param = partial(self._prepare_param, params=params)
939        if qoid in row:
940            where = 'oid = %s' % param(row[qoid], 'int')
941        else:
942            try:
943                keyname = self.pkey(table)
944            except KeyError:
945                raise _prg_error('Table %s has no primary key' % table)
946            keyname = [keyname] if isinstance(
947                keyname, basestring) else sorted(keyname)
948            attnames = self.get_attnames(table)
949            col = self.escape_identifier
950            try:
951                where = ' AND '.join('%s = %s' % (
952                    col(k), param(row[k], attnames[k])) for k in keyname)
953            except KeyError:
954                raise _prg_error('Delete operation needs primary key or oid')
955        q = 'DELETE FROM %s WHERE %s' % (
956            self._escape_qualified_name(table), where)
957        self._do_debug(q, params)
958        res = self.db.query(q, params)
959        return int(res)
960
961    def notification_handler(self, event, callback, arg_dict={}, timeout=None):
962        """Get notification handler that will run the given callback."""
963        return NotificationHandler(self.db, event, callback, arg_dict, timeout)
964
965
966# if run as script, print some information
967
968if __name__ == '__main__':
969    print('PyGreSQL version' + version)
970    print('')
971    print(__doc__)
Note: See TracBrowser for help on using the repository browser.