source: trunk/module/pg.py @ 461

Last change on this file since 461 was 461, checked in by cito, 7 years ago

Add transaction handling methods and context managers to pg.

  • Property svn:keywords set to Id
File size: 29.9 KB
Line 
1#!/usr/bin/env python
2#
3# pg.py
4#
5# Written by D'Arcy J.M. Cain
6# Improved by Christoph Zwerschke
7#
8# $Id: pg.py 461 2012-11-01 11:40:55Z cito $
9#
10
11"""PyGreSQL classic interface.
12
13This pg module implements some basic database management stuff.
14It includes the _pg module and builds on it, providing the higher
15level wrapper class named DB with additional functionality.
16This is known as the "classic" ("old style") PyGreSQL interface.
17For a DB-API 2 compliant interface use the newer pgdb module.
18
19"""
20
21from _pg import *
22try:
23    frozenset
24except NameError:  # Python < 2.4
25    from sets import ImmutableSet as frozenset
26try:
27    from decimal import Decimal
28    set_decimal(Decimal)
29except ImportError:  # Python < 2.4
30    pass
31try:
32    from contextlib import contextmanager
33except ImportError:  # Python < 2.5
34    def contextmanager(f):
35        raise NotImplementedError
36try:
37    from collections import namedtuple
38except ImportError:  # Python < 2.6
39    namedtuple = None
40
41
42# Auxiliary functions which are independent from a DB connection:
43
44def _is_quoted(s):
45    """Check whether this string is a quoted identifier."""
46    s = s.replace('_', 'a')
47    return not s.isalnum() or s[:1].isdigit() or s != s.lower()
48
49
50def _is_unquoted(s):
51    """Check whether this string is an unquoted identifier."""
52    s = s.replace('_', 'a')
53    return s.isalnum() and not s[:1].isdigit()
54
55
56def _split_first_part(s):
57    """Split the first part of a dot separated string."""
58    s = s.lstrip()
59    if s[:1] == '"':
60        p = []
61        s = s.split('"', 3)[1:]
62        p.append(s[0])
63        while len(s) == 3 and s[1] == '':
64            p.append('"')
65            s = s[2].split('"', 2)
66            p.append(s[0])
67        p = [''.join(p)]
68        s = '"'.join(s[1:]).lstrip()
69        if s:
70            if s[:0] == '.':
71                p.append(s[1:])
72            else:
73                s = _split_first_part(s)
74                p[0] += s[0]
75                if len(s) > 1:
76                    p.append(s[1])
77    else:
78        p = s.split('.', 1)
79        s = p[0].rstrip()
80        if _is_unquoted(s):
81            s = s.lower()
82        p[0] = s
83    return p
84
85
86def _split_parts(s):
87    """Split all parts of a dot separated string."""
88    q = []
89    while s:
90        s = _split_first_part(s)
91        q.append(s[0])
92        if len(s) < 2:
93            break
94        s = s[1]
95    return q
96
97
98def _join_parts(s):
99    """Join all parts of a dot separated string."""
100    return '.'.join([_is_quoted(p) and '"%s"' % p or p for p in s])
101
102
103def _oid_key(qcl):
104    """Build oid key from qualified class name."""
105    return 'oid(%s)' % qcl
106
107
108if namedtuple:
109
110    def _namedresult(q):
111        """Get query result as named tuples."""
112        row = namedtuple('Row', q.listfields())
113        return [row(*r) for r in q.getresult()]
114
115    set_namedresult(_namedresult)
116
117
118def _db_error(msg, cls=DatabaseError):
119    """Returns DatabaseError with empty sqlstate attribute."""
120    error = cls(msg)
121    error.sqlstate = None
122    return error
123
124
125def _int_error(msg):
126    """Returns InternalError."""
127    return _db_error(msg, InternalError)
128
129
130def _prg_error(msg):
131    """Returns ProgrammingError."""
132    return _db_error(msg, ProgrammingError)
133
134
135# The PostGreSQL database connection interface:
136
137class DB(object):
138    """Wrapper class for the _pg connection type."""
139
140    def __init__(self, *args, **kw):
141        """Create a new connection.
142
143        You can pass either the connection parameters or an existing
144        _pg or pgdb connection. This allows you to use the methods
145        of the classic pg interface with a DB-API 2 pgdb connection.
146
147        """
148        if not args and len(kw) == 1:
149            db = kw.get('db')
150        elif not kw and len(args) == 1:
151            db = args[0]
152        else:
153            db = None
154        if db:
155            if isinstance(db, DB):
156                db = db.db
157            else:
158                try:
159                    db = db._cnx
160                except AttributeError:
161                    pass
162        if not db or not hasattr(db, 'db') or not hasattr(db, 'query'):
163            db = connect(*args, **kw)
164            self._closeable = True
165        else:
166            self._closeable = False
167        self.db = db
168        self.dbname = db.db
169        self._regtypes = False
170        self._attnames = {}
171        self._pkeys = {}
172        self._privileges = {}
173        self._transaction = False
174        self._args = args, kw
175        self.debug = None  # For debugging scripts, this can be set
176            # * to a string format specification (e.g. in CGI set to "%s<BR>"),
177            # * to a file object to write debug statements or
178            # * to a callable object which takes a string argument
179            # * to any other true value to just print debug statements
180
181    def __getattr__(self, name):
182        # All undefined members are same as in underlying pg connection:
183        if self.db:
184            return getattr(self.db, name)
185        else:
186            raise _int_error('Connection is not valid')
187
188    # Context manager methods
189
190    def __enter__(self):
191        if self._transaction:
192            self.begin()
193        return self
194
195    def __exit__(self, typ, val, tb):
196        if self._transaction:
197            self._transaction = None
198            if tb is None:
199                self.commit()
200            else:
201                self.rollback()
202        else:
203            self.close()
204
205    # Auxiliary methods
206
207    def _do_debug(self, s):
208        """Print a debug message."""
209        if self.debug:
210            if isinstance(self.debug, basestring):
211                print self.debug % s
212            elif isinstance(self.debug, file):
213                file.write(s + '\n')
214            elif callable(self.debug):
215                self.debug(s)
216            else:
217                print s
218
219    def _quote_text(self, d):
220        """Quote text value."""
221        if not isinstance(d, basestring):
222            d = str(d)
223        return "'%s'" % self.escape_string(d)
224
225    _bool_true = frozenset('t true 1 y yes on'.split())
226
227    def _quote_bool(self, d):
228        """Quote boolean value."""
229        if isinstance(d, basestring):
230            if not d:
231                return 'NULL'
232            d = d.lower() in self._bool_true
233        else:
234            d = bool(d)
235        return ("'f'", "'t'")[d]
236
237    _date_literals = frozenset('current_date current_time'
238        ' current_timestamp localtime localtimestamp'.split())
239
240    def _quote_date(self, d):
241        """Quote date value."""
242        if not d:
243            return 'NULL'
244        if isinstance(d, basestring) and d.lower() in self._date_literals:
245            return d
246        return self._quote_text(d)
247
248    def _quote_num(self, d):
249        """Quote numeric value."""
250        if not d and d != 0:
251            return 'NULL'
252        return str(d)
253
254    def _quote_money(self, d):
255        """Quote money value."""
256        if d is None or d == '':
257            return 'NULL'
258        return "'%.2f'" % float(d)
259
260    _quote_funcs = dict(  # quote methods for each type
261        text=_quote_text, bool=_quote_bool, date=_quote_date,
262        int=_quote_num, num=_quote_num, float=_quote_num,
263        money=_quote_money)
264
265    def _quote(self, d, t):
266        """Return quotes if needed."""
267        if d is None:
268            return 'NULL'
269        try:
270            quote_func = self._quote_funcs[t]
271        except KeyError:
272            quote_func = self._quote_funcs['text']
273        return quote_func(self, d)
274
275    def _split_schema(self, cl):
276        """Return schema and name of object separately.
277
278        This auxiliary function splits off the namespace (schema)
279        belonging to the class with the name cl. If the class name
280        is not qualified, the function is able to determine the schema
281        of the class, taking into account the current search path.
282
283        """
284        s = _split_parts(cl)
285        if len(s) > 1:  # name already qualfied?
286            # should be database.schema.table or schema.table
287            if len(s) > 3:
288                raise _prg_error('Too many dots in class name %s' % cl)
289            schema, cl = s[-2:]
290        else:
291            cl = s[0]
292            # determine search path
293            q = 'SELECT current_schemas(TRUE)'
294            schemas = self.db.query(q).getresult()[0][0][1:-1].split(',')
295            if schemas:  # non-empty path
296                # search schema for this object in the current search path
297                q = ' UNION '.join(
298                    ["SELECT %d::integer AS n, '%s'::name AS nspname"
299                        % s for s in enumerate(schemas)])
300                q = ("SELECT nspname FROM pg_class"
301                    " JOIN pg_namespace"
302                    " ON pg_class.relnamespace = pg_namespace.oid"
303                    " JOIN (%s) AS p USING (nspname)"
304                    " WHERE pg_class.relname = '%s'"
305                    " ORDER BY n LIMIT 1" % (q, cl))
306                schema = self.db.query(q).getresult()
307                if schema:  # schema found
308                    schema = schema[0][0]
309                else:  # object not found in current search path
310                    schema = 'public'
311            else:  # empty path
312                schema = 'public'
313        return schema, cl
314
315    def _add_schema(self, cl):
316        """Ensure that the class name is prefixed with a schema name."""
317        return _join_parts(self._split_schema(cl))
318
319    # Public methods
320
321    # escape_string and escape_bytea exist as methods,
322    # so we define unescape_bytea as a method as well
323    unescape_bytea = staticmethod(unescape_bytea)
324
325    def close(self):
326        """Close the database connection."""
327        # Wraps shared library function so we can track state.
328        if self._closeable:
329            if self.db:
330                self.db.close()
331                self.db = None
332            else:
333                raise _int_error('Connection already closed')
334
335    def reset(self):
336        """Reset connection with current parameters.
337
338        All derived queries and large objects derived from this connection
339        will not be usable after this call.
340
341        """
342        if self.db:
343            self.db.reset()
344        else:
345            raise _int_error('Connection already closed')
346
347    def reopen(self):
348        """Reopen connection to the database.
349
350        Used in case we need another connection to the same database.
351        Note that we can still reopen a database that we have closed.
352
353        """
354        # There is no such shared library function.
355        if self._closeable:
356            db = connect(*self._args[0], **self._args[1])
357            if self.db:
358                self.db.close()
359            self.db = db
360
361    def begin(self, mode=None):
362        """Begin a transaction."""
363        qstr = 'BEGIN'
364        if mode:
365            qstr += ' ' + mode
366        return self.query(qstr)
367
368    start = begin
369
370    def commit(self):
371        """Commit the current transaction."""
372        return self.query('COMMIT')
373
374    end = commit
375
376    def rollback(self, name=None):
377        """Rollback the current transaction."""
378        qstr = 'ROLLBACK'
379        if name:
380            qstr += ' TO ' + name
381        return self.query('ROLLBACK')
382
383    def savepoint(self, name=None):
384        """Define a new savepoint within the current transaction."""
385        qstr = 'SAVEPOINT'
386        if name:
387            qstr += ' ' + name
388        return self.query(qstr)
389
390    def release(self, name):
391        """Destroy a previously defined savepoint."""
392        return self.query('RELEASE ' + name)
393
394    @property
395    def transaction(self):
396        """Return a context manager for running a transaction."""
397        self._transaction = True
398        return self
399
400    def query(self, qstr, *args):
401        """Executes a SQL command string.
402
403        This method simply sends a SQL query to the database. If the query is
404        an insert statement that inserted exactly one row into a table that
405        has OIDs, the return value is the OID of the newly inserted row.
406        If the query is an update or delete statement, or an insert statement
407        that did not insert exactly one row in a table with OIDs, then the
408        numer of rows affected is returned as a string. If it is a statement
409        that returns rows as a result (usually a select statement, but maybe
410        also an "insert/update ... returning" statement), this method returns
411        a pgqueryobject that can be accessed via getresult() or dictresult()
412        or simply printed. Otherwise, it returns `None`.
413
414        The query can contain numbered parameters of the form $1 in place
415        of any data constant. Arguments given after the query string will
416        be substituted for the corresponding numbered parameter. Parameter
417        values can also be given as a single list or tuple argument.
418
419        Note that the query string must not be passed as a unicode value,
420        but you can pass arguments as unicode values if they can be decoded
421        using the current client encoding.
422
423        """
424        # Wraps shared library function for debugging.
425        if not self.db:
426            raise _int_error('Connection is not valid')
427        self._do_debug(qstr)
428        return self.db.query(qstr, args)
429
430    def pkey(self, cl, newpkey=None):
431        """This method gets or sets the primary key of a class.
432
433        Composite primary keys are represented as frozensets. Note that
434        this raises an exception if the table does not have a primary key.
435
436        If newpkey is set and is not a dictionary then set that
437        value as the primary key of the class.  If it is a dictionary
438        then replace the _pkeys dictionary with a copy of it.
439
440        """
441        # First see if the caller is supplying a dictionary
442        if isinstance(newpkey, dict):
443            # make sure that all classes have a namespace
444            self._pkeys = dict([
445                ('.' in cl and cl or 'public.' + cl, pkey)
446                for cl, pkey in newpkey.iteritems()])
447            return self._pkeys
448
449        qcl = self._add_schema(cl)  # build fully qualified class name
450        # Check if the caller is supplying a new primary key for the class
451        if newpkey:
452            self._pkeys[qcl] = newpkey
453            return newpkey
454
455        # Get all the primary keys at once
456        if qcl not in self._pkeys:
457            # if not found, check again in case it was added after we started
458            self._pkeys = {}
459            if self.server_version >= 80200:
460                # the ANY syntax works correctly only with PostgreSQL >= 8.2
461                any_indkey = "= ANY (pg_index.indkey)"
462            else:
463                any_indkey = "IN (%s)" % ', '.join(
464                    ['pg_index.indkey[%d]' % i for i in range(16)])
465            for r in self.db.query(
466                "SELECT pg_namespace.nspname, pg_class.relname,"
467                    " pg_attribute.attname FROM pg_class"
468                " JOIN pg_namespace"
469                    " ON pg_namespace.oid = pg_class.relnamespace"
470                    " AND pg_namespace.nspname NOT LIKE 'pg_%'"
471                " JOIN pg_attribute ON pg_attribute.attrelid = pg_class.oid"
472                    " AND pg_attribute.attisdropped = 'f'"
473                " JOIN pg_index ON pg_index.indrelid = pg_class.oid"
474                    " AND pg_index.indisprimary = 't'"
475                    " AND pg_attribute.attnum " + any_indkey).getresult():
476                cl, pkey = _join_parts(r[:2]), r[2]
477                self._pkeys.setdefault(cl, []).append(pkey)
478            # (only) for composite primary keys, the values will be frozensets
479            for cl, pkey in self._pkeys.iteritems():
480                self._pkeys[cl] = len(pkey) > 1 and frozenset(pkey) or pkey[0]
481            self._do_debug(self._pkeys)
482
483        # will raise an exception if primary key doesn't exist
484        return self._pkeys[qcl]
485
486    def get_databases(self):
487        """Get list of databases in the system."""
488        return [s[0] for s in
489            self.db.query('SELECT datname FROM pg_database').getresult()]
490
491    def get_relations(self, kinds=None):
492        """Get list of relations in connected database of specified kinds.
493
494            If kinds is None or empty, all kinds of relations are returned.
495            Otherwise kinds can be a string or sequence of type letters
496            specifying which kind of relations you want to list.
497
498        """
499        where = kinds and "pg_class.relkind IN (%s) AND" % ','.join(
500            ["'%s'" % x for x in kinds]) or ''
501        return map(_join_parts, self.db.query(
502            "SELECT pg_namespace.nspname, pg_class.relname "
503            "FROM pg_class "
504            "JOIN pg_namespace ON pg_namespace.oid = pg_class.relnamespace "
505            "WHERE %s pg_class.relname !~ '^Inv' AND "
506                "pg_class.relname !~ '^pg_' "
507            "ORDER BY 1, 2" % where).getresult())
508
509    def get_tables(self):
510        """Return list of tables in connected database."""
511        return self.get_relations('r')
512
513    def get_attnames(self, cl, newattnames=None):
514        """Given the name of a table, digs out the set of attribute names.
515
516        Returns a dictionary of attribute names (the names are the keys,
517        the values are the names of the attributes' types).
518        If the optional newattnames exists, it must be a dictionary and
519        will become the new attribute names dictionary.
520
521        By default, only a limited number of simple types will be returned.
522        You can get the regular types after calling use_regtypes(True).
523
524        """
525        if isinstance(newattnames, dict):
526            self._attnames = newattnames
527            return
528        elif newattnames:
529            raise _prg_error('If supplied, newattnames must be a dictionary')
530        cl = self._split_schema(cl)  # split into schema and class
531        qcl = _join_parts(cl)  # build fully qualified name
532        # May as well cache them:
533        if qcl in self._attnames:
534            return self._attnames[qcl]
535        if qcl not in self.get_relations('rv'):
536            raise _prg_error('Class %s does not exist' % qcl)
537
538        q = "SELECT pg_attribute.attname, pg_type.typname"
539        if self._regtypes:
540            q += "::regtype"
541        q += (" FROM pg_class"
542            " JOIN pg_namespace ON pg_class.relnamespace = pg_namespace.oid"
543            " JOIN pg_attribute ON pg_attribute.attrelid = pg_class.oid"
544            " JOIN pg_type ON pg_type.oid = pg_attribute.atttypid"
545            " WHERE pg_namespace.nspname = '%s' AND pg_class.relname = '%s'"
546            " AND (pg_attribute.attnum > 0 OR pg_attribute.attname = 'oid')"
547            " AND pg_attribute.attisdropped = 'f'") % cl
548        q = self.db.query(q).getresult()
549
550        if self._regtypes:
551            t = dict(q)
552        else:
553            t = {}
554            for att, typ in q:
555                if typ.startswith('bool'):
556                    typ = 'bool'
557                elif typ.startswith('abstime'):
558                    typ = 'date'
559                elif typ.startswith('date'):
560                    typ = 'date'
561                elif typ.startswith('interval'):
562                    typ = 'date'
563                elif typ.startswith('timestamp'):
564                    typ = 'date'
565                elif typ.startswith('oid'):
566                    typ = 'int'
567                elif typ.startswith('int'):
568                    typ = 'int'
569                elif typ.startswith('float'):
570                    typ = 'float'
571                elif typ.startswith('numeric'):
572                    typ = 'num'
573                elif typ.startswith('money'):
574                    typ = 'money'
575                else:
576                    typ = 'text'
577                t[att] = typ
578
579        self._attnames[qcl] = t  # cache it
580        return self._attnames[qcl]
581
582    def use_regtypes(self, regtypes=None):
583        """Use regular type names instead of simplified type names."""
584        if regtypes is None:
585            return self._regtypes
586        else:
587            regtypes = bool(regtypes)
588            if regtypes != self._regtypes:
589                self._regtypes = regtypes
590                self._attnames.clear()
591            return regtypes
592
593    def has_table_privilege(self, cl, privilege='select'):
594        """Check whether current user has specified table privilege."""
595        qcl = self._add_schema(cl)
596        privilege = privilege.lower()
597        try:
598            return self._privileges[(qcl, privilege)]
599        except KeyError:
600            q = "SELECT has_table_privilege('%s', '%s')" % (qcl, privilege)
601            ret = self.db.query(q).getresult()[0][0] == 't'
602            self._privileges[(qcl, privilege)] = ret
603            return ret
604
605    def get(self, cl, arg, keyname=None):
606        """Get a tuple from a database table or view.
607
608        This method is the basic mechanism to get a single row.  The keyname
609        that the key specifies a unique row.  If keyname is not specified
610        then the primary key for the table is used.  If arg is a dictionary
611        then the value for the key is taken from it and it is modified to
612        include the new values, replacing existing values where necessary.
613        For a composite key, keyname can also be a sequence of key names.
614        The OID is also put into the dictionary if the table has one, but
615        in order to allow the caller to work with multiple tables, it is
616        munged as oid(schema.table).
617
618        """
619        if cl.endswith('*'):  # scan descendant tables?
620            cl = cl[:-1].rstrip()  # need parent table name
621        # build qualified class name
622        qcl = self._add_schema(cl)
623        # To allow users to work with multiple tables,
624        # we munge the name of the "oid" the key
625        qoid = _oid_key(qcl)
626        if not keyname:
627            # use the primary key by default
628            try:
629                keyname = self.pkey(qcl)
630            except KeyError:
631                raise _prg_error('Class %s has no primary key' % qcl)
632        # We want the oid for later updates if that isn't the key
633        if keyname == 'oid':
634            if isinstance(arg, dict):
635                if qoid not in arg:
636                    raise _db_error('%s not in arg' % qoid)
637            else:
638                arg = {qoid: arg}
639            where = 'oid = $1'
640            params = (arg[qoid],)
641            attnames = '*'
642        else:
643            attnames = self.get_attnames(qcl)
644            if isinstance(keyname, basestring):
645                keyname = (keyname,)
646            if not isinstance(arg, dict):
647                if len(keyname) > 1:
648                    raise _prg_error('Composite key needs dict as arg')
649                arg = dict([(k, arg) for k in keyname])
650            where = ' AND '.join(['%s = $%d'
651                % (k, i + 1) for i, k in enumerate(keyname)])
652            params = tuple(arg[k] for k in keyname)
653            attnames = ', '.join(attnames)
654        q = 'SELECT %s FROM %s WHERE %s LIMIT 1' % (attnames, qcl, where)
655        self._do_debug(q + ' %% %r' % (params,))
656        res = self.db.query(q, params).dictresult()
657        if not res:
658            raise _db_error(
659                'No such record in %s where %s %% %r' % (qcl, where, params))
660        for att, value in res[0].iteritems():
661            arg[att == 'oid' and qoid or att] = value
662        return arg
663
664    def insert(self, cl, d=None, **kw):
665        """Insert a tuple into a database table.
666
667        This method inserts a row into a table.  If a dictionary is
668        supplied it starts with that.  Otherwise it uses a blank dictionary.
669        Either way the dictionary is updated from the keywords.
670
671        The dictionary is then, if possible, reloaded with the values actually
672        inserted in order to pick up values modified by rules, triggers, etc.
673
674        Note: The method currently doesn't support insert into views
675        although PostgreSQL does.
676
677        """
678        qcl = self._add_schema(cl)
679        qoid = _oid_key(qcl)
680        if d is None:
681            d = {}
682        d.update(kw)
683        attnames = self.get_attnames(qcl)
684        names, values, params = [], [], []
685        i = 1
686        for n in attnames:
687            if n != 'oid' and n in d:
688                names.append('"%s"' % n)
689                values.append('$%d' % (i,))
690                params.append(d[n])
691                i += 1
692        names, values = ', '.join(names), ', '.join(values)
693        selectable = self.has_table_privilege(qcl)
694        if selectable and self.server_version >= 80200:
695            ret = ' RETURNING %s*' % ('oid' in attnames and 'oid, ' or '')
696        else:
697            ret = ''
698        q = 'INSERT INTO %s (%s) VALUES (%s)%s' % (qcl, names, values, ret)
699        self._do_debug(q + " %% %r" % (params,))
700        res = self.db.query(q, params)
701        if ret:
702            res = res.dictresult()
703            for att, value in res[0].iteritems():
704                d[att == 'oid' and qoid or att] = value
705        elif isinstance(res, int):
706            d[qoid] = res
707            if selectable:
708                self.get(qcl, d, 'oid')
709        elif selectable:
710            if qoid in d:
711                self.get(qcl, d, 'oid')
712            else:
713                try:
714                    self.get(qcl, d)
715                except ProgrammingError:
716                    pass  # table has no primary key
717        return d
718
719    def update(self, cl, d=None, **kw):
720        """Update an existing row in a database table.
721
722        Similar to insert but updates an existing row.  The update is based
723        on the OID value as munged by get or passed as keyword, or on the
724        primary key of the table.  The dictionary is modified, if possible,
725        to reflect any changes caused by the update due to triggers, rules,
726        default values, etc.
727
728        """
729        # Update always works on the oid which get returns if available,
730        # otherwise use the primary key.  Fail if neither.
731        # Note that we only accept oid key from named args for safety
732        qcl = self._add_schema(cl)
733        qoid = _oid_key(qcl)
734        if 'oid' in kw:
735            kw[qoid] = kw['oid']
736            del kw['oid']
737        if d is None:
738            d = {}
739        d.update(kw)
740        attnames = self.get_attnames(qcl)
741        if qoid in d:
742            where = 'oid = $1'
743            params = [d[qoid]]
744            keyname = ()
745        else:
746            try:
747                keyname = self.pkey(qcl)
748            except KeyError:
749                raise _prg_error('Class %s has no primary key' % qcl)
750            if isinstance(keyname, basestring):
751                keyname = (keyname,)
752            try:
753                where = ' AND '.join(['%s = $%d'
754                    % (k, i + 1) for i, k in enumerate(keyname)])
755                params = [d[k] for k in keyname]
756            except KeyError:
757                raise _prg_error('Update needs primary key or oid.')
758        values = []
759        i = len(params)
760        for n in attnames:
761            if n in d and n not in keyname:
762                i += 1
763                values.append('%s = $%d' % (n, i))
764                params.append(d[n])
765        if not values:
766            return d
767        values = ', '.join(values)
768        selectable = self.has_table_privilege(qcl)
769        if selectable and self.server_version >= 880200:
770            ret = ' RETURNING %s*' % ('oid' in attnames and 'oid, ' or '')
771        else:
772            ret = ''
773        q = 'UPDATE %s SET %s WHERE %s%s' % (qcl, values, where, ret)
774        self._do_debug(q)
775        res = self.db.query(q, params)
776        if ret:
777            res = res.dictresult()[0]
778            for att, value in res.iteritems():
779                d[att == 'oid' and qoid or att] = value
780        else:
781            if selectable:
782                if qoid in d:
783                    self.get(qcl, d, 'oid')
784                else:
785                    self.get(qcl, d)
786        return d
787
788    def clear(self, cl, a=None):
789        """Clear all the attributes to values determined by the types.
790
791        Numeric types are set to 0, Booleans are set to 'f', and everything
792        else is set to the empty string.  If the array argument is present,
793        it is used as the array and any entries matching attribute names are
794        cleared with everything else left unchanged.
795
796        """
797        # At some point we will need a way to get defaults from a table.
798        qcl = self._add_schema(cl)
799        if a is None:
800            a = {}  # empty if argument is not present
801        attnames = self.get_attnames(qcl)
802        for n, t in attnames.iteritems():
803            if n == 'oid':
804                continue
805            if t in ('int', 'integer', 'smallint', 'bigint',
806                    'float', 'real', 'double precision',
807                    'num', 'numeric', 'money'):
808                a[n] = 0
809            elif t in ('bool', 'boolean'):
810                a[n] = 'f'
811            else:
812                a[n] = ''
813        return a
814
815    def delete(self, cl, d=None, **kw):
816        """Delete an existing row in a database table.
817
818        This method deletes the row from a table.  It deletes based on the
819        OID value as munged by get or passed as keyword, or on the primary
820        key of the table.  The return value is the number of deleted rows
821        (i.e. 0 if the row did not exist and 1 if the row was deleted).
822
823        """
824        # Like update, delete works on the oid.
825        # One day we will be testing that the record to be deleted
826        # isn't referenced somewhere (or else PostgreSQL will).
827        # Note that we only accept oid key from named args for safety
828        qcl = self._add_schema(cl)
829        qoid = _oid_key(qcl)
830        if 'oid' in kw:
831            kw[qoid] = kw['oid']
832            del kw['oid']
833        if d is None:
834            d = {}
835        d.update(kw)
836        if qoid in d:
837            where = 'oid = $1'
838            params = (d[qoid],)
839        else:
840            try:
841                keyname = self.pkey(qcl)
842            except KeyError:
843                raise _prg_error('Class %s has no primary key' % qcl)
844            if isinstance(keyname, basestring):
845                keyname = (keyname,)
846            try:
847                where = ' AND '.join(['%s = $%d'
848                    % (k, i + 1) for i, k in enumerate(keyname)])
849                params = tuple(d[k] for k in keyname)
850            except KeyError:
851                raise _prg_error('Delete needs primary key or oid.')
852        q = 'DELETE FROM %s WHERE %s' % (qcl, where)
853        self._do_debug(q + " %% %r" % (params,))
854        return int(self.db.query(q, params))
855
856
857# if run as script, print some information
858
859if __name__ == '__main__':
860    print('PyGreSQL version' + version)
861    print('')
862    print(__doc__)
Note: See TracBrowser for help on using the repository browser.