source: trunk/pg.py @ 723

Last change on this file since 723 was 723, checked in by cito, 4 years ago

Minor fixes and additions in the docs

  • Property svn:keywords set to Id
File size: 34.3 KB
Line 
1#! /usr/bin/python
2#
3# pg.py
4#
5# $Id: pg.py 723 2016-01-12 13:44:48Z cito $
6#
7
8"""PyGreSQL classic interface.
9
10This pg module implements some basic database management stuff.
11It includes the _pg module and builds on it, providing the higher
12level wrapper class named DB with additional functionality.
13This is known as the "classic" ("old style") PyGreSQL interface.
14For a DB-API 2 compliant interface use the newer pgdb module.
15
16"""
17
18# Copyright (c) 1997-2013 by D'Arcy J.M. Cain.
19#
20# Contributions made by Ch. Zwerschke and others.
21#
22# The notification handler is based on pgnotify which is
23# Copyright (c) 2001 Ng Pheng Siong. All rights reserved.
24#
25# Permission to use, copy, modify, and distribute this software and its
26# documentation for any purpose and without fee is hereby granted,
27# provided that the above copyright notice appear in all copies and that
28# both that copyright notice and this permission notice appear in
29# supporting documentation.
30
31from __future__ import print_function
32
33from _pg import *
34
35import select
36import warnings
37
38from decimal import Decimal
39from collections import namedtuple
40
41try:
42    basestring
43except NameError:  # Python >= 3.0
44    basestring = (str, bytes)
45
46set_decimal(Decimal)
47
48
49# Auxiliary functions that are independent from a DB connection:
50
51def _is_quoted(s):
52    """Check whether this string is a quoted identifier."""
53    s = s.replace('_', 'a')
54    return not s.isalnum() or s[:1].isdigit() or s != s.lower()
55
56
57def _is_unquoted(s):
58    """Check whether this string is an unquoted identifier."""
59    s = s.replace('_', 'a')
60    return s.isalnum() and not s[:1].isdigit()
61
62
63def _split_first_part(s):
64    """Split the first part of a dot separated string."""
65    s = s.lstrip()
66    if s[:1] == '"':
67        p = []
68        s = s.split('"', 3)[1:]
69        p.append(s[0])
70        while len(s) == 3 and s[1] == '':
71            p.append('"')
72            s = s[2].split('"', 2)
73            p.append(s[0])
74        p = [''.join(p)]
75        s = '"'.join(s[1:]).lstrip()
76        if s:
77            if s[:0] == '.':
78                p.append(s[1:])
79            else:
80                s = _split_first_part(s)
81                p[0] += s[0]
82                if len(s) > 1:
83                    p.append(s[1])
84    else:
85        p = s.split('.', 1)
86        s = p[0].rstrip()
87        if _is_unquoted(s):
88            s = s.lower()
89        p[0] = s
90    return p
91
92
93def _split_parts(s):
94    """Split all parts of a dot separated string."""
95    q = []
96    while s:
97        s = _split_first_part(s)
98        q.append(s[0])
99        if len(s) < 2:
100            break
101        s = s[1]
102    return q
103
104
105def _join_parts(s):
106    """Join all parts of a dot separated string."""
107    return '.'.join(['"%s"' % p if _is_quoted(p) else p for p in s])
108
109
110def _oid_key(qcl):
111    """Build oid key from qualified class name."""
112    return 'oid(%s)' % qcl
113
114
115def _namedresult(q):
116    """Get query result as named tuples."""
117    row = namedtuple('Row', q.listfields())
118    return [row(*r) for r in q.getresult()]
119
120set_namedresult(_namedresult)
121
122
123def _db_error(msg, cls=DatabaseError):
124    """Returns DatabaseError with empty sqlstate attribute."""
125    error = cls(msg)
126    error.sqlstate = None
127    return error
128
129
130def _int_error(msg):
131    """Returns InternalError."""
132    return _db_error(msg, InternalError)
133
134
135def _prg_error(msg):
136    """Returns ProgrammingError."""
137    return _db_error(msg, ProgrammingError)
138
139
140class NotificationHandler(object):
141    """A PostgreSQL client-side asynchronous notification handler."""
142
143    def __init__(self, db, event, callback, arg_dict=None, timeout=None):
144        """Initialize the notification handler.
145
146        db       - PostgreSQL connection object.
147        event    - Event (notification channel) to LISTEN for.
148        callback - Event callback function.
149        arg_dict - A dictionary passed as the argument to the callback.
150        timeout  - Timeout in seconds; a floating point number denotes
151                   fractions of seconds. If it is absent or None, the
152                   callers will never time out.
153
154        """
155        if isinstance(db, DB):
156            db = db.db
157        self.db = db
158        self.event = event
159        self.stop_event = 'stop_%s' % event
160        self.listening = False
161        self.callback = callback
162        if arg_dict is None:
163            arg_dict = {}
164        self.arg_dict = arg_dict
165        self.timeout = timeout
166
167    def __del__(self):
168        self.close()
169
170    def close(self):
171        """Stop listening and close the connection."""
172        if self.db:
173            self.unlisten()
174            self.db.close()
175            self.db = None
176
177    def listen(self):
178        """Start listening for the event and the stop event."""
179        if not self.listening:
180            self.db.query('listen "%s"' % self.event)
181            self.db.query('listen "%s"' % self.stop_event)
182            self.listening = True
183
184    def unlisten(self):
185        """Stop listening for the event and the stop event."""
186        if self.listening:
187            self.db.query('unlisten "%s"' % self.event)
188            self.db.query('unlisten "%s"' % self.stop_event)
189            self.listening = False
190
191    def notify(self, db=None, stop=False, payload=None):
192        """Generate a notification.
193
194        Note: If the main loop is running in another thread, you must pass
195        a different database connection to avoid a collision.
196
197        """
198        if not db:
199            db = self.db
200        if self.listening:
201            q = 'notify "%s"' % (self.stop_event if stop else self.event)
202            if payload:
203                q += ", '%s'" % payload
204            return db.query(q)
205
206    def __call__(self, close=False):
207        """Invoke the notification handler.
208
209        The handler is a loop that actually LISTENs for two NOTIFY messages:
210
211        <event> and stop_<event>.
212
213        When either of these NOTIFY messages are received, its associated
214        'pid' and 'event' are inserted into <arg_dict>, and the callback is
215        invoked with <arg_dict>. If the NOTIFY message is stop_<event>, the
216        handler UNLISTENs both <event> and stop_<event> and exits.
217
218        Note: If you run this loop in another thread, don't use the same
219        database connection for database operations in the main thread.
220
221        """
222        self.listen()
223        _ilist = [self.db.fileno()]
224
225        while self.listening:
226            ilist, _olist, _elist = select.select(_ilist, [], [], self.timeout)
227            if ilist:
228                while self.listening:
229                    notice = self.db.getnotify()
230                    if not notice:  # no more messages
231                        break
232                    event, pid, extra = notice
233                    if event not in (self.event, self.stop_event):
234                        self.unlisten()
235                        raise _db_error(
236                            'listening for "%s" and "%s", but notified of "%s"'
237                            % (self.event, self.stop_event, event))
238                    if event == self.stop_event:
239                        self.unlisten()
240                    self.arg_dict['pid'] = pid
241                    self.arg_dict['event'] = event
242                    self.arg_dict['extra'] = extra
243                    self.callback(self.arg_dict)
244            else:   # we timed out
245                self.unlisten()
246                self.callback(None)
247
248
249def pgnotify(*args, **kw):
250    """Same as NotificationHandler, under the traditional name."""
251    warnings.warn("pgnotify is deprecated, use NotificationHandler instead.",
252        DeprecationWarning, stacklevel=2)
253    return NotificationHandler(*args, **kw)
254
255
256# The actual PostGreSQL database connection interface:
257
258class DB(object):
259    """Wrapper class for the _pg connection type."""
260
261    def __init__(self, *args, **kw):
262        """Create a new connection.
263
264        You can pass either the connection parameters or an existing
265        _pg or pgdb connection. This allows you to use the methods
266        of the classic pg interface with a DB-API 2 pgdb connection.
267
268        """
269        if not args and len(kw) == 1:
270            db = kw.get('db')
271        elif not kw and len(args) == 1:
272            db = args[0]
273        else:
274            db = None
275        if db:
276            if isinstance(db, DB):
277                db = db.db
278            else:
279                try:
280                    db = db._cnx
281                except AttributeError:
282                    pass
283        if not db or not hasattr(db, 'db') or not hasattr(db, 'query'):
284            db = connect(*args, **kw)
285            self._closeable = True
286        else:
287            self._closeable = False
288        self.db = db
289        self.dbname = db.db
290        self._regtypes = False
291        self._attnames = {}
292        self._pkeys = {}
293        self._privileges = {}
294        self._args = args, kw
295        self.debug = None  # For debugging scripts, this can be set
296            # * to a string format specification (e.g. in CGI set to "%s<BR>"),
297            # * to a file object to write debug statements or
298            # * to a callable object which takes a string argument
299            # * to any other true value to just print debug statements
300
301    def __getattr__(self, name):
302        # All undefined members are same as in underlying connection:
303        if self.db:
304            return getattr(self.db, name)
305        else:
306            raise _int_error('Connection is not valid')
307
308    def __dir__(self):
309        # Custom dir function including the attributes of the connection:
310        attrs = set(self.__class__.__dict__)
311        attrs.update(self.__dict__)
312        attrs.update(dir(self.db))
313        return sorted(attrs)
314
315    # Context manager methods
316
317    def __enter__(self):
318        """Enter the runtime context. This will start a transaction."""
319        self.begin()
320        return self
321
322    def __exit__(self, et, ev, tb):
323        """Exit the runtime context. This will end the transaction."""
324        if et is None and ev is None and tb is None:
325            self.commit()
326        else:
327            self.rollback()
328
329    # Auxiliary methods
330
331    def _do_debug(self, s):
332        """Print a debug message."""
333        if self.debug:
334            if isinstance(self.debug, basestring):
335                print(self.debug % s)
336            elif hasattr(self.debug, 'write'):
337                self.debug.write(s + '\n')
338            elif callable(self.debug):
339                self.debug(s)
340            else:
341                print(s)
342
343    @staticmethod
344    def _make_bool(d):
345        """Get boolean value corresponding to d."""
346        return bool(d) if get_bool() else ('t' if d else 'f')
347
348    def _quote_text(self, d):
349        """Quote text value."""
350        if not isinstance(d, basestring):
351            d = str(d)
352        return "'%s'" % self.escape_string(d)
353
354    _bool_true = frozenset('t true 1 y yes on'.split())
355
356    def _quote_bool(self, d):
357        """Quote boolean value."""
358        if isinstance(d, basestring):
359            if not d:
360                return 'NULL'
361            d = d.lower() in self._bool_true
362        return "'t'" if d else "'f'"
363
364    _date_literals = frozenset('current_date current_time'
365        ' current_timestamp localtime localtimestamp'.split())
366
367    def _quote_date(self, d):
368        """Quote date value."""
369        if not d:
370            return 'NULL'
371        if isinstance(d, basestring) and d.lower() in self._date_literals:
372            return d
373        return self._quote_text(d)
374
375    def _quote_num(self, d):
376        """Quote numeric value."""
377        if not d and d != 0:
378            return 'NULL'
379        return str(d)
380
381    def _quote_money(self, d):
382        """Quote money value."""
383        if d is None or d == '':
384            return 'NULL'
385        if not isinstance(d, basestring):
386            d = str(d)
387        return d
388
389    if bytes is str:  # Python < 3.0
390        """Quote bytes value."""
391
392        def _quote_bytea(self, d):
393            return "'%s'" % self.escape_bytea(d)
394
395    else:
396
397        def _quote_bytea(self, d):
398            return "'%s'" % self.escape_bytea(d).decode('ascii')
399
400    _quote_funcs = dict(  # quote methods for each type
401        text=_quote_text, bool=_quote_bool, date=_quote_date,
402        int=_quote_num, num=_quote_num, float=_quote_num,
403        money=_quote_money, bytea=_quote_bytea)
404
405    def _quote(self, d, t):
406        """Return quotes if needed."""
407        if d is None:
408            return 'NULL'
409        try:
410            quote_func = self._quote_funcs[t]
411        except KeyError:
412            quote_func = self._quote_funcs['text']
413        return quote_func(self, d)
414
415    def _split_schema(self, cl):
416        """Return schema and name of object separately.
417
418        This auxiliary function splits off the namespace (schema)
419        belonging to the class with the name cl. If the class name
420        is not qualified, the function is able to determine the schema
421        of the class, taking into account the current search path.
422
423        """
424        s = _split_parts(cl)
425        if len(s) > 1:  # name already qualified?
426            # should be database.schema.table or schema.table
427            if len(s) > 3:
428                raise _prg_error('Too many dots in class name %s' % cl)
429            schema, cl = s[-2:]
430        else:
431            cl = s[0]
432            # determine search path
433            q = 'SELECT current_schemas(TRUE)'
434            schemas = self.db.query(q).getresult()[0][0][1:-1].split(',')
435            if schemas:  # non-empty path
436                # search schema for this object in the current search path
437                q = ' UNION '.join(
438                    ["SELECT %d::integer AS n, '%s'::name AS nspname"
439                        % s for s in enumerate(schemas)])
440                q = ("SELECT nspname FROM pg_class"
441                    " JOIN pg_namespace"
442                    " ON pg_class.relnamespace = pg_namespace.oid"
443                    " JOIN (%s) AS p USING (nspname)"
444                    " WHERE pg_class.relname = '%s'"
445                    " ORDER BY n LIMIT 1" % (q, cl))
446                schema = self.db.query(q).getresult()
447                if schema:  # schema found
448                    schema = schema[0][0]
449                else:  # object not found in current search path
450                    schema = 'public'
451            else:  # empty path
452                schema = 'public'
453        return schema, cl
454
455    def _add_schema(self, cl):
456        """Ensure that the class name is prefixed with a schema name."""
457        return _join_parts(self._split_schema(cl))
458
459    # Public methods
460
461    # escape_string and escape_bytea exist as methods,
462    # so we define unescape_bytea as a method as well
463    unescape_bytea = staticmethod(unescape_bytea)
464
465    def close(self):
466        """Close the database connection."""
467        # Wraps shared library function so we can track state.
468        if self._closeable:
469            if self.db:
470                self.db.close()
471                self.db = None
472            else:
473                raise _int_error('Connection already closed')
474
475    def reset(self):
476        """Reset connection with current parameters.
477
478        All derived queries and large objects derived from this connection
479        will not be usable after this call.
480
481        """
482        if self.db:
483            self.db.reset()
484        else:
485            raise _int_error('Connection already closed')
486
487    def reopen(self):
488        """Reopen connection to the database.
489
490        Used in case we need another connection to the same database.
491        Note that we can still reopen a database that we have closed.
492
493        """
494        # There is no such shared library function.
495        if self._closeable:
496            db = connect(*self._args[0], **self._args[1])
497            if self.db:
498                self.db.close()
499            self.db = db
500
501    def begin(self, mode=None):
502        """Begin a transaction."""
503        qstr = 'BEGIN'
504        if mode:
505            qstr += ' ' + mode
506        return self.query(qstr)
507
508    start = begin
509
510    def commit(self):
511        """Commit the current transaction."""
512        return self.query('COMMIT')
513
514    end = commit
515
516    def rollback(self, name=None):
517        """Rollback the current transaction."""
518        qstr = 'ROLLBACK'
519        if name:
520            qstr += ' TO ' + name
521        return self.query(qstr)
522
523    def savepoint(self, name=None):
524        """Define a new savepoint within the current transaction."""
525        qstr = 'SAVEPOINT'
526        if name:
527            qstr += ' ' + name
528        return self.query(qstr)
529
530    def release(self, name):
531        """Destroy a previously defined savepoint."""
532        return self.query('RELEASE ' + name)
533
534    def query(self, qstr, *args):
535        """Executes a SQL command string.
536
537        This method simply sends a SQL query to the database. If the query is
538        an insert statement that inserted exactly one row into a table that
539        has OIDs, the return value is the OID of the newly inserted row.
540        If the query is an update or delete statement, or an insert statement
541        that did not insert exactly one row in a table with OIDs, then the
542        number of rows affected is returned as a string. If it is a statement
543        that returns rows as a result (usually a select statement, but maybe
544        also an "insert/update ... returning" statement), this method returns
545        a Query object that can be accessed via getresult() or dictresult()
546        or simply printed. Otherwise, it returns `None`.
547
548        The query can contain numbered parameters of the form $1 in place
549        of any data constant. Arguments given after the query string will
550        be substituted for the corresponding numbered parameter. Parameter
551        values can also be given as a single list or tuple argument.
552
553        """
554        # Wraps shared library function for debugging.
555        if not self.db:
556            raise _int_error('Connection is not valid')
557        self._do_debug(qstr)
558        return self.db.query(qstr, args)
559
560    def pkey(self, cl, newpkey=None):
561        """This method gets or sets the primary key of a class.
562
563        Composite primary keys are represented as frozensets. Note that
564        this raises an exception if the table does not have a primary key.
565
566        If newpkey is set and is not a dictionary then set that
567        value as the primary key of the class.  If it is a dictionary
568        then replace the _pkeys dictionary with a copy of it.
569
570        """
571        # First see if the caller is supplying a dictionary
572        if isinstance(newpkey, dict):
573            # make sure that all classes have a namespace
574            self._pkeys = dict([
575                (cl if '.' in cl else 'public.' + cl, pkey)
576                for cl, pkey in newpkey.items()])
577            return self._pkeys
578
579        qcl = self._add_schema(cl)  # build fully qualified class name
580        # Check if the caller is supplying a new primary key for the class
581        if newpkey:
582            self._pkeys[qcl] = newpkey
583            return newpkey
584
585        # Get all the primary keys at once
586        if qcl not in self._pkeys:
587            # if not found, check again in case it was added after we started
588            self._pkeys = {}
589            q = ("SELECT s.nspname, r.relname, a.attname"
590                " FROM pg_class r"
591                " JOIN pg_namespace s ON s.oid = r.relnamespace"
592                " AND s.nspname NOT SIMILAR"
593                " TO 'pg/_%|information/_schema' ESCAPE '/'"
594                " JOIN pg_attribute a ON a.attrelid = r.oid"
595                " AND NOT a.attisdropped"
596                " JOIN pg_index i ON i.indrelid = r.oid"
597                " AND i.indisprimary AND a.attnum = ANY (i.indkey)")
598            for r in self.db.query(q).getresult():
599                cl, pkey = _join_parts(r[:2]), r[2]
600                self._pkeys.setdefault(cl, []).append(pkey)
601            # (only) for composite primary keys, the values will be frozensets
602            for cl, pkey in self._pkeys.items():
603                self._pkeys[cl] = frozenset(pkey) if len(pkey) > 1 else pkey[0]
604            self._do_debug(self._pkeys)
605
606        # will raise an exception if primary key doesn't exist
607        return self._pkeys[qcl]
608
609    def get_databases(self):
610        """Get list of databases in the system."""
611        return [s[0] for s in
612            self.db.query('SELECT datname FROM pg_database').getresult()]
613
614    def get_relations(self, kinds=None):
615        """Get list of relations in connected database of specified kinds.
616
617            If kinds is None or empty, all kinds of relations are returned.
618            Otherwise kinds can be a string or sequence of type letters
619            specifying which kind of relations you want to list.
620
621        """
622        where = " AND r.relkind IN (%s)" % ','.join(
623            ["'%s'" % k for k in kinds]) if kinds else ''
624        q = ("SELECT s.nspname, r.relname"
625            " FROM pg_class r"
626            " JOIN pg_namespace s ON s.oid = r.relnamespace"
627            " WHERE s.nspname NOT SIMILAR"
628            " TO 'pg/_%%|information/_schema' ESCAPE '/' %s"
629            " ORDER BY 1, 2") % where
630        return [_join_parts(r) for r in self.db.query(q).getresult()]
631
632    def get_tables(self):
633        """Return list of tables in connected database."""
634        return self.get_relations('r')
635
636    def get_attnames(self, cl, newattnames=None):
637        """Given the name of a table, digs out the set of attribute names.
638
639        Returns a dictionary of attribute names (the names are the keys,
640        the values are the names of the attributes' types).
641        If the optional newattnames exists, it must be a dictionary and
642        will become the new attribute names dictionary.
643
644        By default, only a limited number of simple types will be returned.
645        You can get the regular types after calling use_regtypes(True).
646
647        """
648        if isinstance(newattnames, dict):
649            self._attnames = newattnames
650            return
651        elif newattnames:
652            raise _prg_error('If supplied, newattnames must be a dictionary')
653        cl = self._split_schema(cl)  # split into schema and class
654        qcl = _join_parts(cl)  # build fully qualified name
655        # May as well cache them:
656        if qcl in self._attnames:
657            return self._attnames[qcl]
658        if qcl not in self.get_relations('rv'):
659            raise _prg_error('Class %s does not exist' % qcl)
660
661        q = ("SELECT a.attname, t.typname%s"
662            " FROM pg_class r"
663            " JOIN pg_namespace s ON r.relnamespace = s.oid"
664            " JOIN pg_attribute a ON a.attrelid = r.oid"
665            " JOIN pg_type t ON t.oid = a.atttypid"
666            " WHERE s.nspname = $1 AND r.relname = $2"
667            " AND (a.attnum > 0 OR a.attname = 'oid')"
668            " AND NOT a.attisdropped") % (
669                '::regtype' if self._regtypes else '',)
670        q = self.db.query(q, cl).getresult()
671
672        if self._regtypes:
673            t = dict(q)
674        else:
675            t = {}
676            for att, typ in q:
677                if typ.startswith('bool'):
678                    typ = 'bool'
679                elif typ.startswith('abstime'):
680                    typ = 'date'
681                elif typ.startswith('date'):
682                    typ = 'date'
683                elif typ.startswith('interval'):
684                    typ = 'date'
685                elif typ.startswith('timestamp'):
686                    typ = 'date'
687                elif typ.startswith('oid'):
688                    typ = 'int'
689                elif typ.startswith('int'):
690                    typ = 'int'
691                elif typ.startswith('float'):
692                    typ = 'float'
693                elif typ.startswith('numeric'):
694                    typ = 'num'
695                elif typ.startswith('money'):
696                    typ = 'money'
697                elif typ.startswith('bytea'):
698                    typ = 'bytea'
699                else:
700                    typ = 'text'
701                t[att] = typ
702
703        self._attnames[qcl] = t  # cache it
704        return self._attnames[qcl]
705
706    def use_regtypes(self, regtypes=None):
707        """Use regular type names instead of simplified type names."""
708        if regtypes is None:
709            return self._regtypes
710        else:
711            regtypes = bool(regtypes)
712            if regtypes != self._regtypes:
713                self._regtypes = regtypes
714                self._attnames.clear()
715            return regtypes
716
717    def has_table_privilege(self, cl, privilege='select'):
718        """Check whether current user has specified table privilege."""
719        qcl = self._add_schema(cl)
720        privilege = privilege.lower()
721        try:
722            return self._privileges[(qcl, privilege)]
723        except KeyError:
724            q = "SELECT has_table_privilege($1, $2)"
725            q = self.db.query(q, (qcl, privilege))
726            ret = q.getresult()[0][0] == self._make_bool(True)
727            self._privileges[(qcl, privilege)] = ret
728            return ret
729
730    def get(self, cl, arg, keyname=None):
731        """Get a row from a database table or view.
732
733        This method is the basic mechanism to get a single row.  The keyname
734        that the key specifies a unique row.  If keyname is not specified
735        then the primary key for the table is used.  If arg is a dictionary
736        then the value for the key is taken from it and it is modified to
737        include the new values, replacing existing values where necessary.
738        For a composite key, keyname can also be a sequence of key names.
739        The OID is also put into the dictionary if the table has one, but
740        in order to allow the caller to work with multiple tables, it is
741        munged as oid(schema.table).
742
743        """
744        if cl.endswith('*'):  # scan descendant tables?
745            cl = cl[:-1].rstrip()  # need parent table name
746        # build qualified class name
747        qcl = self._add_schema(cl)
748        # To allow users to work with multiple tables,
749        # we munge the name of the "oid" the key
750        qoid = _oid_key(qcl)
751        if not keyname:
752            # use the primary key by default
753            try:
754                keyname = self.pkey(qcl)
755            except KeyError:
756                raise _prg_error('Class %s has no primary key' % qcl)
757        attnames = self.get_attnames(qcl)
758        # We want the oid for later updates if that isn't the key
759        if keyname == 'oid':
760            if isinstance(arg, dict):
761                if qoid not in arg:
762                    raise _db_error('%s not in arg' % qoid)
763            else:
764                arg = {qoid: arg}
765            what = '*'
766            where = 'oid = %s' % arg[qoid]
767        else:
768            if isinstance(keyname, basestring):
769                keyname = (keyname,)
770            if not isinstance(arg, dict):
771                if len(keyname) > 1:
772                    raise _prg_error('Composite key needs dict as arg')
773                arg = dict([(k, arg) for k in keyname])
774            what = ', '.join(attnames)
775            where = ' AND '.join(['%s = %s'
776                % (k, self._quote(arg[k], attnames[k])) for k in keyname])
777        q = 'SELECT %s FROM %s WHERE %s LIMIT 1' % (what, qcl, where)
778        self._do_debug(q)
779        res = self.db.query(q).dictresult()
780        if not res:
781            raise _db_error('No such record in %s where %s' % (qcl, where))
782        for n, value in res[0].items():
783            if n == 'oid':
784                n = qoid
785            elif attnames.get(n) == 'bytea':
786                value = self.unescape_bytea(value)
787            arg[n] = value
788        return arg
789
790    def insert(self, cl, d=None, **kw):
791        """Insert a row into a database table.
792
793        This method inserts a row into a table.  If a dictionary is
794        supplied it starts with that.  Otherwise it uses a blank dictionary.
795        Either way the dictionary is updated from the keywords.
796
797        The dictionary is then, if possible, reloaded with the values actually
798        inserted in order to pick up values modified by rules, triggers, etc.
799
800        Note: The method currently doesn't support insert into views
801        although PostgreSQL does.
802
803        """
804        qcl = self._add_schema(cl)
805        qoid = _oid_key(qcl)
806        if d is None:
807            d = {}
808        d.update(kw)
809        attnames = self.get_attnames(qcl)
810        names, values = [], []
811        for n in attnames:
812            if n != 'oid' and n in d:
813                names.append('"%s"' % n)
814                values.append(self._quote(d[n], attnames[n]))
815        names, values = ', '.join(names), ', '.join(values)
816        selectable = self.has_table_privilege(qcl)
817        if selectable:
818            ret = ' RETURNING %s*' % ('oid, ' if 'oid' in attnames else '')
819        else:
820            ret = ''
821        q = 'INSERT INTO %s (%s) VALUES (%s)%s' % (qcl, names, values, ret)
822        self._do_debug(q)
823        res = self.db.query(q)
824        if ret:
825            res = res.dictresult()[0]
826            for n, value in res.items():
827                if n == 'oid':
828                    n = qoid
829                elif attnames.get(n) == 'bytea':
830                    value = self.unescape_bytea(value)
831                d[n] = value
832        elif isinstance(res, int):
833            d[qoid] = res
834            if selectable:
835                self.get(qcl, d, 'oid')
836        elif selectable:
837            if qoid in d:
838                self.get(qcl, d, 'oid')
839            else:
840                try:
841                    self.get(qcl, d)
842                except ProgrammingError:
843                    pass  # table has no primary key
844        return d
845
846    def update(self, cl, d=None, **kw):
847        """Update an existing row in a database table.
848
849        Similar to insert but updates an existing row.  The update is based
850        on the OID value as munged by get or passed as keyword, or on the
851        primary key of the table.  The dictionary is modified, if possible,
852        to reflect any changes caused by the update due to triggers, rules,
853        default values, etc.
854
855        """
856        # Update always works on the oid which get returns if available,
857        # otherwise use the primary key.  Fail if neither.
858        # Note that we only accept oid key from named args for safety
859        qcl = self._add_schema(cl)
860        qoid = _oid_key(qcl)
861        if 'oid' in kw:
862            kw[qoid] = kw['oid']
863            del kw['oid']
864        if d is None:
865            d = {}
866        d.update(kw)
867        attnames = self.get_attnames(qcl)
868        if qoid in d:
869            where = 'oid = %s' % d[qoid]
870            keyname = ()
871        else:
872            try:
873                keyname = self.pkey(qcl)
874            except KeyError:
875                raise _prg_error('Class %s has no primary key' % qcl)
876            if isinstance(keyname, basestring):
877                keyname = (keyname,)
878            try:
879                where = ' AND '.join(['%s = %s'
880                    % (k, self._quote(d[k], attnames[k])) for k in keyname])
881            except KeyError:
882                raise _prg_error('Update needs primary key or oid.')
883        values = []
884        for n in attnames:
885            if n in d and n not in keyname:
886                values.append('%s = %s' % (n, self._quote(d[n], attnames[n])))
887        if not values:
888            return d
889        values = ', '.join(values)
890        selectable = self.has_table_privilege(qcl)
891        if selectable:
892            ret = ' RETURNING %s*' % ('oid, ' if 'oid' in attnames else '')
893        else:
894            ret = ''
895        q = 'UPDATE %s SET %s WHERE %s%s' % (qcl, values, where, ret)
896        self._do_debug(q)
897        res = self.db.query(q)
898        if ret:
899            res = res.dictresult()[0]
900            for n, value in res.items():
901                if n == 'oid':
902                    n = qoid
903                elif attnames.get(n) == 'bytea':
904                    value = self.unescape_bytea(value)
905                d[n] = value
906        else:
907            if selectable:
908                if qoid in d:
909                    self.get(qcl, d, 'oid')
910                else:
911                    self.get(qcl, d)
912        return d
913
914    def clear(self, cl, a=None):
915        """Clear all the attributes to values determined by the types.
916
917        Numeric types are set to 0, Booleans are set to false, and everything
918        else is set to the empty string.  If the array argument is present,
919        it is used as the array and any entries matching attribute names are
920        cleared with everything else left unchanged.
921
922        """
923        # At some point we will need a way to get defaults from a table.
924        qcl = self._add_schema(cl)
925        if a is None:
926            a = {}  # empty if argument is not present
927        attnames = self.get_attnames(qcl)
928        for n, t in attnames.items():
929            if n == 'oid':
930                continue
931            if t in ('int', 'integer', 'smallint', 'bigint',
932                    'float', 'real', 'double precision',
933                    'num', 'numeric', 'money'):
934                a[n] = 0
935            elif t in ('bool', 'boolean'):
936                a[n] = self._make_bool(False)
937            else:
938                a[n] = ''
939        return a
940
941    def delete(self, cl, d=None, **kw):
942        """Delete an existing row in a database table.
943
944        This method deletes the row from a table.  It deletes based on the
945        OID value as munged by get or passed as keyword, or on the primary
946        key of the table.  The return value is the number of deleted rows
947        (i.e. 0 if the row did not exist and 1 if the row was deleted).
948
949        """
950        # Like update, delete works on the oid.
951        # One day we will be testing that the record to be deleted
952        # isn't referenced somewhere (or else PostgreSQL will).
953        # Note that we only accept oid key from named args for safety
954        qcl = self._add_schema(cl)
955        qoid = _oid_key(qcl)
956        if 'oid' in kw:
957            kw[qoid] = kw['oid']
958            del kw['oid']
959        if d is None:
960            d = {}
961        d.update(kw)
962        if qoid in d:
963            where = 'oid = %s' % d[qoid]
964        else:
965            try:
966                keyname = self.pkey(qcl)
967            except KeyError:
968                raise _prg_error('Class %s has no primary key' % qcl)
969            if isinstance(keyname, basestring):
970                keyname = (keyname,)
971            attnames = self.get_attnames(qcl)
972            try:
973                where = ' AND '.join(['%s = %s'
974                    % (k, self._quote(d[k], attnames[k])) for k in keyname])
975            except KeyError:
976                raise _prg_error('Delete needs primary key or oid.')
977        q = 'DELETE FROM %s WHERE %s' % (qcl, where)
978        self._do_debug(q)
979        return int(self.db.query(q))
980
981    def notification_handler(self, event, callback, arg_dict={}, timeout=None):
982        """Get notification handler that will run the given callback."""
983        return NotificationHandler(self.db, event, callback, arg_dict, timeout)
984
985
986# if run as script, print some information
987
988if __name__ == '__main__':
989    print('PyGreSQL version' + version)
990    print('')
991    print(__doc__)
Note: See TracBrowser for help on using the repository browser.