source: branches/4.x/pg.py @ 761

Last change on this file since 761 was 748, checked in by cito, 4 years ago

Add method truncate() to DB wrapper class

This methods can be used to quickly truncate tables.

Since this is pretty useful and will not break anything, I have
also back ported this addition to the 4.x branch.

Everything is well documented and tested, of course.

  • Property svn:keywords set to Id
File size: 42.6 KB
Line 
1#! /usr/bin/python
2#
3# pg.py
4#
5# $Id: pg.py 748 2016-01-15 14:25:31Z cito $
6#
7
8"""PyGreSQL classic interface.
9
10This pg module implements some basic database management stuff.
11It includes the _pg module and builds on it, providing the higher
12level wrapper class named DB with additional functionality.
13This is known as the "classic" ("old style") PyGreSQL interface.
14For a DB-API 2 compliant interface use the newer pgdb module.
15
16"""
17
18# Copyright (c) 1997-2016 by D'Arcy J.M. Cain.
19#
20# Contributions made by Ch. Zwerschke and others.
21#
22# The notification handler is based on pgnotify which is
23# Copyright (c) 2001 Ng Pheng Siong. All rights reserved.
24#
25# Permission to use, copy, modify, and distribute this software and its
26# documentation for any purpose and without fee is hereby granted,
27# provided that the above copyright notice appear in all copies and that
28# both that copyright notice and this permission notice appear in
29# supporting documentation.
30
31from _pg import *
32
33import select
34import warnings
35try:
36    frozenset
37except NameError:  # Python < 2.4, unsupported
38    from sets import ImmutableSet as frozenset
39try:
40    from decimal import Decimal
41    set_decimal(Decimal)
42except ImportError:  # Python < 2.4, unsupported
43    Decimal = float
44try:
45    from collections import namedtuple
46except ImportError:  # Python < 2.6
47    namedtuple = None
48
49
50# Auxiliary functions that are independent from a DB connection:
51
52def _is_quoted(s):
53    """Check whether this string is a quoted identifier."""
54    s = s.replace('_', 'a')
55    return not s.isalnum() or s[:1].isdigit() or s != s.lower()
56
57
58def _is_unquoted(s):
59    """Check whether this string is an unquoted identifier."""
60    s = s.replace('_', 'a')
61    return s.isalnum() and not s[:1].isdigit()
62
63
64def _split_first_part(s):
65    """Split the first part of a dot separated string."""
66    s = s.lstrip()
67    if s[:1] == '"':
68        p = []
69        s = s.split('"', 3)[1:]
70        p.append(s[0])
71        while len(s) == 3 and s[1] == '':
72            p.append('"')
73            s = s[2].split('"', 2)
74            p.append(s[0])
75        p = [''.join(p)]
76        s = '"'.join(s[1:]).lstrip()
77        if s:
78            if s[:0] == '.':
79                p.append(s[1:])
80            else:
81                s = _split_first_part(s)
82                p[0] += s[0]
83                if len(s) > 1:
84                    p.append(s[1])
85    else:
86        p = s.split('.', 1)
87        s = p[0].rstrip()
88        if _is_unquoted(s):
89            s = s.lower()
90        p[0] = s
91    return p
92
93
94def _split_parts(s):
95    """Split all parts of a dot separated string."""
96    q = []
97    while s:
98        s = _split_first_part(s)
99        q.append(s[0])
100        if len(s) < 2:
101            break
102        s = s[1]
103    return q
104
105
106def _join_parts(s):
107    """Join all parts of a dot separated string."""
108    return '.'.join([_is_quoted(p) and '"%s"' % p or p for p in s])
109
110
111def _oid_key(qcl):
112    """Build oid key from qualified class name."""
113    return 'oid(%s)' % qcl
114
115
116if namedtuple:
117
118    def _namedresult(q):
119        """Get query result as named tuples."""
120        row = namedtuple('Row', q.listfields())
121        return [row(*r) for r in q.getresult()]
122
123    set_namedresult(_namedresult)
124
125
126def _db_error(msg, cls=DatabaseError):
127    """Returns DatabaseError with empty sqlstate attribute."""
128    error = cls(msg)
129    error.sqlstate = None
130    return error
131
132
133def _int_error(msg):
134    """Returns InternalError."""
135    return _db_error(msg, InternalError)
136
137
138def _prg_error(msg):
139    """Returns ProgrammingError."""
140    return _db_error(msg, ProgrammingError)
141
142
143class NotificationHandler(object):
144    """A PostgreSQL client-side asynchronous notification handler."""
145
146    def __init__(self, db, event, callback, arg_dict=None, timeout=None):
147        """Initialize the notification handler.
148
149        db       - PostgreSQL connection object.
150        event    - Event (notification channel) to LISTEN for.
151        callback - Event callback function.
152        arg_dict - A dictionary passed as the argument to the callback.
153        timeout  - Timeout in seconds; a floating point number denotes
154                   fractions of seconds. If it is absent or None, the
155                   callers will never time out.
156
157        """
158        if isinstance(db, DB):
159            db = db.db
160        self.db = db
161        self.event = event
162        self.stop_event = 'stop_%s' % event
163        self.listening = False
164        self.callback = callback
165        if arg_dict is None:
166            arg_dict = {}
167        self.arg_dict = arg_dict
168        self.timeout = timeout
169
170    def __del__(self):
171        self.close()
172
173    def close(self):
174        """Stop listening and close the connection."""
175        if self.db:
176            self.unlisten()
177            self.db.close()
178            self.db = None
179
180    def listen(self):
181        """Start listening for the event and the stop event."""
182        if not self.listening:
183            self.db.query('listen "%s"' % self.event)
184            self.db.query('listen "%s"' % self.stop_event)
185            self.listening = True
186
187    def unlisten(self):
188        """Stop listening for the event and the stop event."""
189        if self.listening:
190            self.db.query('unlisten "%s"' % self.event)
191            self.db.query('unlisten "%s"' % self.stop_event)
192            self.listening = False
193
194    def notify(self, db=None, stop=False, payload=None):
195        """Generate a notification.
196
197        Note: If the main loop is running in another thread, you must pass
198        a different database connection to avoid a collision.
199
200        The payload parameter is only supported in PostgreSQL >= 9.0.
201
202        """
203        if not db:
204            db = self.db
205        if self.listening:
206            q = 'notify "%s"' % (stop and self.stop_event or self.event)
207            if payload:
208                q += ", '%s'" % payload
209            return db.query(q)
210
211    def __call__(self, close=False):
212        """Invoke the notification handler.
213
214        The handler is a loop that actually LISTENs for two NOTIFY messages:
215
216        <event> and stop_<event>.
217
218        When either of these NOTIFY messages are received, its associated
219        'pid' and 'event' are inserted into <arg_dict>, and the callback is
220        invoked with <arg_dict>. If the NOTIFY message is stop_<event>, the
221        handler UNLISTENs both <event> and stop_<event> and exits.
222
223        Note: If you run this loop in another thread, don't use the same
224        database connection for database operations in the main thread.
225
226        """
227        self.listen()
228        _ilist = [self.db.fileno()]
229
230        while self.listening:
231            ilist, _olist, _elist = select.select(_ilist, [], [], self.timeout)
232            if ilist:
233                while self.listening:
234                    notice = self.db.getnotify()
235                    if not notice:  # no more messages
236                        break
237                    event, pid, extra = notice
238                    if event not in (self.event, self.stop_event):
239                        self.unlisten()
240                        raise _db_error(
241                            'listening for "%s" and "%s", but notified of "%s"'
242                            % (self.event, self.stop_event, event))
243                    if event == self.stop_event:
244                        self.unlisten()
245                    self.arg_dict['pid'] = pid
246                    self.arg_dict['event'] = event
247                    self.arg_dict['extra'] = extra
248                    self.callback(self.arg_dict)
249            else:   # we timed out
250                self.unlisten()
251                self.callback(None)
252
253
254def pgnotify(*args, **kw):
255    """Same as NotificationHandler, under the traditional name."""
256    warnings.warn("pgnotify is deprecated, use NotificationHandler instead.",
257        DeprecationWarning, stacklevel=2)
258    return NotificationHandler(*args, **kw)
259
260
261# The actual PostGreSQL database connection interface:
262
263class DB(object):
264    """Wrapper class for the _pg connection type."""
265
266    def __init__(self, *args, **kw):
267        """Create a new connection.
268
269        You can pass either the connection parameters or an existing
270        _pg or pgdb connection. This allows you to use the methods
271        of the classic pg interface with a DB-API 2 pgdb connection.
272
273        """
274        if not args and len(kw) == 1:
275            db = kw.get('db')
276        elif not kw and len(args) == 1:
277            db = args[0]
278        else:
279            db = None
280        if db:
281            if isinstance(db, DB):
282                db = db.db
283            else:
284                try:
285                    db = db._cnx
286                except AttributeError:
287                    pass
288        if not db or not hasattr(db, 'db') or not hasattr(db, 'query'):
289            db = connect(*args, **kw)
290            self._closeable = True
291        else:
292            self._closeable = False
293        self.db = db
294        self.dbname = db.db
295        self._regtypes = False
296        self._attnames = {}
297        self._pkeys = {}
298        self._privileges = {}
299        self._args = args, kw
300        self.debug = None  # For debugging scripts, this can be set
301            # * to a string format specification (e.g. in CGI set to "%s<BR>"),
302            # * to a file object to write debug statements or
303            # * to a callable object which takes a string argument
304            # * to any other true value to just print debug statements
305
306    def __getattr__(self, name):
307        # All undefined members are same as in underlying connection:
308        if self.db:
309            return getattr(self.db, name)
310        else:
311            raise _int_error('Connection is not valid')
312
313    # Context manager methods
314
315    def __enter__(self):
316        """Enter the runtime context. This will start a transaction."""
317        self.begin()
318        return self
319
320    def __exit__(self, et, ev, tb):
321        """Exit the runtime context. This will end the transaction."""
322        if et is None and ev is None and tb is None:
323            self.commit()
324        else:
325            self.rollback()
326
327    # Auxiliary methods
328
329    def _do_debug(self, s):
330        """Print a debug message."""
331        if self.debug:
332            if isinstance(self.debug, basestring):
333                print(self.debug % s)
334            elif isinstance(self.debug, file):
335                self.debug.write(s + '\n')
336            elif callable(self.debug):
337                self.debug(s)
338            else:
339                print(s)
340
341    def _make_bool(d):
342        """Get boolean value corresponding to d."""
343        if get_bool():
344            return bool(d)
345        return d and 't' or 'f'
346    _make_bool = staticmethod(_make_bool)
347
348    def _quote_text(self, d):
349        """Quote text value."""
350        if not isinstance(d, basestring):
351            d = str(d)
352        return "'%s'" % self.escape_string(d)
353
354    _bool_true = frozenset('t true 1 y yes on'.split())
355
356    def _quote_bool(self, d):
357        """Quote boolean value."""
358        if isinstance(d, basestring):
359            if not d:
360                return 'NULL'
361            d = d.lower() in self._bool_true
362        return d and "'t'" or "'f'"
363
364    _date_literals = frozenset('current_date current_time'
365        ' current_timestamp localtime localtimestamp'.split())
366
367    def _quote_date(self, d):
368        """Quote date value."""
369        if not d:
370            return 'NULL'
371        if isinstance(d, basestring) and d.lower() in self._date_literals:
372            return d
373        return self._quote_text(d)
374
375    def _quote_num(self, d):
376        """Quote numeric value."""
377        if not d and d != 0:
378            return 'NULL'
379        return str(d)
380
381    def _quote_money(self, d):
382        """Quote money value."""
383        if d is None or d == '':
384            return 'NULL'
385        if not isinstance(d, basestring):
386            d = str(d)
387        return d
388
389    _quote_funcs = dict(  # quote methods for each type
390        text=_quote_text, bool=_quote_bool, date=_quote_date,
391        int=_quote_num, num=_quote_num, float=_quote_num,
392        money=_quote_money)
393
394    def _quote(self, d, t):
395        """Return quotes if needed."""
396        if d is None:
397            return 'NULL'
398        try:
399            quote_func = self._quote_funcs[t]
400        except KeyError:
401            quote_func = self._quote_funcs['text']
402        return quote_func(self, d)
403
404    def _split_schema(self, cl):
405        """Return schema and name of object separately.
406
407        This auxiliary function splits off the namespace (schema)
408        belonging to the class with the name cl. If the class name
409        is not qualified, the function is able to determine the schema
410        of the class, taking into account the current search path.
411
412        """
413        s = _split_parts(cl)
414        if len(s) > 1:  # name already qualified?
415            # should be database.schema.table or schema.table
416            if len(s) > 3:
417                raise _prg_error('Too many dots in class name %s' % cl)
418            schema, cl = s[-2:]
419        else:
420            cl = s[0]
421            # determine search path
422            q = 'SELECT current_schemas(TRUE)'
423            schemas = self.db.query(q).getresult()[0][0][1:-1].split(',')
424            if schemas:  # non-empty path
425                # search schema for this object in the current search path
426                # (we could also use unnest with ordinality here to spare
427                # one query, but this is only possible since PostgreSQL 9.4)
428                q = ' UNION '.join(
429                    ["SELECT %d::integer AS n, '%s'::name AS nspname"
430                        % s for s in enumerate(schemas)])
431                q = ("SELECT nspname FROM pg_class r"
432                    " JOIN pg_namespace s ON r.relnamespace = s.oid"
433                    " JOIN (%s) AS p USING (nspname)"
434                    " WHERE r.relname = $1 ORDER BY n LIMIT 1" % q)
435                schema = self.db.query(q, (cl,)).getresult()
436                if schema:  # schema found
437                    schema = schema[0][0]
438                else:  # object not found in current search path
439                    schema = 'public'
440            else:  # empty path
441                schema = 'public'
442        return schema, cl
443
444    def _add_schema(self, cl):
445        """Ensure that the class name is prefixed with a schema name."""
446        return _join_parts(self._split_schema(cl))
447
448    # Public methods
449
450    # escape_string and escape_bytea exist as methods,
451    # so we define unescape_bytea as a method as well
452    unescape_bytea = staticmethod(unescape_bytea)
453
454    def close(self):
455        """Close the database connection."""
456        # Wraps shared library function so we can track state.
457        if self._closeable:
458            if self.db:
459                self.db.close()
460                self.db = None
461            else:
462                raise _int_error('Connection already closed')
463
464    def reset(self):
465        """Reset connection with current parameters.
466
467        All derived queries and large objects derived from this connection
468        will not be usable after this call.
469
470        """
471        if self.db:
472            self.db.reset()
473        else:
474            raise _int_error('Connection already closed')
475
476    def reopen(self):
477        """Reopen connection to the database.
478
479        Used in case we need another connection to the same database.
480        Note that we can still reopen a database that we have closed.
481
482        """
483        # There is no such shared library function.
484        if self._closeable:
485            db = connect(*self._args[0], **self._args[1])
486            if self.db:
487                self.db.close()
488            self.db = db
489
490    def begin(self, mode=None):
491        """Begin a transaction."""
492        qstr = 'BEGIN'
493        if mode:
494            qstr += ' ' + mode
495        return self.query(qstr)
496
497    start = begin
498
499    def commit(self):
500        """Commit the current transaction."""
501        return self.query('COMMIT')
502
503    end = commit
504
505    def rollback(self, name=None):
506        """Roll back the current transaction."""
507        qstr = 'ROLLBACK'
508        if name:
509            qstr += ' TO ' + name
510        return self.query(qstr)
511
512    def savepoint(self, name):
513        """Define a new savepoint within the current transaction."""
514        return self.query('SAVEPOINT ' + name)
515
516    def release(self, name):
517        """Destroy a previously defined savepoint."""
518        return self.query('RELEASE ' + name)
519
520    def get_parameter(self, parameter):
521        """Get the value of a run-time parameter.
522
523        If the parameter is a string, the return value will also be a string
524        that is the current setting of the run-time parameter with that name.
525
526        You can get several parameters at once by passing a list, set or dict.
527        When passing a list of parameter names, the return value will be a
528        corresponding list of parameter settings.  When passing a set of
529        parameter names, a new dict will be returned, mapping these parameter
530        names to their settings.  Finally, if you pass a dict as parameter,
531        its values will be set to the current parameter settings corresponding
532        to its keys.
533
534        By passing the special name 'all' as the parameter, you can get a dict
535        of all existing configuration parameters.
536        """
537        if isinstance(parameter, basestring):
538            parameter = [parameter]
539            values = None
540        elif isinstance(parameter, (list, tuple)):
541            values = []
542        elif isinstance(parameter, (set, frozenset)):
543            values = {}
544        elif isinstance(parameter, dict):
545            values = parameter
546        else:
547            raise TypeError(
548                'The parameter must be a string, list, set or dict')
549        if not parameter:
550            raise TypeError('No parameter has been specified')
551        if isinstance(values, dict):
552            params = {}
553        else:
554            params = []
555        for key in parameter:
556            if isinstance(key, basestring):
557                param = key.strip().lower()
558            else:
559                param = None
560            if not param:
561                raise TypeError('Invalid parameter')
562            if param == 'all':
563                q = 'SHOW ALL'
564                values = self.db.query(q).getresult()
565                values = dict(value[:2] for value in values)
566                break
567            if isinstance(values, dict):
568                params[param] = key
569            else:
570                params.append(param)
571        else:
572            for param in params:
573                q = 'SHOW %s' % (param,)
574                value = self.db.query(q).getresult()[0][0]
575                if values is None:
576                    values = value
577                elif isinstance(values, list):
578                    values.append(value)
579                else:
580                    values[params[param]] = value
581        return values
582
583    def set_parameter(self, parameter, value=None, local=False):
584        """Set the value of a run-time parameter.
585
586        If the parameter and the value are strings, the run-time parameter
587        will be set to that value.  If no value or None is passed as a value,
588        then the run-time parameter will be restored to its default value.
589
590        You can set several parameters at once by passing a list of parameter
591        names, together with a single value that all parameters should be
592        set to or with a corresponding list of values.  You can also pass
593        the parameters as a set if you only provide a single value.
594        Finally, you can pass a dict with parameter names as keys.  In this
595        case, you should not pass a value, since the values for the parameters
596        will be taken from the dict.
597
598        By passing the special name 'all' as the parameter, you can reset
599        all existing settable run-time parameters to their default values.
600
601        If you set local to True, then the command takes effect for only the
602        current transaction.  After commit() or rollback(), the session-level
603        setting takes effect again.  Setting local to True will appear to
604        have no effect if it is executed outside a transaction, since the
605        transaction will end immediately.
606        """
607        if isinstance(parameter, basestring):
608            parameter = {parameter: value}
609        elif isinstance(parameter, (list, tuple)):
610            if isinstance(value, (list, tuple)):
611                parameter = dict(zip(parameter, value))
612            else:
613                parameter = dict.fromkeys(parameter, value)
614        elif isinstance(parameter, (set, frozenset)):
615            if isinstance(value, (list, tuple, set, frozenset)):
616                value = set(value)
617                if len(value) == 1:
618                    value = value.pop()
619            if not(value is None or isinstance(value, basestring)):
620                raise ValueError('A single value must be specified'
621                    ' when parameter is a set')
622            parameter = dict.fromkeys(parameter, value)
623        elif isinstance(parameter, dict):
624            if value is not None:
625                raise ValueError('A value must not be specified'
626                    ' when parameter is a dictionary')
627        else:
628            raise TypeError(
629                'The parameter must be a string, list, set or dict')
630        if not parameter:
631            raise TypeError('No parameter has been specified')
632        params = {}
633        for key, value in parameter.items():
634            if isinstance(key, basestring):
635                param = key.strip().lower()
636            else:
637                param = None
638            if not param:
639                raise TypeError('Invalid parameter')
640            if param == 'all':
641                if value is not None:
642                    raise ValueError('A value must ot be specified'
643                        " when parameter is 'all'")
644                params = {'all': None}
645                break
646            params[param] = value
647        local = local and ' LOCAL' or ''
648        for param, value in params.items():
649            if value is None:
650                q = 'RESET%s %s' % (local, param)
651            else:
652                q = 'SET%s %s TO %s' % (local, param, value)
653            self._do_debug(q)
654            self.db.query(q)
655
656    def query(self, qstr, *args):
657        """Executes a SQL command string.
658
659        This method simply sends a SQL query to the database. If the query is
660        an insert statement that inserted exactly one row into a table that
661        has OIDs, the return value is the OID of the newly inserted row.
662        If the query is an update or delete statement, or an insert statement
663        that did not insert exactly one row in a table with OIDs, then the
664        number of rows affected is returned as a string. If it is a statement
665        that returns rows as a result (usually a select statement, but maybe
666        also an "insert/update ... returning" statement), this method returns
667        a pgqueryobject that can be accessed via getresult() or dictresult()
668        or simply printed. Otherwise, it returns `None`.
669
670        The query can contain numbered parameters of the form $1 in place
671        of any data constant. Arguments given after the query string will
672        be substituted for the corresponding numbered parameter. Parameter
673        values can also be given as a single list or tuple argument.
674
675        Note that the query string must not be passed as a unicode value,
676        but you can pass arguments as unicode values if they can be decoded
677        using the current client encoding.
678
679        """
680        # Wraps shared library function for debugging.
681        if not self.db:
682            raise _int_error('Connection is not valid')
683        self._do_debug(qstr)
684        return self.db.query(qstr, args)
685
686    def pkey(self, cl, newpkey=None):
687        """This method gets or sets the primary key of a class.
688
689        Composite primary keys are represented as frozensets. Note that
690        this raises a KeyError if the table does not have a primary key.
691
692        If newpkey is set and is not a dictionary then set that
693        value as the primary key of the class.  If it is a dictionary
694        then replace the internal cache of primary keys with a copy of it.
695
696        """
697        # First see if the caller is supplying a dictionary
698        if isinstance(newpkey, dict):
699            # make sure that all classes have a namespace
700            self._pkeys = dict([
701                ('.' in cl and cl or 'public.' + cl, pkey)
702                for cl, pkey in newpkey.items()])
703            return self._pkeys
704
705        qcl = self._add_schema(cl)  # build fully qualified class name
706        # Check if the caller is supplying a new primary key for the class
707        if newpkey:
708            self._pkeys[qcl] = newpkey
709            return newpkey
710
711        # Get all the primary keys at once
712        if qcl not in self._pkeys:
713            # if not found, check again in case it was added after we started
714            self._pkeys = {}
715            if self.server_version >= 80200:
716                # the ANY syntax works correctly only with PostgreSQL >= 8.2
717                any_indkey = "= ANY (i.indkey)"
718            else:
719                any_indkey = "IN (%s)" % ', '.join(
720                    ['i.indkey[%d]' % i for i in range(16)])
721            q = ("SELECT s.nspname, r.relname, a.attname"
722                " FROM pg_class r"
723                " JOIN pg_namespace s ON s.oid = r.relnamespace"
724                " AND s.nspname NOT SIMILAR"
725                " TO 'pg/_%|information/_schema' ESCAPE '/'"
726                " JOIN pg_attribute a ON a.attrelid = r.oid"
727                " AND NOT a.attisdropped"
728                " JOIN pg_index i ON i.indrelid = r.oid"
729                " AND i.indisprimary AND a.attnum " + any_indkey)
730            for r in self.db.query(q).getresult():
731                cl, pkey = _join_parts(r[:2]), r[2]
732                self._pkeys.setdefault(cl, []).append(pkey)
733            # (only) for composite primary keys, the values will be frozensets
734            for cl, pkey in self._pkeys.items():
735                self._pkeys[cl] = len(pkey) > 1 and frozenset(pkey) or pkey[0]
736            self._do_debug(self._pkeys)
737
738        # will raise an exception if primary key doesn't exist
739        return self._pkeys[qcl]
740
741    def get_databases(self):
742        """Get list of databases in the system."""
743        return [s[0] for s in
744            self.db.query('SELECT datname FROM pg_database').getresult()]
745
746    def get_relations(self, kinds=None):
747        """Get list of relations in connected database of specified kinds.
748
749        If kinds is None or empty, all kinds of relations are returned.
750        Otherwise kinds can be a string or sequence of type letters
751        specifying which kind of relations you want to list.
752
753        """
754        where = kinds and " AND r.relkind IN (%s)" % ','.join(
755            ["'%s'" % k for k in kinds]) or ''
756        q = ("SELECT s.nspname, r.relname"
757            " FROM pg_class r"
758            " JOIN pg_namespace s ON s.oid = r.relnamespace"
759            " WHERE s.nspname NOT SIMILAR"
760            " TO 'pg/_%%|information/_schema' ESCAPE '/' %s"
761            " ORDER BY 1, 2") % where
762        return [_join_parts(r) for r in self.db.query(q).getresult()]
763
764    def get_tables(self):
765        """Return list of tables in connected database."""
766        return self.get_relations('r')
767
768    def get_attnames(self, cl, newattnames=None):
769        """Given the name of a table, digs out the set of attribute names.
770
771        Returns a dictionary of attribute names (the names are the keys,
772        the values are the names of the attributes' types).
773        If the optional newattnames exists, it must be a dictionary and
774        will become the new attribute names dictionary.
775
776        By default, only a limited number of simple types will be returned.
777        You can get the regular types after calling use_regtypes(True).
778
779        """
780        if isinstance(newattnames, dict):
781            self._attnames = newattnames
782            return
783        elif newattnames:
784            raise _prg_error('If supplied, newattnames must be a dictionary')
785        cl = self._split_schema(cl)  # split into schema and class
786        qcl = _join_parts(cl)  # build fully qualified name
787        # May as well cache them:
788        if qcl in self._attnames:
789            return self._attnames[qcl]
790        if qcl not in self.get_relations('rv'):
791            raise _prg_error('Class %s does not exist' % qcl)
792
793        q = ("SELECT a.attname, t.typname%s"
794            " FROM pg_class r"
795            " JOIN pg_namespace s ON r.relnamespace = s.oid"
796            " JOIN pg_attribute a ON a.attrelid = r.oid"
797            " JOIN pg_type t ON t.oid = a.atttypid"
798            " WHERE s.nspname = $1 AND r.relname = $2"
799            " AND (a.attnum > 0 OR a.attname = 'oid')"
800            " AND NOT a.attisdropped") % (
801                self._regtypes and '::regtype' or '',)
802        q = self.db.query(q, cl).getresult()
803
804        if self._regtypes:
805            t = dict(q)
806        else:
807            t = {}
808            for att, typ in q:
809                if typ.startswith('bool'):
810                    typ = 'bool'
811                elif typ.startswith('abstime'):
812                    typ = 'date'
813                elif typ.startswith('date'):
814                    typ = 'date'
815                elif typ.startswith('interval'):
816                    typ = 'date'
817                elif typ.startswith('timestamp'):
818                    typ = 'date'
819                elif typ.startswith('oid'):
820                    typ = 'int'
821                elif typ.startswith('int'):
822                    typ = 'int'
823                elif typ.startswith('float'):
824                    typ = 'float'
825                elif typ.startswith('numeric'):
826                    typ = 'num'
827                elif typ.startswith('money'):
828                    typ = 'money'
829                else:
830                    typ = 'text'
831                t[att] = typ
832
833        self._attnames[qcl] = t  # cache it
834        return self._attnames[qcl]
835
836    def use_regtypes(self, regtypes=None):
837        """Use regular type names instead of simplified type names."""
838        if regtypes is None:
839            return self._regtypes
840        else:
841            regtypes = bool(regtypes)
842            if regtypes != self._regtypes:
843                self._regtypes = regtypes
844                self._attnames.clear()
845            return regtypes
846
847    def has_table_privilege(self, cl, privilege='select'):
848        """Check whether current user has specified table privilege."""
849        qcl = self._add_schema(cl)
850        privilege = privilege.lower()
851        try:
852            return self._privileges[(qcl, privilege)]
853        except KeyError:
854            q = "SELECT has_table_privilege($1, $2)"
855            q = self.db.query(q, (qcl, privilege))
856            ret = q.getresult()[0][0] == self._make_bool(True)
857            self._privileges[(qcl, privilege)] = ret
858            return ret
859
860    def get(self, cl, arg, keyname=None):
861        """Get a row from a database table or view.
862
863        This method is the basic mechanism to get a single row.  The keyname
864        that the key specifies a unique row.  If keyname is not specified
865        then the primary key for the table is used.  If arg is a dictionary
866        then the value for the key is taken from it and it is modified to
867        include the new values, replacing existing values where necessary.
868        For a composite key, keyname can also be a sequence of key names.
869        The OID is also put into the dictionary if the table has one, but
870        in order to allow the caller to work with multiple tables, it is
871        munged as oid(schema.table).
872
873        """
874        if cl.endswith('*'):  # scan descendant tables?
875            cl = cl[:-1].rstrip()  # need parent table name
876        # build qualified class name
877        qcl = self._add_schema(cl)
878        # To allow users to work with multiple tables,
879        # we munge the name of the "oid" key
880        qoid = _oid_key(qcl)
881        if not keyname:
882            # use the primary key by default
883            try:
884                keyname = self.pkey(qcl)
885            except KeyError:
886                raise _prg_error('Class %s has no primary key' % qcl)
887        # We want the oid for later updates if that isn't the key
888        if keyname == 'oid':
889            if isinstance(arg, dict):
890                if qoid not in arg:
891                    raise _db_error('%s not in arg' % qoid)
892            else:
893                arg = {qoid: arg}
894            where = 'oid = %s' % arg[qoid]
895            attnames = '*'
896        else:
897            attnames = self.get_attnames(qcl)
898            if isinstance(keyname, basestring):
899                keyname = (keyname,)
900            if not isinstance(arg, dict):
901                if len(keyname) > 1:
902                    raise _prg_error('Composite key needs dict as arg')
903                arg = dict([(k, arg) for k in keyname])
904            where = ' AND '.join(['%s = %s'
905                % (k, self._quote(arg[k], attnames[k])) for k in keyname])
906            attnames = ', '.join(attnames)
907        q = 'SELECT %s FROM %s WHERE %s LIMIT 1' % (attnames, qcl, where)
908        self._do_debug(q)
909        res = self.db.query(q).dictresult()
910        if not res:
911            raise _db_error('No such record in %s where %s' % (qcl, where))
912        for att, value in res[0].items():
913            arg[att == 'oid' and qoid or att] = value
914        return arg
915
916    def insert(self, cl, d=None, **kw):
917        """Insert a row into a database table.
918
919        This method inserts a row into a table.  The name of the table must
920        be passed as the first parameter.  The other parameters are used for
921        providing the data of the row that shall be inserted into the table.
922        If a dictionary is supplied as the second parameter, it starts with
923        that.  Otherwise it uses a blank dictionary. Either way the dictionary
924        is updated from the keywords.
925
926        The dictionary is then, if possible, reloaded with the values actually
927        inserted in order to pick up values modified by rules, triggers, etc.
928
929        Note: The method currently doesn't support insert into views
930        although PostgreSQL does.
931
932        """
933        qcl = self._add_schema(cl)
934        qoid = _oid_key(qcl)
935        if d is None:
936            d = {}
937        d.update(kw)
938        attnames = self.get_attnames(qcl)
939        names, values = [], []
940        for n in attnames:
941            if n != 'oid' and n in d:
942                names.append('"%s"' % n)
943                values.append(self._quote(d[n], attnames[n]))
944        names, values = ', '.join(names), ', '.join(values)
945        selectable = self.has_table_privilege(qcl)
946        if selectable and self.server_version >= 80200:
947            ret = ' RETURNING %s*' % ('oid' in attnames and 'oid, ' or '')
948        else:
949            ret = ''
950        q = 'INSERT INTO %s (%s) VALUES (%s)%s' % (qcl, names, values, ret)
951        self._do_debug(q)
952        res = self.db.query(q)
953        if ret:
954            res = res.dictresult()
955            for att, value in res[0].items():
956                d[att == 'oid' and qoid or att] = value
957        elif isinstance(res, int):
958            d[qoid] = res
959            if selectable:
960                self.get(qcl, d, 'oid')
961        elif selectable:
962            if qoid in d:
963                self.get(qcl, d, 'oid')
964            else:
965                try:
966                    self.get(qcl, d)
967                except ProgrammingError:
968                    pass  # table has no primary key
969        return d
970
971    def update(self, cl, d=None, **kw):
972        """Update an existing row in a database table.
973
974        Similar to insert but updates an existing row.  The update is based
975        on the OID value as munged by get or passed as keyword, or on the
976        primary key of the table.  The dictionary is modified, if possible,
977        to reflect any changes caused by the update due to triggers, rules,
978        default values, etc.
979
980        """
981        # Update always works on the oid which get() returns if available,
982        # otherwise use the primary key.  Fail if neither.
983        # Note that we only accept oid key from named args for safety.
984        qcl = self._add_schema(cl)
985        qoid = _oid_key(qcl)
986        if 'oid' in kw:
987            kw[qoid] = kw['oid']
988            del kw['oid']
989        if d is None:
990            d = {}
991        d.update(kw)
992        attnames = self.get_attnames(qcl)
993        if qoid in d:
994            where = 'oid = %s' % d[qoid]
995            keyname = ()
996        else:
997            try:
998                keyname = self.pkey(qcl)
999            except KeyError:
1000                raise _prg_error('Class %s has no primary key' % qcl)
1001            if isinstance(keyname, basestring):
1002                keyname = (keyname,)
1003            try:
1004                where = ' AND '.join(['%s = %s'
1005                    % (k, self._quote(d[k], attnames[k])) for k in keyname])
1006            except KeyError:
1007                raise _prg_error('Update needs primary key or oid.')
1008        values = []
1009        for n in attnames:
1010            if n in d and n not in keyname:
1011                values.append('%s = %s' % (n, self._quote(d[n], attnames[n])))
1012        if not values:
1013            return d
1014        values = ', '.join(values)
1015        selectable = self.has_table_privilege(qcl)
1016        if selectable and self.server_version >= 80200:
1017            ret = ' RETURNING %s*' % ('oid' in attnames and 'oid, ' or '')
1018        else:
1019            ret = ''
1020        q = 'UPDATE %s SET %s WHERE %s%s' % (qcl, values, where, ret)
1021        self._do_debug(q)
1022        res = self.db.query(q)
1023        if ret:
1024            res = res.dictresult()[0]
1025            for att, value in res.items():
1026                d[att == 'oid' and qoid or att] = value
1027        else:
1028            if selectable:
1029                if qoid in d:
1030                    self.get(qcl, d, 'oid')
1031                else:
1032                    self.get(qcl, d)
1033        return d
1034
1035    def clear(self, cl, d=None):
1036        """Clear all the attributes to values determined by the types.
1037
1038        Numeric types are set to 0, Booleans are set to false, and everything
1039        else is set to the empty string.  If the second argument is present,
1040        it is used as the row dictionary and any entries matching attribute
1041        names are cleared with everything else left unchanged.
1042
1043        """
1044        # At some point we will need a way to get defaults from a table.
1045        qcl = self._add_schema(cl)
1046        if d is None:
1047            d = {}  # empty if argument is not present
1048        attnames = self.get_attnames(qcl)
1049        for n, t in attnames.items():
1050            if n == 'oid':
1051                continue
1052            if t in ('int', 'integer', 'smallint', 'bigint',
1053                    'float', 'real', 'double precision',
1054                    'num', 'numeric', 'money'):
1055                d[n] = 0
1056            elif t in ('bool', 'boolean'):
1057                d[n] = self._make_bool(False)
1058            else:
1059                d[n] = ''
1060        return d
1061
1062    def delete(self, cl, d=None, **kw):
1063        """Delete an existing row in a database table.
1064
1065        This method deletes the row from a table.  It deletes based on the
1066        OID value as munged by get or passed as keyword, or on the primary
1067        key of the table.  The return value is the number of deleted rows
1068        (i.e. 0 if the row did not exist and 1 if the row was deleted).
1069
1070        """
1071        # Like update, delete works on the oid.
1072        # One day we will be testing that the record to be deleted
1073        # isn't referenced somewhere (or else PostgreSQL will).
1074        # Note that we only accept oid key from named args for safety
1075        qcl = self._add_schema(cl)
1076        qoid = _oid_key(qcl)
1077        if 'oid' in kw:
1078            kw[qoid] = kw['oid']
1079            del kw['oid']
1080        if d is None:
1081            d = {}
1082        d.update(kw)
1083        if qoid in d:
1084            where = 'oid = %s' % d[qoid]
1085        else:
1086            try:
1087                keyname = self.pkey(qcl)
1088            except KeyError:
1089                raise _prg_error('Class %s has no primary key' % qcl)
1090            if isinstance(keyname, basestring):
1091                keyname = (keyname,)
1092            attnames = self.get_attnames(qcl)
1093            try:
1094                where = ' AND '.join(['%s = %s'
1095                    % (k, self._quote(d[k], attnames[k])) for k in keyname])
1096            except KeyError:
1097                raise _prg_error('Delete needs primary key or oid.')
1098        q = 'DELETE FROM %s WHERE %s' % (qcl, where)
1099        self._do_debug(q)
1100        return int(self.db.query(q))
1101
1102    def truncate(self, table, restart=False, cascade=False, only=False):
1103        """Empty a table or set of tables.
1104
1105        This method quickly removes all rows from the given table or set
1106        of tables.  It has the same effect as an unqualified DELETE on each
1107        table, but since it does not actually scan the tables it is faster.
1108        Furthermore, it reclaims disk space immediately, rather than requiring
1109        a subsequent VACUUM operation. This is most useful on large tables.
1110
1111        If restart is set to True, sequences owned by columns of the truncated
1112        table(s) are automatically restarted.  If cascade is set to True, it
1113        also truncates all tables that have foreign-key references to any of
1114        the named tables.  If the parameter only is not set to True, all the
1115        descendant tables (if any) will also be truncated. Optionally, a '*'
1116        can be specified after the table name to explicitly indicate that
1117        descendant tables are included.
1118        """
1119        if isinstance(table, basestring):
1120            only = {table: only}
1121            table = [table]
1122        elif isinstance(table, (list, tuple)):
1123            if isinstance(only, (list, tuple)):
1124                only = dict(zip(table, only))
1125            else:
1126                only = dict.fromkeys(table, only)
1127        elif isinstance(table, (set, frozenset)):
1128            only = dict.fromkeys(table, only)
1129        else:
1130            raise TypeError('The table must be a string, list or set')
1131        if not (restart is None or isinstance(restart, (bool, int))):
1132            raise TypeError('Invalid type for the restart option')
1133        if not (cascade is None or isinstance(cascade, (bool, int))):
1134            raise TypeError('Invalid type for the cascade option')
1135        tables = []
1136        for t in table:
1137            u = only.get(t)
1138            if not (u is None or isinstance(u, (bool, int))):
1139                raise TypeError('Invalid type for the only option')
1140            if t.endswith('*'):
1141                if u:
1142                    raise ValueError(
1143                        'Contradictory table name and only options')
1144                t = t[:-1].rstrip()
1145            t = self._add_schema(t)
1146            if u:
1147                t = 'ONLY %s' % t
1148            tables.append(t)
1149        q = ['TRUNCATE', ', '.join(tables)]
1150        if restart:
1151            q.append('RESTART IDENTITY')
1152        if cascade:
1153            q.append('CASCADE')
1154        q = ' '.join(q)
1155        self._do_debug(q)
1156        return self.query(q)
1157
1158    def notification_handler(self, event, callback, arg_dict={}, timeout=None):
1159        """Get notification handler that will run the given callback."""
1160        return NotificationHandler(self.db, event, callback, arg_dict, timeout)
1161
1162
1163# if run as script, print some information
1164
1165if __name__ == '__main__':
1166    print('PyGreSQL version' + version)
1167    print('')
1168    print(__doc__)
Note: See TracBrowser for help on using the repository browser.