source: trunk/pg.py @ 765

Last change on this file since 765 was 765, checked in by cito, 3 years ago

Improve support for access by primary key

Composite primary keys are now returned as tuples instead of frozensets,
where the ordering of the tuple reflects the primary key index.

Primary keys now takes precedence if both OID and primary key are available
(this was solved the other way around in 4.x). Use of OIDs is thus slightly
more discouraged, though it still works as before for tables with OIDs where
no primary key is available.

This changeset also clarifies some docstrings, makes the code a bit clearer,
handles and tests some more edge cases (pg module still has 100% coverage).

  • Property svn:keywords set to Id
File size: 47.0 KB
Line 
1#! /usr/bin/python
2#
3# pg.py
4#
5# $Id: pg.py 765 2016-01-18 23:21:44Z cito $
6#
7
8"""PyGreSQL classic interface.
9
10This pg module implements some basic database management stuff.
11It includes the _pg module and builds on it, providing the higher
12level wrapper class named DB with additional functionality.
13This is known as the "classic" ("old style") PyGreSQL interface.
14For a DB-API 2 compliant interface use the newer pgdb module.
15"""
16
17# Copyright (c) 1997-2016 by D'Arcy J.M. Cain.
18#
19# Contributions made by Ch. Zwerschke and others.
20#
21# The notification handler is based on pgnotify which is
22# Copyright (c) 2001 Ng Pheng Siong. All rights reserved.
23#
24# Permission to use, copy, modify, and distribute this software and its
25# documentation for any purpose and without fee is hereby granted,
26# provided that the above copyright notice appear in all copies and that
27# both that copyright notice and this permission notice appear in
28# supporting documentation.
29
30from __future__ import print_function
31
32from _pg import *
33
34import select
35import warnings
36
37from decimal import Decimal
38from collections import namedtuple
39from functools import partial
40
41try:
42    from collections import OrderedDict
43except ImportError:  # Python 2.6 or 3.0
44    OrderedDict = dict
45
46try:
47    basestring
48except NameError:  # Python >= 3.0
49    basestring = (str, bytes)
50
51set_decimal(Decimal)
52
53
54# Auxiliary functions that are independent from a DB connection:
55
56def _oid_key(table):
57    """Build oid key from a table name."""
58    return 'oid(%s)' % table
59
60
61def _simpletype(typ):
62    """Determine a simplified name a pg_type name."""
63    if typ.startswith('bool'):
64        return 'bool'
65    if typ.startswith(('abstime', 'date', 'interval', 'timestamp')):
66        return 'date'
67    if typ.startswith(('cid', 'oid', 'int', 'xid')):
68        return 'int'
69    if typ.startswith('float'):
70        return 'float'
71    if typ.startswith('numeric'):
72        return 'num'
73    if typ.startswith('money'):
74        return 'money'
75    if typ.startswith('bytea'):
76        return 'bytea'
77    return 'text'
78
79
80def _namedresult(q):
81    """Get query result as named tuples."""
82    row = namedtuple('Row', q.listfields())
83    return [row(*r) for r in q.getresult()]
84
85set_namedresult(_namedresult)
86
87
88def _db_error(msg, cls=DatabaseError):
89    """Return DatabaseError with empty sqlstate attribute."""
90    error = cls(msg)
91    error.sqlstate = None
92    return error
93
94
95def _int_error(msg):
96    """Return InternalError."""
97    return _db_error(msg, InternalError)
98
99
100def _prg_error(msg):
101    """Return ProgrammingError."""
102    return _db_error(msg, ProgrammingError)
103
104
105class NotificationHandler(object):
106    """A PostgreSQL client-side asynchronous notification handler."""
107
108    def __init__(self, db, event, callback=None,
109            arg_dict=None, timeout=None, stop_event=None):
110        """Initialize the notification handler.
111
112        You must pass a PyGreSQL database connection, the name of an
113        event (notification channel) to listen for and a callback function.
114
115        You can also specify a dictionary arg_dict that will be passed as
116        the single argument to the callback function, and a timeout value
117        in seconds (a floating point number denotes fractions of seconds).
118        If it is absent or None, the callers will never time out.  If the
119        timeout is reached, the callback function will be called with a
120        single argument that is None.  If you set the timeout to zero,
121        the handler will poll notifications synchronously and return.
122
123        You can specify the name of the event that will be used to signal
124        the handler to stop listening as stop_event. By default, it will
125        be the event name prefixed with 'stop_'.
126        """
127        self.db = db
128        self.event = event
129        self.stop_event = stop_event or 'stop_%s' % event
130        self.listening = False
131        self.callback = callback
132        if arg_dict is None:
133            arg_dict = {}
134        self.arg_dict = arg_dict
135        self.timeout = timeout
136
137    def __del__(self):
138        self.unlisten()
139
140    def close(self):
141        """Stop listening and close the connection."""
142        if self.db:
143            self.unlisten()
144            self.db.close()
145            self.db = None
146
147    def listen(self):
148        """Start listening for the event and the stop event."""
149        if not self.listening:
150            self.db.query('listen "%s"' % self.event)
151            self.db.query('listen "%s"' % self.stop_event)
152            self.listening = True
153
154    def unlisten(self):
155        """Stop listening for the event and the stop event."""
156        if self.listening:
157            self.db.query('unlisten "%s"' % self.event)
158            self.db.query('unlisten "%s"' % self.stop_event)
159            self.listening = False
160
161    def notify(self, db=None, stop=False, payload=None):
162        """Generate a notification.
163
164        Optionally, you can pass a payload with the notification.
165
166        If you set the stop flag, a stop notification will be sent that
167        will cause the handler to stop listening.
168
169        Note: If the notification handler is running in another thread, you
170        must pass a different database connection since PyGreSQL database
171        connections are not thread-safe.
172        """
173        if self.listening:
174            if not db:
175                db = self.db
176            q = 'notify "%s"' % (self.stop_event if stop else self.event)
177            if payload:
178                q += ", '%s'" % payload
179            return db.query(q)
180
181    def __call__(self):
182        """Invoke the notification handler.
183
184        The handler is a loop that listens for notifications on the event
185        and stop event channels.  When either of these notifications are
186        received, its associated 'pid', 'event' and 'extra' (the payload
187        passed with the notification) are inserted into its arg_dict
188        dictionary and the callback is invoked with this dictionary as
189        a single argument.  When the handler receives a stop event, it
190        stops listening to both events and return.
191
192        In the special case that the timeout of the handler has been set
193        to zero, the handler will poll all events synchronously and return.
194        If will keep listening until it receives a stop event.
195
196        Note: If you run this loop in another thread, don't use the same
197        database connection for database operations in the main thread.
198        """
199        self.listen()
200        poll = self.timeout == 0
201        if not poll:
202            rlist = [self.db.fileno()]
203        while self.listening:
204            if poll or select.select(rlist, [], [], self.timeout)[0]:
205                while self.listening:
206                    notice = self.db.getnotify()
207                    if not notice:  # no more messages
208                        break
209                    event, pid, extra = notice
210                    if event not in (self.event, self.stop_event):
211                        self.unlisten()
212                        raise _db_error(
213                            'Listening for "%s" and "%s", but notified of "%s"'
214                            % (self.event, self.stop_event, event))
215                    if event == self.stop_event:
216                        self.unlisten()
217                    self.arg_dict.update(pid=pid, event=event, extra=extra)
218                    self.callback(self.arg_dict)
219                if poll:
220                    break
221            else:   # we timed out
222                self.unlisten()
223                self.callback(None)
224
225
226def pgnotify(*args, **kw):
227    """Same as NotificationHandler, under the traditional name."""
228    warnings.warn("pgnotify is deprecated, use NotificationHandler instead",
229        DeprecationWarning, stacklevel=2)
230    return NotificationHandler(*args, **kw)
231
232
233# The actual PostGreSQL database connection interface:
234
235class DB(object):
236    """Wrapper class for the _pg connection type."""
237
238    def __init__(self, *args, **kw):
239        """Create a new connection
240
241        You can pass either the connection parameters or an existing
242        _pg or pgdb connection. This allows you to use the methods
243        of the classic pg interface with a DB-API 2 pgdb connection.
244        """
245        if not args and len(kw) == 1:
246            db = kw.get('db')
247        elif not kw and len(args) == 1:
248            db = args[0]
249        else:
250            db = None
251        if db:
252            if isinstance(db, DB):
253                db = db.db
254            else:
255                try:
256                    db = db._cnx
257                except AttributeError:
258                    pass
259        if not db or not hasattr(db, 'db') or not hasattr(db, 'query'):
260            db = connect(*args, **kw)
261            self._closeable = True
262        else:
263            self._closeable = False
264        self.db = db
265        self.dbname = db.db
266        self._regtypes = False
267        self._attnames = {}
268        self._pkeys = {}
269        self._privileges = {}
270        self._args = args, kw
271        self.debug = None  # For debugging scripts, this can be set
272            # * to a string format specification (e.g. in CGI set to "%s<BR>"),
273            # * to a file object to write debug statements or
274            # * to a callable object which takes a string argument
275            # * to any other true value to just print debug statements
276
277    def __getattr__(self, name):
278        # All undefined members are same as in underlying connection:
279        if self.db:
280            return getattr(self.db, name)
281        else:
282            raise _int_error('Connection is not valid')
283
284    def __dir__(self):
285        # Custom dir function including the attributes of the connection:
286        attrs = set(self.__class__.__dict__)
287        attrs.update(self.__dict__)
288        attrs.update(dir(self.db))
289        return sorted(attrs)
290
291    # Context manager methods
292
293    def __enter__(self):
294        """Enter the runtime context. This will start a transactio."""
295        self.begin()
296        return self
297
298    def __exit__(self, et, ev, tb):
299        """Exit the runtime context. This will end the transaction."""
300        if et is None and ev is None and tb is None:
301            self.commit()
302        else:
303            self.rollback()
304
305    # Auxiliary methods
306
307    def _do_debug(self, *args):
308        """Print a debug message"""
309        if self.debug:
310            s = '\n'.join(args)
311            if isinstance(self.debug, basestring):
312                print(self.debug % s)
313            elif hasattr(self.debug, 'write'):
314                self.debug.write(s + '\n')
315            elif callable(self.debug):
316                self.debug(s)
317            else:
318                print(s)
319
320    def _escape_qualified_name(self, s):
321        """Escape a qualified name.
322
323        Escapes the name for use as an SQL identifier, unless the
324        name contains a dot, in which case the name is ambiguous
325        (could be a qualified name or just a name with a dot in it)
326        and must be quoted manually by the caller.
327        """
328        if '.' not in s:
329            s = self.escape_identifier(s)
330        return s
331
332    @staticmethod
333    def _make_bool(d):
334        """Get boolean value corresponding to d."""
335        return bool(d) if get_bool() else ('t' if d else 'f')
336
337    _bool_true_values = frozenset('t true 1 y yes on'.split())
338
339    def _prepare_bool(self, d):
340        """Prepare a boolean parameter."""
341        if isinstance(d, basestring):
342            if not d:
343                return None
344            d = d.lower() in self._bool_true_values
345        return 't' if d else 'f'
346
347    _date_literals = frozenset('current_date current_time'
348        ' current_timestamp localtime localtimestamp'.split())
349
350    def _prepare_date(self, d):
351        """Prepare a date parameter."""
352        if not d:
353            return None
354        if isinstance(d, basestring) and d.lower() in self._date_literals:
355            raise ValueError
356        return d
357
358    _num_types = frozenset('int float num money'
359        ' int2 int4 int8 float4 float8 numeric money'.split())
360
361    def _prepare_num(self, d):
362        """Prepare a numeric parameter."""
363        if not d and d != 0:
364            return None
365        return d
366
367    def _prepare_bytea(self, d):
368        """Prepare a bytea parameter."""
369        return self.escape_bytea(d)
370
371    _prepare_funcs = dict(  # quote methods for each type
372        bool=_prepare_bool, date=_prepare_date,
373        int=_prepare_num, num=_prepare_num, float=_prepare_num,
374        money=_prepare_num, bytea=_prepare_bytea)
375
376    def _prepare_param(self, value, typ, params):
377        """Prepare and add a parameter to the list."""
378        if value is not None and typ != 'text':
379            prepare = self._prepare_funcs[typ]
380            try:
381                value = prepare(self, value)
382            except ValueError:
383                return value
384        params.append(value)
385        return '$%d' % len(params)
386
387    def _list_params(self, params):
388        """Create a human readable parameter list."""
389        return ', '.join('$%d=%r' % (n, v) for n, v in enumerate(params, 1))
390
391    @staticmethod
392    def _prepare_qualified_param(name, param):
393        """Quote parameter representing a qualified name.
394
395        Escapes the name for use as an SQL parameter, unless the
396        name contains a dot, in which case the name is ambiguous
397        (could be a qualified name or just a name with a dot in it)
398        and must be quoted manually by the caller.
399
400        """
401        if isinstance(param, int):
402            param = "$%d" % param
403        if '.' not in name:
404            param = 'quote_ident(%s)' % (param,)
405        return param
406
407    # Public methods
408
409    # escape_string and escape_bytea exist as methods,
410    # so we define unescape_bytea as a method as well
411    unescape_bytea = staticmethod(unescape_bytea)
412
413    def close(self):
414        """Close the database connection."""
415        # Wraps shared library function so we can track state.
416        if self._closeable:
417            if self.db:
418                self.db.close()
419                self.db = None
420            else:
421                raise _int_error('Connection already closed')
422
423    def reset(self):
424        """Reset connection with current parameters.
425
426        All derived queries and large objects derived from this connection
427        will not be usable after this call.
428
429        """
430        if self.db:
431            self.db.reset()
432        else:
433            raise _int_error('Connection already closed')
434
435    def reopen(self):
436        """Reopen connection to the database.
437
438        Used in case we need another connection to the same database.
439        Note that we can still reopen a database that we have closed.
440
441        """
442        # There is no such shared library function.
443        if self._closeable:
444            db = connect(*self._args[0], **self._args[1])
445            if self.db:
446                self.db.close()
447            self.db = db
448
449    def begin(self, mode=None):
450        """Begin a transaction."""
451        qstr = 'BEGIN'
452        if mode:
453            qstr += ' ' + mode
454        return self.query(qstr)
455
456    start = begin
457
458    def commit(self):
459        """Commit the current transaction."""
460        return self.query('COMMIT')
461
462    end = commit
463
464    def rollback(self, name=None):
465        """Roll back the current transaction."""
466        qstr = 'ROLLBACK'
467        if name:
468            qstr += ' TO ' + name
469        return self.query(qstr)
470
471    abort = rollback
472
473    def savepoint(self, name):
474        """Define a new savepoint within the current transaction."""
475        return self.query('SAVEPOINT ' + name)
476
477    def release(self, name):
478        """Destroy a previously defined savepoint."""
479        return self.query('RELEASE ' + name)
480
481    def get_parameter(self, parameter):
482        """Get the value of a run-time parameter.
483
484        If the parameter is a string, the return value will also be a string
485        that is the current setting of the run-time parameter with that name.
486
487        You can get several parameters at once by passing a list, set or dict.
488        When passing a list of parameter names, the return value will be a
489        corresponding list of parameter settings.  When passing a set of
490        parameter names, a new dict will be returned, mapping these parameter
491        names to their settings.  Finally, if you pass a dict as parameter,
492        its values will be set to the current parameter settings corresponding
493        to its keys.
494
495        By passing the special name 'all' as the parameter, you can get a dict
496        of all existing configuration parameters.
497        """
498        if isinstance(parameter, basestring):
499            parameter = [parameter]
500            values = None
501        elif isinstance(parameter, (list, tuple)):
502            values = []
503        elif isinstance(parameter, (set, frozenset)):
504            values = {}
505        elif isinstance(parameter, dict):
506            values = parameter
507        else:
508            raise TypeError(
509                'The parameter must be a string, list, set or dict')
510        if not parameter:
511            raise TypeError('No parameter has been specified')
512        params = {} if isinstance(values, dict) else []
513        for key in parameter:
514            param = key.strip().lower() if isinstance(
515                key, basestring) else None
516            if not param:
517                raise TypeError('Invalid parameter')
518            if param == 'all':
519                q = 'SHOW ALL'
520                values = self.db.query(q).getresult()
521                values = dict(value[:2] for value in values)
522                break
523            if isinstance(values, dict):
524                params[param] = key
525            else:
526                params.append(param)
527        else:
528            for param in params:
529                q = 'SHOW %s' % (param,)
530                value = self.db.query(q).getresult()[0][0]
531                if values is None:
532                    values = value
533                elif isinstance(values, list):
534                    values.append(value)
535                else:
536                    values[params[param]] = value
537        return values
538
539    def set_parameter(self, parameter, value=None, local=False):
540        """Set the value of a run-time parameter.
541
542        If the parameter and the value are strings, the run-time parameter
543        will be set to that value.  If no value or None is passed as a value,
544        then the run-time parameter will be restored to its default value.
545
546        You can set several parameters at once by passing a list of parameter
547        names, together with a single value that all parameters should be
548        set to or with a corresponding list of values.  You can also pass
549        the parameters as a set if you only provide a single value.
550        Finally, you can pass a dict with parameter names as keys.  In this
551        case, you should not pass a value, since the values for the parameters
552        will be taken from the dict.
553
554        By passing the special name 'all' as the parameter, you can reset
555        all existing settable run-time parameters to their default values.
556
557        If you set local to True, then the command takes effect for only the
558        current transaction.  After commit() or rollback(), the session-level
559        setting takes effect again.  Setting local to True will appear to
560        have no effect if it is executed outside a transaction, since the
561        transaction will end immediately.
562        """
563        if isinstance(parameter, basestring):
564            parameter = {parameter: value}
565        elif isinstance(parameter, (list, tuple)):
566            if isinstance(value, (list, tuple)):
567                parameter = dict(zip(parameter, value))
568            else:
569                parameter = dict.fromkeys(parameter, value)
570        elif isinstance(parameter, (set, frozenset)):
571            if isinstance(value, (list, tuple, set, frozenset)):
572                value = set(value)
573                if len(value) == 1:
574                    value = value.pop()
575            if not(value is None or isinstance(value, basestring)):
576                raise ValueError('A single value must be specified'
577                    ' when parameter is a set')
578            parameter = dict.fromkeys(parameter, value)
579        elif isinstance(parameter, dict):
580            if value is not None:
581                raise ValueError('A value must not be specified'
582                    ' when parameter is a dictionary')
583        else:
584            raise TypeError(
585                'The parameter must be a string, list, set or dict')
586        if not parameter:
587            raise TypeError('No parameter has been specified')
588        params = {}
589        for key, value in parameter.items():
590            param = key.strip().lower() if isinstance(
591                key, basestring) else None
592            if not param:
593                raise TypeError('Invalid parameter')
594            if param == 'all':
595                if value is not None:
596                    raise ValueError('A value must ot be specified'
597                        " when parameter is 'all'")
598                params = {'all': None}
599                break
600            params[param] = value
601        local = ' LOCAL' if local else ''
602        for param, value in params.items():
603            if value is None:
604                q = 'RESET%s %s' % (local, param)
605            else:
606                q = 'SET%s %s TO %s' % (local, param, value)
607            self._do_debug(q)
608            self.db.query(q)
609
610    def query(self, qstr, *args):
611        """Execute a SQL command string.
612
613        This method simply sends a SQL query to the database.  If the query is
614        an insert statement that inserted exactly one row into a table that
615        has OIDs, the return value is the OID of the newly inserted row.
616        If the query is an update or delete statement, or an insert statement
617        that did not insert exactly one row in a table with OIDs, then the
618        number of rows affected is returned as a string.  If it is a statement
619        that returns rows as a result (usually a select statement, but maybe
620        also an "insert/update ... returning" statement), this method returns
621        a Query object that can be accessed via getresult() or dictresult()
622        or simply printed.  Otherwise, it returns `None`.
623
624        The query can contain numbered parameters of the form $1 in place
625        of any data constant.  Arguments given after the query string will
626        be substituted for the corresponding numbered parameter.  Parameter
627        values can also be given as a single list or tuple argument.
628        """
629        # Wraps shared library function for debugging.
630        if not self.db:
631            raise _int_error('Connection is not valid')
632        self._do_debug(qstr)
633        return self.db.query(qstr, args)
634
635    def pkey(self, table, composite=False, flush=False):
636        """Get or set the primary key of a table.
637
638        Single primary keys are returned as strings unless you
639        set the composite flag.  Composite primary keys are always
640        represented as tuples.  Note that this raises a KeyError
641        if the table does not have a primary key.
642
643        If flush is set then the internal cache for primary keys will
644        be flushed.  This may be necessary after the database schema or
645        the search path has been changed.
646        """
647        pkeys = self._pkeys
648        if flush:
649            pkeys.clear()
650            self._do_debug('The pkey cache has been flushed')
651        try:  # cache lookup
652            pkey = pkeys[table]
653        except KeyError:  # cache miss, check the database
654            q = ("SELECT a.attname, a.attnum, i.indkey FROM pg_index i"
655                " JOIN pg_attribute a ON a.attrelid = i.indrelid"
656                " AND a.attnum = ANY(i.indkey)"
657                " AND NOT a.attisdropped"
658                " WHERE i.indrelid=%s::regclass"
659                " AND i.indisprimary ORDER BY a.attnum") % (
660                    self._prepare_qualified_param(table, 1),)
661            pkey = self.db.query(q, (table,)).getresult()
662            if not pkey:
663                raise KeyError('Table %s has no primary key' % table)
664            # we want to use the order defined in the primary key index here,
665            # not the order as defined by the columns in the table
666            if len(pkey) > 1:
667                indkey = [int(k) for k in pkey[0][2].split()]
668                pkey = sorted(pkey, key=lambda row: indkey.index(row[1]))
669                pkey = tuple(row[0] for row in pkey)
670            else:
671                pkey = pkey[0][0]
672            pkeys[table] = pkey  # cache it
673        if composite and not isinstance(pkey, tuple):
674            pkey = (pkey,)
675        return pkey
676
677    def get_databases(self):
678        """Get list of databases in the system."""
679        return [s[0] for s in
680            self.db.query('SELECT datname FROM pg_database').getresult()]
681
682    def get_relations(self, kinds=None):
683        """Get list of relations in connected database of specified kinds.
684
685        If kinds is None or empty, all kinds of relations are returned.
686        Otherwise kinds can be a string or sequence of type letters
687        specifying which kind of relations you want to list.
688        """
689        where = " AND r.relkind IN (%s)" % ','.join(
690            ["'%s'" % k for k in kinds]) if kinds else ''
691        q = ("SELECT quote_ident(s.nspname)||'.'||quote_ident(r.relname)"
692            " FROM pg_class r"
693            " JOIN pg_namespace s ON s.oid = r.relnamespace"
694            " WHERE s.nspname NOT SIMILAR"
695            " TO 'pg/_%%|information/_schema' ESCAPE '/' %s"
696            " ORDER BY s.nspname, r.relname") % where
697        return [r[0] for r in self.db.query(q).getresult()]
698
699    def get_tables(self):
700        """Return list of tables in connected database."""
701        return self.get_relations('r')
702
703    def get_attnames(self, table, flush=False):
704        """Given the name of a table, dig out the set of attribute names.
705
706        Returns a dictionary of attribute names (the names are the keys,
707        the values are the names of the attributes' types).
708
709        If your Python version supports this, the dictionary will be an
710        OrderedDictionary with the column names in the right order.
711
712        If flush is set then the internal cache for attribute names will
713        be flushed. This may be necessary after the database schema or
714        the search path has been changed.
715
716        By default, only a limited number of simple types will be returned.
717        You can get the regular types after calling use_regtypes(True).
718        """
719        attnames = self._attnames
720        if flush:
721            attnames.clear()
722            self._do_debug('The attnames cache has been flushed')
723        try:  # cache lookup
724            names = attnames[table]
725        except KeyError:  # cache miss, check the database
726            q = ("SELECT a.attname, t.typname%s"
727                " FROM pg_attribute a"
728                " JOIN pg_type t ON t.oid = a.atttypid"
729                " WHERE a.attrelid = %s::regclass"
730                " AND (a.attnum > 0 OR a.attname = 'oid')"
731                " AND NOT a.attisdropped ORDER BY a.attnum") % (
732                    '::regtype' if self._regtypes else '',
733                    self._prepare_qualified_param(table, 1))
734            names = self.db.query(q, (table,)).getresult()
735            if not self._regtypes:
736                names = ((name, _simpletype(typ)) for name, typ in names)
737            names = OrderedDict(names)
738            attnames[table] = names  # cache it
739        return names
740
741    def use_regtypes(self, regtypes=None):
742        """Use regular type names instead of simplified type names."""
743        if regtypes is None:
744            return self._regtypes
745        else:
746            regtypes = bool(regtypes)
747            if regtypes != self._regtypes:
748                self._regtypes = regtypes
749                self._attnames.clear()
750            return regtypes
751
752    def has_table_privilege(self, table, privilege='select'):
753        """Check whether current user has specified table privilege."""
754        privilege = privilege.lower()
755        try:  # ask cache
756            return self._privileges[(table, privilege)]
757        except KeyError:  # cache miss, ask the database
758            q = "SELECT has_table_privilege(%s, $2)" % (
759                self._prepare_qualified_param(table, 1),)
760            q = self.db.query(q, (table, privilege))
761            ret = q.getresult()[0][0] == self._make_bool(True)
762            self._privileges[(table, privilege)] = ret  # cache it
763            return ret
764
765    def get(self, table, row, keyname=None):
766        """Get a row from a database table or view.
767
768        This method is the basic mechanism to get a single row.  It assumes
769        that the keyname specifies a unique row.  It must be the name of a
770        single column or a tuple of column names.  If the keyname is not
771        specified, then the primary key for the table is used.
772
773        If row is a dictionary, then the value for the key is taken from it.
774        Otherwise, the row must be a single value or a tuple of values
775        corresponding to the passed keyname or primary key.  The fetched row
776        from the table will be returned as a new dictionary or used to replace
777        the existing values when row was passed as aa dictionary.
778
779        The OID is also put into the dictionary if the table has one, but
780        in order to allow the caller to work with multiple tables, it is
781        munged as "oid(table)" using the actual name of the table.
782        """
783        if table.endswith('*'):  # hint for descendant tables can be ignored
784            table = table[:-1].rstrip()
785        attnames = self.get_attnames(table)
786        qoid = _oid_key(table) if 'oid' in attnames else None
787        if keyname and isinstance(keyname, basestring):
788            keyname = (keyname,)
789        if qoid and isinstance(row, dict) and qoid in row and 'oid' not in row:
790            row['oid'] = row[qoid]
791        if not keyname:
792            try:  # if keyname is not specified, try using the primary key
793                keyname = self.pkey(table, True)
794            except KeyError:  # the table has no primary key
795                # try using the oid instead
796                if qoid and isinstance(row, dict) and 'oid' in row:
797                    keyname = ('oid',)
798                else:
799                    raise _prg_error('Table %s has no primary key' % table)
800            else:  # the table has a primary key
801                # check whether all key columns have values
802                if isinstance(row, dict) and not set(keyname).issubset(row):
803                    # try using the oid instead
804                    if qoid and 'oid' in row:
805                        keyname = ('oid',)
806                    else:
807                        raise KeyError(
808                            'Missing value in row for specified keyname')
809        if not isinstance(row, dict):
810            if not isinstance(row, (tuple, list)):
811                row = [row]
812            if len(keyname) != len(row):
813                raise KeyError(
814                    'Differing number of items in keyname and row')
815            row = dict(zip(keyname, row))
816        params = []
817        param = partial(self._prepare_param, params=params)
818        col = self.escape_identifier
819        what = 'oid, *' if qoid else '*'
820        where = ' AND '.join('%s = %s' % (
821            col(k), param(row[k], attnames[k])) for k in keyname)
822        if 'oid' in row:
823            if qoid:
824                row[qoid] = row['oid']
825            del row['oid']
826        q = 'SELECT %s FROM %s WHERE %s LIMIT 1' % (
827            what, self._escape_qualified_name(table), where)
828        self._do_debug(q, params)
829        q = self.db.query(q, params)
830        res = q.dictresult()
831        if not res:
832            raise _db_error('No such record in %s\nwhere %s\nwith %s' % (
833                table, where, self._list_params(params)))
834        for n, value in res[0].items():
835            if qoid and n == 'oid':
836                n = qoid
837            elif value is not None and attnames.get(n) == 'bytea':
838                value = self.unescape_bytea(value)
839            row[n] = value
840        return row
841
842    def insert(self, table, row=None, **kw):
843        """Insert a row into a database table.
844
845        This method inserts a row into a table.  The name of the table must
846        be passed as the first parameter.  The other parameters are used for
847        providing the data of the row that shall be inserted into the table.
848        If a dictionary is supplied as the second parameter, it starts with
849        that.  Otherwise it uses a blank dictionary. Either way the dictionary
850        is updated from the keywords.
851
852        The dictionary is then reloaded with the values actually inserted in
853        order to pick up values modified by rules, triggers, etc.
854
855        Note: The method currently doesn't support insert into views
856        although PostgreSQL does.
857        """
858        if table.endswith('*'):  # hint for descendant tables can be ignored
859            table = table[:-1].rstrip()
860        if row is None:
861            row = {}
862        row.update(kw)
863        if 'oid' in row:
864            del row['oid']  # do not insert oid
865        attnames = self.get_attnames(table)
866        qoid = _oid_key(table) if 'oid' in attnames else None
867        params = []
868        param = partial(self._prepare_param, params=params)
869        col = self.escape_identifier
870        names, values = [], []
871        for n in attnames:
872            if n in row:
873                names.append(col(n))
874                values.append(param(row[n], attnames[n]))
875        names, values = ', '.join(names), ', '.join(values)
876        ret = 'oid, *' if qoid else '*'
877        q = 'INSERT INTO %s (%s) VALUES (%s) RETURNING %s' % (
878            self._escape_qualified_name(table), names, values, ret)
879        self._do_debug(q, params)
880        q = self.db.query(q, params)
881        res = q.dictresult()
882        if res:  # this should always be true
883            for n, value in res[0].items():
884                if qoid and n == 'oid':
885                    n = qoid
886                elif value is not None and attnames.get(n) == 'bytea':
887                    value = self.unescape_bytea(value)
888                row[n] = value
889        return row
890
891    def update(self, table, row=None, **kw):
892        """Update an existing row in a database table.
893
894        Similar to insert but updates an existing row.  The update is based
895        on the primary key of the table or the OID value as munged by get
896        or passed as keyword.
897
898        The dictionary is then modified to reflect any changes caused by the
899        update due to triggers, rules, default values, etc.
900        """
901        if table.endswith('*'):
902            table = table[:-1].rstrip()  # need parent table name
903        attnames = self.get_attnames(table)
904        qoid = _oid_key(table) if 'oid' in attnames else None
905        if row is None:
906            row = {}
907        elif 'oid' in row:
908            del row['oid']  # only accept oid key from named args for safety
909        row.update(kw)
910        if qoid and qoid in row and 'oid' not in row:
911            row['oid'] = row[qoid]
912        try:  # try using the primary key
913            keyname = self.pkey(table, True)
914        except KeyError:  # the table has no primary key
915            # try using the oid instead
916            if qoid and 'oid' in row:
917                keyname = ('oid',)
918            else:
919                raise _prg_error('Table %s has no primary key' % table)
920        else:  # the table has a primary key
921            # check whether all key columns have values
922            if not set(keyname).issubset(row):
923                # try using the oid instead
924                if qoid and 'oid' in row:
925                    keyname = ('oid',)
926                else:
927                    raise KeyError('Missing primary key in row')
928        params = []
929        param = partial(self._prepare_param, params=params)
930        col = self.escape_identifier
931        where = ' AND '.join('%s = %s' % (
932            col(k), param(row[k], attnames[k])) for k in keyname)
933        if 'oid' in row:
934            if qoid:
935                row[qoid] = row['oid']
936            del row['oid']
937        values = []
938        keyname = set(keyname)
939        for n in attnames:
940            if n in row and n not in keyname:
941                values.append('%s = %s' % (col(n), param(row[n], attnames[n])))
942        if not values:
943            return row
944        values = ', '.join(values)
945        ret = 'oid, *' if qoid else '*'
946        q = 'UPDATE %s SET %s WHERE %s RETURNING %s' % (
947            self._escape_qualified_name(table), values, where, ret)
948        self._do_debug(q, params)
949        q = self.db.query(q, params)
950        res = q.dictresult()
951        if res:  # may be empty when row does not exist
952            for n, value in res[0].items():
953                if qoid and n == 'oid':
954                    n = qoid
955                elif value is not None and attnames.get(n) == 'bytea':
956                    value = self.unescape_bytea(value)
957                row[n] = value
958        return row
959
960    def upsert(self, table, row=None, **kw):
961        """Insert a row into a database table with conflict resolution
962
963        This method inserts a row into a table, but instead of raising a
964        ProgrammingError exception in case a row with the same primary key
965        already exists, an update will be executed instead.  This will be
966        performed as a single atomic operation on the database, so race
967        conditions can be avoided.
968
969        Like the insert method, the first parameter is the name of the
970        table and the second parameter can be used to pass the values to
971        be inserted as a dictionary.
972
973        Unlike the insert und update statement, keyword parameters are not
974        used to modify the dictionary, but to specify which columns shall
975        be updated in case of a conflict, and in which way:
976
977        A value of False or None means the column shall not be updated,
978        a value of True means the column shall be updated with the value
979        that has been proposed for insertion, i.e. has been passed as value
980        in the dictionary.  Columns that are not specified by keywords but
981        appear as keys in the dictionary are also updated like in the case
982        keywords had been passed with the value True.
983
984        So if in the case of a conflict you want to update every column that
985        has been passed in the dictionary row , you would call upsert(table, row).
986        If you don't want to do anything in case of a conflict, i.e. leave
987        the existing row as it is, call upsert(table, row, **dict.fromkeys(row)).
988
989        If you need more fine-grained control of what gets updated, you can
990        also pass strings in the keyword parameters.  These strings will
991        be used as SQL expressions for the update columns.  In these
992        expressions you can refer to the value that already exists in
993        the table by prefixing the column name with "included.", and to
994        the value that has been proposed for insertion by prefixing the
995        column name with the "excluded."
996
997        The dictionary is modified in any case to reflect the values in
998        the database after the operation has completed.
999
1000        Note: The method uses the PostgreSQL "upsert" feature which is
1001        only available since PostgreSQL 9.5.
1002        """
1003        if table.endswith('*'):  # hint for descendant tables can be ignored
1004            table = table[:-1].rstrip()
1005        if row is None:
1006            row = {}
1007        if 'oid' in row:
1008            del row['oid']  # do not insert oid
1009        if 'oid' in kw:
1010            del kw['oid']  # do not update oid
1011        attnames = self.get_attnames(table)
1012        qoid = _oid_key(table) if 'oid' in attnames else None
1013        params = []
1014        param = partial(self._prepare_param,params=params)
1015        col = self.escape_identifier
1016        names, values, updates = [], [], []
1017        for n in attnames:
1018            if n in row:
1019                names.append(col(n))
1020                values.append(param(row[n], attnames[n]))
1021        names, values = ', '.join(names), ', '.join(values)
1022        try:
1023            keyname = self.pkey(table, True)
1024        except KeyError:
1025            raise _prg_error('Table %s has no primary key' % table)
1026        target = ', '.join(col(k) for k in keyname)
1027        update = []
1028        keyname = set(keyname)
1029        keyname.add('oid')
1030        for n in attnames:
1031            if n not in keyname:
1032                value = kw.get(n, True)
1033                if value:
1034                    if not isinstance(value, basestring):
1035                        value = 'excluded.%s' % col(n)
1036                    update.append('%s = %s' % (col(n), value))
1037        if not values:
1038            return row
1039        do = 'update set %s' % ', '.join(update) if update else 'nothing'
1040        ret = 'oid, *' if qoid else '*'
1041        q = ('INSERT INTO %s AS included (%s) VALUES (%s)'
1042            ' ON CONFLICT (%s) DO %s RETURNING %s') % (
1043                self._escape_qualified_name(table), names, values,
1044                target, do, ret)
1045        self._do_debug(q, params)
1046        try:
1047            q = self.db.query(q, params)
1048        except ProgrammingError:
1049            if self.server_version < 90500:
1050                raise _prg_error(
1051                    'Upsert operation is not supported by PostgreSQL version')
1052            raise  # re-raise original error
1053        res = q.dictresult()
1054        if res:  # may be empty with "do nothing"
1055            for n, value in res[0].items():
1056                if qoid and n == 'oid':
1057                    n = qoid
1058                elif value is not None and attnames.get(n) == 'bytea':
1059                    value = self.unescape_bytea(value)
1060                row[n] = value
1061        else:
1062            self.get(table, row)
1063        return row
1064
1065    def clear(self, table, row=None):
1066        """Clear all the attributes to values determined by the types.
1067
1068        Numeric types are set to 0, Booleans are set to false, and everything
1069        else is set to the empty string.  If the row argument is present,
1070        it is used as the row dictionary and any entries matching attribute
1071        names are cleared with everything else left unchanged.
1072        """
1073        # At some point we will need a way to get defaults from a table.
1074        if row is None:
1075            row = {}  # empty if argument is not present
1076        attnames = self.get_attnames(table)
1077        for n, t in attnames.items():
1078            if n == 'oid':
1079                continue
1080            if t in self._num_types:
1081                row[n] = 0
1082            elif t == 'bool':
1083                row[n] = self._make_bool(False)
1084            else:
1085                row[n] = ''
1086        return row
1087
1088    def delete(self, table, row=None, **kw):
1089        """Delete an existing row in a database table.
1090
1091        This method deletes the row from a table.  It deletes based on the
1092        primary key of the table or the OID value as munged by get() or
1093        passed as keyword.
1094
1095        The return value is the number of deleted rows (i.e. 0 if the row
1096        did not exist and 1 if the row was deleted).
1097
1098        Note that if the row cannot be deleted because e.g. it is still
1099        referenced by another table, this method raises a ProgrammingError.
1100        """
1101        if table.endswith('*'):  # hint for descendant tables can be ignored
1102            table = table[:-1].rstrip()
1103        attnames = self.get_attnames(table)
1104        qoid = _oid_key(table) if 'oid' in attnames else None
1105        if row is None:
1106            row = {}
1107        elif 'oid' in row:
1108            del row['oid']  # only accept oid key from named args for safety
1109        row.update(kw)
1110        if qoid and qoid in row and 'oid' not in row:
1111            row['oid'] = row[qoid]
1112        try:  # try using the primary key
1113            keyname = self.pkey(table, True)
1114        except KeyError:  # the table has no primary key
1115            # try using the oid instead
1116            if qoid and 'oid' in row:
1117                keyname = ('oid',)
1118            else:
1119                raise _prg_error('Table %s has no primary key' % table)
1120        else:  # the table has a primary key
1121            # check whether all key columns have values
1122            if not set(keyname).issubset(row):
1123                # try using the oid instead
1124                if qoid and 'oid' in row:
1125                    keyname = ('oid',)
1126                else:
1127                    raise KeyError('Missing primary key in row')
1128        params = []
1129        param = partial(self._prepare_param, params=params)
1130        col = self.escape_identifier
1131        where = ' AND '.join('%s = %s' % (
1132            col(k), param(row[k], attnames[k])) for k in keyname)
1133        if 'oid' in row:
1134            if qoid:
1135                row[qoid] = row['oid']
1136            del row['oid']
1137        q = 'DELETE FROM %s WHERE %s' % (
1138            self._escape_qualified_name(table), where)
1139        self._do_debug(q, params)
1140        res = self.db.query(q, params)
1141        return int(res)
1142
1143    def truncate(self, table, restart=False, cascade=False, only=False):
1144        """Empty a table or set of tables.
1145
1146        This method quickly removes all rows from the given table or set
1147        of tables.  It has the same effect as an unqualified DELETE on each
1148        table, but since it does not actually scan the tables it is faster.
1149        Furthermore, it reclaims disk space immediately, rather than requiring
1150        a subsequent VACUUM operation. This is most useful on large tables.
1151
1152        If restart is set to True, sequences owned by columns of the truncated
1153        table(s) are automatically restarted.  If cascade is set to True, it
1154        also truncates all tables that have foreign-key references to any of
1155        the named tables.  If the parameter only is not set to True, all the
1156        descendant tables (if any) will also be truncated. Optionally, a '*'
1157        can be specified after the table name to explicitly indicate that
1158        descendant tables are included.
1159        """
1160        if isinstance(table, basestring):
1161            only = {table: only}
1162            table = [table]
1163        elif isinstance(table, (list, tuple)):
1164            if isinstance(only, (list, tuple)):
1165                only = dict(zip(table, only))
1166            else:
1167                only = dict.fromkeys(table, only)
1168        elif isinstance(table, (set, frozenset)):
1169            only = dict.fromkeys(table, only)
1170        else:
1171            raise TypeError('The table must be a string, list or set')
1172        if not (restart is None or isinstance(restart, (bool, int))):
1173            raise TypeError('Invalid type for the restart option')
1174        if not (cascade is None or isinstance(cascade, (bool, int))):
1175            raise TypeError('Invalid type for the cascade option')
1176        tables = []
1177        for t in table:
1178            u = only.get(t)
1179            if not (u is None or isinstance(u, (bool, int))):
1180                raise TypeError('Invalid type for the only option')
1181            if t.endswith('*'):
1182                if u:
1183                    raise ValueError(
1184                        'Contradictory table name and only options')
1185                t = t[:-1].rstrip()
1186            t = self._escape_qualified_name(t)
1187            if u:
1188                t = 'ONLY %s' % t
1189            tables.append(t)
1190        q = ['TRUNCATE', ', '.join(tables)]
1191        if restart:
1192            q.append('RESTART IDENTITY')
1193        if cascade:
1194            q.append('CASCADE')
1195        q = ' '.join(q)
1196        self._do_debug(q)
1197        return self.query(q)
1198
1199    def notification_handler(self,
1200            event, callback, arg_dict=None, timeout=None, stop_event=None):
1201        """Get notification handler that will run the given callback."""
1202        return NotificationHandler(self,
1203            event, callback, arg_dict, timeout, stop_event)
1204
1205
1206# if run as script, print some information
1207
1208if __name__ == '__main__':
1209    print('PyGreSQL version' + version)
1210    print('')
1211    print(__doc__)
Note: See TracBrowser for help on using the repository browser.