source: trunk/pg.py @ 747

Last change on this file since 747 was 747, checked in by cito, 4 years ago

Improve the get/set_parameter methods

In addition to a list, also allow a set as parameter.

  • Property svn:keywords set to Id
File size: 41.6 KB
RevLine 
[521]1#! /usr/bin/python
[222]2#
[99]3# pg.py
[222]4#
[404]5# $Id: pg.py 747 2016-01-15 11:06:51Z cito $
[222]6#
[14]7
[204]8"""PyGreSQL classic interface.
[14]9
[204]10This pg module implements some basic database management stuff.
11It includes the _pg module and builds on it, providing the higher
[447]12level wrapper class named DB with additional functionality.
[204]13This is known as the "classic" ("old style") PyGreSQL interface.
14For a DB-API 2 compliant interface use the newer pgdb module.
[505]15"""
[504]16
[737]17# Copyright (c) 1997-2016 by D'Arcy J.M. Cain.
[505]18#
19# Contributions made by Ch. Zwerschke and others.
20#
21# The notification handler is based on pgnotify which is
22# Copyright (c) 2001 Ng Pheng Siong. All rights reserved.
23#
24# Permission to use, copy, modify, and distribute this software and its
25# documentation for any purpose and without fee is hereby granted,
26# provided that the above copyright notice appear in all copies and that
27# both that copyright notice and this permission notice appear in
28# supporting documentation.
[504]29
[627]30from __future__ import print_function
31
[506]32from _pg import *
33
[487]34import select
[506]35import warnings
[14]36
[553]37from decimal import Decimal
38from collections import namedtuple
[730]39from functools import partial
[332]40
[591]41try:
[739]42    from collections import OrderedDict
43except ImportError:  # Python 2.6 or 3.0
44    OrderedDict = dict
45
46try:
[591]47    basestring
48except NameError:  # Python >= 3.0
49    basestring = (str, bytes)
50
[553]51set_decimal(Decimal)
52
53
[722]54# Auxiliary functions that are independent from a DB connection:
[204]55
[738]56def _oid_key(table):
57    """Build oid key from a table name."""
58    return 'oid(%s)' % table
[204]59
[436]60
[729]61def _simpletype(typ):
62    """Determine a simplified name a pg_type name."""
63    if typ.startswith('bool'):
64        return 'bool'
65    if typ.startswith(('abstime', 'date', 'interval', 'timestamp')):
66        return 'date'
67    if typ.startswith(('cid', 'oid', 'int', 'xid')):
68        return 'int'
69    if typ.startswith('float'):
70        return 'float'
71    if typ.startswith('numeric'):
72        return 'num'
73    if typ.startswith('money'):
74        return 'money'
75    if typ.startswith('bytea'):
76        return 'bytea'
77    return 'text'
[332]78
[436]79
[568]80def _namedresult(q):
81    """Get query result as named tuples."""
82    row = namedtuple('Row', q.listfields())
83    return [row(*r) for r in q.getresult()]
[447]84
[568]85set_namedresult(_namedresult)
[447]86
87
[434]88def _db_error(msg, cls=DatabaseError):
[740]89    """Return DatabaseError with empty sqlstate attribute."""
[434]90    error = cls(msg)
91    error.sqlstate = None
92    return error
[389]93
[436]94
[434]95def _int_error(msg):
[740]96    """Return InternalError."""
[434]97    return _db_error(msg, InternalError)
98
[436]99
[434]100def _prg_error(msg):
[740]101    """Return ProgrammingError."""
[434]102    return _db_error(msg, ProgrammingError)
103
[493]104
[511]105class NotificationHandler(object):
[487]106    """A PostgreSQL client-side asynchronous notification handler."""
[434]107
[503]108    def __init__(self, db, event, callback, arg_dict=None, timeout=None):
[494]109        """Initialize the notification handler.
110
[511]111        db       - PostgreSQL connection object.
112        event    - Event (notification channel) to LISTEN for.
113        callback - Event callback function.
[487]114        arg_dict - A dictionary passed as the argument to the callback.
115        timeout  - Timeout in seconds; a floating point number denotes
[511]116                   fractions of seconds. If it is absent or None, the
117                   callers will never time out.
118        """
[506]119        if isinstance(db, DB):
120            db = db.db
[503]121        self.db = db
[487]122        self.event = event
[511]123        self.stop_event = 'stop_%s' % event
[503]124        self.listening = False
[487]125        self.callback = callback
[500]126        if arg_dict is None:
127            arg_dict = {}
128        self.arg_dict = arg_dict
[487]129        self.timeout = timeout
130
131    def __del__(self):
[503]132        self.close()
[487]133
[503]134    def close(self):
[511]135        """Stop listening and close the connection."""
[503]136        if self.db:
137            self.unlisten()
138            self.db.close()
139            self.db = None
140
141    def listen(self):
[511]142        """Start listening for the event and the stop event."""
[503]143        if not self.listening:
[511]144            self.db.query('listen "%s"' % self.event)
145            self.db.query('listen "%s"' % self.stop_event)
146            self.listening = True
[503]147
148    def unlisten(self):
[511]149        """Stop listening for the event and the stop event."""
[503]150        if self.listening:
[511]151            self.db.query('unlisten "%s"' % self.event)
152            self.db.query('unlisten "%s"' % self.stop_event)
153            self.listening = False
[503]154
[511]155    def notify(self, db=None, stop=False, payload=None):
156        """Generate a notification.
157
158        Note: If the main loop is running in another thread, you must pass
159        a different database connection to avoid a collision.
160        """
161        if not db:
162            db = self.db
[506]163        if self.listening:
[622]164            q = 'notify "%s"' % (self.stop_event if stop else self.event)
[506]165            if payload:
166                q += ", '%s'" % payload
[511]167            return db.query(q)
[506]168
[511]169    def __call__(self, close=False):
170        """Invoke the notification handler.
[494]171
[506]172        The handler is a loop that actually LISTENs for two NOTIFY messages:
[487]173
174        <event> and stop_<event>.
175
176        When either of these NOTIFY messages are received, its associated
177        'pid' and 'event' are inserted into <arg_dict>, and the callback is
178        invoked with <arg_dict>. If the NOTIFY message is stop_<event>, the
[506]179        handler UNLISTENs both <event> and stop_<event> and exits.
[487]180
[511]181        Note: If you run this loop in another thread, don't use the same
182        database connection for database operations in the main thread.
[506]183        """
[503]184        self.listen()
185        _ilist = [self.db.fileno()]
[487]186
[621]187        while self.listening:
[494]188            ilist, _olist, _elist = select.select(_ilist, [], [], self.timeout)
[621]189            if ilist:
190                while self.listening:
191                    notice = self.db.getnotify()
192                    if not notice:  # no more messages
193                        break
194                    event, pid, extra = notice
195                    if event not in (self.event, self.stop_event):
196                        self.unlisten()
197                        raise _db_error(
[740]198                            'Listening for "%s" and "%s", but notified of "%s"'
[621]199                            % (self.event, self.stop_event, event))
200                    if event == self.stop_event:
201                        self.unlisten()
[487]202                    self.arg_dict['pid'] = pid
203                    self.arg_dict['event'] = event
[500]204                    self.arg_dict['extra'] = extra
[487]205                    self.callback(self.arg_dict)
[621]206            else:   # we timed out
207                self.unlisten()
208                self.callback(None)
[487]209
210
[506]211def pgnotify(*args, **kw):
[511]212    """Same as NotificationHandler, under the traditional name."""
[740]213    warnings.warn("pgnotify is deprecated, use NotificationHandler instead",
[506]214        DeprecationWarning, stacklevel=2)
[511]215    return NotificationHandler(*args, **kw)
[506]216
217
218# The actual PostGreSQL database connection interface:
219
[357]220class DB(object):
[346]221    """Wrapper class for the _pg connection type."""
[14]222
[346]223    def __init__(self, *args, **kw):
[740]224        """Create a new connection
[333]225
[346]226        You can pass either the connection parameters or an existing
227        _pg or pgdb connection. This allows you to use the methods
228        of the classic pg interface with a DB-API 2 pgdb connection.
229        """
230        if not args and len(kw) == 1:
231            db = kw.get('db')
232        elif not kw and len(args) == 1:
233            db = args[0]
234        else:
235            db = None
236        if db:
237            if isinstance(db, DB):
238                db = db.db
239            else:
240                try:
241                    db = db._cnx
242                except AttributeError:
243                    pass
244        if not db or not hasattr(db, 'db') or not hasattr(db, 'query'):
245            db = connect(*args, **kw)
[460]246            self._closeable = True
[346]247        else:
[460]248            self._closeable = False
[346]249        self.db = db
250        self.dbname = db.db
[442]251        self._regtypes = False
[346]252        self._attnames = {}
253        self._pkeys = {}
[391]254        self._privileges = {}
[346]255        self._args = args, kw
[436]256        self.debug = None  # For debugging scripts, this can be set
[346]257            # * to a string format specification (e.g. in CGI set to "%s<BR>"),
[461]258            # * to a file object to write debug statements or
259            # * to a callable object which takes a string argument
260            # * to any other true value to just print debug statements
[14]261
[346]262    def __getattr__(self, name):
[567]263        # All undefined members are same as in underlying connection:
[346]264        if self.db:
265            return getattr(self.db, name)
266        else:
[434]267            raise _int_error('Connection is not valid')
[204]268
[567]269    def __dir__(self):
270        # Custom dir function including the attributes of the connection:
271        attrs = set(self.__class__.__dict__)
272        attrs.update(self.__dict__)
273        attrs.update(dir(self.db))
274        return sorted(attrs)
275
[461]276    # Context manager methods
277
278    def __enter__(self):
[740]279        """Enter the runtime context. This will start a transactio."""
[464]280        self.begin()
[461]281        return self
282
[464]283    def __exit__(self, et, ev, tb):
284        """Exit the runtime context. This will end the transaction."""
285        if et is None and ev is None and tb is None:
286            self.commit()
[461]287        else:
[464]288            self.rollback()
[461]289
[365]290    # Auxiliary methods
[287]291
[730]292    def _do_debug(self, *args):
[740]293        """Print a debug message"""
[346]294        if self.debug:
[730]295            s = '\n'.join(args)
[357]296            if isinstance(self.debug, basestring):
[520]297                print(self.debug % s)
[589]298            elif hasattr(self.debug, 'write'):
[590]299                self.debug.write(s + '\n')
[357]300            elif callable(self.debug):
[346]301                self.debug(s)
[460]302            else:
[520]303                print(s)
[120]304
[732]305    def _escape_qualified_name(self, s):
306        """Escape a qualified name.
307
308        Escapes the name for use as an SQL identifier, unless the
309        name contains a dot, in which case the name is ambiguous
310        (could be a qualified name or just a name with a dot in it)
311        and must be quoted manually by the caller.
312        """
313        if '.' not in s:
314            s = self.escape_identifier(s)
315        return s
316
[706]317    @staticmethod
318    def _make_bool(d):
319        """Get boolean value corresponding to d."""
320        return bool(d) if get_bool() else ('t' if d else 'f')
321
[730]322    _bool_true_values = frozenset('t true 1 y yes on'.split())
[204]323
[730]324    def _prepare_bool(self, d):
325        """Prepare a boolean parameter."""
[365]326        if isinstance(d, basestring):
327            if not d:
[730]328                return None
329            d = d.lower() in self._bool_true_values
330        return 't' if d else 'f'
[202]331
[365]332    _date_literals = frozenset('current_date current_time'
333        ' current_timestamp localtime localtimestamp'.split())
[151]334
[730]335    def _prepare_date(self, d):
336        """Prepare a date parameter."""
[365]337        if not d:
[730]338            return None
[365]339        if isinstance(d, basestring) and d.lower() in self._date_literals:
[730]340            raise ValueError
341        return d
[202]342
[730]343    def _prepare_num(self, d):
344        """Prepare a numeric parameter."""
[394]345        if not d and d != 0:
[730]346            return None
[493]347        return d
[27]348
[730]349    def _prepare_bytea(self, d):
[732]350        """Prepare a bytea parameter."""
[730]351        return self.escape_bytea(d)
[662]352
[730]353    _prepare_funcs = dict(  # quote methods for each type
354        bool=_prepare_bool, date=_prepare_date,
355        int=_prepare_num, num=_prepare_num, float=_prepare_num,
356        money=_prepare_num, bytea=_prepare_bytea)
[662]357
[730]358    def _prepare_param(self, value, typ, params):
359        """Prepare and add a parameter to the list."""
360        if value is not None and typ != 'text':
361            try:
362                prepare = self._prepare_funcs[typ]
363            except KeyError:
364                pass
365            else:
366                try:
367                    value = prepare(self, value)
368                except ValueError:
369                    return value
370        params.append(value)
371        return '$%d' % len(params)
[662]372
[743]373    def _list_params(self, params):
374        """Create a human readable parameter list."""
375        return ', '.join('$%d=%r' % (n, v) for n, v in enumerate(params, 1))
376
[732]377    @staticmethod
[738]378    def _prepare_qualified_param(name, param):
[732]379        """Quote parameter representing a qualified name.
380
381        Escapes the name for use as an SQL parameter, unless the
382        name contains a dot, in which case the name is ambiguous
383        (could be a qualified name or just a name with a dot in it)
384        and must be quoted manually by the caller.
385
386        """
387        if isinstance(param, int):
388            param = "$%d" % param
[738]389        if '.' not in name:
[732]390            param = 'quote_ident(%s)' % (param,)
391        return param
392
[365]393    # Public methods
394
395    # escape_string and escape_bytea exist as methods,
396    # so we define unescape_bytea as a method as well
397    unescape_bytea = staticmethod(unescape_bytea)
398
399    def close(self):
400        """Close the database connection."""
401        # Wraps shared library function so we can track state.
402        if self._closeable:
403            if self.db:
404                self.db.close()
405                self.db = None
406            else:
[434]407                raise _int_error('Connection already closed')
[365]408
[383]409    def reset(self):
410        """Reset connection with current parameters.
411
412        All derived queries and large objects derived from this connection
413        will not be usable after this call.
414
415        """
[436]416        if self.db:
417            self.db.reset()
418        else:
419            raise _int_error('Connection already closed')
[383]420
[365]421    def reopen(self):
422        """Reopen connection to the database.
423
424        Used in case we need another connection to the same database.
425        Note that we can still reopen a database that we have closed.
426
427        """
428        # There is no such shared library function.
429        if self._closeable:
430            db = connect(*self._args[0], **self._args[1])
431            if self.db:
432                self.db.close()
433            self.db = db
434
[461]435    def begin(self, mode=None):
436        """Begin a transaction."""
437        qstr = 'BEGIN'
438        if mode:
439            qstr += ' ' + mode
440        return self.query(qstr)
441
442    start = begin
443
444    def commit(self):
445        """Commit the current transaction."""
446        return self.query('COMMIT')
447
448    end = commit
449
450    def rollback(self, name=None):
[734]451        """Roll back the current transaction."""
[461]452        qstr = 'ROLLBACK'
453        if name:
454            qstr += ' TO ' + name
[464]455        return self.query(qstr)
[461]456
[734]457    def savepoint(self, name):
[461]458        """Define a new savepoint within the current transaction."""
[734]459        return self.query('SAVEPOINT ' + name)
[461]460
461    def release(self, name):
462        """Destroy a previously defined savepoint."""
463        return self.query('RELEASE ' + name)
464
[745]465    def get_parameter(self, parameter):
466        """Get the value of a run-time parameter.
467
468        If the parameter is a string, the return value will also be a string
469        that is the current setting of the run-time parameter with that name.
470
[747]471        You can get several parameters at once by passing a list, set or dict.
472        When passing a list of parameter names, the return value will be a
473        corresponding list of parameter settings.  When passing a set of
474        parameter names, a new dict will be returned, mapping these parameter
475        names to their settings.  Finally, if you pass a dict as parameter,
476        its values will be set to the current parameter settings corresponding
477        to its keys.
[745]478
479        By passing the special name 'all' as the parameter, you can get a dict
480        of all existing configuration parameters.
481        """
482        if isinstance(parameter, basestring):
483            parameter = [parameter]
484            values = None
485        elif isinstance(parameter, (list, tuple)):
486            values = []
[747]487        elif isinstance(parameter, (set, frozenset)):
488            values = {}
[745]489        elif isinstance(parameter, dict):
490            values = parameter
491        else:
[747]492            raise TypeError(
493                'The parameter must be a string, list, set or dict')
[745]494        if not parameter:
495            raise TypeError('No parameter has been specified')
496        params = {} if isinstance(values, dict) else []
497        for key in parameter:
498            param = key.strip().lower() if isinstance(
499                key, basestring) else None
500            if not param:
501                raise TypeError('Invalid parameter')
502            if param == 'all':
503                q = 'SHOW ALL'
504                values = self.db.query(q).getresult()
505                values = dict(value[:2] for value in values)
506                break
507            if isinstance(values, dict):
508                params[param] = key
509            else:
510                params.append(param)
511        else:
512            for param in params:
513                q = 'SHOW %s' % (param,)
514                value = self.db.query(q).getresult()[0][0]
515                if values is None:
516                    values = value
517                elif isinstance(values, list):
518                    values.append(value)
519                else:
520                    values[params[param]] = value
521        return values
522
523    def set_parameter(self, parameter, value=None, local=False):
524        """Set the value of a run-time parameter.
525
526        If the parameter and the value are strings, the run-time parameter
527        will be set to that value.  If no value or None is passed as a value,
528        then the run-time parameter will be restored to its default value.
529
[747]530        You can set several parameters at once by passing a list of parameter
531        names, together with a single value that all parameters should be
532        set to or with a corresponding list of values.  You can also pass
533        the parameters as a set if you only provide a single value.
534        Finally, you can pass a dict with parameter names as keys.  In this
535        case, you should not pass a value, since the values for the parameters
536        will be taken from the dict.
[745]537
538        By passing the special name 'all' as the parameter, you can reset
539        all existing settable run-time parameters to their default values.
540
541        If you set local to True, then the command takes effect for only the
542        current transaction.  After commit() or rollback(), the session-level
543        setting takes effect again.  Setting local to True will appear to
544        have no effect if it is executed outside a transaction, since the
545        transaction will end immediately.
546        """
547        if isinstance(parameter, basestring):
548            parameter = {parameter: value}
549        elif isinstance(parameter, (list, tuple)):
550            if isinstance(value, (list, tuple)):
551                parameter = dict(zip(parameter, value))
552            else:
553                parameter = dict.fromkeys(parameter, value)
[747]554        elif isinstance(parameter, (set, frozenset)):
555            if isinstance(value, (list, tuple, set, frozenset)):
556                value = set(value)
557                if len(value) == 1:
558                    value = value.pop()
559            if not(value is None or isinstance(value, basestring)):
560                raise ValueError('A single value must be specified'
561                    ' when parameter is a set')
562            parameter = dict.fromkeys(parameter, value)
[745]563        elif isinstance(parameter, dict):
564            if value is not None:
[747]565                raise ValueError('A value must not be specified'
566                    ' when parameter is a dictionary')
[745]567        else:
[747]568            raise TypeError(
569                'The parameter must be a string, list, set or dict')
[745]570        if not parameter:
571            raise TypeError('No parameter has been specified')
572        params = {}
573        for key, value in parameter.items():
574            param = key.strip().lower() if isinstance(
575                key, basestring) else None
576            if not param:
577                raise TypeError('Invalid parameter')
578            if param == 'all':
579                if value is not None:
580                    raise ValueError(
581                        "A value must ot be set when parameter is 'all'")
582                params = {'all': None}
583                break
584            params[param] = value
585        local = ' LOCAL' if local else ''
586        for param, value in params.items():
587            if value is None:
588                q = 'RESET%s %s' % (local, param)
589            else:
590                q = 'SET%s %s TO %s' % (local, param, value)
591            self._do_debug(q)
592            self.db.query(q)
593
[446]594    def query(self, qstr, *args):
[740]595        """Execute a SQL command string.
[365]596
597        This method simply sends a SQL query to the database. If the query is
[391]598        an insert statement that inserted exactly one row into a table that
599        has OIDs, the return value is the OID of the newly inserted row.
600        If the query is an update or delete statement, or an insert statement
601        that did not insert exactly one row in a table with OIDs, then the
[563]602        number of rows affected is returned as a string. If it is a statement
[391]603        that returns rows as a result (usually a select statement, but maybe
604        also an "insert/update ... returning" statement), this method returns
[723]605        a Query object that can be accessed via getresult() or dictresult()
[391]606        or simply printed. Otherwise, it returns `None`.
[365]607
[446]608        The query can contain numbered parameters of the form $1 in place
609        of any data constant. Arguments given after the query string will
610        be substituted for the corresponding numbered parameter. Parameter
611        values can also be given as a single list or tuple argument.
[365]612        """
613        # Wraps shared library function for debugging.
614        if not self.db:
[434]615            raise _int_error('Connection is not valid')
[365]616        self._do_debug(qstr)
[446]617        return self.db.query(qstr, args)
[365]618
[738]619    def pkey(self, table, flush=False):
[740]620        """Get or set the primary key of a table.
[202]621
[390]622        Composite primary keys are represented as frozensets. Note that
[729]623        this raises a KeyError if the table does not have a primary key.
[389]624
[729]625        If flush is set then the internal cache for primary keys will
626        be flushed. This may be necessary after the database schema or
627        the search path has been changed.
[346]628        """
[729]629        pkeys = self._pkeys
630        if flush:
631            pkeys.clear()
[740]632            self._do_debug('The pkey cache has been flushed')
[729]633        try:  # cache lookup
[738]634            pkey = pkeys[table]
[729]635        except KeyError:  # cache miss, check the database
636            q = ("SELECT a.attname FROM pg_index i"
637                " JOIN pg_attribute a ON a.attrelid = i.indrelid"
638                " AND a.attnum = ANY(i.indkey)"
[720]639                " AND NOT a.attisdropped"
[729]640                " WHERE i.indrelid=%s::regclass"
[738]641                " AND i.indisprimary") % (
642                    self._prepare_qualified_param(table, 1),)
643            pkey = self.db.query(q, (table,)).getresult()
[729]644            if not pkey:
[738]645                raise KeyError('Table %s has no primary key' % table)
[729]646            if len(pkey) > 1:
[735]647                pkey = frozenset(k[0] for k in pkey)
[729]648            else:
649                pkey = pkey[0][0]
[738]650            pkeys[table] = pkey  # cache it
[729]651        return pkey
[385]652
[346]653    def get_databases(self):
654        """Get list of databases in the system."""
655        return [s[0] for s in
656            self.db.query('SELECT datname FROM pg_database').getresult()]
[14]657
[369]658    def get_relations(self, kinds=None):
[346]659        """Get list of relations in connected database of specified kinds.
[232]660
[727]661        If kinds is None or empty, all kinds of relations are returned.
662        Otherwise kinds can be a string or sequence of type letters
663        specifying which kind of relations you want to list.
[346]664        """
[720]665        where = " AND r.relkind IN (%s)" % ','.join(
666            ["'%s'" % k for k in kinds]) if kinds else ''
[729]667        q = ("SELECT quote_ident(s.nspname)||'.'||quote_ident(r.relname)"
[720]668            " FROM pg_class r"
669            " JOIN pg_namespace s ON s.oid = r.relnamespace"
670            " WHERE s.nspname NOT SIMILAR"
671            " TO 'pg/_%%|information/_schema' ESCAPE '/' %s"
[729]672            " ORDER BY s.nspname, r.relname") % where
673        return [r[0] for r in self.db.query(q).getresult()]
[232]674
[346]675    def get_tables(self):
676        """Return list of tables in connected database."""
677        return self.get_relations('r')
[232]678
[738]679    def get_attnames(self, table, flush=False):
[740]680        """Given the name of a table, dig out the set of attribute names.
[121]681
[346]682        Returns a dictionary of attribute names (the names are the keys,
683        the values are the names of the attributes' types).
[729]684
[739]685        If your Python version supports this, the dictionary will be an
686        OrderedDictionary with the column names in the right order.
[202]687
[739]688        If flush is set then the internal cache for attribute names will
689        be flushed. This may be necessary after the database schema or
690        the search path has been changed.
691
[442]692        By default, only a limited number of simple types will be returned.
693        You can get the regular types after calling use_regtypes(True).
[346]694        """
[729]695        attnames = self._attnames
696        if flush:
697            attnames.clear()
[740]698            self._do_debug('The attnames cache has been flushed')
[729]699        try:  # cache lookup
[738]700            names = attnames[table]
[729]701        except KeyError:  # cache miss, check the database
702            q = ("SELECT a.attname, t.typname%s"
703                " FROM pg_attribute a"
704                " JOIN pg_type t ON t.oid = a.atttypid"
705                " WHERE a.attrelid = %s::regclass"
706                " AND (a.attnum > 0 OR a.attname = 'oid')"
[739]707                " AND NOT a.attisdropped ORDER BY a.attnum") % (
[729]708                    '::regtype' if self._regtypes else '',
[738]709                    self._prepare_qualified_param(table, 1))
710            names = self.db.query(q, (table,)).getresult()
[729]711            if not names:
[738]712                raise KeyError('Table %s does not exist' % table)
[739]713            if not self._regtypes:
714                names = ((name, _simpletype(typ)) for name, typ in names)
715            names = OrderedDict(names)
[738]716            attnames[table] = names  # cache it
[729]717        return names
[439]718
[442]719    def use_regtypes(self, regtypes=None):
720        """Use regular type names instead of simplified type names."""
721        if regtypes is None:
722            return self._regtypes
723        else:
724            regtypes = bool(regtypes)
725            if regtypes != self._regtypes:
726                self._regtypes = regtypes
727                self._attnames.clear()
728            return regtypes
729
[738]730    def has_table_privilege(self, table, privilege='select'):
[391]731        """Check whether current user has specified table privilege."""
732        privilege = privilege.lower()
[729]733        try:  # ask cache
[738]734            return self._privileges[(table, privilege)]
[729]735        except KeyError:  # cache miss, ask the database
736            q = "SELECT has_table_privilege(%s, $2)" % (
[738]737                self._prepare_qualified_param(table, 1),)
738            q = self.db.query(q, (table, privilege))
[721]739            ret = q.getresult()[0][0] == self._make_bool(True)
[738]740            self._privileges[(table, privilege)] = ret  # cache it
[391]741            return ret
742
[738]743    def get(self, table, row, keyname=None):
[718]744        """Get a row from a database table or view.
[14]745
[389]746        This method is the basic mechanism to get a single row.  The keyname
[346]747        that the key specifies a unique row.  If keyname is not specified
[738]748        then the primary key for the table is used.  If row is a dictionary
[346]749        then the value for the key is taken from it and it is modified to
750        include the new values, replacing existing values where necessary.
[389]751        For a composite key, keyname can also be a sequence of key names.
[380]752        The OID is also put into the dictionary if the table has one, but
753        in order to allow the caller to work with multiple tables, it is
[738]754        munged as "oid(table)".
[346]755        """
[738]756        if table.endswith('*'):  # scan descendant tables?
757            table = table[:-1].rstrip()  # need parent table name
[389]758        if not keyname:
759            # use the primary key by default
760            try:
[738]761                keyname = self.pkey(table)
[389]762            except KeyError:
[738]763                raise _prg_error('Table %s has no primary key' % table)
764        attnames = self.get_attnames(table)
[730]765        params = []
766        param = partial(self._prepare_param, params=params)
[732]767        col = self.escape_identifier
[735]768        # We want the oid for later updates if that isn't the key.
769        # To allow users to work with multiple tables, we munge
[738]770        # the name of the "oid" key by adding the name of the table.
771        qoid = _oid_key(table)
[346]772        if keyname == 'oid':
[738]773            if isinstance(row, dict):
774                if qoid not in row:
775                    raise _db_error('%s not in row' % qoid)
[389]776            else:
[738]777                row = {qoid: row}
[664]778            what = '*'
[738]779            where = 'oid = %s' % param(row[qoid], 'int')
[346]780        else:
[735]781            keyname = [keyname] if isinstance(
782                keyname, basestring) else sorted(keyname)
[738]783            if not isinstance(row, dict):
[389]784                if len(keyname) > 1:
[738]785                    raise _prg_error('Composite key needs dict as row')
786                row = dict((k, row) for k in keyname)
[732]787            what = ', '.join(col(k) for k in attnames)
[735]788            where = ' AND '.join('%s = %s' % (
[738]789                col(k), param(row[k], attnames[k])) for k in keyname)
[729]790        q = 'SELECT %s FROM %s WHERE %s LIMIT 1' % (
[738]791            what, self._escape_qualified_name(table), where)
[730]792        self._do_debug(q, params)
[735]793        q = self.db.query(q, params)
794        res = q.dictresult()
[346]795        if not res:
[743]796            raise _db_error('No such record in %s\nwhere %s\nwith %s' % (
797                table, where, self._list_params(params)))
[664]798        for n, value in res[0].items():
799            if n == 'oid':
800                n = qoid
801            elif attnames.get(n) == 'bytea':
802                value = self.unescape_bytea(value)
[738]803            row[n] = value
804        return row
[14]805
[738]806    def insert(self, table, row=None, **kw):
[718]807        """Insert a row into a database table.
[202]808
[727]809        This method inserts a row into a table.  The name of the table must
810        be passed as the first parameter.  The other parameters are used for
811        providing the data of the row that shall be inserted into the table.
812        If a dictionary is supplied as the second parameter, it starts with
813        that.  Otherwise it uses a blank dictionary. Either way the dictionary
814        is updated from the keywords.
[202]815
[733]816        The dictionary is then reloaded with the values actually inserted in
817        order to pick up values modified by rules, triggers, etc.
[382]818
[346]819        Note: The method currently doesn't support insert into views
820        although PostgreSQL does.
821        """
[735]822        if 'oid' in kw:
823            del kw['oid']
[738]824        if row is None:
825            row = {}
826        row.update(kw)
827        attnames = self.get_attnames(table)
[730]828        params = []
829        param = partial(self._prepare_param, params=params)
[732]830        col = self.escape_identifier
[493]831        names, values = [], []
[389]832        for n in attnames:
[738]833            if n in row:
[732]834                names.append(col(n))
[738]835                values.append(param(row[n], attnames[n]))
[389]836        names, values = ', '.join(names), ', '.join(values)
[733]837        ret = 'oid, *' if 'oid' in attnames else '*'
838        q = 'INSERT INTO %s (%s) VALUES (%s) RETURNING %s' % (
[738]839            self._escape_qualified_name(table), names, values, ret)
[730]840        self._do_debug(q, params)
[735]841        q = self.db.query(q, params)
842        res = q.dictresult()
843        if not res:
[740]844            raise _int_error('Insert operation did not return new values')
[735]845        for n, value in res[0].items():
[733]846            if n == 'oid':
[738]847                n = _oid_key(table)
[733]848            elif attnames.get(n) == 'bytea' and value is not None:
849                value = self.unescape_bytea(value)
[738]850            row[n] = value
851        return row
[14]852
[738]853    def update(self, table, row=None, **kw):
[346]854        """Update an existing row in a database table.
[14]855
[346]856        Similar to insert but updates an existing row.  The update is based
[391]857        on the OID value as munged by get or passed as keyword, or on the
[733]858        primary key of the table.  The dictionary is modified to reflect
859        any changes caused by the update due to triggers, rules, default
860        values, etc.
[346]861        """
[735]862        # Update always works on the oid which get() returns if available,
[346]863        # otherwise use the primary key.  Fail if neither.
[735]864        # Note that we only accept oid key from named args for safety.
[738]865        qoid = _oid_key(table)
[357]866        if 'oid' in kw:
[387]867            kw[qoid] = kw['oid']
[346]868            del kw['oid']
[738]869        if row is None:
870            row = {}
871        row.update(kw)
872        attnames = self.get_attnames(table)
[730]873        params = []
874        param = partial(self._prepare_param, params=params)
[732]875        col = self.escape_identifier
[738]876        if qoid in row:
877            where = 'oid = %s' % param(row[qoid], 'int')
[735]878            keyname = []
[346]879        else:
880            try:
[738]881                keyname = self.pkey(table)
[389]882            except KeyError:
[738]883                raise _prg_error('Table %s has no primary key' % table)
[735]884            keyname = [keyname] if isinstance(
885                keyname, basestring) else sorted(keyname)
[389]886            try:
[735]887                where = ' AND '.join('%s = %s' % (
[738]888                    col(k), param(row[k], attnames[k])) for k in keyname)
[389]889            except KeyError:
[740]890                raise _prg_error('Update operation needs primary key or oid')
[735]891        keyname = set(keyname)
892        keyname.add('oid')
[389]893        values = []
894        for n in attnames:
[738]895            if n in row and n not in keyname:
896                values.append('%s = %s' % (col(n), param(row[n], attnames[n])))
[389]897        if not values:
[738]898            return row
[389]899        values = ', '.join(values)
[733]900        ret = 'oid, *' if 'oid' in attnames else '*'
901        q = 'UPDATE %s SET %s WHERE %s RETURNING %s' % (
[738]902            self._escape_qualified_name(table), values, where, ret)
[730]903        self._do_debug(q, params)
[735]904        q = self.db.query(q, params)
905        res = q.dictresult()
906        if res:  # may be empty when row does not exist
907            for n, value in res[0].items():
908                if n == 'oid':
909                    n = qoid
910                elif attnames.get(n) == 'bytea' and value is not None:
911                    value = self.unescape_bytea(value)
[738]912                row[n] = value
913        return row
[14]914
[738]915    def upsert(self, table, row=None, **kw):
[740]916        """Insert a row into a database table with conflict resolution
[735]917
918        This method inserts a row into a table, but instead of raising a
919        ProgrammingError exception in case a row with the same primary key
920        already exists, an update will be executed instead.  This will be
921        performed as a single atomic operation on the database, so race
922        conditions can be avoided.
923
924        Like the insert method, the first parameter is the name of the
925        table and the second parameter can be used to pass the values to
926        be inserted as a dictionary.
927
928        Unlike the insert und update statement, keyword parameters are not
929        used to modify the dictionary, but to specify which columns shall
930        be updated in case of a conflict, and in which way:
931
932        A value of False or None means the column shall not be updated,
933        a value of True means the column shall be updated with the value
934        that has been proposed for insertion, i.e. has been passed as value
935        in the dictionary.  Columns that are not specified by keywords but
936        appear as keys in the dictionary are also updated like in the case
937        keywords had been passed with the value True.
938
939        So if in the case of a conflict you want to update every column that
[738]940        has been passed in the dictionary row , you would call upsert(table, row).
[735]941        If you don't want to do anything in case of a conflict, i.e. leave
[738]942        the existing row as it is, call upsert(table, row, **dict.fromkeys(row)).
[735]943
944        If you need more fine-grained control of what gets updated, you can
945        also pass strings in the keyword parameters.  These strings will
946        be used as SQL expressions for the update columns.  In these
947        expressions you can refer to the value that already exists in
948        the table by prefixing the column name with "included.", and to
949        the value that has been proposed for insertion by prefixing the
950        column name with the "excluded."
951
952        The dictionary is modified in any case to reflect the values in
953        the database after the operation has completed.
954
955        Note: The method uses the PostgreSQL "upsert" feature which is
956        only available since PostgreSQL 9.5.
957        """
958        if 'oid' in kw:
959            del kw['oid']
[738]960        if row is None:
961            row = {}
962        attnames = self.get_attnames(table)
[735]963        params = []
964        param = partial(self._prepare_param,params=params)
965        col = self.escape_identifier
966        names, values, updates = [], [], []
967        for n in attnames:
[738]968            if n in row:
[735]969                names.append(col(n))
[738]970                values.append(param(row[n], attnames[n]))
[735]971        names, values = ', '.join(names), ', '.join(values)
972        try:
[738]973            keyname = self.pkey(table)
[735]974        except KeyError:
[738]975            raise _prg_error('Table %s has no primary key' % table)
[735]976        keyname = [keyname] if isinstance(
977            keyname, basestring) else sorted(keyname)
978        try:
979            target = ', '.join(col(k) for k in keyname)
980        except KeyError:
[740]981            raise _prg_error('Upsert operation needs primary key or oid')
[735]982        update = []
983        keyname = set(keyname)
984        keyname.add('oid')
985        for n in attnames:
986            if n not in keyname:
987                value = kw.get(n, True)
988                if value:
989                    if not isinstance(value, basestring):
990                        value = 'excluded.%s' % col(n)
991                    update.append('%s = %s' % (col(n), value))
992        if not values and not update:
[738]993            return row
[735]994        do = 'update set %s' % ', '.join(update) if update else 'nothing'
995        ret = 'oid, *' if 'oid' in attnames else '*'
996        q = ('INSERT INTO %s AS included (%s) VALUES (%s)'
997            ' ON CONFLICT (%s) DO %s RETURNING %s') % (
[738]998                self._escape_qualified_name(table), names, values,
[735]999                target, do, ret)
1000        self._do_debug(q, params)
1001        try:
1002            q = self.db.query(q, params)
1003        except ProgrammingError:
1004            if self.server_version < 90500:
[740]1005                raise _prg_error(
1006                    'Upsert operation is not supported by PostgreSQL version')
[735]1007            raise  # re-raise original error
1008        res = q.dictresult()
1009        if res:  # may be empty with "do nothing"
1010            for n, value in res[0].items():
1011                if n == 'oid':
[738]1012                    n = _oid_key(table)
[735]1013                elif attnames.get(n) == 'bytea':
1014                    value = self.unescape_bytea(value)
[738]1015                row[n] = value
[735]1016        elif update:
[740]1017            raise _int_error('Upsert operation did not return new values')
[735]1018        else:
[738]1019            self.get(table, row)
1020        return row
[735]1021
[738]1022    def clear(self, table, row=None):
[436]1023        """Clear all the attributes to values determined by the types.
[202]1024
[706]1025        Numeric types are set to 0, Booleans are set to false, and everything
[738]1026        else is set to the empty string.  If the row argument is present,
1027        it is used as the row dictionary and any entries matching attribute
1028        names are cleared with everything else left unchanged.
[346]1029        """
1030        # At some point we will need a way to get defaults from a table.
[738]1031        if row is None:
1032            row = {}  # empty if argument is not present
1033        attnames = self.get_attnames(table)
[520]1034        for n, t in attnames.items():
[389]1035            if n == 'oid':
[346]1036                continue
[441]1037            if t in ('int', 'integer', 'smallint', 'bigint',
1038                    'float', 'real', 'double precision',
1039                    'num', 'numeric', 'money'):
[738]1040                row[n] = 0
[441]1041            elif t in ('bool', 'boolean'):
[738]1042                row[n] = self._make_bool(False)
[346]1043            else:
[738]1044                row[n] = ''
1045        return row
[14]1046
[738]1047    def delete(self, table, row=None, **kw):
[346]1048        """Delete an existing row in a database table.
[202]1049
[391]1050        This method deletes the row from a table.  It deletes based on the
1051        OID value as munged by get or passed as keyword, or on the primary
1052        key of the table.  The return value is the number of deleted rows
1053        (i.e. 0 if the row did not exist and 1 if the row was deleted).
[377]1054        """
[346]1055        # Like update, delete works on the oid.
1056        # One day we will be testing that the record to be deleted
1057        # isn't referenced somewhere (or else PostgreSQL will).
[735]1058        # Note that we only accept oid key from named args for safety.
[738]1059        qoid = _oid_key(table)
[357]1060        if 'oid' in kw:
[387]1061            kw[qoid] = kw['oid']
[346]1062            del kw['oid']
[738]1063        if row is None:
1064            row = {}
1065        row.update(kw)
[730]1066        params = []
1067        param = partial(self._prepare_param, params=params)
[738]1068        if qoid in row:
1069            where = 'oid = %s' % param(row[qoid], 'int')
[391]1070        else:
1071            try:
[738]1072                keyname = self.pkey(table)
[391]1073            except KeyError:
[738]1074                raise _prg_error('Table %s has no primary key' % table)
[735]1075            keyname = [keyname] if isinstance(
1076                keyname, basestring) else sorted(keyname)
[738]1077            attnames = self.get_attnames(table)
[732]1078            col = self.escape_identifier
[391]1079            try:
[738]1080                where = ' AND '.join('%s = %s' % (
1081                    col(k), param(row[k], attnames[k])) for k in keyname)
[391]1082            except KeyError:
[740]1083                raise _prg_error('Delete operation needs primary key or oid')
[732]1084        q = 'DELETE FROM %s WHERE %s' % (
[738]1085            self._escape_qualified_name(table), where)
[730]1086        self._do_debug(q, params)
[735]1087        res = self.db.query(q, params)
1088        return int(res)
[222]1089
[511]1090    def notification_handler(self, event, callback, arg_dict={}, timeout=None):
[506]1091        """Get notification handler that will run the given callback."""
[511]1092        return NotificationHandler(self.db, event, callback, arg_dict, timeout)
[332]1093
[506]1094
[222]1095# if run as script, print some information
[328]1096
[222]1097if __name__ == '__main__':
[422]1098    print('PyGreSQL version' + version)
1099    print('')
1100    print(__doc__)
Note: See TracBrowser for help on using the repository browser.