source: branches/4.x/pg.py @ 762

Last change on this file since 762 was 762, checked in by cito, 4 years ago

Docs and 100% test coverage for NotificationHandler?

  • Property svn:keywords set to Id
File size: 43.5 KB
Line 
1#! /usr/bin/python
2#
3# pg.py
4#
5# $Id: pg.py 762 2016-01-17 16:32:18Z cito $
6#
7
8"""PyGreSQL classic interface.
9
10This pg module implements some basic database management stuff.
11It includes the _pg module and builds on it, providing the higher
12level wrapper class named DB with additional functionality.
13This is known as the "classic" ("old style") PyGreSQL interface.
14For a DB-API 2 compliant interface use the newer pgdb module.
15
16"""
17
18# Copyright (c) 1997-2016 by D'Arcy J.M. Cain.
19#
20# Contributions made by Ch. Zwerschke and others.
21#
22# The notification handler is based on pgnotify which is
23# Copyright (c) 2001 Ng Pheng Siong. All rights reserved.
24#
25# Permission to use, copy, modify, and distribute this software and its
26# documentation for any purpose and without fee is hereby granted,
27# provided that the above copyright notice appear in all copies and that
28# both that copyright notice and this permission notice appear in
29# supporting documentation.
30
31from _pg import *
32
33import select
34import warnings
35try:
36    frozenset
37except NameError:  # Python < 2.4, unsupported
38    from sets import ImmutableSet as frozenset
39try:
40    from decimal import Decimal
41    set_decimal(Decimal)
42except ImportError:  # Python < 2.4, unsupported
43    Decimal = float
44try:
45    from collections import namedtuple
46except ImportError:  # Python < 2.6
47    namedtuple = None
48
49
50# Auxiliary functions that are independent from a DB connection:
51
52def _is_quoted(s):
53    """Check whether this string is a quoted identifier."""
54    s = s.replace('_', 'a')
55    return not s.isalnum() or s[:1].isdigit() or s != s.lower()
56
57
58def _is_unquoted(s):
59    """Check whether this string is an unquoted identifier."""
60    s = s.replace('_', 'a')
61    return s.isalnum() and not s[:1].isdigit()
62
63
64def _split_first_part(s):
65    """Split the first part of a dot separated string."""
66    s = s.lstrip()
67    if s[:1] == '"':
68        p = []
69        s = s.split('"', 3)[1:]
70        p.append(s[0])
71        while len(s) == 3 and s[1] == '':
72            p.append('"')
73            s = s[2].split('"', 2)
74            p.append(s[0])
75        p = [''.join(p)]
76        s = '"'.join(s[1:]).lstrip()
77        if s:
78            if s[:0] == '.':
79                p.append(s[1:])
80            else:
81                s = _split_first_part(s)
82                p[0] += s[0]
83                if len(s) > 1:
84                    p.append(s[1])
85    else:
86        p = s.split('.', 1)
87        s = p[0].rstrip()
88        if _is_unquoted(s):
89            s = s.lower()
90        p[0] = s
91    return p
92
93
94def _split_parts(s):
95    """Split all parts of a dot separated string."""
96    q = []
97    while s:
98        s = _split_first_part(s)
99        q.append(s[0])
100        if len(s) < 2:
101            break
102        s = s[1]
103    return q
104
105
106def _join_parts(s):
107    """Join all parts of a dot separated string."""
108    return '.'.join([_is_quoted(p) and '"%s"' % p or p for p in s])
109
110
111def _oid_key(qcl):
112    """Build oid key from qualified class name."""
113    return 'oid(%s)' % qcl
114
115
116if namedtuple:
117
118    def _namedresult(q):
119        """Get query result as named tuples."""
120        row = namedtuple('Row', q.listfields())
121        return [row(*r) for r in q.getresult()]
122
123    set_namedresult(_namedresult)
124
125
126def _db_error(msg, cls=DatabaseError):
127    """Returns DatabaseError with empty sqlstate attribute."""
128    error = cls(msg)
129    error.sqlstate = None
130    return error
131
132
133def _int_error(msg):
134    """Returns InternalError."""
135    return _db_error(msg, InternalError)
136
137
138def _prg_error(msg):
139    """Returns ProgrammingError."""
140    return _db_error(msg, ProgrammingError)
141
142
143class NotificationHandler(object):
144    """A PostgreSQL client-side asynchronous notification handler."""
145
146    def __init__(self, db, event, callback=None,
147            arg_dict=None, timeout=None, stop_event=None):
148        """Initialize the notification handler.
149
150        You must pass a PyGreSQL database connection, the name of an
151        event (notification channel) to listen for and a callback function.
152
153        You can also specify a dictionary arg_dict that will be passed as
154        the single argument to the callback function, and a timeout value
155        in seconds (a floating point number denotes fractions of seconds).
156        If it is absent or None, the callers will never time out.  If the
157        timeout is reached, the callback function will be called with a
158        single argument that is None.  If you set the timeout to zero,
159        the handler will poll notifications synchronously and return.
160
161        You can specify the name of the event that will be used to signal
162        the handler to stop listening as stop_event. By default, it will
163        be the event name prefixed with 'stop_'.
164        """
165        self.db = db
166        self.event = event
167        self.stop_event = stop_event or 'stop_%s' % event
168        self.listening = False
169        self.callback = callback
170        if arg_dict is None:
171            arg_dict = {}
172        self.arg_dict = arg_dict
173        self.timeout = timeout
174
175    def __del__(self):
176        self.unlisten()
177
178    def close(self):
179        """Stop listening and close the connection."""
180        if self.db:
181            self.unlisten()
182            self.db.close()
183            self.db = None
184
185    def listen(self):
186        """Start listening for the event and the stop event."""
187        if not self.listening:
188            self.db.query('listen "%s"' % self.event)
189            self.db.query('listen "%s"' % self.stop_event)
190            self.listening = True
191
192    def unlisten(self):
193        """Stop listening for the event and the stop event."""
194        if self.listening:
195            self.db.query('unlisten "%s"' % self.event)
196            self.db.query('unlisten "%s"' % self.stop_event)
197            self.listening = False
198
199    def notify(self, db=None, stop=False, payload=None):
200        """Generate a notification.
201
202        Optionally, you can pass a payload with the notification.
203
204        If you set the stop flag, a stop notification will be sent that
205        will cause the handler to stop listening.
206
207        Note: If the notification handler is running in another thread, you
208        must pass a different database connection since PyGreSQL database
209        connections are not thread-safe.
210        """
211        if self.listening:
212            if not db:
213                db = self.db
214            q = 'notify "%s"' % (stop and self.stop_event or self.event)
215            if payload:
216                q += ", '%s'" % payload
217            return db.query(q)
218
219    def __call__(self):
220        """Invoke the notification handler.
221
222        The handler is a loop that listens for notifications on the event
223        and stop event channels.  When either of these notifications are
224        received, its associated 'pid', 'event' and 'extra' (the payload
225        passed with the notification) are inserted into its arg_dict
226        dictionary and the callback is invoked with this dictionary as
227        a single argument.  When the handler receives a stop event, it
228        stops listening to both events and return.
229
230        In the special case that the timeout of the handler has been set
231        to zero, the handler will poll all events synchronously and return.
232        If will keep listening until it receives a stop event.
233
234        Note: If you run this loop in another thread, don't use the same
235        database connection for database operations in the main thread.
236        """
237        self.listen()
238        poll = self.timeout == 0
239        if not poll:
240            rlist = [self.db.fileno()]
241        while self.listening:
242            if poll or select.select(rlist, [], [], self.timeout)[0]:
243                while self.listening:
244                    notice = self.db.getnotify()
245                    if not notice:  # no more messages
246                        break
247                    event, pid, extra = notice
248                    if event not in (self.event, self.stop_event):
249                        self.unlisten()
250                        raise _db_error(
251                            'Listening for "%s" and "%s", but notified of "%s"'
252                            % (self.event, self.stop_event, event))
253                    if event == self.stop_event:
254                        self.unlisten()
255                    self.arg_dict.update(pid=pid, event=event, extra=extra)
256                    self.callback(self.arg_dict)
257                if poll:
258                    break
259            else:   # we timed out
260                self.unlisten()
261                self.callback(None)
262
263
264def pgnotify(*args, **kw):
265    """Same as NotificationHandler, under the traditional name."""
266    warnings.warn("pgnotify is deprecated, use NotificationHandler instead.",
267        DeprecationWarning, stacklevel=2)
268    return NotificationHandler(*args, **kw)
269
270
271# The actual PostGreSQL database connection interface:
272
273class DB(object):
274    """Wrapper class for the _pg connection type."""
275
276    def __init__(self, *args, **kw):
277        """Create a new connection.
278
279        You can pass either the connection parameters or an existing
280        _pg or pgdb connection. This allows you to use the methods
281        of the classic pg interface with a DB-API 2 pgdb connection.
282
283        """
284        if not args and len(kw) == 1:
285            db = kw.get('db')
286        elif not kw and len(args) == 1:
287            db = args[0]
288        else:
289            db = None
290        if db:
291            if isinstance(db, DB):
292                db = db.db
293            else:
294                try:
295                    db = db._cnx
296                except AttributeError:
297                    pass
298        if not db or not hasattr(db, 'db') or not hasattr(db, 'query'):
299            db = connect(*args, **kw)
300            self._closeable = True
301        else:
302            self._closeable = False
303        self.db = db
304        self.dbname = db.db
305        self._regtypes = False
306        self._attnames = {}
307        self._pkeys = {}
308        self._privileges = {}
309        self._args = args, kw
310        self.debug = None  # For debugging scripts, this can be set
311            # * to a string format specification (e.g. in CGI set to "%s<BR>"),
312            # * to a file object to write debug statements or
313            # * to a callable object which takes a string argument
314            # * to any other true value to just print debug statements
315
316    def __getattr__(self, name):
317        # All undefined members are same as in underlying connection:
318        if self.db:
319            return getattr(self.db, name)
320        else:
321            raise _int_error('Connection is not valid')
322
323    # Context manager methods
324
325    def __enter__(self):
326        """Enter the runtime context. This will start a transaction."""
327        self.begin()
328        return self
329
330    def __exit__(self, et, ev, tb):
331        """Exit the runtime context. This will end the transaction."""
332        if et is None and ev is None and tb is None:
333            self.commit()
334        else:
335            self.rollback()
336
337    # Auxiliary methods
338
339    def _do_debug(self, s):
340        """Print a debug message."""
341        if self.debug:
342            if isinstance(self.debug, basestring):
343                print(self.debug % s)
344            elif isinstance(self.debug, file):
345                self.debug.write(s + '\n')
346            elif callable(self.debug):
347                self.debug(s)
348            else:
349                print(s)
350
351    def _make_bool(d):
352        """Get boolean value corresponding to d."""
353        if get_bool():
354            return bool(d)
355        return d and 't' or 'f'
356    _make_bool = staticmethod(_make_bool)
357
358    def _quote_text(self, d):
359        """Quote text value."""
360        if not isinstance(d, basestring):
361            d = str(d)
362        return "'%s'" % self.escape_string(d)
363
364    _bool_true = frozenset('t true 1 y yes on'.split())
365
366    def _quote_bool(self, d):
367        """Quote boolean value."""
368        if isinstance(d, basestring):
369            if not d:
370                return 'NULL'
371            d = d.lower() in self._bool_true
372        return d and "'t'" or "'f'"
373
374    _date_literals = frozenset('current_date current_time'
375        ' current_timestamp localtime localtimestamp'.split())
376
377    def _quote_date(self, d):
378        """Quote date value."""
379        if not d:
380            return 'NULL'
381        if isinstance(d, basestring) and d.lower() in self._date_literals:
382            return d
383        return self._quote_text(d)
384
385    def _quote_num(self, d):
386        """Quote numeric value."""
387        if not d and d != 0:
388            return 'NULL'
389        return str(d)
390
391    def _quote_money(self, d):
392        """Quote money value."""
393        if d is None or d == '':
394            return 'NULL'
395        if not isinstance(d, basestring):
396            d = str(d)
397        return d
398
399    _quote_funcs = dict(  # quote methods for each type
400        text=_quote_text, bool=_quote_bool, date=_quote_date,
401        int=_quote_num, num=_quote_num, float=_quote_num,
402        money=_quote_money)
403
404    def _quote(self, d, t):
405        """Return quotes if needed."""
406        if d is None:
407            return 'NULL'
408        try:
409            quote_func = self._quote_funcs[t]
410        except KeyError:
411            quote_func = self._quote_funcs['text']
412        return quote_func(self, d)
413
414    def _split_schema(self, cl):
415        """Return schema and name of object separately.
416
417        This auxiliary function splits off the namespace (schema)
418        belonging to the class with the name cl. If the class name
419        is not qualified, the function is able to determine the schema
420        of the class, taking into account the current search path.
421
422        """
423        s = _split_parts(cl)
424        if len(s) > 1:  # name already qualified?
425            # should be database.schema.table or schema.table
426            if len(s) > 3:
427                raise _prg_error('Too many dots in class name %s' % cl)
428            schema, cl = s[-2:]
429        else:
430            cl = s[0]
431            # determine search path
432            q = 'SELECT current_schemas(TRUE)'
433            schemas = self.db.query(q).getresult()[0][0][1:-1].split(',')
434            if schemas:  # non-empty path
435                # search schema for this object in the current search path
436                # (we could also use unnest with ordinality here to spare
437                # one query, but this is only possible since PostgreSQL 9.4)
438                q = ' UNION '.join(
439                    ["SELECT %d::integer AS n, '%s'::name AS nspname"
440                        % s for s in enumerate(schemas)])
441                q = ("SELECT nspname FROM pg_class r"
442                    " JOIN pg_namespace s ON r.relnamespace = s.oid"
443                    " JOIN (%s) AS p USING (nspname)"
444                    " WHERE r.relname = $1 ORDER BY n LIMIT 1" % q)
445                schema = self.db.query(q, (cl,)).getresult()
446                if schema:  # schema found
447                    schema = schema[0][0]
448                else:  # object not found in current search path
449                    schema = 'public'
450            else:  # empty path
451                schema = 'public'
452        return schema, cl
453
454    def _add_schema(self, cl):
455        """Ensure that the class name is prefixed with a schema name."""
456        return _join_parts(self._split_schema(cl))
457
458    # Public methods
459
460    # escape_string and escape_bytea exist as methods,
461    # so we define unescape_bytea as a method as well
462    unescape_bytea = staticmethod(unescape_bytea)
463
464    def close(self):
465        """Close the database connection."""
466        # Wraps shared library function so we can track state.
467        if self._closeable:
468            if self.db:
469                self.db.close()
470                self.db = None
471            else:
472                raise _int_error('Connection already closed')
473
474    def reset(self):
475        """Reset connection with current parameters.
476
477        All derived queries and large objects derived from this connection
478        will not be usable after this call.
479
480        """
481        if self.db:
482            self.db.reset()
483        else:
484            raise _int_error('Connection already closed')
485
486    def reopen(self):
487        """Reopen connection to the database.
488
489        Used in case we need another connection to the same database.
490        Note that we can still reopen a database that we have closed.
491
492        """
493        # There is no such shared library function.
494        if self._closeable:
495            db = connect(*self._args[0], **self._args[1])
496            if self.db:
497                self.db.close()
498            self.db = db
499
500    def begin(self, mode=None):
501        """Begin a transaction."""
502        qstr = 'BEGIN'
503        if mode:
504            qstr += ' ' + mode
505        return self.query(qstr)
506
507    start = begin
508
509    def commit(self):
510        """Commit the current transaction."""
511        return self.query('COMMIT')
512
513    end = commit
514
515    def rollback(self, name=None):
516        """Roll back the current transaction."""
517        qstr = 'ROLLBACK'
518        if name:
519            qstr += ' TO ' + name
520        return self.query(qstr)
521
522    def savepoint(self, name):
523        """Define a new savepoint within the current transaction."""
524        return self.query('SAVEPOINT ' + name)
525
526    def release(self, name):
527        """Destroy a previously defined savepoint."""
528        return self.query('RELEASE ' + name)
529
530    def get_parameter(self, parameter):
531        """Get the value of a run-time parameter.
532
533        If the parameter is a string, the return value will also be a string
534        that is the current setting of the run-time parameter with that name.
535
536        You can get several parameters at once by passing a list, set or dict.
537        When passing a list of parameter names, the return value will be a
538        corresponding list of parameter settings.  When passing a set of
539        parameter names, a new dict will be returned, mapping these parameter
540        names to their settings.  Finally, if you pass a dict as parameter,
541        its values will be set to the current parameter settings corresponding
542        to its keys.
543
544        By passing the special name 'all' as the parameter, you can get a dict
545        of all existing configuration parameters.
546        """
547        if isinstance(parameter, basestring):
548            parameter = [parameter]
549            values = None
550        elif isinstance(parameter, (list, tuple)):
551            values = []
552        elif isinstance(parameter, (set, frozenset)):
553            values = {}
554        elif isinstance(parameter, dict):
555            values = parameter
556        else:
557            raise TypeError(
558                'The parameter must be a string, list, set or dict')
559        if not parameter:
560            raise TypeError('No parameter has been specified')
561        if isinstance(values, dict):
562            params = {}
563        else:
564            params = []
565        for key in parameter:
566            if isinstance(key, basestring):
567                param = key.strip().lower()
568            else:
569                param = None
570            if not param:
571                raise TypeError('Invalid parameter')
572            if param == 'all':
573                q = 'SHOW ALL'
574                values = self.db.query(q).getresult()
575                values = dict(value[:2] for value in values)
576                break
577            if isinstance(values, dict):
578                params[param] = key
579            else:
580                params.append(param)
581        else:
582            for param in params:
583                q = 'SHOW %s' % (param,)
584                value = self.db.query(q).getresult()[0][0]
585                if values is None:
586                    values = value
587                elif isinstance(values, list):
588                    values.append(value)
589                else:
590                    values[params[param]] = value
591        return values
592
593    def set_parameter(self, parameter, value=None, local=False):
594        """Set the value of a run-time parameter.
595
596        If the parameter and the value are strings, the run-time parameter
597        will be set to that value.  If no value or None is passed as a value,
598        then the run-time parameter will be restored to its default value.
599
600        You can set several parameters at once by passing a list of parameter
601        names, together with a single value that all parameters should be
602        set to or with a corresponding list of values.  You can also pass
603        the parameters as a set if you only provide a single value.
604        Finally, you can pass a dict with parameter names as keys.  In this
605        case, you should not pass a value, since the values for the parameters
606        will be taken from the dict.
607
608        By passing the special name 'all' as the parameter, you can reset
609        all existing settable run-time parameters to their default values.
610
611        If you set local to True, then the command takes effect for only the
612        current transaction.  After commit() or rollback(), the session-level
613        setting takes effect again.  Setting local to True will appear to
614        have no effect if it is executed outside a transaction, since the
615        transaction will end immediately.
616        """
617        if isinstance(parameter, basestring):
618            parameter = {parameter: value}
619        elif isinstance(parameter, (list, tuple)):
620            if isinstance(value, (list, tuple)):
621                parameter = dict(zip(parameter, value))
622            else:
623                parameter = dict.fromkeys(parameter, value)
624        elif isinstance(parameter, (set, frozenset)):
625            if isinstance(value, (list, tuple, set, frozenset)):
626                value = set(value)
627                if len(value) == 1:
628                    value = value.pop()
629            if not(value is None or isinstance(value, basestring)):
630                raise ValueError('A single value must be specified'
631                    ' when parameter is a set')
632            parameter = dict.fromkeys(parameter, value)
633        elif isinstance(parameter, dict):
634            if value is not None:
635                raise ValueError('A value must not be specified'
636                    ' when parameter is a dictionary')
637        else:
638            raise TypeError(
639                'The parameter must be a string, list, set or dict')
640        if not parameter:
641            raise TypeError('No parameter has been specified')
642        params = {}
643        for key, value in parameter.items():
644            if isinstance(key, basestring):
645                param = key.strip().lower()
646            else:
647                param = None
648            if not param:
649                raise TypeError('Invalid parameter')
650            if param == 'all':
651                if value is not None:
652                    raise ValueError('A value must ot be specified'
653                        " when parameter is 'all'")
654                params = {'all': None}
655                break
656            params[param] = value
657        local = local and ' LOCAL' or ''
658        for param, value in params.items():
659            if value is None:
660                q = 'RESET%s %s' % (local, param)
661            else:
662                q = 'SET%s %s TO %s' % (local, param, value)
663            self._do_debug(q)
664            self.db.query(q)
665
666    def query(self, qstr, *args):
667        """Executes a SQL command string.
668
669        This method simply sends a SQL query to the database. If the query is
670        an insert statement that inserted exactly one row into a table that
671        has OIDs, the return value is the OID of the newly inserted row.
672        If the query is an update or delete statement, or an insert statement
673        that did not insert exactly one row in a table with OIDs, then the
674        number of rows affected is returned as a string. If it is a statement
675        that returns rows as a result (usually a select statement, but maybe
676        also an "insert/update ... returning" statement), this method returns
677        a pgqueryobject that can be accessed via getresult() or dictresult()
678        or simply printed. Otherwise, it returns `None`.
679
680        The query can contain numbered parameters of the form $1 in place
681        of any data constant. Arguments given after the query string will
682        be substituted for the corresponding numbered parameter. Parameter
683        values can also be given as a single list or tuple argument.
684
685        Note that the query string must not be passed as a unicode value,
686        but you can pass arguments as unicode values if they can be decoded
687        using the current client encoding.
688
689        """
690        # Wraps shared library function for debugging.
691        if not self.db:
692            raise _int_error('Connection is not valid')
693        self._do_debug(qstr)
694        return self.db.query(qstr, args)
695
696    def pkey(self, cl, newpkey=None):
697        """This method gets or sets the primary key of a class.
698
699        Composite primary keys are represented as frozensets. Note that
700        this raises a KeyError if the table does not have a primary key.
701
702        If newpkey is set and is not a dictionary then set that
703        value as the primary key of the class.  If it is a dictionary
704        then replace the internal cache of primary keys with a copy of it.
705
706        """
707        # First see if the caller is supplying a dictionary
708        if isinstance(newpkey, dict):
709            # make sure that all classes have a namespace
710            self._pkeys = dict([
711                ('.' in cl and cl or 'public.' + cl, pkey)
712                for cl, pkey in newpkey.items()])
713            return self._pkeys
714
715        qcl = self._add_schema(cl)  # build fully qualified class name
716        # Check if the caller is supplying a new primary key for the class
717        if newpkey:
718            self._pkeys[qcl] = newpkey
719            return newpkey
720
721        # Get all the primary keys at once
722        if qcl not in self._pkeys:
723            # if not found, check again in case it was added after we started
724            self._pkeys = {}
725            if self.server_version >= 80200:
726                # the ANY syntax works correctly only with PostgreSQL >= 8.2
727                any_indkey = "= ANY (i.indkey)"
728            else:
729                any_indkey = "IN (%s)" % ', '.join(
730                    ['i.indkey[%d]' % i for i in range(16)])
731            q = ("SELECT s.nspname, r.relname, a.attname"
732                " FROM pg_class r"
733                " JOIN pg_namespace s ON s.oid = r.relnamespace"
734                " AND s.nspname NOT SIMILAR"
735                " TO 'pg/_%|information/_schema' ESCAPE '/'"
736                " JOIN pg_attribute a ON a.attrelid = r.oid"
737                " AND NOT a.attisdropped"
738                " JOIN pg_index i ON i.indrelid = r.oid"
739                " AND i.indisprimary AND a.attnum " + any_indkey)
740            for r in self.db.query(q).getresult():
741                cl, pkey = _join_parts(r[:2]), r[2]
742                self._pkeys.setdefault(cl, []).append(pkey)
743            # (only) for composite primary keys, the values will be frozensets
744            for cl, pkey in self._pkeys.items():
745                self._pkeys[cl] = len(pkey) > 1 and frozenset(pkey) or pkey[0]
746            self._do_debug(self._pkeys)
747
748        # will raise an exception if primary key doesn't exist
749        return self._pkeys[qcl]
750
751    def get_databases(self):
752        """Get list of databases in the system."""
753        return [s[0] for s in
754            self.db.query('SELECT datname FROM pg_database').getresult()]
755
756    def get_relations(self, kinds=None):
757        """Get list of relations in connected database of specified kinds.
758
759        If kinds is None or empty, all kinds of relations are returned.
760        Otherwise kinds can be a string or sequence of type letters
761        specifying which kind of relations you want to list.
762
763        """
764        where = kinds and " AND r.relkind IN (%s)" % ','.join(
765            ["'%s'" % k for k in kinds]) or ''
766        q = ("SELECT s.nspname, r.relname"
767            " FROM pg_class r"
768            " JOIN pg_namespace s ON s.oid = r.relnamespace"
769            " WHERE s.nspname NOT SIMILAR"
770            " TO 'pg/_%%|information/_schema' ESCAPE '/' %s"
771            " ORDER BY 1, 2") % where
772        return [_join_parts(r) for r in self.db.query(q).getresult()]
773
774    def get_tables(self):
775        """Return list of tables in connected database."""
776        return self.get_relations('r')
777
778    def get_attnames(self, cl, newattnames=None):
779        """Given the name of a table, digs out the set of attribute names.
780
781        Returns a dictionary of attribute names (the names are the keys,
782        the values are the names of the attributes' types).
783        If the optional newattnames exists, it must be a dictionary and
784        will become the new attribute names dictionary.
785
786        By default, only a limited number of simple types will be returned.
787        You can get the regular types after calling use_regtypes(True).
788
789        """
790        if isinstance(newattnames, dict):
791            self._attnames = newattnames
792            return
793        elif newattnames:
794            raise _prg_error('If supplied, newattnames must be a dictionary')
795        cl = self._split_schema(cl)  # split into schema and class
796        qcl = _join_parts(cl)  # build fully qualified name
797        # May as well cache them:
798        if qcl in self._attnames:
799            return self._attnames[qcl]
800        if qcl not in self.get_relations('rv'):
801            raise _prg_error('Class %s does not exist' % qcl)
802
803        q = ("SELECT a.attname, t.typname%s"
804            " FROM pg_class r"
805            " JOIN pg_namespace s ON r.relnamespace = s.oid"
806            " JOIN pg_attribute a ON a.attrelid = r.oid"
807            " JOIN pg_type t ON t.oid = a.atttypid"
808            " WHERE s.nspname = $1 AND r.relname = $2"
809            " AND (a.attnum > 0 OR a.attname = 'oid')"
810            " AND NOT a.attisdropped") % (
811                self._regtypes and '::regtype' or '',)
812        q = self.db.query(q, cl).getresult()
813
814        if self._regtypes:
815            t = dict(q)
816        else:
817            t = {}
818            for att, typ in q:
819                if typ.startswith('bool'):
820                    typ = 'bool'
821                elif typ.startswith('abstime'):
822                    typ = 'date'
823                elif typ.startswith('date'):
824                    typ = 'date'
825                elif typ.startswith('interval'):
826                    typ = 'date'
827                elif typ.startswith('timestamp'):
828                    typ = 'date'
829                elif typ.startswith('oid'):
830                    typ = 'int'
831                elif typ.startswith('int'):
832                    typ = 'int'
833                elif typ.startswith('float'):
834                    typ = 'float'
835                elif typ.startswith('numeric'):
836                    typ = 'num'
837                elif typ.startswith('money'):
838                    typ = 'money'
839                else:
840                    typ = 'text'
841                t[att] = typ
842
843        self._attnames[qcl] = t  # cache it
844        return self._attnames[qcl]
845
846    def use_regtypes(self, regtypes=None):
847        """Use regular type names instead of simplified type names."""
848        if regtypes is None:
849            return self._regtypes
850        else:
851            regtypes = bool(regtypes)
852            if regtypes != self._regtypes:
853                self._regtypes = regtypes
854                self._attnames.clear()
855            return regtypes
856
857    def has_table_privilege(self, cl, privilege='select'):
858        """Check whether current user has specified table privilege."""
859        qcl = self._add_schema(cl)
860        privilege = privilege.lower()
861        try:
862            return self._privileges[(qcl, privilege)]
863        except KeyError:
864            q = "SELECT has_table_privilege($1, $2)"
865            q = self.db.query(q, (qcl, privilege))
866            ret = q.getresult()[0][0] == self._make_bool(True)
867            self._privileges[(qcl, privilege)] = ret
868            return ret
869
870    def get(self, cl, arg, keyname=None):
871        """Get a row from a database table or view.
872
873        This method is the basic mechanism to get a single row.  The keyname
874        that the key specifies a unique row.  If keyname is not specified
875        then the primary key for the table is used.  If arg is a dictionary
876        then the value for the key is taken from it and it is modified to
877        include the new values, replacing existing values where necessary.
878        For a composite key, keyname can also be a sequence of key names.
879        The OID is also put into the dictionary if the table has one, but
880        in order to allow the caller to work with multiple tables, it is
881        munged as oid(schema.table).
882
883        """
884        if cl.endswith('*'):  # scan descendant tables?
885            cl = cl[:-1].rstrip()  # need parent table name
886        # build qualified class name
887        qcl = self._add_schema(cl)
888        # To allow users to work with multiple tables,
889        # we munge the name of the "oid" key
890        qoid = _oid_key(qcl)
891        if not keyname:
892            # use the primary key by default
893            try:
894                keyname = self.pkey(qcl)
895            except KeyError:
896                raise _prg_error('Class %s has no primary key' % qcl)
897        # We want the oid for later updates if that isn't the key
898        if keyname == 'oid':
899            if isinstance(arg, dict):
900                if qoid not in arg:
901                    raise _db_error('%s not in arg' % qoid)
902            else:
903                arg = {qoid: arg}
904            where = 'oid = %s' % arg[qoid]
905            attnames = '*'
906        else:
907            attnames = self.get_attnames(qcl)
908            if isinstance(keyname, basestring):
909                keyname = (keyname,)
910            if not isinstance(arg, dict):
911                if len(keyname) > 1:
912                    raise _prg_error('Composite key needs dict as arg')
913                arg = dict([(k, arg) for k in keyname])
914            where = ' AND '.join(['%s = %s'
915                % (k, self._quote(arg[k], attnames[k])) for k in keyname])
916            attnames = ', '.join(attnames)
917        q = 'SELECT %s FROM %s WHERE %s LIMIT 1' % (attnames, qcl, where)
918        self._do_debug(q)
919        res = self.db.query(q).dictresult()
920        if not res:
921            raise _db_error('No such record in %s where %s' % (qcl, where))
922        for att, value in res[0].items():
923            arg[att == 'oid' and qoid or att] = value
924        return arg
925
926    def insert(self, cl, d=None, **kw):
927        """Insert a row into a database table.
928
929        This method inserts a row into a table.  The name of the table must
930        be passed as the first parameter.  The other parameters are used for
931        providing the data of the row that shall be inserted into the table.
932        If a dictionary is supplied as the second parameter, it starts with
933        that.  Otherwise it uses a blank dictionary. Either way the dictionary
934        is updated from the keywords.
935
936        The dictionary is then, if possible, reloaded with the values actually
937        inserted in order to pick up values modified by rules, triggers, etc.
938
939        Note: The method currently doesn't support insert into views
940        although PostgreSQL does.
941
942        """
943        qcl = self._add_schema(cl)
944        qoid = _oid_key(qcl)
945        if d is None:
946            d = {}
947        d.update(kw)
948        attnames = self.get_attnames(qcl)
949        names, values = [], []
950        for n in attnames:
951            if n != 'oid' and n in d:
952                names.append('"%s"' % n)
953                values.append(self._quote(d[n], attnames[n]))
954        names, values = ', '.join(names), ', '.join(values)
955        selectable = self.has_table_privilege(qcl)
956        if selectable and self.server_version >= 80200:
957            ret = ' RETURNING %s*' % ('oid' in attnames and 'oid, ' or '')
958        else:
959            ret = ''
960        q = 'INSERT INTO %s (%s) VALUES (%s)%s' % (qcl, names, values, ret)
961        self._do_debug(q)
962        res = self.db.query(q)
963        if ret:
964            res = res.dictresult()
965            for att, value in res[0].items():
966                d[att == 'oid' and qoid or att] = value
967        elif isinstance(res, int):
968            d[qoid] = res
969            if selectable:
970                self.get(qcl, d, 'oid')
971        elif selectable:
972            if qoid in d:
973                self.get(qcl, d, 'oid')
974            else:
975                try:
976                    self.get(qcl, d)
977                except ProgrammingError:
978                    pass  # table has no primary key
979        return d
980
981    def update(self, cl, d=None, **kw):
982        """Update an existing row in a database table.
983
984        Similar to insert but updates an existing row.  The update is based
985        on the OID value as munged by get or passed as keyword, or on the
986        primary key of the table.  The dictionary is modified, if possible,
987        to reflect any changes caused by the update due to triggers, rules,
988        default values, etc.
989
990        """
991        # Update always works on the oid which get() returns if available,
992        # otherwise use the primary key.  Fail if neither.
993        # Note that we only accept oid key from named args for safety.
994        qcl = self._add_schema(cl)
995        qoid = _oid_key(qcl)
996        if 'oid' in kw:
997            kw[qoid] = kw['oid']
998            del kw['oid']
999        if d is None:
1000            d = {}
1001        d.update(kw)
1002        attnames = self.get_attnames(qcl)
1003        if qoid in d:
1004            where = 'oid = %s' % d[qoid]
1005            keyname = ()
1006        else:
1007            try:
1008                keyname = self.pkey(qcl)
1009            except KeyError:
1010                raise _prg_error('Class %s has no primary key' % qcl)
1011            if isinstance(keyname, basestring):
1012                keyname = (keyname,)
1013            try:
1014                where = ' AND '.join(['%s = %s'
1015                    % (k, self._quote(d[k], attnames[k])) for k in keyname])
1016            except KeyError:
1017                raise _prg_error('Update needs primary key or oid.')
1018        values = []
1019        for n in attnames:
1020            if n in d and n not in keyname:
1021                values.append('%s = %s' % (n, self._quote(d[n], attnames[n])))
1022        if not values:
1023            return d
1024        values = ', '.join(values)
1025        selectable = self.has_table_privilege(qcl)
1026        if selectable and self.server_version >= 80200:
1027            ret = ' RETURNING %s*' % ('oid' in attnames and 'oid, ' or '')
1028        else:
1029            ret = ''
1030        q = 'UPDATE %s SET %s WHERE %s%s' % (qcl, values, where, ret)
1031        self._do_debug(q)
1032        res = self.db.query(q)
1033        if ret:
1034            res = res.dictresult()[0]
1035            for att, value in res.items():
1036                d[att == 'oid' and qoid or att] = value
1037        else:
1038            if selectable:
1039                if qoid in d:
1040                    self.get(qcl, d, 'oid')
1041                else:
1042                    self.get(qcl, d)
1043        return d
1044
1045    def clear(self, cl, d=None):
1046        """Clear all the attributes to values determined by the types.
1047
1048        Numeric types are set to 0, Booleans are set to false, and everything
1049        else is set to the empty string.  If the second argument is present,
1050        it is used as the row dictionary and any entries matching attribute
1051        names are cleared with everything else left unchanged.
1052
1053        """
1054        # At some point we will need a way to get defaults from a table.
1055        qcl = self._add_schema(cl)
1056        if d is None:
1057            d = {}  # empty if argument is not present
1058        attnames = self.get_attnames(qcl)
1059        for n, t in attnames.items():
1060            if n == 'oid':
1061                continue
1062            if t in ('int', 'integer', 'smallint', 'bigint',
1063                    'float', 'real', 'double precision',
1064                    'num', 'numeric', 'money'):
1065                d[n] = 0
1066            elif t in ('bool', 'boolean'):
1067                d[n] = self._make_bool(False)
1068            else:
1069                d[n] = ''
1070        return d
1071
1072    def delete(self, cl, d=None, **kw):
1073        """Delete an existing row in a database table.
1074
1075        This method deletes the row from a table.  It deletes based on the
1076        OID value as munged by get or passed as keyword, or on the primary
1077        key of the table.  The return value is the number of deleted rows
1078        (i.e. 0 if the row did not exist and 1 if the row was deleted).
1079
1080        """
1081        # Like update, delete works on the oid.
1082        # One day we will be testing that the record to be deleted
1083        # isn't referenced somewhere (or else PostgreSQL will).
1084        # Note that we only accept oid key from named args for safety
1085        qcl = self._add_schema(cl)
1086        qoid = _oid_key(qcl)
1087        if 'oid' in kw:
1088            kw[qoid] = kw['oid']
1089            del kw['oid']
1090        if d is None:
1091            d = {}
1092        d.update(kw)
1093        if qoid in d:
1094            where = 'oid = %s' % d[qoid]
1095        else:
1096            try:
1097                keyname = self.pkey(qcl)
1098            except KeyError:
1099                raise _prg_error('Class %s has no primary key' % qcl)
1100            if isinstance(keyname, basestring):
1101                keyname = (keyname,)
1102            attnames = self.get_attnames(qcl)
1103            try:
1104                where = ' AND '.join(['%s = %s'
1105                    % (k, self._quote(d[k], attnames[k])) for k in keyname])
1106            except KeyError:
1107                raise _prg_error('Delete needs primary key or oid.')
1108        q = 'DELETE FROM %s WHERE %s' % (qcl, where)
1109        self._do_debug(q)
1110        return int(self.db.query(q))
1111
1112    def truncate(self, table, restart=False, cascade=False, only=False):
1113        """Empty a table or set of tables.
1114
1115        This method quickly removes all rows from the given table or set
1116        of tables.  It has the same effect as an unqualified DELETE on each
1117        table, but since it does not actually scan the tables it is faster.
1118        Furthermore, it reclaims disk space immediately, rather than requiring
1119        a subsequent VACUUM operation. This is most useful on large tables.
1120
1121        If restart is set to True, sequences owned by columns of the truncated
1122        table(s) are automatically restarted.  If cascade is set to True, it
1123        also truncates all tables that have foreign-key references to any of
1124        the named tables.  If the parameter only is not set to True, all the
1125        descendant tables (if any) will also be truncated. Optionally, a '*'
1126        can be specified after the table name to explicitly indicate that
1127        descendant tables are included.
1128        """
1129        if isinstance(table, basestring):
1130            only = {table: only}
1131            table = [table]
1132        elif isinstance(table, (list, tuple)):
1133            if isinstance(only, (list, tuple)):
1134                only = dict(zip(table, only))
1135            else:
1136                only = dict.fromkeys(table, only)
1137        elif isinstance(table, (set, frozenset)):
1138            only = dict.fromkeys(table, only)
1139        else:
1140            raise TypeError('The table must be a string, list or set')
1141        if not (restart is None or isinstance(restart, (bool, int))):
1142            raise TypeError('Invalid type for the restart option')
1143        if not (cascade is None or isinstance(cascade, (bool, int))):
1144            raise TypeError('Invalid type for the cascade option')
1145        tables = []
1146        for t in table:
1147            u = only.get(t)
1148            if not (u is None or isinstance(u, (bool, int))):
1149                raise TypeError('Invalid type for the only option')
1150            if t.endswith('*'):
1151                if u:
1152                    raise ValueError(
1153                        'Contradictory table name and only options')
1154                t = t[:-1].rstrip()
1155            t = self._add_schema(t)
1156            if u:
1157                t = 'ONLY %s' % t
1158            tables.append(t)
1159        q = ['TRUNCATE', ', '.join(tables)]
1160        if restart:
1161            q.append('RESTART IDENTITY')
1162        if cascade:
1163            q.append('CASCADE')
1164        q = ' '.join(q)
1165        self._do_debug(q)
1166        return self.query(q)
1167
1168    def notification_handler(self,
1169            event, callback, arg_dict=None, timeout=None, stop_event=None):
1170        """Get notification handler that will run the given callback."""
1171        return NotificationHandler(self,
1172            event, callback, arg_dict, timeout, stop_event)
1173
1174
1175# if run as script, print some information
1176
1177if __name__ == '__main__':
1178    print('PyGreSQL version' + version)
1179    print('')
1180    print(__doc__)
Note: See TracBrowser for help on using the repository browser.