source: trunk/pg.py @ 730

Last change on this file since 730 was 730, checked in by cito, 4 years ago

Use query parameters instead of inline values

The single row methods of the DB wrapper class created queries with inline values
instead of passing them separately as parameters, even though our query method
does have this capability. Using query parameters also spares us a lot of quoting
and escaping that is necessary when passing values inline.

  • Property svn:keywords set to Id
File size: 30.5 KB
Line 
1#! /usr/bin/python
2#
3# pg.py
4#
5# $Id: pg.py 730 2016-01-13 01:58:54Z cito $
6#
7
8"""PyGreSQL classic interface.
9
10This pg module implements some basic database management stuff.
11It includes the _pg module and builds on it, providing the higher
12level wrapper class named DB with additional functionality.
13This is known as the "classic" ("old style") PyGreSQL interface.
14For a DB-API 2 compliant interface use the newer pgdb module.
15
16"""
17
18# Copyright (c) 1997-2013 by D'Arcy J.M. Cain.
19#
20# Contributions made by Ch. Zwerschke and others.
21#
22# The notification handler is based on pgnotify which is
23# Copyright (c) 2001 Ng Pheng Siong. All rights reserved.
24#
25# Permission to use, copy, modify, and distribute this software and its
26# documentation for any purpose and without fee is hereby granted,
27# provided that the above copyright notice appear in all copies and that
28# both that copyright notice and this permission notice appear in
29# supporting documentation.
30
31from __future__ import print_function
32
33from _pg import *
34
35import select
36import warnings
37
38from decimal import Decimal
39from collections import namedtuple
40from functools import partial
41
42try:
43    basestring
44except NameError:  # Python >= 3.0
45    basestring = (str, bytes)
46
47set_decimal(Decimal)
48
49
50# Auxiliary functions that are independent from a DB connection:
51
52def _quote_class_name(cl):
53    """Quote a class name.
54
55    Class names are always quoted unless they contain a dot.
56    In this ambiguous case quotes must be added manually.
57
58    """
59    if '.' not in cl:
60        cl = '"%s"' % cl
61    return cl
62
63
64def _quote_class_param(cl, param):
65    """Quote parameter representing a class name.
66
67    The parameter is automatically quoted unless the class name contains a dot.
68    In this ambiguous case quotes must be added manually.
69
70    """
71    if isinstance(param, int):
72        param = "$%d" % param
73    if '.' not in cl:
74        param = 'quote_ident(%s)' % (param,)
75    return param
76
77
78def _oid_key(cl):
79    """Build oid key from qualified class name."""
80    return 'oid(%s)' % cl
81
82
83def _simpletype(typ):
84    """Determine a simplified name a pg_type name."""
85    if typ.startswith('bool'):
86        return 'bool'
87    if typ.startswith(('abstime', 'date', 'interval', 'timestamp')):
88        return 'date'
89    if typ.startswith(('cid', 'oid', 'int', 'xid')):
90        return 'int'
91    if typ.startswith('float'):
92        return 'float'
93    if typ.startswith('numeric'):
94        return 'num'
95    if typ.startswith('money'):
96        return 'money'
97    if typ.startswith('bytea'):
98        return 'bytea'
99    return 'text'
100
101
102def _namedresult(q):
103    """Get query result as named tuples."""
104    row = namedtuple('Row', q.listfields())
105    return [row(*r) for r in q.getresult()]
106
107set_namedresult(_namedresult)
108
109
110def _db_error(msg, cls=DatabaseError):
111    """Returns DatabaseError with empty sqlstate attribute."""
112    error = cls(msg)
113    error.sqlstate = None
114    return error
115
116
117def _int_error(msg):
118    """Returns InternalError."""
119    return _db_error(msg, InternalError)
120
121
122def _prg_error(msg):
123    """Returns ProgrammingError."""
124    return _db_error(msg, ProgrammingError)
125
126
127class NotificationHandler(object):
128    """A PostgreSQL client-side asynchronous notification handler."""
129
130    def __init__(self, db, event, callback, arg_dict=None, timeout=None):
131        """Initialize the notification handler.
132
133        db       - PostgreSQL connection object.
134        event    - Event (notification channel) to LISTEN for.
135        callback - Event callback function.
136        arg_dict - A dictionary passed as the argument to the callback.
137        timeout  - Timeout in seconds; a floating point number denotes
138                   fractions of seconds. If it is absent or None, the
139                   callers will never time out.
140
141        """
142        if isinstance(db, DB):
143            db = db.db
144        self.db = db
145        self.event = event
146        self.stop_event = 'stop_%s' % event
147        self.listening = False
148        self.callback = callback
149        if arg_dict is None:
150            arg_dict = {}
151        self.arg_dict = arg_dict
152        self.timeout = timeout
153
154    def __del__(self):
155        self.close()
156
157    def close(self):
158        """Stop listening and close the connection."""
159        if self.db:
160            self.unlisten()
161            self.db.close()
162            self.db = None
163
164    def listen(self):
165        """Start listening for the event and the stop event."""
166        if not self.listening:
167            self.db.query('listen "%s"' % self.event)
168            self.db.query('listen "%s"' % self.stop_event)
169            self.listening = True
170
171    def unlisten(self):
172        """Stop listening for the event and the stop event."""
173        if self.listening:
174            self.db.query('unlisten "%s"' % self.event)
175            self.db.query('unlisten "%s"' % self.stop_event)
176            self.listening = False
177
178    def notify(self, db=None, stop=False, payload=None):
179        """Generate a notification.
180
181        Note: If the main loop is running in another thread, you must pass
182        a different database connection to avoid a collision.
183
184        """
185        if not db:
186            db = self.db
187        if self.listening:
188            q = 'notify "%s"' % (self.stop_event if stop else self.event)
189            if payload:
190                q += ", '%s'" % payload
191            return db.query(q)
192
193    def __call__(self, close=False):
194        """Invoke the notification handler.
195
196        The handler is a loop that actually LISTENs for two NOTIFY messages:
197
198        <event> and stop_<event>.
199
200        When either of these NOTIFY messages are received, its associated
201        'pid' and 'event' are inserted into <arg_dict>, and the callback is
202        invoked with <arg_dict>. If the NOTIFY message is stop_<event>, the
203        handler UNLISTENs both <event> and stop_<event> and exits.
204
205        Note: If you run this loop in another thread, don't use the same
206        database connection for database operations in the main thread.
207
208        """
209        self.listen()
210        _ilist = [self.db.fileno()]
211
212        while self.listening:
213            ilist, _olist, _elist = select.select(_ilist, [], [], self.timeout)
214            if ilist:
215                while self.listening:
216                    notice = self.db.getnotify()
217                    if not notice:  # no more messages
218                        break
219                    event, pid, extra = notice
220                    if event not in (self.event, self.stop_event):
221                        self.unlisten()
222                        raise _db_error(
223                            'listening for "%s" and "%s", but notified of "%s"'
224                            % (self.event, self.stop_event, event))
225                    if event == self.stop_event:
226                        self.unlisten()
227                    self.arg_dict['pid'] = pid
228                    self.arg_dict['event'] = event
229                    self.arg_dict['extra'] = extra
230                    self.callback(self.arg_dict)
231            else:   # we timed out
232                self.unlisten()
233                self.callback(None)
234
235
236def pgnotify(*args, **kw):
237    """Same as NotificationHandler, under the traditional name."""
238    warnings.warn("pgnotify is deprecated, use NotificationHandler instead.",
239        DeprecationWarning, stacklevel=2)
240    return NotificationHandler(*args, **kw)
241
242
243# The actual PostGreSQL database connection interface:
244
245class DB(object):
246    """Wrapper class for the _pg connection type."""
247
248    def __init__(self, *args, **kw):
249        """Create a new connection.
250
251        You can pass either the connection parameters or an existing
252        _pg or pgdb connection. This allows you to use the methods
253        of the classic pg interface with a DB-API 2 pgdb connection.
254
255        """
256        if not args and len(kw) == 1:
257            db = kw.get('db')
258        elif not kw and len(args) == 1:
259            db = args[0]
260        else:
261            db = None
262        if db:
263            if isinstance(db, DB):
264                db = db.db
265            else:
266                try:
267                    db = db._cnx
268                except AttributeError:
269                    pass
270        if not db or not hasattr(db, 'db') or not hasattr(db, 'query'):
271            db = connect(*args, **kw)
272            self._closeable = True
273        else:
274            self._closeable = False
275        self.db = db
276        self.dbname = db.db
277        self._regtypes = False
278        self._attnames = {}
279        self._pkeys = {}
280        self._privileges = {}
281        self._args = args, kw
282        self.debug = None  # For debugging scripts, this can be set
283            # * to a string format specification (e.g. in CGI set to "%s<BR>"),
284            # * to a file object to write debug statements or
285            # * to a callable object which takes a string argument
286            # * to any other true value to just print debug statements
287
288    def __getattr__(self, name):
289        # All undefined members are same as in underlying connection:
290        if self.db:
291            return getattr(self.db, name)
292        else:
293            raise _int_error('Connection is not valid')
294
295    def __dir__(self):
296        # Custom dir function including the attributes of the connection:
297        attrs = set(self.__class__.__dict__)
298        attrs.update(self.__dict__)
299        attrs.update(dir(self.db))
300        return sorted(attrs)
301
302    # Context manager methods
303
304    def __enter__(self):
305        """Enter the runtime context. This will start a transaction."""
306        self.begin()
307        return self
308
309    def __exit__(self, et, ev, tb):
310        """Exit the runtime context. This will end the transaction."""
311        if et is None and ev is None and tb is None:
312            self.commit()
313        else:
314            self.rollback()
315
316    # Auxiliary methods
317
318    def _do_debug(self, *args):
319        """Print a debug message."""
320        if self.debug:
321            s = '\n'.join(args)
322            if isinstance(self.debug, basestring):
323                print(self.debug % s)
324            elif hasattr(self.debug, 'write'):
325                self.debug.write(s + '\n')
326            elif callable(self.debug):
327                self.debug(s)
328            else:
329                print(s)
330
331    @staticmethod
332    def _make_bool(d):
333        """Get boolean value corresponding to d."""
334        return bool(d) if get_bool() else ('t' if d else 'f')
335
336    _bool_true_values = frozenset('t true 1 y yes on'.split())
337
338    def _prepare_bool(self, d):
339        """Prepare a boolean parameter."""
340        if isinstance(d, basestring):
341            if not d:
342                return None
343            d = d.lower() in self._bool_true_values
344        return 't' if d else 'f'
345
346    _date_literals = frozenset('current_date current_time'
347        ' current_timestamp localtime localtimestamp'.split())
348
349    def _prepare_date(self, d):
350        """Prepare a date parameter."""
351        if not d:
352            return None
353        if isinstance(d, basestring) and d.lower() in self._date_literals:
354            raise ValueError
355        return d
356
357    def _prepare_num(self, d):
358        """Prepare a numeric parameter."""
359        if not d and d != 0:
360            return None
361        return d
362
363    def _prepare_bytea(self, d):
364        return self.escape_bytea(d)
365
366    _prepare_funcs = dict(  # quote methods for each type
367        bool=_prepare_bool, date=_prepare_date,
368        int=_prepare_num, num=_prepare_num, float=_prepare_num,
369        money=_prepare_num, bytea=_prepare_bytea)
370
371    def _prepare_param(self, value, typ, params):
372        """Prepare and add a parameter to the list."""
373        if value is not None and typ != 'text':
374            try:
375                prepare = self._prepare_funcs[typ]
376            except KeyError:
377                pass
378            else:
379                try:
380                    value = prepare(self, value)
381                except ValueError:
382                    return value
383        params.append(value)
384        return '$%d' % len(params)
385
386    # Public methods
387
388    # escape_string and escape_bytea exist as methods,
389    # so we define unescape_bytea as a method as well
390    unescape_bytea = staticmethod(unescape_bytea)
391
392    def close(self):
393        """Close the database connection."""
394        # Wraps shared library function so we can track state.
395        if self._closeable:
396            if self.db:
397                self.db.close()
398                self.db = None
399            else:
400                raise _int_error('Connection already closed')
401
402    def reset(self):
403        """Reset connection with current parameters.
404
405        All derived queries and large objects derived from this connection
406        will not be usable after this call.
407
408        """
409        if self.db:
410            self.db.reset()
411        else:
412            raise _int_error('Connection already closed')
413
414    def reopen(self):
415        """Reopen connection to the database.
416
417        Used in case we need another connection to the same database.
418        Note that we can still reopen a database that we have closed.
419
420        """
421        # There is no such shared library function.
422        if self._closeable:
423            db = connect(*self._args[0], **self._args[1])
424            if self.db:
425                self.db.close()
426            self.db = db
427
428    def begin(self, mode=None):
429        """Begin a transaction."""
430        qstr = 'BEGIN'
431        if mode:
432            qstr += ' ' + mode
433        return self.query(qstr)
434
435    start = begin
436
437    def commit(self):
438        """Commit the current transaction."""
439        return self.query('COMMIT')
440
441    end = commit
442
443    def rollback(self, name=None):
444        """Rollback the current transaction."""
445        qstr = 'ROLLBACK'
446        if name:
447            qstr += ' TO ' + name
448        return self.query(qstr)
449
450    def savepoint(self, name=None):
451        """Define a new savepoint within the current transaction."""
452        qstr = 'SAVEPOINT'
453        if name:
454            qstr += ' ' + name
455        return self.query(qstr)
456
457    def release(self, name):
458        """Destroy a previously defined savepoint."""
459        return self.query('RELEASE ' + name)
460
461    def query(self, qstr, *args):
462        """Executes a SQL command string.
463
464        This method simply sends a SQL query to the database. If the query is
465        an insert statement that inserted exactly one row into a table that
466        has OIDs, the return value is the OID of the newly inserted row.
467        If the query is an update or delete statement, or an insert statement
468        that did not insert exactly one row in a table with OIDs, then the
469        number of rows affected is returned as a string. If it is a statement
470        that returns rows as a result (usually a select statement, but maybe
471        also an "insert/update ... returning" statement), this method returns
472        a Query object that can be accessed via getresult() or dictresult()
473        or simply printed. Otherwise, it returns `None`.
474
475        The query can contain numbered parameters of the form $1 in place
476        of any data constant. Arguments given after the query string will
477        be substituted for the corresponding numbered parameter. Parameter
478        values can also be given as a single list or tuple argument.
479
480        """
481        # Wraps shared library function for debugging.
482        if not self.db:
483            raise _int_error('Connection is not valid')
484        self._do_debug(qstr)
485        return self.db.query(qstr, args)
486
487    def pkey(self, cl, flush=False):
488        """This method gets or sets the primary key of a class.
489
490        Composite primary keys are represented as frozensets. Note that
491        this raises a KeyError if the table does not have a primary key.
492
493        If flush is set then the internal cache for primary keys will
494        be flushed. This may be necessary after the database schema or
495        the search path has been changed.
496
497        """
498        pkeys = self._pkeys
499        if flush:
500            pkeys.clear()
501            self._do_debug('pkey cache has been flushed')
502        try:  # cache lookup
503            pkey = pkeys[cl]
504        except KeyError:  # cache miss, check the database
505            q = ("SELECT a.attname FROM pg_index i"
506                " JOIN pg_attribute a ON a.attrelid = i.indrelid"
507                " AND a.attnum = ANY(i.indkey)"
508                " AND NOT a.attisdropped"
509                " WHERE i.indrelid=%s::regclass"
510                " AND i.indisprimary" % _quote_class_param(cl, 1))
511            pkey = self.db.query(q, (cl,)).getresult()
512            if not pkey:
513                raise KeyError('Class %s has no primary key' % cl)
514            if len(pkey) > 1:
515                pkey = frozenset([k[0] for k in pkey])
516            else:
517                pkey = pkey[0][0]
518            pkeys[cl] = pkey  # cache it
519        return pkey
520
521    def get_databases(self):
522        """Get list of databases in the system."""
523        return [s[0] for s in
524            self.db.query('SELECT datname FROM pg_database').getresult()]
525
526    def get_relations(self, kinds=None):
527        """Get list of relations in connected database of specified kinds.
528
529        If kinds is None or empty, all kinds of relations are returned.
530        Otherwise kinds can be a string or sequence of type letters
531        specifying which kind of relations you want to list.
532
533        """
534        where = " AND r.relkind IN (%s)" % ','.join(
535            ["'%s'" % k for k in kinds]) if kinds else ''
536        q = ("SELECT quote_ident(s.nspname)||'.'||quote_ident(r.relname)"
537            " FROM pg_class r"
538            " JOIN pg_namespace s ON s.oid = r.relnamespace"
539            " WHERE s.nspname NOT SIMILAR"
540            " TO 'pg/_%%|information/_schema' ESCAPE '/' %s"
541            " ORDER BY s.nspname, r.relname") % where
542        return [r[0] for r in self.db.query(q).getresult()]
543
544    def get_tables(self):
545        """Return list of tables in connected database."""
546        return self.get_relations('r')
547
548    def get_attnames(self, cl, flush=False):
549        """Given the name of a table, digs out the set of attribute names.
550
551        Returns a dictionary of attribute names (the names are the keys,
552        the values are the names of the attributes' types).
553
554        If the optional newattnames exists, it must be a dictionary and
555        will become the new attribute names dictionary.
556
557        By default, only a limited number of simple types will be returned.
558        You can get the regular types after calling use_regtypes(True).
559
560        """
561        attnames = self._attnames
562        if flush:
563            attnames.clear()
564            self._do_debug('pkey cache has been flushed')
565        try:  # cache lookup
566            names = attnames[cl]
567        except KeyError:  # cache miss, check the database
568            q = ("SELECT a.attname, t.typname%s"
569                " FROM pg_attribute a"
570                " JOIN pg_type t ON t.oid = a.atttypid"
571                " WHERE a.attrelid = %s::regclass"
572                " AND (a.attnum > 0 OR a.attname = 'oid')"
573                " AND NOT a.attisdropped") % (
574                    '::regtype' if self._regtypes else '',
575                    _quote_class_param(cl, 1))
576            names = self.db.query(q, (cl,)).getresult()
577            if not names:
578                raise KeyError('Class %s does not exist' % cl)
579            if self._regtypes:
580                names = dict(names)
581            else:
582                names = dict((name, _simpletype(typ)) for name, typ in names)
583            attnames[cl] = names  # cache it
584        return names
585
586    def use_regtypes(self, regtypes=None):
587        """Use regular type names instead of simplified type names."""
588        if regtypes is None:
589            return self._regtypes
590        else:
591            regtypes = bool(regtypes)
592            if regtypes != self._regtypes:
593                self._regtypes = regtypes
594                self._attnames.clear()
595            return regtypes
596
597    def has_table_privilege(self, cl, privilege='select'):
598        """Check whether current user has specified table privilege."""
599        privilege = privilege.lower()
600        try:  # ask cache
601            return self._privileges[(cl, privilege)]
602        except KeyError:  # cache miss, ask the database
603            q = "SELECT has_table_privilege(%s, $2)" % (
604                _quote_class_param(cl, 1),)
605            q = self.db.query(q, (cl, privilege))
606            ret = q.getresult()[0][0] == self._make_bool(True)
607            self._privileges[(cl, privilege)] = ret  # cache it
608            return ret
609
610    def get(self, cl, arg, keyname=None):
611        """Get a row from a database table or view.
612
613        This method is the basic mechanism to get a single row.  The keyname
614        that the key specifies a unique row.  If keyname is not specified
615        then the primary key for the table is used.  If arg is a dictionary
616        then the value for the key is taken from it and it is modified to
617        include the new values, replacing existing values where necessary.
618        For a composite key, keyname can also be a sequence of key names.
619        The OID is also put into the dictionary if the table has one, but
620        in order to allow the caller to work with multiple tables, it is
621        munged as "oid(cl)".
622
623        """
624        if cl.endswith('*'):  # scan descendant tables?
625            cl = cl[:-1].rstrip()  # need parent table name
626        # build qualified class name
627        # To allow users to work with multiple tables,
628        # we munge the name of the "oid" key
629        qoid = _oid_key(cl)
630        if not keyname:
631            # use the primary key by default
632            try:
633                keyname = self.pkey(cl)
634            except KeyError:
635                raise _prg_error('Class %s has no primary key' % cl)
636        attnames = self.get_attnames(cl)
637        params = []
638        param = partial(self._prepare_param, params=params)
639        # We want the oid for later updates if that isn't the key
640        if keyname == 'oid':
641            if isinstance(arg, dict):
642                if qoid not in arg:
643                    raise _db_error('%s not in arg' % qoid)
644            else:
645                arg = {qoid: arg}
646            what = '*'
647            where = 'oid = %s' % param(arg[qoid], 'int')
648        else:
649            if isinstance(keyname, basestring):
650                keyname = (keyname,)
651            if not isinstance(arg, dict):
652                if len(keyname) > 1:
653                    raise _prg_error('Composite key needs dict as arg')
654                arg = dict([(k, arg) for k in keyname])
655            what = ', '.join(attnames)
656            where = ' AND '.join(['%s = %s'
657                % (k, param(arg[k], attnames[k])) for k in keyname])
658        q = 'SELECT %s FROM %s WHERE %s LIMIT 1' % (
659            what, _quote_class_name(cl), where)
660        self._do_debug(q, params)
661        res = self.db.query(q, params).dictresult()
662        if not res:
663            raise _db_error('No such record in %s where %s' % (cl, where))
664        for n, value in res[0].items():
665            if n == 'oid':
666                n = qoid
667            elif attnames.get(n) == 'bytea':
668                value = self.unescape_bytea(value)
669            arg[n] = value
670        return arg
671
672    def insert(self, cl, d=None, **kw):
673        """Insert a row into a database table.
674
675        This method inserts a row into a table.  The name of the table must
676        be passed as the first parameter.  The other parameters are used for
677        providing the data of the row that shall be inserted into the table.
678        If a dictionary is supplied as the second parameter, it starts with
679        that.  Otherwise it uses a blank dictionary. Either way the dictionary
680        is updated from the keywords.
681
682        The dictionary is then, if possible, reloaded with the values actually
683        inserted in order to pick up values modified by rules, triggers, etc.
684
685        Note: The method currently doesn't support insert into views
686        although PostgreSQL does.
687
688        """
689        qoid = _oid_key(cl)
690        if d is None:
691            d = {}
692        d.update(kw)
693        attnames = self.get_attnames(cl)
694        params = []
695        param = partial(self._prepare_param, params=params)
696        names, values = [], []
697        for n in attnames:
698            if n != 'oid' and n in d:
699                names.append('"%s"' % n)
700                values.append(param(d[n], attnames[n]))
701        names, values = ', '.join(names), ', '.join(values)
702        selectable = self.has_table_privilege(cl)
703        if selectable:
704            ret = ' RETURNING %s*' % ('oid, ' if 'oid' in attnames else '')
705        else:
706            ret = ''
707        q = 'INSERT INTO %s (%s) VALUES (%s)%s' % (
708            _quote_class_name(cl), names, values, ret)
709        self._do_debug(q, params)
710        res = self.db.query(q, params)
711        if ret:
712            res = res.dictresult()[0]
713            for n, value in res.items():
714                if n == 'oid':
715                    n = qoid
716                elif attnames.get(n) == 'bytea' and value is not None:
717                    value = self.unescape_bytea(value)
718                d[n] = value
719        elif isinstance(res, int):
720            d[qoid] = res
721            if selectable:
722                self.get(cl, d, 'oid')
723        elif selectable:
724            if qoid in d:
725                self.get(cl, d, 'oid')
726            else:
727                try:
728                    self.get(cl, d)
729                except ProgrammingError:
730                    pass  # table has no primary key
731        return d
732
733    def update(self, cl, d=None, **kw):
734        """Update an existing row in a database table.
735
736        Similar to insert but updates an existing row.  The update is based
737        on the OID value as munged by get or passed as keyword, or on the
738        primary key of the table.  The dictionary is modified, if possible,
739        to reflect any changes caused by the update due to triggers, rules,
740        default values, etc.
741
742        """
743        # Update always works on the oid which get returns if available,
744        # otherwise use the primary key.  Fail if neither.
745        # Note that we only accept oid key from named args for safety
746        qoid = _oid_key(cl)
747        if 'oid' in kw:
748            kw[qoid] = kw['oid']
749            del kw['oid']
750        if d is None:
751            d = {}
752        d.update(kw)
753        attnames = self.get_attnames(cl)
754        params = []
755        param = partial(self._prepare_param, params=params)
756        if qoid in d:
757            where = 'oid = %s' % param(d[qoid], 'int')
758            keyname = ()
759        else:
760            try:
761                keyname = self.pkey(cl)
762            except KeyError:
763                raise _prg_error('Class %s has no primary key' % cl)
764            if isinstance(keyname, basestring):
765                keyname = (keyname,)
766            try:
767                where = ' AND '.join(['%s = %s'
768                    % (k, param(d[k], attnames[k])) for k in keyname])
769            except KeyError:
770                raise _prg_error('Update needs primary key or oid.')
771        values = []
772        for n in attnames:
773            if n in d and n not in keyname:
774                values.append('%s = %s' % (n, param(d[n], attnames[n])))
775        if not values:
776            return d
777        values = ', '.join(values)
778        selectable = self.has_table_privilege(cl)
779        if selectable:
780            ret = ' RETURNING %s*' % ('oid, ' if 'oid' in attnames else '')
781        else:
782            ret = ''
783        q = 'UPDATE %s SET %s WHERE %s%s' % (
784            _quote_class_name(cl), values, where, ret)
785        self._do_debug(q, params)
786        res = self.db.query(q, params)
787        if ret:
788            res = res.dictresult()[0]
789            for n, value in res.items():
790                if n == 'oid':
791                    n = qoid
792                elif attnames.get(n) == 'bytea' and value is not None:
793                    value = self.unescape_bytea(value)
794                d[n] = value
795        else:
796            if selectable:
797                if qoid in d:
798                    self.get(cl, d, 'oid')
799                else:
800                    self.get(cl, d)
801        return d
802
803    def clear(self, cl, a=None):
804        """Clear all the attributes to values determined by the types.
805
806        Numeric types are set to 0, Booleans are set to false, and everything
807        else is set to the empty string.  If the array argument is present,
808        it is used as the array and any entries matching attribute names are
809        cleared with everything else left unchanged.
810
811        """
812        # At some point we will need a way to get defaults from a table.
813        if a is None:
814            a = {}  # empty if argument is not present
815        attnames = self.get_attnames(cl)
816        for n, t in attnames.items():
817            if n == 'oid':
818                continue
819            if t in ('int', 'integer', 'smallint', 'bigint',
820                    'float', 'real', 'double precision',
821                    'num', 'numeric', 'money'):
822                a[n] = 0
823            elif t in ('bool', 'boolean'):
824                a[n] = self._make_bool(False)
825            else:
826                a[n] = ''
827        return a
828
829    def delete(self, cl, d=None, **kw):
830        """Delete an existing row in a database table.
831
832        This method deletes the row from a table.  It deletes based on the
833        OID value as munged by get or passed as keyword, or on the primary
834        key of the table.  The return value is the number of deleted rows
835        (i.e. 0 if the row did not exist and 1 if the row was deleted).
836
837        """
838        # Like update, delete works on the oid.
839        # One day we will be testing that the record to be deleted
840        # isn't referenced somewhere (or else PostgreSQL will).
841        # Note that we only accept oid key from named args for safety
842        qoid = _oid_key(cl)
843        if 'oid' in kw:
844            kw[qoid] = kw['oid']
845            del kw['oid']
846        if d is None:
847            d = {}
848        d.update(kw)
849        params = []
850        param = partial(self._prepare_param, params=params)
851        if qoid in d:
852            where = 'oid = %s' % param(d[qoid], 'int')
853        else:
854            try:
855                keyname = self.pkey(cl)
856            except KeyError:
857                raise _prg_error('Class %s has no primary key' % cl)
858            if isinstance(keyname, basestring):
859                keyname = (keyname,)
860            attnames = self.get_attnames(cl)
861            try:
862                where = ' AND '.join(['%s = %s'
863                    % (k, param(d[k], attnames[k])) for k in keyname])
864            except KeyError:
865                raise _prg_error('Delete needs primary key or oid.')
866        q = 'DELETE FROM %s WHERE %s' % (_quote_class_name(cl), where)
867        self._do_debug(q, params)
868        return int(self.db.query(q, params))
869
870    def notification_handler(self, event, callback, arg_dict={}, timeout=None):
871        """Get notification handler that will run the given callback."""
872        return NotificationHandler(self.db, event, callback, arg_dict, timeout)
873
874
875# if run as script, print some information
876
877if __name__ == '__main__':
878    print('PyGreSQL version' + version)
879    print('')
880    print(__doc__)
Note: See TracBrowser for help on using the repository browser.