source: trunk/pg.py @ 733

Last change on this file since 733 was 733, checked in by cito, 4 years ago

Assume select privilege in insert/update()

Don't query the database whether we have select privilege on the
table on every call of insert/update(). We assume that we can
update/insert() anyway, and it's pretty sane then to assume we can
select, too. This spares one database request per call that is
superfluous under all normal circumstances. In the theoretically
possible but rare case that you can insert/update and not select,
a normal query() is always possible instead.

  • Property svn:keywords set to Id
File size: 30.2 KB
Line 
1#! /usr/bin/python
2#
3# pg.py
4#
5# $Id: pg.py 733 2016-01-13 15:12:48Z cito $
6#
7
8"""PyGreSQL classic interface.
9
10This pg module implements some basic database management stuff.
11It includes the _pg module and builds on it, providing the higher
12level wrapper class named DB with additional functionality.
13This is known as the "classic" ("old style") PyGreSQL interface.
14For a DB-API 2 compliant interface use the newer pgdb module.
15
16"""
17
18# Copyright (c) 1997-2013 by D'Arcy J.M. Cain.
19#
20# Contributions made by Ch. Zwerschke and others.
21#
22# The notification handler is based on pgnotify which is
23# Copyright (c) 2001 Ng Pheng Siong. All rights reserved.
24#
25# Permission to use, copy, modify, and distribute this software and its
26# documentation for any purpose and without fee is hereby granted,
27# provided that the above copyright notice appear in all copies and that
28# both that copyright notice and this permission notice appear in
29# supporting documentation.
30
31from __future__ import print_function
32
33from _pg import *
34
35import select
36import warnings
37
38from decimal import Decimal
39from collections import namedtuple
40from functools import partial
41
42try:
43    basestring
44except NameError:  # Python >= 3.0
45    basestring = (str, bytes)
46
47set_decimal(Decimal)
48
49
50# Auxiliary functions that are independent from a DB connection:
51
52def _oid_key(cl):
53    """Build oid key from a class name."""
54    return 'oid(%s)' % cl
55
56
57def _simpletype(typ):
58    """Determine a simplified name a pg_type name."""
59    if typ.startswith('bool'):
60        return 'bool'
61    if typ.startswith(('abstime', 'date', 'interval', 'timestamp')):
62        return 'date'
63    if typ.startswith(('cid', 'oid', 'int', 'xid')):
64        return 'int'
65    if typ.startswith('float'):
66        return 'float'
67    if typ.startswith('numeric'):
68        return 'num'
69    if typ.startswith('money'):
70        return 'money'
71    if typ.startswith('bytea'):
72        return 'bytea'
73    return 'text'
74
75
76def _namedresult(q):
77    """Get query result as named tuples."""
78    row = namedtuple('Row', q.listfields())
79    return [row(*r) for r in q.getresult()]
80
81set_namedresult(_namedresult)
82
83
84def _db_error(msg, cls=DatabaseError):
85    """Returns DatabaseError with empty sqlstate attribute."""
86    error = cls(msg)
87    error.sqlstate = None
88    return error
89
90
91def _int_error(msg):
92    """Returns InternalError."""
93    return _db_error(msg, InternalError)
94
95
96def _prg_error(msg):
97    """Returns ProgrammingError."""
98    return _db_error(msg, ProgrammingError)
99
100
101class NotificationHandler(object):
102    """A PostgreSQL client-side asynchronous notification handler."""
103
104    def __init__(self, db, event, callback, arg_dict=None, timeout=None):
105        """Initialize the notification handler.
106
107        db       - PostgreSQL connection object.
108        event    - Event (notification channel) to LISTEN for.
109        callback - Event callback function.
110        arg_dict - A dictionary passed as the argument to the callback.
111        timeout  - Timeout in seconds; a floating point number denotes
112                   fractions of seconds. If it is absent or None, the
113                   callers will never time out.
114
115        """
116        if isinstance(db, DB):
117            db = db.db
118        self.db = db
119        self.event = event
120        self.stop_event = 'stop_%s' % event
121        self.listening = False
122        self.callback = callback
123        if arg_dict is None:
124            arg_dict = {}
125        self.arg_dict = arg_dict
126        self.timeout = timeout
127
128    def __del__(self):
129        self.close()
130
131    def close(self):
132        """Stop listening and close the connection."""
133        if self.db:
134            self.unlisten()
135            self.db.close()
136            self.db = None
137
138    def listen(self):
139        """Start listening for the event and the stop event."""
140        if not self.listening:
141            self.db.query('listen "%s"' % self.event)
142            self.db.query('listen "%s"' % self.stop_event)
143            self.listening = True
144
145    def unlisten(self):
146        """Stop listening for the event and the stop event."""
147        if self.listening:
148            self.db.query('unlisten "%s"' % self.event)
149            self.db.query('unlisten "%s"' % self.stop_event)
150            self.listening = False
151
152    def notify(self, db=None, stop=False, payload=None):
153        """Generate a notification.
154
155        Note: If the main loop is running in another thread, you must pass
156        a different database connection to avoid a collision.
157
158        """
159        if not db:
160            db = self.db
161        if self.listening:
162            q = 'notify "%s"' % (self.stop_event if stop else self.event)
163            if payload:
164                q += ", '%s'" % payload
165            return db.query(q)
166
167    def __call__(self, close=False):
168        """Invoke the notification handler.
169
170        The handler is a loop that actually LISTENs for two NOTIFY messages:
171
172        <event> and stop_<event>.
173
174        When either of these NOTIFY messages are received, its associated
175        'pid' and 'event' are inserted into <arg_dict>, and the callback is
176        invoked with <arg_dict>. If the NOTIFY message is stop_<event>, the
177        handler UNLISTENs both <event> and stop_<event> and exits.
178
179        Note: If you run this loop in another thread, don't use the same
180        database connection for database operations in the main thread.
181
182        """
183        self.listen()
184        _ilist = [self.db.fileno()]
185
186        while self.listening:
187            ilist, _olist, _elist = select.select(_ilist, [], [], self.timeout)
188            if ilist:
189                while self.listening:
190                    notice = self.db.getnotify()
191                    if not notice:  # no more messages
192                        break
193                    event, pid, extra = notice
194                    if event not in (self.event, self.stop_event):
195                        self.unlisten()
196                        raise _db_error(
197                            'listening for "%s" and "%s", but notified of "%s"'
198                            % (self.event, self.stop_event, event))
199                    if event == self.stop_event:
200                        self.unlisten()
201                    self.arg_dict['pid'] = pid
202                    self.arg_dict['event'] = event
203                    self.arg_dict['extra'] = extra
204                    self.callback(self.arg_dict)
205            else:   # we timed out
206                self.unlisten()
207                self.callback(None)
208
209
210def pgnotify(*args, **kw):
211    """Same as NotificationHandler, under the traditional name."""
212    warnings.warn("pgnotify is deprecated, use NotificationHandler instead.",
213        DeprecationWarning, stacklevel=2)
214    return NotificationHandler(*args, **kw)
215
216
217# The actual PostGreSQL database connection interface:
218
219class DB(object):
220    """Wrapper class for the _pg connection type."""
221
222    def __init__(self, *args, **kw):
223        """Create a new connection.
224
225        You can pass either the connection parameters or an existing
226        _pg or pgdb connection. This allows you to use the methods
227        of the classic pg interface with a DB-API 2 pgdb connection.
228
229        """
230        if not args and len(kw) == 1:
231            db = kw.get('db')
232        elif not kw and len(args) == 1:
233            db = args[0]
234        else:
235            db = None
236        if db:
237            if isinstance(db, DB):
238                db = db.db
239            else:
240                try:
241                    db = db._cnx
242                except AttributeError:
243                    pass
244        if not db or not hasattr(db, 'db') or not hasattr(db, 'query'):
245            db = connect(*args, **kw)
246            self._closeable = True
247        else:
248            self._closeable = False
249        self.db = db
250        self.dbname = db.db
251        self._regtypes = False
252        self._attnames = {}
253        self._pkeys = {}
254        self._privileges = {}
255        self._args = args, kw
256        self.debug = None  # For debugging scripts, this can be set
257            # * to a string format specification (e.g. in CGI set to "%s<BR>"),
258            # * to a file object to write debug statements or
259            # * to a callable object which takes a string argument
260            # * to any other true value to just print debug statements
261
262    def __getattr__(self, name):
263        # All undefined members are same as in underlying connection:
264        if self.db:
265            return getattr(self.db, name)
266        else:
267            raise _int_error('Connection is not valid')
268
269    def __dir__(self):
270        # Custom dir function including the attributes of the connection:
271        attrs = set(self.__class__.__dict__)
272        attrs.update(self.__dict__)
273        attrs.update(dir(self.db))
274        return sorted(attrs)
275
276    # Context manager methods
277
278    def __enter__(self):
279        """Enter the runtime context. This will start a transaction."""
280        self.begin()
281        return self
282
283    def __exit__(self, et, ev, tb):
284        """Exit the runtime context. This will end the transaction."""
285        if et is None and ev is None and tb is None:
286            self.commit()
287        else:
288            self.rollback()
289
290    # Auxiliary methods
291
292    def _do_debug(self, *args):
293        """Print a debug message."""
294        if self.debug:
295            s = '\n'.join(args)
296            if isinstance(self.debug, basestring):
297                print(self.debug % s)
298            elif hasattr(self.debug, 'write'):
299                self.debug.write(s + '\n')
300            elif callable(self.debug):
301                self.debug(s)
302            else:
303                print(s)
304
305    def _escape_qualified_name(self, s):
306        """Escape a qualified name.
307
308        Escapes the name for use as an SQL identifier, unless the
309        name contains a dot, in which case the name is ambiguous
310        (could be a qualified name or just a name with a dot in it)
311        and must be quoted manually by the caller.
312
313        """
314        if '.' not in s:
315            s = self.escape_identifier(s)
316        return s
317
318    @staticmethod
319    def _make_bool(d):
320        """Get boolean value corresponding to d."""
321        return bool(d) if get_bool() else ('t' if d else 'f')
322
323    _bool_true_values = frozenset('t true 1 y yes on'.split())
324
325    def _prepare_bool(self, d):
326        """Prepare a boolean parameter."""
327        if isinstance(d, basestring):
328            if not d:
329                return None
330            d = d.lower() in self._bool_true_values
331        return 't' if d else 'f'
332
333    _date_literals = frozenset('current_date current_time'
334        ' current_timestamp localtime localtimestamp'.split())
335
336    def _prepare_date(self, d):
337        """Prepare a date parameter."""
338        if not d:
339            return None
340        if isinstance(d, basestring) and d.lower() in self._date_literals:
341            raise ValueError
342        return d
343
344    def _prepare_num(self, d):
345        """Prepare a numeric parameter."""
346        if not d and d != 0:
347            return None
348        return d
349
350    def _prepare_bytea(self, d):
351        """Prepare a bytea parameter."""
352        return self.escape_bytea(d)
353
354    _prepare_funcs = dict(  # quote methods for each type
355        bool=_prepare_bool, date=_prepare_date,
356        int=_prepare_num, num=_prepare_num, float=_prepare_num,
357        money=_prepare_num, bytea=_prepare_bytea)
358
359    def _prepare_param(self, value, typ, params):
360        """Prepare and add a parameter to the list."""
361        if value is not None and typ != 'text':
362            try:
363                prepare = self._prepare_funcs[typ]
364            except KeyError:
365                pass
366            else:
367                try:
368                    value = prepare(self, value)
369                except ValueError:
370                    return value
371        params.append(value)
372        return '$%d' % len(params)
373
374    @staticmethod
375    def _prepare_qualified_param(cl, param):
376        """Quote parameter representing a qualified name.
377
378        Escapes the name for use as an SQL parameter, unless the
379        name contains a dot, in which case the name is ambiguous
380        (could be a qualified name or just a name with a dot in it)
381        and must be quoted manually by the caller.
382
383        """
384        if isinstance(param, int):
385            param = "$%d" % param
386        if '.' not in cl:
387            param = 'quote_ident(%s)' % (param,)
388        return param
389
390    # Public methods
391
392    # escape_string and escape_bytea exist as methods,
393    # so we define unescape_bytea as a method as well
394    unescape_bytea = staticmethod(unescape_bytea)
395
396    def close(self):
397        """Close the database connection."""
398        # Wraps shared library function so we can track state.
399        if self._closeable:
400            if self.db:
401                self.db.close()
402                self.db = None
403            else:
404                raise _int_error('Connection already closed')
405
406    def reset(self):
407        """Reset connection with current parameters.
408
409        All derived queries and large objects derived from this connection
410        will not be usable after this call.
411
412        """
413        if self.db:
414            self.db.reset()
415        else:
416            raise _int_error('Connection already closed')
417
418    def reopen(self):
419        """Reopen connection to the database.
420
421        Used in case we need another connection to the same database.
422        Note that we can still reopen a database that we have closed.
423
424        """
425        # There is no such shared library function.
426        if self._closeable:
427            db = connect(*self._args[0], **self._args[1])
428            if self.db:
429                self.db.close()
430            self.db = db
431
432    def begin(self, mode=None):
433        """Begin a transaction."""
434        qstr = 'BEGIN'
435        if mode:
436            qstr += ' ' + mode
437        return self.query(qstr)
438
439    start = begin
440
441    def commit(self):
442        """Commit the current transaction."""
443        return self.query('COMMIT')
444
445    end = commit
446
447    def rollback(self, name=None):
448        """Rollback the current transaction."""
449        qstr = 'ROLLBACK'
450        if name:
451            qstr += ' TO ' + name
452        return self.query(qstr)
453
454    def savepoint(self, name=None):
455        """Define a new savepoint within the current transaction."""
456        qstr = 'SAVEPOINT'
457        if name:
458            qstr += ' ' + name
459        return self.query(qstr)
460
461    def release(self, name):
462        """Destroy a previously defined savepoint."""
463        return self.query('RELEASE ' + name)
464
465    def query(self, qstr, *args):
466        """Executes a SQL command string.
467
468        This method simply sends a SQL query to the database. If the query is
469        an insert statement that inserted exactly one row into a table that
470        has OIDs, the return value is the OID of the newly inserted row.
471        If the query is an update or delete statement, or an insert statement
472        that did not insert exactly one row in a table with OIDs, then the
473        number of rows affected is returned as a string. If it is a statement
474        that returns rows as a result (usually a select statement, but maybe
475        also an "insert/update ... returning" statement), this method returns
476        a Query object that can be accessed via getresult() or dictresult()
477        or simply printed. Otherwise, it returns `None`.
478
479        The query can contain numbered parameters of the form $1 in place
480        of any data constant. Arguments given after the query string will
481        be substituted for the corresponding numbered parameter. Parameter
482        values can also be given as a single list or tuple argument.
483
484        """
485        # Wraps shared library function for debugging.
486        if not self.db:
487            raise _int_error('Connection is not valid')
488        self._do_debug(qstr)
489        return self.db.query(qstr, args)
490
491    def pkey(self, cl, flush=False):
492        """This method gets or sets the primary key of a class.
493
494        Composite primary keys are represented as frozensets. Note that
495        this raises a KeyError if the table does not have a primary key.
496
497        If flush is set then the internal cache for primary keys will
498        be flushed. This may be necessary after the database schema or
499        the search path has been changed.
500
501        """
502        pkeys = self._pkeys
503        if flush:
504            pkeys.clear()
505            self._do_debug('pkey cache has been flushed')
506        try:  # cache lookup
507            pkey = pkeys[cl]
508        except KeyError:  # cache miss, check the database
509            q = ("SELECT a.attname FROM pg_index i"
510                " JOIN pg_attribute a ON a.attrelid = i.indrelid"
511                " AND a.attnum = ANY(i.indkey)"
512                " AND NOT a.attisdropped"
513                " WHERE i.indrelid=%s::regclass"
514                " AND i.indisprimary" % self._prepare_qualified_param(cl, 1))
515            pkey = self.db.query(q, (cl,)).getresult()
516            if not pkey:
517                raise KeyError('Class %s has no primary key' % cl)
518            if len(pkey) > 1:
519                pkey = frozenset([k[0] for k in pkey])
520            else:
521                pkey = pkey[0][0]
522            pkeys[cl] = pkey  # cache it
523        return pkey
524
525    def get_databases(self):
526        """Get list of databases in the system."""
527        return [s[0] for s in
528            self.db.query('SELECT datname FROM pg_database').getresult()]
529
530    def get_relations(self, kinds=None):
531        """Get list of relations in connected database of specified kinds.
532
533        If kinds is None or empty, all kinds of relations are returned.
534        Otherwise kinds can be a string or sequence of type letters
535        specifying which kind of relations you want to list.
536
537        """
538        where = " AND r.relkind IN (%s)" % ','.join(
539            ["'%s'" % k for k in kinds]) if kinds else ''
540        q = ("SELECT quote_ident(s.nspname)||'.'||quote_ident(r.relname)"
541            " FROM pg_class r"
542            " JOIN pg_namespace s ON s.oid = r.relnamespace"
543            " WHERE s.nspname NOT SIMILAR"
544            " TO 'pg/_%%|information/_schema' ESCAPE '/' %s"
545            " ORDER BY s.nspname, r.relname") % where
546        return [r[0] for r in self.db.query(q).getresult()]
547
548    def get_tables(self):
549        """Return list of tables in connected database."""
550        return self.get_relations('r')
551
552    def get_attnames(self, cl, flush=False):
553        """Given the name of a table, digs out the set of attribute names.
554
555        Returns a dictionary of attribute names (the names are the keys,
556        the values are the names of the attributes' types).
557
558        If the optional newattnames exists, it must be a dictionary and
559        will become the new attribute names dictionary.
560
561        By default, only a limited number of simple types will be returned.
562        You can get the regular types after calling use_regtypes(True).
563
564        """
565        attnames = self._attnames
566        if flush:
567            attnames.clear()
568            self._do_debug('pkey cache has been flushed')
569        try:  # cache lookup
570            names = attnames[cl]
571        except KeyError:  # cache miss, check the database
572            q = ("SELECT a.attname, t.typname%s"
573                " FROM pg_attribute a"
574                " JOIN pg_type t ON t.oid = a.atttypid"
575                " WHERE a.attrelid = %s::regclass"
576                " AND (a.attnum > 0 OR a.attname = 'oid')"
577                " AND NOT a.attisdropped") % (
578                    '::regtype' if self._regtypes else '',
579                    self._prepare_qualified_param(cl, 1))
580            names = self.db.query(q, (cl,)).getresult()
581            if not names:
582                raise KeyError('Class %s does not exist' % cl)
583            if self._regtypes:
584                names = dict(names)
585            else:
586                names = dict((name, _simpletype(typ)) for name, typ in names)
587            attnames[cl] = names  # cache it
588        return names
589
590    def use_regtypes(self, regtypes=None):
591        """Use regular type names instead of simplified type names."""
592        if regtypes is None:
593            return self._regtypes
594        else:
595            regtypes = bool(regtypes)
596            if regtypes != self._regtypes:
597                self._regtypes = regtypes
598                self._attnames.clear()
599            return regtypes
600
601    def has_table_privilege(self, cl, privilege='select'):
602        """Check whether current user has specified table privilege."""
603        privilege = privilege.lower()
604        try:  # ask cache
605            return self._privileges[(cl, privilege)]
606        except KeyError:  # cache miss, ask the database
607            q = "SELECT has_table_privilege(%s, $2)" % (
608                self._prepare_qualified_param(cl, 1),)
609            q = self.db.query(q, (cl, privilege))
610            ret = q.getresult()[0][0] == self._make_bool(True)
611            self._privileges[(cl, privilege)] = ret  # cache it
612            return ret
613
614    def get(self, cl, arg, keyname=None):
615        """Get a row from a database table or view.
616
617        This method is the basic mechanism to get a single row.  The keyname
618        that the key specifies a unique row.  If keyname is not specified
619        then the primary key for the table is used.  If arg is a dictionary
620        then the value for the key is taken from it and it is modified to
621        include the new values, replacing existing values where necessary.
622        For a composite key, keyname can also be a sequence of key names.
623        The OID is also put into the dictionary if the table has one, but
624        in order to allow the caller to work with multiple tables, it is
625        munged as "oid(cl)".
626
627        """
628        if cl.endswith('*'):  # scan descendant tables?
629            cl = cl[:-1].rstrip()  # need parent table name
630        # build qualified class name
631        # To allow users to work with multiple tables,
632        # we munge the name of the "oid" key
633        qoid = _oid_key(cl)
634        if not keyname:
635            # use the primary key by default
636            try:
637                keyname = self.pkey(cl)
638            except KeyError:
639                raise _prg_error('Class %s has no primary key' % cl)
640        attnames = self.get_attnames(cl)
641        params = []
642        param = partial(self._prepare_param, params=params)
643        col = self.escape_identifier
644        # We want the oid for later updates if that isn't the key
645        if keyname == 'oid':
646            if isinstance(arg, dict):
647                if qoid not in arg:
648                    raise _db_error('%s not in arg' % qoid)
649            else:
650                arg = {qoid: arg}
651            what = '*'
652            where = 'oid = %s' % param(arg[qoid], 'int')
653        else:
654            if isinstance(keyname, basestring):
655                keyname = (keyname,)
656            if not isinstance(arg, dict):
657                if len(keyname) > 1:
658                    raise _prg_error('Composite key needs dict as arg')
659                arg = dict((k, arg) for k in keyname)
660            what = ', '.join(col(k) for k in attnames)
661            where = ' AND '.join(['%s = %s'
662                % (col(k), param(arg[k], attnames[k])) for k in keyname])
663        q = 'SELECT %s FROM %s WHERE %s LIMIT 1' % (
664            what, self._escape_qualified_name(cl), where)
665        self._do_debug(q, params)
666        res = self.db.query(q, params).dictresult()
667        if not res:
668            raise _db_error('No such record in %s where %s' % (cl, where))
669        for n, value in res[0].items():
670            if n == 'oid':
671                n = qoid
672            elif attnames.get(n) == 'bytea':
673                value = self.unescape_bytea(value)
674            arg[n] = value
675        return arg
676
677    def insert(self, cl, d=None, **kw):
678        """Insert a row into a database table.
679
680        This method inserts a row into a table.  The name of the table must
681        be passed as the first parameter.  The other parameters are used for
682        providing the data of the row that shall be inserted into the table.
683        If a dictionary is supplied as the second parameter, it starts with
684        that.  Otherwise it uses a blank dictionary. Either way the dictionary
685        is updated from the keywords.
686
687        The dictionary is then reloaded with the values actually inserted in
688        order to pick up values modified by rules, triggers, etc.
689
690        Note: The method currently doesn't support insert into views
691        although PostgreSQL does.
692
693        """
694        qoid = _oid_key(cl)
695        if d is None:
696            d = {}
697        d.update(kw)
698        attnames = self.get_attnames(cl)
699        params = []
700        param = partial(self._prepare_param, params=params)
701        col = self.escape_identifier
702        names, values = [], []
703        for n in attnames:
704            if n != 'oid' and n in d:
705                names.append(col(n))
706                values.append(param(d[n], attnames[n]))
707        names, values = ', '.join(names), ', '.join(values)
708        ret = 'oid, *' if 'oid' in attnames else '*'
709        q = 'INSERT INTO %s (%s) VALUES (%s) RETURNING %s' % (
710            self._escape_qualified_name(cl), names, values, ret)
711        self._do_debug(q, params)
712        res = self.db.query(q, params)
713        res = res.dictresult()[0]
714        for n, value in res.items():
715            if n == 'oid':
716                n = qoid
717            elif attnames.get(n) == 'bytea' and value is not None:
718                value = self.unescape_bytea(value)
719            d[n] = value
720        return d
721
722    def update(self, cl, d=None, **kw):
723        """Update an existing row in a database table.
724
725        Similar to insert but updates an existing row.  The update is based
726        on the OID value as munged by get or passed as keyword, or on the
727        primary key of the table.  The dictionary is modified to reflect
728        any changes caused by the update due to triggers, rules, default
729        values, etc.
730
731        """
732        # Update always works on the oid which get returns if available,
733        # otherwise use the primary key.  Fail if neither.
734        # Note that we only accept oid key from named args for safety
735        qoid = _oid_key(cl)
736        if 'oid' in kw:
737            kw[qoid] = kw['oid']
738            del kw['oid']
739        if d is None:
740            d = {}
741        d.update(kw)
742        attnames = self.get_attnames(cl)
743        params = []
744        param = partial(self._prepare_param, params=params)
745        col = self.escape_identifier
746        if qoid in d:
747            where = 'oid = %s' % param(d[qoid], 'int')
748            keyname = ()
749        else:
750            try:
751                keyname = self.pkey(cl)
752            except KeyError:
753                raise _prg_error('Class %s has no primary key' % cl)
754            if isinstance(keyname, basestring):
755                keyname = (keyname,)
756            try:
757                where = ' AND '.join(['%s = %s'
758                    % (col(k), param(d[k], attnames[k])) for k in keyname])
759            except KeyError:
760                raise _prg_error('Update needs primary key or oid.')
761        values = []
762        for n in attnames:
763            if n in d and n not in keyname:
764                values.append('%s = %s' % (col(n), param(d[n], attnames[n])))
765        if not values:
766            return d
767        values = ', '.join(values)
768        ret = 'oid, *' if 'oid' in attnames else '*'
769        q = 'UPDATE %s SET %s WHERE %s RETURNING %s' % (
770            self._escape_qualified_name(cl), values, where, ret)
771        self._do_debug(q, params)
772        res = self.db.query(q, params)
773        res = res.dictresult()[0]
774        for n, value in res.items():
775            if n == 'oid':
776                n = qoid
777            elif attnames.get(n) == 'bytea' and value is not None:
778                value = self.unescape_bytea(value)
779            d[n] = value
780        return d
781
782    def clear(self, cl, a=None):
783        """Clear all the attributes to values determined by the types.
784
785        Numeric types are set to 0, Booleans are set to false, and everything
786        else is set to the empty string.  If the array argument is present,
787        it is used as the array and any entries matching attribute names are
788        cleared with everything else left unchanged.
789
790        """
791        # At some point we will need a way to get defaults from a table.
792        if a is None:
793            a = {}  # empty if argument is not present
794        attnames = self.get_attnames(cl)
795        for n, t in attnames.items():
796            if n == 'oid':
797                continue
798            if t in ('int', 'integer', 'smallint', 'bigint',
799                    'float', 'real', 'double precision',
800                    'num', 'numeric', 'money'):
801                a[n] = 0
802            elif t in ('bool', 'boolean'):
803                a[n] = self._make_bool(False)
804            else:
805                a[n] = ''
806        return a
807
808    def delete(self, cl, d=None, **kw):
809        """Delete an existing row in a database table.
810
811        This method deletes the row from a table.  It deletes based on the
812        OID value as munged by get or passed as keyword, or on the primary
813        key of the table.  The return value is the number of deleted rows
814        (i.e. 0 if the row did not exist and 1 if the row was deleted).
815
816        """
817        # Like update, delete works on the oid.
818        # One day we will be testing that the record to be deleted
819        # isn't referenced somewhere (or else PostgreSQL will).
820        # Note that we only accept oid key from named args for safety
821        qoid = _oid_key(cl)
822        if 'oid' in kw:
823            kw[qoid] = kw['oid']
824            del kw['oid']
825        if d is None:
826            d = {}
827        d.update(kw)
828        params = []
829        param = partial(self._prepare_param, params=params)
830        if qoid in d:
831            where = 'oid = %s' % param(d[qoid], 'int')
832        else:
833            try:
834                keyname = self.pkey(cl)
835            except KeyError:
836                raise _prg_error('Class %s has no primary key' % cl)
837            if isinstance(keyname, basestring):
838                keyname = (keyname,)
839            attnames = self.get_attnames(cl)
840            col = self.escape_identifier
841            try:
842                where = ' AND '.join(['%s = %s'
843                    % (col(k), param(d[k], attnames[k])) for k in keyname])
844            except KeyError:
845                raise _prg_error('Delete needs primary key or oid.')
846        q = 'DELETE FROM %s WHERE %s' % (
847            self._escape_qualified_name(cl), where)
848        self._do_debug(q, params)
849        return int(self.db.query(q, params))
850
851    def notification_handler(self, event, callback, arg_dict={}, timeout=None):
852        """Get notification handler that will run the given callback."""
853        return NotificationHandler(self.db, event, callback, arg_dict, timeout)
854
855
856# if run as script, print some information
857
858if __name__ == '__main__':
859    print('PyGreSQL version' + version)
860    print('')
861    print(__doc__)
Note: See TracBrowser for help on using the repository browser.