source: trunk/module/pg.py @ 505

Last change on this file since 505 was 505, checked in by cito, 7 years ago

Move copyright notice from module docstring to comment.

  • Property svn:keywords set to Id
File size: 32.7 KB
Line 
1#!/usr/bin/env python
2#
3# pg.py
4#
5# $Id: pg.py 505 2013-01-06 17:44:25Z cito $
6#
7
8"""PyGreSQL classic interface.
9
10This pg module implements some basic database management stuff.
11It includes the _pg module and builds on it, providing the higher
12level wrapper class named DB with additional functionality.
13This is known as the "classic" ("old style") PyGreSQL interface.
14For a DB-API 2 compliant interface use the newer pgdb module.
15
16"""
17
18# Copyright (c) 1997-2013 by D'Arcy J.M. Cain.
19#
20# Contributions made by Ch. Zwerschke and others.
21#
22# The notification handler is based on pgnotify which is
23# Copyright (c) 2001 Ng Pheng Siong. All rights reserved.
24#
25# Permission to use, copy, modify, and distribute this software and its
26# documentation for any purpose and without fee is hereby granted,
27# provided that the above copyright notice appear in all copies and that
28# both that copyright notice and this permission notice appear in
29# supporting documentation.
30
31import select
32
33from _pg import *
34try:
35    frozenset
36except NameError:  # Python < 2.4
37    from sets import ImmutableSet as frozenset
38try:
39    from decimal import Decimal
40    set_decimal(Decimal)
41except ImportError:  # Python < 2.4
42    Decimal = float
43try:
44    from collections import namedtuple
45except ImportError:  # Python < 2.6
46    namedtuple = None
47
48
49# Auxiliary functions which are independent from a DB connection:
50
51def _is_quoted(s):
52    """Check whether this string is a quoted identifier."""
53    s = s.replace('_', 'a')
54    return not s.isalnum() or s[:1].isdigit() or s != s.lower()
55
56
57def _is_unquoted(s):
58    """Check whether this string is an unquoted identifier."""
59    s = s.replace('_', 'a')
60    return s.isalnum() and not s[:1].isdigit()
61
62
63def _split_first_part(s):
64    """Split the first part of a dot separated string."""
65    s = s.lstrip()
66    if s[:1] == '"':
67        p = []
68        s = s.split('"', 3)[1:]
69        p.append(s[0])
70        while len(s) == 3 and s[1] == '':
71            p.append('"')
72            s = s[2].split('"', 2)
73            p.append(s[0])
74        p = [''.join(p)]
75        s = '"'.join(s[1:]).lstrip()
76        if s:
77            if s[:0] == '.':
78                p.append(s[1:])
79            else:
80                s = _split_first_part(s)
81                p[0] += s[0]
82                if len(s) > 1:
83                    p.append(s[1])
84    else:
85        p = s.split('.', 1)
86        s = p[0].rstrip()
87        if _is_unquoted(s):
88            s = s.lower()
89        p[0] = s
90    return p
91
92
93def _split_parts(s):
94    """Split all parts of a dot separated string."""
95    q = []
96    while s:
97        s = _split_first_part(s)
98        q.append(s[0])
99        if len(s) < 2:
100            break
101        s = s[1]
102    return q
103
104
105def _join_parts(s):
106    """Join all parts of a dot separated string."""
107    return '.'.join([_is_quoted(p) and '"%s"' % p or p for p in s])
108
109
110def _oid_key(qcl):
111    """Build oid key from qualified class name."""
112    return 'oid(%s)' % qcl
113
114
115if namedtuple:
116
117    def _namedresult(q):
118        """Get query result as named tuples."""
119        row = namedtuple('Row', q.listfields())
120        return [row(*r) for r in q.getresult()]
121
122    set_namedresult(_namedresult)
123
124
125def _db_error(msg, cls=DatabaseError):
126    """Returns DatabaseError with empty sqlstate attribute."""
127    error = cls(msg)
128    error.sqlstate = None
129    return error
130
131
132def _int_error(msg):
133    """Returns InternalError."""
134    return _db_error(msg, InternalError)
135
136
137def _prg_error(msg):
138    """Returns ProgrammingError."""
139    return _db_error(msg, ProgrammingError)
140
141
142class pgnotify(object):
143    """A PostgreSQL client-side asynchronous notification handler."""
144
145    def __init__(self, db, event, callback, arg_dict=None, timeout=None):
146        """Initialize the notification handler.
147
148        db   - PostgreSQL connection object.
149        event    - Event to LISTEN for.
150        callback - Event callback.
151        arg_dict - A dictionary passed as the argument to the callback.
152        timeout  - Timeout in seconds; a floating point number denotes
153                    fractions of seconds. If it is absent or None, the
154                    callers will never time out."""
155
156        self.db = db
157        self.event = event
158        self.stop = 'stop_%s' % event
159        self.listening = False
160        self.callback = callback
161        if arg_dict is None:
162            arg_dict = {}
163        self.arg_dict = arg_dict
164        self.timeout = timeout
165
166    def __del__(self):
167        self.close()
168
169    def close(self):
170        if self.db:
171            self.unlisten()
172            self.db.close()
173            self.db = None
174
175    def listen(self):
176        if not self.listening:
177            self.db.query('listen "%s"' % self.event)
178            self.db.query('listen "%s"' % self.stop)
179            self.listening = True
180
181    def unlisten(self):
182        if self.listening:
183            self.db.query('unlisten "%s"' % self.event)
184            self.db.query('unlisten "%s"' % self.stop)
185            self.listening = False
186
187    def __call__(self):
188        """Invoke the handler.
189
190        The handler actually LISTENs for two NOTIFY messages:
191
192        <event> and stop_<event>.
193
194        When either of these NOTIFY messages are received, its associated
195        'pid' and 'event' are inserted into <arg_dict>, and the callback is
196        invoked with <arg_dict>. If the NOTIFY message is stop_<event>, the
197        handler UNLISTENs both <event> and stop_<event> and exits."""
198
199        self.listen()
200        _ilist = [self.db.fileno()]
201
202        while True:
203            ilist, _olist, _elist = select.select(_ilist, [], [], self.timeout)
204            if ilist == []:  # we timed out
205                self.unlisten()
206                self.callback(None)
207                break
208            else:
209                notice = self.db.getnotify()
210                if notice is None:
211                    continue
212                event, pid, extra = notice
213                if event in (self.event, self.stop):
214                    self.arg_dict['pid'] = pid
215                    self.arg_dict['event'] = event
216                    self.arg_dict['extra'] = extra
217                    self.callback(self.arg_dict)
218                    if event == self.stop:
219                        self.unlisten()
220                        break
221                else:
222                    self.unlisten()
223                    raise _db_error(
224                        'listening for "%s" and "%s", but notified of "%s"'
225                        % (self.event, self.stop, event))
226
227
228# The PostGreSQL database connection interface:
229class DB(object):
230    """Wrapper class for the _pg connection type."""
231
232    def __init__(self, *args, **kw):
233        """Create a new connection.
234
235        You can pass either the connection parameters or an existing
236        _pg or pgdb connection. This allows you to use the methods
237        of the classic pg interface with a DB-API 2 pgdb connection.
238
239        """
240        if not args and len(kw) == 1:
241            db = kw.get('db')
242        elif not kw and len(args) == 1:
243            db = args[0]
244        else:
245            db = None
246        if db:
247            if isinstance(db, DB):
248                db = db.db
249            else:
250                try:
251                    db = db._cnx
252                except AttributeError:
253                    pass
254        if not db or not hasattr(db, 'db') or not hasattr(db, 'query'):
255            db = connect(*args, **kw)
256            self._closeable = True
257        else:
258            self._closeable = False
259        self.db = db
260        self.dbname = db.db
261        self._regtypes = False
262        self._attnames = {}
263        self._pkeys = {}
264        self._privileges = {}
265        self._args = args, kw
266        self.debug = None  # For debugging scripts, this can be set
267            # * to a string format specification (e.g. in CGI set to "%s<BR>"),
268            # * to a file object to write debug statements or
269            # * to a callable object which takes a string argument
270            # * to any other true value to just print debug statements
271
272    def __getattr__(self, name):
273        # All undefined members are same as in underlying pg connection:
274        if self.db:
275            return getattr(self.db, name)
276        else:
277            raise _int_error('Connection is not valid')
278
279    # Context manager methods
280
281    def __enter__(self):
282        """Enter the runtime context. This will start a transaction."""
283        self.begin()
284        return self
285
286    def __exit__(self, et, ev, tb):
287        """Exit the runtime context. This will end the transaction."""
288        if et is None and ev is None and tb is None:
289            self.commit()
290        else:
291            self.rollback()
292
293    # Auxiliary methods
294
295    def _do_debug(self, s):
296        """Print a debug message."""
297        if self.debug:
298            if isinstance(self.debug, basestring):
299                print self.debug % s
300            elif isinstance(self.debug, file):
301                file.write(s + '\n')
302            elif callable(self.debug):
303                self.debug(s)
304            else:
305                print s
306
307    def _quote_text(self, d):
308        """Quote text value."""
309        if not isinstance(d, basestring):
310            d = str(d)
311        return "'%s'" % self.escape_string(d)
312
313    _bool_true = frozenset('t true 1 y yes on'.split())
314
315    def _quote_bool(self, d):
316        """Quote boolean value."""
317        if isinstance(d, basestring):
318            if not d:
319                return 'NULL'
320            d = d.lower() in self._bool_true
321        else:
322            d = bool(d)
323        return ("'f'", "'t'")[d]
324
325    _date_literals = frozenset('current_date current_time'
326        ' current_timestamp localtime localtimestamp'.split())
327
328    def _quote_date(self, d):
329        """Quote date value."""
330        if not d:
331            return 'NULL'
332        if isinstance(d, basestring) and d.lower() in self._date_literals:
333            return d
334        return self._quote_text(d)
335
336    def _quote_num(self, d):
337        """Quote numeric value."""
338        if not d and d != 0:
339            return 'NULL'
340        return str(d)
341
342    def _quote_money(self, d):
343        """Quote money value."""
344        if d is None or d == '':
345            return 'NULL'
346        if not isinstance(d, basestring):
347            d = str(d)
348        return d
349
350    _quote_funcs = dict(  # quote methods for each type
351        text=_quote_text, bool=_quote_bool, date=_quote_date,
352        int=_quote_num, num=_quote_num, float=_quote_num,
353        money=_quote_money)
354
355    def _quote(self, d, t):
356        """Return quotes if needed."""
357        if d is None:
358            return 'NULL'
359        try:
360            quote_func = self._quote_funcs[t]
361        except KeyError:
362            quote_func = self._quote_funcs['text']
363        return quote_func(self, d)
364
365    def _split_schema(self, cl):
366        """Return schema and name of object separately.
367
368        This auxiliary function splits off the namespace (schema)
369        belonging to the class with the name cl. If the class name
370        is not qualified, the function is able to determine the schema
371        of the class, taking into account the current search path.
372
373        """
374        s = _split_parts(cl)
375        if len(s) > 1:  # name already qualfied?
376            # should be database.schema.table or schema.table
377            if len(s) > 3:
378                raise _prg_error('Too many dots in class name %s' % cl)
379            schema, cl = s[-2:]
380        else:
381            cl = s[0]
382            # determine search path
383            q = 'SELECT current_schemas(TRUE)'
384            schemas = self.db.query(q).getresult()[0][0][1:-1].split(',')
385            if schemas:  # non-empty path
386                # search schema for this object in the current search path
387                q = ' UNION '.join(
388                    ["SELECT %d::integer AS n, '%s'::name AS nspname"
389                        % s for s in enumerate(schemas)])
390                q = ("SELECT nspname FROM pg_class"
391                    " JOIN pg_namespace"
392                    " ON pg_class.relnamespace = pg_namespace.oid"
393                    " JOIN (%s) AS p USING (nspname)"
394                    " WHERE pg_class.relname = '%s'"
395                    " ORDER BY n LIMIT 1" % (q, cl))
396                schema = self.db.query(q).getresult()
397                if schema:  # schema found
398                    schema = schema[0][0]
399                else:  # object not found in current search path
400                    schema = 'public'
401            else:  # empty path
402                schema = 'public'
403        return schema, cl
404
405    def _add_schema(self, cl):
406        """Ensure that the class name is prefixed with a schema name."""
407        return _join_parts(self._split_schema(cl))
408
409    # Public methods
410
411    # escape_string and escape_bytea exist as methods,
412    # so we define unescape_bytea as a method as well
413    unescape_bytea = staticmethod(unescape_bytea)
414
415    def close(self):
416        """Close the database connection."""
417        # Wraps shared library function so we can track state.
418        if self._closeable:
419            if self.db:
420                self.db.close()
421                self.db = None
422            else:
423                raise _int_error('Connection already closed')
424
425    def reset(self):
426        """Reset connection with current parameters.
427
428        All derived queries and large objects derived from this connection
429        will not be usable after this call.
430
431        """
432        if self.db:
433            self.db.reset()
434        else:
435            raise _int_error('Connection already closed')
436
437    def reopen(self):
438        """Reopen connection to the database.
439
440        Used in case we need another connection to the same database.
441        Note that we can still reopen a database that we have closed.
442
443        """
444        # There is no such shared library function.
445        if self._closeable:
446            db = connect(*self._args[0], **self._args[1])
447            if self.db:
448                self.db.close()
449            self.db = db
450
451    def begin(self, mode=None):
452        """Begin a transaction."""
453        qstr = 'BEGIN'
454        if mode:
455            qstr += ' ' + mode
456        return self.query(qstr)
457
458    start = begin
459
460    def commit(self):
461        """Commit the current transaction."""
462        return self.query('COMMIT')
463
464    end = commit
465
466    def rollback(self, name=None):
467        """Rollback the current transaction."""
468        qstr = 'ROLLBACK'
469        if name:
470            qstr += ' TO ' + name
471        return self.query(qstr)
472
473    def savepoint(self, name=None):
474        """Define a new savepoint within the current transaction."""
475        qstr = 'SAVEPOINT'
476        if name:
477            qstr += ' ' + name
478        return self.query(qstr)
479
480    def release(self, name):
481        """Destroy a previously defined savepoint."""
482        return self.query('RELEASE ' + name)
483
484    def query(self, qstr, *args):
485        """Executes a SQL command string.
486
487        This method simply sends a SQL query to the database. If the query is
488        an insert statement that inserted exactly one row into a table that
489        has OIDs, the return value is the OID of the newly inserted row.
490        If the query is an update or delete statement, or an insert statement
491        that did not insert exactly one row in a table with OIDs, then the
492        numer of rows affected is returned as a string. If it is a statement
493        that returns rows as a result (usually a select statement, but maybe
494        also an "insert/update ... returning" statement), this method returns
495        a pgqueryobject that can be accessed via getresult() or dictresult()
496        or simply printed. Otherwise, it returns `None`.
497
498        The query can contain numbered parameters of the form $1 in place
499        of any data constant. Arguments given after the query string will
500        be substituted for the corresponding numbered parameter. Parameter
501        values can also be given as a single list or tuple argument.
502
503        Note that the query string must not be passed as a unicode value,
504        but you can pass arguments as unicode values if they can be decoded
505        using the current client encoding.
506
507        """
508        # Wraps shared library function for debugging.
509        if not self.db:
510            raise _int_error('Connection is not valid')
511        self._do_debug(qstr)
512        return self.db.query(qstr, args)
513
514    def pkey(self, cl, newpkey=None):
515        """This method gets or sets the primary key of a class.
516
517        Composite primary keys are represented as frozensets. Note that
518        this raises an exception if the table does not have a primary key.
519
520        If newpkey is set and is not a dictionary then set that
521        value as the primary key of the class.  If it is a dictionary
522        then replace the _pkeys dictionary with a copy of it.
523
524        """
525        # First see if the caller is supplying a dictionary
526        if isinstance(newpkey, dict):
527            # make sure that all classes have a namespace
528            self._pkeys = dict([
529                ('.' in cl and cl or 'public.' + cl, pkey)
530                for cl, pkey in newpkey.iteritems()])
531            return self._pkeys
532
533        qcl = self._add_schema(cl)  # build fully qualified class name
534        # Check if the caller is supplying a new primary key for the class
535        if newpkey:
536            self._pkeys[qcl] = newpkey
537            return newpkey
538
539        # Get all the primary keys at once
540        if qcl not in self._pkeys:
541            # if not found, check again in case it was added after we started
542            self._pkeys = {}
543            if self.server_version >= 80200:
544                # the ANY syntax works correctly only with PostgreSQL >= 8.2
545                any_indkey = "= ANY (pg_index.indkey)"
546            else:
547                any_indkey = "IN (%s)" % ', '.join(
548                    ['pg_index.indkey[%d]' % i for i in range(16)])
549            for r in self.db.query(
550                "SELECT pg_namespace.nspname, pg_class.relname,"
551                    " pg_attribute.attname FROM pg_class"
552                " JOIN pg_namespace"
553                    " ON pg_namespace.oid = pg_class.relnamespace"
554                    " AND pg_namespace.nspname NOT LIKE 'pg_%'"
555                " JOIN pg_attribute ON pg_attribute.attrelid = pg_class.oid"
556                    " AND pg_attribute.attisdropped = 'f'"
557                " JOIN pg_index ON pg_index.indrelid = pg_class.oid"
558                    " AND pg_index.indisprimary = 't'"
559                    " AND pg_attribute.attnum " + any_indkey).getresult():
560                cl, pkey = _join_parts(r[:2]), r[2]
561                self._pkeys.setdefault(cl, []).append(pkey)
562            # (only) for composite primary keys, the values will be frozensets
563            for cl, pkey in self._pkeys.iteritems():
564                self._pkeys[cl] = len(pkey) > 1 and frozenset(pkey) or pkey[0]
565            self._do_debug(self._pkeys)
566
567        # will raise an exception if primary key doesn't exist
568        return self._pkeys[qcl]
569
570    def get_databases(self):
571        """Get list of databases in the system."""
572        return [s[0] for s in
573            self.db.query('SELECT datname FROM pg_database').getresult()]
574
575    def get_relations(self, kinds=None):
576        """Get list of relations in connected database of specified kinds.
577
578            If kinds is None or empty, all kinds of relations are returned.
579            Otherwise kinds can be a string or sequence of type letters
580            specifying which kind of relations you want to list.
581
582        """
583        where = kinds and "pg_class.relkind IN (%s) AND" % ','.join(
584            ["'%s'" % x for x in kinds]) or ''
585        return map(_join_parts, self.db.query(
586            "SELECT pg_namespace.nspname, pg_class.relname "
587            "FROM pg_class "
588            "JOIN pg_namespace ON pg_namespace.oid = pg_class.relnamespace "
589            "WHERE %s pg_class.relname !~ '^Inv' AND "
590                "pg_class.relname !~ '^pg_' "
591            "ORDER BY 1, 2" % where).getresult())
592
593    def get_tables(self):
594        """Return list of tables in connected database."""
595        return self.get_relations('r')
596
597    def get_attnames(self, cl, newattnames=None):
598        """Given the name of a table, digs out the set of attribute names.
599
600        Returns a dictionary of attribute names (the names are the keys,
601        the values are the names of the attributes' types).
602        If the optional newattnames exists, it must be a dictionary and
603        will become the new attribute names dictionary.
604
605        By default, only a limited number of simple types will be returned.
606        You can get the regular types after calling use_regtypes(True).
607
608        """
609        if isinstance(newattnames, dict):
610            self._attnames = newattnames
611            return
612        elif newattnames:
613            raise _prg_error('If supplied, newattnames must be a dictionary')
614        cl = self._split_schema(cl)  # split into schema and class
615        qcl = _join_parts(cl)  # build fully qualified name
616        # May as well cache them:
617        if qcl in self._attnames:
618            return self._attnames[qcl]
619        if qcl not in self.get_relations('rv'):
620            raise _prg_error('Class %s does not exist' % qcl)
621
622        q = "SELECT pg_attribute.attname, pg_type.typname"
623        if self._regtypes:
624            q += "::regtype"
625        q += (" FROM pg_class"
626            " JOIN pg_namespace ON pg_class.relnamespace = pg_namespace.oid"
627            " JOIN pg_attribute ON pg_attribute.attrelid = pg_class.oid"
628            " JOIN pg_type ON pg_type.oid = pg_attribute.atttypid"
629            " WHERE pg_namespace.nspname = '%s' AND pg_class.relname = '%s'"
630            " AND (pg_attribute.attnum > 0 OR pg_attribute.attname = 'oid')"
631            " AND pg_attribute.attisdropped = 'f'") % cl
632        q = self.db.query(q).getresult()
633
634        if self._regtypes:
635            t = dict(q)
636        else:
637            t = {}
638            for att, typ in q:
639                if typ.startswith('bool'):
640                    typ = 'bool'
641                elif typ.startswith('abstime'):
642                    typ = 'date'
643                elif typ.startswith('date'):
644                    typ = 'date'
645                elif typ.startswith('interval'):
646                    typ = 'date'
647                elif typ.startswith('timestamp'):
648                    typ = 'date'
649                elif typ.startswith('oid'):
650                    typ = 'int'
651                elif typ.startswith('int'):
652                    typ = 'int'
653                elif typ.startswith('float'):
654                    typ = 'float'
655                elif typ.startswith('numeric'):
656                    typ = 'num'
657                elif typ.startswith('money'):
658                    typ = 'money'
659                else:
660                    typ = 'text'
661                t[att] = typ
662
663        self._attnames[qcl] = t  # cache it
664        return self._attnames[qcl]
665
666    def use_regtypes(self, regtypes=None):
667        """Use regular type names instead of simplified type names."""
668        if regtypes is None:
669            return self._regtypes
670        else:
671            regtypes = bool(regtypes)
672            if regtypes != self._regtypes:
673                self._regtypes = regtypes
674                self._attnames.clear()
675            return regtypes
676
677    def has_table_privilege(self, cl, privilege='select'):
678        """Check whether current user has specified table privilege."""
679        qcl = self._add_schema(cl)
680        privilege = privilege.lower()
681        try:
682            return self._privileges[(qcl, privilege)]
683        except KeyError:
684            q = "SELECT has_table_privilege('%s', '%s')" % (qcl, privilege)
685            ret = self.db.query(q).getresult()[0][0] == 't'
686            self._privileges[(qcl, privilege)] = ret
687            return ret
688
689    def get(self, cl, arg, keyname=None):
690        """Get a tuple from a database table or view.
691
692        This method is the basic mechanism to get a single row.  The keyname
693        that the key specifies a unique row.  If keyname is not specified
694        then the primary key for the table is used.  If arg is a dictionary
695        then the value for the key is taken from it and it is modified to
696        include the new values, replacing existing values where necessary.
697        For a composite key, keyname can also be a sequence of key names.
698        The OID is also put into the dictionary if the table has one, but
699        in order to allow the caller to work with multiple tables, it is
700        munged as oid(schema.table).
701
702        """
703        if cl.endswith('*'):  # scan descendant tables?
704            cl = cl[:-1].rstrip()  # need parent table name
705        # build qualified class name
706        qcl = self._add_schema(cl)
707        # To allow users to work with multiple tables,
708        # we munge the name of the "oid" the key
709        qoid = _oid_key(qcl)
710        if not keyname:
711            # use the primary key by default
712            try:
713                keyname = self.pkey(qcl)
714            except KeyError:
715                raise _prg_error('Class %s has no primary key' % qcl)
716        # We want the oid for later updates if that isn't the key
717        if keyname == 'oid':
718            if isinstance(arg, dict):
719                if qoid not in arg:
720                    raise _db_error('%s not in arg' % qoid)
721            else:
722                arg = {qoid: arg}
723            where = 'oid = %s' % arg[qoid]
724            attnames = '*'
725        else:
726            attnames = self.get_attnames(qcl)
727            if isinstance(keyname, basestring):
728                keyname = (keyname,)
729            if not isinstance(arg, dict):
730                if len(keyname) > 1:
731                    raise _prg_error('Composite key needs dict as arg')
732                arg = dict([(k, arg) for k in keyname])
733            where = ' AND '.join(['%s = %s'
734                % (k, self._quote(arg[k], attnames[k])) for k in keyname])
735            attnames = ', '.join(attnames)
736        q = 'SELECT %s FROM %s WHERE %s LIMIT 1' % (attnames, qcl, where)
737        self._do_debug(q)
738        res = self.db.query(q).dictresult()
739        if not res:
740            raise _db_error('No such record in %s where %s' % (qcl, where))
741        for att, value in res[0].iteritems():
742            arg[att == 'oid' and qoid or att] = value
743        return arg
744
745    def insert(self, cl, d=None, **kw):
746        """Insert a tuple into a database table.
747
748        This method inserts a row into a table.  If a dictionary is
749        supplied it starts with that.  Otherwise it uses a blank dictionary.
750        Either way the dictionary is updated from the keywords.
751
752        The dictionary is then, if possible, reloaded with the values actually
753        inserted in order to pick up values modified by rules, triggers, etc.
754
755        Note: The method currently doesn't support insert into views
756        although PostgreSQL does.
757
758        """
759        qcl = self._add_schema(cl)
760        qoid = _oid_key(qcl)
761        if d is None:
762            d = {}
763        d.update(kw)
764        attnames = self.get_attnames(qcl)
765        names, values = [], []
766        for n in attnames:
767            if n != 'oid' and n in d:
768                names.append('"%s"' % n)
769                values.append(self._quote(d[n], attnames[n]))
770        names, values = ', '.join(names), ', '.join(values)
771        selectable = self.has_table_privilege(qcl)
772        if selectable and self.server_version >= 80200:
773            ret = ' RETURNING %s*' % ('oid' in attnames and 'oid, ' or '')
774        else:
775            ret = ''
776        q = 'INSERT INTO %s (%s) VALUES (%s)%s' % (qcl, names, values, ret)
777        self._do_debug(q)
778        res = self.db.query(q)
779        if ret:
780            res = res.dictresult()
781            for att, value in res[0].iteritems():
782                d[att == 'oid' and qoid or att] = value
783        elif isinstance(res, int):
784            d[qoid] = res
785            if selectable:
786                self.get(qcl, d, 'oid')
787        elif selectable:
788            if qoid in d:
789                self.get(qcl, d, 'oid')
790            else:
791                try:
792                    self.get(qcl, d)
793                except ProgrammingError:
794                    pass  # table has no primary key
795        return d
796
797    def update(self, cl, d=None, **kw):
798        """Update an existing row in a database table.
799
800        Similar to insert but updates an existing row.  The update is based
801        on the OID value as munged by get or passed as keyword, or on the
802        primary key of the table.  The dictionary is modified, if possible,
803        to reflect any changes caused by the update due to triggers, rules,
804        default values, etc.
805
806        """
807        # Update always works on the oid which get returns if available,
808        # otherwise use the primary key.  Fail if neither.
809        # Note that we only accept oid key from named args for safety
810        qcl = self._add_schema(cl)
811        qoid = _oid_key(qcl)
812        if 'oid' in kw:
813            kw[qoid] = kw['oid']
814            del kw['oid']
815        if d is None:
816            d = {}
817        d.update(kw)
818        attnames = self.get_attnames(qcl)
819        if qoid in d:
820            where = 'oid = %s' % d[qoid]
821            keyname = ()
822        else:
823            try:
824                keyname = self.pkey(qcl)
825            except KeyError:
826                raise _prg_error('Class %s has no primary key' % qcl)
827            if isinstance(keyname, basestring):
828                keyname = (keyname,)
829            try:
830                where = ' AND '.join(['%s = %s'
831                    % (k, self._quote(d[k], attnames[k])) for k in keyname])
832            except KeyError:
833                raise _prg_error('Update needs primary key or oid.')
834        values = []
835        for n in attnames:
836            if n in d and n not in keyname:
837                values.append('%s = %s' % (n, self._quote(d[n], attnames[n])))
838        if not values:
839            return d
840        values = ', '.join(values)
841        selectable = self.has_table_privilege(qcl)
842        if selectable and self.server_version >= 880200:
843            ret = ' RETURNING %s*' % ('oid' in attnames and 'oid, ' or '')
844        else:
845            ret = ''
846        q = 'UPDATE %s SET %s WHERE %s%s' % (qcl, values, where, ret)
847        self._do_debug(q)
848        res = self.db.query(q)
849        if ret:
850            res = res.dictresult()[0]
851            for att, value in res.iteritems():
852                d[att == 'oid' and qoid or att] = value
853        else:
854            if selectable:
855                if qoid in d:
856                    self.get(qcl, d, 'oid')
857                else:
858                    self.get(qcl, d)
859        return d
860
861    def clear(self, cl, a=None):
862        """Clear all the attributes to values determined by the types.
863
864        Numeric types are set to 0, Booleans are set to 'f', and everything
865        else is set to the empty string.  If the array argument is present,
866        it is used as the array and any entries matching attribute names are
867        cleared with everything else left unchanged.
868
869        """
870        # At some point we will need a way to get defaults from a table.
871        qcl = self._add_schema(cl)
872        if a is None:
873            a = {}  # empty if argument is not present
874        attnames = self.get_attnames(qcl)
875        for n, t in attnames.iteritems():
876            if n == 'oid':
877                continue
878            if t in ('int', 'integer', 'smallint', 'bigint',
879                    'float', 'real', 'double precision',
880                    'num', 'numeric', 'money'):
881                a[n] = 0
882            elif t in ('bool', 'boolean'):
883                a[n] = 'f'
884            else:
885                a[n] = ''
886        return a
887
888    def delete(self, cl, d=None, **kw):
889        """Delete an existing row in a database table.
890
891        This method deletes the row from a table.  It deletes based on the
892        OID value as munged by get or passed as keyword, or on the primary
893        key of the table.  The return value is the number of deleted rows
894        (i.e. 0 if the row did not exist and 1 if the row was deleted).
895
896        """
897        # Like update, delete works on the oid.
898        # One day we will be testing that the record to be deleted
899        # isn't referenced somewhere (or else PostgreSQL will).
900        # Note that we only accept oid key from named args for safety
901        qcl = self._add_schema(cl)
902        qoid = _oid_key(qcl)
903        if 'oid' in kw:
904            kw[qoid] = kw['oid']
905            del kw['oid']
906        if d is None:
907            d = {}
908        d.update(kw)
909        if qoid in d:
910            where = 'oid = %s' % d[qoid]
911        else:
912            try:
913                keyname = self.pkey(qcl)
914            except KeyError:
915                raise _prg_error('Class %s has no primary key' % qcl)
916            if isinstance(keyname, basestring):
917                keyname = (keyname,)
918            attnames = self.get_attnames(qcl)
919            try:
920                where = ' AND '.join(['%s = %s'
921                    % (k, self._quote(d[k], attnames[k])) for k in keyname])
922            except KeyError:
923                raise _prg_error('Delete needs primary key or oid.')
924        q = 'DELETE FROM %s WHERE %s' % (qcl, where)
925        self._do_debug(q)
926        return int(self.db.query(q))
927
928    def pgnotify(self, event, callback, arg_dict={}, timeout=None):
929        return pgnotify(self.db, event, callback, arg_dict, timeout)
930
931# if run as script, print some information
932
933if __name__ == '__main__':
934    print('PyGreSQL version' + version)
935    print('')
936    print(__doc__)
Note: See TracBrowser for help on using the repository browser.