source: trunk/module/pg.py @ 553

Last change on this file since 553 was 553, checked in by cito, 4 years ago

Require at least Python 2.6 for the trunk (5.x)

Support for even older Python versions is maintained in the 4.x branch.
The goal for 5.x is to be a single-source code for both Python 2 and 3,
and this is only possible by dropping support for Python 2.5 and older.
For instance, the new except .. as syntax works only since Python 2.6.
Otherwise we would need to use 2to3 and things would be very ugly.
Note that Python 2.6 is now 7 years old. We may want to drop Python 2.6
as well at some point if it turns out to be a burden.

  • Property svn:keywords set to Id
File size: 33.9 KB
Line 
1#! /usr/bin/python
2#
3# pg.py
4#
5# $Id: pg.py 553 2015-11-20 11:16:03Z cito $
6#
7
8"""PyGreSQL classic interface.
9
10This pg module implements some basic database management stuff.
11It includes the _pg module and builds on it, providing the higher
12level wrapper class named DB with additional functionality.
13This is known as the "classic" ("old style") PyGreSQL interface.
14For a DB-API 2 compliant interface use the newer pgdb module.
15
16"""
17
18# Copyright (c) 1997-2013 by D'Arcy J.M. Cain.
19#
20# Contributions made by Ch. Zwerschke and others.
21#
22# The notification handler is based on pgnotify which is
23# Copyright (c) 2001 Ng Pheng Siong. All rights reserved.
24#
25# Permission to use, copy, modify, and distribute this software and its
26# documentation for any purpose and without fee is hereby granted,
27# provided that the above copyright notice appear in all copies and that
28# both that copyright notice and this permission notice appear in
29# supporting documentation.
30
31import sys
32if sys.version_info[0] == 3:
33    class basestring(str): pass
34
35from _pg import *
36
37import select
38import warnings
39
40from decimal import Decimal
41from collections import namedtuple
42
43set_decimal(Decimal)
44
45
46# Auxiliary functions which are independent from a DB connection:
47
48def _is_quoted(s):
49    """Check whether this string is a quoted identifier."""
50    s = s.replace('_', 'a')
51    return not s.isalnum() or s[:1].isdigit() or s != s.lower()
52
53
54def _is_unquoted(s):
55    """Check whether this string is an unquoted identifier."""
56    s = s.replace('_', 'a')
57    return s.isalnum() and not s[:1].isdigit()
58
59
60def _split_first_part(s):
61    """Split the first part of a dot separated string."""
62    s = s.lstrip()
63    if s[:1] == '"':
64        p = []
65        s = s.split('"', 3)[1:]
66        p.append(s[0])
67        while len(s) == 3 and s[1] == '':
68            p.append('"')
69            s = s[2].split('"', 2)
70            p.append(s[0])
71        p = [''.join(p)]
72        s = '"'.join(s[1:]).lstrip()
73        if s:
74            if s[:0] == '.':
75                p.append(s[1:])
76            else:
77                s = _split_first_part(s)
78                p[0] += s[0]
79                if len(s) > 1:
80                    p.append(s[1])
81    else:
82        p = s.split('.', 1)
83        s = p[0].rstrip()
84        if _is_unquoted(s):
85            s = s.lower()
86        p[0] = s
87    return p
88
89
90def _split_parts(s):
91    """Split all parts of a dot separated string."""
92    q = []
93    while s:
94        s = _split_first_part(s)
95        q.append(s[0])
96        if len(s) < 2:
97            break
98        s = s[1]
99    return q
100
101
102def _join_parts(s):
103    """Join all parts of a dot separated string."""
104    return '.'.join([_is_quoted(p) and '"%s"' % p or p for p in s])
105
106
107def _oid_key(qcl):
108    """Build oid key from qualified class name."""
109    return 'oid(%s)' % qcl
110
111
112if namedtuple:
113
114    def _namedresult(q):
115        """Get query result as named tuples."""
116        row = namedtuple('Row', q.listfields())
117        return [row(*r) for r in q.getresult()]
118
119    set_namedresult(_namedresult)
120
121
122def _db_error(msg, cls=DatabaseError):
123    """Returns DatabaseError with empty sqlstate attribute."""
124    error = cls(msg)
125    error.sqlstate = None
126    return error
127
128
129def _int_error(msg):
130    """Returns InternalError."""
131    return _db_error(msg, InternalError)
132
133
134def _prg_error(msg):
135    """Returns ProgrammingError."""
136    return _db_error(msg, ProgrammingError)
137
138
139class NotificationHandler(object):
140    """A PostgreSQL client-side asynchronous notification handler."""
141
142    def __init__(self, db, event, callback, arg_dict=None, timeout=None):
143        """Initialize the notification handler.
144
145        db       - PostgreSQL connection object.
146        event    - Event (notification channel) to LISTEN for.
147        callback - Event callback function.
148        arg_dict - A dictionary passed as the argument to the callback.
149        timeout  - Timeout in seconds; a floating point number denotes
150                   fractions of seconds. If it is absent or None, the
151                   callers will never time out.
152
153        """
154        if isinstance(db, DB):
155            db = db.db
156        self.db = db
157        self.event = event
158        self.stop_event = 'stop_%s' % event
159        self.listening = False
160        self.callback = callback
161        if arg_dict is None:
162            arg_dict = {}
163        self.arg_dict = arg_dict
164        self.timeout = timeout
165
166    def __del__(self):
167        self.close()
168
169    def close(self):
170        """Stop listening and close the connection."""
171        if self.db:
172            self.unlisten()
173            self.db.close()
174            self.db = None
175
176    def listen(self):
177        """Start listening for the event and the stop event."""
178        if not self.listening:
179            self.db.query('listen "%s"' % self.event)
180            self.db.query('listen "%s"' % self.stop_event)
181            self.listening = True
182
183    def unlisten(self):
184        """Stop listening for the event and the stop event."""
185        if self.listening:
186            self.db.query('unlisten "%s"' % self.event)
187            self.db.query('unlisten "%s"' % self.stop_event)
188            self.listening = False
189
190    def notify(self, db=None, stop=False, payload=None):
191        """Generate a notification.
192
193        Note: If the main loop is running in another thread, you must pass
194        a different database connection to avoid a collision.
195
196        """
197        if not db:
198            db = self.db
199        if self.listening:
200            q = 'notify "%s"' % (stop and self.stop_event or self.event)
201            if payload:
202                q += ", '%s'" % payload
203            return db.query(q)
204
205    def __call__(self, close=False):
206        """Invoke the notification handler.
207
208        The handler is a loop that actually LISTENs for two NOTIFY messages:
209
210        <event> and stop_<event>.
211
212        When either of these NOTIFY messages are received, its associated
213        'pid' and 'event' are inserted into <arg_dict>, and the callback is
214        invoked with <arg_dict>. If the NOTIFY message is stop_<event>, the
215        handler UNLISTENs both <event> and stop_<event> and exits.
216
217        Note: If you run this loop in another thread, don't use the same
218        database connection for database operations in the main thread.
219
220        """
221        self.listen()
222        _ilist = [self.db.fileno()]
223
224        while True:
225            ilist, _olist, _elist = select.select(_ilist, [], [], self.timeout)
226            if ilist == []:  # we timed out
227                self.unlisten()
228                self.callback(None)
229                break
230            else:
231                notice = self.db.getnotify()
232                if notice is None:
233                    continue
234                event, pid, extra = notice
235                if event in (self.event, self.stop_event):
236                    self.arg_dict['pid'] = pid
237                    self.arg_dict['event'] = event
238                    self.arg_dict['extra'] = extra
239                    self.callback(self.arg_dict)
240                    if event == self.stop_event:
241                        self.unlisten()
242                        break
243                else:
244                    self.unlisten()
245                    raise _db_error(
246                        'listening for "%s" and "%s", but notified of "%s"'
247                        % (self.event, self.stop_event, event))
248
249
250def pgnotify(*args, **kw):
251    """Same as NotificationHandler, under the traditional name."""
252    warnings.warn("pgnotify is deprecated, use NotificationHandler instead.",
253        DeprecationWarning, stacklevel=2)
254    return NotificationHandler(*args, **kw)
255
256
257# The actual PostGreSQL database connection interface:
258
259class DB(object):
260    """Wrapper class for the _pg connection type."""
261
262    def __init__(self, *args, **kw):
263        """Create a new connection.
264
265        You can pass either the connection parameters or an existing
266        _pg or pgdb connection. This allows you to use the methods
267        of the classic pg interface with a DB-API 2 pgdb connection.
268
269        """
270        if not args and len(kw) == 1:
271            db = kw.get('db')
272        elif not kw and len(args) == 1:
273            db = args[0]
274        else:
275            db = None
276        if db:
277            if isinstance(db, DB):
278                db = db.db
279            else:
280                try:
281                    db = db._cnx
282                except AttributeError:
283                    pass
284        if not db or not hasattr(db, 'db') or not hasattr(db, 'query'):
285            db = connect(*args, **kw)
286            self._closeable = True
287        else:
288            self._closeable = False
289        self.db = db
290        self.dbname = db.db
291        self._regtypes = False
292        self._attnames = {}
293        self._pkeys = {}
294        self._privileges = {}
295        self._args = args, kw
296        self.debug = None  # For debugging scripts, this can be set
297            # * to a string format specification (e.g. in CGI set to "%s<BR>"),
298            # * to a file object to write debug statements or
299            # * to a callable object which takes a string argument
300            # * to any other true value to just print debug statements
301
302    def __getattr__(self, name):
303        # All undefined members are same as in underlying pg connection:
304        if self.db:
305            return getattr(self.db, name)
306        else:
307            raise _int_error('Connection is not valid')
308
309    # Context manager methods
310
311    def __enter__(self):
312        """Enter the runtime context. This will start a transaction."""
313        self.begin()
314        return self
315
316    def __exit__(self, et, ev, tb):
317        """Exit the runtime context. This will end the transaction."""
318        if et is None and ev is None and tb is None:
319            self.commit()
320        else:
321            self.rollback()
322
323    # Auxiliary methods
324
325    def _do_debug(self, s):
326        """Print a debug message."""
327        if self.debug:
328            if isinstance(self.debug, basestring):
329                print(self.debug % s)
330            elif isinstance(self.debug, file):
331                file.write(s + '\n')
332            elif callable(self.debug):
333                self.debug(s)
334            else:
335                print(s)
336
337    def _quote_text(self, d):
338        """Quote text value."""
339        if not isinstance(d, basestring):
340            d = str(d)
341        return "'%s'" % self.escape_string(d)
342
343    _bool_true = frozenset('t true 1 y yes on'.split())
344
345    def _quote_bool(self, d):
346        """Quote boolean value."""
347        if isinstance(d, basestring):
348            if not d:
349                return 'NULL'
350            d = d.lower() in self._bool_true
351        else:
352            d = bool(d)
353        return ("'f'", "'t'")[d]
354
355    _date_literals = frozenset('current_date current_time'
356        ' current_timestamp localtime localtimestamp'.split())
357
358    def _quote_date(self, d):
359        """Quote date value."""
360        if not d:
361            return 'NULL'
362        if isinstance(d, basestring) and d.lower() in self._date_literals:
363            return d
364        return self._quote_text(d)
365
366    def _quote_num(self, d):
367        """Quote numeric value."""
368        if not d and d != 0:
369            return 'NULL'
370        return str(d)
371
372    def _quote_money(self, d):
373        """Quote money value."""
374        if d is None or d == '':
375            return 'NULL'
376        if not isinstance(d, basestring):
377            d = str(d)
378        return d
379
380    _quote_funcs = dict(  # quote methods for each type
381        text=_quote_text, bool=_quote_bool, date=_quote_date,
382        int=_quote_num, num=_quote_num, float=_quote_num,
383        money=_quote_money)
384
385    def _quote(self, d, t):
386        """Return quotes if needed."""
387        if d is None:
388            return 'NULL'
389        try:
390            quote_func = self._quote_funcs[t]
391        except KeyError:
392            quote_func = self._quote_funcs['text']
393        return quote_func(self, d)
394
395    def _split_schema(self, cl):
396        """Return schema and name of object separately.
397
398        This auxiliary function splits off the namespace (schema)
399        belonging to the class with the name cl. If the class name
400        is not qualified, the function is able to determine the schema
401        of the class, taking into account the current search path.
402
403        """
404        s = _split_parts(cl)
405        if len(s) > 1:  # name already qualfied?
406            # should be database.schema.table or schema.table
407            if len(s) > 3:
408                raise _prg_error('Too many dots in class name %s' % cl)
409            schema, cl = s[-2:]
410        else:
411            cl = s[0]
412            # determine search path
413            q = 'SELECT current_schemas(TRUE)'
414            schemas = self.db.query(q).getresult()[0][0][1:-1].split(',')
415            if schemas:  # non-empty path
416                # search schema for this object in the current search path
417                q = ' UNION '.join(
418                    ["SELECT %d::integer AS n, '%s'::name AS nspname"
419                        % s for s in enumerate(schemas)])
420                q = ("SELECT nspname FROM pg_class"
421                    " JOIN pg_namespace"
422                    " ON pg_class.relnamespace = pg_namespace.oid"
423                    " JOIN (%s) AS p USING (nspname)"
424                    " WHERE pg_class.relname = '%s'"
425                    " ORDER BY n LIMIT 1" % (q, cl))
426                schema = self.db.query(q).getresult()
427                if schema:  # schema found
428                    schema = schema[0][0]
429                else:  # object not found in current search path
430                    schema = 'public'
431            else:  # empty path
432                schema = 'public'
433        return schema, cl
434
435    def _add_schema(self, cl):
436        """Ensure that the class name is prefixed with a schema name."""
437        return _join_parts(self._split_schema(cl))
438
439    # Public methods
440
441    # escape_string and escape_bytea exist as methods,
442    # so we define unescape_bytea as a method as well
443    unescape_bytea = staticmethod(unescape_bytea)
444
445    def close(self):
446        """Close the database connection."""
447        # Wraps shared library function so we can track state.
448        if self._closeable:
449            if self.db:
450                self.db.close()
451                self.db = None
452            else:
453                raise _int_error('Connection already closed')
454
455    def reset(self):
456        """Reset connection with current parameters.
457
458        All derived queries and large objects derived from this connection
459        will not be usable after this call.
460
461        """
462        if self.db:
463            self.db.reset()
464        else:
465            raise _int_error('Connection already closed')
466
467    def reopen(self):
468        """Reopen connection to the database.
469
470        Used in case we need another connection to the same database.
471        Note that we can still reopen a database that we have closed.
472
473        """
474        # There is no such shared library function.
475        if self._closeable:
476            db = connect(*self._args[0], **self._args[1])
477            if self.db:
478                self.db.close()
479            self.db = db
480
481    def begin(self, mode=None):
482        """Begin a transaction."""
483        qstr = 'BEGIN'
484        if mode:
485            qstr += ' ' + mode
486        return self.query(qstr)
487
488    start = begin
489
490    def commit(self):
491        """Commit the current transaction."""
492        return self.query('COMMIT')
493
494    end = commit
495
496    def rollback(self, name=None):
497        """Rollback the current transaction."""
498        qstr = 'ROLLBACK'
499        if name:
500            qstr += ' TO ' + name
501        return self.query(qstr)
502
503    def savepoint(self, name=None):
504        """Define a new savepoint within the current transaction."""
505        qstr = 'SAVEPOINT'
506        if name:
507            qstr += ' ' + name
508        return self.query(qstr)
509
510    def release(self, name):
511        """Destroy a previously defined savepoint."""
512        return self.query('RELEASE ' + name)
513
514    def query(self, qstr, *args):
515        """Executes a SQL command string.
516
517        This method simply sends a SQL query to the database. If the query is
518        an insert statement that inserted exactly one row into a table that
519        has OIDs, the return value is the OID of the newly inserted row.
520        If the query is an update or delete statement, or an insert statement
521        that did not insert exactly one row in a table with OIDs, then the
522        numer of rows affected is returned as a string. If it is a statement
523        that returns rows as a result (usually a select statement, but maybe
524        also an "insert/update ... returning" statement), this method returns
525        a pgqueryobject that can be accessed via getresult() or dictresult()
526        or simply printed. Otherwise, it returns `None`.
527
528        The query can contain numbered parameters of the form $1 in place
529        of any data constant. Arguments given after the query string will
530        be substituted for the corresponding numbered parameter. Parameter
531        values can also be given as a single list or tuple argument.
532
533        Note that the query string must not be passed as a unicode value,
534        but you can pass arguments as unicode values if they can be decoded
535        using the current client encoding.
536
537        """
538        # Wraps shared library function for debugging.
539        if not self.db:
540            raise _int_error('Connection is not valid')
541        self._do_debug(qstr)
542        return self.db.query(qstr, args)
543
544    def pkey(self, cl, newpkey=None):
545        """This method gets or sets the primary key of a class.
546
547        Composite primary keys are represented as frozensets. Note that
548        this raises an exception if the table does not have a primary key.
549
550        If newpkey is set and is not a dictionary then set that
551        value as the primary key of the class.  If it is a dictionary
552        then replace the _pkeys dictionary with a copy of it.
553
554        """
555        # First see if the caller is supplying a dictionary
556        if isinstance(newpkey, dict):
557            # make sure that all classes have a namespace
558            self._pkeys = dict([
559                ('.' in cl and cl or 'public.' + cl, pkey)
560                for cl, pkey in newpkey.items()])
561            return self._pkeys
562
563        qcl = self._add_schema(cl)  # build fully qualified class name
564        # Check if the caller is supplying a new primary key for the class
565        if newpkey:
566            self._pkeys[qcl] = newpkey
567            return newpkey
568
569        # Get all the primary keys at once
570        if qcl not in self._pkeys:
571            # if not found, check again in case it was added after we started
572            self._pkeys = {}
573            if self.server_version >= 80200:
574                # the ANY syntax works correctly only with PostgreSQL >= 8.2
575                any_indkey = "= ANY (pg_index.indkey)"
576            else:
577                any_indkey = "IN (%s)" % ', '.join(
578                    ['pg_index.indkey[%d]' % i for i in range(16)])
579            for r in self.db.query(
580                "SELECT pg_namespace.nspname, pg_class.relname,"
581                    " pg_attribute.attname FROM pg_class"
582                " JOIN pg_namespace"
583                    " ON pg_namespace.oid = pg_class.relnamespace"
584                    " AND pg_namespace.nspname NOT LIKE 'pg_%'"
585                " JOIN pg_attribute ON pg_attribute.attrelid = pg_class.oid"
586                    " AND pg_attribute.attisdropped = 'f'"
587                " JOIN pg_index ON pg_index.indrelid = pg_class.oid"
588                    " AND pg_index.indisprimary = 't'"
589                    " AND pg_attribute.attnum " + any_indkey).getresult():
590                cl, pkey = _join_parts(r[:2]), r[2]
591                self._pkeys.setdefault(cl, []).append(pkey)
592            # (only) for composite primary keys, the values will be frozensets
593            for cl, pkey in self._pkeys.items():
594                self._pkeys[cl] = len(pkey) > 1 and frozenset(pkey) or pkey[0]
595            self._do_debug(self._pkeys)
596
597        # will raise an exception if primary key doesn't exist
598        return self._pkeys[qcl]
599
600    def get_databases(self):
601        """Get list of databases in the system."""
602        return [s[0] for s in
603            self.db.query('SELECT datname FROM pg_database').getresult()]
604
605    def get_relations(self, kinds=None):
606        """Get list of relations in connected database of specified kinds.
607
608            If kinds is None or empty, all kinds of relations are returned.
609            Otherwise kinds can be a string or sequence of type letters
610            specifying which kind of relations you want to list.
611
612        """
613        where = kinds and "pg_class.relkind IN (%s) AND" % ','.join(
614            ["'%s'" % x for x in kinds]) or ''
615        return [_join_parts(x) for x in self.db.query(
616            "SELECT pg_namespace.nspname, pg_class.relname "
617            "FROM pg_class "
618            "JOIN pg_namespace ON pg_namespace.oid = pg_class.relnamespace "
619            "WHERE %s pg_class.relname !~ '^Inv' AND "
620                "pg_class.relname !~ '^pg_' "
621            "ORDER BY 1, 2" % where).getresult()]
622
623    def get_tables(self):
624        """Return list of tables in connected database."""
625        return self.get_relations('r')
626
627    def get_attnames(self, cl, newattnames=None):
628        """Given the name of a table, digs out the set of attribute names.
629
630        Returns a dictionary of attribute names (the names are the keys,
631        the values are the names of the attributes' types).
632        If the optional newattnames exists, it must be a dictionary and
633        will become the new attribute names dictionary.
634
635        By default, only a limited number of simple types will be returned.
636        You can get the regular types after calling use_regtypes(True).
637
638        """
639        if isinstance(newattnames, dict):
640            self._attnames = newattnames
641            return
642        elif newattnames:
643            raise _prg_error('If supplied, newattnames must be a dictionary')
644        cl = self._split_schema(cl)  # split into schema and class
645        qcl = _join_parts(cl)  # build fully qualified name
646        # May as well cache them:
647        if qcl in self._attnames:
648            return self._attnames[qcl]
649        if qcl not in self.get_relations('rv'):
650            raise _prg_error('Class %s does not exist' % qcl)
651
652        q = "SELECT pg_attribute.attname, pg_type.typname"
653        if self._regtypes:
654            q += "::regtype"
655        q += (" FROM pg_class"
656            " JOIN pg_namespace ON pg_class.relnamespace = pg_namespace.oid"
657            " JOIN pg_attribute ON pg_attribute.attrelid = pg_class.oid"
658            " JOIN pg_type ON pg_type.oid = pg_attribute.atttypid"
659            " WHERE pg_namespace.nspname = '%s' AND pg_class.relname = '%s'"
660            " AND (pg_attribute.attnum > 0 OR pg_attribute.attname = 'oid')"
661            " AND pg_attribute.attisdropped = 'f'") % cl
662        q = self.db.query(q).getresult()
663
664        if self._regtypes:
665            t = dict(q)
666        else:
667            t = {}
668            for att, typ in q:
669                if typ.startswith('bool'):
670                    typ = 'bool'
671                elif typ.startswith('abstime'):
672                    typ = 'date'
673                elif typ.startswith('date'):
674                    typ = 'date'
675                elif typ.startswith('interval'):
676                    typ = 'date'
677                elif typ.startswith('timestamp'):
678                    typ = 'date'
679                elif typ.startswith('oid'):
680                    typ = 'int'
681                elif typ.startswith('int'):
682                    typ = 'int'
683                elif typ.startswith('float'):
684                    typ = 'float'
685                elif typ.startswith('numeric'):
686                    typ = 'num'
687                elif typ.startswith('money'):
688                    typ = 'money'
689                else:
690                    typ = 'text'
691                t[att] = typ
692
693        self._attnames[qcl] = t  # cache it
694        return self._attnames[qcl]
695
696    def use_regtypes(self, regtypes=None):
697        """Use regular type names instead of simplified type names."""
698        if regtypes is None:
699            return self._regtypes
700        else:
701            regtypes = bool(regtypes)
702            if regtypes != self._regtypes:
703                self._regtypes = regtypes
704                self._attnames.clear()
705            return regtypes
706
707    def has_table_privilege(self, cl, privilege='select'):
708        """Check whether current user has specified table privilege."""
709        qcl = self._add_schema(cl)
710        privilege = privilege.lower()
711        try:
712            return self._privileges[(qcl, privilege)]
713        except KeyError:
714            q = "SELECT has_table_privilege('%s', '%s')" % (qcl, privilege)
715            ret = self.db.query(q).getresult()[0][0] == 't'
716            self._privileges[(qcl, privilege)] = ret
717            return ret
718
719    def get(self, cl, arg, keyname=None):
720        """Get a tuple from a database table or view.
721
722        This method is the basic mechanism to get a single row.  The keyname
723        that the key specifies a unique row.  If keyname is not specified
724        then the primary key for the table is used.  If arg is a dictionary
725        then the value for the key is taken from it and it is modified to
726        include the new values, replacing existing values where necessary.
727        For a composite key, keyname can also be a sequence of key names.
728        The OID is also put into the dictionary if the table has one, but
729        in order to allow the caller to work with multiple tables, it is
730        munged as oid(schema.table).
731
732        """
733        if cl.endswith('*'):  # scan descendant tables?
734            cl = cl[:-1].rstrip()  # need parent table name
735        # build qualified class name
736        qcl = self._add_schema(cl)
737        # To allow users to work with multiple tables,
738        # we munge the name of the "oid" the key
739        qoid = _oid_key(qcl)
740        if not keyname:
741            # use the primary key by default
742            try:
743                keyname = self.pkey(qcl)
744            except KeyError:
745                raise _prg_error('Class %s has no primary key' % qcl)
746        # We want the oid for later updates if that isn't the key
747        if keyname == 'oid':
748            if isinstance(arg, dict):
749                if qoid not in arg:
750                    raise _db_error('%s not in arg' % qoid)
751            else:
752                arg = {qoid: arg}
753            where = 'oid = %s' % arg[qoid]
754            attnames = '*'
755        else:
756            attnames = self.get_attnames(qcl)
757            if isinstance(keyname, basestring):
758                keyname = (keyname,)
759            if not isinstance(arg, dict):
760                if len(keyname) > 1:
761                    raise _prg_error('Composite key needs dict as arg')
762                arg = dict([(k, arg) for k in keyname])
763            where = ' AND '.join(['%s = %s'
764                % (k, self._quote(arg[k], attnames[k])) for k in keyname])
765            attnames = ', '.join(attnames)
766        q = 'SELECT %s FROM %s WHERE %s LIMIT 1' % (attnames, qcl, where)
767        self._do_debug(q)
768        res = self.db.query(q).dictresult()
769        if not res:
770            raise _db_error('No such record in %s where %s' % (qcl, where))
771        for att, value in res[0].items():
772            arg[att == 'oid' and qoid or att] = value
773        return arg
774
775    def insert(self, cl, d=None, **kw):
776        """Insert a tuple into a database table.
777
778        This method inserts a row into a table.  If a dictionary is
779        supplied it starts with that.  Otherwise it uses a blank dictionary.
780        Either way the dictionary is updated from the keywords.
781
782        The dictionary is then, if possible, reloaded with the values actually
783        inserted in order to pick up values modified by rules, triggers, etc.
784
785        Note: The method currently doesn't support insert into views
786        although PostgreSQL does.
787
788        """
789        qcl = self._add_schema(cl)
790        qoid = _oid_key(qcl)
791        if d is None:
792            d = {}
793        d.update(kw)
794        attnames = self.get_attnames(qcl)
795        names, values = [], []
796        for n in attnames:
797            if n != 'oid' and n in d:
798                names.append('"%s"' % n)
799                values.append(self._quote(d[n], attnames[n]))
800        names, values = ', '.join(names), ', '.join(values)
801        selectable = self.has_table_privilege(qcl)
802        if selectable and self.server_version >= 80200:
803            ret = ' RETURNING %s*' % ('oid' in attnames and 'oid, ' or '')
804        else:
805            ret = ''
806        q = 'INSERT INTO %s (%s) VALUES (%s)%s' % (qcl, names, values, ret)
807        self._do_debug(q)
808        res = self.db.query(q)
809        if ret:
810            res = res.dictresult()
811            for att, value in res[0].items():
812                d[att == 'oid' and qoid or att] = value
813        elif isinstance(res, int):
814            d[qoid] = res
815            if selectable:
816                self.get(qcl, d, 'oid')
817        elif selectable:
818            if qoid in d:
819                self.get(qcl, d, 'oid')
820            else:
821                try:
822                    self.get(qcl, d)
823                except ProgrammingError:
824                    pass  # table has no primary key
825        return d
826
827    def update(self, cl, d=None, **kw):
828        """Update an existing row in a database table.
829
830        Similar to insert but updates an existing row.  The update is based
831        on the OID value as munged by get or passed as keyword, or on the
832        primary key of the table.  The dictionary is modified, if possible,
833        to reflect any changes caused by the update due to triggers, rules,
834        default values, etc.
835
836        """
837        # Update always works on the oid which get returns if available,
838        # otherwise use the primary key.  Fail if neither.
839        # Note that we only accept oid key from named args for safety
840        qcl = self._add_schema(cl)
841        qoid = _oid_key(qcl)
842        if 'oid' in kw:
843            kw[qoid] = kw['oid']
844            del kw['oid']
845        if d is None:
846            d = {}
847        d.update(kw)
848        attnames = self.get_attnames(qcl)
849        if qoid in d:
850            where = 'oid = %s' % d[qoid]
851            keyname = ()
852        else:
853            try:
854                keyname = self.pkey(qcl)
855            except KeyError:
856                raise _prg_error('Class %s has no primary key' % qcl)
857            if isinstance(keyname, basestring):
858                keyname = (keyname,)
859            try:
860                where = ' AND '.join(['%s = %s'
861                    % (k, self._quote(d[k], attnames[k])) for k in keyname])
862            except KeyError:
863                raise _prg_error('Update needs primary key or oid.')
864        values = []
865        for n in attnames:
866            if n in d and n not in keyname:
867                values.append('%s = %s' % (n, self._quote(d[n], attnames[n])))
868        if not values:
869            return d
870        values = ', '.join(values)
871        selectable = self.has_table_privilege(qcl)
872        if selectable and self.server_version >= 880200:
873            ret = ' RETURNING %s*' % ('oid' in attnames and 'oid, ' or '')
874        else:
875            ret = ''
876        q = 'UPDATE %s SET %s WHERE %s%s' % (qcl, values, where, ret)
877        self._do_debug(q)
878        res = self.db.query(q)
879        if ret:
880            res = res.dictresult()[0]
881            for att, value in res.items():
882                d[att == 'oid' and qoid or att] = value
883        else:
884            if selectable:
885                if qoid in d:
886                    self.get(qcl, d, 'oid')
887                else:
888                    self.get(qcl, d)
889        return d
890
891    def clear(self, cl, a=None):
892        """Clear all the attributes to values determined by the types.
893
894        Numeric types are set to 0, Booleans are set to 'f', and everything
895        else is set to the empty string.  If the array argument is present,
896        it is used as the array and any entries matching attribute names are
897        cleared with everything else left unchanged.
898
899        """
900        # At some point we will need a way to get defaults from a table.
901        qcl = self._add_schema(cl)
902        if a is None:
903            a = {}  # empty if argument is not present
904        attnames = self.get_attnames(qcl)
905        for n, t in attnames.items():
906            if n == 'oid':
907                continue
908            if t in ('int', 'integer', 'smallint', 'bigint',
909                    'float', 'real', 'double precision',
910                    'num', 'numeric', 'money'):
911                a[n] = 0
912            elif t in ('bool', 'boolean'):
913                a[n] = 'f'
914            else:
915                a[n] = ''
916        return a
917
918    def delete(self, cl, d=None, **kw):
919        """Delete an existing row in a database table.
920
921        This method deletes the row from a table.  It deletes based on the
922        OID value as munged by get or passed as keyword, or on the primary
923        key of the table.  The return value is the number of deleted rows
924        (i.e. 0 if the row did not exist and 1 if the row was deleted).
925
926        """
927        # Like update, delete works on the oid.
928        # One day we will be testing that the record to be deleted
929        # isn't referenced somewhere (or else PostgreSQL will).
930        # Note that we only accept oid key from named args for safety
931        qcl = self._add_schema(cl)
932        qoid = _oid_key(qcl)
933        if 'oid' in kw:
934            kw[qoid] = kw['oid']
935            del kw['oid']
936        if d is None:
937            d = {}
938        d.update(kw)
939        if qoid in d:
940            where = 'oid = %s' % d[qoid]
941        else:
942            try:
943                keyname = self.pkey(qcl)
944            except KeyError:
945                raise _prg_error('Class %s has no primary key' % qcl)
946            if isinstance(keyname, basestring):
947                keyname = (keyname,)
948            attnames = self.get_attnames(qcl)
949            try:
950                where = ' AND '.join(['%s = %s'
951                    % (k, self._quote(d[k], attnames[k])) for k in keyname])
952            except KeyError:
953                raise _prg_error('Delete needs primary key or oid.')
954        q = 'DELETE FROM %s WHERE %s' % (qcl, where)
955        self._do_debug(q)
956        return int(self.db.query(q))
957
958    def notification_handler(self, event, callback, arg_dict={}, timeout=None):
959        """Get notification handler that will run the given callback."""
960        return NotificationHandler(self.db, event, callback, arg_dict, timeout)
961
962
963# if run as script, print some information
964
965if __name__ == '__main__':
966    print('PyGreSQL version' + version)
967    print('')
968    print(__doc__)
Note: See TracBrowser for help on using the repository browser.