source: trunk/module/pg.py @ 275

Last change on this file since 275 was 275, checked in by cito, 14 years ago

Minor coding improvement in backwards compatibility hack.

File size: 15.9 KB
Line 
1#!/usr/bin/env python
2#
3# pg.py
4#
5# Written by D'Arcy J.M. Cain
6# Improved by Christoph Zwerschke
7#
8# $Id: pg.py,v 1.41 2006-03-13 16:57:34 cito Exp $
9#
10
11"""PyGreSQL classic interface.
12
13This pg module implements some basic database management stuff.
14It includes the _pg module and builds on it, providing the higher
15level wrapper class named DB with addtional functionality.
16This is known as the "classic" ("old style") PyGreSQL interface.
17For a DB-API 2 compliant interface use the newer pgdb module.
18
19"""
20
21from _pg import *
22from types import *
23
24# Auxiliary functions which are independent from a DB connection:
25
26def _quote(d, t):
27        """Return quotes if needed."""
28        if d is None:
29                return 'NULL'
30        if t in ('int', 'seq', 'decimal'):
31                if d == '': return 'NULL'
32                return str(d)
33        if t == 'money':
34                if d == '': return 'NULL'
35                return "'%.2f'" % float(d)
36        if t == 'bool':
37                if type(d) == StringType:
38                        if d == '': return 'NULL'
39                        d = str(d).lower() in ('t', 'true', '1', 'y', 'yes', 'on')
40                else:
41                        d = not not d
42                return ("'f'", "'t'")[d]
43        if t in ('date', 'inet', 'cidr'):
44                if d == '': return 'NULL'
45        return "'%s'" % str(d).replace('\\','\\\\').replace('\'','\\\'')
46
47def _is_quoted(s):
48        """Check whether this string is a quoted identifier."""
49        s = s.replace('_', 'a')
50        return not s.isalnum() or s[:1].isdigit() or s != s.lower()
51
52def _is_unquoted(s):
53        """Check whether this string is an unquoted identifier."""
54        s = s.replace('_', 'a')
55        return s.isalnum() and not s[:1].isdigit()
56
57def _split_first_part(s):
58        """Split the first part of a dot separated string."""
59        s = s.lstrip()
60        if s[:1] == '"':
61                p = []
62                s = s.split('"', 3)[1:]
63                p.append(s[0])
64                while len(s) == 3 and s[1] == '':
65                        p.append('"')
66                        s = s[2].split('"', 2)
67                        p.append(s[0])
68                p = [''.join(p)]
69                s = '"'.join(s[1:]).lstrip()
70                if s:
71                        if s[:0] == '.':
72                                p.append(s[1:])
73                        else:
74                                s = _split_first_part(s)
75                                p[0] += s[0]
76                                if len(s) > 1:
77                                        p.append(s[1])
78        else:
79                p = s.split('.', 1)
80                s = p[0].rstrip()
81                if _is_unquoted(s):
82                        s = s.lower()
83                p[0] = s
84        return p
85
86def _split_parts(s):
87        """Split all parts of a dot separated string."""
88        q = []
89        while s:
90                s = _split_first_part(s)
91                q.append(s[0])
92                if len(s) < 2: break
93                s = s[1]
94        return q
95
96def _join_parts(s):
97        """Join all parts of a dot separated string."""
98        return '.'.join([_is_quoted(p) and '"%s"' % p or p for p in s])
99
100# The PostGreSQL database connection interface:
101
102class DB:
103        """Wrapper class for the _pg connection type."""
104
105        def __init__(self, *args, **kw):
106                self.db = connect(*args, **kw)
107                self.dbname = self.db.db
108                self.__attnames = {}
109                self.__pkeys = {}
110                self.__args = args, kw
111                self.debug = None # For debugging scripts, this can be set
112                        # * to a string format specification (e.g. in a CGI set to "%s<BR>"),
113                        # * to a function which takes a string argument or
114                        # * to a file object to write debug statements to.
115
116        def __getattr__(self, name):
117                # All undefined members are the same as in the underlying pg connection:
118                if self.db:
119                        return getattr(self.db, name)
120                else:
121                        raise InternalError, 'Connection is not valid'
122
123        def _do_debug(self, s):
124                """Print a debug message."""
125                if not self.debug: return
126                if isinstance(self.debug, StringType): print self.debug % s
127                if isinstance(self.debug, FunctionType): self.debug(s)
128                if isinstance(self.debug, FileType): print >> self.debug, s
129
130        def close(self):
131                """Close the database connection."""
132                # Wraps shared library function so we can track state.
133
134                if self.db:
135                        self.db.close()
136                        self.db = None
137                else:
138                        raise InternalError, 'Connection already closed'
139
140        def reopen(self):
141                """Reopen connection to the database.
142
143                Used in case we need another connection to the same database.
144                Note that we can still reopen a database that we have closed.
145
146                """
147                if self.db:
148                        self.db.close()
149                try:
150                        self.db = connect(*self.__args[0], **self.__args[1])
151                except:
152                        self.db = None
153                        raise
154
155        def query(self, qstr):
156                """Executes a SQL command string.
157
158                This method simply sends a SQL query to the database. If the query is
159                an insert statement, the return value is the OID of the newly
160                inserted row.  If it is otherwise a query that does not return a result
161                (ie. is not a some kind of SELECT statement), it returns None.
162                Otherwise, it returns a pgqueryobject that can be accessed via the
163                getresult or dictresult method or simply printed.
164
165                """
166                # Wraps shared library function for debugging.
167                if not self.db:
168                        raise InternalError, 'Connection is not valid'
169                self._do_debug(qstr)
170                return self.db.query(qstr)
171
172        def _split_schema(self, cl):
173                """Return schema and name of object separately.
174
175                This auxiliary function splits off the namespace (schema)
176                belonging to the class with the name cl. If the class name
177                is not qualified, the function is able to determine the schema
178                of the class, taking into account the current search path.
179
180                """
181                s = _split_parts(cl)
182                if len(s) > 1: # name already qualfied?
183                        # should be database.schema.table or schema.table
184                        if len(s) > 3:
185                                raise ProgrammingError, 'Too many dots in class name %s' % cl
186                        schema, cl = s[-2:]
187                else:
188                        cl = s[0]
189                        # determine search path
190                        query = 'SELECT current_schemas(TRUE)'
191                        schemas = self.db.query(query).getresult()[0][0][1:-1].split(',')
192                        if schemas: # non-empty path
193                                # search schema for this object in the current search path
194                                query = ' UNION '.join(["SELECT %d AS n, '%s' AS nspname" % s
195                                        for s in enumerate(schemas)])
196                                query = ("SELECT nspname FROM pg_class"
197                                        " JOIN pg_namespace ON pg_class.relnamespace=pg_namespace.oid"
198                                        " JOIN (%s) AS p USING (nspname)"
199                                        " WHERE pg_class.relname='%s'"
200                                        " ORDER BY n LIMIT 1" % (query, cl))
201                                schema = self.db.query(query).getresult()
202                                if schema: # schema found
203                                        schema = schema[0][0]
204                                else: # object not found in current search path
205                                        schema = 'public'
206                        else: # empty path
207                                schema = 'public'
208                return schema, cl
209
210        def pkey(self, cl, newpkey = None):
211                """This method gets or sets the primary key of a class.
212
213                If newpkey is set and is not a dictionary then set that
214                value as the primary key of the class.  If it is a dictionary
215                then replace the __pkeys dictionary with it.
216
217                """
218                # Get all the primary keys at once
219                if isinstance(newpkey, DictType):
220                        self.__pkeys = newpkey
221                        return newpkey
222                qcl = _join_parts(self._split_schema(cl)) # build qualified name
223                if newpkey:
224                        self.__pkeys[qcl] = newpkey
225                        return newpkey
226                if self.__pkeys == {} or not self.__pkeys.has_key(qcl):
227                        # if not found, determine pkey again in case it was added after we started
228                        for r in self.db.query("SELECT pg_namespace.nspname"
229                                ",pg_class.relname,pg_attribute.attname FROM pg_class"
230                                " JOIN pg_namespace ON pg_namespace.oid=pg_class.relnamespace"
231                                " AND pg_namespace.nspname NOT LIKE 'pg_%'"
232                                " JOIN pg_attribute ON pg_attribute.attrelid=pg_class.oid"
233                                " AND pg_attribute.attisdropped='f'"
234                                " JOIN pg_index ON pg_index.indrelid=pg_class.oid"
235                                " AND pg_index.indisprimary='t'"
236                                " AND pg_index.indkey[0]=pg_attribute.attnum").getresult():
237                                self.__pkeys[_join_parts(r[:2])] = r[2] # build qualified name
238                        self._do_debug(self.__pkeys)
239                # will raise an exception if primary key doesn't exist
240                return self.__pkeys[qcl]
241
242        def get_databases(self):
243                """Get list of databases in the system."""
244                return [s for s, in
245                        self.db.query('SELECT datname FROM pg_database').getresult()]
246
247        def get_relations(self, kinds = None):
248                """Get list of relations in connected database of specified kinds.
249
250                        If kinds is None or empty, all kinds of relations are returned.
251                        Otherwise kinds can be a string or sequence of type letters
252                        specifying which kind of relations you want to list.
253
254                """
255                if kinds:
256                        where = "pg_class.relkind IN (%s) AND" % \
257                                                        ','.join(["'%s'" % x for x in kinds])
258                else:
259                        where = ''
260
261                return [_join_parts(s) for s in
262                        self.db.query(
263                                "SELECT pg_namespace.nspname, pg_class.relname "
264                                "FROM pg_class "
265                                "JOIN pg_namespace ON pg_namespace.oid=pg_class.relnamespace "
266                                "WHERE %s pg_class.relname !~ '^Inv' AND "
267                                        "pg_class.relname !~ '^pg_' "
268                                "ORDER BY 1,2" % where).getresult()]
269
270        def get_tables(self):
271                """Return list of tables in connected database."""
272                return self.get_relations('r')
273
274        def get_attnames(self, cl, newattnames = None):
275                """Given the name of a table, digs out the set of attribute names.
276
277                Returns a dictionary of attribute names (the names are the keys,
278                the values are the names of the attributes' types).
279                If the optional newattnames exists, it must be a dictionary and
280                will become the new attribute names dictionary.
281
282                """
283                if isinstance(newattnames, DictType):
284                        self.__attnames = newattnames
285                        return
286                elif newattnames:
287                        raise ProgrammingError, \
288                                'If supplied, newattnames must be a dictionary'
289                cl = self._split_schema(cl) # split into schema and cl
290                qcl = _join_parts(cl) # build qualified name
291                # May as well cache them:
292                if self.__attnames.has_key(qcl):
293                        return self.__attnames[qcl]
294                if qcl not in self.get_relations('rv'):
295                        raise ProgrammingError, 'Class %s does not exist' % qcl
296                t = {}
297                for att, typ in self.db.query("SELECT pg_attribute.attname"
298                        ",pg_type.typname FROM pg_class"
299                        " JOIN pg_namespace ON pg_class.relnamespace=pg_namespace.oid"
300                        " JOIN pg_attribute ON pg_attribute.attrelid=pg_class.oid"
301                        " JOIN pg_type ON pg_type.oid=pg_attribute.atttypid"
302                        " WHERE pg_namespace.nspname='%s' AND pg_class.relname='%s'"
303                        " AND (pg_attribute.attnum>0 or pg_attribute.attname='oid')"
304                        " AND pg_attribute.attisdropped='f'"
305                                % cl).getresult():
306                        if typ.startswith('bool'):
307                                t[att] = 'bool'
308                        elif typ.startswith('int'):
309                                t[att] = 'int'
310                        elif typ.startswith('oid'):
311                                t[att] = 'int'
312                        elif typ.startswith('float'):
313                                t[att] = 'decimal'
314                        elif typ.startswith('abstime'):
315                                t[att] = 'date'
316                        elif typ.startswith('date'):
317                                t[att] = 'date'
318                        elif typ.startswith('interval'):
319                                t[att] = 'date'
320                        elif typ.startswith('timestamp'):
321                                t[att] = 'date'
322                        elif typ.startswith('money'):
323                                t[att] = 'money'
324                        else:
325                                t[att] = 'text'
326                self.__attnames[qcl] = t # cache it
327                return self.__attnames[qcl]
328
329        def get(self, cl, arg, keyname = None, view = 0):
330                """Get a tuple from a database table or view.
331
332                This method is the basic mechanism to get a single row.  It assumes
333                that the key specifies a unique row.  If keyname is not specified
334                then the primary key for the table is used.  If arg is a dictionary
335                then the value for the key is taken from it and it is modified to
336                include the new values, replacing existing values where necessary.
337                The OID is also put into the dictionary, but in order to allow the
338                caller to work with multiple tables, it is munged as oid(schema.table).
339
340                """
341                if cl.endswith('*'): # scan descendant tables?
342                        cl = cl[:-1].rstrip() # need parent table name
343                qcl = _join_parts(self._split_schema(cl)) # build qualified name
344                # To allow users to work with multiple tables,
345                # we munge the name when the key is "oid"
346                foid = 'oid(%s)' % qcl # build mangled name
347                if keyname == None: # use the primary key by default
348                        keyname = self.pkey(qcl)
349                fnames = self.get_attnames(qcl)
350                if isinstance(arg, DictType):
351                        k = arg[keyname == 'oid' and foid or keyname]
352                else:
353                        k = arg
354                        arg = {}
355                # We want the oid for later updates if that isn't the key
356                if keyname == 'oid':
357                        q = 'SELECT * FROM %s WHERE oid=%s LIMIT 1' % (qcl, k)
358                elif view:
359                        q = 'SELECT * FROM %s WHERE %s=%s LIMIT 1' % \
360                                (qcl, keyname, _quote(k, fnames[keyname]))
361                else:
362                        q = 'SELECT %s FROM %s WHERE %s=%s LIMIT 1' % \
363                                (','.join(fnames.keys()), qcl, \
364                                        keyname, _quote(k, fnames[keyname]))
365                self._do_debug(q)
366                res = self.db.query(q).dictresult()
367                if not res:
368                        raise DatabaseError, \
369                                'No such record in %s where %s=%s' % \
370                                        (qcl, keyname, _quote(k, fnames[keyname]))
371                for k, d in res[0].items():
372                        if k == 'oid':
373                                k = foid
374                        arg[k] = d
375                return arg
376
377        def insert(self, cl, a):
378                """Insert a tuple into a database table.
379
380                This method inserts values into the table specified filling in the
381                values from the dictionary.  It then reloads the dictionary with the
382                values from the database.  This causes the dictionary to be updated
383                with values that are modified by rules, triggers, etc.
384
385                Note: The method currently doesn't support insert into views
386                although PostgreSQL does.
387
388                """
389                qcl = _join_parts(self._split_schema(cl)) # build qualified name
390                foid = 'oid(%s)' % qcl # build mangled name
391                fnames = self.get_attnames(qcl)
392                t = []
393                n = []
394                for f in fnames.keys():
395                        if f != 'oid' and a.has_key(f):
396                                t.append(_quote(a[f], fnames[f]))
397                                n.append(f)
398                q = 'INSERT INTO %s (%s) VALUES (%s)' % \
399                        (qcl, ','.join(n), ','.join(t))
400                self._do_debug(q)
401                a[foid] = self.db.query(q)
402                # Reload the dictionary to catch things modified by engine.
403                # Note that get() changes 'oid' below to oid_schema_table.
404                # If no read perms (it can and does happen), return None.
405                try:
406                        return self.get(qcl, a, 'oid')
407                except:
408                        return None
409
410        def update(self, cl, a):
411                """Update an existing row in a database table.
412
413                Similar to insert but updates an existing row.  The update is based
414                on the OID value as munged by get.  The array returned is the
415                one sent modified to reflect any changes caused by the update due
416                to triggers, rules, defaults, etc.
417
418                """
419                # Update always works on the oid which get returns if available,
420                # otherwise use the primary key.  Fail if neither.
421                qcl = _join_parts(self._split_schema(cl)) # build qualified name
422                foid = 'oid(%s)' % qcl # build mangled oid
423
424                # XXX this code is for backwards compatibility and will be
425                # XXX removed eventually
426                if not a.has_key(foid):
427                        ofoid = 'oid_' + self._split_schema(cl)[-1]
428                        if a.has_key(ofoid):
429                                a[foid] = a[ofoid]
430
431                if a.has_key(foid):
432                        where = "oid=%s" % a[foid]
433                else:
434                        try:
435                                pk = self.pkey(qcl)
436                        except:
437                                raise ProgrammingError, \
438                                        'Update needs primary key or oid as %s' % foid
439                        where = "%s='%s'" % (pk, a[pk])
440                v = []
441                k = 0
442                fnames = self.get_attnames(qcl)
443                for ff in fnames.keys():
444                        if ff != 'oid' and a.has_key(ff):
445                                v.append('%s=%s' % (ff, _quote(a[ff], fnames[ff])))
446                if v == []:
447                        return None
448                q = 'UPDATE %s SET %s WHERE %s' % (qcl, ','.join(v), where)
449                self._do_debug(q)
450                self.db.query(q)
451                # Reload the dictionary to catch things modified by engine:
452                if a.has_key(foid):
453                        return self.get(qcl, a, 'oid')
454                else:
455                        return self.get(qcl, a)
456
457        def clear(self, cl, a = None):
458                """
459
460                This method clears all the attributes to values determined by the types.
461                Numeric types are set to 0, Booleans are set to 'f', and everything
462                else is set to the empty string.  If the array argument is present,
463                it is used as the array and any entries matching attribute names are
464                cleared with everything else left unchanged.
465
466                """
467                # At some point we will need a way to get defaults from a table.
468                if a is None: a = {} # empty if argument is not present
469                qcl = _join_parts(self._split_schema(cl)) # build qualified name
470                foid = 'oid(%s)' % qcl # build mangled oid
471                fnames = self.get_attnames(qcl)
472                for k, t in fnames.items():
473                        if k == 'oid': continue
474                        if t in ['int', 'decimal', 'seq', 'money']:
475                                a[k] = 0
476                        elif t == 'bool':
477                                a[k] = 'f'
478                        else:
479                                a[k] = ''
480                return a
481
482        def delete(self, cl, a):
483                """Delete an existing row in a database table.
484
485                This method deletes the row from a table.
486                It deletes based on the OID munged as described above.
487
488                """
489                # Like update, delete works on the oid.
490                # One day we will be testing that the record to be deleted
491                # isn't referenced somewhere (or else PostgreSQL will).
492                qcl = _join_parts(self._split_schema(cl)) # build qualified name
493                foid = 'oid(%s)' % qcl # build mangled oid
494
495                # XXX this code is for backwards compatibility and will be
496                # XXX removed eventually
497                if not a.has_key(foid):
498                        ofoid = 'oid_' + self._split_schema(cl)[-1]
499                        if a.has_key(ofoid):
500                                a[foid] = a[ofoid]
501
502                q = 'DELETE FROM %s WHERE oid=%s' % (qcl, a[foid])
503                self._do_debug(q)
504                self.db.query(q)
505
506# if run as script, print some information
507if __name__ == '__main__':
508        print 'PyGreSQL version', version
509        print
510        print __doc__
Note: See TracBrowser for help on using the repository browser.