source: trunk/module/pg.py @ 663

Last change on this file since 663 was 663, checked in by cito, 4 years ago

Fix nasty typo in update() method

This bug prevented the method from using the returning clause.

  • Property svn:keywords set to Id
File size: 34.4 KB
Line 
1#! /usr/bin/python
2#
3# pg.py
4#
5# $Id: pg.py 663 2015-12-10 12:40:11Z cito $
6#
7
8"""PyGreSQL classic interface.
9
10This pg module implements some basic database management stuff.
11It includes the _pg module and builds on it, providing the higher
12level wrapper class named DB with additional functionality.
13This is known as the "classic" ("old style") PyGreSQL interface.
14For a DB-API 2 compliant interface use the newer pgdb module.
15
16"""
17
18# Copyright (c) 1997-2013 by D'Arcy J.M. Cain.
19#
20# Contributions made by Ch. Zwerschke and others.
21#
22# The notification handler is based on pgnotify which is
23# Copyright (c) 2001 Ng Pheng Siong. All rights reserved.
24#
25# Permission to use, copy, modify, and distribute this software and its
26# documentation for any purpose and without fee is hereby granted,
27# provided that the above copyright notice appear in all copies and that
28# both that copyright notice and this permission notice appear in
29# supporting documentation.
30
31from __future__ import print_function
32
33from _pg import *
34
35import select
36import warnings
37
38from decimal import Decimal
39from collections import namedtuple
40
41try:
42    basestring
43except NameError:  # Python >= 3.0
44    basestring = (str, bytes)
45
46set_decimal(Decimal)
47
48
49# Auxiliary functions which are independent from a DB connection:
50
51def _is_quoted(s):
52    """Check whether this string is a quoted identifier."""
53    s = s.replace('_', 'a')
54    return not s.isalnum() or s[:1].isdigit() or s != s.lower()
55
56
57def _is_unquoted(s):
58    """Check whether this string is an unquoted identifier."""
59    s = s.replace('_', 'a')
60    return s.isalnum() and not s[:1].isdigit()
61
62
63def _split_first_part(s):
64    """Split the first part of a dot separated string."""
65    s = s.lstrip()
66    if s[:1] == '"':
67        p = []
68        s = s.split('"', 3)[1:]
69        p.append(s[0])
70        while len(s) == 3 and s[1] == '':
71            p.append('"')
72            s = s[2].split('"', 2)
73            p.append(s[0])
74        p = [''.join(p)]
75        s = '"'.join(s[1:]).lstrip()
76        if s:
77            if s[:0] == '.':
78                p.append(s[1:])
79            else:
80                s = _split_first_part(s)
81                p[0] += s[0]
82                if len(s) > 1:
83                    p.append(s[1])
84    else:
85        p = s.split('.', 1)
86        s = p[0].rstrip()
87        if _is_unquoted(s):
88            s = s.lower()
89        p[0] = s
90    return p
91
92
93def _split_parts(s):
94    """Split all parts of a dot separated string."""
95    q = []
96    while s:
97        s = _split_first_part(s)
98        q.append(s[0])
99        if len(s) < 2:
100            break
101        s = s[1]
102    return q
103
104
105def _join_parts(s):
106    """Join all parts of a dot separated string."""
107    return '.'.join([_is_quoted(p) and '"%s"' % p or p for p in s])
108
109
110def _oid_key(qcl):
111    """Build oid key from qualified class name."""
112    return 'oid(%s)' % qcl
113
114
115def _namedresult(q):
116    """Get query result as named tuples."""
117    row = namedtuple('Row', q.listfields())
118    return [row(*r) for r in q.getresult()]
119
120set_namedresult(_namedresult)
121
122
123def _db_error(msg, cls=DatabaseError):
124    """Returns DatabaseError with empty sqlstate attribute."""
125    error = cls(msg)
126    error.sqlstate = None
127    return error
128
129
130def _int_error(msg):
131    """Returns InternalError."""
132    return _db_error(msg, InternalError)
133
134
135def _prg_error(msg):
136    """Returns ProgrammingError."""
137    return _db_error(msg, ProgrammingError)
138
139
140class NotificationHandler(object):
141    """A PostgreSQL client-side asynchronous notification handler."""
142
143    def __init__(self, db, event, callback, arg_dict=None, timeout=None):
144        """Initialize the notification handler.
145
146        db       - PostgreSQL connection object.
147        event    - Event (notification channel) to LISTEN for.
148        callback - Event callback function.
149        arg_dict - A dictionary passed as the argument to the callback.
150        timeout  - Timeout in seconds; a floating point number denotes
151                   fractions of seconds. If it is absent or None, the
152                   callers will never time out.
153
154        """
155        if isinstance(db, DB):
156            db = db.db
157        self.db = db
158        self.event = event
159        self.stop_event = 'stop_%s' % event
160        self.listening = False
161        self.callback = callback
162        if arg_dict is None:
163            arg_dict = {}
164        self.arg_dict = arg_dict
165        self.timeout = timeout
166
167    def __del__(self):
168        self.close()
169
170    def close(self):
171        """Stop listening and close the connection."""
172        if self.db:
173            self.unlisten()
174            self.db.close()
175            self.db = None
176
177    def listen(self):
178        """Start listening for the event and the stop event."""
179        if not self.listening:
180            self.db.query('listen "%s"' % self.event)
181            self.db.query('listen "%s"' % self.stop_event)
182            self.listening = True
183
184    def unlisten(self):
185        """Stop listening for the event and the stop event."""
186        if self.listening:
187            self.db.query('unlisten "%s"' % self.event)
188            self.db.query('unlisten "%s"' % self.stop_event)
189            self.listening = False
190
191    def notify(self, db=None, stop=False, payload=None):
192        """Generate a notification.
193
194        Note: If the main loop is running in another thread, you must pass
195        a different database connection to avoid a collision.
196
197        """
198        if not db:
199            db = self.db
200        if self.listening:
201            q = 'notify "%s"' % (self.stop_event if stop else self.event)
202            if payload:
203                q += ", '%s'" % payload
204            return db.query(q)
205
206    def __call__(self, close=False):
207        """Invoke the notification handler.
208
209        The handler is a loop that actually LISTENs for two NOTIFY messages:
210
211        <event> and stop_<event>.
212
213        When either of these NOTIFY messages are received, its associated
214        'pid' and 'event' are inserted into <arg_dict>, and the callback is
215        invoked with <arg_dict>. If the NOTIFY message is stop_<event>, the
216        handler UNLISTENs both <event> and stop_<event> and exits.
217
218        Note: If you run this loop in another thread, don't use the same
219        database connection for database operations in the main thread.
220
221        """
222        self.listen()
223        _ilist = [self.db.fileno()]
224
225        while self.listening:
226            ilist, _olist, _elist = select.select(_ilist, [], [], self.timeout)
227            if ilist:
228                while self.listening:
229                    notice = self.db.getnotify()
230                    if not notice:  # no more messages
231                        break
232                    event, pid, extra = notice
233                    if event not in (self.event, self.stop_event):
234                        self.unlisten()
235                        raise _db_error(
236                            'listening for "%s" and "%s", but notified of "%s"'
237                            % (self.event, self.stop_event, event))
238                    if event == self.stop_event:
239                        self.unlisten()
240                    self.arg_dict['pid'] = pid
241                    self.arg_dict['event'] = event
242                    self.arg_dict['extra'] = extra
243                    self.callback(self.arg_dict)
244            else:   # we timed out
245                self.unlisten()
246                self.callback(None)
247
248
249def pgnotify(*args, **kw):
250    """Same as NotificationHandler, under the traditional name."""
251    warnings.warn("pgnotify is deprecated, use NotificationHandler instead.",
252        DeprecationWarning, stacklevel=2)
253    return NotificationHandler(*args, **kw)
254
255
256# The actual PostGreSQL database connection interface:
257
258class DB(object):
259    """Wrapper class for the _pg connection type."""
260
261    def __init__(self, *args, **kw):
262        """Create a new connection.
263
264        You can pass either the connection parameters or an existing
265        _pg or pgdb connection. This allows you to use the methods
266        of the classic pg interface with a DB-API 2 pgdb connection.
267
268        """
269        if not args and len(kw) == 1:
270            db = kw.get('db')
271        elif not kw and len(args) == 1:
272            db = args[0]
273        else:
274            db = None
275        if db:
276            if isinstance(db, DB):
277                db = db.db
278            else:
279                try:
280                    db = db._cnx
281                except AttributeError:
282                    pass
283        if not db or not hasattr(db, 'db') or not hasattr(db, 'query'):
284            db = connect(*args, **kw)
285            self._closeable = True
286        else:
287            self._closeable = False
288        self.db = db
289        self.dbname = db.db
290        self._regtypes = False
291        self._attnames = {}
292        self._pkeys = {}
293        self._privileges = {}
294        self._args = args, kw
295        self.debug = None  # For debugging scripts, this can be set
296            # * to a string format specification (e.g. in CGI set to "%s<BR>"),
297            # * to a file object to write debug statements or
298            # * to a callable object which takes a string argument
299            # * to any other true value to just print debug statements
300
301    def __getattr__(self, name):
302        # All undefined members are same as in underlying connection:
303        if self.db:
304            return getattr(self.db, name)
305        else:
306            raise _int_error('Connection is not valid')
307
308    def __dir__(self):
309        # Custom dir function including the attributes of the connection:
310        attrs = set(self.__class__.__dict__)
311        attrs.update(self.__dict__)
312        attrs.update(dir(self.db))
313        return sorted(attrs)
314
315    # Context manager methods
316
317    def __enter__(self):
318        """Enter the runtime context. This will start a transaction."""
319        self.begin()
320        return self
321
322    def __exit__(self, et, ev, tb):
323        """Exit the runtime context. This will end the transaction."""
324        if et is None and ev is None and tb is None:
325            self.commit()
326        else:
327            self.rollback()
328
329    # Auxiliary methods
330
331    def _do_debug(self, s):
332        """Print a debug message."""
333        if self.debug:
334            if isinstance(self.debug, basestring):
335                print(self.debug % s)
336            elif hasattr(self.debug, 'write'):
337                self.debug.write(s + '\n')
338            elif callable(self.debug):
339                self.debug(s)
340            else:
341                print(s)
342
343    def _quote_text(self, d):
344        """Quote text value."""
345        if not isinstance(d, basestring):
346            d = str(d)
347        return "'%s'" % self.escape_string(d)
348
349    _bool_true = frozenset('t true 1 y yes on'.split())
350
351    def _quote_bool(self, d):
352        """Quote boolean value."""
353        if isinstance(d, basestring):
354            if not d:
355                return 'NULL'
356            d = d.lower() in self._bool_true
357        else:
358            d = bool(d)
359        return ("'f'", "'t'")[d]
360
361    _date_literals = frozenset('current_date current_time'
362        ' current_timestamp localtime localtimestamp'.split())
363
364    def _quote_date(self, d):
365        """Quote date value."""
366        if not d:
367            return 'NULL'
368        if isinstance(d, basestring) and d.lower() in self._date_literals:
369            return d
370        return self._quote_text(d)
371
372    def _quote_num(self, d):
373        """Quote numeric value."""
374        if not d and d != 0:
375            return 'NULL'
376        return str(d)
377
378    def _quote_money(self, d):
379        """Quote money value."""
380        if d is None or d == '':
381            return 'NULL'
382        if not isinstance(d, basestring):
383            d = str(d)
384        return d
385
386    if bytes is str:  # Python < 3.0
387        """Quote bytes value."""
388
389        def _quote_bytea(self, d):
390            return "'%s'" % self.escape_bytea(d)
391
392    else:
393
394        def _quote_bytea(self, d):
395            return "'%s'" % self.escape_bytea(d).decode('ascii')
396
397    _quote_funcs = dict(  # quote methods for each type
398        text=_quote_text, bool=_quote_bool, date=_quote_date,
399        int=_quote_num, num=_quote_num, float=_quote_num,
400        money=_quote_money, bytea=_quote_bytea)
401
402    def _quote(self, d, t):
403        """Return quotes if needed."""
404        if d is None:
405            return 'NULL'
406        try:
407            quote_func = self._quote_funcs[t]
408        except KeyError:
409            quote_func = self._quote_funcs['text']
410        return quote_func(self, d)
411
412    def _split_schema(self, cl):
413        """Return schema and name of object separately.
414
415        This auxiliary function splits off the namespace (schema)
416        belonging to the class with the name cl. If the class name
417        is not qualified, the function is able to determine the schema
418        of the class, taking into account the current search path.
419
420        """
421        s = _split_parts(cl)
422        if len(s) > 1:  # name already qualfied?
423            # should be database.schema.table or schema.table
424            if len(s) > 3:
425                raise _prg_error('Too many dots in class name %s' % cl)
426            schema, cl = s[-2:]
427        else:
428            cl = s[0]
429            # determine search path
430            q = 'SELECT current_schemas(TRUE)'
431            schemas = self.db.query(q).getresult()[0][0][1:-1].split(',')
432            if schemas:  # non-empty path
433                # search schema for this object in the current search path
434                q = ' UNION '.join(
435                    ["SELECT %d::integer AS n, '%s'::name AS nspname"
436                        % s for s in enumerate(schemas)])
437                q = ("SELECT nspname FROM pg_class"
438                    " JOIN pg_namespace"
439                    " ON pg_class.relnamespace = pg_namespace.oid"
440                    " JOIN (%s) AS p USING (nspname)"
441                    " WHERE pg_class.relname = '%s'"
442                    " ORDER BY n LIMIT 1" % (q, cl))
443                schema = self.db.query(q).getresult()
444                if schema:  # schema found
445                    schema = schema[0][0]
446                else:  # object not found in current search path
447                    schema = 'public'
448            else:  # empty path
449                schema = 'public'
450        return schema, cl
451
452    def _add_schema(self, cl):
453        """Ensure that the class name is prefixed with a schema name."""
454        return _join_parts(self._split_schema(cl))
455
456    # Public methods
457
458    # escape_string and escape_bytea exist as methods,
459    # so we define unescape_bytea as a method as well
460    unescape_bytea = staticmethod(unescape_bytea)
461
462    def close(self):
463        """Close the database connection."""
464        # Wraps shared library function so we can track state.
465        if self._closeable:
466            if self.db:
467                self.db.close()
468                self.db = None
469            else:
470                raise _int_error('Connection already closed')
471
472    def reset(self):
473        """Reset connection with current parameters.
474
475        All derived queries and large objects derived from this connection
476        will not be usable after this call.
477
478        """
479        if self.db:
480            self.db.reset()
481        else:
482            raise _int_error('Connection already closed')
483
484    def reopen(self):
485        """Reopen connection to the database.
486
487        Used in case we need another connection to the same database.
488        Note that we can still reopen a database that we have closed.
489
490        """
491        # There is no such shared library function.
492        if self._closeable:
493            db = connect(*self._args[0], **self._args[1])
494            if self.db:
495                self.db.close()
496            self.db = db
497
498    def begin(self, mode=None):
499        """Begin a transaction."""
500        qstr = 'BEGIN'
501        if mode:
502            qstr += ' ' + mode
503        return self.query(qstr)
504
505    start = begin
506
507    def commit(self):
508        """Commit the current transaction."""
509        return self.query('COMMIT')
510
511    end = commit
512
513    def rollback(self, name=None):
514        """Rollback the current transaction."""
515        qstr = 'ROLLBACK'
516        if name:
517            qstr += ' TO ' + name
518        return self.query(qstr)
519
520    def savepoint(self, name=None):
521        """Define a new savepoint within the current transaction."""
522        qstr = 'SAVEPOINT'
523        if name:
524            qstr += ' ' + name
525        return self.query(qstr)
526
527    def release(self, name):
528        """Destroy a previously defined savepoint."""
529        return self.query('RELEASE ' + name)
530
531    def query(self, qstr, *args):
532        """Executes a SQL command string.
533
534        This method simply sends a SQL query to the database. If the query is
535        an insert statement that inserted exactly one row into a table that
536        has OIDs, the return value is the OID of the newly inserted row.
537        If the query is an update or delete statement, or an insert statement
538        that did not insert exactly one row in a table with OIDs, then the
539        number of rows affected is returned as a string. If it is a statement
540        that returns rows as a result (usually a select statement, but maybe
541        also an "insert/update ... returning" statement), this method returns
542        a pgqueryobject that can be accessed via getresult() or dictresult()
543        or simply printed. Otherwise, it returns `None`.
544
545        The query can contain numbered parameters of the form $1 in place
546        of any data constant. Arguments given after the query string will
547        be substituted for the corresponding numbered parameter. Parameter
548        values can also be given as a single list or tuple argument.
549
550        """
551        # Wraps shared library function for debugging.
552        if not self.db:
553            raise _int_error('Connection is not valid')
554        self._do_debug(qstr)
555        return self.db.query(qstr, args)
556
557    def pkey(self, cl, newpkey=None):
558        """This method gets or sets the primary key of a class.
559
560        Composite primary keys are represented as frozensets. Note that
561        this raises an exception if the table does not have a primary key.
562
563        If newpkey is set and is not a dictionary then set that
564        value as the primary key of the class.  If it is a dictionary
565        then replace the _pkeys dictionary with a copy of it.
566
567        """
568        # First see if the caller is supplying a dictionary
569        if isinstance(newpkey, dict):
570            # make sure that all classes have a namespace
571            self._pkeys = dict([
572                (cl if '.' in cl else 'public.' + cl, pkey)
573                for cl, pkey in newpkey.items()])
574            return self._pkeys
575
576        qcl = self._add_schema(cl)  # build fully qualified class name
577        # Check if the caller is supplying a new primary key for the class
578        if newpkey:
579            self._pkeys[qcl] = newpkey
580            return newpkey
581
582        # Get all the primary keys at once
583        if qcl not in self._pkeys:
584            # if not found, check again in case it was added after we started
585            self._pkeys = {}
586            if self.server_version >= 80200:
587                # the ANY syntax works correctly only with PostgreSQL >= 8.2
588                any_indkey = "= ANY (pg_index.indkey)"
589            else:
590                any_indkey = "IN (%s)" % ', '.join(
591                    ['pg_index.indkey[%d]' % i for i in range(16)])
592            for r in self.db.query(
593                "SELECT pg_namespace.nspname, pg_class.relname,"
594                    " pg_attribute.attname FROM pg_class"
595                " JOIN pg_namespace"
596                    " ON pg_namespace.oid = pg_class.relnamespace"
597                    " AND pg_namespace.nspname NOT LIKE 'pg_%'"
598                " JOIN pg_attribute ON pg_attribute.attrelid = pg_class.oid"
599                    " AND pg_attribute.attisdropped = 'f'"
600                " JOIN pg_index ON pg_index.indrelid = pg_class.oid"
601                    " AND pg_index.indisprimary = 't'"
602                    " AND pg_attribute.attnum " + any_indkey).getresult():
603                cl, pkey = _join_parts(r[:2]), r[2]
604                self._pkeys.setdefault(cl, []).append(pkey)
605            # (only) for composite primary keys, the values will be frozensets
606            for cl, pkey in self._pkeys.items():
607                self._pkeys[cl] = frozenset(pkey) if len(pkey) > 1 else pkey[0]
608            self._do_debug(self._pkeys)
609
610        # will raise an exception if primary key doesn't exist
611        return self._pkeys[qcl]
612
613    def get_databases(self):
614        """Get list of databases in the system."""
615        return [s[0] for s in
616            self.db.query('SELECT datname FROM pg_database').getresult()]
617
618    def get_relations(self, kinds=None):
619        """Get list of relations in connected database of specified kinds.
620
621            If kinds is None or empty, all kinds of relations are returned.
622            Otherwise kinds can be a string or sequence of type letters
623            specifying which kind of relations you want to list.
624
625        """
626        where = kinds and "pg_class.relkind IN (%s) AND" % ','.join(
627            ["'%s'" % x for x in kinds]) or ''
628        return [_join_parts(x) for x in self.db.query(
629            "SELECT pg_namespace.nspname, pg_class.relname "
630            "FROM pg_class "
631            "JOIN pg_namespace ON pg_namespace.oid = pg_class.relnamespace "
632            "WHERE %s pg_class.relname !~ '^Inv' AND "
633                "pg_class.relname !~ '^pg_' "
634            "ORDER BY 1, 2" % where).getresult()]
635
636    def get_tables(self):
637        """Return list of tables in connected database."""
638        return self.get_relations('r')
639
640    def get_attnames(self, cl, newattnames=None):
641        """Given the name of a table, digs out the set of attribute names.
642
643        Returns a dictionary of attribute names (the names are the keys,
644        the values are the names of the attributes' types).
645        If the optional newattnames exists, it must be a dictionary and
646        will become the new attribute names dictionary.
647
648        By default, only a limited number of simple types will be returned.
649        You can get the regular types after calling use_regtypes(True).
650
651        """
652        if isinstance(newattnames, dict):
653            self._attnames = newattnames
654            return
655        elif newattnames:
656            raise _prg_error('If supplied, newattnames must be a dictionary')
657        cl = self._split_schema(cl)  # split into schema and class
658        qcl = _join_parts(cl)  # build fully qualified name
659        # May as well cache them:
660        if qcl in self._attnames:
661            return self._attnames[qcl]
662        if qcl not in self.get_relations('rv'):
663            raise _prg_error('Class %s does not exist' % qcl)
664
665        q = "SELECT pg_attribute.attname, pg_type.typname"
666        if self._regtypes:
667            q += "::regtype"
668        q += (" FROM pg_class"
669            " JOIN pg_namespace ON pg_class.relnamespace = pg_namespace.oid"
670            " JOIN pg_attribute ON pg_attribute.attrelid = pg_class.oid"
671            " JOIN pg_type ON pg_type.oid = pg_attribute.atttypid"
672            " WHERE pg_namespace.nspname = '%s' AND pg_class.relname = '%s'"
673            " AND (pg_attribute.attnum > 0 OR pg_attribute.attname = 'oid')"
674            " AND pg_attribute.attisdropped = 'f'") % cl
675        q = self.db.query(q).getresult()
676
677        if self._regtypes:
678            t = dict(q)
679        else:
680            t = {}
681            for att, typ in q:
682                if typ.startswith('bool'):
683                    typ = 'bool'
684                elif typ.startswith('abstime'):
685                    typ = 'date'
686                elif typ.startswith('date'):
687                    typ = 'date'
688                elif typ.startswith('interval'):
689                    typ = 'date'
690                elif typ.startswith('timestamp'):
691                    typ = 'date'
692                elif typ.startswith('oid'):
693                    typ = 'int'
694                elif typ.startswith('int'):
695                    typ = 'int'
696                elif typ.startswith('float'):
697                    typ = 'float'
698                elif typ.startswith('numeric'):
699                    typ = 'num'
700                elif typ.startswith('money'):
701                    typ = 'money'
702                elif typ.startswith('bytea'):
703                    typ = 'bytea'
704                else:
705                    typ = 'text'
706                t[att] = typ
707
708        self._attnames[qcl] = t  # cache it
709        return self._attnames[qcl]
710
711    def use_regtypes(self, regtypes=None):
712        """Use regular type names instead of simplified type names."""
713        if regtypes is None:
714            return self._regtypes
715        else:
716            regtypes = bool(regtypes)
717            if regtypes != self._regtypes:
718                self._regtypes = regtypes
719                self._attnames.clear()
720            return regtypes
721
722    def has_table_privilege(self, cl, privilege='select'):
723        """Check whether current user has specified table privilege."""
724        qcl = self._add_schema(cl)
725        privilege = privilege.lower()
726        try:
727            return self._privileges[(qcl, privilege)]
728        except KeyError:
729            q = "SELECT has_table_privilege('%s', '%s')" % (qcl, privilege)
730            ret = self.db.query(q).getresult()[0][0] == 't'
731            self._privileges[(qcl, privilege)] = ret
732            return ret
733
734    def get(self, cl, arg, keyname=None):
735        """Get a tuple from a database table or view.
736
737        This method is the basic mechanism to get a single row.  The keyname
738        that the key specifies a unique row.  If keyname is not specified
739        then the primary key for the table is used.  If arg is a dictionary
740        then the value for the key is taken from it and it is modified to
741        include the new values, replacing existing values where necessary.
742        For a composite key, keyname can also be a sequence of key names.
743        The OID is also put into the dictionary if the table has one, but
744        in order to allow the caller to work with multiple tables, it is
745        munged as oid(schema.table).
746
747        """
748        if cl.endswith('*'):  # scan descendant tables?
749            cl = cl[:-1].rstrip()  # need parent table name
750        # build qualified class name
751        qcl = self._add_schema(cl)
752        # To allow users to work with multiple tables,
753        # we munge the name of the "oid" the key
754        qoid = _oid_key(qcl)
755        if not keyname:
756            # use the primary key by default
757            try:
758                keyname = self.pkey(qcl)
759            except KeyError:
760                raise _prg_error('Class %s has no primary key' % qcl)
761        # We want the oid for later updates if that isn't the key
762        if keyname == 'oid':
763            if isinstance(arg, dict):
764                if qoid not in arg:
765                    raise _db_error('%s not in arg' % qoid)
766            else:
767                arg = {qoid: arg}
768            where = 'oid = %s' % arg[qoid]
769            attnames = '*'
770        else:
771            attnames = self.get_attnames(qcl)
772            if isinstance(keyname, basestring):
773                keyname = (keyname,)
774            if not isinstance(arg, dict):
775                if len(keyname) > 1:
776                    raise _prg_error('Composite key needs dict as arg')
777                arg = dict([(k, arg) for k in keyname])
778            where = ' AND '.join(['%s = %s'
779                % (k, self._quote(arg[k], attnames[k])) for k in keyname])
780            attnames = ', '.join(attnames)
781        q = 'SELECT %s FROM %s WHERE %s LIMIT 1' % (attnames, qcl, where)
782        self._do_debug(q)
783        res = self.db.query(q).dictresult()
784        if not res:
785            raise _db_error('No such record in %s where %s' % (qcl, where))
786        for att, value in res[0].items():
787            arg[qoid if att == 'oid' else att] = value
788        return arg
789
790    def insert(self, cl, d=None, **kw):
791        """Insert a tuple into a database table.
792
793        This method inserts a row into a table.  If a dictionary is
794        supplied it starts with that.  Otherwise it uses a blank dictionary.
795        Either way the dictionary is updated from the keywords.
796
797        The dictionary is then, if possible, reloaded with the values actually
798        inserted in order to pick up values modified by rules, triggers, etc.
799
800        Note: The method currently doesn't support insert into views
801        although PostgreSQL does.
802
803        """
804        qcl = self._add_schema(cl)
805        qoid = _oid_key(qcl)
806        if d is None:
807            d = {}
808        d.update(kw)
809        attnames = self.get_attnames(qcl)
810        names, values = [], []
811        for n in attnames:
812            if n != 'oid' and n in d:
813                names.append('"%s"' % n)
814                values.append(self._quote(d[n], attnames[n]))
815        names, values = ', '.join(names), ', '.join(values)
816        selectable = self.has_table_privilege(qcl)
817        if selectable and self.server_version >= 80200:
818            ret = ' RETURNING %s*' % ('oid, ' if 'oid' in attnames else '')
819        else:
820            ret = ''
821        q = 'INSERT INTO %s (%s) VALUES (%s)%s' % (qcl, names, values, ret)
822        self._do_debug(q)
823        res = self.db.query(q)
824        if ret:
825            res = res.dictresult()
826            for att, value in res[0].items():
827                d[qoid if att == 'oid' else att] = value
828        elif isinstance(res, int):
829            d[qoid] = res
830            if selectable:
831                self.get(qcl, d, 'oid')
832        elif selectable:
833            if qoid in d:
834                self.get(qcl, d, 'oid')
835            else:
836                try:
837                    self.get(qcl, d)
838                except ProgrammingError:
839                    pass  # table has no primary key
840        return d
841
842    def update(self, cl, d=None, **kw):
843        """Update an existing row in a database table.
844
845        Similar to insert but updates an existing row.  The update is based
846        on the OID value as munged by get or passed as keyword, or on the
847        primary key of the table.  The dictionary is modified, if possible,
848        to reflect any changes caused by the update due to triggers, rules,
849        default values, etc.
850
851        """
852        # Update always works on the oid which get returns if available,
853        # otherwise use the primary key.  Fail if neither.
854        # Note that we only accept oid key from named args for safety
855        qcl = self._add_schema(cl)
856        qoid = _oid_key(qcl)
857        if 'oid' in kw:
858            kw[qoid] = kw['oid']
859            del kw['oid']
860        if d is None:
861            d = {}
862        d.update(kw)
863        attnames = self.get_attnames(qcl)
864        if qoid in d:
865            where = 'oid = %s' % d[qoid]
866            keyname = ()
867        else:
868            try:
869                keyname = self.pkey(qcl)
870            except KeyError:
871                raise _prg_error('Class %s has no primary key' % qcl)
872            if isinstance(keyname, basestring):
873                keyname = (keyname,)
874            try:
875                where = ' AND '.join(['%s = %s'
876                    % (k, self._quote(d[k], attnames[k])) for k in keyname])
877            except KeyError:
878                raise _prg_error('Update needs primary key or oid.')
879        values = []
880        for n in attnames:
881            if n in d and n not in keyname:
882                values.append('%s = %s' % (n, self._quote(d[n], attnames[n])))
883        if not values:
884            return d
885        values = ', '.join(values)
886        selectable = self.has_table_privilege(qcl)
887        if selectable and self.server_version >= 80200:
888            ret = ' RETURNING %s*' % ('oid, ' if 'oid' in attnames else '')
889        else:
890            ret = ''
891        q = 'UPDATE %s SET %s WHERE %s%s' % (qcl, values, where, ret)
892        self._do_debug(q)
893        res = self.db.query(q)
894        if ret:
895            res = res.dictresult()[0]
896            for att, value in res.items():
897                d[qoid if att == 'oid' else att] = value
898        else:
899            if selectable:
900                if qoid in d:
901                    self.get(qcl, d, 'oid')
902                else:
903                    self.get(qcl, d)
904        return d
905
906    def clear(self, cl, a=None):
907        """Clear all the attributes to values determined by the types.
908
909        Numeric types are set to 0, Booleans are set to 'f', and everything
910        else is set to the empty string.  If the array argument is present,
911        it is used as the array and any entries matching attribute names are
912        cleared with everything else left unchanged.
913
914        """
915        # At some point we will need a way to get defaults from a table.
916        qcl = self._add_schema(cl)
917        if a is None:
918            a = {}  # empty if argument is not present
919        attnames = self.get_attnames(qcl)
920        for n, t in attnames.items():
921            if n == 'oid':
922                continue
923            if t in ('int', 'integer', 'smallint', 'bigint',
924                    'float', 'real', 'double precision',
925                    'num', 'numeric', 'money'):
926                a[n] = 0
927            elif t in ('bool', 'boolean'):
928                a[n] = 'f'
929            else:
930                a[n] = ''
931        return a
932
933    def delete(self, cl, d=None, **kw):
934        """Delete an existing row in a database table.
935
936        This method deletes the row from a table.  It deletes based on the
937        OID value as munged by get or passed as keyword, or on the primary
938        key of the table.  The return value is the number of deleted rows
939        (i.e. 0 if the row did not exist and 1 if the row was deleted).
940
941        """
942        # Like update, delete works on the oid.
943        # One day we will be testing that the record to be deleted
944        # isn't referenced somewhere (or else PostgreSQL will).
945        # Note that we only accept oid key from named args for safety
946        qcl = self._add_schema(cl)
947        qoid = _oid_key(qcl)
948        if 'oid' in kw:
949            kw[qoid] = kw['oid']
950            del kw['oid']
951        if d is None:
952            d = {}
953        d.update(kw)
954        if qoid in d:
955            where = 'oid = %s' % d[qoid]
956        else:
957            try:
958                keyname = self.pkey(qcl)
959            except KeyError:
960                raise _prg_error('Class %s has no primary key' % qcl)
961            if isinstance(keyname, basestring):
962                keyname = (keyname,)
963            attnames = self.get_attnames(qcl)
964            try:
965                where = ' AND '.join(['%s = %s'
966                    % (k, self._quote(d[k], attnames[k])) for k in keyname])
967            except KeyError:
968                raise _prg_error('Delete needs primary key or oid.')
969        q = 'DELETE FROM %s WHERE %s' % (qcl, where)
970        self._do_debug(q)
971        return int(self.db.query(q))
972
973    def notification_handler(self, event, callback, arg_dict={}, timeout=None):
974        """Get notification handler that will run the given callback."""
975        return NotificationHandler(self.db, event, callback, arg_dict, timeout)
976
977
978# if run as script, print some information
979
980if __name__ == '__main__':
981    print('PyGreSQL version' + version)
982    print('')
983    print(__doc__)
Note: See TracBrowser for help on using the repository browser.