source: trunk/module/pg.py @ 282

Last change on this file since 282 was 282, checked in by darcy, 14 years ago

Add another backwards compatibility code for OID mangling.

File size: 16.1 KB
Line 
1#!/usr/bin/env python
2#
3# pg.py
4#
5# Written by D'Arcy J.M. Cain
6# Improved by Christoph Zwerschke
7#
8# $Id: pg.py,v 1.43 2006-04-13 17:48:05 darcy Exp $
9#
10
11"""PyGreSQL classic interface.
12
13This pg module implements some basic database management stuff.
14It includes the _pg module and builds on it, providing the higher
15level wrapper class named DB with addtional functionality.
16This is known as the "classic" ("old style") PyGreSQL interface.
17For a DB-API 2 compliant interface use the newer pgdb module.
18
19"""
20
21from _pg import *
22from types import *
23
24# Auxiliary functions which are independent from a DB connection:
25
26def _quote(d, t):
27        """Return quotes if needed."""
28        if d is None:
29                return 'NULL'
30        if t in ('int', 'seq', 'decimal'):
31                if d == '': return 'NULL'
32                return str(d)
33        if t == 'money':
34                if d == '': return 'NULL'
35                return "'%.2f'" % float(d)
36        if t == 'bool':
37                if type(d) == StringType:
38                        if d == '': return 'NULL'
39                        d = str(d).lower() in ('t', 'true', '1', 'y', 'yes', 'on')
40                else:
41                        d = not not d
42                return ("'f'", "'t'")[d]
43        if t in ('date', 'inet', 'cidr'):
44                if d == '': return 'NULL'
45        return "'%s'" % str(d).replace('\\','\\\\').replace('\'','\\\'')
46
47def _is_quoted(s):
48        """Check whether this string is a quoted identifier."""
49        s = s.replace('_', 'a')
50        return not s.isalnum() or s[:1].isdigit() or s != s.lower()
51
52def _is_unquoted(s):
53        """Check whether this string is an unquoted identifier."""
54        s = s.replace('_', 'a')
55        return s.isalnum() and not s[:1].isdigit()
56
57def _split_first_part(s):
58        """Split the first part of a dot separated string."""
59        s = s.lstrip()
60        if s[:1] == '"':
61                p = []
62                s = s.split('"', 3)[1:]
63                p.append(s[0])
64                while len(s) == 3 and s[1] == '':
65                        p.append('"')
66                        s = s[2].split('"', 2)
67                        p.append(s[0])
68                p = [''.join(p)]
69                s = '"'.join(s[1:]).lstrip()
70                if s:
71                        if s[:0] == '.':
72                                p.append(s[1:])
73                        else:
74                                s = _split_first_part(s)
75                                p[0] += s[0]
76                                if len(s) > 1:
77                                        p.append(s[1])
78        else:
79                p = s.split('.', 1)
80                s = p[0].rstrip()
81                if _is_unquoted(s):
82                        s = s.lower()
83                p[0] = s
84        return p
85
86def _split_parts(s):
87        """Split all parts of a dot separated string."""
88        q = []
89        while s:
90                s = _split_first_part(s)
91                q.append(s[0])
92                if len(s) < 2: break
93                s = s[1]
94        return q
95
96def _join_parts(s):
97        """Join all parts of a dot separated string."""
98        return '.'.join([_is_quoted(p) and '"%s"' % p or p for p in s])
99
100# The PostGreSQL database connection interface:
101
102class DB:
103        """Wrapper class for the _pg connection type."""
104
105        def __init__(self, *args, **kw):
106                self.db = connect(*args, **kw)
107                self.dbname = self.db.db
108                self.__attnames = {}
109                self.__pkeys = {}
110                self.__args = args, kw
111                self.debug = None # For debugging scripts, this can be set
112                        # * to a string format specification (e.g. in a CGI set to "%s<BR>"),
113                        # * to a function which takes a string argument or
114                        # * to a file object to write debug statements to.
115
116        def __getattr__(self, name):
117                # All undefined members are the same as in the underlying pg connection:
118                if self.db:
119                        return getattr(self.db, name)
120                else:
121                        raise InternalError, 'Connection is not valid'
122
123        def _do_debug(self, s):
124                """Print a debug message."""
125                if not self.debug: return
126                if isinstance(self.debug, StringType): print self.debug % s
127                if isinstance(self.debug, FunctionType): self.debug(s)
128                if isinstance(self.debug, FileType): print >> self.debug, s
129
130        def close(self):
131                """Close the database connection."""
132                # Wraps shared library function so we can track state.
133
134                if self.db:
135                        self.db.close()
136                        self.db = None
137                else:
138                        raise InternalError, 'Connection already closed'
139
140        def reopen(self):
141                """Reopen connection to the database.
142
143                Used in case we need another connection to the same database.
144                Note that we can still reopen a database that we have closed.
145
146                """
147                if self.db:
148                        self.db.close()
149                try:
150                        self.db = connect(*self.__args[0], **self.__args[1])
151                except:
152                        self.db = None
153                        raise
154
155        def query(self, qstr):
156                """Executes a SQL command string.
157
158                This method simply sends a SQL query to the database. If the query is
159                an insert statement, the return value is the OID of the newly
160                inserted row.  If it is otherwise a query that does not return a result
161                (ie. is not a some kind of SELECT statement), it returns None.
162                Otherwise, it returns a pgqueryobject that can be accessed via the
163                getresult or dictresult method or simply printed.
164
165                """
166                # Wraps shared library function for debugging.
167                if not self.db:
168                        raise InternalError, 'Connection is not valid'
169                self._do_debug(qstr)
170                return self.db.query(qstr)
171
172        def _split_schema(self, cl):
173                """Return schema and name of object separately.
174
175                This auxiliary function splits off the namespace (schema)
176                belonging to the class with the name cl. If the class name
177                is not qualified, the function is able to determine the schema
178                of the class, taking into account the current search path.
179
180                """
181                s = _split_parts(cl)
182                if len(s) > 1: # name already qualfied?
183                        # should be database.schema.table or schema.table
184                        if len(s) > 3:
185                                raise ProgrammingError, 'Too many dots in class name %s' % cl
186                        schema, cl = s[-2:]
187                else:
188                        cl = s[0]
189                        # determine search path
190                        query = 'SELECT current_schemas(TRUE)'
191                        schemas = self.db.query(query).getresult()[0][0][1:-1].split(',')
192                        if schemas: # non-empty path
193                                # search schema for this object in the current search path
194                                query = ' UNION '.join(["SELECT %d AS n, '%s' AS nspname" % s
195                                        for s in enumerate(schemas)])
196                                query = ("SELECT nspname FROM pg_class"
197                                        " JOIN pg_namespace ON pg_class.relnamespace=pg_namespace.oid"
198                                        " JOIN (%s) AS p USING (nspname)"
199                                        " WHERE pg_class.relname='%s'"
200                                        " ORDER BY n LIMIT 1" % (query, cl))
201                                schema = self.db.query(query).getresult()
202                                if schema: # schema found
203                                        schema = schema[0][0]
204                                else: # object not found in current search path
205                                        schema = 'public'
206                        else: # empty path
207                                schema = 'public'
208                return schema, cl
209
210        def pkey(self, cl, newpkey = None):
211                """This method gets or sets the primary key of a class.
212
213                If newpkey is set and is not a dictionary then set that
214                value as the primary key of the class.  If it is a dictionary
215                then replace the __pkeys dictionary with it.
216
217                """
218                # Get all the primary keys at once
219                if isinstance(newpkey, DictType):
220                        self.__pkeys = newpkey
221                        return newpkey
222                qcl = _join_parts(self._split_schema(cl)) # build qualified name
223                if newpkey:
224                        self.__pkeys[qcl] = newpkey
225                        return newpkey
226                if self.__pkeys == {} or not self.__pkeys.has_key(qcl):
227                        # if not found, determine pkey again in case it was added after we started
228                        for r in self.db.query("SELECT pg_namespace.nspname"
229                                ",pg_class.relname,pg_attribute.attname FROM pg_class"
230                                " JOIN pg_namespace ON pg_namespace.oid=pg_class.relnamespace"
231                                " AND pg_namespace.nspname NOT LIKE 'pg_%'"
232                                " JOIN pg_attribute ON pg_attribute.attrelid=pg_class.oid"
233                                " AND pg_attribute.attisdropped='f'"
234                                " JOIN pg_index ON pg_index.indrelid=pg_class.oid"
235                                " AND pg_index.indisprimary='t'"
236                                " AND pg_index.indkey[0]=pg_attribute.attnum").getresult():
237                                self.__pkeys[_join_parts(r[:2])] = r[2] # build qualified name
238                        self._do_debug(self.__pkeys)
239                # will raise an exception if primary key doesn't exist
240                return self.__pkeys[qcl]
241
242        def get_databases(self):
243                """Get list of databases in the system."""
244                return [s for s, in
245                        self.db.query('SELECT datname FROM pg_database').getresult()]
246
247        def get_relations(self, kinds = None):
248                """Get list of relations in connected database of specified kinds.
249
250                        If kinds is None or empty, all kinds of relations are returned.
251                        Otherwise kinds can be a string or sequence of type letters
252                        specifying which kind of relations you want to list.
253
254                """
255                if kinds:
256                        where = "pg_class.relkind IN (%s) AND" % \
257                                                        ','.join(["'%s'" % x for x in kinds])
258                else:
259                        where = ''
260
261                return [_join_parts(s) for s in
262                        self.db.query(
263                                "SELECT pg_namespace.nspname, pg_class.relname "
264                                "FROM pg_class "
265                                "JOIN pg_namespace ON pg_namespace.oid=pg_class.relnamespace "
266                                "WHERE %s pg_class.relname !~ '^Inv' AND "
267                                        "pg_class.relname !~ '^pg_' "
268                                "ORDER BY 1,2" % where).getresult()]
269
270        def get_tables(self):
271                """Return list of tables in connected database."""
272                return self.get_relations('r')
273
274        def get_attnames(self, cl, newattnames = None):
275                """Given the name of a table, digs out the set of attribute names.
276
277                Returns a dictionary of attribute names (the names are the keys,
278                the values are the names of the attributes' types).
279                If the optional newattnames exists, it must be a dictionary and
280                will become the new attribute names dictionary.
281
282                """
283                if isinstance(newattnames, DictType):
284                        self.__attnames = newattnames
285                        return
286                elif newattnames:
287                        raise ProgrammingError, \
288                                'If supplied, newattnames must be a dictionary'
289                cl = self._split_schema(cl) # split into schema and cl
290                qcl = _join_parts(cl) # build qualified name
291                # May as well cache them:
292                if self.__attnames.has_key(qcl):
293                        return self.__attnames[qcl]
294                if qcl not in self.get_relations('rv'):
295                        raise ProgrammingError, 'Class %s does not exist' % qcl
296                t = {}
297                for att, typ in self.db.query("SELECT pg_attribute.attname"
298                        ",pg_type.typname FROM pg_class"
299                        " JOIN pg_namespace ON pg_class.relnamespace=pg_namespace.oid"
300                        " JOIN pg_attribute ON pg_attribute.attrelid=pg_class.oid"
301                        " JOIN pg_type ON pg_type.oid=pg_attribute.atttypid"
302                        " WHERE pg_namespace.nspname='%s' AND pg_class.relname='%s'"
303                        " AND (pg_attribute.attnum>0 or pg_attribute.attname='oid')"
304                        " AND pg_attribute.attisdropped='f'"
305                                % cl).getresult():
306                        if typ.startswith('bool'):
307                                t[att] = 'bool'
308                        elif typ.startswith('oid'):
309                                t[att] = 'int'
310                        elif typ.startswith('float'):
311                                t[att] = 'decimal'
312                        elif typ.startswith('abstime'):
313                                t[att] = 'date'
314                        elif typ.startswith('date'):
315                                t[att] = 'date'
316                        elif typ.startswith('interval'):
317                                t[att] = 'date'
318                        elif typ.startswith('int'):
319                                t[att] = 'int'
320                        elif typ.startswith('timestamp'):
321                                t[att] = 'date'
322                        elif typ.startswith('money'):
323                                t[att] = 'money'
324                        else:
325                                t[att] = 'text'
326                self.__attnames[qcl] = t # cache it
327                return self.__attnames[qcl]
328
329        def get(self, cl, arg, keyname = None, view = 0):
330                """Get a tuple from a database table or view.
331
332                This method is the basic mechanism to get a single row.  It assumes
333                that the key specifies a unique row.  If keyname is not specified
334                then the primary key for the table is used.  If arg is a dictionary
335                then the value for the key is taken from it and it is modified to
336                include the new values, replacing existing values where necessary.
337                The OID is also put into the dictionary, but in order to allow the
338                caller to work with multiple tables, it is munged as oid(schema.table).
339
340                """
341                if cl.endswith('*'): # scan descendant tables?
342                        cl = cl[:-1].rstrip() # need parent table name
343                qcl = _join_parts(self._split_schema(cl)) # build qualified name
344                # To allow users to work with multiple tables,
345                # we munge the name when the key is "oid"
346                foid = 'oid(%s)' % qcl # build mangled name
347
348                # XXX this code is for backwards compatibility and will be
349                # XXX removed eventually
350                if not a.has_key(foid):
351                        ofoid = 'oid_' + self._split_schema(cl)[-1]
352                        if a.has_key(ofoid):
353                                a[foid] = a[ofoid]
354
355                if keyname == None: # use the primary key by default
356                        keyname = self.pkey(qcl)
357                fnames = self.get_attnames(qcl)
358                if isinstance(arg, DictType):
359                        k = arg[keyname == 'oid' and foid or keyname]
360                else:
361                        k = arg
362                        arg = {}
363                # We want the oid for later updates if that isn't the key
364                if keyname == 'oid':
365                        q = 'SELECT * FROM %s WHERE oid=%s LIMIT 1' % (qcl, k)
366                elif view:
367                        q = 'SELECT * FROM %s WHERE %s=%s LIMIT 1' % \
368                                (qcl, keyname, _quote(k, fnames[keyname]))
369                else:
370                        q = 'SELECT %s FROM %s WHERE %s=%s LIMIT 1' % \
371                                (','.join(fnames.keys()), qcl, \
372                                        keyname, _quote(k, fnames[keyname]))
373                self._do_debug(q)
374                res = self.db.query(q).dictresult()
375                if not res:
376                        raise DatabaseError, \
377                                'No such record in %s where %s=%s' % \
378                                        (qcl, keyname, _quote(k, fnames[keyname]))
379                for k, d in res[0].items():
380                        if k == 'oid':
381                                k = foid
382                        arg[k] = d
383                return arg
384
385        def insert(self, cl, a):
386                """Insert a tuple into a database table.
387
388                This method inserts values into the table specified filling in the
389                values from the dictionary.  It then reloads the dictionary with the
390                values from the database.  This causes the dictionary to be updated
391                with values that are modified by rules, triggers, etc.
392
393                Note: The method currently doesn't support insert into views
394                although PostgreSQL does.
395
396                """
397                qcl = _join_parts(self._split_schema(cl)) # build qualified name
398                foid = 'oid(%s)' % qcl # build mangled name
399                fnames = self.get_attnames(qcl)
400                t = []
401                n = []
402                for f in fnames.keys():
403                        if f != 'oid' and a.has_key(f):
404                                t.append(_quote(a[f], fnames[f]))
405                                n.append(f)
406                q = 'INSERT INTO %s (%s) VALUES (%s)' % \
407                        (qcl, ','.join(n), ','.join(t))
408                self._do_debug(q)
409                a[foid] = self.db.query(q)
410                # Reload the dictionary to catch things modified by engine.
411                # Note that get() changes 'oid' below to oid_schema_table.
412                # If no read perms (it can and does happen), return None.
413                try:
414                        return self.get(qcl, a, 'oid')
415                except:
416                        return None
417
418        def update(self, cl, a):
419                """Update an existing row in a database table.
420
421                Similar to insert but updates an existing row.  The update is based
422                on the OID value as munged by get.  The array returned is the
423                one sent modified to reflect any changes caused by the update due
424                to triggers, rules, defaults, etc.
425
426                """
427                # Update always works on the oid which get returns if available,
428                # otherwise use the primary key.  Fail if neither.
429                qcl = _join_parts(self._split_schema(cl)) # build qualified name
430                foid = 'oid(%s)' % qcl # build mangled oid
431
432                # XXX this code is for backwards compatibility and will be
433                # XXX removed eventually
434                if not a.has_key(foid):
435                        ofoid = 'oid_' + self._split_schema(cl)[-1]
436                        if a.has_key(ofoid):
437                                a[foid] = a[ofoid]
438
439                if a.has_key(foid):
440                        where = "oid=%s" % a[foid]
441                else:
442                        try:
443                                pk = self.pkey(qcl)
444                        except:
445                                raise ProgrammingError, \
446                                        'Update needs primary key or oid as %s' % foid
447                        where = "%s='%s'" % (pk, a[pk])
448                v = []
449                k = 0
450                fnames = self.get_attnames(qcl)
451                for ff in fnames.keys():
452                        if ff != 'oid' and a.has_key(ff):
453                                v.append('%s=%s' % (ff, _quote(a[ff], fnames[ff])))
454                if v == []:
455                        return None
456                q = 'UPDATE %s SET %s WHERE %s' % (qcl, ','.join(v), where)
457                self._do_debug(q)
458                self.db.query(q)
459                # Reload the dictionary to catch things modified by engine:
460                if a.has_key(foid):
461                        return self.get(qcl, a, 'oid')
462                else:
463                        return self.get(qcl, a)
464
465        def clear(self, cl, a = None):
466                """
467
468                This method clears all the attributes to values determined by the types.
469                Numeric types are set to 0, Booleans are set to 'f', and everything
470                else is set to the empty string.  If the array argument is present,
471                it is used as the array and any entries matching attribute names are
472                cleared with everything else left unchanged.
473
474                """
475                # At some point we will need a way to get defaults from a table.
476                if a is None: a = {} # empty if argument is not present
477                qcl = _join_parts(self._split_schema(cl)) # build qualified name
478                foid = 'oid(%s)' % qcl # build mangled oid
479                fnames = self.get_attnames(qcl)
480                for k, t in fnames.items():
481                        if k == 'oid': continue
482                        if t in ['int', 'decimal', 'seq', 'money']:
483                                a[k] = 0
484                        elif t == 'bool':
485                                a[k] = 'f'
486                        else:
487                                a[k] = ''
488                return a
489
490        def delete(self, cl, a):
491                """Delete an existing row in a database table.
492
493                This method deletes the row from a table.
494                It deletes based on the OID munged as described above.
495
496                """
497                # Like update, delete works on the oid.
498                # One day we will be testing that the record to be deleted
499                # isn't referenced somewhere (or else PostgreSQL will).
500                qcl = _join_parts(self._split_schema(cl)) # build qualified name
501                foid = 'oid(%s)' % qcl # build mangled oid
502
503                # XXX this code is for backwards compatibility and will be
504                # XXX removed eventually
505                if not a.has_key(foid):
506                        ofoid = 'oid_' + self._split_schema(cl)[-1]
507                        if a.has_key(ofoid):
508                                a[foid] = a[ofoid]
509
510                q = 'DELETE FROM %s WHERE oid=%s' % (qcl, a[foid])
511                self._do_debug(q)
512                self.db.query(q)
513
514# if run as script, print some information
515if __name__ == '__main__':
516        print 'PyGreSQL version', version
517        print
518        print __doc__
Note: See TracBrowser for help on using the repository browser.