source: trunk/module/pg.py @ 357

Last change on this file since 357 was 357, checked in by cito, 11 years ago

Modernize code in pg module a bit (we are requiring Py 2.3 now).

File size: 21.5 KB
Line 
1#!/usr/bin/env python
2#
3# pg.py
4#
5# Written by D'Arcy J.M. Cain
6# Improved by Christoph Zwerschke
7#
8# $Id: pg.py,v 1.59 2008-11-21 11:18:32 cito Exp $
9#
10
11"""PyGreSQL classic interface.
12
13This pg module implements some basic database management stuff.
14It includes the _pg module and builds on it, providing the higher
15level wrapper class named DB with addtional functionality.
16This is known as the "classic" ("old style") PyGreSQL interface.
17For a DB-API 2 compliant interface use the newer pgdb module.
18
19"""
20
21from _pg import *
22try:
23    from decimal import Decimal
24    set_decimal(Decimal)
25except ImportError:
26    pass # Python < 2.4
27
28
29# Auxiliary functions which are independent from a DB connection:
30
31def _quote(d, t):
32    """Return quotes if needed."""
33    if d is None:
34        return 'NULL'
35    if t in ('int', 'seq', 'float', 'num'):
36        if d == '':
37            return 'NULL'
38        return str(d)
39    if t == 'money':
40        if d == '':
41            return 'NULL'
42        return "'%.2f'" % float(d)
43    if t == 'bool':
44        if isinstance(d, basestring):
45            if d == '':
46                return 'NULL'
47            d = d.lower() in ('t', 'true', '1', 'y', 'yes', 'on')
48        else:
49            d = bool(d)
50        return ("'f'", "'t'")[d]
51    if t in ('date', 'inet', 'cidr'):
52        if d == '':
53            return 'NULL'
54        if d.lower() in ('current_date', 'current_time',
55            'current_timestamp', 'localtime', 'localtimestamp'):
56            return d
57    return "'%s'" % str(d).replace("\\", "\\\\").replace("'", "''")
58
59def _is_quoted(s):
60    """Check whether this string is a quoted identifier."""
61    s = s.replace('_', 'a')
62    return not s.isalnum() or s[:1].isdigit() or s != s.lower()
63
64def _is_unquoted(s):
65    """Check whether this string is an unquoted identifier."""
66    s = s.replace('_', 'a')
67    return s.isalnum() and not s[:1].isdigit()
68
69def _split_first_part(s):
70    """Split the first part of a dot separated string."""
71    s = s.lstrip()
72    if s[:1] == '"':
73        p = []
74        s = s.split('"', 3)[1:]
75        p.append(s[0])
76        while len(s) == 3 and s[1] == '':
77            p.append('"')
78            s = s[2].split('"', 2)
79            p.append(s[0])
80        p = [''.join(p)]
81        s = '"'.join(s[1:]).lstrip()
82        if s:
83            if s[:0] == '.':
84                p.append(s[1:])
85            else:
86                s = _split_first_part(s)
87                p[0] += s[0]
88                if len(s) > 1:
89                    p.append(s[1])
90    else:
91        p = s.split('.', 1)
92        s = p[0].rstrip()
93        if _is_unquoted(s):
94            s = s.lower()
95        p[0] = s
96    return p
97
98def _split_parts(s):
99    """Split all parts of a dot separated string."""
100    q = []
101    while s:
102        s = _split_first_part(s)
103        q.append(s[0])
104        if len(s) < 2:
105            break
106        s = s[1]
107    return q
108
109def _join_parts(s):
110    """Join all parts of a dot separated string."""
111    return '.'.join([_is_quoted(p) and '"%s"' % p or p for p in s])
112
113
114# The PostGreSQL database connection interface:
115
116class DB(object):
117    """Wrapper class for the _pg connection type."""
118
119    def __init__(self, *args, **kw):
120        """Create a new connection.
121
122        You can pass either the connection parameters or an existing
123        _pg or pgdb connection. This allows you to use the methods
124        of the classic pg interface with a DB-API 2 pgdb connection.
125
126        """
127        if not args and len(kw) == 1:
128            db = kw.get('db')
129        elif not kw and len(args) == 1:
130            db = args[0]
131        else:
132            db = None
133        if db:
134            if isinstance(db, DB):
135                db = db.db
136            else:
137                try:
138                    db = db._cnx
139                except AttributeError:
140                    pass
141        if not db or not hasattr(db, 'db') or not hasattr(db, 'query'):
142            db = connect(*args, **kw)
143            self._closeable = 1
144        else:
145            self._closeable = 0
146        self.db = db
147        self.dbname = db.db
148        self._attnames = {}
149        self._pkeys = {}
150        self._args = args, kw
151        self.debug = None # For debugging scripts, this can be set
152            # * to a string format specification (e.g. in CGI set to "%s<BR>"),
153            # * to a file object to write debug statements or
154            # * to a callable object which takes a string argument.
155
156    def __getattr__(self, name):
157        # All undefined members are the same as in the underlying pg connection:
158        if self.db:
159            return getattr(self.db, name)
160        else:
161            raise InternalError('Connection is not valid')
162
163    # For convenience, define some module functions as static methods also:
164    escape_string, escape_bytea, unescape_bytea = map(staticmethod,
165        (escape_string, escape_bytea, unescape_bytea))
166
167    def _do_debug(self, s):
168        """Print a debug message."""
169        if self.debug:
170            if isinstance(self.debug, basestring):
171                print self.debug % s
172            elif isinstance(self.debug, file):
173                print >> self.debug, s
174            elif callable(self.debug):
175                self.debug(s)
176
177    def close(self):
178        """Close the database connection."""
179        # Wraps shared library function so we can track state.
180        if self._closeable:
181            if self.db:
182                self.db.close()
183                self.db = None
184            else:
185                raise InternalError('Connection already closed')
186
187    def reopen(self):
188        """Reopen connection to the database.
189
190        Used in case we need another connection to the same database.
191        Note that we can still reopen a database that we have closed.
192
193        """
194        # There is no such shared library function.
195        if self._closeable:
196            if self.db:
197                self.db.close()
198            try:
199                self.db = connect(*self._args[0], **self._args[1])
200            except:
201                self.db = None
202                raise
203
204    def query(self, qstr):
205        """Executes a SQL command string.
206
207        This method simply sends a SQL query to the database. If the query is
208        an insert statement, the return value is the OID of the newly
209        inserted row.  If it is otherwise a query that does not return a result
210        (ie. is not a some kind of SELECT statement), it returns None.
211        Otherwise, it returns a pgqueryobject that can be accessed via the
212        getresult or dictresult method or simply printed.
213
214        """
215        # Wraps shared library function for debugging.
216        if not self.db:
217            raise InternalError('Connection is not valid')
218        self._do_debug(qstr)
219        return self.db.query(qstr)
220
221    def _split_schema(self, cl):
222        """Return schema and name of object separately.
223
224        This auxiliary function splits off the namespace (schema)
225        belonging to the class with the name cl. If the class name
226        is not qualified, the function is able to determine the schema
227        of the class, taking into account the current search path.
228
229        """
230        s = _split_parts(cl)
231        if len(s) > 1: # name already qualfied?
232            # should be database.schema.table or schema.table
233            if len(s) > 3:
234                raise ProgrammingError('Too many dots in class name %s' % cl)
235            schema, cl = s[-2:]
236        else:
237            cl = s[0]
238            # determine search path
239            query = 'SELECT current_schemas(TRUE)'
240            schemas = self.db.query(query).getresult()[0][0][1:-1].split(',')
241            if schemas: # non-empty path
242                # search schema for this object in the current search path
243                query = ' UNION '.join(
244                    ["SELECT %d::integer AS n, '%s'::name AS nspname"
245                        % s for s in enumerate(schemas)])
246                query = ("SELECT nspname FROM pg_class"
247                    " JOIN pg_namespace ON pg_class.relnamespace=pg_namespace.oid"
248                    " JOIN (%s) AS p USING (nspname)"
249                    " WHERE pg_class.relname='%s'"
250                    " ORDER BY n LIMIT 1" % (query, cl))
251                schema = self.db.query(query).getresult()
252                if schema: # schema found
253                    schema = schema[0][0]
254                else: # object not found in current search path
255                    schema = 'public'
256            else: # empty path
257                schema = 'public'
258        return schema, cl
259
260    def pkey(self, cl, newpkey = None):
261        """This method gets or sets the primary key of a class.
262
263        If newpkey is set and is not a dictionary then set that
264        value as the primary key of the class.  If it is a dictionary
265        then replace the _pkeys dictionary with it.
266
267        """
268        # First see if the caller is supplying a dictionary
269        if isinstance(newpkey, dict):
270            # make sure that we have a namespace
271            self._pkeys = {}
272            for x in newpkey.keys():
273                if x.find('.') == -1:
274                    self._pkeys['public.' + x] = newpkey[x]
275                else:
276                    self._pkeys[x] = newpkey[x]
277            return self._pkeys
278
279        qcl = _join_parts(self._split_schema(cl)) # build qualified name
280        if newpkey:
281            self._pkeys[qcl] = newpkey
282            return newpkey
283
284        # Get all the primary keys at once
285        if qcl not in self._pkeys:
286            # if not found, check again in case it was added after we started
287            self._pkeys = dict([
288                (_join_parts(r[:2]), r[2]) for r in self.db.query(
289                "SELECT pg_namespace.nspname, pg_class.relname"
290                    ",pg_attribute.attname FROM pg_class"
291                " JOIN pg_namespace ON pg_namespace.oid=pg_class.relnamespace"
292                    " AND pg_namespace.nspname NOT LIKE 'pg_%'"
293                " JOIN pg_attribute ON pg_attribute.attrelid=pg_class.oid"
294                    " AND pg_attribute.attisdropped='f'"
295                " JOIN pg_index ON pg_index.indrelid=pg_class.oid"
296                    " AND pg_index.indisprimary='t'"
297                    " AND pg_index.indkey[0]=pg_attribute.attnum"
298                ).getresult()])
299            self._do_debug(self._pkeys)
300        # will raise an exception if primary key doesn't exist
301        return self._pkeys[qcl]
302
303    def get_databases(self):
304        """Get list of databases in the system."""
305        return [s[0] for s in
306            self.db.query('SELECT datname FROM pg_database').getresult()]
307
308    def get_relations(self, kinds = None):
309        """Get list of relations in connected database of specified kinds.
310
311            If kinds is None or empty, all kinds of relations are returned.
312            Otherwise kinds can be a string or sequence of type letters
313            specifying which kind of relations you want to list.
314
315        """
316        where = kinds and "pg_class.relkind IN (%s) AND" % ','.join(
317            ["'%s'" % x for x in kinds]) or ''
318        return map(_join_parts, self.db.query(
319            "SELECT pg_namespace.nspname, pg_class.relname "
320            "FROM pg_class "
321            "JOIN pg_namespace ON pg_namespace.oid=pg_class.relnamespace "
322            "WHERE %s pg_class.relname !~ '^Inv' AND "
323                "pg_class.relname !~ '^pg_' "
324            "ORDER BY 1, 2" % where).getresult())
325
326    def get_tables(self):
327        """Return list of tables in connected database."""
328        return self.get_relations('r')
329
330    def get_attnames(self, cl, newattnames = None):
331        """Given the name of a table, digs out the set of attribute names.
332
333        Returns a dictionary of attribute names (the names are the keys,
334        the values are the names of the attributes' types).
335        If the optional newattnames exists, it must be a dictionary and
336        will become the new attribute names dictionary.
337
338        """
339        if isinstance(newattnames, dict):
340            self._attnames = newattnames
341            return
342        elif newattnames:
343            raise ProgrammingError(
344                'If supplied, newattnames must be a dictionary')
345        cl = self._split_schema(cl) # split into schema and cl
346        qcl = _join_parts(cl) # build qualified name
347        # May as well cache them:
348        if qcl in self._attnames:
349            return self._attnames[qcl]
350        if qcl not in self.get_relations('rv'):
351            raise ProgrammingError('Class %s does not exist' % qcl)
352        t = {}
353        for att, typ in self.db.query("SELECT pg_attribute.attname"
354            ",pg_type.typname FROM pg_class"
355            " JOIN pg_namespace ON pg_class.relnamespace=pg_namespace.oid"
356            " JOIN pg_attribute ON pg_attribute.attrelid=pg_class.oid"
357            " JOIN pg_type ON pg_type.oid=pg_attribute.atttypid"
358            " WHERE pg_namespace.nspname='%s' AND pg_class.relname='%s'"
359            " AND (pg_attribute.attnum>0 or pg_attribute.attname='oid')"
360            " AND pg_attribute.attisdropped='f'"
361                % cl).getresult():
362            if typ.startswith('bool'):
363                t[att] = 'bool'
364            elif typ.startswith('oid'):
365                t[att] = 'int'
366            elif typ.startswith('float'):
367                t[att] = 'float'
368            elif typ.startswith('numeric'):
369                t[att] = 'num'
370            elif typ.startswith('abstime'):
371                t[att] = 'date'
372            elif typ.startswith('date'):
373                t[att] = 'date'
374            elif typ.startswith('interval'):
375                t[att] = 'date'
376            elif typ.startswith('int'):
377                t[att] = 'int'
378            elif typ.startswith('timestamp'):
379                t[att] = 'date'
380            elif typ.startswith('money'):
381                t[att] = 'money'
382            else:
383                t[att] = 'text'
384        self._attnames[qcl] = t # cache it
385        return self._attnames[qcl]
386
387    def get(self, cl, arg, keyname = None, view = 0):
388        """Get a tuple from a database table or view.
389
390        This method is the basic mechanism to get a single row.  It assumes
391        that the key specifies a unique row.  If keyname is not specified
392        then the primary key for the table is used.  If arg is a dictionary
393        then the value for the key is taken from it and it is modified to
394        include the new values, replacing existing values where necessary.
395        The OID is also put into the dictionary, but in order to allow the
396        caller to work with multiple tables, it is munged as oid(schema.table).
397
398        """
399        if cl.endswith('*'): # scan descendant tables?
400            cl = cl[:-1].rstrip() # need parent table name
401        qcl = _join_parts(self._split_schema(cl)) # build qualified name
402        # To allow users to work with multiple tables,
403        # we munge the name when the key is "oid"
404        foid = 'oid(%s)' % qcl # build mangled name
405        if keyname is None: # use the primary key by default
406            keyname = self.pkey(qcl)
407        fnames = self.get_attnames(qcl)
408        if isinstance(arg, dict):
409            # XXX this code is for backwards compatibility and will be
410            # XXX removed eventually
411            if foid not in arg:
412                ofoid = 'oid_' + self._split_schema(cl)[-1]
413                if ofoid in arg:
414                    arg[foid] = arg[ofoid]
415
416            k = arg[keyname == 'oid' and foid or keyname]
417        else:
418            k = arg
419            arg = {}
420        # We want the oid for later updates if that isn't the key
421        if keyname == 'oid':
422            q = 'SELECT * FROM %s WHERE oid=%s LIMIT 1' % (qcl, k)
423        elif view:
424            q = 'SELECT * FROM %s WHERE %s=%s LIMIT 1' % \
425                (qcl, keyname, _quote(k, fnames[keyname]))
426        else:
427            q = 'SELECT %s FROM %s WHERE %s=%s LIMIT 1' % \
428                (','.join(fnames.keys()), qcl, \
429                    keyname, _quote(k, fnames[keyname]))
430        self._do_debug(q)
431        res = self.db.query(q).dictresult()
432        if not res:
433            raise DatabaseError('No such record in %s where %s=%s'
434                % (qcl, keyname, _quote(k, fnames[keyname])))
435        for k, d in res[0].items():
436            if k == 'oid':
437                k = foid
438            arg[k] = d
439        return arg
440
441    def insert(self, cl, d = None, **kw):
442        """Insert a tuple into a database table.
443
444        This method inserts values into the table specified filling in the
445        values from the dictionary.  It then reloads the dictionary with the
446        values from the database.  This causes the dictionary to be updated
447        with values that are modified by rules, triggers, etc.
448
449        Note: The method currently doesn't support insert into views
450        although PostgreSQL does.
451
452        """
453        if d is None:
454            a = {}
455        else:
456            a = d
457        a.update(kw)
458
459        qcl = _join_parts(self._split_schema(cl)) # build qualified name
460        foid = 'oid(%s)' % qcl # build mangled name
461        fnames = self.get_attnames(qcl)
462        t = []
463        n = []
464        for f in fnames.keys():
465            if f != 'oid' and f in a:
466                t.append(_quote(a[f], fnames[f]))
467                n.append('"%s"' % f)
468        q = 'INSERT INTO %s (%s) VALUES (%s)' % \
469            (qcl, ','.join(n), ','.join(t))
470        self._do_debug(q)
471        a[foid] = self.db.query(q)
472        # Reload the dictionary to catch things modified by engine.
473        # Note that get() changes 'oid' below to oid_schema_table.
474        # If no read perms (it can and does happen), return None.
475        try:
476            return self.get(qcl, a, 'oid')
477        except:
478            return None
479
480    def update(self, cl, d = None, **kw):
481        """Update an existing row in a database table.
482
483        Similar to insert but updates an existing row.  The update is based
484        on the OID value as munged by get.  The array returned is the
485        one sent modified to reflect any changes caused by the update due
486        to triggers, rules, defaults, etc.
487
488        """
489        # Update always works on the oid which get returns if available,
490        # otherwise use the primary key.  Fail if neither.
491        qcl = _join_parts(self._split_schema(cl)) # build qualified name
492        foid = 'oid(%s)' % qcl # build mangled oid
493
494        # Note that we only accept oid key from named args for safety
495        if 'oid' in kw:
496            kw[foid] = kw['oid']
497            del kw['oid']
498
499        if d is None:
500            a = {}
501        else:
502            a = d
503        a.update(kw)
504
505        # XXX this code is for backwards compatibility and will be
506        # XXX removed eventually
507        if foid not in a:
508            ofoid = 'oid_' + self._split_schema(cl)[-1]
509            if ofoid in a:
510                a[foid] = a[ofoid]
511
512        if foid in a:
513            where = "oid=%s" % a[foid]
514        else:
515            try:
516                pk = self.pkey(qcl)
517            except:
518                raise ProgrammingError(
519                    'Update needs primary key or oid as %s' % foid)
520            where = "%s='%s'" % (pk, a[pk])
521        v = []
522        k = 0
523        fnames = self.get_attnames(qcl)
524        for ff in fnames.keys():
525            if ff != 'oid' and ff in a:
526                v.append('%s=%s' % (ff, _quote(a[ff], fnames[ff])))
527        if v == []:
528            return None
529        q = 'UPDATE %s SET %s WHERE %s' % (qcl, ','.join(v), where)
530        self._do_debug(q)
531        self.db.query(q)
532        # Reload the dictionary to catch things modified by engine:
533        if foid in a:
534            return self.get(qcl, a, 'oid')
535        else:
536            return self.get(qcl, a)
537
538    def clear(self, cl, a = None):
539        """
540
541        This method clears all the attributes to values determined by the types.
542        Numeric types are set to 0, Booleans are set to 'f', and everything
543        else is set to the empty string.  If the array argument is present,
544        it is used as the array and any entries matching attribute names are
545        cleared with everything else left unchanged.
546
547        """
548        # At some point we will need a way to get defaults from a table.
549        if a is None:
550            a = {} # empty if argument is not present
551        qcl = _join_parts(self._split_schema(cl)) # build qualified name
552        foid = 'oid(%s)' % qcl # build mangled oid
553        fnames = self.get_attnames(qcl)
554        for k, t in fnames.items():
555            if k == 'oid':
556                continue
557            if t in ['int', 'seq', 'float', 'num', 'money']:
558                a[k] = 0
559            elif t == 'bool':
560                a[k] = 'f'
561            else:
562                a[k] = ''
563        return a
564
565    def delete(self, cl, d = None, **kw):
566        """Delete an existing row in a database table.
567
568        This method deletes the row from a table.
569        It deletes based on the OID munged as described above."""
570
571        # Like update, delete works on the oid.
572        # One day we will be testing that the record to be deleted
573        # isn't referenced somewhere (or else PostgreSQL will).
574        qcl = _join_parts(self._split_schema(cl)) # build qualified name
575        foid = 'oid(%s)' % qcl # build mangled oid
576
577        # Note that we only accept oid key from named args for safety
578        if 'oid' in kw:
579            kw[foid] = kw['oid']
580            del kw['oid']
581
582        if d is None:
583            a = {}
584        else:
585            a = d
586        a.update(kw)
587
588        # XXX this code is for backwards compatibility and will be
589        # XXX removed eventually
590        if foid not in a:
591            ofoid = 'oid_' + self._split_schema(cl)[-1]
592            if ofoid in a:
593                a[foid] = a[ofoid]
594
595        q = 'DELETE FROM %s WHERE oid=%s' % (qcl, a[foid])
596        self._do_debug(q)
597        self.db.query(q)
598
599
600# if run as script, print some information
601
602if __name__ == '__main__':
603    print 'PyGreSQL version', version
604    print
605    print __doc__
Note: See TracBrowser for help on using the repository browser.