source: trunk/pg.py @ 769

Last change on this file since 769 was 769, checked in by cito, 4 years ago

Removed misleading sentence from docs

The insert() method does in fact support inserting into views in newer
PostgreSQL versions or when the necessary rules have been created.

  • Property svn:keywords set to Id
File size: 46.9 KB
Line 
1#! /usr/bin/python
2#
3# pg.py
4#
5# $Id: pg.py 769 2016-01-19 17:00:26Z cito $
6#
7
8"""PyGreSQL classic interface.
9
10This pg module implements some basic database management stuff.
11It includes the _pg module and builds on it, providing the higher
12level wrapper class named DB with additional functionality.
13This is known as the "classic" ("old style") PyGreSQL interface.
14For a DB-API 2 compliant interface use the newer pgdb module.
15"""
16
17# Copyright (c) 1997-2016 by D'Arcy J.M. Cain.
18#
19# Contributions made by Ch. Zwerschke and others.
20#
21# The notification handler is based on pgnotify which is
22# Copyright (c) 2001 Ng Pheng Siong. All rights reserved.
23#
24# Permission to use, copy, modify, and distribute this software and its
25# documentation for any purpose and without fee is hereby granted,
26# provided that the above copyright notice appear in all copies and that
27# both that copyright notice and this permission notice appear in
28# supporting documentation.
29
30from __future__ import print_function
31
32from _pg import *
33
34import select
35import warnings
36
37from decimal import Decimal
38from collections import namedtuple
39from functools import partial
40
41try:
42    from collections import OrderedDict
43except ImportError:  # Python 2.6 or 3.0
44    OrderedDict = dict
45
46try:
47    basestring
48except NameError:  # Python >= 3.0
49    basestring = (str, bytes)
50
51set_decimal(Decimal)
52
53
54# Auxiliary functions that are independent from a DB connection:
55
56def _oid_key(table):
57    """Build oid key from a table name."""
58    return 'oid(%s)' % table
59
60
61def _simpletype(typ):
62    """Determine a simplified name a pg_type name."""
63    if typ.startswith('bool'):
64        return 'bool'
65    if typ.startswith(('abstime', 'date', 'interval', 'timestamp')):
66        return 'date'
67    if typ.startswith(('cid', 'oid', 'int', 'xid')):
68        return 'int'
69    if typ.startswith('float'):
70        return 'float'
71    if typ.startswith('numeric'):
72        return 'num'
73    if typ.startswith('money'):
74        return 'money'
75    if typ.startswith('bytea'):
76        return 'bytea'
77    return 'text'
78
79
80def _namedresult(q):
81    """Get query result as named tuples."""
82    row = namedtuple('Row', q.listfields())
83    return [row(*r) for r in q.getresult()]
84
85set_namedresult(_namedresult)
86
87
88def _db_error(msg, cls=DatabaseError):
89    """Return DatabaseError with empty sqlstate attribute."""
90    error = cls(msg)
91    error.sqlstate = None
92    return error
93
94
95def _int_error(msg):
96    """Return InternalError."""
97    return _db_error(msg, InternalError)
98
99
100def _prg_error(msg):
101    """Return ProgrammingError."""
102    return _db_error(msg, ProgrammingError)
103
104
105class NotificationHandler(object):
106    """A PostgreSQL client-side asynchronous notification handler."""
107
108    def __init__(self, db, event, callback=None,
109            arg_dict=None, timeout=None, stop_event=None):
110        """Initialize the notification handler.
111
112        You must pass a PyGreSQL database connection, the name of an
113        event (notification channel) to listen for and a callback function.
114
115        You can also specify a dictionary arg_dict that will be passed as
116        the single argument to the callback function, and a timeout value
117        in seconds (a floating point number denotes fractions of seconds).
118        If it is absent or None, the callers will never time out.  If the
119        timeout is reached, the callback function will be called with a
120        single argument that is None.  If you set the timeout to zero,
121        the handler will poll notifications synchronously and return.
122
123        You can specify the name of the event that will be used to signal
124        the handler to stop listening as stop_event. By default, it will
125        be the event name prefixed with 'stop_'.
126        """
127        self.db = db
128        self.event = event
129        self.stop_event = stop_event or 'stop_%s' % event
130        self.listening = False
131        self.callback = callback
132        if arg_dict is None:
133            arg_dict = {}
134        self.arg_dict = arg_dict
135        self.timeout = timeout
136
137    def __del__(self):
138        self.unlisten()
139
140    def close(self):
141        """Stop listening and close the connection."""
142        if self.db:
143            self.unlisten()
144            self.db.close()
145            self.db = None
146
147    def listen(self):
148        """Start listening for the event and the stop event."""
149        if not self.listening:
150            self.db.query('listen "%s"' % self.event)
151            self.db.query('listen "%s"' % self.stop_event)
152            self.listening = True
153
154    def unlisten(self):
155        """Stop listening for the event and the stop event."""
156        if self.listening:
157            self.db.query('unlisten "%s"' % self.event)
158            self.db.query('unlisten "%s"' % self.stop_event)
159            self.listening = False
160
161    def notify(self, db=None, stop=False, payload=None):
162        """Generate a notification.
163
164        Optionally, you can pass a payload with the notification.
165
166        If you set the stop flag, a stop notification will be sent that
167        will cause the handler to stop listening.
168
169        Note: If the notification handler is running in another thread, you
170        must pass a different database connection since PyGreSQL database
171        connections are not thread-safe.
172        """
173        if self.listening:
174            if not db:
175                db = self.db
176            q = 'notify "%s"' % (self.stop_event if stop else self.event)
177            if payload:
178                q += ", '%s'" % payload
179            return db.query(q)
180
181    def __call__(self):
182        """Invoke the notification handler.
183
184        The handler is a loop that listens for notifications on the event
185        and stop event channels.  When either of these notifications are
186        received, its associated 'pid', 'event' and 'extra' (the payload
187        passed with the notification) are inserted into its arg_dict
188        dictionary and the callback is invoked with this dictionary as
189        a single argument.  When the handler receives a stop event, it
190        stops listening to both events and return.
191
192        In the special case that the timeout of the handler has been set
193        to zero, the handler will poll all events synchronously and return.
194        If will keep listening until it receives a stop event.
195
196        Note: If you run this loop in another thread, don't use the same
197        database connection for database operations in the main thread.
198        """
199        self.listen()
200        poll = self.timeout == 0
201        if not poll:
202            rlist = [self.db.fileno()]
203        while self.listening:
204            if poll or select.select(rlist, [], [], self.timeout)[0]:
205                while self.listening:
206                    notice = self.db.getnotify()
207                    if not notice:  # no more messages
208                        break
209                    event, pid, extra = notice
210                    if event not in (self.event, self.stop_event):
211                        self.unlisten()
212                        raise _db_error(
213                            'Listening for "%s" and "%s", but notified of "%s"'
214                            % (self.event, self.stop_event, event))
215                    if event == self.stop_event:
216                        self.unlisten()
217                    self.arg_dict.update(pid=pid, event=event, extra=extra)
218                    self.callback(self.arg_dict)
219                if poll:
220                    break
221            else:   # we timed out
222                self.unlisten()
223                self.callback(None)
224
225
226def pgnotify(*args, **kw):
227    """Same as NotificationHandler, under the traditional name."""
228    warnings.warn("pgnotify is deprecated, use NotificationHandler instead",
229        DeprecationWarning, stacklevel=2)
230    return NotificationHandler(*args, **kw)
231
232
233# The actual PostGreSQL database connection interface:
234
235class DB(object):
236    """Wrapper class for the _pg connection type."""
237
238    def __init__(self, *args, **kw):
239        """Create a new connection
240
241        You can pass either the connection parameters or an existing
242        _pg or pgdb connection. This allows you to use the methods
243        of the classic pg interface with a DB-API 2 pgdb connection.
244        """
245        if not args and len(kw) == 1:
246            db = kw.get('db')
247        elif not kw and len(args) == 1:
248            db = args[0]
249        else:
250            db = None
251        if db:
252            if isinstance(db, DB):
253                db = db.db
254            else:
255                try:
256                    db = db._cnx
257                except AttributeError:
258                    pass
259        if not db or not hasattr(db, 'db') or not hasattr(db, 'query'):
260            db = connect(*args, **kw)
261            self._closeable = True
262        else:
263            self._closeable = False
264        self.db = db
265        self.dbname = db.db
266        self._regtypes = False
267        self._attnames = {}
268        self._pkeys = {}
269        self._privileges = {}
270        self._args = args, kw
271        self.debug = None  # For debugging scripts, this can be set
272            # * to a string format specification (e.g. in CGI set to "%s<BR>"),
273            # * to a file object to write debug statements or
274            # * to a callable object which takes a string argument
275            # * to any other true value to just print debug statements
276
277    def __getattr__(self, name):
278        # All undefined members are same as in underlying connection:
279        if self.db:
280            return getattr(self.db, name)
281        else:
282            raise _int_error('Connection is not valid')
283
284    def __dir__(self):
285        # Custom dir function including the attributes of the connection:
286        attrs = set(self.__class__.__dict__)
287        attrs.update(self.__dict__)
288        attrs.update(dir(self.db))
289        return sorted(attrs)
290
291    # Context manager methods
292
293    def __enter__(self):
294        """Enter the runtime context. This will start a transactio."""
295        self.begin()
296        return self
297
298    def __exit__(self, et, ev, tb):
299        """Exit the runtime context. This will end the transaction."""
300        if et is None and ev is None and tb is None:
301            self.commit()
302        else:
303            self.rollback()
304
305    # Auxiliary methods
306
307    def _do_debug(self, *args):
308        """Print a debug message"""
309        if self.debug:
310            s = '\n'.join(args)
311            if isinstance(self.debug, basestring):
312                print(self.debug % s)
313            elif hasattr(self.debug, 'write'):
314                self.debug.write(s + '\n')
315            elif callable(self.debug):
316                self.debug(s)
317            else:
318                print(s)
319
320    def _escape_qualified_name(self, s):
321        """Escape a qualified name.
322
323        Escapes the name for use as an SQL identifier, unless the
324        name contains a dot, in which case the name is ambiguous
325        (could be a qualified name or just a name with a dot in it)
326        and must be quoted manually by the caller.
327        """
328        if '.' not in s:
329            s = self.escape_identifier(s)
330        return s
331
332    @staticmethod
333    def _make_bool(d):
334        """Get boolean value corresponding to d."""
335        return bool(d) if get_bool() else ('t' if d else 'f')
336
337    _bool_true_values = frozenset('t true 1 y yes on'.split())
338
339    def _prepare_bool(self, d):
340        """Prepare a boolean parameter."""
341        if isinstance(d, basestring):
342            if not d:
343                return None
344            d = d.lower() in self._bool_true_values
345        return 't' if d else 'f'
346
347    _date_literals = frozenset('current_date current_time'
348        ' current_timestamp localtime localtimestamp'.split())
349
350    def _prepare_date(self, d):
351        """Prepare a date parameter."""
352        if not d:
353            return None
354        if isinstance(d, basestring) and d.lower() in self._date_literals:
355            raise ValueError
356        return d
357
358    _num_types = frozenset('int float num money'
359        ' int2 int4 int8 float4 float8 numeric money'.split())
360
361    def _prepare_num(self, d):
362        """Prepare a numeric parameter."""
363        if not d and d != 0:
364            return None
365        return d
366
367    def _prepare_bytea(self, d):
368        """Prepare a bytea parameter."""
369        return self.escape_bytea(d)
370
371    _prepare_funcs = dict(  # quote methods for each type
372        bool=_prepare_bool, date=_prepare_date,
373        int=_prepare_num, num=_prepare_num, float=_prepare_num,
374        money=_prepare_num, bytea=_prepare_bytea)
375
376    def _prepare_param(self, value, typ, params):
377        """Prepare and add a parameter to the list."""
378        if value is not None and typ != 'text':
379            prepare = self._prepare_funcs[typ]
380            try:
381                value = prepare(self, value)
382            except ValueError:
383                return value
384        params.append(value)
385        return '$%d' % len(params)
386
387    def _list_params(self, params):
388        """Create a human readable parameter list."""
389        return ', '.join('$%d=%r' % (n, v) for n, v in enumerate(params, 1))
390
391    @staticmethod
392    def _prepare_qualified_param(name, param):
393        """Quote parameter representing a qualified name.
394
395        Escapes the name for use as an SQL parameter, unless the
396        name contains a dot, in which case the name is ambiguous
397        (could be a qualified name or just a name with a dot in it)
398        and must be quoted manually by the caller.
399
400        """
401        if isinstance(param, int):
402            param = "$%d" % param
403        if '.' not in name:
404            param = 'quote_ident(%s)' % (param,)
405        return param
406
407    # Public methods
408
409    # escape_string and escape_bytea exist as methods,
410    # so we define unescape_bytea as a method as well
411    unescape_bytea = staticmethod(unescape_bytea)
412
413    def close(self):
414        """Close the database connection."""
415        # Wraps shared library function so we can track state.
416        if self._closeable:
417            if self.db:
418                self.db.close()
419                self.db = None
420            else:
421                raise _int_error('Connection already closed')
422
423    def reset(self):
424        """Reset connection with current parameters.
425
426        All derived queries and large objects derived from this connection
427        will not be usable after this call.
428
429        """
430        if self.db:
431            self.db.reset()
432        else:
433            raise _int_error('Connection already closed')
434
435    def reopen(self):
436        """Reopen connection to the database.
437
438        Used in case we need another connection to the same database.
439        Note that we can still reopen a database that we have closed.
440
441        """
442        # There is no such shared library function.
443        if self._closeable:
444            db = connect(*self._args[0], **self._args[1])
445            if self.db:
446                self.db.close()
447            self.db = db
448
449    def begin(self, mode=None):
450        """Begin a transaction."""
451        qstr = 'BEGIN'
452        if mode:
453            qstr += ' ' + mode
454        return self.query(qstr)
455
456    start = begin
457
458    def commit(self):
459        """Commit the current transaction."""
460        return self.query('COMMIT')
461
462    end = commit
463
464    def rollback(self, name=None):
465        """Roll back the current transaction."""
466        qstr = 'ROLLBACK'
467        if name:
468            qstr += ' TO ' + name
469        return self.query(qstr)
470
471    abort = rollback
472
473    def savepoint(self, name):
474        """Define a new savepoint within the current transaction."""
475        return self.query('SAVEPOINT ' + name)
476
477    def release(self, name):
478        """Destroy a previously defined savepoint."""
479        return self.query('RELEASE ' + name)
480
481    def get_parameter(self, parameter):
482        """Get the value of a run-time parameter.
483
484        If the parameter is a string, the return value will also be a string
485        that is the current setting of the run-time parameter with that name.
486
487        You can get several parameters at once by passing a list, set or dict.
488        When passing a list of parameter names, the return value will be a
489        corresponding list of parameter settings.  When passing a set of
490        parameter names, a new dict will be returned, mapping these parameter
491        names to their settings.  Finally, if you pass a dict as parameter,
492        its values will be set to the current parameter settings corresponding
493        to its keys.
494
495        By passing the special name 'all' as the parameter, you can get a dict
496        of all existing configuration parameters.
497        """
498        if isinstance(parameter, basestring):
499            parameter = [parameter]
500            values = None
501        elif isinstance(parameter, (list, tuple)):
502            values = []
503        elif isinstance(parameter, (set, frozenset)):
504            values = {}
505        elif isinstance(parameter, dict):
506            values = parameter
507        else:
508            raise TypeError(
509                'The parameter must be a string, list, set or dict')
510        if not parameter:
511            raise TypeError('No parameter has been specified')
512        params = {} if isinstance(values, dict) else []
513        for key in parameter:
514            param = key.strip().lower() if isinstance(
515                key, basestring) else None
516            if not param:
517                raise TypeError('Invalid parameter')
518            if param == 'all':
519                q = 'SHOW ALL'
520                values = self.db.query(q).getresult()
521                values = dict(value[:2] for value in values)
522                break
523            if isinstance(values, dict):
524                params[param] = key
525            else:
526                params.append(param)
527        else:
528            for param in params:
529                q = 'SHOW %s' % (param,)
530                value = self.db.query(q).getresult()[0][0]
531                if values is None:
532                    values = value
533                elif isinstance(values, list):
534                    values.append(value)
535                else:
536                    values[params[param]] = value
537        return values
538
539    def set_parameter(self, parameter, value=None, local=False):
540        """Set the value of a run-time parameter.
541
542        If the parameter and the value are strings, the run-time parameter
543        will be set to that value.  If no value or None is passed as a value,
544        then the run-time parameter will be restored to its default value.
545
546        You can set several parameters at once by passing a list of parameter
547        names, together with a single value that all parameters should be
548        set to or with a corresponding list of values.  You can also pass
549        the parameters as a set if you only provide a single value.
550        Finally, you can pass a dict with parameter names as keys.  In this
551        case, you should not pass a value, since the values for the parameters
552        will be taken from the dict.
553
554        By passing the special name 'all' as the parameter, you can reset
555        all existing settable run-time parameters to their default values.
556
557        If you set local to True, then the command takes effect for only the
558        current transaction.  After commit() or rollback(), the session-level
559        setting takes effect again.  Setting local to True will appear to
560        have no effect if it is executed outside a transaction, since the
561        transaction will end immediately.
562        """
563        if isinstance(parameter, basestring):
564            parameter = {parameter: value}
565        elif isinstance(parameter, (list, tuple)):
566            if isinstance(value, (list, tuple)):
567                parameter = dict(zip(parameter, value))
568            else:
569                parameter = dict.fromkeys(parameter, value)
570        elif isinstance(parameter, (set, frozenset)):
571            if isinstance(value, (list, tuple, set, frozenset)):
572                value = set(value)
573                if len(value) == 1:
574                    value = value.pop()
575            if not(value is None or isinstance(value, basestring)):
576                raise ValueError('A single value must be specified'
577                    ' when parameter is a set')
578            parameter = dict.fromkeys(parameter, value)
579        elif isinstance(parameter, dict):
580            if value is not None:
581                raise ValueError('A value must not be specified'
582                    ' when parameter is a dictionary')
583        else:
584            raise TypeError(
585                'The parameter must be a string, list, set or dict')
586        if not parameter:
587            raise TypeError('No parameter has been specified')
588        params = {}
589        for key, value in parameter.items():
590            param = key.strip().lower() if isinstance(
591                key, basestring) else None
592            if not param:
593                raise TypeError('Invalid parameter')
594            if param == 'all':
595                if value is not None:
596                    raise ValueError('A value must ot be specified'
597                        " when parameter is 'all'")
598                params = {'all': None}
599                break
600            params[param] = value
601        local = ' LOCAL' if local else ''
602        for param, value in params.items():
603            if value is None:
604                q = 'RESET%s %s' % (local, param)
605            else:
606                q = 'SET%s %s TO %s' % (local, param, value)
607            self._do_debug(q)
608            self.db.query(q)
609
610    def query(self, qstr, *args):
611        """Execute a SQL command string.
612
613        This method simply sends a SQL query to the database.  If the query is
614        an insert statement that inserted exactly one row into a table that
615        has OIDs, the return value is the OID of the newly inserted row.
616        If the query is an update or delete statement, or an insert statement
617        that did not insert exactly one row in a table with OIDs, then the
618        number of rows affected is returned as a string.  If it is a statement
619        that returns rows as a result (usually a select statement, but maybe
620        also an "insert/update ... returning" statement), this method returns
621        a Query object that can be accessed via getresult() or dictresult()
622        or simply printed.  Otherwise, it returns `None`.
623
624        The query can contain numbered parameters of the form $1 in place
625        of any data constant.  Arguments given after the query string will
626        be substituted for the corresponding numbered parameter.  Parameter
627        values can also be given as a single list or tuple argument.
628        """
629        # Wraps shared library function for debugging.
630        if not self.db:
631            raise _int_error('Connection is not valid')
632        self._do_debug(qstr)
633        return self.db.query(qstr, args)
634
635    def pkey(self, table, composite=False, flush=False):
636        """Get or set the primary key of a table.
637
638        Single primary keys are returned as strings unless you
639        set the composite flag.  Composite primary keys are always
640        represented as tuples.  Note that this raises a KeyError
641        if the table does not have a primary key.
642
643        If flush is set then the internal cache for primary keys will
644        be flushed.  This may be necessary after the database schema or
645        the search path has been changed.
646        """
647        pkeys = self._pkeys
648        if flush:
649            pkeys.clear()
650            self._do_debug('The pkey cache has been flushed')
651        try:  # cache lookup
652            pkey = pkeys[table]
653        except KeyError:  # cache miss, check the database
654            q = ("SELECT a.attname, a.attnum, i.indkey FROM pg_index i"
655                " JOIN pg_attribute a ON a.attrelid = i.indrelid"
656                " AND a.attnum = ANY(i.indkey)"
657                " AND NOT a.attisdropped"
658                " WHERE i.indrelid=%s::regclass"
659                " AND i.indisprimary ORDER BY a.attnum") % (
660                    self._prepare_qualified_param(table, 1),)
661            pkey = self.db.query(q, (table,)).getresult()
662            if not pkey:
663                raise KeyError('Table %s has no primary key' % table)
664            # we want to use the order defined in the primary key index here,
665            # not the order as defined by the columns in the table
666            if len(pkey) > 1:
667                indkey = [int(k) for k in pkey[0][2].split()]
668                pkey = sorted(pkey, key=lambda row: indkey.index(row[1]))
669                pkey = tuple(row[0] for row in pkey)
670            else:
671                pkey = pkey[0][0]
672            pkeys[table] = pkey  # cache it
673        if composite and not isinstance(pkey, tuple):
674            pkey = (pkey,)
675        return pkey
676
677    def get_databases(self):
678        """Get list of databases in the system."""
679        return [s[0] for s in
680            self.db.query('SELECT datname FROM pg_database').getresult()]
681
682    def get_relations(self, kinds=None):
683        """Get list of relations in connected database of specified kinds.
684
685        If kinds is None or empty, all kinds of relations are returned.
686        Otherwise kinds can be a string or sequence of type letters
687        specifying which kind of relations you want to list.
688        """
689        where = " AND r.relkind IN (%s)" % ','.join(
690            ["'%s'" % k for k in kinds]) if kinds else ''
691        q = ("SELECT quote_ident(s.nspname)||'.'||quote_ident(r.relname)"
692            " FROM pg_class r"
693            " JOIN pg_namespace s ON s.oid = r.relnamespace"
694            " WHERE s.nspname NOT SIMILAR"
695            " TO 'pg/_%%|information/_schema' ESCAPE '/' %s"
696            " ORDER BY s.nspname, r.relname") % where
697        return [r[0] for r in self.db.query(q).getresult()]
698
699    def get_tables(self):
700        """Return list of tables in connected database."""
701        return self.get_relations('r')
702
703    def get_attnames(self, table, flush=False):
704        """Given the name of a table, dig out the set of attribute names.
705
706        Returns a dictionary of attribute names (the names are the keys,
707        the values are the names of the attributes' types).
708
709        If your Python version supports this, the dictionary will be an
710        OrderedDictionary with the column names in the right order.
711
712        If flush is set then the internal cache for attribute names will
713        be flushed. This may be necessary after the database schema or
714        the search path has been changed.
715
716        By default, only a limited number of simple types will be returned.
717        You can get the regular types after calling use_regtypes(True).
718        """
719        attnames = self._attnames
720        if flush:
721            attnames.clear()
722            self._do_debug('The attnames cache has been flushed')
723        try:  # cache lookup
724            names = attnames[table]
725        except KeyError:  # cache miss, check the database
726            q = ("SELECT a.attname, t.typname%s"
727                " FROM pg_attribute a"
728                " JOIN pg_type t ON t.oid = a.atttypid"
729                " WHERE a.attrelid = %s::regclass"
730                " AND (a.attnum > 0 OR a.attname = 'oid')"
731                " AND NOT a.attisdropped ORDER BY a.attnum") % (
732                    '::regtype' if self._regtypes else '',
733                    self._prepare_qualified_param(table, 1))
734            names = self.db.query(q, (table,)).getresult()
735            if not self._regtypes:
736                names = ((name, _simpletype(typ)) for name, typ in names)
737            names = OrderedDict(names)
738            attnames[table] = names  # cache it
739        return names
740
741    def use_regtypes(self, regtypes=None):
742        """Use regular type names instead of simplified type names."""
743        if regtypes is None:
744            return self._regtypes
745        else:
746            regtypes = bool(regtypes)
747            if regtypes != self._regtypes:
748                self._regtypes = regtypes
749                self._attnames.clear()
750            return regtypes
751
752    def has_table_privilege(self, table, privilege='select'):
753        """Check whether current user has specified table privilege."""
754        privilege = privilege.lower()
755        try:  # ask cache
756            return self._privileges[(table, privilege)]
757        except KeyError:  # cache miss, ask the database
758            q = "SELECT has_table_privilege(%s, $2)" % (
759                self._prepare_qualified_param(table, 1),)
760            q = self.db.query(q, (table, privilege))
761            ret = q.getresult()[0][0] == self._make_bool(True)
762            self._privileges[(table, privilege)] = ret  # cache it
763            return ret
764
765    def get(self, table, row, keyname=None):
766        """Get a row from a database table or view.
767
768        This method is the basic mechanism to get a single row.  It assumes
769        that the keyname specifies a unique row.  It must be the name of a
770        single column or a tuple of column names.  If the keyname is not
771        specified, then the primary key for the table is used.
772
773        If row is a dictionary, then the value for the key is taken from it.
774        Otherwise, the row must be a single value or a tuple of values
775        corresponding to the passed keyname or primary key.  The fetched row
776        from the table will be returned as a new dictionary or used to replace
777        the existing values when row was passed as aa dictionary.
778
779        The OID is also put into the dictionary if the table has one, but
780        in order to allow the caller to work with multiple tables, it is
781        munged as "oid(table)" using the actual name of the table.
782        """
783        if table.endswith('*'):  # hint for descendant tables can be ignored
784            table = table[:-1].rstrip()
785        attnames = self.get_attnames(table)
786        qoid = _oid_key(table) if 'oid' in attnames else None
787        if keyname and isinstance(keyname, basestring):
788            keyname = (keyname,)
789        if qoid and isinstance(row, dict) and qoid in row and 'oid' not in row:
790            row['oid'] = row[qoid]
791        if not keyname:
792            try:  # if keyname is not specified, try using the primary key
793                keyname = self.pkey(table, True)
794            except KeyError:  # the table has no primary key
795                # try using the oid instead
796                if qoid and isinstance(row, dict) and 'oid' in row:
797                    keyname = ('oid',)
798                else:
799                    raise _prg_error('Table %s has no primary key' % table)
800            else:  # the table has a primary key
801                # check whether all key columns have values
802                if isinstance(row, dict) and not set(keyname).issubset(row):
803                    # try using the oid instead
804                    if qoid and 'oid' in row:
805                        keyname = ('oid',)
806                    else:
807                        raise KeyError(
808                            'Missing value in row for specified keyname')
809        if not isinstance(row, dict):
810            if not isinstance(row, (tuple, list)):
811                row = [row]
812            if len(keyname) != len(row):
813                raise KeyError(
814                    'Differing number of items in keyname and row')
815            row = dict(zip(keyname, row))
816        params = []
817        param = partial(self._prepare_param, params=params)
818        col = self.escape_identifier
819        what = 'oid, *' if qoid else '*'
820        where = ' AND '.join('%s = %s' % (
821            col(k), param(row[k], attnames[k])) for k in keyname)
822        if 'oid' in row:
823            if qoid:
824                row[qoid] = row['oid']
825            del row['oid']
826        q = 'SELECT %s FROM %s WHERE %s LIMIT 1' % (
827            what, self._escape_qualified_name(table), where)
828        self._do_debug(q, params)
829        q = self.db.query(q, params)
830        res = q.dictresult()
831        if not res:
832            raise _db_error('No such record in %s\nwhere %s\nwith %s' % (
833                table, where, self._list_params(params)))
834        for n, value in res[0].items():
835            if qoid and n == 'oid':
836                n = qoid
837            elif value is not None and attnames.get(n) == 'bytea':
838                value = self.unescape_bytea(value)
839            row[n] = value
840        return row
841
842    def insert(self, table, row=None, **kw):
843        """Insert a row into a database table.
844
845        This method inserts a row into a table.  The name of the table must
846        be passed as the first parameter.  The other parameters are used for
847        providing the data of the row that shall be inserted into the table.
848        If a dictionary is supplied as the second parameter, it starts with
849        that.  Otherwise it uses a blank dictionary. Either way the dictionary
850        is updated from the keywords.
851
852        The dictionary is then reloaded with the values actually inserted in
853        order to pick up values modified by rules, triggers, etc.
854        """
855        if table.endswith('*'):  # hint for descendant tables can be ignored
856            table = table[:-1].rstrip()
857        if row is None:
858            row = {}
859        row.update(kw)
860        if 'oid' in row:
861            del row['oid']  # do not insert oid
862        attnames = self.get_attnames(table)
863        qoid = _oid_key(table) if 'oid' in attnames else None
864        params = []
865        param = partial(self._prepare_param, params=params)
866        col = self.escape_identifier
867        names, values = [], []
868        for n in attnames:
869            if n in row:
870                names.append(col(n))
871                values.append(param(row[n], attnames[n]))
872        names, values = ', '.join(names), ', '.join(values)
873        ret = 'oid, *' if qoid else '*'
874        q = 'INSERT INTO %s (%s) VALUES (%s) RETURNING %s' % (
875            self._escape_qualified_name(table), names, values, ret)
876        self._do_debug(q, params)
877        q = self.db.query(q, params)
878        res = q.dictresult()
879        if res:  # this should always be true
880            for n, value in res[0].items():
881                if qoid and n == 'oid':
882                    n = qoid
883                elif value is not None and attnames.get(n) == 'bytea':
884                    value = self.unescape_bytea(value)
885                row[n] = value
886        return row
887
888    def update(self, table, row=None, **kw):
889        """Update an existing row in a database table.
890
891        Similar to insert but updates an existing row.  The update is based
892        on the primary key of the table or the OID value as munged by get
893        or passed as keyword.
894
895        The dictionary is then modified to reflect any changes caused by the
896        update due to triggers, rules, default values, etc.
897        """
898        if table.endswith('*'):
899            table = table[:-1].rstrip()  # need parent table name
900        attnames = self.get_attnames(table)
901        qoid = _oid_key(table) if 'oid' in attnames else None
902        if row is None:
903            row = {}
904        elif 'oid' in row:
905            del row['oid']  # only accept oid key from named args for safety
906        row.update(kw)
907        if qoid and qoid in row and 'oid' not in row:
908            row['oid'] = row[qoid]
909        try:  # try using the primary key
910            keyname = self.pkey(table, True)
911        except KeyError:  # the table has no primary key
912            # try using the oid instead
913            if qoid and 'oid' in row:
914                keyname = ('oid',)
915            else:
916                raise _prg_error('Table %s has no primary key' % table)
917        else:  # the table has a primary key
918            # check whether all key columns have values
919            if not set(keyname).issubset(row):
920                # try using the oid instead
921                if qoid and 'oid' in row:
922                    keyname = ('oid',)
923                else:
924                    raise KeyError('Missing primary key in row')
925        params = []
926        param = partial(self._prepare_param, params=params)
927        col = self.escape_identifier
928        where = ' AND '.join('%s = %s' % (
929            col(k), param(row[k], attnames[k])) for k in keyname)
930        if 'oid' in row:
931            if qoid:
932                row[qoid] = row['oid']
933            del row['oid']
934        values = []
935        keyname = set(keyname)
936        for n in attnames:
937            if n in row and n not in keyname:
938                values.append('%s = %s' % (col(n), param(row[n], attnames[n])))
939        if not values:
940            return row
941        values = ', '.join(values)
942        ret = 'oid, *' if qoid else '*'
943        q = 'UPDATE %s SET %s WHERE %s RETURNING %s' % (
944            self._escape_qualified_name(table), values, where, ret)
945        self._do_debug(q, params)
946        q = self.db.query(q, params)
947        res = q.dictresult()
948        if res:  # may be empty when row does not exist
949            for n, value in res[0].items():
950                if qoid and n == 'oid':
951                    n = qoid
952                elif value is not None and attnames.get(n) == 'bytea':
953                    value = self.unescape_bytea(value)
954                row[n] = value
955        return row
956
957    def upsert(self, table, row=None, **kw):
958        """Insert a row into a database table with conflict resolution
959
960        This method inserts a row into a table, but instead of raising a
961        ProgrammingError exception in case a row with the same primary key
962        already exists, an update will be executed instead.  This will be
963        performed as a single atomic operation on the database, so race
964        conditions can be avoided.
965
966        Like the insert method, the first parameter is the name of the
967        table and the second parameter can be used to pass the values to
968        be inserted as a dictionary.
969
970        Unlike the insert und update statement, keyword parameters are not
971        used to modify the dictionary, but to specify which columns shall
972        be updated in case of a conflict, and in which way:
973
974        A value of False or None means the column shall not be updated,
975        a value of True means the column shall be updated with the value
976        that has been proposed for insertion, i.e. has been passed as value
977        in the dictionary.  Columns that are not specified by keywords but
978        appear as keys in the dictionary are also updated like in the case
979        keywords had been passed with the value True.
980
981        So if in the case of a conflict you want to update every column that
982        has been passed in the dictionary row , you would call upsert(table, row).
983        If you don't want to do anything in case of a conflict, i.e. leave
984        the existing row as it is, call upsert(table, row, **dict.fromkeys(row)).
985
986        If you need more fine-grained control of what gets updated, you can
987        also pass strings in the keyword parameters.  These strings will
988        be used as SQL expressions for the update columns.  In these
989        expressions you can refer to the value that already exists in
990        the table by prefixing the column name with "included.", and to
991        the value that has been proposed for insertion by prefixing the
992        column name with the "excluded."
993
994        The dictionary is modified in any case to reflect the values in
995        the database after the operation has completed.
996
997        Note: The method uses the PostgreSQL "upsert" feature which is
998        only available since PostgreSQL 9.5.
999        """
1000        if table.endswith('*'):  # hint for descendant tables can be ignored
1001            table = table[:-1].rstrip()
1002        if row is None:
1003            row = {}
1004        if 'oid' in row:
1005            del row['oid']  # do not insert oid
1006        if 'oid' in kw:
1007            del kw['oid']  # do not update oid
1008        attnames = self.get_attnames(table)
1009        qoid = _oid_key(table) if 'oid' in attnames else None
1010        params = []
1011        param = partial(self._prepare_param,params=params)
1012        col = self.escape_identifier
1013        names, values, updates = [], [], []
1014        for n in attnames:
1015            if n in row:
1016                names.append(col(n))
1017                values.append(param(row[n], attnames[n]))
1018        names, values = ', '.join(names), ', '.join(values)
1019        try:
1020            keyname = self.pkey(table, True)
1021        except KeyError:
1022            raise _prg_error('Table %s has no primary key' % table)
1023        target = ', '.join(col(k) for k in keyname)
1024        update = []
1025        keyname = set(keyname)
1026        keyname.add('oid')
1027        for n in attnames:
1028            if n not in keyname:
1029                value = kw.get(n, True)
1030                if value:
1031                    if not isinstance(value, basestring):
1032                        value = 'excluded.%s' % col(n)
1033                    update.append('%s = %s' % (col(n), value))
1034        if not values:
1035            return row
1036        do = 'update set %s' % ', '.join(update) if update else 'nothing'
1037        ret = 'oid, *' if qoid else '*'
1038        q = ('INSERT INTO %s AS included (%s) VALUES (%s)'
1039            ' ON CONFLICT (%s) DO %s RETURNING %s') % (
1040                self._escape_qualified_name(table), names, values,
1041                target, do, ret)
1042        self._do_debug(q, params)
1043        try:
1044            q = self.db.query(q, params)
1045        except ProgrammingError:
1046            if self.server_version < 90500:
1047                raise _prg_error(
1048                    'Upsert operation is not supported by PostgreSQL version')
1049            raise  # re-raise original error
1050        res = q.dictresult()
1051        if res:  # may be empty with "do nothing"
1052            for n, value in res[0].items():
1053                if qoid and n == 'oid':
1054                    n = qoid
1055                elif value is not None and attnames.get(n) == 'bytea':
1056                    value = self.unescape_bytea(value)
1057                row[n] = value
1058        else:
1059            self.get(table, row)
1060        return row
1061
1062    def clear(self, table, row=None):
1063        """Clear all the attributes to values determined by the types.
1064
1065        Numeric types are set to 0, Booleans are set to false, and everything
1066        else is set to the empty string.  If the row argument is present,
1067        it is used as the row dictionary and any entries matching attribute
1068        names are cleared with everything else left unchanged.
1069        """
1070        # At some point we will need a way to get defaults from a table.
1071        if row is None:
1072            row = {}  # empty if argument is not present
1073        attnames = self.get_attnames(table)
1074        for n, t in attnames.items():
1075            if n == 'oid':
1076                continue
1077            if t in self._num_types:
1078                row[n] = 0
1079            elif t == 'bool':
1080                row[n] = self._make_bool(False)
1081            else:
1082                row[n] = ''
1083        return row
1084
1085    def delete(self, table, row=None, **kw):
1086        """Delete an existing row in a database table.
1087
1088        This method deletes the row from a table.  It deletes based on the
1089        primary key of the table or the OID value as munged by get() or
1090        passed as keyword.
1091
1092        The return value is the number of deleted rows (i.e. 0 if the row
1093        did not exist and 1 if the row was deleted).
1094
1095        Note that if the row cannot be deleted because e.g. it is still
1096        referenced by another table, this method raises a ProgrammingError.
1097        """
1098        if table.endswith('*'):  # hint for descendant tables can be ignored
1099            table = table[:-1].rstrip()
1100        attnames = self.get_attnames(table)
1101        qoid = _oid_key(table) if 'oid' in attnames else None
1102        if row is None:
1103            row = {}
1104        elif 'oid' in row:
1105            del row['oid']  # only accept oid key from named args for safety
1106        row.update(kw)
1107        if qoid and qoid in row and 'oid' not in row:
1108            row['oid'] = row[qoid]
1109        try:  # try using the primary key
1110            keyname = self.pkey(table, True)
1111        except KeyError:  # the table has no primary key
1112            # try using the oid instead
1113            if qoid and 'oid' in row:
1114                keyname = ('oid',)
1115            else:
1116                raise _prg_error('Table %s has no primary key' % table)
1117        else:  # the table has a primary key
1118            # check whether all key columns have values
1119            if not set(keyname).issubset(row):
1120                # try using the oid instead
1121                if qoid and 'oid' in row:
1122                    keyname = ('oid',)
1123                else:
1124                    raise KeyError('Missing primary key in row')
1125        params = []
1126        param = partial(self._prepare_param, params=params)
1127        col = self.escape_identifier
1128        where = ' AND '.join('%s = %s' % (
1129            col(k), param(row[k], attnames[k])) for k in keyname)
1130        if 'oid' in row:
1131            if qoid:
1132                row[qoid] = row['oid']
1133            del row['oid']
1134        q = 'DELETE FROM %s WHERE %s' % (
1135            self._escape_qualified_name(table), where)
1136        self._do_debug(q, params)
1137        res = self.db.query(q, params)
1138        return int(res)
1139
1140    def truncate(self, table, restart=False, cascade=False, only=False):
1141        """Empty a table or set of tables.
1142
1143        This method quickly removes all rows from the given table or set
1144        of tables.  It has the same effect as an unqualified DELETE on each
1145        table, but since it does not actually scan the tables it is faster.
1146        Furthermore, it reclaims disk space immediately, rather than requiring
1147        a subsequent VACUUM operation. This is most useful on large tables.
1148
1149        If restart is set to True, sequences owned by columns of the truncated
1150        table(s) are automatically restarted.  If cascade is set to True, it
1151        also truncates all tables that have foreign-key references to any of
1152        the named tables.  If the parameter only is not set to True, all the
1153        descendant tables (if any) will also be truncated. Optionally, a '*'
1154        can be specified after the table name to explicitly indicate that
1155        descendant tables are included.
1156        """
1157        if isinstance(table, basestring):
1158            only = {table: only}
1159            table = [table]
1160        elif isinstance(table, (list, tuple)):
1161            if isinstance(only, (list, tuple)):
1162                only = dict(zip(table, only))
1163            else:
1164                only = dict.fromkeys(table, only)
1165        elif isinstance(table, (set, frozenset)):
1166            only = dict.fromkeys(table, only)
1167        else:
1168            raise TypeError('The table must be a string, list or set')
1169        if not (restart is None or isinstance(restart, (bool, int))):
1170            raise TypeError('Invalid type for the restart option')
1171        if not (cascade is None or isinstance(cascade, (bool, int))):
1172            raise TypeError('Invalid type for the cascade option')
1173        tables = []
1174        for t in table:
1175            u = only.get(t)
1176            if not (u is None or isinstance(u, (bool, int))):
1177                raise TypeError('Invalid type for the only option')
1178            if t.endswith('*'):
1179                if u:
1180                    raise ValueError(
1181                        'Contradictory table name and only options')
1182                t = t[:-1].rstrip()
1183            t = self._escape_qualified_name(t)
1184            if u:
1185                t = 'ONLY %s' % t
1186            tables.append(t)
1187        q = ['TRUNCATE', ', '.join(tables)]
1188        if restart:
1189            q.append('RESTART IDENTITY')
1190        if cascade:
1191            q.append('CASCADE')
1192        q = ' '.join(q)
1193        self._do_debug(q)
1194        return self.query(q)
1195
1196    def notification_handler(self,
1197            event, callback, arg_dict=None, timeout=None, stop_event=None):
1198        """Get notification handler that will run the given callback."""
1199        return NotificationHandler(self,
1200            event, callback, arg_dict, timeout, stop_event)
1201
1202
1203# if run as script, print some information
1204
1205if __name__ == '__main__':
1206    print('PyGreSQL version' + version)
1207    print('')
1208    print(__doc__)
Note: See TracBrowser for help on using the repository browser.