source: trunk/module/pg.py @ 521

Last change on this file since 521 was 521, checked in by darcy, 4 years ago

Use hard coded path rather than env.

  • Property svn:keywords set to Id
File size: 34.2 KB
Line 
1#! /usr/bin/python
2#
3# pg.py
4#
5# $Id: pg.py 521 2015-07-19 00:57:59Z darcy $
6#
7
8"""PyGreSQL classic interface.
9
10This pg module implements some basic database management stuff.
11It includes the _pg module and builds on it, providing the higher
12level wrapper class named DB with additional functionality.
13This is known as the "classic" ("old style") PyGreSQL interface.
14For a DB-API 2 compliant interface use the newer pgdb module.
15
16"""
17
18# Copyright (c) 1997-2013 by D'Arcy J.M. Cain.
19#
20# Contributions made by Ch. Zwerschke and others.
21#
22# The notification handler is based on pgnotify which is
23# Copyright (c) 2001 Ng Pheng Siong. All rights reserved.
24#
25# Permission to use, copy, modify, and distribute this software and its
26# documentation for any purpose and without fee is hereby granted,
27# provided that the above copyright notice appear in all copies and that
28# both that copyright notice and this permission notice appear in
29# supporting documentation.
30
31import sys
32if sys.version_info[0] == 3:
33    class basestring(str): pass
34
35from _pg import *
36
37import select
38import warnings
39try:
40    frozenset
41except NameError:  # Python < 2.4
42    from sets import ImmutableSet as frozenset
43try:
44    from decimal import Decimal
45    set_decimal(Decimal)
46except ImportError:  # Python < 2.4
47    Decimal = float
48try:
49    from collections import namedtuple
50except ImportError:  # Python < 2.6
51    namedtuple = None
52
53
54# Auxiliary functions which are independent from a DB connection:
55
56def _is_quoted(s):
57    """Check whether this string is a quoted identifier."""
58    s = s.replace('_', 'a')
59    return not s.isalnum() or s[:1].isdigit() or s != s.lower()
60
61
62def _is_unquoted(s):
63    """Check whether this string is an unquoted identifier."""
64    s = s.replace('_', 'a')
65    return s.isalnum() and not s[:1].isdigit()
66
67
68def _split_first_part(s):
69    """Split the first part of a dot separated string."""
70    s = s.lstrip()
71    if s[:1] == '"':
72        p = []
73        s = s.split('"', 3)[1:]
74        p.append(s[0])
75        while len(s) == 3 and s[1] == '':
76            p.append('"')
77            s = s[2].split('"', 2)
78            p.append(s[0])
79        p = [''.join(p)]
80        s = '"'.join(s[1:]).lstrip()
81        if s:
82            if s[:0] == '.':
83                p.append(s[1:])
84            else:
85                s = _split_first_part(s)
86                p[0] += s[0]
87                if len(s) > 1:
88                    p.append(s[1])
89    else:
90        p = s.split('.', 1)
91        s = p[0].rstrip()
92        if _is_unquoted(s):
93            s = s.lower()
94        p[0] = s
95    return p
96
97
98def _split_parts(s):
99    """Split all parts of a dot separated string."""
100    q = []
101    while s:
102        s = _split_first_part(s)
103        q.append(s[0])
104        if len(s) < 2:
105            break
106        s = s[1]
107    return q
108
109
110def _join_parts(s):
111    """Join all parts of a dot separated string."""
112    return '.'.join([_is_quoted(p) and '"%s"' % p or p for p in s])
113
114
115def _oid_key(qcl):
116    """Build oid key from qualified class name."""
117    return 'oid(%s)' % qcl
118
119
120if namedtuple:
121
122    def _namedresult(q):
123        """Get query result as named tuples."""
124        row = namedtuple('Row', q.listfields())
125        return [row(*r) for r in q.getresult()]
126
127    set_namedresult(_namedresult)
128
129
130def _db_error(msg, cls=DatabaseError):
131    """Returns DatabaseError with empty sqlstate attribute."""
132    error = cls(msg)
133    error.sqlstate = None
134    return error
135
136
137def _int_error(msg):
138    """Returns InternalError."""
139    return _db_error(msg, InternalError)
140
141
142def _prg_error(msg):
143    """Returns ProgrammingError."""
144    return _db_error(msg, ProgrammingError)
145
146
147class NotificationHandler(object):
148    """A PostgreSQL client-side asynchronous notification handler."""
149
150    def __init__(self, db, event, callback, arg_dict=None, timeout=None):
151        """Initialize the notification handler.
152
153        db       - PostgreSQL connection object.
154        event    - Event (notification channel) to LISTEN for.
155        callback - Event callback function.
156        arg_dict - A dictionary passed as the argument to the callback.
157        timeout  - Timeout in seconds; a floating point number denotes
158                   fractions of seconds. If it is absent or None, the
159                   callers will never time out.
160
161        """
162        if isinstance(db, DB):
163            db = db.db
164        self.db = db
165        self.event = event
166        self.stop_event = 'stop_%s' % event
167        self.listening = False
168        self.callback = callback
169        if arg_dict is None:
170            arg_dict = {}
171        self.arg_dict = arg_dict
172        self.timeout = timeout
173
174    def __del__(self):
175        self.close()
176
177    def close(self):
178        """Stop listening and close the connection."""
179        if self.db:
180            self.unlisten()
181            self.db.close()
182            self.db = None
183
184    def listen(self):
185        """Start listening for the event and the stop event."""
186        if not self.listening:
187            self.db.query('listen "%s"' % self.event)
188            self.db.query('listen "%s"' % self.stop_event)
189            self.listening = True
190
191    def unlisten(self):
192        """Stop listening for the event and the stop event."""
193        if self.listening:
194            self.db.query('unlisten "%s"' % self.event)
195            self.db.query('unlisten "%s"' % self.stop_event)
196            self.listening = False
197
198    def notify(self, db=None, stop=False, payload=None):
199        """Generate a notification.
200
201        Note: If the main loop is running in another thread, you must pass
202        a different database connection to avoid a collision.
203
204        """
205        if not db:
206            db = self.db
207        if self.listening:
208            q = 'notify "%s"' % (stop and self.stop_event or self.event)
209            if payload:
210                q += ", '%s'" % payload
211            return db.query(q)
212
213    def __call__(self, close=False):
214        """Invoke the notification handler.
215
216        The handler is a loop that actually LISTENs for two NOTIFY messages:
217
218        <event> and stop_<event>.
219
220        When either of these NOTIFY messages are received, its associated
221        'pid' and 'event' are inserted into <arg_dict>, and the callback is
222        invoked with <arg_dict>. If the NOTIFY message is stop_<event>, the
223        handler UNLISTENs both <event> and stop_<event> and exits.
224
225        Note: If you run this loop in another thread, don't use the same
226        database connection for database operations in the main thread.
227
228        """
229        self.listen()
230        _ilist = [self.db.fileno()]
231
232        while True:
233            ilist, _olist, _elist = select.select(_ilist, [], [], self.timeout)
234            if ilist == []:  # we timed out
235                self.unlisten()
236                self.callback(None)
237                break
238            else:
239                notice = self.db.getnotify()
240                if notice is None:
241                    continue
242                event, pid, extra = notice
243                if event in (self.event, self.stop_event):
244                    self.arg_dict['pid'] = pid
245                    self.arg_dict['event'] = event
246                    self.arg_dict['extra'] = extra
247                    self.callback(self.arg_dict)
248                    if event == self.stop_event:
249                        self.unlisten()
250                        break
251                else:
252                    self.unlisten()
253                    raise _db_error(
254                        'listening for "%s" and "%s", but notified of "%s"'
255                        % (self.event, self.stop_event, event))
256
257
258def pgnotify(*args, **kw):
259    """Same as NotificationHandler, under the traditional name."""
260    warnings.warn("pgnotify is deprecated, use NotificationHandler instead.",
261        DeprecationWarning, stacklevel=2)
262    return NotificationHandler(*args, **kw)
263
264
265# The actual PostGreSQL database connection interface:
266
267class DB(object):
268    """Wrapper class for the _pg connection type."""
269
270    def __init__(self, *args, **kw):
271        """Create a new connection.
272
273        You can pass either the connection parameters or an existing
274        _pg or pgdb connection. This allows you to use the methods
275        of the classic pg interface with a DB-API 2 pgdb connection.
276
277        """
278        if not args and len(kw) == 1:
279            db = kw.get('db')
280        elif not kw and len(args) == 1:
281            db = args[0]
282        else:
283            db = None
284        if db:
285            if isinstance(db, DB):
286                db = db.db
287            else:
288                try:
289                    db = db._cnx
290                except AttributeError:
291                    pass
292        if not db or not hasattr(db, 'db') or not hasattr(db, 'query'):
293            db = connect(*args, **kw)
294            self._closeable = True
295        else:
296            self._closeable = False
297        self.db = db
298        self.dbname = db.db
299        self._regtypes = False
300        self._attnames = {}
301        self._pkeys = {}
302        self._privileges = {}
303        self._args = args, kw
304        self.debug = None  # For debugging scripts, this can be set
305            # * to a string format specification (e.g. in CGI set to "%s<BR>"),
306            # * to a file object to write debug statements or
307            # * to a callable object which takes a string argument
308            # * to any other true value to just print debug statements
309
310    def __getattr__(self, name):
311        # All undefined members are same as in underlying pg connection:
312        if self.db:
313            return getattr(self.db, name)
314        else:
315            raise _int_error('Connection is not valid')
316
317    # Context manager methods
318
319    def __enter__(self):
320        """Enter the runtime context. This will start a transaction."""
321        self.begin()
322        return self
323
324    def __exit__(self, et, ev, tb):
325        """Exit the runtime context. This will end the transaction."""
326        if et is None and ev is None and tb is None:
327            self.commit()
328        else:
329            self.rollback()
330
331    # Auxiliary methods
332
333    def _do_debug(self, s):
334        """Print a debug message."""
335        if self.debug:
336            if isinstance(self.debug, basestring):
337                print(self.debug % s)
338            elif isinstance(self.debug, file):
339                file.write(s + '\n')
340            elif callable(self.debug):
341                self.debug(s)
342            else:
343                print(s)
344
345    def _quote_text(self, d):
346        """Quote text value."""
347        if not isinstance(d, basestring):
348            d = str(d)
349        return "'%s'" % self.escape_string(d)
350
351    _bool_true = frozenset('t true 1 y yes on'.split())
352
353    def _quote_bool(self, d):
354        """Quote boolean value."""
355        if isinstance(d, basestring):
356            if not d:
357                return 'NULL'
358            d = d.lower() in self._bool_true
359        else:
360            d = bool(d)
361        return ("'f'", "'t'")[d]
362
363    _date_literals = frozenset('current_date current_time'
364        ' current_timestamp localtime localtimestamp'.split())
365
366    def _quote_date(self, d):
367        """Quote date value."""
368        if not d:
369            return 'NULL'
370        if isinstance(d, basestring) and d.lower() in self._date_literals:
371            return d
372        return self._quote_text(d)
373
374    def _quote_num(self, d):
375        """Quote numeric value."""
376        if not d and d != 0:
377            return 'NULL'
378        return str(d)
379
380    def _quote_money(self, d):
381        """Quote money value."""
382        if d is None or d == '':
383            return 'NULL'
384        if not isinstance(d, basestring):
385            d = str(d)
386        return d
387
388    _quote_funcs = dict(  # quote methods for each type
389        text=_quote_text, bool=_quote_bool, date=_quote_date,
390        int=_quote_num, num=_quote_num, float=_quote_num,
391        money=_quote_money)
392
393    def _quote(self, d, t):
394        """Return quotes if needed."""
395        if d is None:
396            return 'NULL'
397        try:
398            quote_func = self._quote_funcs[t]
399        except KeyError:
400            quote_func = self._quote_funcs['text']
401        return quote_func(self, d)
402
403    def _split_schema(self, cl):
404        """Return schema and name of object separately.
405
406        This auxiliary function splits off the namespace (schema)
407        belonging to the class with the name cl. If the class name
408        is not qualified, the function is able to determine the schema
409        of the class, taking into account the current search path.
410
411        """
412        s = _split_parts(cl)
413        if len(s) > 1:  # name already qualfied?
414            # should be database.schema.table or schema.table
415            if len(s) > 3:
416                raise _prg_error('Too many dots in class name %s' % cl)
417            schema, cl = s[-2:]
418        else:
419            cl = s[0]
420            # determine search path
421            q = 'SELECT current_schemas(TRUE)'
422            schemas = self.db.query(q).getresult()[0][0][1:-1].split(',')
423            if schemas:  # non-empty path
424                # search schema for this object in the current search path
425                q = ' UNION '.join(
426                    ["SELECT %d::integer AS n, '%s'::name AS nspname"
427                        % s for s in enumerate(schemas)])
428                q = ("SELECT nspname FROM pg_class"
429                    " JOIN pg_namespace"
430                    " ON pg_class.relnamespace = pg_namespace.oid"
431                    " JOIN (%s) AS p USING (nspname)"
432                    " WHERE pg_class.relname = '%s'"
433                    " ORDER BY n LIMIT 1" % (q, cl))
434                schema = self.db.query(q).getresult()
435                if schema:  # schema found
436                    schema = schema[0][0]
437                else:  # object not found in current search path
438                    schema = 'public'
439            else:  # empty path
440                schema = 'public'
441        return schema, cl
442
443    def _add_schema(self, cl):
444        """Ensure that the class name is prefixed with a schema name."""
445        return _join_parts(self._split_schema(cl))
446
447    # Public methods
448
449    # escape_string and escape_bytea exist as methods,
450    # so we define unescape_bytea as a method as well
451    unescape_bytea = staticmethod(unescape_bytea)
452
453    def close(self):
454        """Close the database connection."""
455        # Wraps shared library function so we can track state.
456        if self._closeable:
457            if self.db:
458                self.db.close()
459                self.db = None
460            else:
461                raise _int_error('Connection already closed')
462
463    def reset(self):
464        """Reset connection with current parameters.
465
466        All derived queries and large objects derived from this connection
467        will not be usable after this call.
468
469        """
470        if self.db:
471            self.db.reset()
472        else:
473            raise _int_error('Connection already closed')
474
475    def reopen(self):
476        """Reopen connection to the database.
477
478        Used in case we need another connection to the same database.
479        Note that we can still reopen a database that we have closed.
480
481        """
482        # There is no such shared library function.
483        if self._closeable:
484            db = connect(*self._args[0], **self._args[1])
485            if self.db:
486                self.db.close()
487            self.db = db
488
489    def begin(self, mode=None):
490        """Begin a transaction."""
491        qstr = 'BEGIN'
492        if mode:
493            qstr += ' ' + mode
494        return self.query(qstr)
495
496    start = begin
497
498    def commit(self):
499        """Commit the current transaction."""
500        return self.query('COMMIT')
501
502    end = commit
503
504    def rollback(self, name=None):
505        """Rollback the current transaction."""
506        qstr = 'ROLLBACK'
507        if name:
508            qstr += ' TO ' + name
509        return self.query(qstr)
510
511    def savepoint(self, name=None):
512        """Define a new savepoint within the current transaction."""
513        qstr = 'SAVEPOINT'
514        if name:
515            qstr += ' ' + name
516        return self.query(qstr)
517
518    def release(self, name):
519        """Destroy a previously defined savepoint."""
520        return self.query('RELEASE ' + name)
521
522    def query(self, qstr, *args):
523        """Executes a SQL command string.
524
525        This method simply sends a SQL query to the database. If the query is
526        an insert statement that inserted exactly one row into a table that
527        has OIDs, the return value is the OID of the newly inserted row.
528        If the query is an update or delete statement, or an insert statement
529        that did not insert exactly one row in a table with OIDs, then the
530        numer of rows affected is returned as a string. If it is a statement
531        that returns rows as a result (usually a select statement, but maybe
532        also an "insert/update ... returning" statement), this method returns
533        a pgqueryobject that can be accessed via getresult() or dictresult()
534        or simply printed. Otherwise, it returns `None`.
535
536        The query can contain numbered parameters of the form $1 in place
537        of any data constant. Arguments given after the query string will
538        be substituted for the corresponding numbered parameter. Parameter
539        values can also be given as a single list or tuple argument.
540
541        Note that the query string must not be passed as a unicode value,
542        but you can pass arguments as unicode values if they can be decoded
543        using the current client encoding.
544
545        """
546        # Wraps shared library function for debugging.
547        if not self.db:
548            raise _int_error('Connection is not valid')
549        self._do_debug(qstr)
550        return self.db.query(qstr, args)
551
552    def pkey(self, cl, newpkey=None):
553        """This method gets or sets the primary key of a class.
554
555        Composite primary keys are represented as frozensets. Note that
556        this raises an exception if the table does not have a primary key.
557
558        If newpkey is set and is not a dictionary then set that
559        value as the primary key of the class.  If it is a dictionary
560        then replace the _pkeys dictionary with a copy of it.
561
562        """
563        # First see if the caller is supplying a dictionary
564        if isinstance(newpkey, dict):
565            # make sure that all classes have a namespace
566            self._pkeys = dict([
567                ('.' in cl and cl or 'public.' + cl, pkey)
568                for cl, pkey in newpkey.items()])
569            return self._pkeys
570
571        qcl = self._add_schema(cl)  # build fully qualified class name
572        # Check if the caller is supplying a new primary key for the class
573        if newpkey:
574            self._pkeys[qcl] = newpkey
575            return newpkey
576
577        # Get all the primary keys at once
578        if qcl not in self._pkeys:
579            # if not found, check again in case it was added after we started
580            self._pkeys = {}
581            if self.server_version >= 80200:
582                # the ANY syntax works correctly only with PostgreSQL >= 8.2
583                any_indkey = "= ANY (pg_index.indkey)"
584            else:
585                any_indkey = "IN (%s)" % ', '.join(
586                    ['pg_index.indkey[%d]' % i for i in range(16)])
587            for r in self.db.query(
588                "SELECT pg_namespace.nspname, pg_class.relname,"
589                    " pg_attribute.attname FROM pg_class"
590                " JOIN pg_namespace"
591                    " ON pg_namespace.oid = pg_class.relnamespace"
592                    " AND pg_namespace.nspname NOT LIKE 'pg_%'"
593                " JOIN pg_attribute ON pg_attribute.attrelid = pg_class.oid"
594                    " AND pg_attribute.attisdropped = 'f'"
595                " JOIN pg_index ON pg_index.indrelid = pg_class.oid"
596                    " AND pg_index.indisprimary = 't'"
597                    " AND pg_attribute.attnum " + any_indkey).getresult():
598                cl, pkey = _join_parts(r[:2]), r[2]
599                self._pkeys.setdefault(cl, []).append(pkey)
600            # (only) for composite primary keys, the values will be frozensets
601            for cl, pkey in self._pkeys.items():
602                self._pkeys[cl] = len(pkey) > 1 and frozenset(pkey) or pkey[0]
603            self._do_debug(self._pkeys)
604
605        # will raise an exception if primary key doesn't exist
606        return self._pkeys[qcl]
607
608    def get_databases(self):
609        """Get list of databases in the system."""
610        return [s[0] for s in
611            self.db.query('SELECT datname FROM pg_database').getresult()]
612
613    def get_relations(self, kinds=None):
614        """Get list of relations in connected database of specified kinds.
615
616            If kinds is None or empty, all kinds of relations are returned.
617            Otherwise kinds can be a string or sequence of type letters
618            specifying which kind of relations you want to list.
619
620        """
621        where = kinds and "pg_class.relkind IN (%s) AND" % ','.join(
622            ["'%s'" % x for x in kinds]) or ''
623        return [_join_parts(x) for x in self.db.query(
624            "SELECT pg_namespace.nspname, pg_class.relname "
625            "FROM pg_class "
626            "JOIN pg_namespace ON pg_namespace.oid = pg_class.relnamespace "
627            "WHERE %s pg_class.relname !~ '^Inv' AND "
628                "pg_class.relname !~ '^pg_' "
629            "ORDER BY 1, 2" % where).getresult()]
630
631    def get_tables(self):
632        """Return list of tables in connected database."""
633        return self.get_relations('r')
634
635    def get_attnames(self, cl, newattnames=None):
636        """Given the name of a table, digs out the set of attribute names.
637
638        Returns a dictionary of attribute names (the names are the keys,
639        the values are the names of the attributes' types).
640        If the optional newattnames exists, it must be a dictionary and
641        will become the new attribute names dictionary.
642
643        By default, only a limited number of simple types will be returned.
644        You can get the regular types after calling use_regtypes(True).
645
646        """
647        if isinstance(newattnames, dict):
648            self._attnames = newattnames
649            return
650        elif newattnames:
651            raise _prg_error('If supplied, newattnames must be a dictionary')
652        cl = self._split_schema(cl)  # split into schema and class
653        qcl = _join_parts(cl)  # build fully qualified name
654        # May as well cache them:
655        if qcl in self._attnames:
656            return self._attnames[qcl]
657        if qcl not in self.get_relations('rv'):
658            raise _prg_error('Class %s does not exist' % qcl)
659
660        q = "SELECT pg_attribute.attname, pg_type.typname"
661        if self._regtypes:
662            q += "::regtype"
663        q += (" FROM pg_class"
664            " JOIN pg_namespace ON pg_class.relnamespace = pg_namespace.oid"
665            " JOIN pg_attribute ON pg_attribute.attrelid = pg_class.oid"
666            " JOIN pg_type ON pg_type.oid = pg_attribute.atttypid"
667            " WHERE pg_namespace.nspname = '%s' AND pg_class.relname = '%s'"
668            " AND (pg_attribute.attnum > 0 OR pg_attribute.attname = 'oid')"
669            " AND pg_attribute.attisdropped = 'f'") % cl
670        q = self.db.query(q).getresult()
671
672        if self._regtypes:
673            t = dict(q)
674        else:
675            t = {}
676            for att, typ in q:
677                if typ.startswith('bool'):
678                    typ = 'bool'
679                elif typ.startswith('abstime'):
680                    typ = 'date'
681                elif typ.startswith('date'):
682                    typ = 'date'
683                elif typ.startswith('interval'):
684                    typ = 'date'
685                elif typ.startswith('timestamp'):
686                    typ = 'date'
687                elif typ.startswith('oid'):
688                    typ = 'int'
689                elif typ.startswith('int'):
690                    typ = 'int'
691                elif typ.startswith('float'):
692                    typ = 'float'
693                elif typ.startswith('numeric'):
694                    typ = 'num'
695                elif typ.startswith('money'):
696                    typ = 'money'
697                else:
698                    typ = 'text'
699                t[att] = typ
700
701        self._attnames[qcl] = t  # cache it
702        return self._attnames[qcl]
703
704    def use_regtypes(self, regtypes=None):
705        """Use regular type names instead of simplified type names."""
706        if regtypes is None:
707            return self._regtypes
708        else:
709            regtypes = bool(regtypes)
710            if regtypes != self._regtypes:
711                self._regtypes = regtypes
712                self._attnames.clear()
713            return regtypes
714
715    def has_table_privilege(self, cl, privilege='select'):
716        """Check whether current user has specified table privilege."""
717        qcl = self._add_schema(cl)
718        privilege = privilege.lower()
719        try:
720            return self._privileges[(qcl, privilege)]
721        except KeyError:
722            q = "SELECT has_table_privilege('%s', '%s')" % (qcl, privilege)
723            ret = self.db.query(q).getresult()[0][0] == 't'
724            self._privileges[(qcl, privilege)] = ret
725            return ret
726
727    def get(self, cl, arg, keyname=None):
728        """Get a tuple from a database table or view.
729
730        This method is the basic mechanism to get a single row.  The keyname
731        that the key specifies a unique row.  If keyname is not specified
732        then the primary key for the table is used.  If arg is a dictionary
733        then the value for the key is taken from it and it is modified to
734        include the new values, replacing existing values where necessary.
735        For a composite key, keyname can also be a sequence of key names.
736        The OID is also put into the dictionary if the table has one, but
737        in order to allow the caller to work with multiple tables, it is
738        munged as oid(schema.table).
739
740        """
741        if cl.endswith('*'):  # scan descendant tables?
742            cl = cl[:-1].rstrip()  # need parent table name
743        # build qualified class name
744        qcl = self._add_schema(cl)
745        # To allow users to work with multiple tables,
746        # we munge the name of the "oid" the key
747        qoid = _oid_key(qcl)
748        if not keyname:
749            # use the primary key by default
750            try:
751                keyname = self.pkey(qcl)
752            except KeyError:
753                raise _prg_error('Class %s has no primary key' % qcl)
754        # We want the oid for later updates if that isn't the key
755        if keyname == 'oid':
756            if isinstance(arg, dict):
757                if qoid not in arg:
758                    raise _db_error('%s not in arg' % qoid)
759            else:
760                arg = {qoid: arg}
761            where = 'oid = %s' % arg[qoid]
762            attnames = '*'
763        else:
764            attnames = self.get_attnames(qcl)
765            if isinstance(keyname, basestring):
766                keyname = (keyname,)
767            if not isinstance(arg, dict):
768                if len(keyname) > 1:
769                    raise _prg_error('Composite key needs dict as arg')
770                arg = dict([(k, arg) for k in keyname])
771            where = ' AND '.join(['%s = %s'
772                % (k, self._quote(arg[k], attnames[k])) for k in keyname])
773            attnames = ', '.join(attnames)
774        q = 'SELECT %s FROM %s WHERE %s LIMIT 1' % (attnames, qcl, where)
775        self._do_debug(q)
776        res = self.db.query(q).dictresult()
777        if not res:
778            raise _db_error('No such record in %s where %s' % (qcl, where))
779        for att, value in res[0].items():
780            arg[att == 'oid' and qoid or att] = value
781        return arg
782
783    def insert(self, cl, d=None, **kw):
784        """Insert a tuple into a database table.
785
786        This method inserts a row into a table.  If a dictionary is
787        supplied it starts with that.  Otherwise it uses a blank dictionary.
788        Either way the dictionary is updated from the keywords.
789
790        The dictionary is then, if possible, reloaded with the values actually
791        inserted in order to pick up values modified by rules, triggers, etc.
792
793        Note: The method currently doesn't support insert into views
794        although PostgreSQL does.
795
796        """
797        qcl = self._add_schema(cl)
798        qoid = _oid_key(qcl)
799        if d is None:
800            d = {}
801        d.update(kw)
802        attnames = self.get_attnames(qcl)
803        names, values = [], []
804        for n in attnames:
805            if n != 'oid' and n in d:
806                names.append('"%s"' % n)
807                values.append(self._quote(d[n], attnames[n]))
808        names, values = ', '.join(names), ', '.join(values)
809        selectable = self.has_table_privilege(qcl)
810        if selectable and self.server_version >= 80200:
811            ret = ' RETURNING %s*' % ('oid' in attnames and 'oid, ' or '')
812        else:
813            ret = ''
814        q = 'INSERT INTO %s (%s) VALUES (%s)%s' % (qcl, names, values, ret)
815        self._do_debug(q)
816        res = self.db.query(q)
817        if ret:
818            res = res.dictresult()
819            for att, value in res[0].items():
820                d[att == 'oid' and qoid or att] = value
821        elif isinstance(res, int):
822            d[qoid] = res
823            if selectable:
824                self.get(qcl, d, 'oid')
825        elif selectable:
826            if qoid in d:
827                self.get(qcl, d, 'oid')
828            else:
829                try:
830                    self.get(qcl, d)
831                except ProgrammingError:
832                    pass  # table has no primary key
833        return d
834
835    def update(self, cl, d=None, **kw):
836        """Update an existing row in a database table.
837
838        Similar to insert but updates an existing row.  The update is based
839        on the OID value as munged by get or passed as keyword, or on the
840        primary key of the table.  The dictionary is modified, if possible,
841        to reflect any changes caused by the update due to triggers, rules,
842        default values, etc.
843
844        """
845        # Update always works on the oid which get returns if available,
846        # otherwise use the primary key.  Fail if neither.
847        # Note that we only accept oid key from named args for safety
848        qcl = self._add_schema(cl)
849        qoid = _oid_key(qcl)
850        if 'oid' in kw:
851            kw[qoid] = kw['oid']
852            del kw['oid']
853        if d is None:
854            d = {}
855        d.update(kw)
856        attnames = self.get_attnames(qcl)
857        if qoid in d:
858            where = 'oid = %s' % d[qoid]
859            keyname = ()
860        else:
861            try:
862                keyname = self.pkey(qcl)
863            except KeyError:
864                raise _prg_error('Class %s has no primary key' % qcl)
865            if isinstance(keyname, basestring):
866                keyname = (keyname,)
867            try:
868                where = ' AND '.join(['%s = %s'
869                    % (k, self._quote(d[k], attnames[k])) for k in keyname])
870            except KeyError:
871                raise _prg_error('Update needs primary key or oid.')
872        values = []
873        for n in attnames:
874            if n in d and n not in keyname:
875                values.append('%s = %s' % (n, self._quote(d[n], attnames[n])))
876        if not values:
877            return d
878        values = ', '.join(values)
879        selectable = self.has_table_privilege(qcl)
880        if selectable and self.server_version >= 880200:
881            ret = ' RETURNING %s*' % ('oid' in attnames and 'oid, ' or '')
882        else:
883            ret = ''
884        q = 'UPDATE %s SET %s WHERE %s%s' % (qcl, values, where, ret)
885        self._do_debug(q)
886        res = self.db.query(q)
887        if ret:
888            res = res.dictresult()[0]
889            for att, value in res.items():
890                d[att == 'oid' and qoid or att] = value
891        else:
892            if selectable:
893                if qoid in d:
894                    self.get(qcl, d, 'oid')
895                else:
896                    self.get(qcl, d)
897        return d
898
899    def clear(self, cl, a=None):
900        """Clear all the attributes to values determined by the types.
901
902        Numeric types are set to 0, Booleans are set to 'f', and everything
903        else is set to the empty string.  If the array argument is present,
904        it is used as the array and any entries matching attribute names are
905        cleared with everything else left unchanged.
906
907        """
908        # At some point we will need a way to get defaults from a table.
909        qcl = self._add_schema(cl)
910        if a is None:
911            a = {}  # empty if argument is not present
912        attnames = self.get_attnames(qcl)
913        for n, t in attnames.items():
914            if n == 'oid':
915                continue
916            if t in ('int', 'integer', 'smallint', 'bigint',
917                    'float', 'real', 'double precision',
918                    'num', 'numeric', 'money'):
919                a[n] = 0
920            elif t in ('bool', 'boolean'):
921                a[n] = 'f'
922            else:
923                a[n] = ''
924        return a
925
926    def delete(self, cl, d=None, **kw):
927        """Delete an existing row in a database table.
928
929        This method deletes the row from a table.  It deletes based on the
930        OID value as munged by get or passed as keyword, or on the primary
931        key of the table.  The return value is the number of deleted rows
932        (i.e. 0 if the row did not exist and 1 if the row was deleted).
933
934        """
935        # Like update, delete works on the oid.
936        # One day we will be testing that the record to be deleted
937        # isn't referenced somewhere (or else PostgreSQL will).
938        # Note that we only accept oid key from named args for safety
939        qcl = self._add_schema(cl)
940        qoid = _oid_key(qcl)
941        if 'oid' in kw:
942            kw[qoid] = kw['oid']
943            del kw['oid']
944        if d is None:
945            d = {}
946        d.update(kw)
947        if qoid in d:
948            where = 'oid = %s' % d[qoid]
949        else:
950            try:
951                keyname = self.pkey(qcl)
952            except KeyError:
953                raise _prg_error('Class %s has no primary key' % qcl)
954            if isinstance(keyname, basestring):
955                keyname = (keyname,)
956            attnames = self.get_attnames(qcl)
957            try:
958                where = ' AND '.join(['%s = %s'
959                    % (k, self._quote(d[k], attnames[k])) for k in keyname])
960            except KeyError:
961                raise _prg_error('Delete needs primary key or oid.')
962        q = 'DELETE FROM %s WHERE %s' % (qcl, where)
963        self._do_debug(q)
964        return int(self.db.query(q))
965
966    def notification_handler(self, event, callback, arg_dict={}, timeout=None):
967        """Get notification handler that will run the given callback."""
968        return NotificationHandler(self.db, event, callback, arg_dict, timeout)
969
970
971# if run as script, print some information
972
973if __name__ == '__main__':
974    print('PyGreSQL version' + version)
975    print('')
976    print(__doc__)
Note: See TracBrowser for help on using the repository browser.