source: branches/4.x/pg.py @ 857

Last change on this file since 857 was 857, checked in by cito, 3 years ago

Add system parameter to get_relations()

Also fix a regression in the 4.x branch when using temporary tables,
related to filtering system tables (as discussed on the mailing list).

  • Property svn:keywords set to Id
File size: 44.1 KB
Line 
1#! /usr/bin/python
2#
3# pg.py
4#
5# $Id: pg.py 857 2016-03-18 12:22:21Z cito $
6#
7
8"""PyGreSQL classic interface.
9
10This pg module implements some basic database management stuff.
11It includes the _pg module and builds on it, providing the higher
12level wrapper class named DB with additional functionality.
13This is known as the "classic" ("old style") PyGreSQL interface.
14For a DB-API 2 compliant interface use the newer pgdb module.
15
16"""
17
18# Copyright (c) 1997-2016 by D'Arcy J.M. Cain.
19#
20# Contributions made by Ch. Zwerschke and others.
21#
22# The notification handler is based on pgnotify which is
23# Copyright (c) 2001 Ng Pheng Siong. All rights reserved.
24#
25# Permission to use, copy, modify, and distribute this software and its
26# documentation for any purpose and without fee is hereby granted,
27# provided that the above copyright notice appear in all copies and that
28# both that copyright notice and this permission notice appear in
29# supporting documentation.
30
31from _pg import *
32
33import select
34import warnings
35try:
36    frozenset
37except NameError:  # Python < 2.4, unsupported
38    from sets import ImmutableSet as frozenset
39try:
40    from decimal import Decimal
41    set_decimal(Decimal)
42except ImportError:  # Python < 2.4, unsupported
43    Decimal = float
44try:
45    from collections import namedtuple
46except ImportError:  # Python < 2.6
47    namedtuple = None
48
49
50# Auxiliary functions that are independent from a DB connection:
51
52def _is_quoted(s):
53    """Check whether this string is a quoted identifier."""
54    s = s.replace('_', 'a')
55    return not s.isalnum() or s[:1].isdigit() or s != s.lower()
56
57
58def _is_unquoted(s):
59    """Check whether this string is an unquoted identifier."""
60    s = s.replace('_', 'a')
61    return s.isalnum() and not s[:1].isdigit()
62
63
64def _split_first_part(s):
65    """Split the first part of a dot separated string."""
66    s = s.lstrip()
67    if s[:1] == '"':
68        p = []
69        s = s.split('"', 3)[1:]
70        p.append(s[0])
71        while len(s) == 3 and s[1] == '':
72            p.append('"')
73            s = s[2].split('"', 2)
74            p.append(s[0])
75        p = [''.join(p)]
76        s = '"'.join(s[1:]).lstrip()
77        if s:
78            if s[:0] == '.':
79                p.append(s[1:])
80            else:
81                s = _split_first_part(s)
82                p[0] += s[0]
83                if len(s) > 1:
84                    p.append(s[1])
85    else:
86        p = s.split('.', 1)
87        s = p[0].rstrip()
88        if _is_unquoted(s):
89            s = s.lower()
90        p[0] = s
91    return p
92
93
94def _split_parts(s):
95    """Split all parts of a dot separated string."""
96    q = []
97    while s:
98        s = _split_first_part(s)
99        q.append(s[0])
100        if len(s) < 2:
101            break
102        s = s[1]
103    return q
104
105
106def _join_parts(s):
107    """Join all parts of a dot separated string."""
108    return '.'.join([_is_quoted(p) and '"%s"' % p or p for p in s])
109
110
111def _oid_key(qcl):
112    """Build oid key from qualified class name."""
113    return 'oid(%s)' % qcl
114
115
116if namedtuple:
117
118    def _namedresult(q):
119        """Get query result as named tuples."""
120        row = namedtuple('Row', q.listfields())
121        return [row(*r) for r in q.getresult()]
122
123    set_namedresult(_namedresult)
124
125
126def _db_error(msg, cls=DatabaseError):
127    """Returns DatabaseError with empty sqlstate attribute."""
128    error = cls(msg)
129    error.sqlstate = None
130    return error
131
132
133def _int_error(msg):
134    """Returns InternalError."""
135    return _db_error(msg, InternalError)
136
137
138def _prg_error(msg):
139    """Returns ProgrammingError."""
140    return _db_error(msg, ProgrammingError)
141
142
143# The notification handler
144
145class NotificationHandler(object):
146    """A PostgreSQL client-side asynchronous notification handler."""
147
148    def __init__(self, db, event, callback=None,
149            arg_dict=None, timeout=None, stop_event=None):
150        """Initialize the notification handler.
151
152        You must pass a PyGreSQL database connection, the name of an
153        event (notification channel) to listen for and a callback function.
154
155        You can also specify a dictionary arg_dict that will be passed as
156        the single argument to the callback function, and a timeout value
157        in seconds (a floating point number denotes fractions of seconds).
158        If it is absent or None, the callers will never time out.  If the
159        timeout is reached, the callback function will be called with a
160        single argument that is None.  If you set the timeout to zero,
161        the handler will poll notifications synchronously and return.
162
163        You can specify the name of the event that will be used to signal
164        the handler to stop listening as stop_event. By default, it will
165        be the event name prefixed with 'stop_'.
166        """
167        self.db = db
168        self.event = event
169        self.stop_event = stop_event or 'stop_%s' % event
170        self.listening = False
171        self.callback = callback
172        if arg_dict is None:
173            arg_dict = {}
174        self.arg_dict = arg_dict
175        self.timeout = timeout
176
177    def __del__(self):
178        self.unlisten()
179
180    def close(self):
181        """Stop listening and close the connection."""
182        if self.db:
183            self.unlisten()
184            self.db.close()
185            self.db = None
186
187    def listen(self):
188        """Start listening for the event and the stop event."""
189        if not self.listening:
190            self.db.query('listen "%s"' % self.event)
191            self.db.query('listen "%s"' % self.stop_event)
192            self.listening = True
193
194    def unlisten(self):
195        """Stop listening for the event and the stop event."""
196        if self.listening:
197            self.db.query('unlisten "%s"' % self.event)
198            self.db.query('unlisten "%s"' % self.stop_event)
199            self.listening = False
200
201    def notify(self, db=None, stop=False, payload=None):
202        """Generate a notification.
203
204        Optionally, you can pass a payload with the notification.
205
206        If you set the stop flag, a stop notification will be sent that
207        will cause the handler to stop listening.
208
209        Note: If the notification handler is running in another thread, you
210        must pass a different database connection since PyGreSQL database
211        connections are not thread-safe.
212        """
213        if self.listening:
214            if not db:
215                db = self.db
216            q = 'notify "%s"' % (stop and self.stop_event or self.event)
217            if payload:
218                q += ", '%s'" % payload
219            return db.query(q)
220
221    def __call__(self):
222        """Invoke the notification handler.
223
224        The handler is a loop that listens for notifications on the event
225        and stop event channels.  When either of these notifications are
226        received, its associated 'pid', 'event' and 'extra' (the payload
227        passed with the notification) are inserted into its arg_dict
228        dictionary and the callback is invoked with this dictionary as
229        a single argument.  When the handler receives a stop event, it
230        stops listening to both events and return.
231
232        In the special case that the timeout of the handler has been set
233        to zero, the handler will poll all events synchronously and return.
234        If will keep listening until it receives a stop event.
235
236        Note: If you run this loop in another thread, don't use the same
237        database connection for database operations in the main thread.
238        """
239        self.listen()
240        poll = self.timeout == 0
241        if not poll:
242            rlist = [self.db.fileno()]
243        while self.listening:
244            if poll or select.select(rlist, [], [], self.timeout)[0]:
245                while self.listening:
246                    notice = self.db.getnotify()
247                    if not notice:  # no more messages
248                        break
249                    event, pid, extra = notice
250                    if event not in (self.event, self.stop_event):
251                        self.unlisten()
252                        raise _db_error(
253                            'Listening for "%s" and "%s", but notified of "%s"'
254                            % (self.event, self.stop_event, event))
255                    if event == self.stop_event:
256                        self.unlisten()
257                    self.arg_dict.update(pid=pid, event=event, extra=extra)
258                    self.callback(self.arg_dict)
259                if poll:
260                    break
261            else:   # we timed out
262                self.unlisten()
263                self.callback(None)
264
265
266def pgnotify(*args, **kw):
267    """Same as NotificationHandler, under the traditional name."""
268    warnings.warn("pgnotify is deprecated, use NotificationHandler instead.",
269        DeprecationWarning, stacklevel=2)
270    return NotificationHandler(*args, **kw)
271
272
273# The actual PostGreSQL database connection interface:
274
275class DB(object):
276    """Wrapper class for the _pg connection type."""
277
278    def __init__(self, *args, **kw):
279        """Create a new connection.
280
281        You can pass either the connection parameters or an existing
282        _pg or pgdb connection. This allows you to use the methods
283        of the classic pg interface with a DB-API 2 pgdb connection.
284
285        """
286        if not args and len(kw) == 1:
287            db = kw.get('db')
288        elif not kw and len(args) == 1:
289            db = args[0]
290        else:
291            db = None
292        if db:
293            if isinstance(db, DB):
294                db = db.db
295            else:
296                try:
297                    db = db._cnx
298                except AttributeError:
299                    pass
300        if not db or not hasattr(db, 'db') or not hasattr(db, 'query'):
301            db = connect(*args, **kw)
302            self._closeable = True
303        else:
304            self._closeable = False
305        self.db = db
306        self.dbname = db.db
307        self._regtypes = False
308        self._attnames = {}
309        self._pkeys = {}
310        self._privileges = {}
311        self._args = args, kw
312        self.debug = None  # For debugging scripts, this can be set
313            # * to a string format specification (e.g. in CGI set to "%s<BR>"),
314            # * to a file object to write debug statements or
315            # * to a callable object which takes a string argument
316            # * to any other true value to just print debug statements
317
318    def __getattr__(self, name):
319        # All undefined members are same as in underlying connection:
320        if self.db:
321            return getattr(self.db, name)
322        else:
323            raise _int_error('Connection is not valid')
324
325    # Context manager methods
326
327    def __enter__(self):
328        """Enter the runtime context. This will start a transaction."""
329        self.begin()
330        return self
331
332    def __exit__(self, et, ev, tb):
333        """Exit the runtime context. This will end the transaction."""
334        if et is None and ev is None and tb is None:
335            self.commit()
336        else:
337            self.rollback()
338
339    # Auxiliary methods
340
341    def _do_debug(self, s):
342        """Print a debug message."""
343        if self.debug:
344            if isinstance(self.debug, basestring):
345                print(self.debug % s)
346            elif isinstance(self.debug, file):
347                self.debug.write(s + '\n')
348            elif callable(self.debug):
349                self.debug(s)
350            else:
351                print(s)
352
353    def _make_bool(d):
354        """Get boolean value corresponding to d."""
355        if get_bool():
356            return bool(d)
357        return d and 't' or 'f'
358    _make_bool = staticmethod(_make_bool)
359
360    def _quote_text(self, d):
361        """Quote text value."""
362        if not isinstance(d, basestring):
363            d = str(d)
364        return "'%s'" % self.escape_string(d)
365
366    _bool_true = frozenset('t true 1 y yes on'.split())
367
368    def _quote_bool(self, d):
369        """Quote boolean value."""
370        if isinstance(d, basestring):
371            if not d:
372                return 'NULL'
373            d = d.lower() in self._bool_true
374        return d and "'t'" or "'f'"
375
376    _date_literals = frozenset('current_date current_time'
377        ' current_timestamp localtime localtimestamp'.split())
378
379    def _quote_date(self, d):
380        """Quote date value."""
381        if not d:
382            return 'NULL'
383        if isinstance(d, basestring) and d.lower() in self._date_literals:
384            return d
385        return self._quote_text(d)
386
387    def _quote_num(self, d):
388        """Quote numeric value."""
389        if not d and d != 0:
390            return 'NULL'
391        return str(d)
392
393    def _quote_money(self, d):
394        """Quote money value."""
395        if d is None or d == '':
396            return 'NULL'
397        if not isinstance(d, basestring):
398            d = str(d)
399        return d
400
401    _quote_funcs = dict(  # quote methods for each type
402        text=_quote_text, bool=_quote_bool, date=_quote_date,
403        int=_quote_num, num=_quote_num, float=_quote_num,
404        money=_quote_money)
405
406    def _quote(self, d, t):
407        """Return quotes if needed."""
408        if d is None:
409            return 'NULL'
410        try:
411            quote_func = self._quote_funcs[t]
412        except KeyError:
413            quote_func = self._quote_funcs['text']
414        return quote_func(self, d)
415
416    def _split_schema(self, cl):
417        """Return schema and name of object separately.
418
419        This auxiliary function splits off the namespace (schema)
420        belonging to the class with the name cl. If the class name
421        is not qualified, the function is able to determine the schema
422        of the class, taking into account the current search path.
423
424        """
425        s = _split_parts(cl)
426        if len(s) > 1:  # name already qualified?
427            # should be database.schema.table or schema.table
428            if len(s) > 3:
429                raise _prg_error('Too many dots in class name %s' % cl)
430            schema, cl = s[-2:]
431        else:
432            cl = s[0]
433            # determine search path
434            q = 'SELECT current_schemas(TRUE)'
435            schemas = self.db.query(q).getresult()[0][0][1:-1].split(',')
436            if schemas:  # non-empty path
437                # search schema for this object in the current search path
438                # (we could also use unnest with ordinality here to spare
439                # one query, but this is only possible since PostgreSQL 9.4)
440                q = ' UNION '.join(
441                    ["SELECT %d::integer AS n, '%s'::name AS nspname"
442                        % s for s in enumerate(schemas)])
443                q = ("SELECT nspname FROM pg_class r"
444                    " JOIN pg_namespace s ON r.relnamespace = s.oid"
445                    " JOIN (%s) AS p USING (nspname)"
446                    " WHERE r.relname = $1 ORDER BY n LIMIT 1" % q)
447                schema = self.db.query(q, (cl,)).getresult()
448                if schema:  # schema found
449                    schema = schema[0][0]
450                else:  # object not found in current search path
451                    schema = 'public'
452            else:  # empty path
453                schema = 'public'
454        return schema, cl
455
456    def _add_schema(self, cl):
457        """Ensure that the class name is prefixed with a schema name."""
458        return _join_parts(self._split_schema(cl))
459
460    # Public methods
461
462    # escape_string and escape_bytea exist as methods,
463    # so we define unescape_bytea as a method as well
464    unescape_bytea = staticmethod(unescape_bytea)
465
466    def close(self):
467        """Close the database connection."""
468        # Wraps shared library function so we can track state.
469        if self._closeable:
470            if self.db:
471                self.db.close()
472                self.db = None
473            else:
474                raise _int_error('Connection already closed')
475
476    def reset(self):
477        """Reset connection with current parameters.
478
479        All derived queries and large objects derived from this connection
480        will not be usable after this call.
481
482        """
483        if self.db:
484            self.db.reset()
485        else:
486            raise _int_error('Connection already closed')
487
488    def reopen(self):
489        """Reopen connection to the database.
490
491        Used in case we need another connection to the same database.
492        Note that we can still reopen a database that we have closed.
493
494        """
495        # There is no such shared library function.
496        if self._closeable:
497            db = connect(*self._args[0], **self._args[1])
498            if self.db:
499                self.db.close()
500            self.db = db
501
502    def begin(self, mode=None):
503        """Begin a transaction."""
504        qstr = 'BEGIN'
505        if mode:
506            qstr += ' ' + mode
507        return self.query(qstr)
508
509    start = begin
510
511    def commit(self):
512        """Commit the current transaction."""
513        return self.query('COMMIT')
514
515    end = commit
516
517    def rollback(self, name=None):
518        """Roll back the current transaction."""
519        qstr = 'ROLLBACK'
520        if name:
521            qstr += ' TO ' + name
522        return self.query(qstr)
523
524    abort = rollback
525
526    def savepoint(self, name):
527        """Define a new savepoint within the current transaction."""
528        return self.query('SAVEPOINT ' + name)
529
530    def release(self, name):
531        """Destroy a previously defined savepoint."""
532        return self.query('RELEASE ' + name)
533
534    def get_parameter(self, parameter):
535        """Get the value of a run-time parameter.
536
537        If the parameter is a string, the return value will also be a string
538        that is the current setting of the run-time parameter with that name.
539
540        You can get several parameters at once by passing a list, set or dict.
541        When passing a list of parameter names, the return value will be a
542        corresponding list of parameter settings.  When passing a set of
543        parameter names, a new dict will be returned, mapping these parameter
544        names to their settings.  Finally, if you pass a dict as parameter,
545        its values will be set to the current parameter settings corresponding
546        to its keys.
547
548        By passing the special name 'all' as the parameter, you can get a dict
549        of all existing configuration parameters.
550        """
551        if isinstance(parameter, basestring):
552            parameter = [parameter]
553            values = None
554        elif isinstance(parameter, (list, tuple)):
555            values = []
556        elif isinstance(parameter, (set, frozenset)):
557            values = {}
558        elif isinstance(parameter, dict):
559            values = parameter
560        else:
561            raise TypeError(
562                'The parameter must be a string, list, set or dict')
563        if not parameter:
564            raise TypeError('No parameter has been specified')
565        if isinstance(values, dict):
566            params = {}
567        else:
568            params = []
569        for key in parameter:
570            if isinstance(key, basestring):
571                param = key.strip().lower()
572            else:
573                param = None
574            if not param:
575                raise TypeError('Invalid parameter')
576            if param == 'all':
577                q = 'SHOW ALL'
578                values = self.db.query(q).getresult()
579                values = dict(value[:2] for value in values)
580                break
581            if isinstance(values, dict):
582                params[param] = key
583            else:
584                params.append(param)
585        else:
586            for param in params:
587                q = 'SHOW %s' % (param,)
588                value = self.db.query(q).getresult()[0][0]
589                if values is None:
590                    values = value
591                elif isinstance(values, list):
592                    values.append(value)
593                else:
594                    values[params[param]] = value
595        return values
596
597    def set_parameter(self, parameter, value=None, local=False):
598        """Set the value of a run-time parameter.
599
600        If the parameter and the value are strings, the run-time parameter
601        will be set to that value.  If no value or None is passed as a value,
602        then the run-time parameter will be restored to its default value.
603
604        You can set several parameters at once by passing a list of parameter
605        names, together with a single value that all parameters should be
606        set to or with a corresponding list of values.  You can also pass
607        the parameters as a set if you only provide a single value.
608        Finally, you can pass a dict with parameter names as keys.  In this
609        case, you should not pass a value, since the values for the parameters
610        will be taken from the dict.
611
612        By passing the special name 'all' as the parameter, you can reset
613        all existing settable run-time parameters to their default values.
614
615        If you set local to True, then the command takes effect for only the
616        current transaction.  After commit() or rollback(), the session-level
617        setting takes effect again.  Setting local to True will appear to
618        have no effect if it is executed outside a transaction, since the
619        transaction will end immediately.
620        """
621        if isinstance(parameter, basestring):
622            parameter = {parameter: value}
623        elif isinstance(parameter, (list, tuple)):
624            if isinstance(value, (list, tuple)):
625                parameter = dict(zip(parameter, value))
626            else:
627                parameter = dict.fromkeys(parameter, value)
628        elif isinstance(parameter, (set, frozenset)):
629            if isinstance(value, (list, tuple, set, frozenset)):
630                value = set(value)
631                if len(value) == 1:
632                    value = value.pop()
633            if not(value is None or isinstance(value, basestring)):
634                raise ValueError('A single value must be specified'
635                    ' when parameter is a set')
636            parameter = dict.fromkeys(parameter, value)
637        elif isinstance(parameter, dict):
638            if value is not None:
639                raise ValueError('A value must not be specified'
640                    ' when parameter is a dictionary')
641        else:
642            raise TypeError(
643                'The parameter must be a string, list, set or dict')
644        if not parameter:
645            raise TypeError('No parameter has been specified')
646        params = {}
647        for key, value in parameter.items():
648            if isinstance(key, basestring):
649                param = key.strip().lower()
650            else:
651                param = None
652            if not param:
653                raise TypeError('Invalid parameter')
654            if param == 'all':
655                if value is not None:
656                    raise ValueError('A value must ot be specified'
657                        " when parameter is 'all'")
658                params = {'all': None}
659                break
660            params[param] = value
661        local = local and ' LOCAL' or ''
662        for param, value in params.items():
663            if value is None:
664                q = 'RESET%s %s' % (local, param)
665            else:
666                q = 'SET%s %s TO %s' % (local, param, value)
667            self._do_debug(q)
668            self.db.query(q)
669
670    def query(self, qstr, *args):
671        """Executes a SQL command string.
672
673        This method simply sends a SQL query to the database.  If the query is
674        an insert statement that inserted exactly one row into a table that
675        has OIDs, the return value is the OID of the newly inserted row.
676        If the query is an update or delete statement, or an insert statement
677        that did not insert exactly one row in a table with OIDs, then the
678        number of rows affected is returned as a string.  If it is a statement
679        that returns rows as a result (usually a select statement, but maybe
680        also an "insert/update ... returning" statement), this method returns
681        a pgqueryobject that can be accessed via getresult() or dictresult()
682        or simply printed.  Otherwise, it returns `None`.
683
684        The query can contain numbered parameters of the form $1 in place
685        of any data constant.  Arguments given after the query string will
686        be substituted for the corresponding numbered parameter.  Parameter
687        values can also be given as a single list or tuple argument.
688
689        Note that the query string must not be passed as a unicode value,
690        but you can pass arguments as unicode values if they can be decoded
691        using the current client encoding.
692
693        """
694        # Wraps shared library function for debugging.
695        if not self.db:
696            raise _int_error('Connection is not valid')
697        self._do_debug(qstr)
698        return self.db.query(qstr, args)
699
700    def pkey(self, cl, newpkey=None):
701        """This method gets or sets the primary key of a class.
702
703        Composite primary keys are represented as frozensets. Note that
704        this raises a KeyError if the table does not have a primary key.
705
706        If newpkey is set and is not a dictionary then set that
707        value as the primary key of the class.  If it is a dictionary
708        then replace the internal cache of primary keys with a copy of it.
709
710        """
711        # First see if the caller is supplying a dictionary
712        if isinstance(newpkey, dict):
713            # make sure that all classes have a namespace
714            self._pkeys = dict([
715                ('.' in cl and cl or 'public.' + cl, pkey)
716                for cl, pkey in newpkey.items()])
717            return self._pkeys
718
719        qcl = self._add_schema(cl)  # build fully qualified class name
720        # Check if the caller is supplying a new primary key for the class
721        if newpkey:
722            self._pkeys[qcl] = newpkey
723            return newpkey
724
725        # Get all the primary keys at once
726        if qcl not in self._pkeys:
727            # if not found, check again in case it was added after we started
728            self._pkeys = {}
729            if self.server_version >= 80200:
730                # the ANY syntax works correctly only with PostgreSQL >= 8.2
731                any_indkey = "= ANY (i.indkey)"
732            else:
733                any_indkey = "IN (%s)" % ', '.join(
734                    ['i.indkey[%d]' % i for i in range(16)])
735            q = ("SELECT s.nspname, r.relname, a.attname"
736                " FROM pg_class r"
737                " JOIN pg_namespace s ON s.oid = r.relnamespace"
738                " JOIN pg_attribute a ON a.attrelid = r.oid"
739                " AND NOT a.attisdropped"
740                " JOIN pg_index i ON i.indrelid = r.oid"
741                " AND i.indisprimary AND a.attnum %s"
742                " AND r.relkind IN ('r', 'v')" % any_indkey)
743            for r in self.db.query(q).getresult():
744                cl, pkey = _join_parts(r[:2]), r[2]
745                self._pkeys.setdefault(cl, []).append(pkey)
746            # (only) for composite primary keys, the values will be frozensets
747            for cl, pkey in self._pkeys.items():
748                self._pkeys[cl] = len(pkey) > 1 and frozenset(pkey) or pkey[0]
749            self._do_debug(self._pkeys)
750
751        # will raise an exception if primary key doesn't exist
752        return self._pkeys[qcl]
753
754    def get_databases(self):
755        """Get list of databases in the system."""
756        return [s[0] for s in
757            self.db.query('SELECT datname FROM pg_database').getresult()]
758
759    def get_relations(self, kinds=None, system=False):
760        """Get list of relations in connected database of specified kinds.
761
762        If kinds is None or empty, all kinds of relations are returned.
763        Otherwise kinds can be a string or sequence of type letters
764        specifying which kind of relations you want to list.
765
766        Set the system flag if you want to get the system relations as well.
767        """
768        where = []
769        if kinds:
770            where.append("r.relkind IN (%s)" %
771                ','.join(["'%s'" % k for k in kinds]))
772        if not system:
773            where.append("s.nspname NOT SIMILAR"
774                " TO 'pg/_%|information/_schema' ESCAPE '/'")
775        where = where and " WHERE %s" % ' AND '.join(where) or ''
776        q = ("SELECT s.nspname, r.relname"
777            " FROM pg_class r"
778            " JOIN pg_namespace s ON s.oid = r.relnamespace%s"
779            " ORDER BY 1, 2") % where
780        return [_join_parts(r) for r in self.db.query(q).getresult()]
781
782    def get_tables(self, system=False):
783        """Return list of tables in connected database.
784
785        Set the system flag if you want to get the system tables as well.
786        """
787        return self.get_relations('r', system)
788
789    def get_attnames(self, cl, newattnames=None):
790        """Given the name of a table, digs out the set of attribute names.
791
792        Returns a dictionary of attribute names (the names are the keys,
793        the values are the names of the attributes' types).
794        If the optional newattnames exists, it must be a dictionary and
795        will become the new attribute names dictionary.
796
797        By default, only a limited number of simple types will be returned.
798        You can get the regular types after calling use_regtypes(True).
799
800        """
801        if isinstance(newattnames, dict):
802            self._attnames = newattnames
803            return
804        elif newattnames:
805            raise _prg_error('If supplied, newattnames must be a dictionary')
806        cl = self._split_schema(cl)  # split into schema and class
807        qcl = _join_parts(cl)  # build fully qualified name
808        # May as well cache them:
809        if qcl in self._attnames:
810            return self._attnames[qcl]
811
812        q = ("SELECT a.attname, t.typname%s"
813            " FROM pg_class r"
814            " JOIN pg_namespace s ON r.relnamespace = s.oid"
815            " JOIN pg_attribute a ON a.attrelid = r.oid"
816            " JOIN pg_type t ON t.oid = a.atttypid"
817            " WHERE s.nspname = $1 AND r.relname = $2"
818            " AND r.relkind IN ('r', 'v')"
819            " AND (a.attnum > 0 OR a.attname = 'oid')"
820            " AND NOT a.attisdropped") % (
821                self._regtypes and '::regtype' or '',)
822        q = self.db.query(q, cl).getresult()
823        if not q:
824            r = ("SELECT r.relnamespace"
825                 " FROM pg_class r"
826                 " JOIN pg_namespace s ON s.oid = r.relnamespace"
827                 " WHERE s.nspname =$1 AND r.relname = $2"
828                 " AND r.relkind IN ('r', 'v') LIMIT 1")
829            r = self.db.query(r, cl).getresult()
830            if not r:
831                raise _prg_error('Class %s does not exist' % qcl)
832
833        if self._regtypes:
834            t = dict(q)
835        else:
836            t = {}
837            for att, typ in q:
838                if typ.startswith('bool'):
839                    typ = 'bool'
840                elif typ.startswith('abstime'):
841                    typ = 'date'
842                elif typ.startswith('date'):
843                    typ = 'date'
844                elif typ.startswith('interval'):
845                    typ = 'date'
846                elif typ.startswith('timestamp'):
847                    typ = 'date'
848                elif typ.startswith('oid'):
849                    typ = 'int'
850                elif typ.startswith('int'):
851                    typ = 'int'
852                elif typ.startswith('float'):
853                    typ = 'float'
854                elif typ.startswith('numeric'):
855                    typ = 'num'
856                elif typ.startswith('money'):
857                    typ = 'money'
858                else:
859                    typ = 'text'
860                t[att] = typ
861
862        self._attnames[qcl] = t  # cache it
863        return self._attnames[qcl]
864
865    def use_regtypes(self, regtypes=None):
866        """Use regular type names instead of simplified type names."""
867        if regtypes is None:
868            return self._regtypes
869        else:
870            regtypes = bool(regtypes)
871            if regtypes != self._regtypes:
872                self._regtypes = regtypes
873                self._attnames.clear()
874            return regtypes
875
876    def has_table_privilege(self, cl, privilege='select'):
877        """Check whether current user has specified table privilege."""
878        qcl = self._add_schema(cl)
879        privilege = privilege.lower()
880        try:
881            return self._privileges[(qcl, privilege)]
882        except KeyError:
883            q = "SELECT has_table_privilege($1, $2)"
884            q = self.db.query(q, (qcl, privilege))
885            ret = q.getresult()[0][0] == self._make_bool(True)
886            self._privileges[(qcl, privilege)] = ret
887            return ret
888
889    def get(self, cl, arg, keyname=None):
890        """Get a row from a database table or view.
891
892        This method is the basic mechanism to get a single row.  It assumes
893        that the key specifies a unique row.  If keyname is not specified
894        then the primary key for the table is used.  If arg is a dictionary
895        then the value for the key is taken from it and it is modified to
896        include the new values, replacing existing values where necessary.
897        For a composite key, keyname can also be a sequence of key names.
898        The OID is also put into the dictionary if the table has one, but
899        in order to allow the caller to work with multiple tables, it is
900        munged as oid(schema.table).
901
902        """
903        if cl.endswith('*'):  # scan descendant tables?
904            cl = cl[:-1].rstrip()  # need parent table name
905        # build qualified class name
906        qcl = self._add_schema(cl)
907        # To allow users to work with multiple tables,
908        # we munge the name of the "oid" key
909        qoid = _oid_key(qcl)
910        if not keyname:
911            # use the primary key by default
912            try:
913                keyname = self.pkey(qcl)
914            except KeyError:
915                raise _prg_error('Class %s has no primary key' % qcl)
916        # We want the oid for later updates if that isn't the key
917        if keyname == 'oid':
918            if isinstance(arg, dict):
919                if qoid not in arg:
920                    raise _prg_error('%s not in arg' % qoid)
921            else:
922                arg = {qoid: arg}
923            where = 'oid = %s' % arg[qoid]
924            attnames = '*'
925        else:
926            attnames = self.get_attnames(qcl)
927            if isinstance(keyname, basestring):
928                keyname = (keyname,)
929            if not isinstance(arg, dict):
930                if len(keyname) > 1:
931                    raise _prg_error('Composite key needs dict as arg')
932                arg = dict([(k, arg) for k in keyname])
933            where = ' AND '.join(['%s = %s'
934                % (k, self._quote(arg[k], attnames[k])) for k in keyname])
935            attnames = ', '.join(attnames)
936        q = 'SELECT %s FROM %s WHERE %s LIMIT 1' % (attnames, qcl, where)
937        self._do_debug(q)
938        res = self.db.query(q).dictresult()
939        if not res:
940            raise _db_error('No such record in %s where %s' % (qcl, where))
941        for att, value in res[0].items():
942            arg[att == 'oid' and qoid or att] = value
943        return arg
944
945    def insert(self, cl, d=None, **kw):
946        """Insert a row into a database table.
947
948        This method inserts a row into a table.  The name of the table must
949        be passed as the first parameter.  The other parameters are used for
950        providing the data of the row that shall be inserted into the table.
951        If a dictionary is supplied as the second parameter, it starts with
952        that.  Otherwise it uses a blank dictionary. Either way the dictionary
953        is updated from the keywords.
954
955        The dictionary is then, if possible, reloaded with the values actually
956        inserted in order to pick up values modified by rules, triggers, etc.
957
958        """
959        qcl = self._add_schema(cl)
960        qoid = _oid_key(qcl)
961        if d is None:
962            d = {}
963        d.update(kw)
964        attnames = self.get_attnames(qcl)
965        names, values = [], []
966        for n in attnames:
967            if n != 'oid' and n in d:
968                names.append('"%s"' % n)
969                values.append(self._quote(d[n], attnames[n]))
970        names, values = ', '.join(names), ', '.join(values)
971        selectable = self.has_table_privilege(qcl)
972        if selectable and self.server_version >= 80200:
973            ret = ' RETURNING %s*' % ('oid' in attnames and 'oid, ' or '')
974        else:
975            ret = ''
976        q = 'INSERT INTO %s (%s) VALUES (%s)%s' % (qcl, names, values, ret)
977        self._do_debug(q)
978        res = self.db.query(q)
979        if ret:
980            res = res.dictresult()
981            for att, value in res[0].items():
982                d[att == 'oid' and qoid or att] = value
983        elif isinstance(res, int):
984            d[qoid] = res
985            if selectable:
986                self.get(qcl, d, 'oid')
987        elif selectable:
988            if qoid in d:
989                self.get(qcl, d, 'oid')
990            else:
991                try:
992                    self.get(qcl, d)
993                except ProgrammingError:
994                    pass  # table has no primary key
995        return d
996
997    def update(self, cl, d=None, **kw):
998        """Update an existing row in a database table.
999
1000        Similar to insert but updates an existing row.  The update is based
1001        on the OID value as munged by get() or passed as keyword, or on the
1002        primary key of the table.  The dictionary is modified, if possible,
1003        to reflect any changes caused by the update due to triggers, rules,
1004        default values, etc.
1005
1006        """
1007        # Update always works on the oid which get() returns if available,
1008        # otherwise use the primary key.  Fail if neither.
1009        # Note that we only accept oid key from named args for safety.
1010        qcl = self._add_schema(cl)
1011        qoid = _oid_key(qcl)
1012        if 'oid' in kw:
1013            kw[qoid] = kw['oid']
1014            del kw['oid']
1015        if d is None:
1016            d = {}
1017        d.update(kw)
1018        attnames = self.get_attnames(qcl)
1019        if qoid in d:
1020            where = 'oid = %s' % d[qoid]
1021            keyname = ()
1022        else:
1023            try:
1024                keyname = self.pkey(qcl)
1025            except KeyError:
1026                raise _prg_error('Class %s has no primary key' % qcl)
1027            if isinstance(keyname, basestring):
1028                keyname = (keyname,)
1029            try:
1030                where = ' AND '.join(['%s = %s'
1031                    % (k, self._quote(d[k], attnames[k])) for k in keyname])
1032            except KeyError:
1033                raise _prg_error('Update needs primary key or oid.')
1034        values = []
1035        for n in attnames:
1036            if n in d and n not in keyname:
1037                values.append('%s = %s' % (n, self._quote(d[n], attnames[n])))
1038        if not values:
1039            return d
1040        values = ', '.join(values)
1041        selectable = self.has_table_privilege(qcl)
1042        if selectable and self.server_version >= 80200:
1043            ret = ' RETURNING %s*' % ('oid' in attnames and 'oid, ' or '')
1044        else:
1045            ret = ''
1046        q = 'UPDATE %s SET %s WHERE %s%s' % (qcl, values, where, ret)
1047        self._do_debug(q)
1048        res = self.db.query(q)
1049        if ret:
1050            res = res.dictresult()[0]
1051            for att, value in res.items():
1052                d[att == 'oid' and qoid or att] = value
1053        else:
1054            if selectable:
1055                if qoid in d:
1056                    self.get(qcl, d, 'oid')
1057                else:
1058                    self.get(qcl, d)
1059        return d
1060
1061    def clear(self, cl, d=None):
1062        """Clear all the attributes to values determined by the types.
1063
1064        Numeric types are set to 0, Booleans are set to false, and everything
1065        else is set to the empty string.  If the second argument is present,
1066        it is used as the row dictionary and any entries matching attribute
1067        names are cleared with everything else left unchanged.
1068
1069        """
1070        # At some point we will need a way to get defaults from a table.
1071        qcl = self._add_schema(cl)
1072        if d is None:
1073            d = {}  # empty if argument is not present
1074        attnames = self.get_attnames(qcl)
1075        for n, t in attnames.items():
1076            if n == 'oid':
1077                continue
1078            if t in ('int', 'integer', 'smallint', 'bigint',
1079                    'float', 'real', 'double precision',
1080                    'num', 'numeric', 'money'):
1081                d[n] = 0
1082            elif t in ('bool', 'boolean'):
1083                d[n] = self._make_bool(False)
1084            else:
1085                d[n] = ''
1086        return d
1087
1088    def delete(self, cl, d=None, **kw):
1089        """Delete an existing row in a database table.
1090
1091        This method deletes the row from a table.  It deletes based on the
1092        OID value as munged by get() or passed as keyword, or on the primary
1093        key of the table.  The return value is the number of deleted rows
1094        (i.e. 0 if the row did not exist and 1 if the row was deleted).
1095
1096        """
1097        # Like update, delete works on the oid.
1098        # One day we will be testing that the record to be deleted
1099        # isn't referenced somewhere (or else PostgreSQL will).
1100        # Note that we only accept oid key from named args for safety
1101        qcl = self._add_schema(cl)
1102        qoid = _oid_key(qcl)
1103        if 'oid' in kw:
1104            kw[qoid] = kw['oid']
1105            del kw['oid']
1106        if d is None:
1107            d = {}
1108        d.update(kw)
1109        if qoid in d:
1110            where = 'oid = %s' % d[qoid]
1111        else:
1112            try:
1113                keyname = self.pkey(qcl)
1114            except KeyError:
1115                raise _prg_error('Class %s has no primary key' % qcl)
1116            if isinstance(keyname, basestring):
1117                keyname = (keyname,)
1118            attnames = self.get_attnames(qcl)
1119            try:
1120                where = ' AND '.join(['%s = %s'
1121                    % (k, self._quote(d[k], attnames[k])) for k in keyname])
1122            except KeyError:
1123                raise _prg_error('Delete needs primary key or oid.')
1124        q = 'DELETE FROM %s WHERE %s' % (qcl, where)
1125        self._do_debug(q)
1126        return int(self.db.query(q))
1127
1128    def truncate(self, table, restart=False, cascade=False, only=False):
1129        """Empty a table or set of tables.
1130
1131        This method quickly removes all rows from the given table or set
1132        of tables.  It has the same effect as an unqualified DELETE on each
1133        table, but since it does not actually scan the tables it is faster.
1134        Furthermore, it reclaims disk space immediately, rather than requiring
1135        a subsequent VACUUM operation. This is most useful on large tables.
1136
1137        If restart is set to True, sequences owned by columns of the truncated
1138        table(s) are automatically restarted.  If cascade is set to True, it
1139        also truncates all tables that have foreign-key references to any of
1140        the named tables.  If the parameter only is not set to True, all the
1141        descendant tables (if any) will also be truncated. Optionally, a '*'
1142        can be specified after the table name to explicitly indicate that
1143        descendant tables are included.
1144        """
1145        if isinstance(table, basestring):
1146            only = {table: only}
1147            table = [table]
1148        elif isinstance(table, (list, tuple)):
1149            if isinstance(only, (list, tuple)):
1150                only = dict(zip(table, only))
1151            else:
1152                only = dict.fromkeys(table, only)
1153        elif isinstance(table, (set, frozenset)):
1154            only = dict.fromkeys(table, only)
1155        else:
1156            raise TypeError('The table must be a string, list or set')
1157        if not (restart is None or isinstance(restart, (bool, int))):
1158            raise TypeError('Invalid type for the restart option')
1159        if not (cascade is None or isinstance(cascade, (bool, int))):
1160            raise TypeError('Invalid type for the cascade option')
1161        tables = []
1162        for t in table:
1163            u = only.get(t)
1164            if not (u is None or isinstance(u, (bool, int))):
1165                raise TypeError('Invalid type for the only option')
1166            if t.endswith('*'):
1167                if u:
1168                    raise ValueError(
1169                        'Contradictory table name and only options')
1170                t = t[:-1].rstrip()
1171            t = self._add_schema(t)
1172            if u:
1173                t = 'ONLY %s' % t
1174            tables.append(t)
1175        q = ['TRUNCATE', ', '.join(tables)]
1176        if restart:
1177            q.append('RESTART IDENTITY')
1178        if cascade:
1179            q.append('CASCADE')
1180        q = ' '.join(q)
1181        self._do_debug(q)
1182        return self.db.query(q)
1183
1184    def notification_handler(self,
1185            event, callback, arg_dict=None, timeout=None, stop_event=None):
1186        """Get notification handler that will run the given callback."""
1187        return NotificationHandler(self,
1188            event, callback, arg_dict, timeout, stop_event)
1189
1190
1191# if run as script, print some information
1192
1193if __name__ == '__main__':
1194    print('PyGreSQL version' + version)
1195    print('')
1196    print(__doc__)
Note: See TracBrowser for help on using the repository browser.