source: branches/4.x/pg.py @ 771

Last change on this file since 771 was 771, checked in by cito, 4 years ago

Back port some minor fixes from the trunk

This also gives better error message if test runner does not support unittest2.

  • Property svn:keywords set to Id
File size: 43.4 KB
Line 
1#! /usr/bin/python
2#
3# pg.py
4#
5# $Id: pg.py 771 2016-01-20 18:22:10Z cito $
6#
7
8"""PyGreSQL classic interface.
9
10This pg module implements some basic database management stuff.
11It includes the _pg module and builds on it, providing the higher
12level wrapper class named DB with additional functionality.
13This is known as the "classic" ("old style") PyGreSQL interface.
14For a DB-API 2 compliant interface use the newer pgdb module.
15
16"""
17
18# Copyright (c) 1997-2016 by D'Arcy J.M. Cain.
19#
20# Contributions made by Ch. Zwerschke and others.
21#
22# The notification handler is based on pgnotify which is
23# Copyright (c) 2001 Ng Pheng Siong. All rights reserved.
24#
25# Permission to use, copy, modify, and distribute this software and its
26# documentation for any purpose and without fee is hereby granted,
27# provided that the above copyright notice appear in all copies and that
28# both that copyright notice and this permission notice appear in
29# supporting documentation.
30
31from _pg import *
32
33import select
34import warnings
35try:
36    frozenset
37except NameError:  # Python < 2.4, unsupported
38    from sets import ImmutableSet as frozenset
39try:
40    from decimal import Decimal
41    set_decimal(Decimal)
42except ImportError:  # Python < 2.4, unsupported
43    Decimal = float
44try:
45    from collections import namedtuple
46except ImportError:  # Python < 2.6
47    namedtuple = None
48
49
50# Auxiliary functions that are independent from a DB connection:
51
52def _is_quoted(s):
53    """Check whether this string is a quoted identifier."""
54    s = s.replace('_', 'a')
55    return not s.isalnum() or s[:1].isdigit() or s != s.lower()
56
57
58def _is_unquoted(s):
59    """Check whether this string is an unquoted identifier."""
60    s = s.replace('_', 'a')
61    return s.isalnum() and not s[:1].isdigit()
62
63
64def _split_first_part(s):
65    """Split the first part of a dot separated string."""
66    s = s.lstrip()
67    if s[:1] == '"':
68        p = []
69        s = s.split('"', 3)[1:]
70        p.append(s[0])
71        while len(s) == 3 and s[1] == '':
72            p.append('"')
73            s = s[2].split('"', 2)
74            p.append(s[0])
75        p = [''.join(p)]
76        s = '"'.join(s[1:]).lstrip()
77        if s:
78            if s[:0] == '.':
79                p.append(s[1:])
80            else:
81                s = _split_first_part(s)
82                p[0] += s[0]
83                if len(s) > 1:
84                    p.append(s[1])
85    else:
86        p = s.split('.', 1)
87        s = p[0].rstrip()
88        if _is_unquoted(s):
89            s = s.lower()
90        p[0] = s
91    return p
92
93
94def _split_parts(s):
95    """Split all parts of a dot separated string."""
96    q = []
97    while s:
98        s = _split_first_part(s)
99        q.append(s[0])
100        if len(s) < 2:
101            break
102        s = s[1]
103    return q
104
105
106def _join_parts(s):
107    """Join all parts of a dot separated string."""
108    return '.'.join([_is_quoted(p) and '"%s"' % p or p for p in s])
109
110
111def _oid_key(qcl):
112    """Build oid key from qualified class name."""
113    return 'oid(%s)' % qcl
114
115
116if namedtuple:
117
118    def _namedresult(q):
119        """Get query result as named tuples."""
120        row = namedtuple('Row', q.listfields())
121        return [row(*r) for r in q.getresult()]
122
123    set_namedresult(_namedresult)
124
125
126def _db_error(msg, cls=DatabaseError):
127    """Returns DatabaseError with empty sqlstate attribute."""
128    error = cls(msg)
129    error.sqlstate = None
130    return error
131
132
133def _int_error(msg):
134    """Returns InternalError."""
135    return _db_error(msg, InternalError)
136
137
138def _prg_error(msg):
139    """Returns ProgrammingError."""
140    return _db_error(msg, ProgrammingError)
141
142
143class NotificationHandler(object):
144    """A PostgreSQL client-side asynchronous notification handler."""
145
146    def __init__(self, db, event, callback=None,
147            arg_dict=None, timeout=None, stop_event=None):
148        """Initialize the notification handler.
149
150        You must pass a PyGreSQL database connection, the name of an
151        event (notification channel) to listen for and a callback function.
152
153        You can also specify a dictionary arg_dict that will be passed as
154        the single argument to the callback function, and a timeout value
155        in seconds (a floating point number denotes fractions of seconds).
156        If it is absent or None, the callers will never time out.  If the
157        timeout is reached, the callback function will be called with a
158        single argument that is None.  If you set the timeout to zero,
159        the handler will poll notifications synchronously and return.
160
161        You can specify the name of the event that will be used to signal
162        the handler to stop listening as stop_event. By default, it will
163        be the event name prefixed with 'stop_'.
164        """
165        self.db = db
166        self.event = event
167        self.stop_event = stop_event or 'stop_%s' % event
168        self.listening = False
169        self.callback = callback
170        if arg_dict is None:
171            arg_dict = {}
172        self.arg_dict = arg_dict
173        self.timeout = timeout
174
175    def __del__(self):
176        self.unlisten()
177
178    def close(self):
179        """Stop listening and close the connection."""
180        if self.db:
181            self.unlisten()
182            self.db.close()
183            self.db = None
184
185    def listen(self):
186        """Start listening for the event and the stop event."""
187        if not self.listening:
188            self.db.query('listen "%s"' % self.event)
189            self.db.query('listen "%s"' % self.stop_event)
190            self.listening = True
191
192    def unlisten(self):
193        """Stop listening for the event and the stop event."""
194        if self.listening:
195            self.db.query('unlisten "%s"' % self.event)
196            self.db.query('unlisten "%s"' % self.stop_event)
197            self.listening = False
198
199    def notify(self, db=None, stop=False, payload=None):
200        """Generate a notification.
201
202        Optionally, you can pass a payload with the notification.
203
204        If you set the stop flag, a stop notification will be sent that
205        will cause the handler to stop listening.
206
207        Note: If the notification handler is running in another thread, you
208        must pass a different database connection since PyGreSQL database
209        connections are not thread-safe.
210        """
211        if self.listening:
212            if not db:
213                db = self.db
214            q = 'notify "%s"' % (stop and self.stop_event or self.event)
215            if payload:
216                q += ", '%s'" % payload
217            return db.query(q)
218
219    def __call__(self):
220        """Invoke the notification handler.
221
222        The handler is a loop that listens for notifications on the event
223        and stop event channels.  When either of these notifications are
224        received, its associated 'pid', 'event' and 'extra' (the payload
225        passed with the notification) are inserted into its arg_dict
226        dictionary and the callback is invoked with this dictionary as
227        a single argument.  When the handler receives a stop event, it
228        stops listening to both events and return.
229
230        In the special case that the timeout of the handler has been set
231        to zero, the handler will poll all events synchronously and return.
232        If will keep listening until it receives a stop event.
233
234        Note: If you run this loop in another thread, don't use the same
235        database connection for database operations in the main thread.
236        """
237        self.listen()
238        poll = self.timeout == 0
239        if not poll:
240            rlist = [self.db.fileno()]
241        while self.listening:
242            if poll or select.select(rlist, [], [], self.timeout)[0]:
243                while self.listening:
244                    notice = self.db.getnotify()
245                    if not notice:  # no more messages
246                        break
247                    event, pid, extra = notice
248                    if event not in (self.event, self.stop_event):
249                        self.unlisten()
250                        raise _db_error(
251                            'Listening for "%s" and "%s", but notified of "%s"'
252                            % (self.event, self.stop_event, event))
253                    if event == self.stop_event:
254                        self.unlisten()
255                    self.arg_dict.update(pid=pid, event=event, extra=extra)
256                    self.callback(self.arg_dict)
257                if poll:
258                    break
259            else:   # we timed out
260                self.unlisten()
261                self.callback(None)
262
263
264def pgnotify(*args, **kw):
265    """Same as NotificationHandler, under the traditional name."""
266    warnings.warn("pgnotify is deprecated, use NotificationHandler instead.",
267        DeprecationWarning, stacklevel=2)
268    return NotificationHandler(*args, **kw)
269
270
271# The actual PostGreSQL database connection interface:
272
273class DB(object):
274    """Wrapper class for the _pg connection type."""
275
276    def __init__(self, *args, **kw):
277        """Create a new connection.
278
279        You can pass either the connection parameters or an existing
280        _pg or pgdb connection. This allows you to use the methods
281        of the classic pg interface with a DB-API 2 pgdb connection.
282
283        """
284        if not args and len(kw) == 1:
285            db = kw.get('db')
286        elif not kw and len(args) == 1:
287            db = args[0]
288        else:
289            db = None
290        if db:
291            if isinstance(db, DB):
292                db = db.db
293            else:
294                try:
295                    db = db._cnx
296                except AttributeError:
297                    pass
298        if not db or not hasattr(db, 'db') or not hasattr(db, 'query'):
299            db = connect(*args, **kw)
300            self._closeable = True
301        else:
302            self._closeable = False
303        self.db = db
304        self.dbname = db.db
305        self._regtypes = False
306        self._attnames = {}
307        self._pkeys = {}
308        self._privileges = {}
309        self._args = args, kw
310        self.debug = None  # For debugging scripts, this can be set
311            # * to a string format specification (e.g. in CGI set to "%s<BR>"),
312            # * to a file object to write debug statements or
313            # * to a callable object which takes a string argument
314            # * to any other true value to just print debug statements
315
316    def __getattr__(self, name):
317        # All undefined members are same as in underlying connection:
318        if self.db:
319            return getattr(self.db, name)
320        else:
321            raise _int_error('Connection is not valid')
322
323    # Context manager methods
324
325    def __enter__(self):
326        """Enter the runtime context. This will start a transaction."""
327        self.begin()
328        return self
329
330    def __exit__(self, et, ev, tb):
331        """Exit the runtime context. This will end the transaction."""
332        if et is None and ev is None and tb is None:
333            self.commit()
334        else:
335            self.rollback()
336
337    # Auxiliary methods
338
339    def _do_debug(self, s):
340        """Print a debug message."""
341        if self.debug:
342            if isinstance(self.debug, basestring):
343                print(self.debug % s)
344            elif isinstance(self.debug, file):
345                self.debug.write(s + '\n')
346            elif callable(self.debug):
347                self.debug(s)
348            else:
349                print(s)
350
351    def _make_bool(d):
352        """Get boolean value corresponding to d."""
353        if get_bool():
354            return bool(d)
355        return d and 't' or 'f'
356    _make_bool = staticmethod(_make_bool)
357
358    def _quote_text(self, d):
359        """Quote text value."""
360        if not isinstance(d, basestring):
361            d = str(d)
362        return "'%s'" % self.escape_string(d)
363
364    _bool_true = frozenset('t true 1 y yes on'.split())
365
366    def _quote_bool(self, d):
367        """Quote boolean value."""
368        if isinstance(d, basestring):
369            if not d:
370                return 'NULL'
371            d = d.lower() in self._bool_true
372        return d and "'t'" or "'f'"
373
374    _date_literals = frozenset('current_date current_time'
375        ' current_timestamp localtime localtimestamp'.split())
376
377    def _quote_date(self, d):
378        """Quote date value."""
379        if not d:
380            return 'NULL'
381        if isinstance(d, basestring) and d.lower() in self._date_literals:
382            return d
383        return self._quote_text(d)
384
385    def _quote_num(self, d):
386        """Quote numeric value."""
387        if not d and d != 0:
388            return 'NULL'
389        return str(d)
390
391    def _quote_money(self, d):
392        """Quote money value."""
393        if d is None or d == '':
394            return 'NULL'
395        if not isinstance(d, basestring):
396            d = str(d)
397        return d
398
399    _quote_funcs = dict(  # quote methods for each type
400        text=_quote_text, bool=_quote_bool, date=_quote_date,
401        int=_quote_num, num=_quote_num, float=_quote_num,
402        money=_quote_money)
403
404    def _quote(self, d, t):
405        """Return quotes if needed."""
406        if d is None:
407            return 'NULL'
408        try:
409            quote_func = self._quote_funcs[t]
410        except KeyError:
411            quote_func = self._quote_funcs['text']
412        return quote_func(self, d)
413
414    def _split_schema(self, cl):
415        """Return schema and name of object separately.
416
417        This auxiliary function splits off the namespace (schema)
418        belonging to the class with the name cl. If the class name
419        is not qualified, the function is able to determine the schema
420        of the class, taking into account the current search path.
421
422        """
423        s = _split_parts(cl)
424        if len(s) > 1:  # name already qualified?
425            # should be database.schema.table or schema.table
426            if len(s) > 3:
427                raise _prg_error('Too many dots in class name %s' % cl)
428            schema, cl = s[-2:]
429        else:
430            cl = s[0]
431            # determine search path
432            q = 'SELECT current_schemas(TRUE)'
433            schemas = self.db.query(q).getresult()[0][0][1:-1].split(',')
434            if schemas:  # non-empty path
435                # search schema for this object in the current search path
436                # (we could also use unnest with ordinality here to spare
437                # one query, but this is only possible since PostgreSQL 9.4)
438                q = ' UNION '.join(
439                    ["SELECT %d::integer AS n, '%s'::name AS nspname"
440                        % s for s in enumerate(schemas)])
441                q = ("SELECT nspname FROM pg_class r"
442                    " JOIN pg_namespace s ON r.relnamespace = s.oid"
443                    " JOIN (%s) AS p USING (nspname)"
444                    " WHERE r.relname = $1 ORDER BY n LIMIT 1" % q)
445                schema = self.db.query(q, (cl,)).getresult()
446                if schema:  # schema found
447                    schema = schema[0][0]
448                else:  # object not found in current search path
449                    schema = 'public'
450            else:  # empty path
451                schema = 'public'
452        return schema, cl
453
454    def _add_schema(self, cl):
455        """Ensure that the class name is prefixed with a schema name."""
456        return _join_parts(self._split_schema(cl))
457
458    # Public methods
459
460    # escape_string and escape_bytea exist as methods,
461    # so we define unescape_bytea as a method as well
462    unescape_bytea = staticmethod(unescape_bytea)
463
464    def close(self):
465        """Close the database connection."""
466        # Wraps shared library function so we can track state.
467        if self._closeable:
468            if self.db:
469                self.db.close()
470                self.db = None
471            else:
472                raise _int_error('Connection already closed')
473
474    def reset(self):
475        """Reset connection with current parameters.
476
477        All derived queries and large objects derived from this connection
478        will not be usable after this call.
479
480        """
481        if self.db:
482            self.db.reset()
483        else:
484            raise _int_error('Connection already closed')
485
486    def reopen(self):
487        """Reopen connection to the database.
488
489        Used in case we need another connection to the same database.
490        Note that we can still reopen a database that we have closed.
491
492        """
493        # There is no such shared library function.
494        if self._closeable:
495            db = connect(*self._args[0], **self._args[1])
496            if self.db:
497                self.db.close()
498            self.db = db
499
500    def begin(self, mode=None):
501        """Begin a transaction."""
502        qstr = 'BEGIN'
503        if mode:
504            qstr += ' ' + mode
505        return self.query(qstr)
506
507    start = begin
508
509    def commit(self):
510        """Commit the current transaction."""
511        return self.query('COMMIT')
512
513    end = commit
514
515    def rollback(self, name=None):
516        """Roll back the current transaction."""
517        qstr = 'ROLLBACK'
518        if name:
519            qstr += ' TO ' + name
520        return self.query(qstr)
521
522    abort = rollback
523
524    def savepoint(self, name):
525        """Define a new savepoint within the current transaction."""
526        return self.query('SAVEPOINT ' + name)
527
528    def release(self, name):
529        """Destroy a previously defined savepoint."""
530        return self.query('RELEASE ' + name)
531
532    def get_parameter(self, parameter):
533        """Get the value of a run-time parameter.
534
535        If the parameter is a string, the return value will also be a string
536        that is the current setting of the run-time parameter with that name.
537
538        You can get several parameters at once by passing a list, set or dict.
539        When passing a list of parameter names, the return value will be a
540        corresponding list of parameter settings.  When passing a set of
541        parameter names, a new dict will be returned, mapping these parameter
542        names to their settings.  Finally, if you pass a dict as parameter,
543        its values will be set to the current parameter settings corresponding
544        to its keys.
545
546        By passing the special name 'all' as the parameter, you can get a dict
547        of all existing configuration parameters.
548        """
549        if isinstance(parameter, basestring):
550            parameter = [parameter]
551            values = None
552        elif isinstance(parameter, (list, tuple)):
553            values = []
554        elif isinstance(parameter, (set, frozenset)):
555            values = {}
556        elif isinstance(parameter, dict):
557            values = parameter
558        else:
559            raise TypeError(
560                'The parameter must be a string, list, set or dict')
561        if not parameter:
562            raise TypeError('No parameter has been specified')
563        if isinstance(values, dict):
564            params = {}
565        else:
566            params = []
567        for key in parameter:
568            if isinstance(key, basestring):
569                param = key.strip().lower()
570            else:
571                param = None
572            if not param:
573                raise TypeError('Invalid parameter')
574            if param == 'all':
575                q = 'SHOW ALL'
576                values = self.db.query(q).getresult()
577                values = dict(value[:2] for value in values)
578                break
579            if isinstance(values, dict):
580                params[param] = key
581            else:
582                params.append(param)
583        else:
584            for param in params:
585                q = 'SHOW %s' % (param,)
586                value = self.db.query(q).getresult()[0][0]
587                if values is None:
588                    values = value
589                elif isinstance(values, list):
590                    values.append(value)
591                else:
592                    values[params[param]] = value
593        return values
594
595    def set_parameter(self, parameter, value=None, local=False):
596        """Set the value of a run-time parameter.
597
598        If the parameter and the value are strings, the run-time parameter
599        will be set to that value.  If no value or None is passed as a value,
600        then the run-time parameter will be restored to its default value.
601
602        You can set several parameters at once by passing a list of parameter
603        names, together with a single value that all parameters should be
604        set to or with a corresponding list of values.  You can also pass
605        the parameters as a set if you only provide a single value.
606        Finally, you can pass a dict with parameter names as keys.  In this
607        case, you should not pass a value, since the values for the parameters
608        will be taken from the dict.
609
610        By passing the special name 'all' as the parameter, you can reset
611        all existing settable run-time parameters to their default values.
612
613        If you set local to True, then the command takes effect for only the
614        current transaction.  After commit() or rollback(), the session-level
615        setting takes effect again.  Setting local to True will appear to
616        have no effect if it is executed outside a transaction, since the
617        transaction will end immediately.
618        """
619        if isinstance(parameter, basestring):
620            parameter = {parameter: value}
621        elif isinstance(parameter, (list, tuple)):
622            if isinstance(value, (list, tuple)):
623                parameter = dict(zip(parameter, value))
624            else:
625                parameter = dict.fromkeys(parameter, value)
626        elif isinstance(parameter, (set, frozenset)):
627            if isinstance(value, (list, tuple, set, frozenset)):
628                value = set(value)
629                if len(value) == 1:
630                    value = value.pop()
631            if not(value is None or isinstance(value, basestring)):
632                raise ValueError('A single value must be specified'
633                    ' when parameter is a set')
634            parameter = dict.fromkeys(parameter, value)
635        elif isinstance(parameter, dict):
636            if value is not None:
637                raise ValueError('A value must not be specified'
638                    ' when parameter is a dictionary')
639        else:
640            raise TypeError(
641                'The parameter must be a string, list, set or dict')
642        if not parameter:
643            raise TypeError('No parameter has been specified')
644        params = {}
645        for key, value in parameter.items():
646            if isinstance(key, basestring):
647                param = key.strip().lower()
648            else:
649                param = None
650            if not param:
651                raise TypeError('Invalid parameter')
652            if param == 'all':
653                if value is not None:
654                    raise ValueError('A value must ot be specified'
655                        " when parameter is 'all'")
656                params = {'all': None}
657                break
658            params[param] = value
659        local = local and ' LOCAL' or ''
660        for param, value in params.items():
661            if value is None:
662                q = 'RESET%s %s' % (local, param)
663            else:
664                q = 'SET%s %s TO %s' % (local, param, value)
665            self._do_debug(q)
666            self.db.query(q)
667
668    def query(self, qstr, *args):
669        """Executes a SQL command string.
670
671        This method simply sends a SQL query to the database.  If the query is
672        an insert statement that inserted exactly one row into a table that
673        has OIDs, the return value is the OID of the newly inserted row.
674        If the query is an update or delete statement, or an insert statement
675        that did not insert exactly one row in a table with OIDs, then the
676        number of rows affected is returned as a string.  If it is a statement
677        that returns rows as a result (usually a select statement, but maybe
678        also an "insert/update ... returning" statement), this method returns
679        a pgqueryobject that can be accessed via getresult() or dictresult()
680        or simply printed.  Otherwise, it returns `None`.
681
682        The query can contain numbered parameters of the form $1 in place
683        of any data constant.  Arguments given after the query string will
684        be substituted for the corresponding numbered parameter.  Parameter
685        values can also be given as a single list or tuple argument.
686
687        Note that the query string must not be passed as a unicode value,
688        but you can pass arguments as unicode values if they can be decoded
689        using the current client encoding.
690
691        """
692        # Wraps shared library function for debugging.
693        if not self.db:
694            raise _int_error('Connection is not valid')
695        self._do_debug(qstr)
696        return self.db.query(qstr, args)
697
698    def pkey(self, cl, newpkey=None):
699        """This method gets or sets the primary key of a class.
700
701        Composite primary keys are represented as frozensets. Note that
702        this raises a KeyError if the table does not have a primary key.
703
704        If newpkey is set and is not a dictionary then set that
705        value as the primary key of the class.  If it is a dictionary
706        then replace the internal cache of primary keys with a copy of it.
707
708        """
709        # First see if the caller is supplying a dictionary
710        if isinstance(newpkey, dict):
711            # make sure that all classes have a namespace
712            self._pkeys = dict([
713                ('.' in cl and cl or 'public.' + cl, pkey)
714                for cl, pkey in newpkey.items()])
715            return self._pkeys
716
717        qcl = self._add_schema(cl)  # build fully qualified class name
718        # Check if the caller is supplying a new primary key for the class
719        if newpkey:
720            self._pkeys[qcl] = newpkey
721            return newpkey
722
723        # Get all the primary keys at once
724        if qcl not in self._pkeys:
725            # if not found, check again in case it was added after we started
726            self._pkeys = {}
727            if self.server_version >= 80200:
728                # the ANY syntax works correctly only with PostgreSQL >= 8.2
729                any_indkey = "= ANY (i.indkey)"
730            else:
731                any_indkey = "IN (%s)" % ', '.join(
732                    ['i.indkey[%d]' % i for i in range(16)])
733            q = ("SELECT s.nspname, r.relname, a.attname"
734                " FROM pg_class r"
735                " JOIN pg_namespace s ON s.oid = r.relnamespace"
736                " AND s.nspname NOT SIMILAR"
737                " TO 'pg/_%|information/_schema' ESCAPE '/'"
738                " JOIN pg_attribute a ON a.attrelid = r.oid"
739                " AND NOT a.attisdropped"
740                " JOIN pg_index i ON i.indrelid = r.oid"
741                " AND i.indisprimary AND a.attnum " + any_indkey)
742            for r in self.db.query(q).getresult():
743                cl, pkey = _join_parts(r[:2]), r[2]
744                self._pkeys.setdefault(cl, []).append(pkey)
745            # (only) for composite primary keys, the values will be frozensets
746            for cl, pkey in self._pkeys.items():
747                self._pkeys[cl] = len(pkey) > 1 and frozenset(pkey) or pkey[0]
748            self._do_debug(self._pkeys)
749
750        # will raise an exception if primary key doesn't exist
751        return self._pkeys[qcl]
752
753    def get_databases(self):
754        """Get list of databases in the system."""
755        return [s[0] for s in
756            self.db.query('SELECT datname FROM pg_database').getresult()]
757
758    def get_relations(self, kinds=None):
759        """Get list of relations in connected database of specified kinds.
760
761        If kinds is None or empty, all kinds of relations are returned.
762        Otherwise kinds can be a string or sequence of type letters
763        specifying which kind of relations you want to list.
764
765        """
766        where = kinds and " AND r.relkind IN (%s)" % ','.join(
767            ["'%s'" % k for k in kinds]) or ''
768        q = ("SELECT s.nspname, r.relname"
769            " FROM pg_class r"
770            " JOIN pg_namespace s ON s.oid = r.relnamespace"
771            " WHERE s.nspname NOT SIMILAR"
772            " TO 'pg/_%%|information/_schema' ESCAPE '/' %s"
773            " ORDER BY 1, 2") % where
774        return [_join_parts(r) for r in self.db.query(q).getresult()]
775
776    def get_tables(self):
777        """Return list of tables in connected database."""
778        return self.get_relations('r')
779
780    def get_attnames(self, cl, newattnames=None):
781        """Given the name of a table, digs out the set of attribute names.
782
783        Returns a dictionary of attribute names (the names are the keys,
784        the values are the names of the attributes' types).
785        If the optional newattnames exists, it must be a dictionary and
786        will become the new attribute names dictionary.
787
788        By default, only a limited number of simple types will be returned.
789        You can get the regular types after calling use_regtypes(True).
790
791        """
792        if isinstance(newattnames, dict):
793            self._attnames = newattnames
794            return
795        elif newattnames:
796            raise _prg_error('If supplied, newattnames must be a dictionary')
797        cl = self._split_schema(cl)  # split into schema and class
798        qcl = _join_parts(cl)  # build fully qualified name
799        # May as well cache them:
800        if qcl in self._attnames:
801            return self._attnames[qcl]
802        if qcl not in self.get_relations('rv'):
803            raise _prg_error('Class %s does not exist' % qcl)
804
805        q = ("SELECT a.attname, t.typname%s"
806            " FROM pg_class r"
807            " JOIN pg_namespace s ON r.relnamespace = s.oid"
808            " JOIN pg_attribute a ON a.attrelid = r.oid"
809            " JOIN pg_type t ON t.oid = a.atttypid"
810            " WHERE s.nspname = $1 AND r.relname = $2"
811            " AND (a.attnum > 0 OR a.attname = 'oid')"
812            " AND NOT a.attisdropped") % (
813                self._regtypes and '::regtype' or '',)
814        q = self.db.query(q, cl).getresult()
815
816        if self._regtypes:
817            t = dict(q)
818        else:
819            t = {}
820            for att, typ in q:
821                if typ.startswith('bool'):
822                    typ = 'bool'
823                elif typ.startswith('abstime'):
824                    typ = 'date'
825                elif typ.startswith('date'):
826                    typ = 'date'
827                elif typ.startswith('interval'):
828                    typ = 'date'
829                elif typ.startswith('timestamp'):
830                    typ = 'date'
831                elif typ.startswith('oid'):
832                    typ = 'int'
833                elif typ.startswith('int'):
834                    typ = 'int'
835                elif typ.startswith('float'):
836                    typ = 'float'
837                elif typ.startswith('numeric'):
838                    typ = 'num'
839                elif typ.startswith('money'):
840                    typ = 'money'
841                else:
842                    typ = 'text'
843                t[att] = typ
844
845        self._attnames[qcl] = t  # cache it
846        return self._attnames[qcl]
847
848    def use_regtypes(self, regtypes=None):
849        """Use regular type names instead of simplified type names."""
850        if regtypes is None:
851            return self._regtypes
852        else:
853            regtypes = bool(regtypes)
854            if regtypes != self._regtypes:
855                self._regtypes = regtypes
856                self._attnames.clear()
857            return regtypes
858
859    def has_table_privilege(self, cl, privilege='select'):
860        """Check whether current user has specified table privilege."""
861        qcl = self._add_schema(cl)
862        privilege = privilege.lower()
863        try:
864            return self._privileges[(qcl, privilege)]
865        except KeyError:
866            q = "SELECT has_table_privilege($1, $2)"
867            q = self.db.query(q, (qcl, privilege))
868            ret = q.getresult()[0][0] == self._make_bool(True)
869            self._privileges[(qcl, privilege)] = ret
870            return ret
871
872    def get(self, cl, arg, keyname=None):
873        """Get a row from a database table or view.
874
875        This method is the basic mechanism to get a single row.  It assumes
876        that the key specifies a unique row.  If keyname is not specified
877        then the primary key for the table is used.  If arg is a dictionary
878        then the value for the key is taken from it and it is modified to
879        include the new values, replacing existing values where necessary.
880        For a composite key, keyname can also be a sequence of key names.
881        The OID is also put into the dictionary if the table has one, but
882        in order to allow the caller to work with multiple tables, it is
883        munged as oid(schema.table).
884
885        """
886        if cl.endswith('*'):  # scan descendant tables?
887            cl = cl[:-1].rstrip()  # need parent table name
888        # build qualified class name
889        qcl = self._add_schema(cl)
890        # To allow users to work with multiple tables,
891        # we munge the name of the "oid" key
892        qoid = _oid_key(qcl)
893        if not keyname:
894            # use the primary key by default
895            try:
896                keyname = self.pkey(qcl)
897            except KeyError:
898                raise _prg_error('Class %s has no primary key' % qcl)
899        # We want the oid for later updates if that isn't the key
900        if keyname == 'oid':
901            if isinstance(arg, dict):
902                if qoid not in arg:
903                    raise _prg_error('%s not in arg' % qoid)
904            else:
905                arg = {qoid: arg}
906            where = 'oid = %s' % arg[qoid]
907            attnames = '*'
908        else:
909            attnames = self.get_attnames(qcl)
910            if isinstance(keyname, basestring):
911                keyname = (keyname,)
912            if not isinstance(arg, dict):
913                if len(keyname) > 1:
914                    raise _prg_error('Composite key needs dict as arg')
915                arg = dict([(k, arg) for k in keyname])
916            where = ' AND '.join(['%s = %s'
917                % (k, self._quote(arg[k], attnames[k])) for k in keyname])
918            attnames = ', '.join(attnames)
919        q = 'SELECT %s FROM %s WHERE %s LIMIT 1' % (attnames, qcl, where)
920        self._do_debug(q)
921        res = self.db.query(q).dictresult()
922        if not res:
923            raise _db_error('No such record in %s where %s' % (qcl, where))
924        for att, value in res[0].items():
925            arg[att == 'oid' and qoid or att] = value
926        return arg
927
928    def insert(self, cl, d=None, **kw):
929        """Insert a row into a database table.
930
931        This method inserts a row into a table.  The name of the table must
932        be passed as the first parameter.  The other parameters are used for
933        providing the data of the row that shall be inserted into the table.
934        If a dictionary is supplied as the second parameter, it starts with
935        that.  Otherwise it uses a blank dictionary. Either way the dictionary
936        is updated from the keywords.
937
938        The dictionary is then, if possible, reloaded with the values actually
939        inserted in order to pick up values modified by rules, triggers, etc.
940
941        """
942        qcl = self._add_schema(cl)
943        qoid = _oid_key(qcl)
944        if d is None:
945            d = {}
946        d.update(kw)
947        attnames = self.get_attnames(qcl)
948        names, values = [], []
949        for n in attnames:
950            if n != 'oid' and n in d:
951                names.append('"%s"' % n)
952                values.append(self._quote(d[n], attnames[n]))
953        names, values = ', '.join(names), ', '.join(values)
954        selectable = self.has_table_privilege(qcl)
955        if selectable and self.server_version >= 80200:
956            ret = ' RETURNING %s*' % ('oid' in attnames and 'oid, ' or '')
957        else:
958            ret = ''
959        q = 'INSERT INTO %s (%s) VALUES (%s)%s' % (qcl, names, values, ret)
960        self._do_debug(q)
961        res = self.db.query(q)
962        if ret:
963            res = res.dictresult()
964            for att, value in res[0].items():
965                d[att == 'oid' and qoid or att] = value
966        elif isinstance(res, int):
967            d[qoid] = res
968            if selectable:
969                self.get(qcl, d, 'oid')
970        elif selectable:
971            if qoid in d:
972                self.get(qcl, d, 'oid')
973            else:
974                try:
975                    self.get(qcl, d)
976                except ProgrammingError:
977                    pass  # table has no primary key
978        return d
979
980    def update(self, cl, d=None, **kw):
981        """Update an existing row in a database table.
982
983        Similar to insert but updates an existing row.  The update is based
984        on the OID value as munged by get() or passed as keyword, or on the
985        primary key of the table.  The dictionary is modified, if possible,
986        to reflect any changes caused by the update due to triggers, rules,
987        default values, etc.
988
989        """
990        # Update always works on the oid which get() returns if available,
991        # otherwise use the primary key.  Fail if neither.
992        # Note that we only accept oid key from named args for safety.
993        qcl = self._add_schema(cl)
994        qoid = _oid_key(qcl)
995        if 'oid' in kw:
996            kw[qoid] = kw['oid']
997            del kw['oid']
998        if d is None:
999            d = {}
1000        d.update(kw)
1001        attnames = self.get_attnames(qcl)
1002        if qoid in d:
1003            where = 'oid = %s' % d[qoid]
1004            keyname = ()
1005        else:
1006            try:
1007                keyname = self.pkey(qcl)
1008            except KeyError:
1009                raise _prg_error('Class %s has no primary key' % qcl)
1010            if isinstance(keyname, basestring):
1011                keyname = (keyname,)
1012            try:
1013                where = ' AND '.join(['%s = %s'
1014                    % (k, self._quote(d[k], attnames[k])) for k in keyname])
1015            except KeyError:
1016                raise _prg_error('Update needs primary key or oid.')
1017        values = []
1018        for n in attnames:
1019            if n in d and n not in keyname:
1020                values.append('%s = %s' % (n, self._quote(d[n], attnames[n])))
1021        if not values:
1022            return d
1023        values = ', '.join(values)
1024        selectable = self.has_table_privilege(qcl)
1025        if selectable and self.server_version >= 80200:
1026            ret = ' RETURNING %s*' % ('oid' in attnames and 'oid, ' or '')
1027        else:
1028            ret = ''
1029        q = 'UPDATE %s SET %s WHERE %s%s' % (qcl, values, where, ret)
1030        self._do_debug(q)
1031        res = self.db.query(q)
1032        if ret:
1033            res = res.dictresult()[0]
1034            for att, value in res.items():
1035                d[att == 'oid' and qoid or att] = value
1036        else:
1037            if selectable:
1038                if qoid in d:
1039                    self.get(qcl, d, 'oid')
1040                else:
1041                    self.get(qcl, d)
1042        return d
1043
1044    def clear(self, cl, d=None):
1045        """Clear all the attributes to values determined by the types.
1046
1047        Numeric types are set to 0, Booleans are set to false, and everything
1048        else is set to the empty string.  If the second argument is present,
1049        it is used as the row dictionary and any entries matching attribute
1050        names are cleared with everything else left unchanged.
1051
1052        """
1053        # At some point we will need a way to get defaults from a table.
1054        qcl = self._add_schema(cl)
1055        if d is None:
1056            d = {}  # empty if argument is not present
1057        attnames = self.get_attnames(qcl)
1058        for n, t in attnames.items():
1059            if n == 'oid':
1060                continue
1061            if t in ('int', 'integer', 'smallint', 'bigint',
1062                    'float', 'real', 'double precision',
1063                    'num', 'numeric', 'money'):
1064                d[n] = 0
1065            elif t in ('bool', 'boolean'):
1066                d[n] = self._make_bool(False)
1067            else:
1068                d[n] = ''
1069        return d
1070
1071    def delete(self, cl, d=None, **kw):
1072        """Delete an existing row in a database table.
1073
1074        This method deletes the row from a table.  It deletes based on the
1075        OID value as munged by get() or passed as keyword, or on the primary
1076        key of the table.  The return value is the number of deleted rows
1077        (i.e. 0 if the row did not exist and 1 if the row was deleted).
1078
1079        """
1080        # Like update, delete works on the oid.
1081        # One day we will be testing that the record to be deleted
1082        # isn't referenced somewhere (or else PostgreSQL will).
1083        # Note that we only accept oid key from named args for safety
1084        qcl = self._add_schema(cl)
1085        qoid = _oid_key(qcl)
1086        if 'oid' in kw:
1087            kw[qoid] = kw['oid']
1088            del kw['oid']
1089        if d is None:
1090            d = {}
1091        d.update(kw)
1092        if qoid in d:
1093            where = 'oid = %s' % d[qoid]
1094        else:
1095            try:
1096                keyname = self.pkey(qcl)
1097            except KeyError:
1098                raise _prg_error('Class %s has no primary key' % qcl)
1099            if isinstance(keyname, basestring):
1100                keyname = (keyname,)
1101            attnames = self.get_attnames(qcl)
1102            try:
1103                where = ' AND '.join(['%s = %s'
1104                    % (k, self._quote(d[k], attnames[k])) for k in keyname])
1105            except KeyError:
1106                raise _prg_error('Delete needs primary key or oid.')
1107        q = 'DELETE FROM %s WHERE %s' % (qcl, where)
1108        self._do_debug(q)
1109        return int(self.db.query(q))
1110
1111    def truncate(self, table, restart=False, cascade=False, only=False):
1112        """Empty a table or set of tables.
1113
1114        This method quickly removes all rows from the given table or set
1115        of tables.  It has the same effect as an unqualified DELETE on each
1116        table, but since it does not actually scan the tables it is faster.
1117        Furthermore, it reclaims disk space immediately, rather than requiring
1118        a subsequent VACUUM operation. This is most useful on large tables.
1119
1120        If restart is set to True, sequences owned by columns of the truncated
1121        table(s) are automatically restarted.  If cascade is set to True, it
1122        also truncates all tables that have foreign-key references to any of
1123        the named tables.  If the parameter only is not set to True, all the
1124        descendant tables (if any) will also be truncated. Optionally, a '*'
1125        can be specified after the table name to explicitly indicate that
1126        descendant tables are included.
1127        """
1128        if isinstance(table, basestring):
1129            only = {table: only}
1130            table = [table]
1131        elif isinstance(table, (list, tuple)):
1132            if isinstance(only, (list, tuple)):
1133                only = dict(zip(table, only))
1134            else:
1135                only = dict.fromkeys(table, only)
1136        elif isinstance(table, (set, frozenset)):
1137            only = dict.fromkeys(table, only)
1138        else:
1139            raise TypeError('The table must be a string, list or set')
1140        if not (restart is None or isinstance(restart, (bool, int))):
1141            raise TypeError('Invalid type for the restart option')
1142        if not (cascade is None or isinstance(cascade, (bool, int))):
1143            raise TypeError('Invalid type for the cascade option')
1144        tables = []
1145        for t in table:
1146            u = only.get(t)
1147            if not (u is None or isinstance(u, (bool, int))):
1148                raise TypeError('Invalid type for the only option')
1149            if t.endswith('*'):
1150                if u:
1151                    raise ValueError(
1152                        'Contradictory table name and only options')
1153                t = t[:-1].rstrip()
1154            t = self._add_schema(t)
1155            if u:
1156                t = 'ONLY %s' % t
1157            tables.append(t)
1158        q = ['TRUNCATE', ', '.join(tables)]
1159        if restart:
1160            q.append('RESTART IDENTITY')
1161        if cascade:
1162            q.append('CASCADE')
1163        q = ' '.join(q)
1164        self._do_debug(q)
1165        return self.db.query(q)
1166
1167    def notification_handler(self,
1168            event, callback, arg_dict=None, timeout=None, stop_event=None):
1169        """Get notification handler that will run the given callback."""
1170        return NotificationHandler(self,
1171            event, callback, arg_dict, timeout, stop_event)
1172
1173
1174# if run as script, print some information
1175
1176if __name__ == '__main__':
1177    print('PyGreSQL version' + version)
1178    print('')
1179    print(__doc__)
Note: See TracBrowser for help on using the repository browser.