source: branches/4.x/pg.py @ 775

Last change on this file since 775 was 775, checked in by cito, 4 years ago

Backport some minor doc fixes to the 4.x branch

  • Property svn:keywords set to Id
File size: 43.5 KB
Line 
1#! /usr/bin/python
2#
3# pg.py
4#
5# $Id: pg.py 775 2016-01-21 19:07:16Z cito $
6#
7
8"""PyGreSQL classic interface.
9
10This pg module implements some basic database management stuff.
11It includes the _pg module and builds on it, providing the higher
12level wrapper class named DB with additional functionality.
13This is known as the "classic" ("old style") PyGreSQL interface.
14For a DB-API 2 compliant interface use the newer pgdb module.
15
16"""
17
18# Copyright (c) 1997-2016 by D'Arcy J.M. Cain.
19#
20# Contributions made by Ch. Zwerschke and others.
21#
22# The notification handler is based on pgnotify which is
23# Copyright (c) 2001 Ng Pheng Siong. All rights reserved.
24#
25# Permission to use, copy, modify, and distribute this software and its
26# documentation for any purpose and without fee is hereby granted,
27# provided that the above copyright notice appear in all copies and that
28# both that copyright notice and this permission notice appear in
29# supporting documentation.
30
31from _pg import *
32
33import select
34import warnings
35try:
36    frozenset
37except NameError:  # Python < 2.4, unsupported
38    from sets import ImmutableSet as frozenset
39try:
40    from decimal import Decimal
41    set_decimal(Decimal)
42except ImportError:  # Python < 2.4, unsupported
43    Decimal = float
44try:
45    from collections import namedtuple
46except ImportError:  # Python < 2.6
47    namedtuple = None
48
49
50# Auxiliary functions that are independent from a DB connection:
51
52def _is_quoted(s):
53    """Check whether this string is a quoted identifier."""
54    s = s.replace('_', 'a')
55    return not s.isalnum() or s[:1].isdigit() or s != s.lower()
56
57
58def _is_unquoted(s):
59    """Check whether this string is an unquoted identifier."""
60    s = s.replace('_', 'a')
61    return s.isalnum() and not s[:1].isdigit()
62
63
64def _split_first_part(s):
65    """Split the first part of a dot separated string."""
66    s = s.lstrip()
67    if s[:1] == '"':
68        p = []
69        s = s.split('"', 3)[1:]
70        p.append(s[0])
71        while len(s) == 3 and s[1] == '':
72            p.append('"')
73            s = s[2].split('"', 2)
74            p.append(s[0])
75        p = [''.join(p)]
76        s = '"'.join(s[1:]).lstrip()
77        if s:
78            if s[:0] == '.':
79                p.append(s[1:])
80            else:
81                s = _split_first_part(s)
82                p[0] += s[0]
83                if len(s) > 1:
84                    p.append(s[1])
85    else:
86        p = s.split('.', 1)
87        s = p[0].rstrip()
88        if _is_unquoted(s):
89            s = s.lower()
90        p[0] = s
91    return p
92
93
94def _split_parts(s):
95    """Split all parts of a dot separated string."""
96    q = []
97    while s:
98        s = _split_first_part(s)
99        q.append(s[0])
100        if len(s) < 2:
101            break
102        s = s[1]
103    return q
104
105
106def _join_parts(s):
107    """Join all parts of a dot separated string."""
108    return '.'.join([_is_quoted(p) and '"%s"' % p or p for p in s])
109
110
111def _oid_key(qcl):
112    """Build oid key from qualified class name."""
113    return 'oid(%s)' % qcl
114
115
116if namedtuple:
117
118    def _namedresult(q):
119        """Get query result as named tuples."""
120        row = namedtuple('Row', q.listfields())
121        return [row(*r) for r in q.getresult()]
122
123    set_namedresult(_namedresult)
124
125
126def _db_error(msg, cls=DatabaseError):
127    """Returns DatabaseError with empty sqlstate attribute."""
128    error = cls(msg)
129    error.sqlstate = None
130    return error
131
132
133def _int_error(msg):
134    """Returns InternalError."""
135    return _db_error(msg, InternalError)
136
137
138def _prg_error(msg):
139    """Returns ProgrammingError."""
140    return _db_error(msg, ProgrammingError)
141
142
143# The notification handler
144
145class NotificationHandler(object):
146    """A PostgreSQL client-side asynchronous notification handler."""
147
148    def __init__(self, db, event, callback=None,
149            arg_dict=None, timeout=None, stop_event=None):
150        """Initialize the notification handler.
151
152        You must pass a PyGreSQL database connection, the name of an
153        event (notification channel) to listen for and a callback function.
154
155        You can also specify a dictionary arg_dict that will be passed as
156        the single argument to the callback function, and a timeout value
157        in seconds (a floating point number denotes fractions of seconds).
158        If it is absent or None, the callers will never time out.  If the
159        timeout is reached, the callback function will be called with a
160        single argument that is None.  If you set the timeout to zero,
161        the handler will poll notifications synchronously and return.
162
163        You can specify the name of the event that will be used to signal
164        the handler to stop listening as stop_event. By default, it will
165        be the event name prefixed with 'stop_'.
166        """
167        self.db = db
168        self.event = event
169        self.stop_event = stop_event or 'stop_%s' % event
170        self.listening = False
171        self.callback = callback
172        if arg_dict is None:
173            arg_dict = {}
174        self.arg_dict = arg_dict
175        self.timeout = timeout
176
177    def __del__(self):
178        self.unlisten()
179
180    def close(self):
181        """Stop listening and close the connection."""
182        if self.db:
183            self.unlisten()
184            self.db.close()
185            self.db = None
186
187    def listen(self):
188        """Start listening for the event and the stop event."""
189        if not self.listening:
190            self.db.query('listen "%s"' % self.event)
191            self.db.query('listen "%s"' % self.stop_event)
192            self.listening = True
193
194    def unlisten(self):
195        """Stop listening for the event and the stop event."""
196        if self.listening:
197            self.db.query('unlisten "%s"' % self.event)
198            self.db.query('unlisten "%s"' % self.stop_event)
199            self.listening = False
200
201    def notify(self, db=None, stop=False, payload=None):
202        """Generate a notification.
203
204        Optionally, you can pass a payload with the notification.
205
206        If you set the stop flag, a stop notification will be sent that
207        will cause the handler to stop listening.
208
209        Note: If the notification handler is running in another thread, you
210        must pass a different database connection since PyGreSQL database
211        connections are not thread-safe.
212        """
213        if self.listening:
214            if not db:
215                db = self.db
216            q = 'notify "%s"' % (stop and self.stop_event or self.event)
217            if payload:
218                q += ", '%s'" % payload
219            return db.query(q)
220
221    def __call__(self):
222        """Invoke the notification handler.
223
224        The handler is a loop that listens for notifications on the event
225        and stop event channels.  When either of these notifications are
226        received, its associated 'pid', 'event' and 'extra' (the payload
227        passed with the notification) are inserted into its arg_dict
228        dictionary and the callback is invoked with this dictionary as
229        a single argument.  When the handler receives a stop event, it
230        stops listening to both events and return.
231
232        In the special case that the timeout of the handler has been set
233        to zero, the handler will poll all events synchronously and return.
234        If will keep listening until it receives a stop event.
235
236        Note: If you run this loop in another thread, don't use the same
237        database connection for database operations in the main thread.
238        """
239        self.listen()
240        poll = self.timeout == 0
241        if not poll:
242            rlist = [self.db.fileno()]
243        while self.listening:
244            if poll or select.select(rlist, [], [], self.timeout)[0]:
245                while self.listening:
246                    notice = self.db.getnotify()
247                    if not notice:  # no more messages
248                        break
249                    event, pid, extra = notice
250                    if event not in (self.event, self.stop_event):
251                        self.unlisten()
252                        raise _db_error(
253                            'Listening for "%s" and "%s", but notified of "%s"'
254                            % (self.event, self.stop_event, event))
255                    if event == self.stop_event:
256                        self.unlisten()
257                    self.arg_dict.update(pid=pid, event=event, extra=extra)
258                    self.callback(self.arg_dict)
259                if poll:
260                    break
261            else:   # we timed out
262                self.unlisten()
263                self.callback(None)
264
265
266def pgnotify(*args, **kw):
267    """Same as NotificationHandler, under the traditional name."""
268    warnings.warn("pgnotify is deprecated, use NotificationHandler instead.",
269        DeprecationWarning, stacklevel=2)
270    return NotificationHandler(*args, **kw)
271
272
273# The actual PostGreSQL database connection interface:
274
275class DB(object):
276    """Wrapper class for the _pg connection type."""
277
278    def __init__(self, *args, **kw):
279        """Create a new connection.
280
281        You can pass either the connection parameters or an existing
282        _pg or pgdb connection. This allows you to use the methods
283        of the classic pg interface with a DB-API 2 pgdb connection.
284
285        """
286        if not args and len(kw) == 1:
287            db = kw.get('db')
288        elif not kw and len(args) == 1:
289            db = args[0]
290        else:
291            db = None
292        if db:
293            if isinstance(db, DB):
294                db = db.db
295            else:
296                try:
297                    db = db._cnx
298                except AttributeError:
299                    pass
300        if not db or not hasattr(db, 'db') or not hasattr(db, 'query'):
301            db = connect(*args, **kw)
302            self._closeable = True
303        else:
304            self._closeable = False
305        self.db = db
306        self.dbname = db.db
307        self._regtypes = False
308        self._attnames = {}
309        self._pkeys = {}
310        self._privileges = {}
311        self._args = args, kw
312        self.debug = None  # For debugging scripts, this can be set
313            # * to a string format specification (e.g. in CGI set to "%s<BR>"),
314            # * to a file object to write debug statements or
315            # * to a callable object which takes a string argument
316            # * to any other true value to just print debug statements
317
318    def __getattr__(self, name):
319        # All undefined members are same as in underlying connection:
320        if self.db:
321            return getattr(self.db, name)
322        else:
323            raise _int_error('Connection is not valid')
324
325    # Context manager methods
326
327    def __enter__(self):
328        """Enter the runtime context. This will start a transaction."""
329        self.begin()
330        return self
331
332    def __exit__(self, et, ev, tb):
333        """Exit the runtime context. This will end the transaction."""
334        if et is None and ev is None and tb is None:
335            self.commit()
336        else:
337            self.rollback()
338
339    # Auxiliary methods
340
341    def _do_debug(self, s):
342        """Print a debug message."""
343        if self.debug:
344            if isinstance(self.debug, basestring):
345                print(self.debug % s)
346            elif isinstance(self.debug, file):
347                self.debug.write(s + '\n')
348            elif callable(self.debug):
349                self.debug(s)
350            else:
351                print(s)
352
353    def _make_bool(d):
354        """Get boolean value corresponding to d."""
355        if get_bool():
356            return bool(d)
357        return d and 't' or 'f'
358    _make_bool = staticmethod(_make_bool)
359
360    def _quote_text(self, d):
361        """Quote text value."""
362        if not isinstance(d, basestring):
363            d = str(d)
364        return "'%s'" % self.escape_string(d)
365
366    _bool_true = frozenset('t true 1 y yes on'.split())
367
368    def _quote_bool(self, d):
369        """Quote boolean value."""
370        if isinstance(d, basestring):
371            if not d:
372                return 'NULL'
373            d = d.lower() in self._bool_true
374        return d and "'t'" or "'f'"
375
376    _date_literals = frozenset('current_date current_time'
377        ' current_timestamp localtime localtimestamp'.split())
378
379    def _quote_date(self, d):
380        """Quote date value."""
381        if not d:
382            return 'NULL'
383        if isinstance(d, basestring) and d.lower() in self._date_literals:
384            return d
385        return self._quote_text(d)
386
387    def _quote_num(self, d):
388        """Quote numeric value."""
389        if not d and d != 0:
390            return 'NULL'
391        return str(d)
392
393    def _quote_money(self, d):
394        """Quote money value."""
395        if d is None or d == '':
396            return 'NULL'
397        if not isinstance(d, basestring):
398            d = str(d)
399        return d
400
401    _quote_funcs = dict(  # quote methods for each type
402        text=_quote_text, bool=_quote_bool, date=_quote_date,
403        int=_quote_num, num=_quote_num, float=_quote_num,
404        money=_quote_money)
405
406    def _quote(self, d, t):
407        """Return quotes if needed."""
408        if d is None:
409            return 'NULL'
410        try:
411            quote_func = self._quote_funcs[t]
412        except KeyError:
413            quote_func = self._quote_funcs['text']
414        return quote_func(self, d)
415
416    def _split_schema(self, cl):
417        """Return schema and name of object separately.
418
419        This auxiliary function splits off the namespace (schema)
420        belonging to the class with the name cl. If the class name
421        is not qualified, the function is able to determine the schema
422        of the class, taking into account the current search path.
423
424        """
425        s = _split_parts(cl)
426        if len(s) > 1:  # name already qualified?
427            # should be database.schema.table or schema.table
428            if len(s) > 3:
429                raise _prg_error('Too many dots in class name %s' % cl)
430            schema, cl = s[-2:]
431        else:
432            cl = s[0]
433            # determine search path
434            q = 'SELECT current_schemas(TRUE)'
435            schemas = self.db.query(q).getresult()[0][0][1:-1].split(',')
436            if schemas:  # non-empty path
437                # search schema for this object in the current search path
438                # (we could also use unnest with ordinality here to spare
439                # one query, but this is only possible since PostgreSQL 9.4)
440                q = ' UNION '.join(
441                    ["SELECT %d::integer AS n, '%s'::name AS nspname"
442                        % s for s in enumerate(schemas)])
443                q = ("SELECT nspname FROM pg_class r"
444                    " JOIN pg_namespace s ON r.relnamespace = s.oid"
445                    " JOIN (%s) AS p USING (nspname)"
446                    " WHERE r.relname = $1 ORDER BY n LIMIT 1" % q)
447                schema = self.db.query(q, (cl,)).getresult()
448                if schema:  # schema found
449                    schema = schema[0][0]
450                else:  # object not found in current search path
451                    schema = 'public'
452            else:  # empty path
453                schema = 'public'
454        return schema, cl
455
456    def _add_schema(self, cl):
457        """Ensure that the class name is prefixed with a schema name."""
458        return _join_parts(self._split_schema(cl))
459
460    # Public methods
461
462    # escape_string and escape_bytea exist as methods,
463    # so we define unescape_bytea as a method as well
464    unescape_bytea = staticmethod(unescape_bytea)
465
466    def close(self):
467        """Close the database connection."""
468        # Wraps shared library function so we can track state.
469        if self._closeable:
470            if self.db:
471                self.db.close()
472                self.db = None
473            else:
474                raise _int_error('Connection already closed')
475
476    def reset(self):
477        """Reset connection with current parameters.
478
479        All derived queries and large objects derived from this connection
480        will not be usable after this call.
481
482        """
483        if self.db:
484            self.db.reset()
485        else:
486            raise _int_error('Connection already closed')
487
488    def reopen(self):
489        """Reopen connection to the database.
490
491        Used in case we need another connection to the same database.
492        Note that we can still reopen a database that we have closed.
493
494        """
495        # There is no such shared library function.
496        if self._closeable:
497            db = connect(*self._args[0], **self._args[1])
498            if self.db:
499                self.db.close()
500            self.db = db
501
502    def begin(self, mode=None):
503        """Begin a transaction."""
504        qstr = 'BEGIN'
505        if mode:
506            qstr += ' ' + mode
507        return self.query(qstr)
508
509    start = begin
510
511    def commit(self):
512        """Commit the current transaction."""
513        return self.query('COMMIT')
514
515    end = commit
516
517    def rollback(self, name=None):
518        """Roll back the current transaction."""
519        qstr = 'ROLLBACK'
520        if name:
521            qstr += ' TO ' + name
522        return self.query(qstr)
523
524    abort = rollback
525
526    def savepoint(self, name):
527        """Define a new savepoint within the current transaction."""
528        return self.query('SAVEPOINT ' + name)
529
530    def release(self, name):
531        """Destroy a previously defined savepoint."""
532        return self.query('RELEASE ' + name)
533
534    def get_parameter(self, parameter):
535        """Get the value of a run-time parameter.
536
537        If the parameter is a string, the return value will also be a string
538        that is the current setting of the run-time parameter with that name.
539
540        You can get several parameters at once by passing a list, set or dict.
541        When passing a list of parameter names, the return value will be a
542        corresponding list of parameter settings.  When passing a set of
543        parameter names, a new dict will be returned, mapping these parameter
544        names to their settings.  Finally, if you pass a dict as parameter,
545        its values will be set to the current parameter settings corresponding
546        to its keys.
547
548        By passing the special name 'all' as the parameter, you can get a dict
549        of all existing configuration parameters.
550        """
551        if isinstance(parameter, basestring):
552            parameter = [parameter]
553            values = None
554        elif isinstance(parameter, (list, tuple)):
555            values = []
556        elif isinstance(parameter, (set, frozenset)):
557            values = {}
558        elif isinstance(parameter, dict):
559            values = parameter
560        else:
561            raise TypeError(
562                'The parameter must be a string, list, set or dict')
563        if not parameter:
564            raise TypeError('No parameter has been specified')
565        if isinstance(values, dict):
566            params = {}
567        else:
568            params = []
569        for key in parameter:
570            if isinstance(key, basestring):
571                param = key.strip().lower()
572            else:
573                param = None
574            if not param:
575                raise TypeError('Invalid parameter')
576            if param == 'all':
577                q = 'SHOW ALL'
578                values = self.db.query(q).getresult()
579                values = dict(value[:2] for value in values)
580                break
581            if isinstance(values, dict):
582                params[param] = key
583            else:
584                params.append(param)
585        else:
586            for param in params:
587                q = 'SHOW %s' % (param,)
588                value = self.db.query(q).getresult()[0][0]
589                if values is None:
590                    values = value
591                elif isinstance(values, list):
592                    values.append(value)
593                else:
594                    values[params[param]] = value
595        return values
596
597    def set_parameter(self, parameter, value=None, local=False):
598        """Set the value of a run-time parameter.
599
600        If the parameter and the value are strings, the run-time parameter
601        will be set to that value.  If no value or None is passed as a value,
602        then the run-time parameter will be restored to its default value.
603
604        You can set several parameters at once by passing a list of parameter
605        names, together with a single value that all parameters should be
606        set to or with a corresponding list of values.  You can also pass
607        the parameters as a set if you only provide a single value.
608        Finally, you can pass a dict with parameter names as keys.  In this
609        case, you should not pass a value, since the values for the parameters
610        will be taken from the dict.
611
612        By passing the special name 'all' as the parameter, you can reset
613        all existing settable run-time parameters to their default values.
614
615        If you set local to True, then the command takes effect for only the
616        current transaction.  After commit() or rollback(), the session-level
617        setting takes effect again.  Setting local to True will appear to
618        have no effect if it is executed outside a transaction, since the
619        transaction will end immediately.
620        """
621        if isinstance(parameter, basestring):
622            parameter = {parameter: value}
623        elif isinstance(parameter, (list, tuple)):
624            if isinstance(value, (list, tuple)):
625                parameter = dict(zip(parameter, value))
626            else:
627                parameter = dict.fromkeys(parameter, value)
628        elif isinstance(parameter, (set, frozenset)):
629            if isinstance(value, (list, tuple, set, frozenset)):
630                value = set(value)
631                if len(value) == 1:
632                    value = value.pop()
633            if not(value is None or isinstance(value, basestring)):
634                raise ValueError('A single value must be specified'
635                    ' when parameter is a set')
636            parameter = dict.fromkeys(parameter, value)
637        elif isinstance(parameter, dict):
638            if value is not None:
639                raise ValueError('A value must not be specified'
640                    ' when parameter is a dictionary')
641        else:
642            raise TypeError(
643                'The parameter must be a string, list, set or dict')
644        if not parameter:
645            raise TypeError('No parameter has been specified')
646        params = {}
647        for key, value in parameter.items():
648            if isinstance(key, basestring):
649                param = key.strip().lower()
650            else:
651                param = None
652            if not param:
653                raise TypeError('Invalid parameter')
654            if param == 'all':
655                if value is not None:
656                    raise ValueError('A value must ot be specified'
657                        " when parameter is 'all'")
658                params = {'all': None}
659                break
660            params[param] = value
661        local = local and ' LOCAL' or ''
662        for param, value in params.items():
663            if value is None:
664                q = 'RESET%s %s' % (local, param)
665            else:
666                q = 'SET%s %s TO %s' % (local, param, value)
667            self._do_debug(q)
668            self.db.query(q)
669
670    def query(self, qstr, *args):
671        """Executes a SQL command string.
672
673        This method simply sends a SQL query to the database.  If the query is
674        an insert statement that inserted exactly one row into a table that
675        has OIDs, the return value is the OID of the newly inserted row.
676        If the query is an update or delete statement, or an insert statement
677        that did not insert exactly one row in a table with OIDs, then the
678        number of rows affected is returned as a string.  If it is a statement
679        that returns rows as a result (usually a select statement, but maybe
680        also an "insert/update ... returning" statement), this method returns
681        a pgqueryobject that can be accessed via getresult() or dictresult()
682        or simply printed.  Otherwise, it returns `None`.
683
684        The query can contain numbered parameters of the form $1 in place
685        of any data constant.  Arguments given after the query string will
686        be substituted for the corresponding numbered parameter.  Parameter
687        values can also be given as a single list or tuple argument.
688
689        Note that the query string must not be passed as a unicode value,
690        but you can pass arguments as unicode values if they can be decoded
691        using the current client encoding.
692
693        """
694        # Wraps shared library function for debugging.
695        if not self.db:
696            raise _int_error('Connection is not valid')
697        self._do_debug(qstr)
698        return self.db.query(qstr, args)
699
700    def pkey(self, cl, newpkey=None):
701        """This method gets or sets the primary key of a class.
702
703        Composite primary keys are represented as frozensets. Note that
704        this raises a KeyError if the table does not have a primary key.
705
706        If newpkey is set and is not a dictionary then set that
707        value as the primary key of the class.  If it is a dictionary
708        then replace the internal cache of primary keys with a copy of it.
709
710        """
711        # First see if the caller is supplying a dictionary
712        if isinstance(newpkey, dict):
713            # make sure that all classes have a namespace
714            self._pkeys = dict([
715                ('.' in cl and cl or 'public.' + cl, pkey)
716                for cl, pkey in newpkey.items()])
717            return self._pkeys
718
719        qcl = self._add_schema(cl)  # build fully qualified class name
720        # Check if the caller is supplying a new primary key for the class
721        if newpkey:
722            self._pkeys[qcl] = newpkey
723            return newpkey
724
725        # Get all the primary keys at once
726        if qcl not in self._pkeys:
727            # if not found, check again in case it was added after we started
728            self._pkeys = {}
729            if self.server_version >= 80200:
730                # the ANY syntax works correctly only with PostgreSQL >= 8.2
731                any_indkey = "= ANY (i.indkey)"
732            else:
733                any_indkey = "IN (%s)" % ', '.join(
734                    ['i.indkey[%d]' % i for i in range(16)])
735            q = ("SELECT s.nspname, r.relname, a.attname"
736                " FROM pg_class r"
737                " JOIN pg_namespace s ON s.oid = r.relnamespace"
738                " AND s.nspname NOT SIMILAR"
739                " TO 'pg/_%|information/_schema' ESCAPE '/'"
740                " JOIN pg_attribute a ON a.attrelid = r.oid"
741                " AND NOT a.attisdropped"
742                " JOIN pg_index i ON i.indrelid = r.oid"
743                " AND i.indisprimary AND a.attnum " + any_indkey)
744            for r in self.db.query(q).getresult():
745                cl, pkey = _join_parts(r[:2]), r[2]
746                self._pkeys.setdefault(cl, []).append(pkey)
747            # (only) for composite primary keys, the values will be frozensets
748            for cl, pkey in self._pkeys.items():
749                self._pkeys[cl] = len(pkey) > 1 and frozenset(pkey) or pkey[0]
750            self._do_debug(self._pkeys)
751
752        # will raise an exception if primary key doesn't exist
753        return self._pkeys[qcl]
754
755    def get_databases(self):
756        """Get list of databases in the system."""
757        return [s[0] for s in
758            self.db.query('SELECT datname FROM pg_database').getresult()]
759
760    def get_relations(self, kinds=None):
761        """Get list of relations in connected database of specified kinds.
762
763        If kinds is None or empty, all kinds of relations are returned.
764        Otherwise kinds can be a string or sequence of type letters
765        specifying which kind of relations you want to list.
766
767        """
768        where = kinds and " AND r.relkind IN (%s)" % ','.join(
769            ["'%s'" % k for k in kinds]) or ''
770        q = ("SELECT s.nspname, r.relname"
771            " FROM pg_class r"
772            " JOIN pg_namespace s ON s.oid = r.relnamespace"
773            " WHERE s.nspname NOT SIMILAR"
774            " TO 'pg/_%%|information/_schema' ESCAPE '/' %s"
775            " ORDER BY 1, 2") % where
776        return [_join_parts(r) for r in self.db.query(q).getresult()]
777
778    def get_tables(self):
779        """Return list of tables in connected database."""
780        return self.get_relations('r')
781
782    def get_attnames(self, cl, newattnames=None):
783        """Given the name of a table, digs out the set of attribute names.
784
785        Returns a dictionary of attribute names (the names are the keys,
786        the values are the names of the attributes' types).
787        If the optional newattnames exists, it must be a dictionary and
788        will become the new attribute names dictionary.
789
790        By default, only a limited number of simple types will be returned.
791        You can get the regular types after calling use_regtypes(True).
792
793        """
794        if isinstance(newattnames, dict):
795            self._attnames = newattnames
796            return
797        elif newattnames:
798            raise _prg_error('If supplied, newattnames must be a dictionary')
799        cl = self._split_schema(cl)  # split into schema and class
800        qcl = _join_parts(cl)  # build fully qualified name
801        # May as well cache them:
802        if qcl in self._attnames:
803            return self._attnames[qcl]
804        if qcl not in self.get_relations('rv'):
805            raise _prg_error('Class %s does not exist' % qcl)
806
807        q = ("SELECT a.attname, t.typname%s"
808            " FROM pg_class r"
809            " JOIN pg_namespace s ON r.relnamespace = s.oid"
810            " JOIN pg_attribute a ON a.attrelid = r.oid"
811            " JOIN pg_type t ON t.oid = a.atttypid"
812            " WHERE s.nspname = $1 AND r.relname = $2"
813            " AND (a.attnum > 0 OR a.attname = 'oid')"
814            " AND NOT a.attisdropped") % (
815                self._regtypes and '::regtype' or '',)
816        q = self.db.query(q, cl).getresult()
817
818        if self._regtypes:
819            t = dict(q)
820        else:
821            t = {}
822            for att, typ in q:
823                if typ.startswith('bool'):
824                    typ = 'bool'
825                elif typ.startswith('abstime'):
826                    typ = 'date'
827                elif typ.startswith('date'):
828                    typ = 'date'
829                elif typ.startswith('interval'):
830                    typ = 'date'
831                elif typ.startswith('timestamp'):
832                    typ = 'date'
833                elif typ.startswith('oid'):
834                    typ = 'int'
835                elif typ.startswith('int'):
836                    typ = 'int'
837                elif typ.startswith('float'):
838                    typ = 'float'
839                elif typ.startswith('numeric'):
840                    typ = 'num'
841                elif typ.startswith('money'):
842                    typ = 'money'
843                else:
844                    typ = 'text'
845                t[att] = typ
846
847        self._attnames[qcl] = t  # cache it
848        return self._attnames[qcl]
849
850    def use_regtypes(self, regtypes=None):
851        """Use regular type names instead of simplified type names."""
852        if regtypes is None:
853            return self._regtypes
854        else:
855            regtypes = bool(regtypes)
856            if regtypes != self._regtypes:
857                self._regtypes = regtypes
858                self._attnames.clear()
859            return regtypes
860
861    def has_table_privilege(self, cl, privilege='select'):
862        """Check whether current user has specified table privilege."""
863        qcl = self._add_schema(cl)
864        privilege = privilege.lower()
865        try:
866            return self._privileges[(qcl, privilege)]
867        except KeyError:
868            q = "SELECT has_table_privilege($1, $2)"
869            q = self.db.query(q, (qcl, privilege))
870            ret = q.getresult()[0][0] == self._make_bool(True)
871            self._privileges[(qcl, privilege)] = ret
872            return ret
873
874    def get(self, cl, arg, keyname=None):
875        """Get a row from a database table or view.
876
877        This method is the basic mechanism to get a single row.  It assumes
878        that the key specifies a unique row.  If keyname is not specified
879        then the primary key for the table is used.  If arg is a dictionary
880        then the value for the key is taken from it and it is modified to
881        include the new values, replacing existing values where necessary.
882        For a composite key, keyname can also be a sequence of key names.
883        The OID is also put into the dictionary if the table has one, but
884        in order to allow the caller to work with multiple tables, it is
885        munged as oid(schema.table).
886
887        """
888        if cl.endswith('*'):  # scan descendant tables?
889            cl = cl[:-1].rstrip()  # need parent table name
890        # build qualified class name
891        qcl = self._add_schema(cl)
892        # To allow users to work with multiple tables,
893        # we munge the name of the "oid" key
894        qoid = _oid_key(qcl)
895        if not keyname:
896            # use the primary key by default
897            try:
898                keyname = self.pkey(qcl)
899            except KeyError:
900                raise _prg_error('Class %s has no primary key' % qcl)
901        # We want the oid for later updates if that isn't the key
902        if keyname == 'oid':
903            if isinstance(arg, dict):
904                if qoid not in arg:
905                    raise _prg_error('%s not in arg' % qoid)
906            else:
907                arg = {qoid: arg}
908            where = 'oid = %s' % arg[qoid]
909            attnames = '*'
910        else:
911            attnames = self.get_attnames(qcl)
912            if isinstance(keyname, basestring):
913                keyname = (keyname,)
914            if not isinstance(arg, dict):
915                if len(keyname) > 1:
916                    raise _prg_error('Composite key needs dict as arg')
917                arg = dict([(k, arg) for k in keyname])
918            where = ' AND '.join(['%s = %s'
919                % (k, self._quote(arg[k], attnames[k])) for k in keyname])
920            attnames = ', '.join(attnames)
921        q = 'SELECT %s FROM %s WHERE %s LIMIT 1' % (attnames, qcl, where)
922        self._do_debug(q)
923        res = self.db.query(q).dictresult()
924        if not res:
925            raise _db_error('No such record in %s where %s' % (qcl, where))
926        for att, value in res[0].items():
927            arg[att == 'oid' and qoid or att] = value
928        return arg
929
930    def insert(self, cl, d=None, **kw):
931        """Insert a row into a database table.
932
933        This method inserts a row into a table.  The name of the table must
934        be passed as the first parameter.  The other parameters are used for
935        providing the data of the row that shall be inserted into the table.
936        If a dictionary is supplied as the second parameter, it starts with
937        that.  Otherwise it uses a blank dictionary. Either way the dictionary
938        is updated from the keywords.
939
940        The dictionary is then, if possible, reloaded with the values actually
941        inserted in order to pick up values modified by rules, triggers, etc.
942
943        """
944        qcl = self._add_schema(cl)
945        qoid = _oid_key(qcl)
946        if d is None:
947            d = {}
948        d.update(kw)
949        attnames = self.get_attnames(qcl)
950        names, values = [], []
951        for n in attnames:
952            if n != 'oid' and n in d:
953                names.append('"%s"' % n)
954                values.append(self._quote(d[n], attnames[n]))
955        names, values = ', '.join(names), ', '.join(values)
956        selectable = self.has_table_privilege(qcl)
957        if selectable and self.server_version >= 80200:
958            ret = ' RETURNING %s*' % ('oid' in attnames and 'oid, ' or '')
959        else:
960            ret = ''
961        q = 'INSERT INTO %s (%s) VALUES (%s)%s' % (qcl, names, values, ret)
962        self._do_debug(q)
963        res = self.db.query(q)
964        if ret:
965            res = res.dictresult()
966            for att, value in res[0].items():
967                d[att == 'oid' and qoid or att] = value
968        elif isinstance(res, int):
969            d[qoid] = res
970            if selectable:
971                self.get(qcl, d, 'oid')
972        elif selectable:
973            if qoid in d:
974                self.get(qcl, d, 'oid')
975            else:
976                try:
977                    self.get(qcl, d)
978                except ProgrammingError:
979                    pass  # table has no primary key
980        return d
981
982    def update(self, cl, d=None, **kw):
983        """Update an existing row in a database table.
984
985        Similar to insert but updates an existing row.  The update is based
986        on the OID value as munged by get() or passed as keyword, or on the
987        primary key of the table.  The dictionary is modified, if possible,
988        to reflect any changes caused by the update due to triggers, rules,
989        default values, etc.
990
991        """
992        # Update always works on the oid which get() returns if available,
993        # otherwise use the primary key.  Fail if neither.
994        # Note that we only accept oid key from named args for safety.
995        qcl = self._add_schema(cl)
996        qoid = _oid_key(qcl)
997        if 'oid' in kw:
998            kw[qoid] = kw['oid']
999            del kw['oid']
1000        if d is None:
1001            d = {}
1002        d.update(kw)
1003        attnames = self.get_attnames(qcl)
1004        if qoid in d:
1005            where = 'oid = %s' % d[qoid]
1006            keyname = ()
1007        else:
1008            try:
1009                keyname = self.pkey(qcl)
1010            except KeyError:
1011                raise _prg_error('Class %s has no primary key' % qcl)
1012            if isinstance(keyname, basestring):
1013                keyname = (keyname,)
1014            try:
1015                where = ' AND '.join(['%s = %s'
1016                    % (k, self._quote(d[k], attnames[k])) for k in keyname])
1017            except KeyError:
1018                raise _prg_error('Update needs primary key or oid.')
1019        values = []
1020        for n in attnames:
1021            if n in d and n not in keyname:
1022                values.append('%s = %s' % (n, self._quote(d[n], attnames[n])))
1023        if not values:
1024            return d
1025        values = ', '.join(values)
1026        selectable = self.has_table_privilege(qcl)
1027        if selectable and self.server_version >= 80200:
1028            ret = ' RETURNING %s*' % ('oid' in attnames and 'oid, ' or '')
1029        else:
1030            ret = ''
1031        q = 'UPDATE %s SET %s WHERE %s%s' % (qcl, values, where, ret)
1032        self._do_debug(q)
1033        res = self.db.query(q)
1034        if ret:
1035            res = res.dictresult()[0]
1036            for att, value in res.items():
1037                d[att == 'oid' and qoid or att] = value
1038        else:
1039            if selectable:
1040                if qoid in d:
1041                    self.get(qcl, d, 'oid')
1042                else:
1043                    self.get(qcl, d)
1044        return d
1045
1046    def clear(self, cl, d=None):
1047        """Clear all the attributes to values determined by the types.
1048
1049        Numeric types are set to 0, Booleans are set to false, and everything
1050        else is set to the empty string.  If the second argument is present,
1051        it is used as the row dictionary and any entries matching attribute
1052        names are cleared with everything else left unchanged.
1053
1054        """
1055        # At some point we will need a way to get defaults from a table.
1056        qcl = self._add_schema(cl)
1057        if d is None:
1058            d = {}  # empty if argument is not present
1059        attnames = self.get_attnames(qcl)
1060        for n, t in attnames.items():
1061            if n == 'oid':
1062                continue
1063            if t in ('int', 'integer', 'smallint', 'bigint',
1064                    'float', 'real', 'double precision',
1065                    'num', 'numeric', 'money'):
1066                d[n] = 0
1067            elif t in ('bool', 'boolean'):
1068                d[n] = self._make_bool(False)
1069            else:
1070                d[n] = ''
1071        return d
1072
1073    def delete(self, cl, d=None, **kw):
1074        """Delete an existing row in a database table.
1075
1076        This method deletes the row from a table.  It deletes based on the
1077        OID value as munged by get() or passed as keyword, or on the primary
1078        key of the table.  The return value is the number of deleted rows
1079        (i.e. 0 if the row did not exist and 1 if the row was deleted).
1080
1081        """
1082        # Like update, delete works on the oid.
1083        # One day we will be testing that the record to be deleted
1084        # isn't referenced somewhere (or else PostgreSQL will).
1085        # Note that we only accept oid key from named args for safety
1086        qcl = self._add_schema(cl)
1087        qoid = _oid_key(qcl)
1088        if 'oid' in kw:
1089            kw[qoid] = kw['oid']
1090            del kw['oid']
1091        if d is None:
1092            d = {}
1093        d.update(kw)
1094        if qoid in d:
1095            where = 'oid = %s' % d[qoid]
1096        else:
1097            try:
1098                keyname = self.pkey(qcl)
1099            except KeyError:
1100                raise _prg_error('Class %s has no primary key' % qcl)
1101            if isinstance(keyname, basestring):
1102                keyname = (keyname,)
1103            attnames = self.get_attnames(qcl)
1104            try:
1105                where = ' AND '.join(['%s = %s'
1106                    % (k, self._quote(d[k], attnames[k])) for k in keyname])
1107            except KeyError:
1108                raise _prg_error('Delete needs primary key or oid.')
1109        q = 'DELETE FROM %s WHERE %s' % (qcl, where)
1110        self._do_debug(q)
1111        return int(self.db.query(q))
1112
1113    def truncate(self, table, restart=False, cascade=False, only=False):
1114        """Empty a table or set of tables.
1115
1116        This method quickly removes all rows from the given table or set
1117        of tables.  It has the same effect as an unqualified DELETE on each
1118        table, but since it does not actually scan the tables it is faster.
1119        Furthermore, it reclaims disk space immediately, rather than requiring
1120        a subsequent VACUUM operation. This is most useful on large tables.
1121
1122        If restart is set to True, sequences owned by columns of the truncated
1123        table(s) are automatically restarted.  If cascade is set to True, it
1124        also truncates all tables that have foreign-key references to any of
1125        the named tables.  If the parameter only is not set to True, all the
1126        descendant tables (if any) will also be truncated. Optionally, a '*'
1127        can be specified after the table name to explicitly indicate that
1128        descendant tables are included.
1129        """
1130        if isinstance(table, basestring):
1131            only = {table: only}
1132            table = [table]
1133        elif isinstance(table, (list, tuple)):
1134            if isinstance(only, (list, tuple)):
1135                only = dict(zip(table, only))
1136            else:
1137                only = dict.fromkeys(table, only)
1138        elif isinstance(table, (set, frozenset)):
1139            only = dict.fromkeys(table, only)
1140        else:
1141            raise TypeError('The table must be a string, list or set')
1142        if not (restart is None or isinstance(restart, (bool, int))):
1143            raise TypeError('Invalid type for the restart option')
1144        if not (cascade is None or isinstance(cascade, (bool, int))):
1145            raise TypeError('Invalid type for the cascade option')
1146        tables = []
1147        for t in table:
1148            u = only.get(t)
1149            if not (u is None or isinstance(u, (bool, int))):
1150                raise TypeError('Invalid type for the only option')
1151            if t.endswith('*'):
1152                if u:
1153                    raise ValueError(
1154                        'Contradictory table name and only options')
1155                t = t[:-1].rstrip()
1156            t = self._add_schema(t)
1157            if u:
1158                t = 'ONLY %s' % t
1159            tables.append(t)
1160        q = ['TRUNCATE', ', '.join(tables)]
1161        if restart:
1162            q.append('RESTART IDENTITY')
1163        if cascade:
1164            q.append('CASCADE')
1165        q = ' '.join(q)
1166        self._do_debug(q)
1167        return self.db.query(q)
1168
1169    def notification_handler(self,
1170            event, callback, arg_dict=None, timeout=None, stop_event=None):
1171        """Get notification handler that will run the given callback."""
1172        return NotificationHandler(self,
1173            event, callback, arg_dict, timeout, stop_event)
1174
1175
1176# if run as script, print some information
1177
1178if __name__ == '__main__':
1179    print('PyGreSQL version' + version)
1180    print('')
1181    print(__doc__)
Note: See TracBrowser for help on using the repository browser.