source: branches/4.x/module/pg.py @ 709

Last change on this file since 709 was 709, checked in by cito, 3 years ago

get_tables() should not list the information schema tables

Since get_tables() does not return the other system tables starting with pg_,
so it should not return the information schema tables either.

Also removed an ancient check for tables starting with Inv
that is not relevant any more since PostgreSQL 7.1 or so.

  • Property svn:keywords set to Id
File size: 34.4 KB
Line 
1#! /usr/bin/python
2#
3# pg.py
4#
5# $Id: pg.py 709 2016-01-10 22:12:03Z cito $
6#
7
8"""PyGreSQL classic interface.
9
10This pg module implements some basic database management stuff.
11It includes the _pg module and builds on it, providing the higher
12level wrapper class named DB with additional functionality.
13This is known as the "classic" ("old style") PyGreSQL interface.
14For a DB-API 2 compliant interface use the newer pgdb module.
15
16"""
17
18# Copyright (c) 1997-2013 by D'Arcy J.M. Cain.
19#
20# Contributions made by Ch. Zwerschke and others.
21#
22# The notification handler is based on pgnotify which is
23# Copyright (c) 2001 Ng Pheng Siong. All rights reserved.
24#
25# Permission to use, copy, modify, and distribute this software and its
26# documentation for any purpose and without fee is hereby granted,
27# provided that the above copyright notice appear in all copies and that
28# both that copyright notice and this permission notice appear in
29# supporting documentation.
30
31from _pg import *
32
33import select
34import warnings
35try:
36    frozenset
37except NameError:  # Python < 2.4, unsupported
38    from sets import ImmutableSet as frozenset
39try:
40    from decimal import Decimal
41    set_decimal(Decimal)
42except ImportError:  # Python < 2.4, unsupported
43    Decimal = float
44try:
45    from collections import namedtuple
46except ImportError:  # Python < 2.6
47    namedtuple = None
48
49
50# Auxiliary functions that are independent of a DB connection:
51
52def _is_quoted(s):
53    """Check whether this string is a quoted identifier."""
54    s = s.replace('_', 'a')
55    return not s.isalnum() or s[:1].isdigit() or s != s.lower()
56
57
58def _is_unquoted(s):
59    """Check whether this string is an unquoted identifier."""
60    s = s.replace('_', 'a')
61    return s.isalnum() and not s[:1].isdigit()
62
63
64def _split_first_part(s):
65    """Split the first part of a dot separated string."""
66    s = s.lstrip()
67    if s[:1] == '"':
68        p = []
69        s = s.split('"', 3)[1:]
70        p.append(s[0])
71        while len(s) == 3 and s[1] == '':
72            p.append('"')
73            s = s[2].split('"', 2)
74            p.append(s[0])
75        p = [''.join(p)]
76        s = '"'.join(s[1:]).lstrip()
77        if s:
78            if s[:0] == '.':
79                p.append(s[1:])
80            else:
81                s = _split_first_part(s)
82                p[0] += s[0]
83                if len(s) > 1:
84                    p.append(s[1])
85    else:
86        p = s.split('.', 1)
87        s = p[0].rstrip()
88        if _is_unquoted(s):
89            s = s.lower()
90        p[0] = s
91    return p
92
93
94def _split_parts(s):
95    """Split all parts of a dot separated string."""
96    q = []
97    while s:
98        s = _split_first_part(s)
99        q.append(s[0])
100        if len(s) < 2:
101            break
102        s = s[1]
103    return q
104
105
106def _join_parts(s):
107    """Join all parts of a dot separated string."""
108    return '.'.join([_is_quoted(p) and '"%s"' % p or p for p in s])
109
110
111def _oid_key(qcl):
112    """Build oid key from qualified class name."""
113    return 'oid(%s)' % qcl
114
115
116if namedtuple:
117
118    def _namedresult(q):
119        """Get query result as named tuples."""
120        row = namedtuple('Row', q.listfields())
121        return [row(*r) for r in q.getresult()]
122
123    set_namedresult(_namedresult)
124
125
126def _db_error(msg, cls=DatabaseError):
127    """Returns DatabaseError with empty sqlstate attribute."""
128    error = cls(msg)
129    error.sqlstate = None
130    return error
131
132
133def _int_error(msg):
134    """Returns InternalError."""
135    return _db_error(msg, InternalError)
136
137
138def _prg_error(msg):
139    """Returns ProgrammingError."""
140    return _db_error(msg, ProgrammingError)
141
142
143class NotificationHandler(object):
144    """A PostgreSQL client-side asynchronous notification handler."""
145
146    def __init__(self, db, event, callback, arg_dict=None, timeout=None):
147        """Initialize the notification handler.
148
149        db       - PostgreSQL connection object.
150        event    - Event (notification channel) to LISTEN for.
151        callback - Event callback function.
152        arg_dict - A dictionary passed as the argument to the callback.
153        timeout  - Timeout in seconds; a floating point number denotes
154                   fractions of seconds. If it is absent or None, the
155                   callers will never time out.
156
157        """
158        if isinstance(db, DB):
159            db = db.db
160        self.db = db
161        self.event = event
162        self.stop_event = 'stop_%s' % event
163        self.listening = False
164        self.callback = callback
165        if arg_dict is None:
166            arg_dict = {}
167        self.arg_dict = arg_dict
168        self.timeout = timeout
169
170    def __del__(self):
171        self.close()
172
173    def close(self):
174        """Stop listening and close the connection."""
175        if self.db:
176            self.unlisten()
177            self.db.close()
178            self.db = None
179
180    def listen(self):
181        """Start listening for the event and the stop event."""
182        if not self.listening:
183            self.db.query('listen "%s"' % self.event)
184            self.db.query('listen "%s"' % self.stop_event)
185            self.listening = True
186
187    def unlisten(self):
188        """Stop listening for the event and the stop event."""
189        if self.listening:
190            self.db.query('unlisten "%s"' % self.event)
191            self.db.query('unlisten "%s"' % self.stop_event)
192            self.listening = False
193
194    def notify(self, db=None, stop=False, payload=None):
195        """Generate a notification.
196
197        Note: If the main loop is running in another thread, you must pass
198        a different database connection to avoid a collision.
199
200        The payload parameter is only supported in PostgreSQL >= 9.0.
201
202        """
203        if not db:
204            db = self.db
205        if self.listening:
206            q = 'notify "%s"' % (stop and self.stop_event or self.event)
207            if payload:
208                q += ", '%s'" % payload
209            return db.query(q)
210
211    def __call__(self, close=False):
212        """Invoke the notification handler.
213
214        The handler is a loop that actually LISTENs for two NOTIFY messages:
215
216        <event> and stop_<event>.
217
218        When either of these NOTIFY messages are received, its associated
219        'pid' and 'event' are inserted into <arg_dict>, and the callback is
220        invoked with <arg_dict>. If the NOTIFY message is stop_<event>, the
221        handler UNLISTENs both <event> and stop_<event> and exits.
222
223        Note: If you run this loop in another thread, don't use the same
224        database connection for database operations in the main thread.
225
226        """
227        self.listen()
228        _ilist = [self.db.fileno()]
229
230        while self.listening:
231            ilist, _olist, _elist = select.select(_ilist, [], [], self.timeout)
232            if ilist:
233                while self.listening:
234                    notice = self.db.getnotify()
235                    if not notice:  # no more messages
236                        break
237                    event, pid, extra = notice
238                    if event not in (self.event, self.stop_event):
239                        self.unlisten()
240                        raise _db_error(
241                            'listening for "%s" and "%s", but notified of "%s"'
242                            % (self.event, self.stop_event, event))
243                    if event == self.stop_event:
244                        self.unlisten()
245                    self.arg_dict['pid'] = pid
246                    self.arg_dict['event'] = event
247                    self.arg_dict['extra'] = extra
248                    self.callback(self.arg_dict)
249            else:   # we timed out
250                self.unlisten()
251                self.callback(None)
252
253
254def pgnotify(*args, **kw):
255    """Same as NotificationHandler, under the traditional name."""
256    warnings.warn("pgnotify is deprecated, use NotificationHandler instead.",
257        DeprecationWarning, stacklevel=2)
258    return NotificationHandler(*args, **kw)
259
260
261# The actual PostGreSQL database connection interface:
262
263class DB(object):
264    """Wrapper class for the _pg connection type."""
265
266    def __init__(self, *args, **kw):
267        """Create a new connection.
268
269        You can pass either the connection parameters or an existing
270        _pg or pgdb connection. This allows you to use the methods
271        of the classic pg interface with a DB-API 2 pgdb connection.
272
273        """
274        if not args and len(kw) == 1:
275            db = kw.get('db')
276        elif not kw and len(args) == 1:
277            db = args[0]
278        else:
279            db = None
280        if db:
281            if isinstance(db, DB):
282                db = db.db
283            else:
284                try:
285                    db = db._cnx
286                except AttributeError:
287                    pass
288        if not db or not hasattr(db, 'db') or not hasattr(db, 'query'):
289            db = connect(*args, **kw)
290            self._closeable = True
291        else:
292            self._closeable = False
293        self.db = db
294        self.dbname = db.db
295        self._regtypes = False
296        self._attnames = {}
297        self._pkeys = {}
298        self._privileges = {}
299        self._args = args, kw
300        self.debug = None  # For debugging scripts, this can be set
301            # * to a string format specification (e.g. in CGI set to "%s<BR>"),
302            # * to a file object to write debug statements or
303            # * to a callable object which takes a string argument
304            # * to any other true value to just print debug statements
305
306    def __getattr__(self, name):
307        # All undefined members are same as in underlying pg connection:
308        if self.db:
309            return getattr(self.db, name)
310        else:
311            raise _int_error('Connection is not valid')
312
313    # Context manager methods
314
315    def __enter__(self):
316        """Enter the runtime context. This will start a transaction."""
317        self.begin()
318        return self
319
320    def __exit__(self, et, ev, tb):
321        """Exit the runtime context. This will end the transaction."""
322        if et is None and ev is None and tb is None:
323            self.commit()
324        else:
325            self.rollback()
326
327    # Auxiliary methods
328
329    def _do_debug(self, s):
330        """Print a debug message."""
331        if self.debug:
332            if isinstance(self.debug, basestring):
333                print(self.debug % s)
334            elif isinstance(self.debug, file):
335                self.debug.write(s + '\n')
336            elif callable(self.debug):
337                self.debug(s)
338            else:
339                print(s)
340
341    def _make_bool(d):
342        """Get boolean value corresponding to d."""
343        if get_bool():
344            return bool(d)
345        return d and 't' or 'f'
346    _make_bool = staticmethod(_make_bool)
347
348    def _quote_text(self, d):
349        """Quote text value."""
350        if not isinstance(d, basestring):
351            d = str(d)
352        return "'%s'" % self.escape_string(d)
353
354    _bool_true = frozenset('t true 1 y yes on'.split())
355
356    def _quote_bool(self, d):
357        """Quote boolean value."""
358        if isinstance(d, basestring):
359            if not d:
360                return 'NULL'
361            d = d.lower() in self._bool_true
362        return d and "'t'" or "'f'"
363
364    _date_literals = frozenset('current_date current_time'
365        ' current_timestamp localtime localtimestamp'.split())
366
367    def _quote_date(self, d):
368        """Quote date value."""
369        if not d:
370            return 'NULL'
371        if isinstance(d, basestring) and d.lower() in self._date_literals:
372            return d
373        return self._quote_text(d)
374
375    def _quote_num(self, d):
376        """Quote numeric value."""
377        if not d and d != 0:
378            return 'NULL'
379        return str(d)
380
381    def _quote_money(self, d):
382        """Quote money value."""
383        if d is None or d == '':
384            return 'NULL'
385        if not isinstance(d, basestring):
386            d = str(d)
387        return d
388
389    _quote_funcs = dict(  # quote methods for each type
390        text=_quote_text, bool=_quote_bool, date=_quote_date,
391        int=_quote_num, num=_quote_num, float=_quote_num,
392        money=_quote_money)
393
394    def _quote(self, d, t):
395        """Return quotes if needed."""
396        if d is None:
397            return 'NULL'
398        try:
399            quote_func = self._quote_funcs[t]
400        except KeyError:
401            quote_func = self._quote_funcs['text']
402        return quote_func(self, d)
403
404    def _split_schema(self, cl):
405        """Return schema and name of object separately.
406
407        This auxiliary function splits off the namespace (schema)
408        belonging to the class with the name cl. If the class name
409        is not qualified, the function is able to determine the schema
410        of the class, taking into account the current search path.
411
412        """
413        s = _split_parts(cl)
414        if len(s) > 1:  # name already qualfied?
415            # should be database.schema.table or schema.table
416            if len(s) > 3:
417                raise _prg_error('Too many dots in class name %s' % cl)
418            schema, cl = s[-2:]
419        else:
420            cl = s[0]
421            # determine search path
422            q = 'SELECT current_schemas(TRUE)'
423            schemas = self.db.query(q).getresult()[0][0][1:-1].split(',')
424            if schemas:  # non-empty path
425                # search schema for this object in the current search path
426                q = ' UNION '.join(
427                    ["SELECT %d::integer AS n, '%s'::name AS nspname"
428                        % s for s in enumerate(schemas)])
429                q = ("SELECT nspname FROM pg_class"
430                    " JOIN pg_namespace"
431                    " ON pg_class.relnamespace = pg_namespace.oid"
432                    " JOIN (%s) AS p USING (nspname)"
433                    " WHERE pg_class.relname = '%s'"
434                    " ORDER BY n LIMIT 1" % (q, cl))
435                schema = self.db.query(q).getresult()
436                if schema:  # schema found
437                    schema = schema[0][0]
438                else:  # object not found in current search path
439                    schema = 'public'
440            else:  # empty path
441                schema = 'public'
442        return schema, cl
443
444    def _add_schema(self, cl):
445        """Ensure that the class name is prefixed with a schema name."""
446        return _join_parts(self._split_schema(cl))
447
448    # Public methods
449
450    # escape_string and escape_bytea exist as methods,
451    # so we define unescape_bytea as a method as well
452    unescape_bytea = staticmethod(unescape_bytea)
453
454    def close(self):
455        """Close the database connection."""
456        # Wraps shared library function so we can track state.
457        if self._closeable:
458            if self.db:
459                self.db.close()
460                self.db = None
461            else:
462                raise _int_error('Connection already closed')
463
464    def reset(self):
465        """Reset connection with current parameters.
466
467        All derived queries and large objects derived from this connection
468        will not be usable after this call.
469
470        """
471        if self.db:
472            self.db.reset()
473        else:
474            raise _int_error('Connection already closed')
475
476    def reopen(self):
477        """Reopen connection to the database.
478
479        Used in case we need another connection to the same database.
480        Note that we can still reopen a database that we have closed.
481
482        """
483        # There is no such shared library function.
484        if self._closeable:
485            db = connect(*self._args[0], **self._args[1])
486            if self.db:
487                self.db.close()
488            self.db = db
489
490    def begin(self, mode=None):
491        """Begin a transaction."""
492        qstr = 'BEGIN'
493        if mode:
494            qstr += ' ' + mode
495        return self.query(qstr)
496
497    start = begin
498
499    def commit(self):
500        """Commit the current transaction."""
501        return self.query('COMMIT')
502
503    end = commit
504
505    def rollback(self, name=None):
506        """Rollback the current transaction."""
507        qstr = 'ROLLBACK'
508        if name:
509            qstr += ' TO ' + name
510        return self.query(qstr)
511
512    def savepoint(self, name=None):
513        """Define a new savepoint within the current transaction."""
514        qstr = 'SAVEPOINT'
515        if name:
516            qstr += ' ' + name
517        return self.query(qstr)
518
519    def release(self, name):
520        """Destroy a previously defined savepoint."""
521        return self.query('RELEASE ' + name)
522
523    def query(self, qstr, *args):
524        """Executes a SQL command string.
525
526        This method simply sends a SQL query to the database. If the query is
527        an insert statement that inserted exactly one row into a table that
528        has OIDs, the return value is the OID of the newly inserted row.
529        If the query is an update or delete statement, or an insert statement
530        that did not insert exactly one row in a table with OIDs, then the
531        numer of rows affected is returned as a string. If it is a statement
532        that returns rows as a result (usually a select statement, but maybe
533        also an "insert/update ... returning" statement), this method returns
534        a pgqueryobject that can be accessed via getresult() or dictresult()
535        or simply printed. Otherwise, it returns `None`.
536
537        The query can contain numbered parameters of the form $1 in place
538        of any data constant. Arguments given after the query string will
539        be substituted for the corresponding numbered parameter. Parameter
540        values can also be given as a single list or tuple argument.
541
542        Note that the query string must not be passed as a unicode value,
543        but you can pass arguments as unicode values if they can be decoded
544        using the current client encoding.
545
546        """
547        # Wraps shared library function for debugging.
548        if not self.db:
549            raise _int_error('Connection is not valid')
550        self._do_debug(qstr)
551        return self.db.query(qstr, args)
552
553    def pkey(self, cl, newpkey=None):
554        """This method gets or sets the primary key of a class.
555
556        Composite primary keys are represented as frozensets. Note that
557        this raises an exception if the table does not have a primary key.
558
559        If newpkey is set and is not a dictionary then set that
560        value as the primary key of the class.  If it is a dictionary
561        then replace the _pkeys dictionary with a copy of it.
562
563        """
564        # First see if the caller is supplying a dictionary
565        if isinstance(newpkey, dict):
566            # make sure that all classes have a namespace
567            self._pkeys = dict([
568                ('.' in cl and cl or 'public.' + cl, pkey)
569                for cl, pkey in newpkey.items()])
570            return self._pkeys
571
572        qcl = self._add_schema(cl)  # build fully qualified class name
573        # Check if the caller is supplying a new primary key for the class
574        if newpkey:
575            self._pkeys[qcl] = newpkey
576            return newpkey
577
578        # Get all the primary keys at once
579        if qcl not in self._pkeys:
580            # if not found, check again in case it was added after we started
581            self._pkeys = {}
582            if self.server_version >= 80200:
583                # the ANY syntax works correctly only with PostgreSQL >= 8.2
584                any_indkey = "= ANY (pg_index.indkey)"
585            else:
586                any_indkey = "IN (%s)" % ', '.join(
587                    ['pg_index.indkey[%d]' % i for i in range(16)])
588            for r in self.db.query(
589                "SELECT pg_namespace.nspname, pg_class.relname,"
590                    " pg_attribute.attname FROM pg_class"
591                " JOIN pg_namespace"
592                    " ON pg_namespace.oid = pg_class.relnamespace"
593                    " AND pg_namespace.nspname NOT LIKE 'pg_%'"
594                " JOIN pg_attribute ON pg_attribute.attrelid = pg_class.oid"
595                    " AND pg_attribute.attisdropped = 'f'"
596                " JOIN pg_index ON pg_index.indrelid = pg_class.oid"
597                    " AND pg_index.indisprimary = 't'"
598                    " AND pg_attribute.attnum " + any_indkey).getresult():
599                cl, pkey = _join_parts(r[:2]), r[2]
600                self._pkeys.setdefault(cl, []).append(pkey)
601            # (only) for composite primary keys, the values will be frozensets
602            for cl, pkey in self._pkeys.items():
603                self._pkeys[cl] = len(pkey) > 1 and frozenset(pkey) or pkey[0]
604            self._do_debug(self._pkeys)
605
606        # will raise an exception if primary key doesn't exist
607        return self._pkeys[qcl]
608
609    def get_databases(self):
610        """Get list of databases in the system."""
611        return [s[0] for s in
612            self.db.query('SELECT datname FROM pg_database').getresult()]
613
614    def get_relations(self, kinds=None):
615        """Get list of relations in connected database of specified kinds.
616
617            If kinds is None or empty, all kinds of relations are returned.
618            Otherwise kinds can be a string or sequence of type letters
619            specifying which kind of relations you want to list.
620
621        """
622        where = kinds and "pg_class.relkind IN (%s) AND" % ','.join(
623            ["'%s'" % x for x in kinds]) or ''
624        return [_join_parts(x) for x in self.db.query(
625            "SELECT pg_namespace.nspname, pg_class.relname"
626            " FROM pg_class "
627            " JOIN pg_namespace ON pg_namespace.oid = pg_class.relnamespace "
628            " WHERE %s pg_namespace.nspname != 'information_schema'"
629                " AND pg_namespace.nspname !~ '^pg_' "
630            " ORDER BY 1, 2" % where).getresult()]
631
632    def get_tables(self):
633        """Return list of tables in connected database."""
634        return self.get_relations('r')
635
636    def get_attnames(self, cl, newattnames=None):
637        """Given the name of a table, digs out the set of attribute names.
638
639        Returns a dictionary of attribute names (the names are the keys,
640        the values are the names of the attributes' types).
641        If the optional newattnames exists, it must be a dictionary and
642        will become the new attribute names dictionary.
643
644        By default, only a limited number of simple types will be returned.
645        You can get the regular types after calling use_regtypes(True).
646
647        """
648        if isinstance(newattnames, dict):
649            self._attnames = newattnames
650            return
651        elif newattnames:
652            raise _prg_error('If supplied, newattnames must be a dictionary')
653        cl = self._split_schema(cl)  # split into schema and class
654        qcl = _join_parts(cl)  # build fully qualified name
655        # May as well cache them:
656        if qcl in self._attnames:
657            return self._attnames[qcl]
658        if qcl not in self.get_relations('rv'):
659            raise _prg_error('Class %s does not exist' % qcl)
660
661        q = "SELECT pg_attribute.attname, pg_type.typname"
662        if self._regtypes:
663            q += "::regtype"
664        q += (" FROM pg_class"
665            " JOIN pg_namespace ON pg_class.relnamespace = pg_namespace.oid"
666            " JOIN pg_attribute ON pg_attribute.attrelid = pg_class.oid"
667            " JOIN pg_type ON pg_type.oid = pg_attribute.atttypid"
668            " WHERE pg_namespace.nspname = '%s' AND pg_class.relname = '%s'"
669            " AND (pg_attribute.attnum > 0 OR pg_attribute.attname = 'oid')"
670            " AND pg_attribute.attisdropped = 'f'") % cl
671        q = self.db.query(q).getresult()
672
673        if self._regtypes:
674            t = dict(q)
675        else:
676            t = {}
677            for att, typ in q:
678                if typ.startswith('bool'):
679                    typ = 'bool'
680                elif typ.startswith('abstime'):
681                    typ = 'date'
682                elif typ.startswith('date'):
683                    typ = 'date'
684                elif typ.startswith('interval'):
685                    typ = 'date'
686                elif typ.startswith('timestamp'):
687                    typ = 'date'
688                elif typ.startswith('oid'):
689                    typ = 'int'
690                elif typ.startswith('int'):
691                    typ = 'int'
692                elif typ.startswith('float'):
693                    typ = 'float'
694                elif typ.startswith('numeric'):
695                    typ = 'num'
696                elif typ.startswith('money'):
697                    typ = 'money'
698                else:
699                    typ = 'text'
700                t[att] = typ
701
702        self._attnames[qcl] = t  # cache it
703        return self._attnames[qcl]
704
705    def use_regtypes(self, regtypes=None):
706        """Use regular type names instead of simplified type names."""
707        if regtypes is None:
708            return self._regtypes
709        else:
710            regtypes = bool(regtypes)
711            if regtypes != self._regtypes:
712                self._regtypes = regtypes
713                self._attnames.clear()
714            return regtypes
715
716    def has_table_privilege(self, cl, privilege='select'):
717        """Check whether current user has specified table privilege."""
718        qcl = self._add_schema(cl)
719        privilege = privilege.lower()
720        try:
721            return self._privileges[(qcl, privilege)]
722        except KeyError:
723            q = "SELECT has_table_privilege('%s', '%s')" % (qcl, privilege)
724            ret = self.db.query(q).getresult()[0][0] == self._make_bool(True)
725            self._privileges[(qcl, privilege)] = ret
726            return ret
727
728    def get(self, cl, arg, keyname=None):
729        """Get a tuple from a database table or view.
730
731        This method is the basic mechanism to get a single row.  The keyname
732        that the key specifies a unique row.  If keyname is not specified
733        then the primary key for the table is used.  If arg is a dictionary
734        then the value for the key is taken from it and it is modified to
735        include the new values, replacing existing values where necessary.
736        For a composite key, keyname can also be a sequence of key names.
737        The OID is also put into the dictionary if the table has one, but
738        in order to allow the caller to work with multiple tables, it is
739        munged as oid(schema.table).
740
741        """
742        if cl.endswith('*'):  # scan descendant tables?
743            cl = cl[:-1].rstrip()  # need parent table name
744        # build qualified class name
745        qcl = self._add_schema(cl)
746        # To allow users to work with multiple tables,
747        # we munge the name of the "oid" the key
748        qoid = _oid_key(qcl)
749        if not keyname:
750            # use the primary key by default
751            try:
752                keyname = self.pkey(qcl)
753            except KeyError:
754                raise _prg_error('Class %s has no primary key' % qcl)
755        # We want the oid for later updates if that isn't the key
756        if keyname == 'oid':
757            if isinstance(arg, dict):
758                if qoid not in arg:
759                    raise _db_error('%s not in arg' % qoid)
760            else:
761                arg = {qoid: arg}
762            where = 'oid = %s' % arg[qoid]
763            attnames = '*'
764        else:
765            attnames = self.get_attnames(qcl)
766            if isinstance(keyname, basestring):
767                keyname = (keyname,)
768            if not isinstance(arg, dict):
769                if len(keyname) > 1:
770                    raise _prg_error('Composite key needs dict as arg')
771                arg = dict([(k, arg) for k in keyname])
772            where = ' AND '.join(['%s = %s'
773                % (k, self._quote(arg[k], attnames[k])) for k in keyname])
774            attnames = ', '.join(attnames)
775        q = 'SELECT %s FROM %s WHERE %s LIMIT 1' % (attnames, qcl, where)
776        self._do_debug(q)
777        res = self.db.query(q).dictresult()
778        if not res:
779            raise _db_error('No such record in %s where %s' % (qcl, where))
780        for att, value in res[0].items():
781            arg[att == 'oid' and qoid or att] = value
782        return arg
783
784    def insert(self, cl, d=None, **kw):
785        """Insert a tuple into a database table.
786
787        This method inserts a row into a table.  If a dictionary is
788        supplied it starts with that.  Otherwise it uses a blank dictionary.
789        Either way the dictionary is updated from the keywords.
790
791        The dictionary is then, if possible, reloaded with the values actually
792        inserted in order to pick up values modified by rules, triggers, etc.
793
794        Note: The method currently doesn't support insert into views
795        although PostgreSQL does.
796
797        """
798        qcl = self._add_schema(cl)
799        qoid = _oid_key(qcl)
800        if d is None:
801            d = {}
802        d.update(kw)
803        attnames = self.get_attnames(qcl)
804        names, values = [], []
805        for n in attnames:
806            if n != 'oid' and n in d:
807                names.append('"%s"' % n)
808                values.append(self._quote(d[n], attnames[n]))
809        names, values = ', '.join(names), ', '.join(values)
810        selectable = self.has_table_privilege(qcl)
811        if selectable and self.server_version >= 80200:
812            ret = ' RETURNING %s*' % ('oid' in attnames and 'oid, ' or '')
813        else:
814            ret = ''
815        q = 'INSERT INTO %s (%s) VALUES (%s)%s' % (qcl, names, values, ret)
816        self._do_debug(q)
817        res = self.db.query(q)
818        if ret:
819            res = res.dictresult()
820            for att, value in res[0].items():
821                d[att == 'oid' and qoid or att] = value
822        elif isinstance(res, int):
823            d[qoid] = res
824            if selectable:
825                self.get(qcl, d, 'oid')
826        elif selectable:
827            if qoid in d:
828                self.get(qcl, d, 'oid')
829            else:
830                try:
831                    self.get(qcl, d)
832                except ProgrammingError:
833                    pass  # table has no primary key
834        return d
835
836    def update(self, cl, d=None, **kw):
837        """Update an existing row in a database table.
838
839        Similar to insert but updates an existing row.  The update is based
840        on the OID value as munged by get or passed as keyword, or on the
841        primary key of the table.  The dictionary is modified, if possible,
842        to reflect any changes caused by the update due to triggers, rules,
843        default values, etc.
844
845        """
846        # Update always works on the oid which get returns if available,
847        # otherwise use the primary key.  Fail if neither.
848        # Note that we only accept oid key from named args for safety
849        qcl = self._add_schema(cl)
850        qoid = _oid_key(qcl)
851        if 'oid' in kw:
852            kw[qoid] = kw['oid']
853            del kw['oid']
854        if d is None:
855            d = {}
856        d.update(kw)
857        attnames = self.get_attnames(qcl)
858        if qoid in d:
859            where = 'oid = %s' % d[qoid]
860            keyname = ()
861        else:
862            try:
863                keyname = self.pkey(qcl)
864            except KeyError:
865                raise _prg_error('Class %s has no primary key' % qcl)
866            if isinstance(keyname, basestring):
867                keyname = (keyname,)
868            try:
869                where = ' AND '.join(['%s = %s'
870                    % (k, self._quote(d[k], attnames[k])) for k in keyname])
871            except KeyError:
872                raise _prg_error('Update needs primary key or oid.')
873        values = []
874        for n in attnames:
875            if n in d and n not in keyname:
876                values.append('%s = %s' % (n, self._quote(d[n], attnames[n])))
877        if not values:
878            return d
879        values = ', '.join(values)
880        selectable = self.has_table_privilege(qcl)
881        if selectable and self.server_version >= 80200:
882            ret = ' RETURNING %s*' % ('oid' in attnames and 'oid, ' or '')
883        else:
884            ret = ''
885        q = 'UPDATE %s SET %s WHERE %s%s' % (qcl, values, where, ret)
886        self._do_debug(q)
887        res = self.db.query(q)
888        if ret:
889            res = res.dictresult()[0]
890            for att, value in res.items():
891                d[att == 'oid' and qoid or att] = value
892        else:
893            if selectable:
894                if qoid in d:
895                    self.get(qcl, d, 'oid')
896                else:
897                    self.get(qcl, d)
898        return d
899
900    def clear(self, cl, a=None):
901        """Clear all the attributes to values determined by the types.
902
903        Numeric types are set to 0, Booleans are set to false, and everything
904        else is set to the empty string.  If the array argument is present,
905        it is used as the array and any entries matching attribute names are
906        cleared with everything else left unchanged.
907
908        """
909        # At some point we will need a way to get defaults from a table.
910        qcl = self._add_schema(cl)
911        if a is None:
912            a = {}  # empty if argument is not present
913        attnames = self.get_attnames(qcl)
914        for n, t in attnames.items():
915            if n == 'oid':
916                continue
917            if t in ('int', 'integer', 'smallint', 'bigint',
918                    'float', 'real', 'double precision',
919                    'num', 'numeric', 'money'):
920                a[n] = 0
921            elif t in ('bool', 'boolean'):
922                a[n] = self._make_bool(False)
923            else:
924                a[n] = ''
925        return a
926
927    def delete(self, cl, d=None, **kw):
928        """Delete an existing row in a database table.
929
930        This method deletes the row from a table.  It deletes based on the
931        OID value as munged by get or passed as keyword, or on the primary
932        key of the table.  The return value is the number of deleted rows
933        (i.e. 0 if the row did not exist and 1 if the row was deleted).
934
935        """
936        # Like update, delete works on the oid.
937        # One day we will be testing that the record to be deleted
938        # isn't referenced somewhere (or else PostgreSQL will).
939        # Note that we only accept oid key from named args for safety
940        qcl = self._add_schema(cl)
941        qoid = _oid_key(qcl)
942        if 'oid' in kw:
943            kw[qoid] = kw['oid']
944            del kw['oid']
945        if d is None:
946            d = {}
947        d.update(kw)
948        if qoid in d:
949            where = 'oid = %s' % d[qoid]
950        else:
951            try:
952                keyname = self.pkey(qcl)
953            except KeyError:
954                raise _prg_error('Class %s has no primary key' % qcl)
955            if isinstance(keyname, basestring):
956                keyname = (keyname,)
957            attnames = self.get_attnames(qcl)
958            try:
959                where = ' AND '.join(['%s = %s'
960                    % (k, self._quote(d[k], attnames[k])) for k in keyname])
961            except KeyError:
962                raise _prg_error('Delete needs primary key or oid.')
963        q = 'DELETE FROM %s WHERE %s' % (qcl, where)
964        self._do_debug(q)
965        return int(self.db.query(q))
966
967    def notification_handler(self, event, callback, arg_dict={}, timeout=None):
968        """Get notification handler that will run the given callback."""
969        return NotificationHandler(self.db, event, callback, arg_dict, timeout)
970
971
972# if run as script, print some information
973
974if __name__ == '__main__':
975    print('PyGreSQL version' + version)
976    print('')
977    print(__doc__)
Note: See TracBrowser for help on using the repository browser.