source: branches/4.x/module/pg.py @ 690

Last change on this file since 690 was 690, checked in by cito, 4 years ago

Amend tests so that they can run with PostgreSQL < 9.0

Note that we do not need to make these amendments in the trunk,
because we assume PostgreSQL >= 9.0 for PyGreSQL version 5.0.

  • Property svn:keywords set to Id
File size: 34.2 KB
Line 
1#! /usr/bin/python
2#
3# pg.py
4#
5# $Id: pg.py 690 2016-01-02 23:25:14Z cito $
6#
7
8"""PyGreSQL classic interface.
9
10This pg module implements some basic database management stuff.
11It includes the _pg module and builds on it, providing the higher
12level wrapper class named DB with additional functionality.
13This is known as the "classic" ("old style") PyGreSQL interface.
14For a DB-API 2 compliant interface use the newer pgdb module.
15
16"""
17
18# Copyright (c) 1997-2013 by D'Arcy J.M. Cain.
19#
20# Contributions made by Ch. Zwerschke and others.
21#
22# The notification handler is based on pgnotify which is
23# Copyright (c) 2001 Ng Pheng Siong. All rights reserved.
24#
25# Permission to use, copy, modify, and distribute this software and its
26# documentation for any purpose and without fee is hereby granted,
27# provided that the above copyright notice appear in all copies and that
28# both that copyright notice and this permission notice appear in
29# supporting documentation.
30
31from _pg import *
32
33import select
34import warnings
35try:
36    frozenset
37except NameError:  # Python < 2.4, unsupported
38    from sets import ImmutableSet as frozenset
39try:
40    from decimal import Decimal
41    set_decimal(Decimal)
42except ImportError:  # Python < 2.4, unsupported
43    Decimal = float
44try:
45    from collections import namedtuple
46except ImportError:  # Python < 2.6
47    namedtuple = None
48
49
50# Auxiliary functions that are independent of a DB connection:
51
52def _is_quoted(s):
53    """Check whether this string is a quoted identifier."""
54    s = s.replace('_', 'a')
55    return not s.isalnum() or s[:1].isdigit() or s != s.lower()
56
57
58def _is_unquoted(s):
59    """Check whether this string is an unquoted identifier."""
60    s = s.replace('_', 'a')
61    return s.isalnum() and not s[:1].isdigit()
62
63
64def _split_first_part(s):
65    """Split the first part of a dot separated string."""
66    s = s.lstrip()
67    if s[:1] == '"':
68        p = []
69        s = s.split('"', 3)[1:]
70        p.append(s[0])
71        while len(s) == 3 and s[1] == '':
72            p.append('"')
73            s = s[2].split('"', 2)
74            p.append(s[0])
75        p = [''.join(p)]
76        s = '"'.join(s[1:]).lstrip()
77        if s:
78            if s[:0] == '.':
79                p.append(s[1:])
80            else:
81                s = _split_first_part(s)
82                p[0] += s[0]
83                if len(s) > 1:
84                    p.append(s[1])
85    else:
86        p = s.split('.', 1)
87        s = p[0].rstrip()
88        if _is_unquoted(s):
89            s = s.lower()
90        p[0] = s
91    return p
92
93
94def _split_parts(s):
95    """Split all parts of a dot separated string."""
96    q = []
97    while s:
98        s = _split_first_part(s)
99        q.append(s[0])
100        if len(s) < 2:
101            break
102        s = s[1]
103    return q
104
105
106def _join_parts(s):
107    """Join all parts of a dot separated string."""
108    return '.'.join([_is_quoted(p) and '"%s"' % p or p for p in s])
109
110
111def _oid_key(qcl):
112    """Build oid key from qualified class name."""
113    return 'oid(%s)' % qcl
114
115
116if namedtuple:
117
118    def _namedresult(q):
119        """Get query result as named tuples."""
120        row = namedtuple('Row', q.listfields())
121        return [row(*r) for r in q.getresult()]
122
123    set_namedresult(_namedresult)
124
125
126def _db_error(msg, cls=DatabaseError):
127    """Returns DatabaseError with empty sqlstate attribute."""
128    error = cls(msg)
129    error.sqlstate = None
130    return error
131
132
133def _int_error(msg):
134    """Returns InternalError."""
135    return _db_error(msg, InternalError)
136
137
138def _prg_error(msg):
139    """Returns ProgrammingError."""
140    return _db_error(msg, ProgrammingError)
141
142
143class NotificationHandler(object):
144    """A PostgreSQL client-side asynchronous notification handler."""
145
146    def __init__(self, db, event, callback, arg_dict=None, timeout=None):
147        """Initialize the notification handler.
148
149        db       - PostgreSQL connection object.
150        event    - Event (notification channel) to LISTEN for.
151        callback - Event callback function.
152        arg_dict - A dictionary passed as the argument to the callback.
153        timeout  - Timeout in seconds; a floating point number denotes
154                   fractions of seconds. If it is absent or None, the
155                   callers will never time out.
156
157        """
158        if isinstance(db, DB):
159            db = db.db
160        self.db = db
161        self.event = event
162        self.stop_event = 'stop_%s' % event
163        self.listening = False
164        self.callback = callback
165        if arg_dict is None:
166            arg_dict = {}
167        self.arg_dict = arg_dict
168        self.timeout = timeout
169
170    def __del__(self):
171        self.close()
172
173    def close(self):
174        """Stop listening and close the connection."""
175        if self.db:
176            self.unlisten()
177            self.db.close()
178            self.db = None
179
180    def listen(self):
181        """Start listening for the event and the stop event."""
182        if not self.listening:
183            self.db.query('listen "%s"' % self.event)
184            self.db.query('listen "%s"' % self.stop_event)
185            self.listening = True
186
187    def unlisten(self):
188        """Stop listening for the event and the stop event."""
189        if self.listening:
190            self.db.query('unlisten "%s"' % self.event)
191            self.db.query('unlisten "%s"' % self.stop_event)
192            self.listening = False
193
194    def notify(self, db=None, stop=False, payload=None):
195        """Generate a notification.
196
197        Note: If the main loop is running in another thread, you must pass
198        a different database connection to avoid a collision.
199
200        The payload parameter is only supported in PostgreSQL >= 9.0.
201
202        """
203        if not db:
204            db = self.db
205        if self.listening:
206            q = 'notify "%s"' % (stop and self.stop_event or self.event)
207            if payload:
208                q += ", '%s'" % payload
209            return db.query(q)
210
211    def __call__(self, close=False):
212        """Invoke the notification handler.
213
214        The handler is a loop that actually LISTENs for two NOTIFY messages:
215
216        <event> and stop_<event>.
217
218        When either of these NOTIFY messages are received, its associated
219        'pid' and 'event' are inserted into <arg_dict>, and the callback is
220        invoked with <arg_dict>. If the NOTIFY message is stop_<event>, the
221        handler UNLISTENs both <event> and stop_<event> and exits.
222
223        Note: If you run this loop in another thread, don't use the same
224        database connection for database operations in the main thread.
225
226        """
227        self.listen()
228        _ilist = [self.db.fileno()]
229
230        while self.listening:
231            ilist, _olist, _elist = select.select(_ilist, [], [], self.timeout)
232            if ilist:
233                while self.listening:
234                    notice = self.db.getnotify()
235                    if not notice:  # no more messages
236                        break
237                    event, pid, extra = notice
238                    if event not in (self.event, self.stop_event):
239                        self.unlisten()
240                        raise _db_error(
241                            'listening for "%s" and "%s", but notified of "%s"'
242                            % (self.event, self.stop_event, event))
243                    if event == self.stop_event:
244                        self.unlisten()
245                    self.arg_dict['pid'] = pid
246                    self.arg_dict['event'] = event
247                    self.arg_dict['extra'] = extra
248                    self.callback(self.arg_dict)
249            else:   # we timed out
250                self.unlisten()
251                self.callback(None)
252
253
254def pgnotify(*args, **kw):
255    """Same as NotificationHandler, under the traditional name."""
256    warnings.warn("pgnotify is deprecated, use NotificationHandler instead.",
257        DeprecationWarning, stacklevel=2)
258    return NotificationHandler(*args, **kw)
259
260
261# The actual PostGreSQL database connection interface:
262
263class DB(object):
264    """Wrapper class for the _pg connection type."""
265
266    def __init__(self, *args, **kw):
267        """Create a new connection.
268
269        You can pass either the connection parameters or an existing
270        _pg or pgdb connection. This allows you to use the methods
271        of the classic pg interface with a DB-API 2 pgdb connection.
272
273        """
274        if not args and len(kw) == 1:
275            db = kw.get('db')
276        elif not kw and len(args) == 1:
277            db = args[0]
278        else:
279            db = None
280        if db:
281            if isinstance(db, DB):
282                db = db.db
283            else:
284                try:
285                    db = db._cnx
286                except AttributeError:
287                    pass
288        if not db or not hasattr(db, 'db') or not hasattr(db, 'query'):
289            db = connect(*args, **kw)
290            self._closeable = True
291        else:
292            self._closeable = False
293        self.db = db
294        self.dbname = db.db
295        self._regtypes = False
296        self._attnames = {}
297        self._pkeys = {}
298        self._privileges = {}
299        self._args = args, kw
300        self.debug = None  # For debugging scripts, this can be set
301            # * to a string format specification (e.g. in CGI set to "%s<BR>"),
302            # * to a file object to write debug statements or
303            # * to a callable object which takes a string argument
304            # * to any other true value to just print debug statements
305
306    def __getattr__(self, name):
307        # All undefined members are same as in underlying pg connection:
308        if self.db:
309            return getattr(self.db, name)
310        else:
311            raise _int_error('Connection is not valid')
312
313    # Context manager methods
314
315    def __enter__(self):
316        """Enter the runtime context. This will start a transaction."""
317        self.begin()
318        return self
319
320    def __exit__(self, et, ev, tb):
321        """Exit the runtime context. This will end the transaction."""
322        if et is None and ev is None and tb is None:
323            self.commit()
324        else:
325            self.rollback()
326
327    # Auxiliary methods
328
329    def _do_debug(self, s):
330        """Print a debug message."""
331        if self.debug:
332            if isinstance(self.debug, basestring):
333                print(self.debug % s)
334            elif isinstance(self.debug, file):
335                self.debug.write(s + '\n')
336            elif callable(self.debug):
337                self.debug(s)
338            else:
339                print(s)
340
341    def _quote_text(self, d):
342        """Quote text value."""
343        if not isinstance(d, basestring):
344            d = str(d)
345        return "'%s'" % self.escape_string(d)
346
347    _bool_true = frozenset('t true 1 y yes on'.split())
348
349    def _quote_bool(self, d):
350        """Quote boolean value."""
351        if isinstance(d, basestring):
352            if not d:
353                return 'NULL'
354            d = d.lower() in self._bool_true
355        else:
356            d = bool(d)
357        return ("'f'", "'t'")[d]
358
359    _date_literals = frozenset('current_date current_time'
360        ' current_timestamp localtime localtimestamp'.split())
361
362    def _quote_date(self, d):
363        """Quote date value."""
364        if not d:
365            return 'NULL'
366        if isinstance(d, basestring) and d.lower() in self._date_literals:
367            return d
368        return self._quote_text(d)
369
370    def _quote_num(self, d):
371        """Quote numeric value."""
372        if not d and d != 0:
373            return 'NULL'
374        return str(d)
375
376    def _quote_money(self, d):
377        """Quote money value."""
378        if d is None or d == '':
379            return 'NULL'
380        if not isinstance(d, basestring):
381            d = str(d)
382        return d
383
384    _quote_funcs = dict(  # quote methods for each type
385        text=_quote_text, bool=_quote_bool, date=_quote_date,
386        int=_quote_num, num=_quote_num, float=_quote_num,
387        money=_quote_money)
388
389    def _quote(self, d, t):
390        """Return quotes if needed."""
391        if d is None:
392            return 'NULL'
393        try:
394            quote_func = self._quote_funcs[t]
395        except KeyError:
396            quote_func = self._quote_funcs['text']
397        return quote_func(self, d)
398
399    def _split_schema(self, cl):
400        """Return schema and name of object separately.
401
402        This auxiliary function splits off the namespace (schema)
403        belonging to the class with the name cl. If the class name
404        is not qualified, the function is able to determine the schema
405        of the class, taking into account the current search path.
406
407        """
408        s = _split_parts(cl)
409        if len(s) > 1:  # name already qualfied?
410            # should be database.schema.table or schema.table
411            if len(s) > 3:
412                raise _prg_error('Too many dots in class name %s' % cl)
413            schema, cl = s[-2:]
414        else:
415            cl = s[0]
416            # determine search path
417            q = 'SELECT current_schemas(TRUE)'
418            schemas = self.db.query(q).getresult()[0][0][1:-1].split(',')
419            if schemas:  # non-empty path
420                # search schema for this object in the current search path
421                q = ' UNION '.join(
422                    ["SELECT %d::integer AS n, '%s'::name AS nspname"
423                        % s for s in enumerate(schemas)])
424                q = ("SELECT nspname FROM pg_class"
425                    " JOIN pg_namespace"
426                    " ON pg_class.relnamespace = pg_namespace.oid"
427                    " JOIN (%s) AS p USING (nspname)"
428                    " WHERE pg_class.relname = '%s'"
429                    " ORDER BY n LIMIT 1" % (q, cl))
430                schema = self.db.query(q).getresult()
431                if schema:  # schema found
432                    schema = schema[0][0]
433                else:  # object not found in current search path
434                    schema = 'public'
435            else:  # empty path
436                schema = 'public'
437        return schema, cl
438
439    def _add_schema(self, cl):
440        """Ensure that the class name is prefixed with a schema name."""
441        return _join_parts(self._split_schema(cl))
442
443    # Public methods
444
445    # escape_string and escape_bytea exist as methods,
446    # so we define unescape_bytea as a method as well
447    unescape_bytea = staticmethod(unescape_bytea)
448
449    def close(self):
450        """Close the database connection."""
451        # Wraps shared library function so we can track state.
452        if self._closeable:
453            if self.db:
454                self.db.close()
455                self.db = None
456            else:
457                raise _int_error('Connection already closed')
458
459    def reset(self):
460        """Reset connection with current parameters.
461
462        All derived queries and large objects derived from this connection
463        will not be usable after this call.
464
465        """
466        if self.db:
467            self.db.reset()
468        else:
469            raise _int_error('Connection already closed')
470
471    def reopen(self):
472        """Reopen connection to the database.
473
474        Used in case we need another connection to the same database.
475        Note that we can still reopen a database that we have closed.
476
477        """
478        # There is no such shared library function.
479        if self._closeable:
480            db = connect(*self._args[0], **self._args[1])
481            if self.db:
482                self.db.close()
483            self.db = db
484
485    def begin(self, mode=None):
486        """Begin a transaction."""
487        qstr = 'BEGIN'
488        if mode:
489            qstr += ' ' + mode
490        return self.query(qstr)
491
492    start = begin
493
494    def commit(self):
495        """Commit the current transaction."""
496        return self.query('COMMIT')
497
498    end = commit
499
500    def rollback(self, name=None):
501        """Rollback the current transaction."""
502        qstr = 'ROLLBACK'
503        if name:
504            qstr += ' TO ' + name
505        return self.query(qstr)
506
507    def savepoint(self, name=None):
508        """Define a new savepoint within the current transaction."""
509        qstr = 'SAVEPOINT'
510        if name:
511            qstr += ' ' + name
512        return self.query(qstr)
513
514    def release(self, name):
515        """Destroy a previously defined savepoint."""
516        return self.query('RELEASE ' + name)
517
518    def query(self, qstr, *args):
519        """Executes a SQL command string.
520
521        This method simply sends a SQL query to the database. If the query is
522        an insert statement that inserted exactly one row into a table that
523        has OIDs, the return value is the OID of the newly inserted row.
524        If the query is an update or delete statement, or an insert statement
525        that did not insert exactly one row in a table with OIDs, then the
526        numer of rows affected is returned as a string. If it is a statement
527        that returns rows as a result (usually a select statement, but maybe
528        also an "insert/update ... returning" statement), this method returns
529        a pgqueryobject that can be accessed via getresult() or dictresult()
530        or simply printed. Otherwise, it returns `None`.
531
532        The query can contain numbered parameters of the form $1 in place
533        of any data constant. Arguments given after the query string will
534        be substituted for the corresponding numbered parameter. Parameter
535        values can also be given as a single list or tuple argument.
536
537        Note that the query string must not be passed as a unicode value,
538        but you can pass arguments as unicode values if they can be decoded
539        using the current client encoding.
540
541        """
542        # Wraps shared library function for debugging.
543        if not self.db:
544            raise _int_error('Connection is not valid')
545        self._do_debug(qstr)
546        return self.db.query(qstr, args)
547
548    def pkey(self, cl, newpkey=None):
549        """This method gets or sets the primary key of a class.
550
551        Composite primary keys are represented as frozensets. Note that
552        this raises an exception if the table does not have a primary key.
553
554        If newpkey is set and is not a dictionary then set that
555        value as the primary key of the class.  If it is a dictionary
556        then replace the _pkeys dictionary with a copy of it.
557
558        """
559        # First see if the caller is supplying a dictionary
560        if isinstance(newpkey, dict):
561            # make sure that all classes have a namespace
562            self._pkeys = dict([
563                ('.' in cl and cl or 'public.' + cl, pkey)
564                for cl, pkey in newpkey.items()])
565            return self._pkeys
566
567        qcl = self._add_schema(cl)  # build fully qualified class name
568        # Check if the caller is supplying a new primary key for the class
569        if newpkey:
570            self._pkeys[qcl] = newpkey
571            return newpkey
572
573        # Get all the primary keys at once
574        if qcl not in self._pkeys:
575            # if not found, check again in case it was added after we started
576            self._pkeys = {}
577            if self.server_version >= 80200:
578                # the ANY syntax works correctly only with PostgreSQL >= 8.2
579                any_indkey = "= ANY (pg_index.indkey)"
580            else:
581                any_indkey = "IN (%s)" % ', '.join(
582                    ['pg_index.indkey[%d]' % i for i in range(16)])
583            for r in self.db.query(
584                "SELECT pg_namespace.nspname, pg_class.relname,"
585                    " pg_attribute.attname FROM pg_class"
586                " JOIN pg_namespace"
587                    " ON pg_namespace.oid = pg_class.relnamespace"
588                    " AND pg_namespace.nspname NOT LIKE 'pg_%'"
589                " JOIN pg_attribute ON pg_attribute.attrelid = pg_class.oid"
590                    " AND pg_attribute.attisdropped = 'f'"
591                " JOIN pg_index ON pg_index.indrelid = pg_class.oid"
592                    " AND pg_index.indisprimary = 't'"
593                    " AND pg_attribute.attnum " + any_indkey).getresult():
594                cl, pkey = _join_parts(r[:2]), r[2]
595                self._pkeys.setdefault(cl, []).append(pkey)
596            # (only) for composite primary keys, the values will be frozensets
597            for cl, pkey in self._pkeys.items():
598                self._pkeys[cl] = len(pkey) > 1 and frozenset(pkey) or pkey[0]
599            self._do_debug(self._pkeys)
600
601        # will raise an exception if primary key doesn't exist
602        return self._pkeys[qcl]
603
604    def get_databases(self):
605        """Get list of databases in the system."""
606        return [s[0] for s in
607            self.db.query('SELECT datname FROM pg_database').getresult()]
608
609    def get_relations(self, kinds=None):
610        """Get list of relations in connected database of specified kinds.
611
612            If kinds is None or empty, all kinds of relations are returned.
613            Otherwise kinds can be a string or sequence of type letters
614            specifying which kind of relations you want to list.
615
616        """
617        where = kinds and "pg_class.relkind IN (%s) AND" % ','.join(
618            ["'%s'" % x for x in kinds]) or ''
619        return [_join_parts(x) for x in self.db.query(
620            "SELECT pg_namespace.nspname, pg_class.relname "
621            "FROM pg_class "
622            "JOIN pg_namespace ON pg_namespace.oid = pg_class.relnamespace "
623            "WHERE %s pg_class.relname !~ '^Inv' AND "
624                "pg_class.relname !~ '^pg_' "
625            "ORDER BY 1, 2" % where).getresult()]
626
627    def get_tables(self):
628        """Return list of tables in connected database."""
629        return self.get_relations('r')
630
631    def get_attnames(self, cl, newattnames=None):
632        """Given the name of a table, digs out the set of attribute names.
633
634        Returns a dictionary of attribute names (the names are the keys,
635        the values are the names of the attributes' types).
636        If the optional newattnames exists, it must be a dictionary and
637        will become the new attribute names dictionary.
638
639        By default, only a limited number of simple types will be returned.
640        You can get the regular types after calling use_regtypes(True).
641
642        """
643        if isinstance(newattnames, dict):
644            self._attnames = newattnames
645            return
646        elif newattnames:
647            raise _prg_error('If supplied, newattnames must be a dictionary')
648        cl = self._split_schema(cl)  # split into schema and class
649        qcl = _join_parts(cl)  # build fully qualified name
650        # May as well cache them:
651        if qcl in self._attnames:
652            return self._attnames[qcl]
653        if qcl not in self.get_relations('rv'):
654            raise _prg_error('Class %s does not exist' % qcl)
655
656        q = "SELECT pg_attribute.attname, pg_type.typname"
657        if self._regtypes:
658            q += "::regtype"
659        q += (" FROM pg_class"
660            " JOIN pg_namespace ON pg_class.relnamespace = pg_namespace.oid"
661            " JOIN pg_attribute ON pg_attribute.attrelid = pg_class.oid"
662            " JOIN pg_type ON pg_type.oid = pg_attribute.atttypid"
663            " WHERE pg_namespace.nspname = '%s' AND pg_class.relname = '%s'"
664            " AND (pg_attribute.attnum > 0 OR pg_attribute.attname = 'oid')"
665            " AND pg_attribute.attisdropped = 'f'") % cl
666        q = self.db.query(q).getresult()
667
668        if self._regtypes:
669            t = dict(q)
670        else:
671            t = {}
672            for att, typ in q:
673                if typ.startswith('bool'):
674                    typ = 'bool'
675                elif typ.startswith('abstime'):
676                    typ = 'date'
677                elif typ.startswith('date'):
678                    typ = 'date'
679                elif typ.startswith('interval'):
680                    typ = 'date'
681                elif typ.startswith('timestamp'):
682                    typ = 'date'
683                elif typ.startswith('oid'):
684                    typ = 'int'
685                elif typ.startswith('int'):
686                    typ = 'int'
687                elif typ.startswith('float'):
688                    typ = 'float'
689                elif typ.startswith('numeric'):
690                    typ = 'num'
691                elif typ.startswith('money'):
692                    typ = 'money'
693                else:
694                    typ = 'text'
695                t[att] = typ
696
697        self._attnames[qcl] = t  # cache it
698        return self._attnames[qcl]
699
700    def use_regtypes(self, regtypes=None):
701        """Use regular type names instead of simplified type names."""
702        if regtypes is None:
703            return self._regtypes
704        else:
705            regtypes = bool(regtypes)
706            if regtypes != self._regtypes:
707                self._regtypes = regtypes
708                self._attnames.clear()
709            return regtypes
710
711    def has_table_privilege(self, cl, privilege='select'):
712        """Check whether current user has specified table privilege."""
713        qcl = self._add_schema(cl)
714        privilege = privilege.lower()
715        try:
716            return self._privileges[(qcl, privilege)]
717        except KeyError:
718            q = "SELECT has_table_privilege('%s', '%s')" % (qcl, privilege)
719            ret = self.db.query(q).getresult()[0][0] == 't'
720            self._privileges[(qcl, privilege)] = ret
721            return ret
722
723    def get(self, cl, arg, keyname=None):
724        """Get a tuple from a database table or view.
725
726        This method is the basic mechanism to get a single row.  The keyname
727        that the key specifies a unique row.  If keyname is not specified
728        then the primary key for the table is used.  If arg is a dictionary
729        then the value for the key is taken from it and it is modified to
730        include the new values, replacing existing values where necessary.
731        For a composite key, keyname can also be a sequence of key names.
732        The OID is also put into the dictionary if the table has one, but
733        in order to allow the caller to work with multiple tables, it is
734        munged as oid(schema.table).
735
736        """
737        if cl.endswith('*'):  # scan descendant tables?
738            cl = cl[:-1].rstrip()  # need parent table name
739        # build qualified class name
740        qcl = self._add_schema(cl)
741        # To allow users to work with multiple tables,
742        # we munge the name of the "oid" the key
743        qoid = _oid_key(qcl)
744        if not keyname:
745            # use the primary key by default
746            try:
747                keyname = self.pkey(qcl)
748            except KeyError:
749                raise _prg_error('Class %s has no primary key' % qcl)
750        # We want the oid for later updates if that isn't the key
751        if keyname == 'oid':
752            if isinstance(arg, dict):
753                if qoid not in arg:
754                    raise _db_error('%s not in arg' % qoid)
755            else:
756                arg = {qoid: arg}
757            where = 'oid = %s' % arg[qoid]
758            attnames = '*'
759        else:
760            attnames = self.get_attnames(qcl)
761            if isinstance(keyname, basestring):
762                keyname = (keyname,)
763            if not isinstance(arg, dict):
764                if len(keyname) > 1:
765                    raise _prg_error('Composite key needs dict as arg')
766                arg = dict([(k, arg) for k in keyname])
767            where = ' AND '.join(['%s = %s'
768                % (k, self._quote(arg[k], attnames[k])) for k in keyname])
769            attnames = ', '.join(attnames)
770        q = 'SELECT %s FROM %s WHERE %s LIMIT 1' % (attnames, qcl, where)
771        self._do_debug(q)
772        res = self.db.query(q).dictresult()
773        if not res:
774            raise _db_error('No such record in %s where %s' % (qcl, where))
775        for att, value in res[0].items():
776            arg[att == 'oid' and qoid or att] = value
777        return arg
778
779    def insert(self, cl, d=None, **kw):
780        """Insert a tuple into a database table.
781
782        This method inserts a row into a table.  If a dictionary is
783        supplied it starts with that.  Otherwise it uses a blank dictionary.
784        Either way the dictionary is updated from the keywords.
785
786        The dictionary is then, if possible, reloaded with the values actually
787        inserted in order to pick up values modified by rules, triggers, etc.
788
789        Note: The method currently doesn't support insert into views
790        although PostgreSQL does.
791
792        """
793        qcl = self._add_schema(cl)
794        qoid = _oid_key(qcl)
795        if d is None:
796            d = {}
797        d.update(kw)
798        attnames = self.get_attnames(qcl)
799        names, values = [], []
800        for n in attnames:
801            if n != 'oid' and n in d:
802                names.append('"%s"' % n)
803                values.append(self._quote(d[n], attnames[n]))
804        names, values = ', '.join(names), ', '.join(values)
805        selectable = self.has_table_privilege(qcl)
806        if selectable and self.server_version >= 80200:
807            ret = ' RETURNING %s*' % ('oid' in attnames and 'oid, ' or '')
808        else:
809            ret = ''
810        q = 'INSERT INTO %s (%s) VALUES (%s)%s' % (qcl, names, values, ret)
811        self._do_debug(q)
812        res = self.db.query(q)
813        if ret:
814            res = res.dictresult()
815            for att, value in res[0].items():
816                d[att == 'oid' and qoid or att] = value
817        elif isinstance(res, int):
818            d[qoid] = res
819            if selectable:
820                self.get(qcl, d, 'oid')
821        elif selectable:
822            if qoid in d:
823                self.get(qcl, d, 'oid')
824            else:
825                try:
826                    self.get(qcl, d)
827                except ProgrammingError:
828                    pass  # table has no primary key
829        return d
830
831    def update(self, cl, d=None, **kw):
832        """Update an existing row in a database table.
833
834        Similar to insert but updates an existing row.  The update is based
835        on the OID value as munged by get or passed as keyword, or on the
836        primary key of the table.  The dictionary is modified, if possible,
837        to reflect any changes caused by the update due to triggers, rules,
838        default values, etc.
839
840        """
841        # Update always works on the oid which get returns if available,
842        # otherwise use the primary key.  Fail if neither.
843        # Note that we only accept oid key from named args for safety
844        qcl = self._add_schema(cl)
845        qoid = _oid_key(qcl)
846        if 'oid' in kw:
847            kw[qoid] = kw['oid']
848            del kw['oid']
849        if d is None:
850            d = {}
851        d.update(kw)
852        attnames = self.get_attnames(qcl)
853        if qoid in d:
854            where = 'oid = %s' % d[qoid]
855            keyname = ()
856        else:
857            try:
858                keyname = self.pkey(qcl)
859            except KeyError:
860                raise _prg_error('Class %s has no primary key' % qcl)
861            if isinstance(keyname, basestring):
862                keyname = (keyname,)
863            try:
864                where = ' AND '.join(['%s = %s'
865                    % (k, self._quote(d[k], attnames[k])) for k in keyname])
866            except KeyError:
867                raise _prg_error('Update needs primary key or oid.')
868        values = []
869        for n in attnames:
870            if n in d and n not in keyname:
871                values.append('%s = %s' % (n, self._quote(d[n], attnames[n])))
872        if not values:
873            return d
874        values = ', '.join(values)
875        selectable = self.has_table_privilege(qcl)
876        if selectable and self.server_version >= 80200:
877            ret = ' RETURNING %s*' % ('oid' in attnames and 'oid, ' or '')
878        else:
879            ret = ''
880        q = 'UPDATE %s SET %s WHERE %s%s' % (qcl, values, where, ret)
881        self._do_debug(q)
882        res = self.db.query(q)
883        if ret:
884            res = res.dictresult()[0]
885            for att, value in res.items():
886                d[att == 'oid' and qoid or att] = value
887        else:
888            if selectable:
889                if qoid in d:
890                    self.get(qcl, d, 'oid')
891                else:
892                    self.get(qcl, d)
893        return d
894
895    def clear(self, cl, a=None):
896        """Clear all the attributes to values determined by the types.
897
898        Numeric types are set to 0, Booleans are set to 'f', and everything
899        else is set to the empty string.  If the array argument is present,
900        it is used as the array and any entries matching attribute names are
901        cleared with everything else left unchanged.
902
903        """
904        # At some point we will need a way to get defaults from a table.
905        qcl = self._add_schema(cl)
906        if a is None:
907            a = {}  # empty if argument is not present
908        attnames = self.get_attnames(qcl)
909        for n, t in attnames.items():
910            if n == 'oid':
911                continue
912            if t in ('int', 'integer', 'smallint', 'bigint',
913                    'float', 'real', 'double precision',
914                    'num', 'numeric', 'money'):
915                a[n] = 0
916            elif t in ('bool', 'boolean'):
917                a[n] = 'f'
918            else:
919                a[n] = ''
920        return a
921
922    def delete(self, cl, d=None, **kw):
923        """Delete an existing row in a database table.
924
925        This method deletes the row from a table.  It deletes based on the
926        OID value as munged by get or passed as keyword, or on the primary
927        key of the table.  The return value is the number of deleted rows
928        (i.e. 0 if the row did not exist and 1 if the row was deleted).
929
930        """
931        # Like update, delete works on the oid.
932        # One day we will be testing that the record to be deleted
933        # isn't referenced somewhere (or else PostgreSQL will).
934        # Note that we only accept oid key from named args for safety
935        qcl = self._add_schema(cl)
936        qoid = _oid_key(qcl)
937        if 'oid' in kw:
938            kw[qoid] = kw['oid']
939            del kw['oid']
940        if d is None:
941            d = {}
942        d.update(kw)
943        if qoid in d:
944            where = 'oid = %s' % d[qoid]
945        else:
946            try:
947                keyname = self.pkey(qcl)
948            except KeyError:
949                raise _prg_error('Class %s has no primary key' % qcl)
950            if isinstance(keyname, basestring):
951                keyname = (keyname,)
952            attnames = self.get_attnames(qcl)
953            try:
954                where = ' AND '.join(['%s = %s'
955                    % (k, self._quote(d[k], attnames[k])) for k in keyname])
956            except KeyError:
957                raise _prg_error('Delete needs primary key or oid.')
958        q = 'DELETE FROM %s WHERE %s' % (qcl, where)
959        self._do_debug(q)
960        return int(self.db.query(q))
961
962    def notification_handler(self, event, callback, arg_dict={}, timeout=None):
963        """Get notification handler that will run the given callback."""
964        return NotificationHandler(self.db, event, callback, arg_dict, timeout)
965
966
967# if run as script, print some information
968
969if __name__ == '__main__':
970    print('PyGreSQL version' + version)
971    print('')
972    print(__doc__)
Note: See TracBrowser for help on using the repository browser.