source: trunk/pg.py @ 729

Last change on this file since 729 was 729, checked in by cito, 4 years ago

Simplify caching and handling of class names

The caches now use the class names as keys as they are passed in.
We do not automatically calculate the qualified name any more,
since this causes too much overhead. Also, we fill the pkey cache
not pro-actively with tables from all possible schemes any more.

Most of the internal auxiliary functions for handling class names
could be discarded by making good use of quote_ident and reglass.

  • Property svn:keywords set to Id
File size: 30.4 KB
Line 
1#! /usr/bin/python
2#
3# pg.py
4#
5# $Id: pg.py 729 2016-01-12 21:29:07Z cito $
6#
7
8"""PyGreSQL classic interface.
9
10This pg module implements some basic database management stuff.
11It includes the _pg module and builds on it, providing the higher
12level wrapper class named DB with additional functionality.
13This is known as the "classic" ("old style") PyGreSQL interface.
14For a DB-API 2 compliant interface use the newer pgdb module.
15
16"""
17
18# Copyright (c) 1997-2013 by D'Arcy J.M. Cain.
19#
20# Contributions made by Ch. Zwerschke and others.
21#
22# The notification handler is based on pgnotify which is
23# Copyright (c) 2001 Ng Pheng Siong. All rights reserved.
24#
25# Permission to use, copy, modify, and distribute this software and its
26# documentation for any purpose and without fee is hereby granted,
27# provided that the above copyright notice appear in all copies and that
28# both that copyright notice and this permission notice appear in
29# supporting documentation.
30
31from __future__ import print_function
32
33from _pg import *
34
35import select
36import warnings
37
38from decimal import Decimal
39from collections import namedtuple
40from itertools import groupby
41
42try:
43    basestring
44except NameError:  # Python >= 3.0
45    basestring = (str, bytes)
46
47set_decimal(Decimal)
48
49
50# Auxiliary functions that are independent from a DB connection:
51
52def _quote_class_name(cl):
53    """Quote a class name.
54
55    Class names are always quoted unless they contain a dot.
56    In this ambiguous case quotes must be added manually.
57
58    """
59    if '.' not in cl:
60        cl = '"%s"' % cl
61    return cl
62
63
64def _quote_class_param(cl, param):
65    """Quote parameter representing a class name.
66
67    The parameter is automatically quoted unless the class name contains a dot.
68    In this ambiguous case quotes must be added manually.
69
70    """
71    if isinstance(param, int):
72        param = "$%d" % param
73    if '.' not in cl:
74        param = 'quote_ident(%s)' % (param,)
75    return param
76
77
78def _oid_key(cl):
79    """Build oid key from qualified class name."""
80    return 'oid(%s)' % cl
81
82
83def _simpletype(typ):
84    """Determine a simplified name a pg_type name."""
85    if typ.startswith('bool'):
86        return 'bool'
87    if typ.startswith(('abstime', 'date', 'interval', 'timestamp')):
88        return 'date'
89    if typ.startswith(('cid', 'oid', 'int', 'xid')):
90        return 'int'
91    if typ.startswith('float'):
92        return 'float'
93    if typ.startswith('numeric'):
94        return 'num'
95    if typ.startswith('money'):
96        return 'money'
97    if typ.startswith('bytea'):
98        return 'bytea'
99    return 'text'
100
101
102def _namedresult(q):
103    """Get query result as named tuples."""
104    row = namedtuple('Row', q.listfields())
105    return [row(*r) for r in q.getresult()]
106
107set_namedresult(_namedresult)
108
109
110def _db_error(msg, cls=DatabaseError):
111    """Returns DatabaseError with empty sqlstate attribute."""
112    error = cls(msg)
113    error.sqlstate = None
114    return error
115
116
117def _int_error(msg):
118    """Returns InternalError."""
119    return _db_error(msg, InternalError)
120
121
122def _prg_error(msg):
123    """Returns ProgrammingError."""
124    return _db_error(msg, ProgrammingError)
125
126
127class NotificationHandler(object):
128    """A PostgreSQL client-side asynchronous notification handler."""
129
130    def __init__(self, db, event, callback, arg_dict=None, timeout=None):
131        """Initialize the notification handler.
132
133        db       - PostgreSQL connection object.
134        event    - Event (notification channel) to LISTEN for.
135        callback - Event callback function.
136        arg_dict - A dictionary passed as the argument to the callback.
137        timeout  - Timeout in seconds; a floating point number denotes
138                   fractions of seconds. If it is absent or None, the
139                   callers will never time out.
140
141        """
142        if isinstance(db, DB):
143            db = db.db
144        self.db = db
145        self.event = event
146        self.stop_event = 'stop_%s' % event
147        self.listening = False
148        self.callback = callback
149        if arg_dict is None:
150            arg_dict = {}
151        self.arg_dict = arg_dict
152        self.timeout = timeout
153
154    def __del__(self):
155        self.close()
156
157    def close(self):
158        """Stop listening and close the connection."""
159        if self.db:
160            self.unlisten()
161            self.db.close()
162            self.db = None
163
164    def listen(self):
165        """Start listening for the event and the stop event."""
166        if not self.listening:
167            self.db.query('listen "%s"' % self.event)
168            self.db.query('listen "%s"' % self.stop_event)
169            self.listening = True
170
171    def unlisten(self):
172        """Stop listening for the event and the stop event."""
173        if self.listening:
174            self.db.query('unlisten "%s"' % self.event)
175            self.db.query('unlisten "%s"' % self.stop_event)
176            self.listening = False
177
178    def notify(self, db=None, stop=False, payload=None):
179        """Generate a notification.
180
181        Note: If the main loop is running in another thread, you must pass
182        a different database connection to avoid a collision.
183
184        """
185        if not db:
186            db = self.db
187        if self.listening:
188            q = 'notify "%s"' % (self.stop_event if stop else self.event)
189            if payload:
190                q += ", '%s'" % payload
191            return db.query(q)
192
193    def __call__(self, close=False):
194        """Invoke the notification handler.
195
196        The handler is a loop that actually LISTENs for two NOTIFY messages:
197
198        <event> and stop_<event>.
199
200        When either of these NOTIFY messages are received, its associated
201        'pid' and 'event' are inserted into <arg_dict>, and the callback is
202        invoked with <arg_dict>. If the NOTIFY message is stop_<event>, the
203        handler UNLISTENs both <event> and stop_<event> and exits.
204
205        Note: If you run this loop in another thread, don't use the same
206        database connection for database operations in the main thread.
207
208        """
209        self.listen()
210        _ilist = [self.db.fileno()]
211
212        while self.listening:
213            ilist, _olist, _elist = select.select(_ilist, [], [], self.timeout)
214            if ilist:
215                while self.listening:
216                    notice = self.db.getnotify()
217                    if not notice:  # no more messages
218                        break
219                    event, pid, extra = notice
220                    if event not in (self.event, self.stop_event):
221                        self.unlisten()
222                        raise _db_error(
223                            'listening for "%s" and "%s", but notified of "%s"'
224                            % (self.event, self.stop_event, event))
225                    if event == self.stop_event:
226                        self.unlisten()
227                    self.arg_dict['pid'] = pid
228                    self.arg_dict['event'] = event
229                    self.arg_dict['extra'] = extra
230                    self.callback(self.arg_dict)
231            else:   # we timed out
232                self.unlisten()
233                self.callback(None)
234
235
236def pgnotify(*args, **kw):
237    """Same as NotificationHandler, under the traditional name."""
238    warnings.warn("pgnotify is deprecated, use NotificationHandler instead.",
239        DeprecationWarning, stacklevel=2)
240    return NotificationHandler(*args, **kw)
241
242
243# The actual PostGreSQL database connection interface:
244
245class DB(object):
246    """Wrapper class for the _pg connection type."""
247
248    def __init__(self, *args, **kw):
249        """Create a new connection.
250
251        You can pass either the connection parameters or an existing
252        _pg or pgdb connection. This allows you to use the methods
253        of the classic pg interface with a DB-API 2 pgdb connection.
254
255        """
256        if not args and len(kw) == 1:
257            db = kw.get('db')
258        elif not kw and len(args) == 1:
259            db = args[0]
260        else:
261            db = None
262        if db:
263            if isinstance(db, DB):
264                db = db.db
265            else:
266                try:
267                    db = db._cnx
268                except AttributeError:
269                    pass
270        if not db or not hasattr(db, 'db') or not hasattr(db, 'query'):
271            db = connect(*args, **kw)
272            self._closeable = True
273        else:
274            self._closeable = False
275        self.db = db
276        self.dbname = db.db
277        self._regtypes = False
278        self._attnames = {}
279        self._pkeys = {}
280        self._privileges = {}
281        self._args = args, kw
282        self.debug = None  # For debugging scripts, this can be set
283            # * to a string format specification (e.g. in CGI set to "%s<BR>"),
284            # * to a file object to write debug statements or
285            # * to a callable object which takes a string argument
286            # * to any other true value to just print debug statements
287
288    def __getattr__(self, name):
289        # All undefined members are same as in underlying connection:
290        if self.db:
291            return getattr(self.db, name)
292        else:
293            raise _int_error('Connection is not valid')
294
295    def __dir__(self):
296        # Custom dir function including the attributes of the connection:
297        attrs = set(self.__class__.__dict__)
298        attrs.update(self.__dict__)
299        attrs.update(dir(self.db))
300        return sorted(attrs)
301
302    # Context manager methods
303
304    def __enter__(self):
305        """Enter the runtime context. This will start a transaction."""
306        self.begin()
307        return self
308
309    def __exit__(self, et, ev, tb):
310        """Exit the runtime context. This will end the transaction."""
311        if et is None and ev is None and tb is None:
312            self.commit()
313        else:
314            self.rollback()
315
316    # Auxiliary methods
317
318    def _do_debug(self, s):
319        """Print a debug message."""
320        if self.debug:
321            if isinstance(self.debug, basestring):
322                print(self.debug % s)
323            elif hasattr(self.debug, 'write'):
324                self.debug.write(s + '\n')
325            elif callable(self.debug):
326                self.debug(s)
327            else:
328                print(s)
329
330    @staticmethod
331    def _make_bool(d):
332        """Get boolean value corresponding to d."""
333        return bool(d) if get_bool() else ('t' if d else 'f')
334
335    def _quote_text(self, d):
336        """Quote text value."""
337        if not isinstance(d, basestring):
338            d = str(d)
339        return "'%s'" % self.escape_string(d)
340
341    _bool_true = frozenset('t true 1 y yes on'.split())
342
343    def _quote_bool(self, d):
344        """Quote boolean value."""
345        if isinstance(d, basestring):
346            if not d:
347                return 'NULL'
348            d = d.lower() in self._bool_true
349        return "'t'" if d else "'f'"
350
351    _date_literals = frozenset('current_date current_time'
352        ' current_timestamp localtime localtimestamp'.split())
353
354    def _quote_date(self, d):
355        """Quote date value."""
356        if not d:
357            return 'NULL'
358        if isinstance(d, basestring) and d.lower() in self._date_literals:
359            return d
360        return self._quote_text(d)
361
362    def _quote_num(self, d):
363        """Quote numeric value."""
364        if not d and d != 0:
365            return 'NULL'
366        return str(d)
367
368    def _quote_money(self, d):
369        """Quote money value."""
370        if d is None or d == '':
371            return 'NULL'
372        if not isinstance(d, basestring):
373            d = str(d)
374        return d
375
376    if bytes is str:  # Python < 3.0
377        """Quote bytes value."""
378
379        def _quote_bytea(self, d):
380            return "'%s'" % self.escape_bytea(d)
381
382    else:
383
384        def _quote_bytea(self, d):
385            return "'%s'" % self.escape_bytea(d).decode('ascii')
386
387    _quote_funcs = dict(  # quote methods for each type
388        text=_quote_text, bool=_quote_bool, date=_quote_date,
389        int=_quote_num, num=_quote_num, float=_quote_num,
390        money=_quote_money, bytea=_quote_bytea)
391
392    def _quote(self, d, t):
393        """Return quotes if needed."""
394        if d is None:
395            return 'NULL'
396        try:
397            quote_func = self._quote_funcs[t]
398        except KeyError:
399            quote_func = self._quote_funcs['text']
400        return quote_func(self, d)
401
402    # Public methods
403
404    # escape_string and escape_bytea exist as methods,
405    # so we define unescape_bytea as a method as well
406    unescape_bytea = staticmethod(unescape_bytea)
407
408    def close(self):
409        """Close the database connection."""
410        # Wraps shared library function so we can track state.
411        if self._closeable:
412            if self.db:
413                self.db.close()
414                self.db = None
415            else:
416                raise _int_error('Connection already closed')
417
418    def reset(self):
419        """Reset connection with current parameters.
420
421        All derived queries and large objects derived from this connection
422        will not be usable after this call.
423
424        """
425        if self.db:
426            self.db.reset()
427        else:
428            raise _int_error('Connection already closed')
429
430    def reopen(self):
431        """Reopen connection to the database.
432
433        Used in case we need another connection to the same database.
434        Note that we can still reopen a database that we have closed.
435
436        """
437        # There is no such shared library function.
438        if self._closeable:
439            db = connect(*self._args[0], **self._args[1])
440            if self.db:
441                self.db.close()
442            self.db = db
443
444    def begin(self, mode=None):
445        """Begin a transaction."""
446        qstr = 'BEGIN'
447        if mode:
448            qstr += ' ' + mode
449        return self.query(qstr)
450
451    start = begin
452
453    def commit(self):
454        """Commit the current transaction."""
455        return self.query('COMMIT')
456
457    end = commit
458
459    def rollback(self, name=None):
460        """Rollback the current transaction."""
461        qstr = 'ROLLBACK'
462        if name:
463            qstr += ' TO ' + name
464        return self.query(qstr)
465
466    def savepoint(self, name=None):
467        """Define a new savepoint within the current transaction."""
468        qstr = 'SAVEPOINT'
469        if name:
470            qstr += ' ' + name
471        return self.query(qstr)
472
473    def release(self, name):
474        """Destroy a previously defined savepoint."""
475        return self.query('RELEASE ' + name)
476
477    def query(self, qstr, *args):
478        """Executes a SQL command string.
479
480        This method simply sends a SQL query to the database. If the query is
481        an insert statement that inserted exactly one row into a table that
482        has OIDs, the return value is the OID of the newly inserted row.
483        If the query is an update or delete statement, or an insert statement
484        that did not insert exactly one row in a table with OIDs, then the
485        number of rows affected is returned as a string. If it is a statement
486        that returns rows as a result (usually a select statement, but maybe
487        also an "insert/update ... returning" statement), this method returns
488        a Query object that can be accessed via getresult() or dictresult()
489        or simply printed. Otherwise, it returns `None`.
490
491        The query can contain numbered parameters of the form $1 in place
492        of any data constant. Arguments given after the query string will
493        be substituted for the corresponding numbered parameter. Parameter
494        values can also be given as a single list or tuple argument.
495
496        """
497        # Wraps shared library function for debugging.
498        if not self.db:
499            raise _int_error('Connection is not valid')
500        self._do_debug(qstr)
501        return self.db.query(qstr, args)
502
503    def pkey(self, cl, flush=False):
504        """This method gets or sets the primary key of a class.
505
506        Composite primary keys are represented as frozensets. Note that
507        this raises a KeyError if the table does not have a primary key.
508
509        If flush is set then the internal cache for primary keys will
510        be flushed. This may be necessary after the database schema or
511        the search path has been changed.
512
513        """
514        pkeys = self._pkeys
515        if flush:
516            pkeys.clear()
517            self._do_debug('pkey cache has been flushed')
518        try:  # cache lookup
519            pkey = pkeys[cl]
520        except KeyError:  # cache miss, check the database
521            q = ("SELECT a.attname FROM pg_index i"
522                " JOIN pg_attribute a ON a.attrelid = i.indrelid"
523                " AND a.attnum = ANY(i.indkey)"
524                " AND NOT a.attisdropped"
525                " WHERE i.indrelid=%s::regclass"
526                " AND i.indisprimary" % _quote_class_param(cl, 1))
527            pkey = self.db.query(q, (cl,)).getresult()
528            if not pkey:
529                raise KeyError('Class %s has no primary key' % cl)
530            if len(pkey) > 1:
531                pkey = frozenset([k[0] for k in pkey])
532            else:
533                pkey = pkey[0][0]
534            pkeys[cl] = pkey  # cache it
535        return pkey
536
537    def get_databases(self):
538        """Get list of databases in the system."""
539        return [s[0] for s in
540            self.db.query('SELECT datname FROM pg_database').getresult()]
541
542    def get_relations(self, kinds=None):
543        """Get list of relations in connected database of specified kinds.
544
545        If kinds is None or empty, all kinds of relations are returned.
546        Otherwise kinds can be a string or sequence of type letters
547        specifying which kind of relations you want to list.
548
549        """
550        where = " AND r.relkind IN (%s)" % ','.join(
551            ["'%s'" % k for k in kinds]) if kinds else ''
552        q = ("SELECT quote_ident(s.nspname)||'.'||quote_ident(r.relname)"
553            " FROM pg_class r"
554            " JOIN pg_namespace s ON s.oid = r.relnamespace"
555            " WHERE s.nspname NOT SIMILAR"
556            " TO 'pg/_%%|information/_schema' ESCAPE '/' %s"
557            " ORDER BY s.nspname, r.relname") % where
558        return [r[0] for r in self.db.query(q).getresult()]
559
560    def get_tables(self):
561        """Return list of tables in connected database."""
562        return self.get_relations('r')
563
564    def get_attnames(self, cl, flush=False):
565        """Given the name of a table, digs out the set of attribute names.
566
567        Returns a dictionary of attribute names (the names are the keys,
568        the values are the names of the attributes' types).
569
570        If the optional newattnames exists, it must be a dictionary and
571        will become the new attribute names dictionary.
572
573        By default, only a limited number of simple types will be returned.
574        You can get the regular types after calling use_regtypes(True).
575
576        """
577        attnames = self._attnames
578        if flush:
579            attnames.clear()
580            self._do_debug('pkey cache has been flushed')
581
582        try:  # cache lookup
583            names = attnames[cl]
584        except KeyError:  # cache miss, check the database
585            q = ("SELECT a.attname, t.typname%s"
586                " FROM pg_attribute a"
587                " JOIN pg_type t ON t.oid = a.atttypid"
588                " WHERE a.attrelid = %s::regclass"
589                " AND (a.attnum > 0 OR a.attname = 'oid')"
590                " AND NOT a.attisdropped") % (
591                    '::regtype' if self._regtypes else '',
592                    _quote_class_param(cl, 1))
593            names = self.db.query(q, (cl,)).getresult()
594            if not names:
595                raise KeyError('Class %s does not exist' % cl)
596            if self._regtypes:
597                names = dict(names)
598            else:
599                names = dict((name, _simpletype(typ)) for name, typ in names)
600            attnames[cl] = names  # cache it
601        return names
602
603    def use_regtypes(self, regtypes=None):
604        """Use regular type names instead of simplified type names."""
605        if regtypes is None:
606            return self._regtypes
607        else:
608            regtypes = bool(regtypes)
609            if regtypes != self._regtypes:
610                self._regtypes = regtypes
611                self._attnames.clear()
612            return regtypes
613
614    def has_table_privilege(self, cl, privilege='select'):
615        """Check whether current user has specified table privilege."""
616        privilege = privilege.lower()
617        try:  # ask cache
618            return self._privileges[(cl, privilege)]
619        except KeyError:  # cache miss, ask the database
620            q = "SELECT has_table_privilege(%s, $2)" % (
621                _quote_class_param(cl, 1),)
622            q = self.db.query(q, (cl, privilege))
623            ret = q.getresult()[0][0] == self._make_bool(True)
624            self._privileges[(cl, privilege)] = ret  # cache it
625            return ret
626
627    def get(self, cl, arg, keyname=None):
628        """Get a row from a database table or view.
629
630        This method is the basic mechanism to get a single row.  The keyname
631        that the key specifies a unique row.  If keyname is not specified
632        then the primary key for the table is used.  If arg is a dictionary
633        then the value for the key is taken from it and it is modified to
634        include the new values, replacing existing values where necessary.
635        For a composite key, keyname can also be a sequence of key names.
636        The OID is also put into the dictionary if the table has one, but
637        in order to allow the caller to work with multiple tables, it is
638        munged as "oid(cl)".
639
640        """
641        if cl.endswith('*'):  # scan descendant tables?
642            cl = cl[:-1].rstrip()  # need parent table name
643        # build qualified class name
644        # To allow users to work with multiple tables,
645        # we munge the name of the "oid" key
646        qoid = _oid_key(cl)
647        if not keyname:
648            # use the primary key by default
649            try:
650                keyname = self.pkey(cl)
651            except KeyError:
652                raise _prg_error('Class %s has no primary key' % cl)
653        attnames = self.get_attnames(cl)
654        # We want the oid for later updates if that isn't the key
655        if keyname == 'oid':
656            if isinstance(arg, dict):
657                if qoid not in arg:
658                    raise _db_error('%s not in arg' % qoid)
659            else:
660                arg = {qoid: arg}
661            what = '*'
662            where = 'oid = %s' % arg[qoid]
663        else:
664            if isinstance(keyname, basestring):
665                keyname = (keyname,)
666            if not isinstance(arg, dict):
667                if len(keyname) > 1:
668                    raise _prg_error('Composite key needs dict as arg')
669                arg = dict([(k, arg) for k in keyname])
670            what = ', '.join(attnames)
671            where = ' AND '.join(['%s = %s'
672                % (k, self._quote(arg[k], attnames[k])) for k in keyname])
673        q = 'SELECT %s FROM %s WHERE %s LIMIT 1' % (
674            what, _quote_class_name(cl), where)
675        self._do_debug(q)
676        res = self.db.query(q).dictresult()
677        if not res:
678            raise _db_error('No such record in %s where %s' % (cl, where))
679        for n, value in res[0].items():
680            if n == 'oid':
681                n = qoid
682            elif attnames.get(n) == 'bytea':
683                value = self.unescape_bytea(value)
684            arg[n] = value
685        return arg
686
687    def insert(self, cl, d=None, **kw):
688        """Insert a row into a database table.
689
690        This method inserts a row into a table.  The name of the table must
691        be passed as the first parameter.  The other parameters are used for
692        providing the data of the row that shall be inserted into the table.
693        If a dictionary is supplied as the second parameter, it starts with
694        that.  Otherwise it uses a blank dictionary. Either way the dictionary
695        is updated from the keywords.
696
697        The dictionary is then, if possible, reloaded with the values actually
698        inserted in order to pick up values modified by rules, triggers, etc.
699
700        Note: The method currently doesn't support insert into views
701        although PostgreSQL does.
702
703        """
704        qoid = _oid_key(cl)
705        if d is None:
706            d = {}
707        d.update(kw)
708        attnames = self.get_attnames(cl)
709        names, values = [], []
710        for n in attnames:
711            if n != 'oid' and n in d:
712                names.append('"%s"' % n)
713                values.append(self._quote(d[n], attnames[n]))
714        names, values = ', '.join(names), ', '.join(values)
715        selectable = self.has_table_privilege(cl)
716        if selectable:
717            ret = ' RETURNING %s*' % ('oid, ' if 'oid' in attnames else '')
718        else:
719            ret = ''
720        q = 'INSERT INTO %s (%s) VALUES (%s)%s' % (
721            _quote_class_name(cl), names, values, ret)
722        self._do_debug(q)
723        res = self.db.query(q)
724        if ret:
725            res = res.dictresult()[0]
726            for n, value in res.items():
727                if n == 'oid':
728                    n = qoid
729                elif attnames.get(n) == 'bytea':
730                    value = self.unescape_bytea(value)
731                d[n] = value
732        elif isinstance(res, int):
733            d[qoid] = res
734            if selectable:
735                self.get(cl, d, 'oid')
736        elif selectable:
737            if qoid in d:
738                self.get(cl, d, 'oid')
739            else:
740                try:
741                    self.get(cl, d)
742                except ProgrammingError:
743                    pass  # table has no primary key
744        return d
745
746    def update(self, cl, d=None, **kw):
747        """Update an existing row in a database table.
748
749        Similar to insert but updates an existing row.  The update is based
750        on the OID value as munged by get or passed as keyword, or on the
751        primary key of the table.  The dictionary is modified, if possible,
752        to reflect any changes caused by the update due to triggers, rules,
753        default values, etc.
754
755        """
756        # Update always works on the oid which get returns if available,
757        # otherwise use the primary key.  Fail if neither.
758        # Note that we only accept oid key from named args for safety
759        qoid = _oid_key(cl)
760        if 'oid' in kw:
761            kw[qoid] = kw['oid']
762            del kw['oid']
763        if d is None:
764            d = {}
765        d.update(kw)
766        attnames = self.get_attnames(cl)
767        if qoid in d:
768            where = 'oid = %s' % d[qoid]
769            keyname = ()
770        else:
771            try:
772                keyname = self.pkey(cl)
773            except KeyError:
774                raise _prg_error('Class %s has no primary key' % cl)
775            if isinstance(keyname, basestring):
776                keyname = (keyname,)
777            try:
778                where = ' AND '.join(['%s = %s'
779                    % (k, self._quote(d[k], attnames[k])) for k in keyname])
780            except KeyError:
781                raise _prg_error('Update needs primary key or oid.')
782        values = []
783        for n in attnames:
784            if n in d and n not in keyname:
785                values.append('%s = %s' % (n, self._quote(d[n], attnames[n])))
786        if not values:
787            return d
788        values = ', '.join(values)
789        selectable = self.has_table_privilege(cl)
790        if selectable:
791            ret = ' RETURNING %s*' % ('oid, ' if 'oid' in attnames else '')
792        else:
793            ret = ''
794        q = 'UPDATE %s SET %s WHERE %s%s' % (
795            _quote_class_name(cl), values, where, ret)
796        self._do_debug(q)
797        res = self.db.query(q)
798        if ret:
799            res = res.dictresult()[0]
800            for n, value in res.items():
801                if n == 'oid':
802                    n = qoid
803                elif attnames.get(n) == 'bytea':
804                    value = self.unescape_bytea(value)
805                d[n] = value
806        else:
807            if selectable:
808                if qoid in d:
809                    self.get(cl, d, 'oid')
810                else:
811                    self.get(cl, d)
812        return d
813
814    def clear(self, cl, a=None):
815        """Clear all the attributes to values determined by the types.
816
817        Numeric types are set to 0, Booleans are set to false, and everything
818        else is set to the empty string.  If the array argument is present,
819        it is used as the array and any entries matching attribute names are
820        cleared with everything else left unchanged.
821
822        """
823        # At some point we will need a way to get defaults from a table.
824        if a is None:
825            a = {}  # empty if argument is not present
826        attnames = self.get_attnames(cl)
827        for n, t in attnames.items():
828            if n == 'oid':
829                continue
830            if t in ('int', 'integer', 'smallint', 'bigint',
831                    'float', 'real', 'double precision',
832                    'num', 'numeric', 'money'):
833                a[n] = 0
834            elif t in ('bool', 'boolean'):
835                a[n] = self._make_bool(False)
836            else:
837                a[n] = ''
838        return a
839
840    def delete(self, cl, d=None, **kw):
841        """Delete an existing row in a database table.
842
843        This method deletes the row from a table.  It deletes based on the
844        OID value as munged by get or passed as keyword, or on the primary
845        key of the table.  The return value is the number of deleted rows
846        (i.e. 0 if the row did not exist and 1 if the row was deleted).
847
848        """
849        # Like update, delete works on the oid.
850        # One day we will be testing that the record to be deleted
851        # isn't referenced somewhere (or else PostgreSQL will).
852        # Note that we only accept oid key from named args for safety
853        qoid = _oid_key(cl)
854        if 'oid' in kw:
855            kw[qoid] = kw['oid']
856            del kw['oid']
857        if d is None:
858            d = {}
859        d.update(kw)
860        if qoid in d:
861            where = 'oid = %s' % d[qoid]
862        else:
863            try:
864                keyname = self.pkey(cl)
865            except KeyError:
866                raise _prg_error('Class %s has no primary key' % cl)
867            if isinstance(keyname, basestring):
868                keyname = (keyname,)
869            attnames = self.get_attnames(cl)
870            try:
871                where = ' AND '.join(['%s = %s'
872                    % (k, self._quote(d[k], attnames[k])) for k in keyname])
873            except KeyError:
874                raise _prg_error('Delete needs primary key or oid.')
875        q = 'DELETE FROM %s WHERE %s' % (_quote_class_name(cl), where)
876        self._do_debug(q)
877        return int(self.db.query(q))
878
879    def notification_handler(self, event, callback, arg_dict={}, timeout=None):
880        """Get notification handler that will run the given callback."""
881        return NotificationHandler(self.db, event, callback, arg_dict, timeout)
882
883
884# if run as script, print some information
885
886if __name__ == '__main__':
887    print('PyGreSQL version' + version)
888    print('')
889    print(__doc__)
Note: See TracBrowser for help on using the repository browser.