source: trunk/pg.py @ 727

Last change on this file since 727 was 727, checked in by cito, 4 years ago

Improve docstring

  • Property svn:keywords set to Id
File size: 34.5 KB
Line 
1#! /usr/bin/python
2#
3# pg.py
4#
5# $Id: pg.py 727 2016-01-12 16:57:21Z cito $
6#
7
8"""PyGreSQL classic interface.
9
10This pg module implements some basic database management stuff.
11It includes the _pg module and builds on it, providing the higher
12level wrapper class named DB with additional functionality.
13This is known as the "classic" ("old style") PyGreSQL interface.
14For a DB-API 2 compliant interface use the newer pgdb module.
15
16"""
17
18# Copyright (c) 1997-2013 by D'Arcy J.M. Cain.
19#
20# Contributions made by Ch. Zwerschke and others.
21#
22# The notification handler is based on pgnotify which is
23# Copyright (c) 2001 Ng Pheng Siong. All rights reserved.
24#
25# Permission to use, copy, modify, and distribute this software and its
26# documentation for any purpose and without fee is hereby granted,
27# provided that the above copyright notice appear in all copies and that
28# both that copyright notice and this permission notice appear in
29# supporting documentation.
30
31from __future__ import print_function
32
33from _pg import *
34
35import select
36import warnings
37
38from decimal import Decimal
39from collections import namedtuple
40from itertools import groupby
41
42try:
43    basestring
44except NameError:  # Python >= 3.0
45    basestring = (str, bytes)
46
47set_decimal(Decimal)
48
49
50# Auxiliary functions that are independent from a DB connection:
51
52def _is_quoted(s):
53    """Check whether this string is a quoted identifier."""
54    s = s.replace('_', 'a')
55    return not s.isalnum() or s[:1].isdigit() or s != s.lower()
56
57
58def _is_unquoted(s):
59    """Check whether this string is an unquoted identifier."""
60    s = s.replace('_', 'a')
61    return s.isalnum() and not s[:1].isdigit()
62
63
64def _split_first_part(s):
65    """Split the first part of a dot separated string."""
66    s = s.lstrip()
67    if s[:1] == '"':
68        p = []
69        s = s.split('"', 3)[1:]
70        p.append(s[0])
71        while len(s) == 3 and s[1] == '':
72            p.append('"')
73            s = s[2].split('"', 2)
74            p.append(s[0])
75        p = [''.join(p)]
76        s = '"'.join(s[1:]).lstrip()
77        if s:
78            if s[:0] == '.':
79                p.append(s[1:])
80            else:
81                s = _split_first_part(s)
82                p[0] += s[0]
83                if len(s) > 1:
84                    p.append(s[1])
85    else:
86        p = s.split('.', 1)
87        s = p[0].rstrip()
88        if _is_unquoted(s):
89            s = s.lower()
90        p[0] = s
91    return p
92
93
94def _split_parts(s):
95    """Split all parts of a dot separated string."""
96    q = []
97    while s:
98        s = _split_first_part(s)
99        q.append(s[0])
100        if len(s) < 2:
101            break
102        s = s[1]
103    return q
104
105
106def _join_parts(s):
107    """Join all parts of a dot separated string."""
108    return '.'.join(['"%s"' % p if _is_quoted(p) else p for p in s])
109
110
111def _oid_key(qcl):
112    """Build oid key from qualified class name."""
113    return 'oid(%s)' % qcl
114
115
116def _namedresult(q):
117    """Get query result as named tuples."""
118    row = namedtuple('Row', q.listfields())
119    return [row(*r) for r in q.getresult()]
120
121set_namedresult(_namedresult)
122
123
124def _db_error(msg, cls=DatabaseError):
125    """Returns DatabaseError with empty sqlstate attribute."""
126    error = cls(msg)
127    error.sqlstate = None
128    return error
129
130
131def _int_error(msg):
132    """Returns InternalError."""
133    return _db_error(msg, InternalError)
134
135
136def _prg_error(msg):
137    """Returns ProgrammingError."""
138    return _db_error(msg, ProgrammingError)
139
140
141class NotificationHandler(object):
142    """A PostgreSQL client-side asynchronous notification handler."""
143
144    def __init__(self, db, event, callback, arg_dict=None, timeout=None):
145        """Initialize the notification handler.
146
147        db       - PostgreSQL connection object.
148        event    - Event (notification channel) to LISTEN for.
149        callback - Event callback function.
150        arg_dict - A dictionary passed as the argument to the callback.
151        timeout  - Timeout in seconds; a floating point number denotes
152                   fractions of seconds. If it is absent or None, the
153                   callers will never time out.
154
155        """
156        if isinstance(db, DB):
157            db = db.db
158        self.db = db
159        self.event = event
160        self.stop_event = 'stop_%s' % event
161        self.listening = False
162        self.callback = callback
163        if arg_dict is None:
164            arg_dict = {}
165        self.arg_dict = arg_dict
166        self.timeout = timeout
167
168    def __del__(self):
169        self.close()
170
171    def close(self):
172        """Stop listening and close the connection."""
173        if self.db:
174            self.unlisten()
175            self.db.close()
176            self.db = None
177
178    def listen(self):
179        """Start listening for the event and the stop event."""
180        if not self.listening:
181            self.db.query('listen "%s"' % self.event)
182            self.db.query('listen "%s"' % self.stop_event)
183            self.listening = True
184
185    def unlisten(self):
186        """Stop listening for the event and the stop event."""
187        if self.listening:
188            self.db.query('unlisten "%s"' % self.event)
189            self.db.query('unlisten "%s"' % self.stop_event)
190            self.listening = False
191
192    def notify(self, db=None, stop=False, payload=None):
193        """Generate a notification.
194
195        Note: If the main loop is running in another thread, you must pass
196        a different database connection to avoid a collision.
197
198        """
199        if not db:
200            db = self.db
201        if self.listening:
202            q = 'notify "%s"' % (self.stop_event if stop else self.event)
203            if payload:
204                q += ", '%s'" % payload
205            return db.query(q)
206
207    def __call__(self, close=False):
208        """Invoke the notification handler.
209
210        The handler is a loop that actually LISTENs for two NOTIFY messages:
211
212        <event> and stop_<event>.
213
214        When either of these NOTIFY messages are received, its associated
215        'pid' and 'event' are inserted into <arg_dict>, and the callback is
216        invoked with <arg_dict>. If the NOTIFY message is stop_<event>, the
217        handler UNLISTENs both <event> and stop_<event> and exits.
218
219        Note: If you run this loop in another thread, don't use the same
220        database connection for database operations in the main thread.
221
222        """
223        self.listen()
224        _ilist = [self.db.fileno()]
225
226        while self.listening:
227            ilist, _olist, _elist = select.select(_ilist, [], [], self.timeout)
228            if ilist:
229                while self.listening:
230                    notice = self.db.getnotify()
231                    if not notice:  # no more messages
232                        break
233                    event, pid, extra = notice
234                    if event not in (self.event, self.stop_event):
235                        self.unlisten()
236                        raise _db_error(
237                            'listening for "%s" and "%s", but notified of "%s"'
238                            % (self.event, self.stop_event, event))
239                    if event == self.stop_event:
240                        self.unlisten()
241                    self.arg_dict['pid'] = pid
242                    self.arg_dict['event'] = event
243                    self.arg_dict['extra'] = extra
244                    self.callback(self.arg_dict)
245            else:   # we timed out
246                self.unlisten()
247                self.callback(None)
248
249
250def pgnotify(*args, **kw):
251    """Same as NotificationHandler, under the traditional name."""
252    warnings.warn("pgnotify is deprecated, use NotificationHandler instead.",
253        DeprecationWarning, stacklevel=2)
254    return NotificationHandler(*args, **kw)
255
256
257# The actual PostGreSQL database connection interface:
258
259class DB(object):
260    """Wrapper class for the _pg connection type."""
261
262    def __init__(self, *args, **kw):
263        """Create a new connection.
264
265        You can pass either the connection parameters or an existing
266        _pg or pgdb connection. This allows you to use the methods
267        of the classic pg interface with a DB-API 2 pgdb connection.
268
269        """
270        if not args and len(kw) == 1:
271            db = kw.get('db')
272        elif not kw and len(args) == 1:
273            db = args[0]
274        else:
275            db = None
276        if db:
277            if isinstance(db, DB):
278                db = db.db
279            else:
280                try:
281                    db = db._cnx
282                except AttributeError:
283                    pass
284        if not db or not hasattr(db, 'db') or not hasattr(db, 'query'):
285            db = connect(*args, **kw)
286            self._closeable = True
287        else:
288            self._closeable = False
289        self.db = db
290        self.dbname = db.db
291        self._regtypes = False
292        self._attnames = {}
293        self._pkeys = {}
294        self._privileges = {}
295        self._args = args, kw
296        self.debug = None  # For debugging scripts, this can be set
297            # * to a string format specification (e.g. in CGI set to "%s<BR>"),
298            # * to a file object to write debug statements or
299            # * to a callable object which takes a string argument
300            # * to any other true value to just print debug statements
301
302    def __getattr__(self, name):
303        # All undefined members are same as in underlying connection:
304        if self.db:
305            return getattr(self.db, name)
306        else:
307            raise _int_error('Connection is not valid')
308
309    def __dir__(self):
310        # Custom dir function including the attributes of the connection:
311        attrs = set(self.__class__.__dict__)
312        attrs.update(self.__dict__)
313        attrs.update(dir(self.db))
314        return sorted(attrs)
315
316    # Context manager methods
317
318    def __enter__(self):
319        """Enter the runtime context. This will start a transaction."""
320        self.begin()
321        return self
322
323    def __exit__(self, et, ev, tb):
324        """Exit the runtime context. This will end the transaction."""
325        if et is None and ev is None and tb is None:
326            self.commit()
327        else:
328            self.rollback()
329
330    # Auxiliary methods
331
332    def _do_debug(self, s):
333        """Print a debug message."""
334        if self.debug:
335            if isinstance(self.debug, basestring):
336                print(self.debug % s)
337            elif hasattr(self.debug, 'write'):
338                self.debug.write(s + '\n')
339            elif callable(self.debug):
340                self.debug(s)
341            else:
342                print(s)
343
344    @staticmethod
345    def _make_bool(d):
346        """Get boolean value corresponding to d."""
347        return bool(d) if get_bool() else ('t' if d else 'f')
348
349    def _quote_text(self, d):
350        """Quote text value."""
351        if not isinstance(d, basestring):
352            d = str(d)
353        return "'%s'" % self.escape_string(d)
354
355    _bool_true = frozenset('t true 1 y yes on'.split())
356
357    def _quote_bool(self, d):
358        """Quote boolean value."""
359        if isinstance(d, basestring):
360            if not d:
361                return 'NULL'
362            d = d.lower() in self._bool_true
363        return "'t'" if d else "'f'"
364
365    _date_literals = frozenset('current_date current_time'
366        ' current_timestamp localtime localtimestamp'.split())
367
368    def _quote_date(self, d):
369        """Quote date value."""
370        if not d:
371            return 'NULL'
372        if isinstance(d, basestring) and d.lower() in self._date_literals:
373            return d
374        return self._quote_text(d)
375
376    def _quote_num(self, d):
377        """Quote numeric value."""
378        if not d and d != 0:
379            return 'NULL'
380        return str(d)
381
382    def _quote_money(self, d):
383        """Quote money value."""
384        if d is None or d == '':
385            return 'NULL'
386        if not isinstance(d, basestring):
387            d = str(d)
388        return d
389
390    if bytes is str:  # Python < 3.0
391        """Quote bytes value."""
392
393        def _quote_bytea(self, d):
394            return "'%s'" % self.escape_bytea(d)
395
396    else:
397
398        def _quote_bytea(self, d):
399            return "'%s'" % self.escape_bytea(d).decode('ascii')
400
401    _quote_funcs = dict(  # quote methods for each type
402        text=_quote_text, bool=_quote_bool, date=_quote_date,
403        int=_quote_num, num=_quote_num, float=_quote_num,
404        money=_quote_money, bytea=_quote_bytea)
405
406    def _quote(self, d, t):
407        """Return quotes if needed."""
408        if d is None:
409            return 'NULL'
410        try:
411            quote_func = self._quote_funcs[t]
412        except KeyError:
413            quote_func = self._quote_funcs['text']
414        return quote_func(self, d)
415
416    def _split_schema(self, cl):
417        """Return schema and name of object separately.
418
419        This auxiliary function splits off the namespace (schema)
420        belonging to the class with the name cl. If the class name
421        is not qualified, the function is able to determine the schema
422        of the class, taking into account the current search path.
423
424        """
425        s = _split_parts(cl)
426        if len(s) > 1:  # name already qualified?
427            # should be database.schema.table or schema.table
428            if len(s) > 3:
429                raise _prg_error('Too many dots in class name %s' % cl)
430            schema, cl = s[-2:]
431        else:
432            cl = s[0]
433            # determine search path
434            q = 'SELECT current_schemas(TRUE)'
435            schemas = self.db.query(q).getresult()[0][0][1:-1].split(',')
436            if schemas:  # non-empty path
437                # search schema for this object in the current search path
438                # (we could also use unnest with ordinality here to spare
439                # one query, but this is only possible since PostgreSQL 9.4)
440                q = ' UNION '.join(
441                    ["SELECT %d::integer AS n, '%s'::name AS nspname"
442                        % s for s in enumerate(schemas)])
443                q = ("SELECT nspname FROM pg_class r"
444                    " JOIN pg_namespace s ON r.relnamespace = s.oid"
445                    " JOIN (%s) AS p USING (nspname)"
446                    " WHERE r.relname = $1 ORDER BY n LIMIT 1" % q)
447                schema = self.db.query(q, (cl,)).getresult()
448                if schema:  # schema found
449                    schema = schema[0][0]
450                else:  # object not found in current search path
451                    schema = 'public'
452            else:  # empty path
453                schema = 'public'
454        return schema, cl
455
456    def _add_schema(self, cl):
457        """Ensure that the class name is prefixed with a schema name."""
458        return _join_parts(self._split_schema(cl))
459
460    # Public methods
461
462    # escape_string and escape_bytea exist as methods,
463    # so we define unescape_bytea as a method as well
464    unescape_bytea = staticmethod(unescape_bytea)
465
466    def close(self):
467        """Close the database connection."""
468        # Wraps shared library function so we can track state.
469        if self._closeable:
470            if self.db:
471                self.db.close()
472                self.db = None
473            else:
474                raise _int_error('Connection already closed')
475
476    def reset(self):
477        """Reset connection with current parameters.
478
479        All derived queries and large objects derived from this connection
480        will not be usable after this call.
481
482        """
483        if self.db:
484            self.db.reset()
485        else:
486            raise _int_error('Connection already closed')
487
488    def reopen(self):
489        """Reopen connection to the database.
490
491        Used in case we need another connection to the same database.
492        Note that we can still reopen a database that we have closed.
493
494        """
495        # There is no such shared library function.
496        if self._closeable:
497            db = connect(*self._args[0], **self._args[1])
498            if self.db:
499                self.db.close()
500            self.db = db
501
502    def begin(self, mode=None):
503        """Begin a transaction."""
504        qstr = 'BEGIN'
505        if mode:
506            qstr += ' ' + mode
507        return self.query(qstr)
508
509    start = begin
510
511    def commit(self):
512        """Commit the current transaction."""
513        return self.query('COMMIT')
514
515    end = commit
516
517    def rollback(self, name=None):
518        """Rollback the current transaction."""
519        qstr = 'ROLLBACK'
520        if name:
521            qstr += ' TO ' + name
522        return self.query(qstr)
523
524    def savepoint(self, name=None):
525        """Define a new savepoint within the current transaction."""
526        qstr = 'SAVEPOINT'
527        if name:
528            qstr += ' ' + name
529        return self.query(qstr)
530
531    def release(self, name):
532        """Destroy a previously defined savepoint."""
533        return self.query('RELEASE ' + name)
534
535    def query(self, qstr, *args):
536        """Executes a SQL command string.
537
538        This method simply sends a SQL query to the database. If the query is
539        an insert statement that inserted exactly one row into a table that
540        has OIDs, the return value is the OID of the newly inserted row.
541        If the query is an update or delete statement, or an insert statement
542        that did not insert exactly one row in a table with OIDs, then the
543        number of rows affected is returned as a string. If it is a statement
544        that returns rows as a result (usually a select statement, but maybe
545        also an "insert/update ... returning" statement), this method returns
546        a Query object that can be accessed via getresult() or dictresult()
547        or simply printed. Otherwise, it returns `None`.
548
549        The query can contain numbered parameters of the form $1 in place
550        of any data constant. Arguments given after the query string will
551        be substituted for the corresponding numbered parameter. Parameter
552        values can also be given as a single list or tuple argument.
553
554        """
555        # Wraps shared library function for debugging.
556        if not self.db:
557            raise _int_error('Connection is not valid')
558        self._do_debug(qstr)
559        return self.db.query(qstr, args)
560
561    def pkey(self, cl, newpkey=None):
562        """This method gets or sets the primary key of a class.
563
564        Composite primary keys are represented as frozensets. Note that
565        this raises an exception if the table does not have a primary key.
566
567        If newpkey is set and is not a dictionary then set that
568        value as the primary key of the class.  If it is a dictionary
569        then replace the internal cache for primary keys with a copy of it.
570
571        """
572        add_schema = self._add_schema
573
574        # First see if the caller is supplying a dictionary
575        if isinstance(newpkey, dict):
576            # make sure that all classes have a namespace
577            self._pkeys = dict((add_schema(cl), pkey)
578                for cl, pkey in newpkey.items())
579            return self._pkeys
580
581        qcl = add_schema(cl)  # build fully qualified class name
582        # Check if the caller is supplying a new primary key for the class
583        if newpkey:
584            self._pkeys[qcl] = newpkey
585            return newpkey
586
587        # Get all the primary keys at once
588        if qcl not in self._pkeys:
589            # if not found, check again in case it was added after we started
590            q = ("SELECT s.nspname, r.relname, a.attname"
591                " FROM pg_class r"
592                " JOIN pg_namespace s ON s.oid = r.relnamespace"
593                " AND s.nspname NOT SIMILAR"
594                " TO 'pg/_%|information/_schema' ESCAPE '/'"
595                " JOIN pg_attribute a ON a.attrelid = r.oid"
596                " AND NOT a.attisdropped"
597                " JOIN pg_index i ON i.indrelid = r.oid"
598                " AND i.indisprimary AND a.attnum = ANY (i.indkey)"
599                " ORDER BY 1,2")
600            rows = self.db.query(q).getresult()
601            pkeys = {}
602            for cl, group in groupby(rows, lambda row: row[:2]):
603                cl = _join_parts(cl)
604                pkey = [row[2] for row in group]
605                pkeys[cl] = frozenset(pkey) if len(pkey) > 1 else pkey[0]
606            self._do_debug(pkeys)
607            self._pkeys = pkeys
608
609        # will raise an exception if primary key doesn't exist
610        return self._pkeys[qcl]
611
612    def get_databases(self):
613        """Get list of databases in the system."""
614        return [s[0] for s in
615            self.db.query('SELECT datname FROM pg_database').getresult()]
616
617    def get_relations(self, kinds=None):
618        """Get list of relations in connected database of specified kinds.
619
620        If kinds is None or empty, all kinds of relations are returned.
621        Otherwise kinds can be a string or sequence of type letters
622        specifying which kind of relations you want to list.
623
624        """
625        where = " AND r.relkind IN (%s)" % ','.join(
626            ["'%s'" % k for k in kinds]) if kinds else ''
627        q = ("SELECT s.nspname, r.relname"
628            " FROM pg_class r"
629            " JOIN pg_namespace s ON s.oid = r.relnamespace"
630            " WHERE s.nspname NOT SIMILAR"
631            " TO 'pg/_%%|information/_schema' ESCAPE '/' %s"
632            " ORDER BY 1, 2") % where
633        return [_join_parts(r) for r in self.db.query(q).getresult()]
634
635    def get_tables(self):
636        """Return list of tables in connected database."""
637        return self.get_relations('r')
638
639    def get_attnames(self, cl, newattnames=None):
640        """Given the name of a table, digs out the set of attribute names.
641
642        Returns a dictionary of attribute names (the names are the keys,
643        the values are the names of the attributes' types).
644        If the optional newattnames exists, it must be a dictionary and
645        will become the new attribute names dictionary.
646
647        By default, only a limited number of simple types will be returned.
648        You can get the regular types after calling use_regtypes(True).
649
650        """
651        if isinstance(newattnames, dict):
652            self._attnames = newattnames
653            return
654        elif newattnames:
655            raise _prg_error('If supplied, newattnames must be a dictionary')
656        cl = self._split_schema(cl)  # split into schema and class
657        qcl = _join_parts(cl)  # build fully qualified name
658        # May as well cache them:
659        if qcl in self._attnames:
660            return self._attnames[qcl]
661        if qcl not in self.get_relations('rv'):
662            raise _prg_error('Class %s does not exist' % qcl)
663
664        q = ("SELECT a.attname, t.typname%s"
665            " FROM pg_class r"
666            " JOIN pg_namespace s ON r.relnamespace = s.oid"
667            " JOIN pg_attribute a ON a.attrelid = r.oid"
668            " JOIN pg_type t ON t.oid = a.atttypid"
669            " WHERE s.nspname = $1 AND r.relname = $2"
670            " AND (a.attnum > 0 OR a.attname = 'oid')"
671            " AND NOT a.attisdropped") % (
672                '::regtype' if self._regtypes else '',)
673        q = self.db.query(q, cl).getresult()
674
675        if self._regtypes:
676            t = dict(q)
677        else:
678            t = {}
679            for att, typ in q:
680                if typ.startswith('bool'):
681                    typ = 'bool'
682                elif typ.startswith('abstime'):
683                    typ = 'date'
684                elif typ.startswith('date'):
685                    typ = 'date'
686                elif typ.startswith('interval'):
687                    typ = 'date'
688                elif typ.startswith('timestamp'):
689                    typ = 'date'
690                elif typ.startswith('oid'):
691                    typ = 'int'
692                elif typ.startswith('int'):
693                    typ = 'int'
694                elif typ.startswith('float'):
695                    typ = 'float'
696                elif typ.startswith('numeric'):
697                    typ = 'num'
698                elif typ.startswith('money'):
699                    typ = 'money'
700                elif typ.startswith('bytea'):
701                    typ = 'bytea'
702                else:
703                    typ = 'text'
704                t[att] = typ
705
706        self._attnames[qcl] = t  # cache it
707        return self._attnames[qcl]
708
709    def use_regtypes(self, regtypes=None):
710        """Use regular type names instead of simplified type names."""
711        if regtypes is None:
712            return self._regtypes
713        else:
714            regtypes = bool(regtypes)
715            if regtypes != self._regtypes:
716                self._regtypes = regtypes
717                self._attnames.clear()
718            return regtypes
719
720    def has_table_privilege(self, cl, privilege='select'):
721        """Check whether current user has specified table privilege."""
722        qcl = self._add_schema(cl)
723        privilege = privilege.lower()
724        try:
725            return self._privileges[(qcl, privilege)]
726        except KeyError:
727            q = "SELECT has_table_privilege($1, $2)"
728            q = self.db.query(q, (qcl, privilege))
729            ret = q.getresult()[0][0] == self._make_bool(True)
730            self._privileges[(qcl, privilege)] = ret
731            return ret
732
733    def get(self, cl, arg, keyname=None):
734        """Get a row from a database table or view.
735
736        This method is the basic mechanism to get a single row.  The keyname
737        that the key specifies a unique row.  If keyname is not specified
738        then the primary key for the table is used.  If arg is a dictionary
739        then the value for the key is taken from it and it is modified to
740        include the new values, replacing existing values where necessary.
741        For a composite key, keyname can also be a sequence of key names.
742        The OID is also put into the dictionary if the table has one, but
743        in order to allow the caller to work with multiple tables, it is
744        munged as oid(schema.table).
745
746        """
747        if cl.endswith('*'):  # scan descendant tables?
748            cl = cl[:-1].rstrip()  # need parent table name
749        # build qualified class name
750        qcl = self._add_schema(cl)
751        # To allow users to work with multiple tables,
752        # we munge the name of the "oid" the key
753        qoid = _oid_key(qcl)
754        if not keyname:
755            # use the primary key by default
756            try:
757                keyname = self.pkey(qcl)
758            except KeyError:
759                raise _prg_error('Class %s has no primary key' % qcl)
760        attnames = self.get_attnames(qcl)
761        # We want the oid for later updates if that isn't the key
762        if keyname == 'oid':
763            if isinstance(arg, dict):
764                if qoid not in arg:
765                    raise _db_error('%s not in arg' % qoid)
766            else:
767                arg = {qoid: arg}
768            what = '*'
769            where = 'oid = %s' % arg[qoid]
770        else:
771            if isinstance(keyname, basestring):
772                keyname = (keyname,)
773            if not isinstance(arg, dict):
774                if len(keyname) > 1:
775                    raise _prg_error('Composite key needs dict as arg')
776                arg = dict([(k, arg) for k in keyname])
777            what = ', '.join(attnames)
778            where = ' AND '.join(['%s = %s'
779                % (k, self._quote(arg[k], attnames[k])) for k in keyname])
780        q = 'SELECT %s FROM %s WHERE %s LIMIT 1' % (what, qcl, where)
781        self._do_debug(q)
782        res = self.db.query(q).dictresult()
783        if not res:
784            raise _db_error('No such record in %s where %s' % (qcl, where))
785        for n, value in res[0].items():
786            if n == 'oid':
787                n = qoid
788            elif attnames.get(n) == 'bytea':
789                value = self.unescape_bytea(value)
790            arg[n] = value
791        return arg
792
793    def insert(self, cl, d=None, **kw):
794        """Insert a row into a database table.
795
796        This method inserts a row into a table.  The name of the table must
797        be passed as the first parameter.  The other parameters are used for
798        providing the data of the row that shall be inserted into the table.
799        If a dictionary is supplied as the second parameter, it starts with
800        that.  Otherwise it uses a blank dictionary. Either way the dictionary
801        is updated from the keywords.
802
803        The dictionary is then, if possible, reloaded with the values actually
804        inserted in order to pick up values modified by rules, triggers, etc.
805
806        Note: The method currently doesn't support insert into views
807        although PostgreSQL does.
808
809        """
810        qcl = self._add_schema(cl)
811        qoid = _oid_key(qcl)
812        if d is None:
813            d = {}
814        d.update(kw)
815        attnames = self.get_attnames(qcl)
816        names, values = [], []
817        for n in attnames:
818            if n != 'oid' and n in d:
819                names.append('"%s"' % n)
820                values.append(self._quote(d[n], attnames[n]))
821        names, values = ', '.join(names), ', '.join(values)
822        selectable = self.has_table_privilege(qcl)
823        if selectable:
824            ret = ' RETURNING %s*' % ('oid, ' if 'oid' in attnames else '')
825        else:
826            ret = ''
827        q = 'INSERT INTO %s (%s) VALUES (%s)%s' % (qcl, names, values, ret)
828        self._do_debug(q)
829        res = self.db.query(q)
830        if ret:
831            res = res.dictresult()[0]
832            for n, value in res.items():
833                if n == 'oid':
834                    n = qoid
835                elif attnames.get(n) == 'bytea':
836                    value = self.unescape_bytea(value)
837                d[n] = value
838        elif isinstance(res, int):
839            d[qoid] = res
840            if selectable:
841                self.get(qcl, d, 'oid')
842        elif selectable:
843            if qoid in d:
844                self.get(qcl, d, 'oid')
845            else:
846                try:
847                    self.get(qcl, d)
848                except ProgrammingError:
849                    pass  # table has no primary key
850        return d
851
852    def update(self, cl, d=None, **kw):
853        """Update an existing row in a database table.
854
855        Similar to insert but updates an existing row.  The update is based
856        on the OID value as munged by get or passed as keyword, or on the
857        primary key of the table.  The dictionary is modified, if possible,
858        to reflect any changes caused by the update due to triggers, rules,
859        default values, etc.
860
861        """
862        # Update always works on the oid which get returns if available,
863        # otherwise use the primary key.  Fail if neither.
864        # Note that we only accept oid key from named args for safety
865        qcl = self._add_schema(cl)
866        qoid = _oid_key(qcl)
867        if 'oid' in kw:
868            kw[qoid] = kw['oid']
869            del kw['oid']
870        if d is None:
871            d = {}
872        d.update(kw)
873        attnames = self.get_attnames(qcl)
874        if qoid in d:
875            where = 'oid = %s' % d[qoid]
876            keyname = ()
877        else:
878            try:
879                keyname = self.pkey(qcl)
880            except KeyError:
881                raise _prg_error('Class %s has no primary key' % qcl)
882            if isinstance(keyname, basestring):
883                keyname = (keyname,)
884            try:
885                where = ' AND '.join(['%s = %s'
886                    % (k, self._quote(d[k], attnames[k])) for k in keyname])
887            except KeyError:
888                raise _prg_error('Update needs primary key or oid.')
889        values = []
890        for n in attnames:
891            if n in d and n not in keyname:
892                values.append('%s = %s' % (n, self._quote(d[n], attnames[n])))
893        if not values:
894            return d
895        values = ', '.join(values)
896        selectable = self.has_table_privilege(qcl)
897        if selectable:
898            ret = ' RETURNING %s*' % ('oid, ' if 'oid' in attnames else '')
899        else:
900            ret = ''
901        q = 'UPDATE %s SET %s WHERE %s%s' % (qcl, values, where, ret)
902        self._do_debug(q)
903        res = self.db.query(q)
904        if ret:
905            res = res.dictresult()[0]
906            for n, value in res.items():
907                if n == 'oid':
908                    n = qoid
909                elif attnames.get(n) == 'bytea':
910                    value = self.unescape_bytea(value)
911                d[n] = value
912        else:
913            if selectable:
914                if qoid in d:
915                    self.get(qcl, d, 'oid')
916                else:
917                    self.get(qcl, d)
918        return d
919
920    def clear(self, cl, a=None):
921        """Clear all the attributes to values determined by the types.
922
923        Numeric types are set to 0, Booleans are set to false, and everything
924        else is set to the empty string.  If the array argument is present,
925        it is used as the array and any entries matching attribute names are
926        cleared with everything else left unchanged.
927
928        """
929        # At some point we will need a way to get defaults from a table.
930        qcl = self._add_schema(cl)
931        if a is None:
932            a = {}  # empty if argument is not present
933        attnames = self.get_attnames(qcl)
934        for n, t in attnames.items():
935            if n == 'oid':
936                continue
937            if t in ('int', 'integer', 'smallint', 'bigint',
938                    'float', 'real', 'double precision',
939                    'num', 'numeric', 'money'):
940                a[n] = 0
941            elif t in ('bool', 'boolean'):
942                a[n] = self._make_bool(False)
943            else:
944                a[n] = ''
945        return a
946
947    def delete(self, cl, d=None, **kw):
948        """Delete an existing row in a database table.
949
950        This method deletes the row from a table.  It deletes based on the
951        OID value as munged by get or passed as keyword, or on the primary
952        key of the table.  The return value is the number of deleted rows
953        (i.e. 0 if the row did not exist and 1 if the row was deleted).
954
955        """
956        # Like update, delete works on the oid.
957        # One day we will be testing that the record to be deleted
958        # isn't referenced somewhere (or else PostgreSQL will).
959        # Note that we only accept oid key from named args for safety
960        qcl = self._add_schema(cl)
961        qoid = _oid_key(qcl)
962        if 'oid' in kw:
963            kw[qoid] = kw['oid']
964            del kw['oid']
965        if d is None:
966            d = {}
967        d.update(kw)
968        if qoid in d:
969            where = 'oid = %s' % d[qoid]
970        else:
971            try:
972                keyname = self.pkey(qcl)
973            except KeyError:
974                raise _prg_error('Class %s has no primary key' % qcl)
975            if isinstance(keyname, basestring):
976                keyname = (keyname,)
977            attnames = self.get_attnames(qcl)
978            try:
979                where = ' AND '.join(['%s = %s'
980                    % (k, self._quote(d[k], attnames[k])) for k in keyname])
981            except KeyError:
982                raise _prg_error('Delete needs primary key or oid.')
983        q = 'DELETE FROM %s WHERE %s' % (qcl, where)
984        self._do_debug(q)
985        return int(self.db.query(q))
986
987    def notification_handler(self, event, callback, arg_dict={}, timeout=None):
988        """Get notification handler that will run the given callback."""
989        return NotificationHandler(self.db, event, callback, arg_dict, timeout)
990
991
992# if run as script, print some information
993
994if __name__ == '__main__':
995    print('PyGreSQL version' + version)
996    print('')
997    print(__doc__)
Note: See TracBrowser for help on using the repository browser.