source: trunk/module/pg.py @ 705

Last change on this file since 705 was 705, checked in by cito, 4 years ago

Use more idiomatic SQL for boolean attributes

  • Property svn:keywords set to Id
File size: 34.4 KB
Line 
1#! /usr/bin/python
2#
3# pg.py
4#
5# $Id: pg.py 705 2016-01-09 20:49:18Z cito $
6#
7
8"""PyGreSQL classic interface.
9
10This pg module implements some basic database management stuff.
11It includes the _pg module and builds on it, providing the higher
12level wrapper class named DB with additional functionality.
13This is known as the "classic" ("old style") PyGreSQL interface.
14For a DB-API 2 compliant interface use the newer pgdb module.
15
16"""
17
18# Copyright (c) 1997-2013 by D'Arcy J.M. Cain.
19#
20# Contributions made by Ch. Zwerschke and others.
21#
22# The notification handler is based on pgnotify which is
23# Copyright (c) 2001 Ng Pheng Siong. All rights reserved.
24#
25# Permission to use, copy, modify, and distribute this software and its
26# documentation for any purpose and without fee is hereby granted,
27# provided that the above copyright notice appear in all copies and that
28# both that copyright notice and this permission notice appear in
29# supporting documentation.
30
31from __future__ import print_function
32
33from _pg import *
34
35import select
36import warnings
37
38from decimal import Decimal
39from collections import namedtuple
40
41try:
42    basestring
43except NameError:  # Python >= 3.0
44    basestring = (str, bytes)
45
46set_decimal(Decimal)
47
48
49# Auxiliary functions which are independent from a DB connection:
50
51def _is_quoted(s):
52    """Check whether this string is a quoted identifier."""
53    s = s.replace('_', 'a')
54    return not s.isalnum() or s[:1].isdigit() or s != s.lower()
55
56
57def _is_unquoted(s):
58    """Check whether this string is an unquoted identifier."""
59    s = s.replace('_', 'a')
60    return s.isalnum() and not s[:1].isdigit()
61
62
63def _split_first_part(s):
64    """Split the first part of a dot separated string."""
65    s = s.lstrip()
66    if s[:1] == '"':
67        p = []
68        s = s.split('"', 3)[1:]
69        p.append(s[0])
70        while len(s) == 3 and s[1] == '':
71            p.append('"')
72            s = s[2].split('"', 2)
73            p.append(s[0])
74        p = [''.join(p)]
75        s = '"'.join(s[1:]).lstrip()
76        if s:
77            if s[:0] == '.':
78                p.append(s[1:])
79            else:
80                s = _split_first_part(s)
81                p[0] += s[0]
82                if len(s) > 1:
83                    p.append(s[1])
84    else:
85        p = s.split('.', 1)
86        s = p[0].rstrip()
87        if _is_unquoted(s):
88            s = s.lower()
89        p[0] = s
90    return p
91
92
93def _split_parts(s):
94    """Split all parts of a dot separated string."""
95    q = []
96    while s:
97        s = _split_first_part(s)
98        q.append(s[0])
99        if len(s) < 2:
100            break
101        s = s[1]
102    return q
103
104
105def _join_parts(s):
106    """Join all parts of a dot separated string."""
107    return '.'.join([_is_quoted(p) and '"%s"' % p or p for p in s])
108
109
110def _oid_key(qcl):
111    """Build oid key from qualified class name."""
112    return 'oid(%s)' % qcl
113
114
115def _namedresult(q):
116    """Get query result as named tuples."""
117    row = namedtuple('Row', q.listfields())
118    return [row(*r) for r in q.getresult()]
119
120set_namedresult(_namedresult)
121
122
123def _db_error(msg, cls=DatabaseError):
124    """Returns DatabaseError with empty sqlstate attribute."""
125    error = cls(msg)
126    error.sqlstate = None
127    return error
128
129
130def _int_error(msg):
131    """Returns InternalError."""
132    return _db_error(msg, InternalError)
133
134
135def _prg_error(msg):
136    """Returns ProgrammingError."""
137    return _db_error(msg, ProgrammingError)
138
139
140class NotificationHandler(object):
141    """A PostgreSQL client-side asynchronous notification handler."""
142
143    def __init__(self, db, event, callback, arg_dict=None, timeout=None):
144        """Initialize the notification handler.
145
146        db       - PostgreSQL connection object.
147        event    - Event (notification channel) to LISTEN for.
148        callback - Event callback function.
149        arg_dict - A dictionary passed as the argument to the callback.
150        timeout  - Timeout in seconds; a floating point number denotes
151                   fractions of seconds. If it is absent or None, the
152                   callers will never time out.
153
154        """
155        if isinstance(db, DB):
156            db = db.db
157        self.db = db
158        self.event = event
159        self.stop_event = 'stop_%s' % event
160        self.listening = False
161        self.callback = callback
162        if arg_dict is None:
163            arg_dict = {}
164        self.arg_dict = arg_dict
165        self.timeout = timeout
166
167    def __del__(self):
168        self.close()
169
170    def close(self):
171        """Stop listening and close the connection."""
172        if self.db:
173            self.unlisten()
174            self.db.close()
175            self.db = None
176
177    def listen(self):
178        """Start listening for the event and the stop event."""
179        if not self.listening:
180            self.db.query('listen "%s"' % self.event)
181            self.db.query('listen "%s"' % self.stop_event)
182            self.listening = True
183
184    def unlisten(self):
185        """Stop listening for the event and the stop event."""
186        if self.listening:
187            self.db.query('unlisten "%s"' % self.event)
188            self.db.query('unlisten "%s"' % self.stop_event)
189            self.listening = False
190
191    def notify(self, db=None, stop=False, payload=None):
192        """Generate a notification.
193
194        Note: If the main loop is running in another thread, you must pass
195        a different database connection to avoid a collision.
196
197        """
198        if not db:
199            db = self.db
200        if self.listening:
201            q = 'notify "%s"' % (self.stop_event if stop else self.event)
202            if payload:
203                q += ", '%s'" % payload
204            return db.query(q)
205
206    def __call__(self, close=False):
207        """Invoke the notification handler.
208
209        The handler is a loop that actually LISTENs for two NOTIFY messages:
210
211        <event> and stop_<event>.
212
213        When either of these NOTIFY messages are received, its associated
214        'pid' and 'event' are inserted into <arg_dict>, and the callback is
215        invoked with <arg_dict>. If the NOTIFY message is stop_<event>, the
216        handler UNLISTENs both <event> and stop_<event> and exits.
217
218        Note: If you run this loop in another thread, don't use the same
219        database connection for database operations in the main thread.
220
221        """
222        self.listen()
223        _ilist = [self.db.fileno()]
224
225        while self.listening:
226            ilist, _olist, _elist = select.select(_ilist, [], [], self.timeout)
227            if ilist:
228                while self.listening:
229                    notice = self.db.getnotify()
230                    if not notice:  # no more messages
231                        break
232                    event, pid, extra = notice
233                    if event not in (self.event, self.stop_event):
234                        self.unlisten()
235                        raise _db_error(
236                            'listening for "%s" and "%s", but notified of "%s"'
237                            % (self.event, self.stop_event, event))
238                    if event == self.stop_event:
239                        self.unlisten()
240                    self.arg_dict['pid'] = pid
241                    self.arg_dict['event'] = event
242                    self.arg_dict['extra'] = extra
243                    self.callback(self.arg_dict)
244            else:   # we timed out
245                self.unlisten()
246                self.callback(None)
247
248
249def pgnotify(*args, **kw):
250    """Same as NotificationHandler, under the traditional name."""
251    warnings.warn("pgnotify is deprecated, use NotificationHandler instead.",
252        DeprecationWarning, stacklevel=2)
253    return NotificationHandler(*args, **kw)
254
255
256# The actual PostGreSQL database connection interface:
257
258class DB(object):
259    """Wrapper class for the _pg connection type."""
260
261    def __init__(self, *args, **kw):
262        """Create a new connection.
263
264        You can pass either the connection parameters or an existing
265        _pg or pgdb connection. This allows you to use the methods
266        of the classic pg interface with a DB-API 2 pgdb connection.
267
268        """
269        if not args and len(kw) == 1:
270            db = kw.get('db')
271        elif not kw and len(args) == 1:
272            db = args[0]
273        else:
274            db = None
275        if db:
276            if isinstance(db, DB):
277                db = db.db
278            else:
279                try:
280                    db = db._cnx
281                except AttributeError:
282                    pass
283        if not db or not hasattr(db, 'db') or not hasattr(db, 'query'):
284            db = connect(*args, **kw)
285            self._closeable = True
286        else:
287            self._closeable = False
288        self.db = db
289        self.dbname = db.db
290        self._regtypes = False
291        self._attnames = {}
292        self._pkeys = {}
293        self._privileges = {}
294        self._args = args, kw
295        self.debug = None  # For debugging scripts, this can be set
296            # * to a string format specification (e.g. in CGI set to "%s<BR>"),
297            # * to a file object to write debug statements or
298            # * to a callable object which takes a string argument
299            # * to any other true value to just print debug statements
300
301    def __getattr__(self, name):
302        # All undefined members are same as in underlying connection:
303        if self.db:
304            return getattr(self.db, name)
305        else:
306            raise _int_error('Connection is not valid')
307
308    def __dir__(self):
309        # Custom dir function including the attributes of the connection:
310        attrs = set(self.__class__.__dict__)
311        attrs.update(self.__dict__)
312        attrs.update(dir(self.db))
313        return sorted(attrs)
314
315    # Context manager methods
316
317    def __enter__(self):
318        """Enter the runtime context. This will start a transaction."""
319        self.begin()
320        return self
321
322    def __exit__(self, et, ev, tb):
323        """Exit the runtime context. This will end the transaction."""
324        if et is None and ev is None and tb is None:
325            self.commit()
326        else:
327            self.rollback()
328
329    # Auxiliary methods
330
331    def _do_debug(self, s):
332        """Print a debug message."""
333        if self.debug:
334            if isinstance(self.debug, basestring):
335                print(self.debug % s)
336            elif hasattr(self.debug, 'write'):
337                self.debug.write(s + '\n')
338            elif callable(self.debug):
339                self.debug(s)
340            else:
341                print(s)
342
343    def _quote_text(self, d):
344        """Quote text value."""
345        if not isinstance(d, basestring):
346            d = str(d)
347        return "'%s'" % self.escape_string(d)
348
349    _bool_true = frozenset('t true 1 y yes on'.split())
350
351    def _quote_bool(self, d):
352        """Quote boolean value."""
353        if isinstance(d, basestring):
354            if not d:
355                return 'NULL'
356            d = d.lower() in self._bool_true
357        else:
358            d = bool(d)
359        return ("'f'", "'t'")[d]
360
361    _date_literals = frozenset('current_date current_time'
362        ' current_timestamp localtime localtimestamp'.split())
363
364    def _quote_date(self, d):
365        """Quote date value."""
366        if not d:
367            return 'NULL'
368        if isinstance(d, basestring) and d.lower() in self._date_literals:
369            return d
370        return self._quote_text(d)
371
372    def _quote_num(self, d):
373        """Quote numeric value."""
374        if not d and d != 0:
375            return 'NULL'
376        return str(d)
377
378    def _quote_money(self, d):
379        """Quote money value."""
380        if d is None or d == '':
381            return 'NULL'
382        if not isinstance(d, basestring):
383            d = str(d)
384        return d
385
386    if bytes is str:  # Python < 3.0
387        """Quote bytes value."""
388
389        def _quote_bytea(self, d):
390            return "'%s'" % self.escape_bytea(d)
391
392    else:
393
394        def _quote_bytea(self, d):
395            return "'%s'" % self.escape_bytea(d).decode('ascii')
396
397    _quote_funcs = dict(  # quote methods for each type
398        text=_quote_text, bool=_quote_bool, date=_quote_date,
399        int=_quote_num, num=_quote_num, float=_quote_num,
400        money=_quote_money, bytea=_quote_bytea)
401
402    def _quote(self, d, t):
403        """Return quotes if needed."""
404        if d is None:
405            return 'NULL'
406        try:
407            quote_func = self._quote_funcs[t]
408        except KeyError:
409            quote_func = self._quote_funcs['text']
410        return quote_func(self, d)
411
412    def _split_schema(self, cl):
413        """Return schema and name of object separately.
414
415        This auxiliary function splits off the namespace (schema)
416        belonging to the class with the name cl. If the class name
417        is not qualified, the function is able to determine the schema
418        of the class, taking into account the current search path.
419
420        """
421        s = _split_parts(cl)
422        if len(s) > 1:  # name already qualfied?
423            # should be database.schema.table or schema.table
424            if len(s) > 3:
425                raise _prg_error('Too many dots in class name %s' % cl)
426            schema, cl = s[-2:]
427        else:
428            cl = s[0]
429            # determine search path
430            q = 'SELECT current_schemas(TRUE)'
431            schemas = self.db.query(q).getresult()[0][0][1:-1].split(',')
432            if schemas:  # non-empty path
433                # search schema for this object in the current search path
434                q = ' UNION '.join(
435                    ["SELECT %d::integer AS n, '%s'::name AS nspname"
436                        % s for s in enumerate(schemas)])
437                q = ("SELECT nspname FROM pg_class"
438                    " JOIN pg_namespace"
439                    " ON pg_class.relnamespace = pg_namespace.oid"
440                    " JOIN (%s) AS p USING (nspname)"
441                    " WHERE pg_class.relname = '%s'"
442                    " ORDER BY n LIMIT 1" % (q, cl))
443                schema = self.db.query(q).getresult()
444                if schema:  # schema found
445                    schema = schema[0][0]
446                else:  # object not found in current search path
447                    schema = 'public'
448            else:  # empty path
449                schema = 'public'
450        return schema, cl
451
452    def _add_schema(self, cl):
453        """Ensure that the class name is prefixed with a schema name."""
454        return _join_parts(self._split_schema(cl))
455
456    # Public methods
457
458    # escape_string and escape_bytea exist as methods,
459    # so we define unescape_bytea as a method as well
460    unescape_bytea = staticmethod(unescape_bytea)
461
462    def close(self):
463        """Close the database connection."""
464        # Wraps shared library function so we can track state.
465        if self._closeable:
466            if self.db:
467                self.db.close()
468                self.db = None
469            else:
470                raise _int_error('Connection already closed')
471
472    def reset(self):
473        """Reset connection with current parameters.
474
475        All derived queries and large objects derived from this connection
476        will not be usable after this call.
477
478        """
479        if self.db:
480            self.db.reset()
481        else:
482            raise _int_error('Connection already closed')
483
484    def reopen(self):
485        """Reopen connection to the database.
486
487        Used in case we need another connection to the same database.
488        Note that we can still reopen a database that we have closed.
489
490        """
491        # There is no such shared library function.
492        if self._closeable:
493            db = connect(*self._args[0], **self._args[1])
494            if self.db:
495                self.db.close()
496            self.db = db
497
498    def begin(self, mode=None):
499        """Begin a transaction."""
500        qstr = 'BEGIN'
501        if mode:
502            qstr += ' ' + mode
503        return self.query(qstr)
504
505    start = begin
506
507    def commit(self):
508        """Commit the current transaction."""
509        return self.query('COMMIT')
510
511    end = commit
512
513    def rollback(self, name=None):
514        """Rollback the current transaction."""
515        qstr = 'ROLLBACK'
516        if name:
517            qstr += ' TO ' + name
518        return self.query(qstr)
519
520    def savepoint(self, name=None):
521        """Define a new savepoint within the current transaction."""
522        qstr = 'SAVEPOINT'
523        if name:
524            qstr += ' ' + name
525        return self.query(qstr)
526
527    def release(self, name):
528        """Destroy a previously defined savepoint."""
529        return self.query('RELEASE ' + name)
530
531    def query(self, qstr, *args):
532        """Executes a SQL command string.
533
534        This method simply sends a SQL query to the database. If the query is
535        an insert statement that inserted exactly one row into a table that
536        has OIDs, the return value is the OID of the newly inserted row.
537        If the query is an update or delete statement, or an insert statement
538        that did not insert exactly one row in a table with OIDs, then the
539        number of rows affected is returned as a string. If it is a statement
540        that returns rows as a result (usually a select statement, but maybe
541        also an "insert/update ... returning" statement), this method returns
542        a pgqueryobject that can be accessed via getresult() or dictresult()
543        or simply printed. Otherwise, it returns `None`.
544
545        The query can contain numbered parameters of the form $1 in place
546        of any data constant. Arguments given after the query string will
547        be substituted for the corresponding numbered parameter. Parameter
548        values can also be given as a single list or tuple argument.
549
550        """
551        # Wraps shared library function for debugging.
552        if not self.db:
553            raise _int_error('Connection is not valid')
554        self._do_debug(qstr)
555        return self.db.query(qstr, args)
556
557    def pkey(self, cl, newpkey=None):
558        """This method gets or sets the primary key of a class.
559
560        Composite primary keys are represented as frozensets. Note that
561        this raises an exception if the table does not have a primary key.
562
563        If newpkey is set and is not a dictionary then set that
564        value as the primary key of the class.  If it is a dictionary
565        then replace the _pkeys dictionary with a copy of it.
566
567        """
568        # First see if the caller is supplying a dictionary
569        if isinstance(newpkey, dict):
570            # make sure that all classes have a namespace
571            self._pkeys = dict([
572                (cl if '.' in cl else 'public.' + cl, pkey)
573                for cl, pkey in newpkey.items()])
574            return self._pkeys
575
576        qcl = self._add_schema(cl)  # build fully qualified class name
577        # Check if the caller is supplying a new primary key for the class
578        if newpkey:
579            self._pkeys[qcl] = newpkey
580            return newpkey
581
582        # Get all the primary keys at once
583        if qcl not in self._pkeys:
584            # if not found, check again in case it was added after we started
585            self._pkeys = {}
586            for r in self.db.query(
587                "SELECT pg_namespace.nspname, pg_class.relname,"
588                    " pg_attribute.attname FROM pg_class"
589                " JOIN pg_namespace"
590                    " ON pg_namespace.oid = pg_class.relnamespace"
591                    " AND pg_namespace.nspname NOT LIKE 'pg_%'"
592                " JOIN pg_attribute ON pg_attribute.attrelid = pg_class.oid"
593                    " AND NOT pg_attribute.attisdropped"
594                " JOIN pg_index ON pg_index.indrelid = pg_class.oid"
595                    " AND pg_index.indisprimary"
596                    " AND pg_attribute.attnum"
597                        " = ANY (pg_index.indkey)").getresult():
598                cl, pkey = _join_parts(r[:2]), r[2]
599                self._pkeys.setdefault(cl, []).append(pkey)
600            # (only) for composite primary keys, the values will be frozensets
601            for cl, pkey in self._pkeys.items():
602                self._pkeys[cl] = frozenset(pkey) if len(pkey) > 1 else pkey[0]
603            self._do_debug(self._pkeys)
604
605        # will raise an exception if primary key doesn't exist
606        return self._pkeys[qcl]
607
608    def get_databases(self):
609        """Get list of databases in the system."""
610        return [s[0] for s in
611            self.db.query('SELECT datname FROM pg_database').getresult()]
612
613    def get_relations(self, kinds=None):
614        """Get list of relations in connected database of specified kinds.
615
616            If kinds is None or empty, all kinds of relations are returned.
617            Otherwise kinds can be a string or sequence of type letters
618            specifying which kind of relations you want to list.
619
620        """
621        where = kinds and "pg_class.relkind IN (%s) AND" % ','.join(
622            ["'%s'" % x for x in kinds]) or ''
623        return [_join_parts(x) for x in self.db.query(
624            "SELECT pg_namespace.nspname, pg_class.relname "
625            "FROM pg_class "
626            "JOIN pg_namespace ON pg_namespace.oid = pg_class.relnamespace "
627            "WHERE %s pg_class.relname !~ '^Inv' AND "
628                "pg_class.relname !~ '^pg_' "
629            "ORDER BY 1, 2" % where).getresult()]
630
631    def get_tables(self):
632        """Return list of tables in connected database."""
633        return self.get_relations('r')
634
635    def get_attnames(self, cl, newattnames=None):
636        """Given the name of a table, digs out the set of attribute names.
637
638        Returns a dictionary of attribute names (the names are the keys,
639        the values are the names of the attributes' types).
640        If the optional newattnames exists, it must be a dictionary and
641        will become the new attribute names dictionary.
642
643        By default, only a limited number of simple types will be returned.
644        You can get the regular types after calling use_regtypes(True).
645
646        """
647        if isinstance(newattnames, dict):
648            self._attnames = newattnames
649            return
650        elif newattnames:
651            raise _prg_error('If supplied, newattnames must be a dictionary')
652        cl = self._split_schema(cl)  # split into schema and class
653        qcl = _join_parts(cl)  # build fully qualified name
654        # May as well cache them:
655        if qcl in self._attnames:
656            return self._attnames[qcl]
657        if qcl not in self.get_relations('rv'):
658            raise _prg_error('Class %s does not exist' % qcl)
659
660        q = "SELECT pg_attribute.attname, pg_type.typname"
661        if self._regtypes:
662            q += "::regtype"
663        q += (" FROM pg_class"
664            " JOIN pg_namespace ON pg_class.relnamespace = pg_namespace.oid"
665            " JOIN pg_attribute ON pg_attribute.attrelid = pg_class.oid"
666            " JOIN pg_type ON pg_type.oid = pg_attribute.atttypid"
667            " WHERE pg_namespace.nspname = '%s' AND pg_class.relname = '%s'"
668            " AND (pg_attribute.attnum > 0 OR pg_attribute.attname = 'oid')"
669            " AND NOT pg_attribute.attisdropped") % cl
670        q = self.db.query(q).getresult()
671
672        if self._regtypes:
673            t = dict(q)
674        else:
675            t = {}
676            for att, typ in q:
677                if typ.startswith('bool'):
678                    typ = 'bool'
679                elif typ.startswith('abstime'):
680                    typ = 'date'
681                elif typ.startswith('date'):
682                    typ = 'date'
683                elif typ.startswith('interval'):
684                    typ = 'date'
685                elif typ.startswith('timestamp'):
686                    typ = 'date'
687                elif typ.startswith('oid'):
688                    typ = 'int'
689                elif typ.startswith('int'):
690                    typ = 'int'
691                elif typ.startswith('float'):
692                    typ = 'float'
693                elif typ.startswith('numeric'):
694                    typ = 'num'
695                elif typ.startswith('money'):
696                    typ = 'money'
697                elif typ.startswith('bytea'):
698                    typ = 'bytea'
699                else:
700                    typ = 'text'
701                t[att] = typ
702
703        self._attnames[qcl] = t  # cache it
704        return self._attnames[qcl]
705
706    def use_regtypes(self, regtypes=None):
707        """Use regular type names instead of simplified type names."""
708        if regtypes is None:
709            return self._regtypes
710        else:
711            regtypes = bool(regtypes)
712            if regtypes != self._regtypes:
713                self._regtypes = regtypes
714                self._attnames.clear()
715            return regtypes
716
717    def has_table_privilege(self, cl, privilege='select'):
718        """Check whether current user has specified table privilege."""
719        qcl = self._add_schema(cl)
720        privilege = privilege.lower()
721        try:
722            return self._privileges[(qcl, privilege)]
723        except KeyError:
724            q = "SELECT has_table_privilege('%s', '%s')" % (qcl, privilege)
725            ret = self.db.query(q).getresult()[0][0] == 't'
726            self._privileges[(qcl, privilege)] = ret
727            return ret
728
729    def get(self, cl, arg, keyname=None):
730        """Get a tuple from a database table or view.
731
732        This method is the basic mechanism to get a single row.  The keyname
733        that the key specifies a unique row.  If keyname is not specified
734        then the primary key for the table is used.  If arg is a dictionary
735        then the value for the key is taken from it and it is modified to
736        include the new values, replacing existing values where necessary.
737        For a composite key, keyname can also be a sequence of key names.
738        The OID is also put into the dictionary if the table has one, but
739        in order to allow the caller to work with multiple tables, it is
740        munged as oid(schema.table).
741
742        """
743        if cl.endswith('*'):  # scan descendant tables?
744            cl = cl[:-1].rstrip()  # need parent table name
745        # build qualified class name
746        qcl = self._add_schema(cl)
747        # To allow users to work with multiple tables,
748        # we munge the name of the "oid" the key
749        qoid = _oid_key(qcl)
750        if not keyname:
751            # use the primary key by default
752            try:
753                keyname = self.pkey(qcl)
754            except KeyError:
755                raise _prg_error('Class %s has no primary key' % qcl)
756        attnames = self.get_attnames(qcl)
757        # We want the oid for later updates if that isn't the key
758        if keyname == 'oid':
759            if isinstance(arg, dict):
760                if qoid not in arg:
761                    raise _db_error('%s not in arg' % qoid)
762            else:
763                arg = {qoid: arg}
764            what = '*'
765            where = 'oid = %s' % arg[qoid]
766        else:
767            if isinstance(keyname, basestring):
768                keyname = (keyname,)
769            if not isinstance(arg, dict):
770                if len(keyname) > 1:
771                    raise _prg_error('Composite key needs dict as arg')
772                arg = dict([(k, arg) for k in keyname])
773            what = ', '.join(attnames)
774            where = ' AND '.join(['%s = %s'
775                % (k, self._quote(arg[k], attnames[k])) for k in keyname])
776        q = 'SELECT %s FROM %s WHERE %s LIMIT 1' % (what, qcl, where)
777        self._do_debug(q)
778        res = self.db.query(q).dictresult()
779        if not res:
780            raise _db_error('No such record in %s where %s' % (qcl, where))
781        for n, value in res[0].items():
782            if n == 'oid':
783                n = qoid
784            elif attnames.get(n) == 'bytea':
785                value = self.unescape_bytea(value)
786            arg[n] = value
787        return arg
788
789    def insert(self, cl, d=None, **kw):
790        """Insert a tuple into a database table.
791
792        This method inserts a row into a table.  If a dictionary is
793        supplied it starts with that.  Otherwise it uses a blank dictionary.
794        Either way the dictionary is updated from the keywords.
795
796        The dictionary is then, if possible, reloaded with the values actually
797        inserted in order to pick up values modified by rules, triggers, etc.
798
799        Note: The method currently doesn't support insert into views
800        although PostgreSQL does.
801
802        """
803        qcl = self._add_schema(cl)
804        qoid = _oid_key(qcl)
805        if d is None:
806            d = {}
807        d.update(kw)
808        attnames = self.get_attnames(qcl)
809        names, values = [], []
810        for n in attnames:
811            if n != 'oid' and n in d:
812                names.append('"%s"' % n)
813                values.append(self._quote(d[n], attnames[n]))
814        names, values = ', '.join(names), ', '.join(values)
815        selectable = self.has_table_privilege(qcl)
816        if selectable:
817            ret = ' RETURNING %s*' % ('oid, ' if 'oid' in attnames else '')
818        else:
819            ret = ''
820        q = 'INSERT INTO %s (%s) VALUES (%s)%s' % (qcl, names, values, ret)
821        self._do_debug(q)
822        res = self.db.query(q)
823        if ret:
824            res = res.dictresult()[0]
825            for n, value in res.items():
826                if n == 'oid':
827                    n = qoid
828                elif attnames.get(n) == 'bytea':
829                    value = self.unescape_bytea(value)
830                d[n] = value
831        elif isinstance(res, int):
832            d[qoid] = res
833            if selectable:
834                self.get(qcl, d, 'oid')
835        elif selectable:
836            if qoid in d:
837                self.get(qcl, d, 'oid')
838            else:
839                try:
840                    self.get(qcl, d)
841                except ProgrammingError:
842                    pass  # table has no primary key
843        return d
844
845    def update(self, cl, d=None, **kw):
846        """Update an existing row in a database table.
847
848        Similar to insert but updates an existing row.  The update is based
849        on the OID value as munged by get or passed as keyword, or on the
850        primary key of the table.  The dictionary is modified, if possible,
851        to reflect any changes caused by the update due to triggers, rules,
852        default values, etc.
853
854        """
855        # Update always works on the oid which get returns if available,
856        # otherwise use the primary key.  Fail if neither.
857        # Note that we only accept oid key from named args for safety
858        qcl = self._add_schema(cl)
859        qoid = _oid_key(qcl)
860        if 'oid' in kw:
861            kw[qoid] = kw['oid']
862            del kw['oid']
863        if d is None:
864            d = {}
865        d.update(kw)
866        attnames = self.get_attnames(qcl)
867        if qoid in d:
868            where = 'oid = %s' % d[qoid]
869            keyname = ()
870        else:
871            try:
872                keyname = self.pkey(qcl)
873            except KeyError:
874                raise _prg_error('Class %s has no primary key' % qcl)
875            if isinstance(keyname, basestring):
876                keyname = (keyname,)
877            try:
878                where = ' AND '.join(['%s = %s'
879                    % (k, self._quote(d[k], attnames[k])) for k in keyname])
880            except KeyError:
881                raise _prg_error('Update needs primary key or oid.')
882        values = []
883        for n in attnames:
884            if n in d and n not in keyname:
885                values.append('%s = %s' % (n, self._quote(d[n], attnames[n])))
886        if not values:
887            return d
888        values = ', '.join(values)
889        selectable = self.has_table_privilege(qcl)
890        if selectable:
891            ret = ' RETURNING %s*' % ('oid, ' if 'oid' in attnames else '')
892        else:
893            ret = ''
894        q = 'UPDATE %s SET %s WHERE %s%s' % (qcl, values, where, ret)
895        self._do_debug(q)
896        res = self.db.query(q)
897        if ret:
898            res = res.dictresult()[0]
899            for n, value in res.items():
900                if n == 'oid':
901                    n = qoid
902                elif attnames.get(n) == 'bytea':
903                    value = self.unescape_bytea(value)
904                d[n] = value
905        else:
906            if selectable:
907                if qoid in d:
908                    self.get(qcl, d, 'oid')
909                else:
910                    self.get(qcl, d)
911        return d
912
913    def clear(self, cl, a=None):
914        """Clear all the attributes to values determined by the types.
915
916        Numeric types are set to 0, Booleans are set to 'f', and everything
917        else is set to the empty string.  If the array argument is present,
918        it is used as the array and any entries matching attribute names are
919        cleared with everything else left unchanged.
920
921        """
922        # At some point we will need a way to get defaults from a table.
923        qcl = self._add_schema(cl)
924        if a is None:
925            a = {}  # empty if argument is not present
926        attnames = self.get_attnames(qcl)
927        for n, t in attnames.items():
928            if n == 'oid':
929                continue
930            if t in ('int', 'integer', 'smallint', 'bigint',
931                    'float', 'real', 'double precision',
932                    'num', 'numeric', 'money'):
933                a[n] = 0
934            elif t in ('bool', 'boolean'):
935                a[n] = 'f'
936            else:
937                a[n] = ''
938        return a
939
940    def delete(self, cl, d=None, **kw):
941        """Delete an existing row in a database table.
942
943        This method deletes the row from a table.  It deletes based on the
944        OID value as munged by get or passed as keyword, or on the primary
945        key of the table.  The return value is the number of deleted rows
946        (i.e. 0 if the row did not exist and 1 if the row was deleted).
947
948        """
949        # Like update, delete works on the oid.
950        # One day we will be testing that the record to be deleted
951        # isn't referenced somewhere (or else PostgreSQL will).
952        # Note that we only accept oid key from named args for safety
953        qcl = self._add_schema(cl)
954        qoid = _oid_key(qcl)
955        if 'oid' in kw:
956            kw[qoid] = kw['oid']
957            del kw['oid']
958        if d is None:
959            d = {}
960        d.update(kw)
961        if qoid in d:
962            where = 'oid = %s' % d[qoid]
963        else:
964            try:
965                keyname = self.pkey(qcl)
966            except KeyError:
967                raise _prg_error('Class %s has no primary key' % qcl)
968            if isinstance(keyname, basestring):
969                keyname = (keyname,)
970            attnames = self.get_attnames(qcl)
971            try:
972                where = ' AND '.join(['%s = %s'
973                    % (k, self._quote(d[k], attnames[k])) for k in keyname])
974            except KeyError:
975                raise _prg_error('Delete needs primary key or oid.')
976        q = 'DELETE FROM %s WHERE %s' % (qcl, where)
977        self._do_debug(q)
978        return int(self.db.query(q))
979
980    def notification_handler(self, event, callback, arg_dict={}, timeout=None):
981        """Get notification handler that will run the given callback."""
982        return NotificationHandler(self.db, event, callback, arg_dict, timeout)
983
984
985# if run as script, print some information
986
987if __name__ == '__main__':
988    print('PyGreSQL version' + version)
989    print('')
990    print(__doc__)
Note: See TracBrowser for help on using the repository browser.