source: trunk/pg.py @ 745

Last change on this file since 745 was 745, checked in by cito, 4 years ago

Add methods get/set_parameter to DB wrapper class

These methods can be used to get/set/reset run-time parameters,
even several at once.

Since this is pretty useful and will not break anything, I have
also back ported these additions to the 4.x branch.

Everything is well documented and tested, of course.

  • Property svn:keywords set to Id
File size: 40.7 KB
Line 
1#! /usr/bin/python
2#
3# pg.py
4#
5# $Id: pg.py 745 2016-01-15 01:16:23Z cito $
6#
7
8"""PyGreSQL classic interface.
9
10This pg module implements some basic database management stuff.
11It includes the _pg module and builds on it, providing the higher
12level wrapper class named DB with additional functionality.
13This is known as the "classic" ("old style") PyGreSQL interface.
14For a DB-API 2 compliant interface use the newer pgdb module.
15"""
16
17# Copyright (c) 1997-2016 by D'Arcy J.M. Cain.
18#
19# Contributions made by Ch. Zwerschke and others.
20#
21# The notification handler is based on pgnotify which is
22# Copyright (c) 2001 Ng Pheng Siong. All rights reserved.
23#
24# Permission to use, copy, modify, and distribute this software and its
25# documentation for any purpose and without fee is hereby granted,
26# provided that the above copyright notice appear in all copies and that
27# both that copyright notice and this permission notice appear in
28# supporting documentation.
29
30from __future__ import print_function
31
32from _pg import *
33
34import select
35import warnings
36
37from decimal import Decimal
38from collections import namedtuple
39from functools import partial
40
41try:
42    from collections import OrderedDict
43except ImportError:  # Python 2.6 or 3.0
44    OrderedDict = dict
45
46try:
47    basestring
48except NameError:  # Python >= 3.0
49    basestring = (str, bytes)
50
51set_decimal(Decimal)
52
53
54# Auxiliary functions that are independent from a DB connection:
55
56def _oid_key(table):
57    """Build oid key from a table name."""
58    return 'oid(%s)' % table
59
60
61def _simpletype(typ):
62    """Determine a simplified name a pg_type name."""
63    if typ.startswith('bool'):
64        return 'bool'
65    if typ.startswith(('abstime', 'date', 'interval', 'timestamp')):
66        return 'date'
67    if typ.startswith(('cid', 'oid', 'int', 'xid')):
68        return 'int'
69    if typ.startswith('float'):
70        return 'float'
71    if typ.startswith('numeric'):
72        return 'num'
73    if typ.startswith('money'):
74        return 'money'
75    if typ.startswith('bytea'):
76        return 'bytea'
77    return 'text'
78
79
80def _namedresult(q):
81    """Get query result as named tuples."""
82    row = namedtuple('Row', q.listfields())
83    return [row(*r) for r in q.getresult()]
84
85set_namedresult(_namedresult)
86
87
88def _db_error(msg, cls=DatabaseError):
89    """Return DatabaseError with empty sqlstate attribute."""
90    error = cls(msg)
91    error.sqlstate = None
92    return error
93
94
95def _int_error(msg):
96    """Return InternalError."""
97    return _db_error(msg, InternalError)
98
99
100def _prg_error(msg):
101    """Return ProgrammingError."""
102    return _db_error(msg, ProgrammingError)
103
104
105class NotificationHandler(object):
106    """A PostgreSQL client-side asynchronous notification handler."""
107
108    def __init__(self, db, event, callback, arg_dict=None, timeout=None):
109        """Initialize the notification handler.
110
111        db       - PostgreSQL connection object.
112        event    - Event (notification channel) to LISTEN for.
113        callback - Event callback function.
114        arg_dict - A dictionary passed as the argument to the callback.
115        timeout  - Timeout in seconds; a floating point number denotes
116                   fractions of seconds. If it is absent or None, the
117                   callers will never time out.
118        """
119        if isinstance(db, DB):
120            db = db.db
121        self.db = db
122        self.event = event
123        self.stop_event = 'stop_%s' % event
124        self.listening = False
125        self.callback = callback
126        if arg_dict is None:
127            arg_dict = {}
128        self.arg_dict = arg_dict
129        self.timeout = timeout
130
131    def __del__(self):
132        self.close()
133
134    def close(self):
135        """Stop listening and close the connection."""
136        if self.db:
137            self.unlisten()
138            self.db.close()
139            self.db = None
140
141    def listen(self):
142        """Start listening for the event and the stop event."""
143        if not self.listening:
144            self.db.query('listen "%s"' % self.event)
145            self.db.query('listen "%s"' % self.stop_event)
146            self.listening = True
147
148    def unlisten(self):
149        """Stop listening for the event and the stop event."""
150        if self.listening:
151            self.db.query('unlisten "%s"' % self.event)
152            self.db.query('unlisten "%s"' % self.stop_event)
153            self.listening = False
154
155    def notify(self, db=None, stop=False, payload=None):
156        """Generate a notification.
157
158        Note: If the main loop is running in another thread, you must pass
159        a different database connection to avoid a collision.
160        """
161        if not db:
162            db = self.db
163        if self.listening:
164            q = 'notify "%s"' % (self.stop_event if stop else self.event)
165            if payload:
166                q += ", '%s'" % payload
167            return db.query(q)
168
169    def __call__(self, close=False):
170        """Invoke the notification handler.
171
172        The handler is a loop that actually LISTENs for two NOTIFY messages:
173
174        <event> and stop_<event>.
175
176        When either of these NOTIFY messages are received, its associated
177        'pid' and 'event' are inserted into <arg_dict>, and the callback is
178        invoked with <arg_dict>. If the NOTIFY message is stop_<event>, the
179        handler UNLISTENs both <event> and stop_<event> and exits.
180
181        Note: If you run this loop in another thread, don't use the same
182        database connection for database operations in the main thread.
183        """
184        self.listen()
185        _ilist = [self.db.fileno()]
186
187        while self.listening:
188            ilist, _olist, _elist = select.select(_ilist, [], [], self.timeout)
189            if ilist:
190                while self.listening:
191                    notice = self.db.getnotify()
192                    if not notice:  # no more messages
193                        break
194                    event, pid, extra = notice
195                    if event not in (self.event, self.stop_event):
196                        self.unlisten()
197                        raise _db_error(
198                            'Listening for "%s" and "%s", but notified of "%s"'
199                            % (self.event, self.stop_event, event))
200                    if event == self.stop_event:
201                        self.unlisten()
202                    self.arg_dict['pid'] = pid
203                    self.arg_dict['event'] = event
204                    self.arg_dict['extra'] = extra
205                    self.callback(self.arg_dict)
206            else:   # we timed out
207                self.unlisten()
208                self.callback(None)
209
210
211def pgnotify(*args, **kw):
212    """Same as NotificationHandler, under the traditional name."""
213    warnings.warn("pgnotify is deprecated, use NotificationHandler instead",
214        DeprecationWarning, stacklevel=2)
215    return NotificationHandler(*args, **kw)
216
217
218# The actual PostGreSQL database connection interface:
219
220class DB(object):
221    """Wrapper class for the _pg connection type."""
222
223    def __init__(self, *args, **kw):
224        """Create a new connection
225
226        You can pass either the connection parameters or an existing
227        _pg or pgdb connection. This allows you to use the methods
228        of the classic pg interface with a DB-API 2 pgdb connection.
229        """
230        if not args and len(kw) == 1:
231            db = kw.get('db')
232        elif not kw and len(args) == 1:
233            db = args[0]
234        else:
235            db = None
236        if db:
237            if isinstance(db, DB):
238                db = db.db
239            else:
240                try:
241                    db = db._cnx
242                except AttributeError:
243                    pass
244        if not db or not hasattr(db, 'db') or not hasattr(db, 'query'):
245            db = connect(*args, **kw)
246            self._closeable = True
247        else:
248            self._closeable = False
249        self.db = db
250        self.dbname = db.db
251        self._regtypes = False
252        self._attnames = {}
253        self._pkeys = {}
254        self._privileges = {}
255        self._args = args, kw
256        self.debug = None  # For debugging scripts, this can be set
257            # * to a string format specification (e.g. in CGI set to "%s<BR>"),
258            # * to a file object to write debug statements or
259            # * to a callable object which takes a string argument
260            # * to any other true value to just print debug statements
261
262    def __getattr__(self, name):
263        # All undefined members are same as in underlying connection:
264        if self.db:
265            return getattr(self.db, name)
266        else:
267            raise _int_error('Connection is not valid')
268
269    def __dir__(self):
270        # Custom dir function including the attributes of the connection:
271        attrs = set(self.__class__.__dict__)
272        attrs.update(self.__dict__)
273        attrs.update(dir(self.db))
274        return sorted(attrs)
275
276    # Context manager methods
277
278    def __enter__(self):
279        """Enter the runtime context. This will start a transactio."""
280        self.begin()
281        return self
282
283    def __exit__(self, et, ev, tb):
284        """Exit the runtime context. This will end the transaction."""
285        if et is None and ev is None and tb is None:
286            self.commit()
287        else:
288            self.rollback()
289
290    # Auxiliary methods
291
292    def _do_debug(self, *args):
293        """Print a debug message"""
294        if self.debug:
295            s = '\n'.join(args)
296            if isinstance(self.debug, basestring):
297                print(self.debug % s)
298            elif hasattr(self.debug, 'write'):
299                self.debug.write(s + '\n')
300            elif callable(self.debug):
301                self.debug(s)
302            else:
303                print(s)
304
305    def _escape_qualified_name(self, s):
306        """Escape a qualified name.
307
308        Escapes the name for use as an SQL identifier, unless the
309        name contains a dot, in which case the name is ambiguous
310        (could be a qualified name or just a name with a dot in it)
311        and must be quoted manually by the caller.
312        """
313        if '.' not in s:
314            s = self.escape_identifier(s)
315        return s
316
317    @staticmethod
318    def _make_bool(d):
319        """Get boolean value corresponding to d."""
320        return bool(d) if get_bool() else ('t' if d else 'f')
321
322    _bool_true_values = frozenset('t true 1 y yes on'.split())
323
324    def _prepare_bool(self, d):
325        """Prepare a boolean parameter."""
326        if isinstance(d, basestring):
327            if not d:
328                return None
329            d = d.lower() in self._bool_true_values
330        return 't' if d else 'f'
331
332    _date_literals = frozenset('current_date current_time'
333        ' current_timestamp localtime localtimestamp'.split())
334
335    def _prepare_date(self, d):
336        """Prepare a date parameter."""
337        if not d:
338            return None
339        if isinstance(d, basestring) and d.lower() in self._date_literals:
340            raise ValueError
341        return d
342
343    def _prepare_num(self, d):
344        """Prepare a numeric parameter."""
345        if not d and d != 0:
346            return None
347        return d
348
349    def _prepare_bytea(self, d):
350        """Prepare a bytea parameter."""
351        return self.escape_bytea(d)
352
353    _prepare_funcs = dict(  # quote methods for each type
354        bool=_prepare_bool, date=_prepare_date,
355        int=_prepare_num, num=_prepare_num, float=_prepare_num,
356        money=_prepare_num, bytea=_prepare_bytea)
357
358    def _prepare_param(self, value, typ, params):
359        """Prepare and add a parameter to the list."""
360        if value is not None and typ != 'text':
361            try:
362                prepare = self._prepare_funcs[typ]
363            except KeyError:
364                pass
365            else:
366                try:
367                    value = prepare(self, value)
368                except ValueError:
369                    return value
370        params.append(value)
371        return '$%d' % len(params)
372
373    def _list_params(self, params):
374        """Create a human readable parameter list."""
375        return ', '.join('$%d=%r' % (n, v) for n, v in enumerate(params, 1))
376
377    @staticmethod
378    def _prepare_qualified_param(name, param):
379        """Quote parameter representing a qualified name.
380
381        Escapes the name for use as an SQL parameter, unless the
382        name contains a dot, in which case the name is ambiguous
383        (could be a qualified name or just a name with a dot in it)
384        and must be quoted manually by the caller.
385
386        """
387        if isinstance(param, int):
388            param = "$%d" % param
389        if '.' not in name:
390            param = 'quote_ident(%s)' % (param,)
391        return param
392
393    # Public methods
394
395    # escape_string and escape_bytea exist as methods,
396    # so we define unescape_bytea as a method as well
397    unescape_bytea = staticmethod(unescape_bytea)
398
399    def close(self):
400        """Close the database connection."""
401        # Wraps shared library function so we can track state.
402        if self._closeable:
403            if self.db:
404                self.db.close()
405                self.db = None
406            else:
407                raise _int_error('Connection already closed')
408
409    def reset(self):
410        """Reset connection with current parameters.
411
412        All derived queries and large objects derived from this connection
413        will not be usable after this call.
414
415        """
416        if self.db:
417            self.db.reset()
418        else:
419            raise _int_error('Connection already closed')
420
421    def reopen(self):
422        """Reopen connection to the database.
423
424        Used in case we need another connection to the same database.
425        Note that we can still reopen a database that we have closed.
426
427        """
428        # There is no such shared library function.
429        if self._closeable:
430            db = connect(*self._args[0], **self._args[1])
431            if self.db:
432                self.db.close()
433            self.db = db
434
435    def begin(self, mode=None):
436        """Begin a transaction."""
437        qstr = 'BEGIN'
438        if mode:
439            qstr += ' ' + mode
440        return self.query(qstr)
441
442    start = begin
443
444    def commit(self):
445        """Commit the current transaction."""
446        return self.query('COMMIT')
447
448    end = commit
449
450    def rollback(self, name=None):
451        """Roll back the current transaction."""
452        qstr = 'ROLLBACK'
453        if name:
454            qstr += ' TO ' + name
455        return self.query(qstr)
456
457    def savepoint(self, name):
458        """Define a new savepoint within the current transaction."""
459        return self.query('SAVEPOINT ' + name)
460
461    def release(self, name):
462        """Destroy a previously defined savepoint."""
463        return self.query('RELEASE ' + name)
464
465    def get_parameter(self, parameter):
466        """Get the value of a run-time parameter.
467
468        If the parameter is a string, the return value will also be a string
469        that is the current setting of the run-time parameter with that name.
470
471        You can get several parameters at once by passing a list or tuple of
472        parameter names.  The return value will then be a corresponding list
473        of parameter settings.  If you pass a dict as parameter instead, its
474        values will be set to the parameter settings corresponding to its keys.
475
476        By passing the special name 'all' as the parameter, you can get a dict
477        of all existing configuration parameters.
478        """
479        if isinstance(parameter, basestring):
480            parameter = [parameter]
481            values = None
482        elif isinstance(parameter, (list, tuple)):
483            values = []
484        elif isinstance(parameter, dict):
485            values = parameter
486        else:
487            raise TypeError('The parameter must be a dict, list or string')
488        if not parameter:
489            raise TypeError('No parameter has been specified')
490        params = {} if isinstance(values, dict) else []
491        for key in parameter:
492            param = key.strip().lower() if isinstance(
493                key, basestring) else None
494            if not param:
495                raise TypeError('Invalid parameter')
496            if param == 'all':
497                q = 'SHOW ALL'
498                values = self.db.query(q).getresult()
499                values = dict(value[:2] for value in values)
500                break
501            if isinstance(values, dict):
502                params[param] = key
503            else:
504                params.append(param)
505        else:
506            for param in params:
507                q = 'SHOW %s' % (param,)
508                value = self.db.query(q).getresult()[0][0]
509                if values is None:
510                    values = value
511                elif isinstance(values, list):
512                    values.append(value)
513                else:
514                    values[params[param]] = value
515        return values
516
517    def set_parameter(self, parameter, value=None, local=False):
518        """Set the value of a run-time parameter.
519
520        If the parameter and the value are strings, the run-time parameter
521        will be set to that value.  If no value or None is passed as a value,
522        then the run-time parameter will be restored to its default value.
523
524        You can set several parameters at once by passing a list or tuple
525        of parameter names, with a single value that all parameters should
526        be set to or with a corresponding list or tuple of values.
527
528        You can also pass a dict as parameters.  In this case, you should
529        not pass a value, since the values will be taken from the dict.
530
531        By passing the special name 'all' as the parameter, you can reset
532        all existing settable run-time parameters to their default values.
533
534        If you set local to True, then the command takes effect for only the
535        current transaction.  After commit() or rollback(), the session-level
536        setting takes effect again.  Setting local to True will appear to
537        have no effect if it is executed outside a transaction, since the
538        transaction will end immediately.
539        """
540        if isinstance(parameter, basestring):
541            parameter = {parameter: value}
542        elif isinstance(parameter, (list, tuple)):
543            if isinstance(value, (list, tuple)):
544                parameter = dict(zip(parameter, value))
545            else:
546                parameter = dict.fromkeys(parameter, value)
547        elif isinstance(parameter, dict):
548            if value is not None:
549                raise ValueError(
550                    'A value must not be set when parameter is a dictionary')
551        else:
552            raise TypeError('The parameter must be a dict, list or string')
553        if not parameter:
554            raise TypeError('No parameter has been specified')
555        params = {}
556        for key, value in parameter.items():
557            param = key.strip().lower() if isinstance(
558                key, basestring) else None
559            if not param:
560                raise TypeError('Invalid parameter')
561            if param == 'all':
562                if value is not None:
563                    raise ValueError(
564                        "A value must ot be set when parameter is 'all'")
565                params = {'all': None}
566                break
567            params[param] = value
568        local = ' LOCAL' if local else ''
569        for param, value in params.items():
570            if value is None:
571                q = 'RESET%s %s' % (local, param)
572            else:
573                q = 'SET%s %s TO %s' % (local, param, value)
574            self._do_debug(q)
575            self.db.query(q)
576
577    def query(self, qstr, *args):
578        """Execute a SQL command string.
579
580        This method simply sends a SQL query to the database. If the query is
581        an insert statement that inserted exactly one row into a table that
582        has OIDs, the return value is the OID of the newly inserted row.
583        If the query is an update or delete statement, or an insert statement
584        that did not insert exactly one row in a table with OIDs, then the
585        number of rows affected is returned as a string. If it is a statement
586        that returns rows as a result (usually a select statement, but maybe
587        also an "insert/update ... returning" statement), this method returns
588        a Query object that can be accessed via getresult() or dictresult()
589        or simply printed. Otherwise, it returns `None`.
590
591        The query can contain numbered parameters of the form $1 in place
592        of any data constant. Arguments given after the query string will
593        be substituted for the corresponding numbered parameter. Parameter
594        values can also be given as a single list or tuple argument.
595        """
596        # Wraps shared library function for debugging.
597        if not self.db:
598            raise _int_error('Connection is not valid')
599        self._do_debug(qstr)
600        return self.db.query(qstr, args)
601
602    def pkey(self, table, flush=False):
603        """Get or set the primary key of a table.
604
605        Composite primary keys are represented as frozensets. Note that
606        this raises a KeyError if the table does not have a primary key.
607
608        If flush is set then the internal cache for primary keys will
609        be flushed. This may be necessary after the database schema or
610        the search path has been changed.
611        """
612        pkeys = self._pkeys
613        if flush:
614            pkeys.clear()
615            self._do_debug('The pkey cache has been flushed')
616        try:  # cache lookup
617            pkey = pkeys[table]
618        except KeyError:  # cache miss, check the database
619            q = ("SELECT a.attname FROM pg_index i"
620                " JOIN pg_attribute a ON a.attrelid = i.indrelid"
621                " AND a.attnum = ANY(i.indkey)"
622                " AND NOT a.attisdropped"
623                " WHERE i.indrelid=%s::regclass"
624                " AND i.indisprimary") % (
625                    self._prepare_qualified_param(table, 1),)
626            pkey = self.db.query(q, (table,)).getresult()
627            if not pkey:
628                raise KeyError('Table %s has no primary key' % table)
629            if len(pkey) > 1:
630                pkey = frozenset(k[0] for k in pkey)
631            else:
632                pkey = pkey[0][0]
633            pkeys[table] = pkey  # cache it
634        return pkey
635
636    def get_databases(self):
637        """Get list of databases in the system."""
638        return [s[0] for s in
639            self.db.query('SELECT datname FROM pg_database').getresult()]
640
641    def get_relations(self, kinds=None):
642        """Get list of relations in connected database of specified kinds.
643
644        If kinds is None or empty, all kinds of relations are returned.
645        Otherwise kinds can be a string or sequence of type letters
646        specifying which kind of relations you want to list.
647        """
648        where = " AND r.relkind IN (%s)" % ','.join(
649            ["'%s'" % k for k in kinds]) if kinds else ''
650        q = ("SELECT quote_ident(s.nspname)||'.'||quote_ident(r.relname)"
651            " FROM pg_class r"
652            " JOIN pg_namespace s ON s.oid = r.relnamespace"
653            " WHERE s.nspname NOT SIMILAR"
654            " TO 'pg/_%%|information/_schema' ESCAPE '/' %s"
655            " ORDER BY s.nspname, r.relname") % where
656        return [r[0] for r in self.db.query(q).getresult()]
657
658    def get_tables(self):
659        """Return list of tables in connected database."""
660        return self.get_relations('r')
661
662    def get_attnames(self, table, flush=False):
663        """Given the name of a table, dig out the set of attribute names.
664
665        Returns a dictionary of attribute names (the names are the keys,
666        the values are the names of the attributes' types).
667
668        If your Python version supports this, the dictionary will be an
669        OrderedDictionary with the column names in the right order.
670
671        If flush is set then the internal cache for attribute names will
672        be flushed. This may be necessary after the database schema or
673        the search path has been changed.
674
675        By default, only a limited number of simple types will be returned.
676        You can get the regular types after calling use_regtypes(True).
677        """
678        attnames = self._attnames
679        if flush:
680            attnames.clear()
681            self._do_debug('The attnames cache has been flushed')
682        try:  # cache lookup
683            names = attnames[table]
684        except KeyError:  # cache miss, check the database
685            q = ("SELECT a.attname, t.typname%s"
686                " FROM pg_attribute a"
687                " JOIN pg_type t ON t.oid = a.atttypid"
688                " WHERE a.attrelid = %s::regclass"
689                " AND (a.attnum > 0 OR a.attname = 'oid')"
690                " AND NOT a.attisdropped ORDER BY a.attnum") % (
691                    '::regtype' if self._regtypes else '',
692                    self._prepare_qualified_param(table, 1))
693            names = self.db.query(q, (table,)).getresult()
694            if not names:
695                raise KeyError('Table %s does not exist' % table)
696            if not self._regtypes:
697                names = ((name, _simpletype(typ)) for name, typ in names)
698            names = OrderedDict(names)
699            attnames[table] = names  # cache it
700        return names
701
702    def use_regtypes(self, regtypes=None):
703        """Use regular type names instead of simplified type names."""
704        if regtypes is None:
705            return self._regtypes
706        else:
707            regtypes = bool(regtypes)
708            if regtypes != self._regtypes:
709                self._regtypes = regtypes
710                self._attnames.clear()
711            return regtypes
712
713    def has_table_privilege(self, table, privilege='select'):
714        """Check whether current user has specified table privilege."""
715        privilege = privilege.lower()
716        try:  # ask cache
717            return self._privileges[(table, privilege)]
718        except KeyError:  # cache miss, ask the database
719            q = "SELECT has_table_privilege(%s, $2)" % (
720                self._prepare_qualified_param(table, 1),)
721            q = self.db.query(q, (table, privilege))
722            ret = q.getresult()[0][0] == self._make_bool(True)
723            self._privileges[(table, privilege)] = ret  # cache it
724            return ret
725
726    def get(self, table, row, keyname=None):
727        """Get a row from a database table or view.
728
729        This method is the basic mechanism to get a single row.  The keyname
730        that the key specifies a unique row.  If keyname is not specified
731        then the primary key for the table is used.  If row is a dictionary
732        then the value for the key is taken from it and it is modified to
733        include the new values, replacing existing values where necessary.
734        For a composite key, keyname can also be a sequence of key names.
735        The OID is also put into the dictionary if the table has one, but
736        in order to allow the caller to work with multiple tables, it is
737        munged as "oid(table)".
738        """
739        if table.endswith('*'):  # scan descendant tables?
740            table = table[:-1].rstrip()  # need parent table name
741        if not keyname:
742            # use the primary key by default
743            try:
744                keyname = self.pkey(table)
745            except KeyError:
746                raise _prg_error('Table %s has no primary key' % table)
747        attnames = self.get_attnames(table)
748        params = []
749        param = partial(self._prepare_param, params=params)
750        col = self.escape_identifier
751        # We want the oid for later updates if that isn't the key.
752        # To allow users to work with multiple tables, we munge
753        # the name of the "oid" key by adding the name of the table.
754        qoid = _oid_key(table)
755        if keyname == 'oid':
756            if isinstance(row, dict):
757                if qoid not in row:
758                    raise _db_error('%s not in row' % qoid)
759            else:
760                row = {qoid: row}
761            what = '*'
762            where = 'oid = %s' % param(row[qoid], 'int')
763        else:
764            keyname = [keyname] if isinstance(
765                keyname, basestring) else sorted(keyname)
766            if not isinstance(row, dict):
767                if len(keyname) > 1:
768                    raise _prg_error('Composite key needs dict as row')
769                row = dict((k, row) for k in keyname)
770            what = ', '.join(col(k) for k in attnames)
771            where = ' AND '.join('%s = %s' % (
772                col(k), param(row[k], attnames[k])) for k in keyname)
773        q = 'SELECT %s FROM %s WHERE %s LIMIT 1' % (
774            what, self._escape_qualified_name(table), where)
775        self._do_debug(q, params)
776        q = self.db.query(q, params)
777        res = q.dictresult()
778        if not res:
779            raise _db_error('No such record in %s\nwhere %s\nwith %s' % (
780                table, where, self._list_params(params)))
781        for n, value in res[0].items():
782            if n == 'oid':
783                n = qoid
784            elif attnames.get(n) == 'bytea':
785                value = self.unescape_bytea(value)
786            row[n] = value
787        return row
788
789    def insert(self, table, row=None, **kw):
790        """Insert a row into a database table.
791
792        This method inserts a row into a table.  The name of the table must
793        be passed as the first parameter.  The other parameters are used for
794        providing the data of the row that shall be inserted into the table.
795        If a dictionary is supplied as the second parameter, it starts with
796        that.  Otherwise it uses a blank dictionary. Either way the dictionary
797        is updated from the keywords.
798
799        The dictionary is then reloaded with the values actually inserted in
800        order to pick up values modified by rules, triggers, etc.
801
802        Note: The method currently doesn't support insert into views
803        although PostgreSQL does.
804        """
805        if 'oid' in kw:
806            del kw['oid']
807        if row is None:
808            row = {}
809        row.update(kw)
810        attnames = self.get_attnames(table)
811        params = []
812        param = partial(self._prepare_param, params=params)
813        col = self.escape_identifier
814        names, values = [], []
815        for n in attnames:
816            if n in row:
817                names.append(col(n))
818                values.append(param(row[n], attnames[n]))
819        names, values = ', '.join(names), ', '.join(values)
820        ret = 'oid, *' if 'oid' in attnames else '*'
821        q = 'INSERT INTO %s (%s) VALUES (%s) RETURNING %s' % (
822            self._escape_qualified_name(table), names, values, ret)
823        self._do_debug(q, params)
824        q = self.db.query(q, params)
825        res = q.dictresult()
826        if not res:
827            raise _int_error('Insert operation did not return new values')
828        for n, value in res[0].items():
829            if n == 'oid':
830                n = _oid_key(table)
831            elif attnames.get(n) == 'bytea' and value is not None:
832                value = self.unescape_bytea(value)
833            row[n] = value
834        return row
835
836    def update(self, table, row=None, **kw):
837        """Update an existing row in a database table.
838
839        Similar to insert but updates an existing row.  The update is based
840        on the OID value as munged by get or passed as keyword, or on the
841        primary key of the table.  The dictionary is modified to reflect
842        any changes caused by the update due to triggers, rules, default
843        values, etc.
844        """
845        # Update always works on the oid which get() returns if available,
846        # otherwise use the primary key.  Fail if neither.
847        # Note that we only accept oid key from named args for safety.
848        qoid = _oid_key(table)
849        if 'oid' in kw:
850            kw[qoid] = kw['oid']
851            del kw['oid']
852        if row is None:
853            row = {}
854        row.update(kw)
855        attnames = self.get_attnames(table)
856        params = []
857        param = partial(self._prepare_param, params=params)
858        col = self.escape_identifier
859        if qoid in row:
860            where = 'oid = %s' % param(row[qoid], 'int')
861            keyname = []
862        else:
863            try:
864                keyname = self.pkey(table)
865            except KeyError:
866                raise _prg_error('Table %s has no primary key' % table)
867            keyname = [keyname] if isinstance(
868                keyname, basestring) else sorted(keyname)
869            try:
870                where = ' AND '.join('%s = %s' % (
871                    col(k), param(row[k], attnames[k])) for k in keyname)
872            except KeyError:
873                raise _prg_error('Update operation needs primary key or oid')
874        keyname = set(keyname)
875        keyname.add('oid')
876        values = []
877        for n in attnames:
878            if n in row and n not in keyname:
879                values.append('%s = %s' % (col(n), param(row[n], attnames[n])))
880        if not values:
881            return row
882        values = ', '.join(values)
883        ret = 'oid, *' if 'oid' in attnames else '*'
884        q = 'UPDATE %s SET %s WHERE %s RETURNING %s' % (
885            self._escape_qualified_name(table), values, where, ret)
886        self._do_debug(q, params)
887        q = self.db.query(q, params)
888        res = q.dictresult()
889        if res:  # may be empty when row does not exist
890            for n, value in res[0].items():
891                if n == 'oid':
892                    n = qoid
893                elif attnames.get(n) == 'bytea' and value is not None:
894                    value = self.unescape_bytea(value)
895                row[n] = value
896        return row
897
898    def upsert(self, table, row=None, **kw):
899        """Insert a row into a database table with conflict resolution
900
901        This method inserts a row into a table, but instead of raising a
902        ProgrammingError exception in case a row with the same primary key
903        already exists, an update will be executed instead.  This will be
904        performed as a single atomic operation on the database, so race
905        conditions can be avoided.
906
907        Like the insert method, the first parameter is the name of the
908        table and the second parameter can be used to pass the values to
909        be inserted as a dictionary.
910
911        Unlike the insert und update statement, keyword parameters are not
912        used to modify the dictionary, but to specify which columns shall
913        be updated in case of a conflict, and in which way:
914
915        A value of False or None means the column shall not be updated,
916        a value of True means the column shall be updated with the value
917        that has been proposed for insertion, i.e. has been passed as value
918        in the dictionary.  Columns that are not specified by keywords but
919        appear as keys in the dictionary are also updated like in the case
920        keywords had been passed with the value True.
921
922        So if in the case of a conflict you want to update every column that
923        has been passed in the dictionary row , you would call upsert(table, row).
924        If you don't want to do anything in case of a conflict, i.e. leave
925        the existing row as it is, call upsert(table, row, **dict.fromkeys(row)).
926
927        If you need more fine-grained control of what gets updated, you can
928        also pass strings in the keyword parameters.  These strings will
929        be used as SQL expressions for the update columns.  In these
930        expressions you can refer to the value that already exists in
931        the table by prefixing the column name with "included.", and to
932        the value that has been proposed for insertion by prefixing the
933        column name with the "excluded."
934
935        The dictionary is modified in any case to reflect the values in
936        the database after the operation has completed.
937
938        Note: The method uses the PostgreSQL "upsert" feature which is
939        only available since PostgreSQL 9.5.
940        """
941        if 'oid' in kw:
942            del kw['oid']
943        if row is None:
944            row = {}
945        attnames = self.get_attnames(table)
946        params = []
947        param = partial(self._prepare_param,params=params)
948        col = self.escape_identifier
949        names, values, updates = [], [], []
950        for n in attnames:
951            if n in row:
952                names.append(col(n))
953                values.append(param(row[n], attnames[n]))
954        names, values = ', '.join(names), ', '.join(values)
955        try:
956            keyname = self.pkey(table)
957        except KeyError:
958            raise _prg_error('Table %s has no primary key' % table)
959        keyname = [keyname] if isinstance(
960            keyname, basestring) else sorted(keyname)
961        try:
962            target = ', '.join(col(k) for k in keyname)
963        except KeyError:
964            raise _prg_error('Upsert operation needs primary key or oid')
965        update = []
966        keyname = set(keyname)
967        keyname.add('oid')
968        for n in attnames:
969            if n not in keyname:
970                value = kw.get(n, True)
971                if value:
972                    if not isinstance(value, basestring):
973                        value = 'excluded.%s' % col(n)
974                    update.append('%s = %s' % (col(n), value))
975        if not values and not update:
976            return row
977        do = 'update set %s' % ', '.join(update) if update else 'nothing'
978        ret = 'oid, *' if 'oid' in attnames else '*'
979        q = ('INSERT INTO %s AS included (%s) VALUES (%s)'
980            ' ON CONFLICT (%s) DO %s RETURNING %s') % (
981                self._escape_qualified_name(table), names, values,
982                target, do, ret)
983        self._do_debug(q, params)
984        try:
985            q = self.db.query(q, params)
986        except ProgrammingError:
987            if self.server_version < 90500:
988                raise _prg_error(
989                    'Upsert operation is not supported by PostgreSQL version')
990            raise  # re-raise original error
991        res = q.dictresult()
992        if res:  # may be empty with "do nothing"
993            for n, value in res[0].items():
994                if n == 'oid':
995                    n = _oid_key(table)
996                elif attnames.get(n) == 'bytea':
997                    value = self.unescape_bytea(value)
998                row[n] = value
999        elif update:
1000            raise _int_error('Upsert operation did not return new values')
1001        else:
1002            self.get(table, row)
1003        return row
1004
1005    def clear(self, table, row=None):
1006        """Clear all the attributes to values determined by the types.
1007
1008        Numeric types are set to 0, Booleans are set to false, and everything
1009        else is set to the empty string.  If the row argument is present,
1010        it is used as the row dictionary and any entries matching attribute
1011        names are cleared with everything else left unchanged.
1012        """
1013        # At some point we will need a way to get defaults from a table.
1014        if row is None:
1015            row = {}  # empty if argument is not present
1016        attnames = self.get_attnames(table)
1017        for n, t in attnames.items():
1018            if n == 'oid':
1019                continue
1020            if t in ('int', 'integer', 'smallint', 'bigint',
1021                    'float', 'real', 'double precision',
1022                    'num', 'numeric', 'money'):
1023                row[n] = 0
1024            elif t in ('bool', 'boolean'):
1025                row[n] = self._make_bool(False)
1026            else:
1027                row[n] = ''
1028        return row
1029
1030    def delete(self, table, row=None, **kw):
1031        """Delete an existing row in a database table.
1032
1033        This method deletes the row from a table.  It deletes based on the
1034        OID value as munged by get or passed as keyword, or on the primary
1035        key of the table.  The return value is the number of deleted rows
1036        (i.e. 0 if the row did not exist and 1 if the row was deleted).
1037        """
1038        # Like update, delete works on the oid.
1039        # One day we will be testing that the record to be deleted
1040        # isn't referenced somewhere (or else PostgreSQL will).
1041        # Note that we only accept oid key from named args for safety.
1042        qoid = _oid_key(table)
1043        if 'oid' in kw:
1044            kw[qoid] = kw['oid']
1045            del kw['oid']
1046        if row is None:
1047            row = {}
1048        row.update(kw)
1049        params = []
1050        param = partial(self._prepare_param, params=params)
1051        if qoid in row:
1052            where = 'oid = %s' % param(row[qoid], 'int')
1053        else:
1054            try:
1055                keyname = self.pkey(table)
1056            except KeyError:
1057                raise _prg_error('Table %s has no primary key' % table)
1058            keyname = [keyname] if isinstance(
1059                keyname, basestring) else sorted(keyname)
1060            attnames = self.get_attnames(table)
1061            col = self.escape_identifier
1062            try:
1063                where = ' AND '.join('%s = %s' % (
1064                    col(k), param(row[k], attnames[k])) for k in keyname)
1065            except KeyError:
1066                raise _prg_error('Delete operation needs primary key or oid')
1067        q = 'DELETE FROM %s WHERE %s' % (
1068            self._escape_qualified_name(table), where)
1069        self._do_debug(q, params)
1070        res = self.db.query(q, params)
1071        return int(res)
1072
1073    def notification_handler(self, event, callback, arg_dict={}, timeout=None):
1074        """Get notification handler that will run the given callback."""
1075        return NotificationHandler(self.db, event, callback, arg_dict, timeout)
1076
1077
1078# if run as script, print some information
1079
1080if __name__ == '__main__':
1081    print('PyGreSQL version' + version)
1082    print('')
1083    print(__doc__)
Note: See TracBrowser for help on using the repository browser.