source: trunk/module/pg.py @ 386

Last change on this file since 386 was 386, checked in by cito, 11 years ago

Added note about composite primary keys.

File size: 22.0 KB
Line 
1#!/usr/bin/env python
2#
3# pg.py
4#
5# Written by D'Arcy J.M. Cain
6# Improved by Christoph Zwerschke
7#
8# $Id: pg.py,v 1.72 2008-12-04 20:28:06 cito Exp $
9#
10
11"""PyGreSQL classic interface.
12
13This pg module implements some basic database management stuff.
14It includes the _pg module and builds on it, providing the higher
15level wrapper class named DB with addtional functionality.
16This is known as the "classic" ("old style") PyGreSQL interface.
17For a DB-API 2 compliant interface use the newer pgdb module.
18
19"""
20
21from _pg import *
22try:
23    frozenset
24except NameError: # Python < 2.4
25    from sets import ImmutableSet as frozenset
26try:
27    from decimal import Decimal
28    set_decimal(Decimal)
29except ImportError:
30    pass # Python < 2.4
31
32
33# Auxiliary functions which are independent from a DB connection:
34
35def _is_quoted(s):
36    """Check whether this string is a quoted identifier."""
37    s = s.replace('_', 'a')
38    return not s.isalnum() or s[:1].isdigit() or s != s.lower()
39
40def _is_unquoted(s):
41    """Check whether this string is an unquoted identifier."""
42    s = s.replace('_', 'a')
43    return s.isalnum() and not s[:1].isdigit()
44
45def _split_first_part(s):
46    """Split the first part of a dot separated string."""
47    s = s.lstrip()
48    if s[:1] == '"':
49        p = []
50        s = s.split('"', 3)[1:]
51        p.append(s[0])
52        while len(s) == 3 and s[1] == '':
53            p.append('"')
54            s = s[2].split('"', 2)
55            p.append(s[0])
56        p = [''.join(p)]
57        s = '"'.join(s[1:]).lstrip()
58        if s:
59            if s[:0] == '.':
60                p.append(s[1:])
61            else:
62                s = _split_first_part(s)
63                p[0] += s[0]
64                if len(s) > 1:
65                    p.append(s[1])
66    else:
67        p = s.split('.', 1)
68        s = p[0].rstrip()
69        if _is_unquoted(s):
70            s = s.lower()
71        p[0] = s
72    return p
73
74def _split_parts(s):
75    """Split all parts of a dot separated string."""
76    q = []
77    while s:
78        s = _split_first_part(s)
79        q.append(s[0])
80        if len(s) < 2:
81            break
82        s = s[1]
83    return q
84
85def _join_parts(s):
86    """Join all parts of a dot separated string."""
87    return '.'.join([_is_quoted(p) and '"%s"' % p or p for p in s])
88
89
90# The PostGreSQL database connection interface:
91
92class DB(object):
93    """Wrapper class for the _pg connection type."""
94
95    def __init__(self, *args, **kw):
96        """Create a new connection.
97
98        You can pass either the connection parameters or an existing
99        _pg or pgdb connection. This allows you to use the methods
100        of the classic pg interface with a DB-API 2 pgdb connection.
101
102        """
103        if not args and len(kw) == 1:
104            db = kw.get('db')
105        elif not kw and len(args) == 1:
106            db = args[0]
107        else:
108            db = None
109        if db:
110            if isinstance(db, DB):
111                db = db.db
112            else:
113                try:
114                    db = db._cnx
115                except AttributeError:
116                    pass
117        if not db or not hasattr(db, 'db') or not hasattr(db, 'query'):
118            db = connect(*args, **kw)
119            self._closeable = 1
120        else:
121            self._closeable = 0
122        self.db = db
123        self.dbname = db.db
124        self._attnames = {}
125        self._pkeys = {}
126        self._args = args, kw
127        self.debug = None # For debugging scripts, this can be set
128            # * to a string format specification (e.g. in CGI set to "%s<BR>"),
129            # * to a file object to write debug statements or
130            # * to a callable object which takes a string argument.
131
132    def __getattr__(self, name):
133        # All undefined members are the same as in the underlying pg connection:
134        if self.db:
135            return getattr(self.db, name)
136        else:
137            raise InternalError('Connection is not valid')
138
139    # Auxiliary methods
140
141    def _do_debug(self, s):
142        """Print a debug message."""
143        if self.debug:
144            if isinstance(self.debug, basestring):
145                print self.debug % s
146            elif isinstance(self.debug, file):
147                print >> self.debug, s
148            elif callable(self.debug):
149                self.debug(s)
150
151    def _quote_text(self, d):
152        """Quote text value."""
153        if not isinstance(d, basestring):
154            d = str(d)
155        return "'%s'" % self.escape_string(d)
156
157    _bool_true = frozenset('t true 1 y yes on'.split())
158
159    def _quote_bool(self, d):
160        """Quote boolean value."""
161        if isinstance(d, basestring):
162            if not d:
163                return 'NULL'
164            d = d.lower() in self._bool_true
165        else:
166            d = bool(d)
167        return ("'f'", "'t'")[d]
168
169    _date_literals = frozenset('current_date current_time'
170        ' current_timestamp localtime localtimestamp'.split())
171
172    def _quote_date(self, d):
173        """Quote date value."""
174        if not d:
175            return 'NULL'
176        if isinstance(d, basestring) and d.lower() in self._date_literals:
177            return d
178        return self._quote_text(d)
179
180    def _quote_num(self, d):
181        """Quote numeric value."""
182        if not d:
183            return 'NULL'
184        return str(d)
185
186    def _quote_money(self, d):
187        """Quote money value."""
188        if not d:
189            return 'NULL'
190        return "'%.2f'" % float(d)
191
192    _quote_funcs = dict( # quote methods for each type
193        text=_quote_text, bool=_quote_bool, date=_quote_date,
194        int=_quote_num, num=_quote_num, float=_quote_num,
195        money=_quote_money)
196
197    def _quote(self, d, t):
198        """Return quotes if needed."""
199        if d is None:
200            return 'NULL'
201        try:
202            quote_func = self._quote_funcs[t]
203        except KeyError:
204            quote_func = self._quote_funcs['text']
205        return quote_func(self, d)
206
207    def _split_schema(self, cl):
208        """Return schema and name of object separately.
209
210        This auxiliary function splits off the namespace (schema)
211        belonging to the class with the name cl. If the class name
212        is not qualified, the function is able to determine the schema
213        of the class, taking into account the current search path.
214
215        """
216        s = _split_parts(cl)
217        if len(s) > 1: # name already qualfied?
218            # should be database.schema.table or schema.table
219            if len(s) > 3:
220                raise ProgrammingError('Too many dots in class name %s' % cl)
221            schema, cl = s[-2:]
222        else:
223            cl = s[0]
224            # determine search path
225            query = 'SELECT current_schemas(TRUE)'
226            schemas = self.db.query(query).getresult()[0][0][1:-1].split(',')
227            if schemas: # non-empty path
228                # search schema for this object in the current search path
229                query = ' UNION '.join(
230                    ["SELECT %d::integer AS n, '%s'::name AS nspname"
231                        % s for s in enumerate(schemas)])
232                query = ("SELECT nspname FROM pg_class"
233                    " JOIN pg_namespace ON pg_class.relnamespace=pg_namespace.oid"
234                    " JOIN (%s) AS p USING (nspname)"
235                    " WHERE pg_class.relname='%s'"
236                    " ORDER BY n LIMIT 1" % (query, cl))
237                schema = self.db.query(query).getresult()
238                if schema: # schema found
239                    schema = schema[0][0]
240                else: # object not found in current search path
241                    schema = 'public'
242            else: # empty path
243                schema = 'public'
244        return schema, cl
245
246    # Public methods
247
248    # escape_string and escape_bytea exist as methods,
249    # so we define unescape_bytea as a method as well
250    unescape_bytea = staticmethod(unescape_bytea)
251
252    def close(self):
253        """Close the database connection."""
254        # Wraps shared library function so we can track state.
255        if self._closeable:
256            if self.db:
257                self.db.close()
258                self.db = None
259            else:
260                raise InternalError('Connection already closed')
261
262    def reset(self):
263        """Reset connection with current parameters.
264
265        All derived queries and large objects derived from this connection
266        will not be usable after this call.
267
268        """
269        self.db.reset()
270
271    def reopen(self):
272        """Reopen connection to the database.
273
274        Used in case we need another connection to the same database.
275        Note that we can still reopen a database that we have closed.
276
277        """
278        # There is no such shared library function.
279        if self._closeable:
280            db = connect(*self._args[0], **self._args[1])
281            if self.db:
282                self.db.close()
283            self.db = db
284
285    def query(self, qstr):
286        """Executes a SQL command string.
287
288        This method simply sends a SQL query to the database. If the query is
289        an insert statement, the return value is the OID of the newly
290        inserted row.  If it is otherwise a query that does not return a result
291        (ie. is not a some kind of SELECT statement), it returns None.
292        Otherwise, it returns a pgqueryobject that can be accessed via the
293        getresult or dictresult method or simply printed.
294
295        """
296        # Wraps shared library function for debugging.
297        if not self.db:
298            raise InternalError('Connection is not valid')
299        self._do_debug(qstr)
300        return self.db.query(qstr)
301
302    def pkey(self, cl, newpkey=None):
303        """This method gets or sets the primary key of a class.
304
305        If newpkey is set and is not a dictionary then set that
306        value as the primary key of the class.  If it is a dictionary
307        then replace the _pkeys dictionary with a copy of it.
308
309        """
310        # First see if the caller is supplying a dictionary
311        if isinstance(newpkey, dict):
312            # make sure that all classes have a namespace
313            self._pkeys = dict([
314                ('.' in cl and cl or 'public.' + cl, pkey)
315                for cl, pkey in newpkey.iteritems()])
316            return self._pkeys
317
318        # Build qualified name for the given class
319        qcl = _join_parts(self._split_schema(cl))
320        # Check if the caller is supplying a new primary key for that class
321        if newpkey:
322            self._pkeys[qcl] = newpkey
323            return newpkey
324
325        # Get all the primary keys at once
326        if qcl not in self._pkeys:
327            # if not found, check again in case it was added after we started
328            self._pkeys = dict([
329                (_join_parts(r[:2]), r[2]) for r in self.db.query(
330                "SELECT pg_namespace.nspname, pg_class.relname"
331                    ", pg_attribute.attname FROM pg_class"
332                " JOIN pg_namespace ON pg_namespace.oid=pg_class.relnamespace"
333                    " AND pg_namespace.nspname NOT LIKE 'pg_%'"
334                " JOIN pg_attribute ON pg_attribute.attrelid=pg_class.oid"
335                    " AND pg_attribute.attisdropped='f'"
336                " JOIN pg_index ON pg_index.indrelid=pg_class.oid"
337                    " AND pg_index.indisprimary='t'"
338                    # note that this gets only the first attribute
339                    # of composite primary keys
340                    " AND pg_index.indkey[0]=pg_attribute.attnum"
341                ).getresult()])
342            self._do_debug(self._pkeys)
343
344        # will raise an exception if primary key doesn't exist
345        return self._pkeys[qcl]
346
347    def get_databases(self):
348        """Get list of databases in the system."""
349        return [s[0] for s in
350            self.db.query('SELECT datname FROM pg_database').getresult()]
351
352    def get_relations(self, kinds=None):
353        """Get list of relations in connected database of specified kinds.
354
355            If kinds is None or empty, all kinds of relations are returned.
356            Otherwise kinds can be a string or sequence of type letters
357            specifying which kind of relations you want to list.
358
359        """
360        where = kinds and "pg_class.relkind IN (%s) AND" % ','.join(
361            ["'%s'" % x for x in kinds]) or ''
362        return map(_join_parts, self.db.query(
363            "SELECT pg_namespace.nspname, pg_class.relname "
364            "FROM pg_class "
365            "JOIN pg_namespace ON pg_namespace.oid=pg_class.relnamespace "
366            "WHERE %s pg_class.relname !~ '^Inv' AND "
367                "pg_class.relname !~ '^pg_' "
368            "ORDER BY 1, 2" % where).getresult())
369
370    def get_tables(self):
371        """Return list of tables in connected database."""
372        return self.get_relations('r')
373
374    def get_attnames(self, cl, newattnames=None):
375        """Given the name of a table, digs out the set of attribute names.
376
377        Returns a dictionary of attribute names (the names are the keys,
378        the values are the names of the attributes' types).
379        If the optional newattnames exists, it must be a dictionary and
380        will become the new attribute names dictionary.
381
382        """
383        if isinstance(newattnames, dict):
384            self._attnames = newattnames
385            return
386        elif newattnames:
387            raise ProgrammingError(
388                'If supplied, newattnames must be a dictionary')
389        cl = self._split_schema(cl) # split into schema and cl
390        qcl = _join_parts(cl) # build qualified name
391        # May as well cache them:
392        if qcl in self._attnames:
393            return self._attnames[qcl]
394        if qcl not in self.get_relations('rv'):
395            raise ProgrammingError('Class %s does not exist' % qcl)
396        t = {}
397        for att, typ in self.db.query("SELECT pg_attribute.attname"
398            ",pg_type.typname FROM pg_class"
399            " JOIN pg_namespace ON pg_class.relnamespace=pg_namespace.oid"
400            " JOIN pg_attribute ON pg_attribute.attrelid=pg_class.oid"
401            " JOIN pg_type ON pg_type.oid=pg_attribute.atttypid"
402            " WHERE pg_namespace.nspname='%s' AND pg_class.relname='%s'"
403            " AND (pg_attribute.attnum>0 or pg_attribute.attname='oid')"
404            " AND pg_attribute.attisdropped='f'"
405                % cl).getresult():
406            if typ.startswith('bool'):
407                t[att] = 'bool'
408            elif typ.startswith('abstime'):
409                t[att] = 'date'
410            elif typ.startswith('date'):
411                t[att] = 'date'
412            elif typ.startswith('interval'):
413                t[att] = 'date'
414            elif typ.startswith('timestamp'):
415                t[att] = 'date'
416            elif typ.startswith('oid'):
417                t[att] = 'int'
418            elif typ.startswith('int'):
419                t[att] = 'int'
420            elif typ.startswith('float'):
421                t[att] = 'float'
422            elif typ.startswith('numeric'):
423                t[att] = 'num'
424            elif typ.startswith('money'):
425                t[att] = 'money'
426            else:
427                t[att] = 'text'
428        self._attnames[qcl] = t # cache it
429        return self._attnames[qcl]
430
431    def get(self, cl, arg, keyname=None, view=0):
432        """Get a tuple from a database table or view.
433
434        This method is the basic mechanism to get a single row.  It assumes
435        that the key specifies a unique row.  If keyname is not specified
436        then the primary key for the table is used.  If arg is a dictionary
437        then the value for the key is taken from it and it is modified to
438        include the new values, replacing existing values where necessary.
439        The OID is also put into the dictionary if the table has one, but
440        in order to allow the caller to work with multiple tables, it is
441        munged as oid(schema.table).
442
443        """
444        if cl.endswith('*'): # scan descendant tables?
445            cl = cl[:-1].rstrip() # need parent table name
446        qcl = _join_parts(self._split_schema(cl)) # build qualified name
447        # To allow users to work with multiple tables,
448        # we munge the name when the key is "oid"
449        foid = 'oid(%s)' % qcl # build mangled name
450        if keyname is None: # use the primary key by default
451            keyname = self.pkey(qcl)
452        fnames = self.get_attnames(qcl)
453        if isinstance(arg, dict):
454            k = arg[keyname == 'oid' and foid or keyname]
455        else:
456            k = arg
457            arg = {}
458        # We want the oid for later updates if that isn't the key
459        if keyname == 'oid':
460            q = 'SELECT * FROM %s WHERE oid=%s LIMIT 1' % (qcl, k)
461        elif view:
462            q = 'SELECT * FROM %s WHERE %s=%s LIMIT 1' % \
463                (qcl, keyname, self._quote(k, fnames[keyname]))
464        else:
465            q = 'SELECT %s FROM %s WHERE %s=%s LIMIT 1' % \
466                (','.join(fnames), qcl, \
467                    keyname, self._quote(k, fnames[keyname]))
468        self._do_debug(q)
469        res = self.db.query(q).dictresult()
470        if not res:
471            raise DatabaseError('No such record in %s where %s=%s'
472                % (qcl, keyname, self._quote(k, fnames[keyname])))
473        for k, d in res[0].iteritems():
474            if k == 'oid':
475                k = foid
476            arg[k] = d
477        return arg
478
479    def insert(self, cl, d=None, return_changes=True, **kw):
480        """Insert a tuple into a database table.
481
482        This method inserts a row into a table.  If a dictionary is
483        supplied it starts with that.  Otherwise it uses a blank dictionary.
484        Either way the dictionary is updated from the keywords.
485
486        The dictionary is then reloaded with the values actually inserted
487        in order to pick up values modified by rules, triggers, etc.  If
488        the optional flag return_changes is set to False this reload will
489        be skipped.
490
491        Note: The method currently doesn't support insert into views
492        although PostgreSQL does.
493
494        """
495        if d is None:
496            a = {}
497        else:
498            a = d
499        a.update(kw)
500
501        qcl = _join_parts(self._split_schema(cl)) # build qualified name
502        foid = 'oid(%s)' % qcl # build mangled name
503        fnames = self.get_attnames(qcl)
504        t = []
505        n = []
506
507        for f in fnames:
508            if f != 'oid' and f in a:
509                t.append(self._quote(a[f], fnames[f]))
510                n.append('"%s"' % f)
511
512        q = 'INSERT INTO %s (%s) VALUES (%s)' % \
513            (qcl, ','.join(n), ','.join(t))
514        self._do_debug(q)
515        a[foid] = self.db.query(q)
516
517        # Reload the dictionary to catch things modified by engine.
518        # Note that get() changes 'oid' below to oid(schema.table).
519        if return_changes: self.get(qcl, a, 'oid')
520
521        return a
522
523    def update(self, cl, d=None, **kw):
524        """Update an existing row in a database table.
525
526        Similar to insert but updates an existing row.  The update is based
527        on the OID value as munged by get.  The array returned is the
528        one sent modified to reflect any changes caused by the update due
529        to triggers, rules, defaults, etc.
530
531        """
532        # Update always works on the oid which get returns if available,
533        # otherwise use the primary key.  Fail if neither.
534        qcl = _join_parts(self._split_schema(cl)) # build qualified name
535        foid = 'oid(%s)' % qcl # build mangled oid
536
537        # Note that we only accept oid key from named args for safety
538        if 'oid' in kw:
539            kw[foid] = kw['oid']
540            del kw['oid']
541
542        if d is None:
543            a = {}
544        else:
545            a = d
546        a.update(kw)
547
548        if foid in a:
549            where = "oid=%s" % a[foid]
550        else:
551            try:
552                pk = self.pkey(qcl)
553            except Exception:
554                raise ProgrammingError(
555                    'Update needs primary key or oid as %s' % foid)
556            where = "%s='%s'" % (pk, a[pk])
557        v = []
558        fnames = self.get_attnames(qcl)
559        for ff in fnames:
560            if ff != 'oid' and ff in a:
561                v.append('%s=%s' % (ff, self._quote(a[ff], fnames[ff])))
562        if v == []:
563            return None
564        q = 'UPDATE %s SET %s WHERE %s' % (qcl, ','.join(v), where)
565        self._do_debug(q)
566        self.db.query(q)
567        # Reload the dictionary to catch things modified by engine:
568        if foid in a:
569            return self.get(qcl, a, 'oid')
570        else:
571            return self.get(qcl, a)
572
573    def clear(self, cl, a=None):
574        """
575
576        This method clears all the attributes to values determined by the types.
577        Numeric types are set to 0, Booleans are set to 'f', and everything
578        else is set to the empty string.  If the array argument is present,
579        it is used as the array and any entries matching attribute names are
580        cleared with everything else left unchanged.
581
582        """
583        # At some point we will need a way to get defaults from a table.
584        if a is None:
585            a = {} # empty if argument is not present
586        qcl = _join_parts(self._split_schema(cl)) # build qualified name
587        fnames = self.get_attnames(qcl)
588        for k, t in fnames.iteritems():
589            if k == 'oid':
590                continue
591            if t in ('int', 'float', 'num', 'money'):
592                a[k] = 0
593            elif t == 'bool':
594                a[k] = 'f'
595            else:
596                a[k] = ''
597        return a
598
599    def delete(self, cl, d=None, **kw):
600        """Delete an existing row in a database table.
601
602        This method deletes the row from a table.
603        It deletes based on the OID munged as described above.
604
605        """
606        # Like update, delete works on the oid.
607        # One day we will be testing that the record to be deleted
608        # isn't referenced somewhere (or else PostgreSQL will).
609        qcl = _join_parts(self._split_schema(cl)) # build qualified name
610        foid = 'oid(%s)' % qcl # build mangled oid
611
612        # Note that we only accept oid key from named args for safety
613        if 'oid' in kw:
614            kw[foid] = kw['oid']
615            del kw['oid']
616
617        if d is None:
618            a = {}
619        else:
620            a = d
621        a.update(kw)
622
623        q = 'DELETE FROM %s WHERE oid=%s' % (qcl, a[foid])
624        self._do_debug(q)
625        self.db.query(q)
626
627
628# if run as script, print some information
629
630if __name__ == '__main__':
631    print 'PyGreSQL version', version
632    print
633    print __doc__
Note: See TracBrowser for help on using the repository browser.