source: trunk/module/pg.py @ 706

Last change on this file since 706 was 706, checked in by cito, 4 years ago

Make sure DB methods respect the new bool option

Two DB methods assumed that booleans are always returned as strings,
which is no longer true when the set_bool() option is activated.
Added a test run with different global options to make sure that
no DB methods make such tacit assumptions about these options.

  • Property svn:keywords set to Id
File size: 34.6 KB
Line 
1#! /usr/bin/python
2#
3# pg.py
4#
5# $Id: pg.py 706 2016-01-09 22:53:44Z cito $
6#
7
8"""PyGreSQL classic interface.
9
10This pg module implements some basic database management stuff.
11It includes the _pg module and builds on it, providing the higher
12level wrapper class named DB with additional functionality.
13This is known as the "classic" ("old style") PyGreSQL interface.
14For a DB-API 2 compliant interface use the newer pgdb module.
15
16"""
17
18# Copyright (c) 1997-2013 by D'Arcy J.M. Cain.
19#
20# Contributions made by Ch. Zwerschke and others.
21#
22# The notification handler is based on pgnotify which is
23# Copyright (c) 2001 Ng Pheng Siong. All rights reserved.
24#
25# Permission to use, copy, modify, and distribute this software and its
26# documentation for any purpose and without fee is hereby granted,
27# provided that the above copyright notice appear in all copies and that
28# both that copyright notice and this permission notice appear in
29# supporting documentation.
30
31from __future__ import print_function
32
33from _pg import *
34
35import select
36import warnings
37
38from decimal import Decimal
39from collections import namedtuple
40
41try:
42    basestring
43except NameError:  # Python >= 3.0
44    basestring = (str, bytes)
45
46set_decimal(Decimal)
47
48
49# Auxiliary functions which are independent from a DB connection:
50
51def _is_quoted(s):
52    """Check whether this string is a quoted identifier."""
53    s = s.replace('_', 'a')
54    return not s.isalnum() or s[:1].isdigit() or s != s.lower()
55
56
57def _is_unquoted(s):
58    """Check whether this string is an unquoted identifier."""
59    s = s.replace('_', 'a')
60    return s.isalnum() and not s[:1].isdigit()
61
62
63def _split_first_part(s):
64    """Split the first part of a dot separated string."""
65    s = s.lstrip()
66    if s[:1] == '"':
67        p = []
68        s = s.split('"', 3)[1:]
69        p.append(s[0])
70        while len(s) == 3 and s[1] == '':
71            p.append('"')
72            s = s[2].split('"', 2)
73            p.append(s[0])
74        p = [''.join(p)]
75        s = '"'.join(s[1:]).lstrip()
76        if s:
77            if s[:0] == '.':
78                p.append(s[1:])
79            else:
80                s = _split_first_part(s)
81                p[0] += s[0]
82                if len(s) > 1:
83                    p.append(s[1])
84    else:
85        p = s.split('.', 1)
86        s = p[0].rstrip()
87        if _is_unquoted(s):
88            s = s.lower()
89        p[0] = s
90    return p
91
92
93def _split_parts(s):
94    """Split all parts of a dot separated string."""
95    q = []
96    while s:
97        s = _split_first_part(s)
98        q.append(s[0])
99        if len(s) < 2:
100            break
101        s = s[1]
102    return q
103
104
105def _join_parts(s):
106    """Join all parts of a dot separated string."""
107    return '.'.join([_is_quoted(p) and '"%s"' % p or p for p in s])
108
109
110def _oid_key(qcl):
111    """Build oid key from qualified class name."""
112    return 'oid(%s)' % qcl
113
114
115def _namedresult(q):
116    """Get query result as named tuples."""
117    row = namedtuple('Row', q.listfields())
118    return [row(*r) for r in q.getresult()]
119
120set_namedresult(_namedresult)
121
122
123def _db_error(msg, cls=DatabaseError):
124    """Returns DatabaseError with empty sqlstate attribute."""
125    error = cls(msg)
126    error.sqlstate = None
127    return error
128
129
130def _int_error(msg):
131    """Returns InternalError."""
132    return _db_error(msg, InternalError)
133
134
135def _prg_error(msg):
136    """Returns ProgrammingError."""
137    return _db_error(msg, ProgrammingError)
138
139
140class NotificationHandler(object):
141    """A PostgreSQL client-side asynchronous notification handler."""
142
143    def __init__(self, db, event, callback, arg_dict=None, timeout=None):
144        """Initialize the notification handler.
145
146        db       - PostgreSQL connection object.
147        event    - Event (notification channel) to LISTEN for.
148        callback - Event callback function.
149        arg_dict - A dictionary passed as the argument to the callback.
150        timeout  - Timeout in seconds; a floating point number denotes
151                   fractions of seconds. If it is absent or None, the
152                   callers will never time out.
153
154        """
155        if isinstance(db, DB):
156            db = db.db
157        self.db = db
158        self.event = event
159        self.stop_event = 'stop_%s' % event
160        self.listening = False
161        self.callback = callback
162        if arg_dict is None:
163            arg_dict = {}
164        self.arg_dict = arg_dict
165        self.timeout = timeout
166
167    def __del__(self):
168        self.close()
169
170    def close(self):
171        """Stop listening and close the connection."""
172        if self.db:
173            self.unlisten()
174            self.db.close()
175            self.db = None
176
177    def listen(self):
178        """Start listening for the event and the stop event."""
179        if not self.listening:
180            self.db.query('listen "%s"' % self.event)
181            self.db.query('listen "%s"' % self.stop_event)
182            self.listening = True
183
184    def unlisten(self):
185        """Stop listening for the event and the stop event."""
186        if self.listening:
187            self.db.query('unlisten "%s"' % self.event)
188            self.db.query('unlisten "%s"' % self.stop_event)
189            self.listening = False
190
191    def notify(self, db=None, stop=False, payload=None):
192        """Generate a notification.
193
194        Note: If the main loop is running in another thread, you must pass
195        a different database connection to avoid a collision.
196
197        """
198        if not db:
199            db = self.db
200        if self.listening:
201            q = 'notify "%s"' % (self.stop_event if stop else self.event)
202            if payload:
203                q += ", '%s'" % payload
204            return db.query(q)
205
206    def __call__(self, close=False):
207        """Invoke the notification handler.
208
209        The handler is a loop that actually LISTENs for two NOTIFY messages:
210
211        <event> and stop_<event>.
212
213        When either of these NOTIFY messages are received, its associated
214        'pid' and 'event' are inserted into <arg_dict>, and the callback is
215        invoked with <arg_dict>. If the NOTIFY message is stop_<event>, the
216        handler UNLISTENs both <event> and stop_<event> and exits.
217
218        Note: If you run this loop in another thread, don't use the same
219        database connection for database operations in the main thread.
220
221        """
222        self.listen()
223        _ilist = [self.db.fileno()]
224
225        while self.listening:
226            ilist, _olist, _elist = select.select(_ilist, [], [], self.timeout)
227            if ilist:
228                while self.listening:
229                    notice = self.db.getnotify()
230                    if not notice:  # no more messages
231                        break
232                    event, pid, extra = notice
233                    if event not in (self.event, self.stop_event):
234                        self.unlisten()
235                        raise _db_error(
236                            'listening for "%s" and "%s", but notified of "%s"'
237                            % (self.event, self.stop_event, event))
238                    if event == self.stop_event:
239                        self.unlisten()
240                    self.arg_dict['pid'] = pid
241                    self.arg_dict['event'] = event
242                    self.arg_dict['extra'] = extra
243                    self.callback(self.arg_dict)
244            else:   # we timed out
245                self.unlisten()
246                self.callback(None)
247
248
249def pgnotify(*args, **kw):
250    """Same as NotificationHandler, under the traditional name."""
251    warnings.warn("pgnotify is deprecated, use NotificationHandler instead.",
252        DeprecationWarning, stacklevel=2)
253    return NotificationHandler(*args, **kw)
254
255
256# The actual PostGreSQL database connection interface:
257
258class DB(object):
259    """Wrapper class for the _pg connection type."""
260
261    def __init__(self, *args, **kw):
262        """Create a new connection.
263
264        You can pass either the connection parameters or an existing
265        _pg or pgdb connection. This allows you to use the methods
266        of the classic pg interface with a DB-API 2 pgdb connection.
267
268        """
269        if not args and len(kw) == 1:
270            db = kw.get('db')
271        elif not kw and len(args) == 1:
272            db = args[0]
273        else:
274            db = None
275        if db:
276            if isinstance(db, DB):
277                db = db.db
278            else:
279                try:
280                    db = db._cnx
281                except AttributeError:
282                    pass
283        if not db or not hasattr(db, 'db') or not hasattr(db, 'query'):
284            db = connect(*args, **kw)
285            self._closeable = True
286        else:
287            self._closeable = False
288        self.db = db
289        self.dbname = db.db
290        self._regtypes = False
291        self._attnames = {}
292        self._pkeys = {}
293        self._privileges = {}
294        self._args = args, kw
295        self.debug = None  # For debugging scripts, this can be set
296            # * to a string format specification (e.g. in CGI set to "%s<BR>"),
297            # * to a file object to write debug statements or
298            # * to a callable object which takes a string argument
299            # * to any other true value to just print debug statements
300
301    def __getattr__(self, name):
302        # All undefined members are same as in underlying connection:
303        if self.db:
304            return getattr(self.db, name)
305        else:
306            raise _int_error('Connection is not valid')
307
308    def __dir__(self):
309        # Custom dir function including the attributes of the connection:
310        attrs = set(self.__class__.__dict__)
311        attrs.update(self.__dict__)
312        attrs.update(dir(self.db))
313        return sorted(attrs)
314
315    # Context manager methods
316
317    def __enter__(self):
318        """Enter the runtime context. This will start a transaction."""
319        self.begin()
320        return self
321
322    def __exit__(self, et, ev, tb):
323        """Exit the runtime context. This will end the transaction."""
324        if et is None and ev is None and tb is None:
325            self.commit()
326        else:
327            self.rollback()
328
329    # Auxiliary methods
330
331    def _do_debug(self, s):
332        """Print a debug message."""
333        if self.debug:
334            if isinstance(self.debug, basestring):
335                print(self.debug % s)
336            elif hasattr(self.debug, 'write'):
337                self.debug.write(s + '\n')
338            elif callable(self.debug):
339                self.debug(s)
340            else:
341                print(s)
342
343    @staticmethod
344    def _make_bool(d):
345        """Get boolean value corresponding to d."""
346        return bool(d) if get_bool() else ('t' if d else 'f')
347
348    def _quote_text(self, d):
349        """Quote text value."""
350        if not isinstance(d, basestring):
351            d = str(d)
352        return "'%s'" % self.escape_string(d)
353
354    _bool_true = frozenset('t true 1 y yes on'.split())
355
356    def _quote_bool(self, d):
357        """Quote boolean value."""
358        if isinstance(d, basestring):
359            if not d:
360                return 'NULL'
361            d = d.lower() in self._bool_true
362        return "'t'" if d else "'f'"
363
364    _date_literals = frozenset('current_date current_time'
365        ' current_timestamp localtime localtimestamp'.split())
366
367    def _quote_date(self, d):
368        """Quote date value."""
369        if not d:
370            return 'NULL'
371        if isinstance(d, basestring) and d.lower() in self._date_literals:
372            return d
373        return self._quote_text(d)
374
375    def _quote_num(self, d):
376        """Quote numeric value."""
377        if not d and d != 0:
378            return 'NULL'
379        return str(d)
380
381    def _quote_money(self, d):
382        """Quote money value."""
383        if d is None or d == '':
384            return 'NULL'
385        if not isinstance(d, basestring):
386            d = str(d)
387        return d
388
389    if bytes is str:  # Python < 3.0
390        """Quote bytes value."""
391
392        def _quote_bytea(self, d):
393            return "'%s'" % self.escape_bytea(d)
394
395    else:
396
397        def _quote_bytea(self, d):
398            return "'%s'" % self.escape_bytea(d).decode('ascii')
399
400    _quote_funcs = dict(  # quote methods for each type
401        text=_quote_text, bool=_quote_bool, date=_quote_date,
402        int=_quote_num, num=_quote_num, float=_quote_num,
403        money=_quote_money, bytea=_quote_bytea)
404
405    def _quote(self, d, t):
406        """Return quotes if needed."""
407        if d is None:
408            return 'NULL'
409        try:
410            quote_func = self._quote_funcs[t]
411        except KeyError:
412            quote_func = self._quote_funcs['text']
413        return quote_func(self, d)
414
415    def _split_schema(self, cl):
416        """Return schema and name of object separately.
417
418        This auxiliary function splits off the namespace (schema)
419        belonging to the class with the name cl. If the class name
420        is not qualified, the function is able to determine the schema
421        of the class, taking into account the current search path.
422
423        """
424        s = _split_parts(cl)
425        if len(s) > 1:  # name already qualfied?
426            # should be database.schema.table or schema.table
427            if len(s) > 3:
428                raise _prg_error('Too many dots in class name %s' % cl)
429            schema, cl = s[-2:]
430        else:
431            cl = s[0]
432            # determine search path
433            q = 'SELECT current_schemas(TRUE)'
434            schemas = self.db.query(q).getresult()[0][0][1:-1].split(',')
435            if schemas:  # non-empty path
436                # search schema for this object in the current search path
437                q = ' UNION '.join(
438                    ["SELECT %d::integer AS n, '%s'::name AS nspname"
439                        % s for s in enumerate(schemas)])
440                q = ("SELECT nspname FROM pg_class"
441                    " JOIN pg_namespace"
442                    " ON pg_class.relnamespace = pg_namespace.oid"
443                    " JOIN (%s) AS p USING (nspname)"
444                    " WHERE pg_class.relname = '%s'"
445                    " ORDER BY n LIMIT 1" % (q, cl))
446                schema = self.db.query(q).getresult()
447                if schema:  # schema found
448                    schema = schema[0][0]
449                else:  # object not found in current search path
450                    schema = 'public'
451            else:  # empty path
452                schema = 'public'
453        return schema, cl
454
455    def _add_schema(self, cl):
456        """Ensure that the class name is prefixed with a schema name."""
457        return _join_parts(self._split_schema(cl))
458
459    # Public methods
460
461    # escape_string and escape_bytea exist as methods,
462    # so we define unescape_bytea as a method as well
463    unescape_bytea = staticmethod(unescape_bytea)
464
465    def close(self):
466        """Close the database connection."""
467        # Wraps shared library function so we can track state.
468        if self._closeable:
469            if self.db:
470                self.db.close()
471                self.db = None
472            else:
473                raise _int_error('Connection already closed')
474
475    def reset(self):
476        """Reset connection with current parameters.
477
478        All derived queries and large objects derived from this connection
479        will not be usable after this call.
480
481        """
482        if self.db:
483            self.db.reset()
484        else:
485            raise _int_error('Connection already closed')
486
487    def reopen(self):
488        """Reopen connection to the database.
489
490        Used in case we need another connection to the same database.
491        Note that we can still reopen a database that we have closed.
492
493        """
494        # There is no such shared library function.
495        if self._closeable:
496            db = connect(*self._args[0], **self._args[1])
497            if self.db:
498                self.db.close()
499            self.db = db
500
501    def begin(self, mode=None):
502        """Begin a transaction."""
503        qstr = 'BEGIN'
504        if mode:
505            qstr += ' ' + mode
506        return self.query(qstr)
507
508    start = begin
509
510    def commit(self):
511        """Commit the current transaction."""
512        return self.query('COMMIT')
513
514    end = commit
515
516    def rollback(self, name=None):
517        """Rollback the current transaction."""
518        qstr = 'ROLLBACK'
519        if name:
520            qstr += ' TO ' + name
521        return self.query(qstr)
522
523    def savepoint(self, name=None):
524        """Define a new savepoint within the current transaction."""
525        qstr = 'SAVEPOINT'
526        if name:
527            qstr += ' ' + name
528        return self.query(qstr)
529
530    def release(self, name):
531        """Destroy a previously defined savepoint."""
532        return self.query('RELEASE ' + name)
533
534    def query(self, qstr, *args):
535        """Executes a SQL command string.
536
537        This method simply sends a SQL query to the database. If the query is
538        an insert statement that inserted exactly one row into a table that
539        has OIDs, the return value is the OID of the newly inserted row.
540        If the query is an update or delete statement, or an insert statement
541        that did not insert exactly one row in a table with OIDs, then the
542        number of rows affected is returned as a string. If it is a statement
543        that returns rows as a result (usually a select statement, but maybe
544        also an "insert/update ... returning" statement), this method returns
545        a pgqueryobject that can be accessed via getresult() or dictresult()
546        or simply printed. Otherwise, it returns `None`.
547
548        The query can contain numbered parameters of the form $1 in place
549        of any data constant. Arguments given after the query string will
550        be substituted for the corresponding numbered parameter. Parameter
551        values can also be given as a single list or tuple argument.
552
553        """
554        # Wraps shared library function for debugging.
555        if not self.db:
556            raise _int_error('Connection is not valid')
557        self._do_debug(qstr)
558        return self.db.query(qstr, args)
559
560    def pkey(self, cl, newpkey=None):
561        """This method gets or sets the primary key of a class.
562
563        Composite primary keys are represented as frozensets. Note that
564        this raises an exception if the table does not have a primary key.
565
566        If newpkey is set and is not a dictionary then set that
567        value as the primary key of the class.  If it is a dictionary
568        then replace the _pkeys dictionary with a copy of it.
569
570        """
571        # First see if the caller is supplying a dictionary
572        if isinstance(newpkey, dict):
573            # make sure that all classes have a namespace
574            self._pkeys = dict([
575                (cl if '.' in cl else 'public.' + cl, pkey)
576                for cl, pkey in newpkey.items()])
577            return self._pkeys
578
579        qcl = self._add_schema(cl)  # build fully qualified class name
580        # Check if the caller is supplying a new primary key for the class
581        if newpkey:
582            self._pkeys[qcl] = newpkey
583            return newpkey
584
585        # Get all the primary keys at once
586        if qcl not in self._pkeys:
587            # if not found, check again in case it was added after we started
588            self._pkeys = {}
589            for r in self.db.query(
590                "SELECT pg_namespace.nspname, pg_class.relname,"
591                    " pg_attribute.attname FROM pg_class"
592                " JOIN pg_namespace"
593                    " ON pg_namespace.oid = pg_class.relnamespace"
594                    " AND pg_namespace.nspname NOT LIKE 'pg_%'"
595                " JOIN pg_attribute ON pg_attribute.attrelid = pg_class.oid"
596                    " AND NOT pg_attribute.attisdropped"
597                " JOIN pg_index ON pg_index.indrelid = pg_class.oid"
598                    " AND pg_index.indisprimary"
599                    " AND pg_attribute.attnum"
600                        " = ANY (pg_index.indkey)").getresult():
601                cl, pkey = _join_parts(r[:2]), r[2]
602                self._pkeys.setdefault(cl, []).append(pkey)
603            # (only) for composite primary keys, the values will be frozensets
604            for cl, pkey in self._pkeys.items():
605                self._pkeys[cl] = frozenset(pkey) if len(pkey) > 1 else pkey[0]
606            self._do_debug(self._pkeys)
607
608        # will raise an exception if primary key doesn't exist
609        return self._pkeys[qcl]
610
611    def get_databases(self):
612        """Get list of databases in the system."""
613        return [s[0] for s in
614            self.db.query('SELECT datname FROM pg_database').getresult()]
615
616    def get_relations(self, kinds=None):
617        """Get list of relations in connected database of specified kinds.
618
619            If kinds is None or empty, all kinds of relations are returned.
620            Otherwise kinds can be a string or sequence of type letters
621            specifying which kind of relations you want to list.
622
623        """
624        where = kinds and "pg_class.relkind IN (%s) AND" % ','.join(
625            ["'%s'" % x for x in kinds]) or ''
626        return [_join_parts(x) for x in self.db.query(
627            "SELECT pg_namespace.nspname, pg_class.relname "
628            "FROM pg_class "
629            "JOIN pg_namespace ON pg_namespace.oid = pg_class.relnamespace "
630            "WHERE %s pg_class.relname !~ '^Inv' AND "
631                "pg_class.relname !~ '^pg_' "
632            "ORDER BY 1, 2" % where).getresult()]
633
634    def get_tables(self):
635        """Return list of tables in connected database."""
636        return self.get_relations('r')
637
638    def get_attnames(self, cl, newattnames=None):
639        """Given the name of a table, digs out the set of attribute names.
640
641        Returns a dictionary of attribute names (the names are the keys,
642        the values are the names of the attributes' types).
643        If the optional newattnames exists, it must be a dictionary and
644        will become the new attribute names dictionary.
645
646        By default, only a limited number of simple types will be returned.
647        You can get the regular types after calling use_regtypes(True).
648
649        """
650        if isinstance(newattnames, dict):
651            self._attnames = newattnames
652            return
653        elif newattnames:
654            raise _prg_error('If supplied, newattnames must be a dictionary')
655        cl = self._split_schema(cl)  # split into schema and class
656        qcl = _join_parts(cl)  # build fully qualified name
657        # May as well cache them:
658        if qcl in self._attnames:
659            return self._attnames[qcl]
660        if qcl not in self.get_relations('rv'):
661            raise _prg_error('Class %s does not exist' % qcl)
662
663        q = "SELECT pg_attribute.attname, pg_type.typname"
664        if self._regtypes:
665            q += "::regtype"
666        q += (" FROM pg_class"
667            " JOIN pg_namespace ON pg_class.relnamespace = pg_namespace.oid"
668            " JOIN pg_attribute ON pg_attribute.attrelid = pg_class.oid"
669            " JOIN pg_type ON pg_type.oid = pg_attribute.atttypid"
670            " WHERE pg_namespace.nspname = '%s' AND pg_class.relname = '%s'"
671            " AND (pg_attribute.attnum > 0 OR pg_attribute.attname = 'oid')"
672            " AND NOT pg_attribute.attisdropped") % cl
673        q = self.db.query(q).getresult()
674
675        if self._regtypes:
676            t = dict(q)
677        else:
678            t = {}
679            for att, typ in q:
680                if typ.startswith('bool'):
681                    typ = 'bool'
682                elif typ.startswith('abstime'):
683                    typ = 'date'
684                elif typ.startswith('date'):
685                    typ = 'date'
686                elif typ.startswith('interval'):
687                    typ = 'date'
688                elif typ.startswith('timestamp'):
689                    typ = 'date'
690                elif typ.startswith('oid'):
691                    typ = 'int'
692                elif typ.startswith('int'):
693                    typ = 'int'
694                elif typ.startswith('float'):
695                    typ = 'float'
696                elif typ.startswith('numeric'):
697                    typ = 'num'
698                elif typ.startswith('money'):
699                    typ = 'money'
700                elif typ.startswith('bytea'):
701                    typ = 'bytea'
702                else:
703                    typ = 'text'
704                t[att] = typ
705
706        self._attnames[qcl] = t  # cache it
707        return self._attnames[qcl]
708
709    def use_regtypes(self, regtypes=None):
710        """Use regular type names instead of simplified type names."""
711        if regtypes is None:
712            return self._regtypes
713        else:
714            regtypes = bool(regtypes)
715            if regtypes != self._regtypes:
716                self._regtypes = regtypes
717                self._attnames.clear()
718            return regtypes
719
720    def has_table_privilege(self, cl, privilege='select'):
721        """Check whether current user has specified table privilege."""
722        qcl = self._add_schema(cl)
723        privilege = privilege.lower()
724        try:
725            return self._privileges[(qcl, privilege)]
726        except KeyError:
727            q = "SELECT has_table_privilege('%s', '%s')" % (qcl, privilege)
728            ret = self.db.query(q).getresult()[0][0] == self._make_bool(True)
729            self._privileges[(qcl, privilege)] = ret
730            return ret
731
732    def get(self, cl, arg, keyname=None):
733        """Get a tuple from a database table or view.
734
735        This method is the basic mechanism to get a single row.  The keyname
736        that the key specifies a unique row.  If keyname is not specified
737        then the primary key for the table is used.  If arg is a dictionary
738        then the value for the key is taken from it and it is modified to
739        include the new values, replacing existing values where necessary.
740        For a composite key, keyname can also be a sequence of key names.
741        The OID is also put into the dictionary if the table has one, but
742        in order to allow the caller to work with multiple tables, it is
743        munged as oid(schema.table).
744
745        """
746        if cl.endswith('*'):  # scan descendant tables?
747            cl = cl[:-1].rstrip()  # need parent table name
748        # build qualified class name
749        qcl = self._add_schema(cl)
750        # To allow users to work with multiple tables,
751        # we munge the name of the "oid" the key
752        qoid = _oid_key(qcl)
753        if not keyname:
754            # use the primary key by default
755            try:
756                keyname = self.pkey(qcl)
757            except KeyError:
758                raise _prg_error('Class %s has no primary key' % qcl)
759        attnames = self.get_attnames(qcl)
760        # We want the oid for later updates if that isn't the key
761        if keyname == 'oid':
762            if isinstance(arg, dict):
763                if qoid not in arg:
764                    raise _db_error('%s not in arg' % qoid)
765            else:
766                arg = {qoid: arg}
767            what = '*'
768            where = 'oid = %s' % arg[qoid]
769        else:
770            if isinstance(keyname, basestring):
771                keyname = (keyname,)
772            if not isinstance(arg, dict):
773                if len(keyname) > 1:
774                    raise _prg_error('Composite key needs dict as arg')
775                arg = dict([(k, arg) for k in keyname])
776            what = ', '.join(attnames)
777            where = ' AND '.join(['%s = %s'
778                % (k, self._quote(arg[k], attnames[k])) for k in keyname])
779        q = 'SELECT %s FROM %s WHERE %s LIMIT 1' % (what, qcl, where)
780        self._do_debug(q)
781        res = self.db.query(q).dictresult()
782        if not res:
783            raise _db_error('No such record in %s where %s' % (qcl, where))
784        for n, value in res[0].items():
785            if n == 'oid':
786                n = qoid
787            elif attnames.get(n) == 'bytea':
788                value = self.unescape_bytea(value)
789            arg[n] = value
790        return arg
791
792    def insert(self, cl, d=None, **kw):
793        """Insert a tuple into a database table.
794
795        This method inserts a row into a table.  If a dictionary is
796        supplied it starts with that.  Otherwise it uses a blank dictionary.
797        Either way the dictionary is updated from the keywords.
798
799        The dictionary is then, if possible, reloaded with the values actually
800        inserted in order to pick up values modified by rules, triggers, etc.
801
802        Note: The method currently doesn't support insert into views
803        although PostgreSQL does.
804
805        """
806        qcl = self._add_schema(cl)
807        qoid = _oid_key(qcl)
808        if d is None:
809            d = {}
810        d.update(kw)
811        attnames = self.get_attnames(qcl)
812        names, values = [], []
813        for n in attnames:
814            if n != 'oid' and n in d:
815                names.append('"%s"' % n)
816                values.append(self._quote(d[n], attnames[n]))
817        names, values = ', '.join(names), ', '.join(values)
818        selectable = self.has_table_privilege(qcl)
819        if selectable:
820            ret = ' RETURNING %s*' % ('oid, ' if 'oid' in attnames else '')
821        else:
822            ret = ''
823        q = 'INSERT INTO %s (%s) VALUES (%s)%s' % (qcl, names, values, ret)
824        self._do_debug(q)
825        res = self.db.query(q)
826        if ret:
827            res = res.dictresult()[0]
828            for n, value in res.items():
829                if n == 'oid':
830                    n = qoid
831                elif attnames.get(n) == 'bytea':
832                    value = self.unescape_bytea(value)
833                d[n] = value
834        elif isinstance(res, int):
835            d[qoid] = res
836            if selectable:
837                self.get(qcl, d, 'oid')
838        elif selectable:
839            if qoid in d:
840                self.get(qcl, d, 'oid')
841            else:
842                try:
843                    self.get(qcl, d)
844                except ProgrammingError:
845                    pass  # table has no primary key
846        return d
847
848    def update(self, cl, d=None, **kw):
849        """Update an existing row in a database table.
850
851        Similar to insert but updates an existing row.  The update is based
852        on the OID value as munged by get or passed as keyword, or on the
853        primary key of the table.  The dictionary is modified, if possible,
854        to reflect any changes caused by the update due to triggers, rules,
855        default values, etc.
856
857        """
858        # Update always works on the oid which get returns if available,
859        # otherwise use the primary key.  Fail if neither.
860        # Note that we only accept oid key from named args for safety
861        qcl = self._add_schema(cl)
862        qoid = _oid_key(qcl)
863        if 'oid' in kw:
864            kw[qoid] = kw['oid']
865            del kw['oid']
866        if d is None:
867            d = {}
868        d.update(kw)
869        attnames = self.get_attnames(qcl)
870        if qoid in d:
871            where = 'oid = %s' % d[qoid]
872            keyname = ()
873        else:
874            try:
875                keyname = self.pkey(qcl)
876            except KeyError:
877                raise _prg_error('Class %s has no primary key' % qcl)
878            if isinstance(keyname, basestring):
879                keyname = (keyname,)
880            try:
881                where = ' AND '.join(['%s = %s'
882                    % (k, self._quote(d[k], attnames[k])) for k in keyname])
883            except KeyError:
884                raise _prg_error('Update needs primary key or oid.')
885        values = []
886        for n in attnames:
887            if n in d and n not in keyname:
888                values.append('%s = %s' % (n, self._quote(d[n], attnames[n])))
889        if not values:
890            return d
891        values = ', '.join(values)
892        selectable = self.has_table_privilege(qcl)
893        if selectable:
894            ret = ' RETURNING %s*' % ('oid, ' if 'oid' in attnames else '')
895        else:
896            ret = ''
897        q = 'UPDATE %s SET %s WHERE %s%s' % (qcl, values, where, ret)
898        self._do_debug(q)
899        res = self.db.query(q)
900        if ret:
901            res = res.dictresult()[0]
902            for n, value in res.items():
903                if n == 'oid':
904                    n = qoid
905                elif attnames.get(n) == 'bytea':
906                    value = self.unescape_bytea(value)
907                d[n] = value
908        else:
909            if selectable:
910                if qoid in d:
911                    self.get(qcl, d, 'oid')
912                else:
913                    self.get(qcl, d)
914        return d
915
916    def clear(self, cl, a=None):
917        """Clear all the attributes to values determined by the types.
918
919        Numeric types are set to 0, Booleans are set to false, and everything
920        else is set to the empty string.  If the array argument is present,
921        it is used as the array and any entries matching attribute names are
922        cleared with everything else left unchanged.
923
924        """
925        # At some point we will need a way to get defaults from a table.
926        qcl = self._add_schema(cl)
927        if a is None:
928            a = {}  # empty if argument is not present
929        attnames = self.get_attnames(qcl)
930        for n, t in attnames.items():
931            if n == 'oid':
932                continue
933            if t in ('int', 'integer', 'smallint', 'bigint',
934                    'float', 'real', 'double precision',
935                    'num', 'numeric', 'money'):
936                a[n] = 0
937            elif t in ('bool', 'boolean'):
938                a[n] = self._make_bool(False)
939            else:
940                a[n] = ''
941        return a
942
943    def delete(self, cl, d=None, **kw):
944        """Delete an existing row in a database table.
945
946        This method deletes the row from a table.  It deletes based on the
947        OID value as munged by get or passed as keyword, or on the primary
948        key of the table.  The return value is the number of deleted rows
949        (i.e. 0 if the row did not exist and 1 if the row was deleted).
950
951        """
952        # Like update, delete works on the oid.
953        # One day we will be testing that the record to be deleted
954        # isn't referenced somewhere (or else PostgreSQL will).
955        # Note that we only accept oid key from named args for safety
956        qcl = self._add_schema(cl)
957        qoid = _oid_key(qcl)
958        if 'oid' in kw:
959            kw[qoid] = kw['oid']
960            del kw['oid']
961        if d is None:
962            d = {}
963        d.update(kw)
964        if qoid in d:
965            where = 'oid = %s' % d[qoid]
966        else:
967            try:
968                keyname = self.pkey(qcl)
969            except KeyError:
970                raise _prg_error('Class %s has no primary key' % qcl)
971            if isinstance(keyname, basestring):
972                keyname = (keyname,)
973            attnames = self.get_attnames(qcl)
974            try:
975                where = ' AND '.join(['%s = %s'
976                    % (k, self._quote(d[k], attnames[k])) for k in keyname])
977            except KeyError:
978                raise _prg_error('Delete needs primary key or oid.')
979        q = 'DELETE FROM %s WHERE %s' % (qcl, where)
980        self._do_debug(q)
981        return int(self.db.query(q))
982
983    def notification_handler(self, event, callback, arg_dict={}, timeout=None):
984        """Get notification handler that will run the given callback."""
985        return NotificationHandler(self.db, event, callback, arg_dict, timeout)
986
987
988# if run as script, print some information
989
990if __name__ == '__main__':
991    print('PyGreSQL version' + version)
992    print('')
993    print(__doc__)
Note: See TracBrowser for help on using the repository browser.