source: trunk/pg.py @ 719

Last change on this file since 719 was 719, checked in by cito, 4 years ago

Better checks for system catalogs

  • Property svn:keywords set to Id
File size: 34.6 KB
Line 
1#! /usr/bin/python
2#
3# pg.py
4#
5# $Id: pg.py 719 2016-01-12 12:27:13Z cito $
6#
7
8"""PyGreSQL classic interface.
9
10This pg module implements some basic database management stuff.
11It includes the _pg module and builds on it, providing the higher
12level wrapper class named DB with additional functionality.
13This is known as the "classic" ("old style") PyGreSQL interface.
14For a DB-API 2 compliant interface use the newer pgdb module.
15
16"""
17
18# Copyright (c) 1997-2013 by D'Arcy J.M. Cain.
19#
20# Contributions made by Ch. Zwerschke and others.
21#
22# The notification handler is based on pgnotify which is
23# Copyright (c) 2001 Ng Pheng Siong. All rights reserved.
24#
25# Permission to use, copy, modify, and distribute this software and its
26# documentation for any purpose and without fee is hereby granted,
27# provided that the above copyright notice appear in all copies and that
28# both that copyright notice and this permission notice appear in
29# supporting documentation.
30
31from __future__ import print_function
32
33from _pg import *
34
35import select
36import warnings
37
38from decimal import Decimal
39from collections import namedtuple
40
41try:
42    basestring
43except NameError:  # Python >= 3.0
44    basestring = (str, bytes)
45
46set_decimal(Decimal)
47
48
49# Auxiliary functions which are independent from a DB connection:
50
51def _is_quoted(s):
52    """Check whether this string is a quoted identifier."""
53    s = s.replace('_', 'a')
54    return not s.isalnum() or s[:1].isdigit() or s != s.lower()
55
56
57def _is_unquoted(s):
58    """Check whether this string is an unquoted identifier."""
59    s = s.replace('_', 'a')
60    return s.isalnum() and not s[:1].isdigit()
61
62
63def _split_first_part(s):
64    """Split the first part of a dot separated string."""
65    s = s.lstrip()
66    if s[:1] == '"':
67        p = []
68        s = s.split('"', 3)[1:]
69        p.append(s[0])
70        while len(s) == 3 and s[1] == '':
71            p.append('"')
72            s = s[2].split('"', 2)
73            p.append(s[0])
74        p = [''.join(p)]
75        s = '"'.join(s[1:]).lstrip()
76        if s:
77            if s[:0] == '.':
78                p.append(s[1:])
79            else:
80                s = _split_first_part(s)
81                p[0] += s[0]
82                if len(s) > 1:
83                    p.append(s[1])
84    else:
85        p = s.split('.', 1)
86        s = p[0].rstrip()
87        if _is_unquoted(s):
88            s = s.lower()
89        p[0] = s
90    return p
91
92
93def _split_parts(s):
94    """Split all parts of a dot separated string."""
95    q = []
96    while s:
97        s = _split_first_part(s)
98        q.append(s[0])
99        if len(s) < 2:
100            break
101        s = s[1]
102    return q
103
104
105def _join_parts(s):
106    """Join all parts of a dot separated string."""
107    return '.'.join(['"%s"' % p if _is_quoted(p) else p for p in s])
108
109
110def _oid_key(qcl):
111    """Build oid key from qualified class name."""
112    return 'oid(%s)' % qcl
113
114
115def _namedresult(q):
116    """Get query result as named tuples."""
117    row = namedtuple('Row', q.listfields())
118    return [row(*r) for r in q.getresult()]
119
120set_namedresult(_namedresult)
121
122
123def _db_error(msg, cls=DatabaseError):
124    """Returns DatabaseError with empty sqlstate attribute."""
125    error = cls(msg)
126    error.sqlstate = None
127    return error
128
129
130def _int_error(msg):
131    """Returns InternalError."""
132    return _db_error(msg, InternalError)
133
134
135def _prg_error(msg):
136    """Returns ProgrammingError."""
137    return _db_error(msg, ProgrammingError)
138
139
140class NotificationHandler(object):
141    """A PostgreSQL client-side asynchronous notification handler."""
142
143    def __init__(self, db, event, callback, arg_dict=None, timeout=None):
144        """Initialize the notification handler.
145
146        db       - PostgreSQL connection object.
147        event    - Event (notification channel) to LISTEN for.
148        callback - Event callback function.
149        arg_dict - A dictionary passed as the argument to the callback.
150        timeout  - Timeout in seconds; a floating point number denotes
151                   fractions of seconds. If it is absent or None, the
152                   callers will never time out.
153
154        """
155        if isinstance(db, DB):
156            db = db.db
157        self.db = db
158        self.event = event
159        self.stop_event = 'stop_%s' % event
160        self.listening = False
161        self.callback = callback
162        if arg_dict is None:
163            arg_dict = {}
164        self.arg_dict = arg_dict
165        self.timeout = timeout
166
167    def __del__(self):
168        self.close()
169
170    def close(self):
171        """Stop listening and close the connection."""
172        if self.db:
173            self.unlisten()
174            self.db.close()
175            self.db = None
176
177    def listen(self):
178        """Start listening for the event and the stop event."""
179        if not self.listening:
180            self.db.query('listen "%s"' % self.event)
181            self.db.query('listen "%s"' % self.stop_event)
182            self.listening = True
183
184    def unlisten(self):
185        """Stop listening for the event and the stop event."""
186        if self.listening:
187            self.db.query('unlisten "%s"' % self.event)
188            self.db.query('unlisten "%s"' % self.stop_event)
189            self.listening = False
190
191    def notify(self, db=None, stop=False, payload=None):
192        """Generate a notification.
193
194        Note: If the main loop is running in another thread, you must pass
195        a different database connection to avoid a collision.
196
197        """
198        if not db:
199            db = self.db
200        if self.listening:
201            q = 'notify "%s"' % (self.stop_event if stop else self.event)
202            if payload:
203                q += ", '%s'" % payload
204            return db.query(q)
205
206    def __call__(self, close=False):
207        """Invoke the notification handler.
208
209        The handler is a loop that actually LISTENs for two NOTIFY messages:
210
211        <event> and stop_<event>.
212
213        When either of these NOTIFY messages are received, its associated
214        'pid' and 'event' are inserted into <arg_dict>, and the callback is
215        invoked with <arg_dict>. If the NOTIFY message is stop_<event>, the
216        handler UNLISTENs both <event> and stop_<event> and exits.
217
218        Note: If you run this loop in another thread, don't use the same
219        database connection for database operations in the main thread.
220
221        """
222        self.listen()
223        _ilist = [self.db.fileno()]
224
225        while self.listening:
226            ilist, _olist, _elist = select.select(_ilist, [], [], self.timeout)
227            if ilist:
228                while self.listening:
229                    notice = self.db.getnotify()
230                    if not notice:  # no more messages
231                        break
232                    event, pid, extra = notice
233                    if event not in (self.event, self.stop_event):
234                        self.unlisten()
235                        raise _db_error(
236                            'listening for "%s" and "%s", but notified of "%s"'
237                            % (self.event, self.stop_event, event))
238                    if event == self.stop_event:
239                        self.unlisten()
240                    self.arg_dict['pid'] = pid
241                    self.arg_dict['event'] = event
242                    self.arg_dict['extra'] = extra
243                    self.callback(self.arg_dict)
244            else:   # we timed out
245                self.unlisten()
246                self.callback(None)
247
248
249def pgnotify(*args, **kw):
250    """Same as NotificationHandler, under the traditional name."""
251    warnings.warn("pgnotify is deprecated, use NotificationHandler instead.",
252        DeprecationWarning, stacklevel=2)
253    return NotificationHandler(*args, **kw)
254
255
256# The actual PostGreSQL database connection interface:
257
258class DB(object):
259    """Wrapper class for the _pg connection type."""
260
261    def __init__(self, *args, **kw):
262        """Create a new connection.
263
264        You can pass either the connection parameters or an existing
265        _pg or pgdb connection. This allows you to use the methods
266        of the classic pg interface with a DB-API 2 pgdb connection.
267
268        """
269        if not args and len(kw) == 1:
270            db = kw.get('db')
271        elif not kw and len(args) == 1:
272            db = args[0]
273        else:
274            db = None
275        if db:
276            if isinstance(db, DB):
277                db = db.db
278            else:
279                try:
280                    db = db._cnx
281                except AttributeError:
282                    pass
283        if not db or not hasattr(db, 'db') or not hasattr(db, 'query'):
284            db = connect(*args, **kw)
285            self._closeable = True
286        else:
287            self._closeable = False
288        self.db = db
289        self.dbname = db.db
290        self._regtypes = False
291        self._attnames = {}
292        self._pkeys = {}
293        self._privileges = {}
294        self._args = args, kw
295        self.debug = None  # For debugging scripts, this can be set
296            # * to a string format specification (e.g. in CGI set to "%s<BR>"),
297            # * to a file object to write debug statements or
298            # * to a callable object which takes a string argument
299            # * to any other true value to just print debug statements
300
301    def __getattr__(self, name):
302        # All undefined members are same as in underlying connection:
303        if self.db:
304            return getattr(self.db, name)
305        else:
306            raise _int_error('Connection is not valid')
307
308    def __dir__(self):
309        # Custom dir function including the attributes of the connection:
310        attrs = set(self.__class__.__dict__)
311        attrs.update(self.__dict__)
312        attrs.update(dir(self.db))
313        return sorted(attrs)
314
315    # Context manager methods
316
317    def __enter__(self):
318        """Enter the runtime context. This will start a transaction."""
319        self.begin()
320        return self
321
322    def __exit__(self, et, ev, tb):
323        """Exit the runtime context. This will end the transaction."""
324        if et is None and ev is None and tb is None:
325            self.commit()
326        else:
327            self.rollback()
328
329    # Auxiliary methods
330
331    def _do_debug(self, s):
332        """Print a debug message."""
333        if self.debug:
334            if isinstance(self.debug, basestring):
335                print(self.debug % s)
336            elif hasattr(self.debug, 'write'):
337                self.debug.write(s + '\n')
338            elif callable(self.debug):
339                self.debug(s)
340            else:
341                print(s)
342
343    @staticmethod
344    def _make_bool(d):
345        """Get boolean value corresponding to d."""
346        return bool(d) if get_bool() else ('t' if d else 'f')
347
348    def _quote_text(self, d):
349        """Quote text value."""
350        if not isinstance(d, basestring):
351            d = str(d)
352        return "'%s'" % self.escape_string(d)
353
354    _bool_true = frozenset('t true 1 y yes on'.split())
355
356    def _quote_bool(self, d):
357        """Quote boolean value."""
358        if isinstance(d, basestring):
359            if not d:
360                return 'NULL'
361            d = d.lower() in self._bool_true
362        return "'t'" if d else "'f'"
363
364    _date_literals = frozenset('current_date current_time'
365        ' current_timestamp localtime localtimestamp'.split())
366
367    def _quote_date(self, d):
368        """Quote date value."""
369        if not d:
370            return 'NULL'
371        if isinstance(d, basestring) and d.lower() in self._date_literals:
372            return d
373        return self._quote_text(d)
374
375    def _quote_num(self, d):
376        """Quote numeric value."""
377        if not d and d != 0:
378            return 'NULL'
379        return str(d)
380
381    def _quote_money(self, d):
382        """Quote money value."""
383        if d is None or d == '':
384            return 'NULL'
385        if not isinstance(d, basestring):
386            d = str(d)
387        return d
388
389    if bytes is str:  # Python < 3.0
390        """Quote bytes value."""
391
392        def _quote_bytea(self, d):
393            return "'%s'" % self.escape_bytea(d)
394
395    else:
396
397        def _quote_bytea(self, d):
398            return "'%s'" % self.escape_bytea(d).decode('ascii')
399
400    _quote_funcs = dict(  # quote methods for each type
401        text=_quote_text, bool=_quote_bool, date=_quote_date,
402        int=_quote_num, num=_quote_num, float=_quote_num,
403        money=_quote_money, bytea=_quote_bytea)
404
405    def _quote(self, d, t):
406        """Return quotes if needed."""
407        if d is None:
408            return 'NULL'
409        try:
410            quote_func = self._quote_funcs[t]
411        except KeyError:
412            quote_func = self._quote_funcs['text']
413        return quote_func(self, d)
414
415    def _split_schema(self, cl):
416        """Return schema and name of object separately.
417
418        This auxiliary function splits off the namespace (schema)
419        belonging to the class with the name cl. If the class name
420        is not qualified, the function is able to determine the schema
421        of the class, taking into account the current search path.
422
423        """
424        s = _split_parts(cl)
425        if len(s) > 1:  # name already qualfied?
426            # should be database.schema.table or schema.table
427            if len(s) > 3:
428                raise _prg_error('Too many dots in class name %s' % cl)
429            schema, cl = s[-2:]
430        else:
431            cl = s[0]
432            # determine search path
433            q = 'SELECT current_schemas(TRUE)'
434            schemas = self.db.query(q).getresult()[0][0][1:-1].split(',')
435            if schemas:  # non-empty path
436                # search schema for this object in the current search path
437                q = ' UNION '.join(
438                    ["SELECT %d::integer AS n, '%s'::name AS nspname"
439                        % s for s in enumerate(schemas)])
440                q = ("SELECT nspname FROM pg_class"
441                    " JOIN pg_namespace"
442                    " ON pg_class.relnamespace = pg_namespace.oid"
443                    " JOIN (%s) AS p USING (nspname)"
444                    " WHERE pg_class.relname = '%s'"
445                    " ORDER BY n LIMIT 1" % (q, cl))
446                schema = self.db.query(q).getresult()
447                if schema:  # schema found
448                    schema = schema[0][0]
449                else:  # object not found in current search path
450                    schema = 'public'
451            else:  # empty path
452                schema = 'public'
453        return schema, cl
454
455    def _add_schema(self, cl):
456        """Ensure that the class name is prefixed with a schema name."""
457        return _join_parts(self._split_schema(cl))
458
459    # Public methods
460
461    # escape_string and escape_bytea exist as methods,
462    # so we define unescape_bytea as a method as well
463    unescape_bytea = staticmethod(unescape_bytea)
464
465    def close(self):
466        """Close the database connection."""
467        # Wraps shared library function so we can track state.
468        if self._closeable:
469            if self.db:
470                self.db.close()
471                self.db = None
472            else:
473                raise _int_error('Connection already closed')
474
475    def reset(self):
476        """Reset connection with current parameters.
477
478        All derived queries and large objects derived from this connection
479        will not be usable after this call.
480
481        """
482        if self.db:
483            self.db.reset()
484        else:
485            raise _int_error('Connection already closed')
486
487    def reopen(self):
488        """Reopen connection to the database.
489
490        Used in case we need another connection to the same database.
491        Note that we can still reopen a database that we have closed.
492
493        """
494        # There is no such shared library function.
495        if self._closeable:
496            db = connect(*self._args[0], **self._args[1])
497            if self.db:
498                self.db.close()
499            self.db = db
500
501    def begin(self, mode=None):
502        """Begin a transaction."""
503        qstr = 'BEGIN'
504        if mode:
505            qstr += ' ' + mode
506        return self.query(qstr)
507
508    start = begin
509
510    def commit(self):
511        """Commit the current transaction."""
512        return self.query('COMMIT')
513
514    end = commit
515
516    def rollback(self, name=None):
517        """Rollback the current transaction."""
518        qstr = 'ROLLBACK'
519        if name:
520            qstr += ' TO ' + name
521        return self.query(qstr)
522
523    def savepoint(self, name=None):
524        """Define a new savepoint within the current transaction."""
525        qstr = 'SAVEPOINT'
526        if name:
527            qstr += ' ' + name
528        return self.query(qstr)
529
530    def release(self, name):
531        """Destroy a previously defined savepoint."""
532        return self.query('RELEASE ' + name)
533
534    def query(self, qstr, *args):
535        """Executes a SQL command string.
536
537        This method simply sends a SQL query to the database. If the query is
538        an insert statement that inserted exactly one row into a table that
539        has OIDs, the return value is the OID of the newly inserted row.
540        If the query is an update or delete statement, or an insert statement
541        that did not insert exactly one row in a table with OIDs, then the
542        number of rows affected is returned as a string. If it is a statement
543        that returns rows as a result (usually a select statement, but maybe
544        also an "insert/update ... returning" statement), this method returns
545        a pgqueryobject that can be accessed via getresult() or dictresult()
546        or simply printed. Otherwise, it returns `None`.
547
548        The query can contain numbered parameters of the form $1 in place
549        of any data constant. Arguments given after the query string will
550        be substituted for the corresponding numbered parameter. Parameter
551        values can also be given as a single list or tuple argument.
552
553        """
554        # Wraps shared library function for debugging.
555        if not self.db:
556            raise _int_error('Connection is not valid')
557        self._do_debug(qstr)
558        return self.db.query(qstr, args)
559
560    def pkey(self, cl, newpkey=None):
561        """This method gets or sets the primary key of a class.
562
563        Composite primary keys are represented as frozensets. Note that
564        this raises an exception if the table does not have a primary key.
565
566        If newpkey is set and is not a dictionary then set that
567        value as the primary key of the class.  If it is a dictionary
568        then replace the _pkeys dictionary with a copy of it.
569
570        """
571        # First see if the caller is supplying a dictionary
572        if isinstance(newpkey, dict):
573            # make sure that all classes have a namespace
574            self._pkeys = dict([
575                (cl if '.' in cl else 'public.' + cl, pkey)
576                for cl, pkey in newpkey.items()])
577            return self._pkeys
578
579        qcl = self._add_schema(cl)  # build fully qualified class name
580        # Check if the caller is supplying a new primary key for the class
581        if newpkey:
582            self._pkeys[qcl] = newpkey
583            return newpkey
584
585        # Get all the primary keys at once
586        if qcl not in self._pkeys:
587            # if not found, check again in case it was added after we started
588            self._pkeys = {}
589            for r in self.db.query(
590                "SELECT pg_namespace.nspname, pg_class.relname,"
591                    " pg_attribute.attname FROM pg_class"
592                " JOIN pg_namespace"
593                    " ON pg_namespace.oid = pg_class.relnamespace"
594                    " AND pg_namespace.nspname"
595                    " NOT SIMILAR TO 'pg/_%|information/_schema' ESCAPE '/'"
596                " JOIN pg_attribute ON pg_attribute.attrelid = pg_class.oid"
597                    " AND NOT pg_attribute.attisdropped"
598                " JOIN pg_index ON pg_index.indrelid = pg_class.oid"
599                    " AND pg_index.indisprimary"
600                    " AND pg_attribute.attnum"
601                        " = ANY (pg_index.indkey)").getresult():
602                cl, pkey = _join_parts(r[:2]), r[2]
603                self._pkeys.setdefault(cl, []).append(pkey)
604            # (only) for composite primary keys, the values will be frozensets
605            for cl, pkey in self._pkeys.items():
606                self._pkeys[cl] = frozenset(pkey) if len(pkey) > 1 else pkey[0]
607            self._do_debug(self._pkeys)
608
609        # will raise an exception if primary key doesn't exist
610        return self._pkeys[qcl]
611
612    def get_databases(self):
613        """Get list of databases in the system."""
614        return [s[0] for s in
615            self.db.query('SELECT datname FROM pg_database').getresult()]
616
617    def get_relations(self, kinds=None):
618        """Get list of relations in connected database of specified kinds.
619
620            If kinds is None or empty, all kinds of relations are returned.
621            Otherwise kinds can be a string or sequence of type letters
622            specifying which kind of relations you want to list.
623
624        """
625        where = "pg_class.relkind IN (%s) AND" % ','.join(
626            ["'%s'" % x for x in kinds]) if kinds else ''
627        return [_join_parts(x) for x in self.db.query(
628            "SELECT pg_namespace.nspname, pg_class.relname"
629            " FROM pg_class "
630            " JOIN pg_namespace ON pg_namespace.oid = pg_class.relnamespace "
631            " WHERE %s pg_namespace.nspname"
632            " NOT SIMILAR TO 'pg/_%%|information/_schema' ESCAPE '/'"
633            " ORDER BY 1, 2" % where).getresult()]
634
635    def get_tables(self):
636        """Return list of tables in connected database."""
637        return self.get_relations('r')
638
639    def get_attnames(self, cl, newattnames=None):
640        """Given the name of a table, digs out the set of attribute names.
641
642        Returns a dictionary of attribute names (the names are the keys,
643        the values are the names of the attributes' types).
644        If the optional newattnames exists, it must be a dictionary and
645        will become the new attribute names dictionary.
646
647        By default, only a limited number of simple types will be returned.
648        You can get the regular types after calling use_regtypes(True).
649
650        """
651        if isinstance(newattnames, dict):
652            self._attnames = newattnames
653            return
654        elif newattnames:
655            raise _prg_error('If supplied, newattnames must be a dictionary')
656        cl = self._split_schema(cl)  # split into schema and class
657        qcl = _join_parts(cl)  # build fully qualified name
658        # May as well cache them:
659        if qcl in self._attnames:
660            return self._attnames[qcl]
661        if qcl not in self.get_relations('rv'):
662            raise _prg_error('Class %s does not exist' % qcl)
663
664        q = "SELECT pg_attribute.attname, pg_type.typname"
665        if self._regtypes:
666            q += "::regtype"
667        q += (" FROM pg_class"
668            " JOIN pg_namespace ON pg_class.relnamespace = pg_namespace.oid"
669            " JOIN pg_attribute ON pg_attribute.attrelid = pg_class.oid"
670            " JOIN pg_type ON pg_type.oid = pg_attribute.atttypid"
671            " WHERE pg_namespace.nspname = '%s' AND pg_class.relname = '%s'"
672            " AND (pg_attribute.attnum > 0 OR pg_attribute.attname = 'oid')"
673            " AND NOT pg_attribute.attisdropped") % cl
674        q = self.db.query(q).getresult()
675
676        if self._regtypes:
677            t = dict(q)
678        else:
679            t = {}
680            for att, typ in q:
681                if typ.startswith('bool'):
682                    typ = 'bool'
683                elif typ.startswith('abstime'):
684                    typ = 'date'
685                elif typ.startswith('date'):
686                    typ = 'date'
687                elif typ.startswith('interval'):
688                    typ = 'date'
689                elif typ.startswith('timestamp'):
690                    typ = 'date'
691                elif typ.startswith('oid'):
692                    typ = 'int'
693                elif typ.startswith('int'):
694                    typ = 'int'
695                elif typ.startswith('float'):
696                    typ = 'float'
697                elif typ.startswith('numeric'):
698                    typ = 'num'
699                elif typ.startswith('money'):
700                    typ = 'money'
701                elif typ.startswith('bytea'):
702                    typ = 'bytea'
703                else:
704                    typ = 'text'
705                t[att] = typ
706
707        self._attnames[qcl] = t  # cache it
708        return self._attnames[qcl]
709
710    def use_regtypes(self, regtypes=None):
711        """Use regular type names instead of simplified type names."""
712        if regtypes is None:
713            return self._regtypes
714        else:
715            regtypes = bool(regtypes)
716            if regtypes != self._regtypes:
717                self._regtypes = regtypes
718                self._attnames.clear()
719            return regtypes
720
721    def has_table_privilege(self, cl, privilege='select'):
722        """Check whether current user has specified table privilege."""
723        qcl = self._add_schema(cl)
724        privilege = privilege.lower()
725        try:
726            return self._privileges[(qcl, privilege)]
727        except KeyError:
728            q = "SELECT has_table_privilege('%s', '%s')" % (qcl, privilege)
729            ret = self.db.query(q).getresult()[0][0] == self._make_bool(True)
730            self._privileges[(qcl, privilege)] = ret
731            return ret
732
733    def get(self, cl, arg, keyname=None):
734        """Get a row from a database table or view.
735
736        This method is the basic mechanism to get a single row.  The keyname
737        that the key specifies a unique row.  If keyname is not specified
738        then the primary key for the table is used.  If arg is a dictionary
739        then the value for the key is taken from it and it is modified to
740        include the new values, replacing existing values where necessary.
741        For a composite key, keyname can also be a sequence of key names.
742        The OID is also put into the dictionary if the table has one, but
743        in order to allow the caller to work with multiple tables, it is
744        munged as oid(schema.table).
745
746        """
747        if cl.endswith('*'):  # scan descendant tables?
748            cl = cl[:-1].rstrip()  # need parent table name
749        # build qualified class name
750        qcl = self._add_schema(cl)
751        # To allow users to work with multiple tables,
752        # we munge the name of the "oid" the key
753        qoid = _oid_key(qcl)
754        if not keyname:
755            # use the primary key by default
756            try:
757                keyname = self.pkey(qcl)
758            except KeyError:
759                raise _prg_error('Class %s has no primary key' % qcl)
760        attnames = self.get_attnames(qcl)
761        # We want the oid for later updates if that isn't the key
762        if keyname == 'oid':
763            if isinstance(arg, dict):
764                if qoid not in arg:
765                    raise _db_error('%s not in arg' % qoid)
766            else:
767                arg = {qoid: arg}
768            what = '*'
769            where = 'oid = %s' % arg[qoid]
770        else:
771            if isinstance(keyname, basestring):
772                keyname = (keyname,)
773            if not isinstance(arg, dict):
774                if len(keyname) > 1:
775                    raise _prg_error('Composite key needs dict as arg')
776                arg = dict([(k, arg) for k in keyname])
777            what = ', '.join(attnames)
778            where = ' AND '.join(['%s = %s'
779                % (k, self._quote(arg[k], attnames[k])) for k in keyname])
780        q = 'SELECT %s FROM %s WHERE %s LIMIT 1' % (what, qcl, where)
781        self._do_debug(q)
782        res = self.db.query(q).dictresult()
783        if not res:
784            raise _db_error('No such record in %s where %s' % (qcl, where))
785        for n, value in res[0].items():
786            if n == 'oid':
787                n = qoid
788            elif attnames.get(n) == 'bytea':
789                value = self.unescape_bytea(value)
790            arg[n] = value
791        return arg
792
793    def insert(self, cl, d=None, **kw):
794        """Insert a row into a database table.
795
796        This method inserts a row into a table.  If a dictionary is
797        supplied it starts with that.  Otherwise it uses a blank dictionary.
798        Either way the dictionary is updated from the keywords.
799
800        The dictionary is then, if possible, reloaded with the values actually
801        inserted in order to pick up values modified by rules, triggers, etc.
802
803        Note: The method currently doesn't support insert into views
804        although PostgreSQL does.
805
806        """
807        qcl = self._add_schema(cl)
808        qoid = _oid_key(qcl)
809        if d is None:
810            d = {}
811        d.update(kw)
812        attnames = self.get_attnames(qcl)
813        names, values = [], []
814        for n in attnames:
815            if n != 'oid' and n in d:
816                names.append('"%s"' % n)
817                values.append(self._quote(d[n], attnames[n]))
818        names, values = ', '.join(names), ', '.join(values)
819        selectable = self.has_table_privilege(qcl)
820        if selectable:
821            ret = ' RETURNING %s*' % ('oid, ' if 'oid' in attnames else '')
822        else:
823            ret = ''
824        q = 'INSERT INTO %s (%s) VALUES (%s)%s' % (qcl, names, values, ret)
825        self._do_debug(q)
826        res = self.db.query(q)
827        if ret:
828            res = res.dictresult()[0]
829            for n, value in res.items():
830                if n == 'oid':
831                    n = qoid
832                elif attnames.get(n) == 'bytea':
833                    value = self.unescape_bytea(value)
834                d[n] = value
835        elif isinstance(res, int):
836            d[qoid] = res
837            if selectable:
838                self.get(qcl, d, 'oid')
839        elif selectable:
840            if qoid in d:
841                self.get(qcl, d, 'oid')
842            else:
843                try:
844                    self.get(qcl, d)
845                except ProgrammingError:
846                    pass  # table has no primary key
847        return d
848
849    def update(self, cl, d=None, **kw):
850        """Update an existing row in a database table.
851
852        Similar to insert but updates an existing row.  The update is based
853        on the OID value as munged by get or passed as keyword, or on the
854        primary key of the table.  The dictionary is modified, if possible,
855        to reflect any changes caused by the update due to triggers, rules,
856        default values, etc.
857
858        """
859        # Update always works on the oid which get returns if available,
860        # otherwise use the primary key.  Fail if neither.
861        # Note that we only accept oid key from named args for safety
862        qcl = self._add_schema(cl)
863        qoid = _oid_key(qcl)
864        if 'oid' in kw:
865            kw[qoid] = kw['oid']
866            del kw['oid']
867        if d is None:
868            d = {}
869        d.update(kw)
870        attnames = self.get_attnames(qcl)
871        if qoid in d:
872            where = 'oid = %s' % d[qoid]
873            keyname = ()
874        else:
875            try:
876                keyname = self.pkey(qcl)
877            except KeyError:
878                raise _prg_error('Class %s has no primary key' % qcl)
879            if isinstance(keyname, basestring):
880                keyname = (keyname,)
881            try:
882                where = ' AND '.join(['%s = %s'
883                    % (k, self._quote(d[k], attnames[k])) for k in keyname])
884            except KeyError:
885                raise _prg_error('Update needs primary key or oid.')
886        values = []
887        for n in attnames:
888            if n in d and n not in keyname:
889                values.append('%s = %s' % (n, self._quote(d[n], attnames[n])))
890        if not values:
891            return d
892        values = ', '.join(values)
893        selectable = self.has_table_privilege(qcl)
894        if selectable:
895            ret = ' RETURNING %s*' % ('oid, ' if 'oid' in attnames else '')
896        else:
897            ret = ''
898        q = 'UPDATE %s SET %s WHERE %s%s' % (qcl, values, where, ret)
899        self._do_debug(q)
900        res = self.db.query(q)
901        if ret:
902            res = res.dictresult()[0]
903            for n, value in res.items():
904                if n == 'oid':
905                    n = qoid
906                elif attnames.get(n) == 'bytea':
907                    value = self.unescape_bytea(value)
908                d[n] = value
909        else:
910            if selectable:
911                if qoid in d:
912                    self.get(qcl, d, 'oid')
913                else:
914                    self.get(qcl, d)
915        return d
916
917    def clear(self, cl, a=None):
918        """Clear all the attributes to values determined by the types.
919
920        Numeric types are set to 0, Booleans are set to false, and everything
921        else is set to the empty string.  If the array argument is present,
922        it is used as the array and any entries matching attribute names are
923        cleared with everything else left unchanged.
924
925        """
926        # At some point we will need a way to get defaults from a table.
927        qcl = self._add_schema(cl)
928        if a is None:
929            a = {}  # empty if argument is not present
930        attnames = self.get_attnames(qcl)
931        for n, t in attnames.items():
932            if n == 'oid':
933                continue
934            if t in ('int', 'integer', 'smallint', 'bigint',
935                    'float', 'real', 'double precision',
936                    'num', 'numeric', 'money'):
937                a[n] = 0
938            elif t in ('bool', 'boolean'):
939                a[n] = self._make_bool(False)
940            else:
941                a[n] = ''
942        return a
943
944    def delete(self, cl, d=None, **kw):
945        """Delete an existing row in a database table.
946
947        This method deletes the row from a table.  It deletes based on the
948        OID value as munged by get or passed as keyword, or on the primary
949        key of the table.  The return value is the number of deleted rows
950        (i.e. 0 if the row did not exist and 1 if the row was deleted).
951
952        """
953        # Like update, delete works on the oid.
954        # One day we will be testing that the record to be deleted
955        # isn't referenced somewhere (or else PostgreSQL will).
956        # Note that we only accept oid key from named args for safety
957        qcl = self._add_schema(cl)
958        qoid = _oid_key(qcl)
959        if 'oid' in kw:
960            kw[qoid] = kw['oid']
961            del kw['oid']
962        if d is None:
963            d = {}
964        d.update(kw)
965        if qoid in d:
966            where = 'oid = %s' % d[qoid]
967        else:
968            try:
969                keyname = self.pkey(qcl)
970            except KeyError:
971                raise _prg_error('Class %s has no primary key' % qcl)
972            if isinstance(keyname, basestring):
973                keyname = (keyname,)
974            attnames = self.get_attnames(qcl)
975            try:
976                where = ' AND '.join(['%s = %s'
977                    % (k, self._quote(d[k], attnames[k])) for k in keyname])
978            except KeyError:
979                raise _prg_error('Delete needs primary key or oid.')
980        q = 'DELETE FROM %s WHERE %s' % (qcl, where)
981        self._do_debug(q)
982        return int(self.db.query(q))
983
984    def notification_handler(self, event, callback, arg_dict={}, timeout=None):
985        """Get notification handler that will run the given callback."""
986        return NotificationHandler(self.db, event, callback, arg_dict, timeout)
987
988
989# if run as script, print some information
990
991if __name__ == '__main__':
992    print('PyGreSQL version' + version)
993    print('')
994    print(__doc__)
Note: See TracBrowser for help on using the repository browser.