source: trunk/module/pg.py @ 284

Last change on this file since 284 was 284, checked in by darcy, 14 years ago

If caller supplies key dictionary, make sure that all has a namespace.
Create unit test for change.
Document change in changelog.

File size: 16.3 KB
Line 
1#!/usr/bin/env python
2#
3# pg.py
4#
5# Written by D'Arcy J.M. Cain
6# Improved by Christoph Zwerschke
7#
8# $Id: pg.py,v 1.45 2006-04-14 16:13:22 darcy Exp $
9#
10
11"""PyGreSQL classic interface.
12
13This pg module implements some basic database management stuff.
14It includes the _pg module and builds on it, providing the higher
15level wrapper class named DB with addtional functionality.
16This is known as the "classic" ("old style") PyGreSQL interface.
17For a DB-API 2 compliant interface use the newer pgdb module.
18
19"""
20
21from _pg import *
22from types import *
23
24# Auxiliary functions which are independent from a DB connection:
25
26def _quote(d, t):
27        """Return quotes if needed."""
28        if d is None:
29                return 'NULL'
30        if t in ('int', 'seq', 'decimal'):
31                if d == '': return 'NULL'
32                return str(d)
33        if t == 'money':
34                if d == '': return 'NULL'
35                return "'%.2f'" % float(d)
36        if t == 'bool':
37                if type(d) == StringType:
38                        if d == '': return 'NULL'
39                        d = str(d).lower() in ('t', 'true', '1', 'y', 'yes', 'on')
40                else:
41                        d = not not d
42                return ("'f'", "'t'")[d]
43        if t in ('date', 'inet', 'cidr'):
44                if d == '': return 'NULL'
45        return "'%s'" % str(d).replace('\\','\\\\').replace('\'','\\\'')
46
47def _is_quoted(s):
48        """Check whether this string is a quoted identifier."""
49        s = s.replace('_', 'a')
50        return not s.isalnum() or s[:1].isdigit() or s != s.lower()
51
52def _is_unquoted(s):
53        """Check whether this string is an unquoted identifier."""
54        s = s.replace('_', 'a')
55        return s.isalnum() and not s[:1].isdigit()
56
57def _split_first_part(s):
58        """Split the first part of a dot separated string."""
59        s = s.lstrip()
60        if s[:1] == '"':
61                p = []
62                s = s.split('"', 3)[1:]
63                p.append(s[0])
64                while len(s) == 3 and s[1] == '':
65                        p.append('"')
66                        s = s[2].split('"', 2)
67                        p.append(s[0])
68                p = [''.join(p)]
69                s = '"'.join(s[1:]).lstrip()
70                if s:
71                        if s[:0] == '.':
72                                p.append(s[1:])
73                        else:
74                                s = _split_first_part(s)
75                                p[0] += s[0]
76                                if len(s) > 1:
77                                        p.append(s[1])
78        else:
79                p = s.split('.', 1)
80                s = p[0].rstrip()
81                if _is_unquoted(s):
82                        s = s.lower()
83                p[0] = s
84        return p
85
86def _split_parts(s):
87        """Split all parts of a dot separated string."""
88        q = []
89        while s:
90                s = _split_first_part(s)
91                q.append(s[0])
92                if len(s) < 2: break
93                s = s[1]
94        return q
95
96def _join_parts(s):
97        """Join all parts of a dot separated string."""
98        return '.'.join([_is_quoted(p) and '"%s"' % p or p for p in s])
99
100# The PostGreSQL database connection interface:
101
102class DB:
103        """Wrapper class for the _pg connection type."""
104
105        def __init__(self, *args, **kw):
106                self.db = connect(*args, **kw)
107                self.dbname = self.db.db
108                self.__attnames = {}
109                self.__pkeys = {}
110                self.__args = args, kw
111                self.debug = None # For debugging scripts, this can be set
112                        # * to a string format specification (e.g. in a CGI set to "%s<BR>"),
113                        # * to a function which takes a string argument or
114                        # * to a file object to write debug statements to.
115
116        def __getattr__(self, name):
117                # All undefined members are the same as in the underlying pg connection:
118                if self.db:
119                        return getattr(self.db, name)
120                else:
121                        raise InternalError, 'Connection is not valid'
122
123        def _do_debug(self, s):
124                """Print a debug message."""
125                if not self.debug: return
126                if isinstance(self.debug, StringType): print self.debug % s
127                if isinstance(self.debug, FunctionType): self.debug(s)
128                if isinstance(self.debug, FileType): print >> self.debug, s
129
130        def close(self):
131                """Close the database connection."""
132                # Wraps shared library function so we can track state.
133
134                if self.db:
135                        self.db.close()
136                        self.db = None
137                else:
138                        raise InternalError, 'Connection already closed'
139
140        def reopen(self):
141                """Reopen connection to the database.
142
143                Used in case we need another connection to the same database.
144                Note that we can still reopen a database that we have closed.
145
146                """
147                if self.db:
148                        self.db.close()
149                try:
150                        self.db = connect(*self.__args[0], **self.__args[1])
151                except:
152                        self.db = None
153                        raise
154
155        def query(self, qstr):
156                """Executes a SQL command string.
157
158                This method simply sends a SQL query to the database. If the query is
159                an insert statement, the return value is the OID of the newly
160                inserted row.  If it is otherwise a query that does not return a result
161                (ie. is not a some kind of SELECT statement), it returns None.
162                Otherwise, it returns a pgqueryobject that can be accessed via the
163                getresult or dictresult method or simply printed.
164
165                """
166                # Wraps shared library function for debugging.
167                if not self.db:
168                        raise InternalError, 'Connection is not valid'
169                self._do_debug(qstr)
170                return self.db.query(qstr)
171
172        def _split_schema(self, cl):
173                """Return schema and name of object separately.
174
175                This auxiliary function splits off the namespace (schema)
176                belonging to the class with the name cl. If the class name
177                is not qualified, the function is able to determine the schema
178                of the class, taking into account the current search path.
179
180                """
181                s = _split_parts(cl)
182                if len(s) > 1: # name already qualfied?
183                        # should be database.schema.table or schema.table
184                        if len(s) > 3:
185                                raise ProgrammingError, 'Too many dots in class name %s' % cl
186                        schema, cl = s[-2:]
187                else:
188                        cl = s[0]
189                        # determine search path
190                        query = 'SELECT current_schemas(TRUE)'
191                        schemas = self.db.query(query).getresult()[0][0][1:-1].split(',')
192                        if schemas: # non-empty path
193                                # search schema for this object in the current search path
194                                query = ' UNION '.join(["SELECT %d AS n, '%s' AS nspname" % s
195                                        for s in enumerate(schemas)])
196                                query = ("SELECT nspname FROM pg_class"
197                                        " JOIN pg_namespace ON pg_class.relnamespace=pg_namespace.oid"
198                                        " JOIN (%s) AS p USING (nspname)"
199                                        " WHERE pg_class.relname='%s'"
200                                        " ORDER BY n LIMIT 1" % (query, cl))
201                                schema = self.db.query(query).getresult()
202                                if schema: # schema found
203                                        schema = schema[0][0]
204                                else: # object not found in current search path
205                                        schema = 'public'
206                        else: # empty path
207                                schema = 'public'
208                return schema, cl
209
210        def pkey(self, cl, newpkey = None):
211                """This method gets or sets the primary key of a class.
212
213                If newpkey is set and is not a dictionary then set that
214                value as the primary key of the class.  If it is a dictionary
215                then replace the __pkeys dictionary with it.
216
217                """
218                # First see if the caller is supplying a dictionary
219                if isinstance(newpkey, DictType):
220                        # make sure that we have a namespace
221                        self.__pkeys = {}
222                        for x in newpkey.keys():
223                                if x.find('.') == -1:
224                                        self.__pkeys['public.' + x] = newpkey[x]
225                                else:
226                                        self.__pkeys[x] = newpkey[x]
227
228                        return self.__pkeys
229
230                qcl = _join_parts(self._split_schema(cl)) # build qualified name
231                if newpkey:
232                        self.__pkeys[qcl] = newpkey
233                        return newpkey
234
235                # Get all the primary keys at once
236                if self.__pkeys == {} or not self.__pkeys.has_key(qcl):
237                        # if not found, check again in case it was added after we started
238                        for r in self.db.query("SELECT pg_namespace.nspname"
239                                ",pg_class.relname,pg_attribute.attname FROM pg_class"
240                                " JOIN pg_namespace ON pg_namespace.oid=pg_class.relnamespace"
241                                " AND pg_namespace.nspname NOT LIKE 'pg_%'"
242                                " JOIN pg_attribute ON pg_attribute.attrelid=pg_class.oid"
243                                " AND pg_attribute.attisdropped='f'"
244                                " JOIN pg_index ON pg_index.indrelid=pg_class.oid"
245                                " AND pg_index.indisprimary='t'"
246                                " AND pg_index.indkey[0]=pg_attribute.attnum").getresult():
247                                self.__pkeys[_join_parts(r[:2])] = r[2] # build qualified name
248                        self._do_debug(self.__pkeys)
249                # will raise an exception if primary key doesn't exist
250                return self.__pkeys[qcl]
251
252        def get_databases(self):
253                """Get list of databases in the system."""
254                return [s for s, in
255                        self.db.query('SELECT datname FROM pg_database').getresult()]
256
257        def get_relations(self, kinds = None):
258                """Get list of relations in connected database of specified kinds.
259
260                        If kinds is None or empty, all kinds of relations are returned.
261                        Otherwise kinds can be a string or sequence of type letters
262                        specifying which kind of relations you want to list.
263
264                """
265                if kinds:
266                        where = "pg_class.relkind IN (%s) AND" % \
267                                                        ','.join(["'%s'" % x for x in kinds])
268                else:
269                        where = ''
270
271                return [_join_parts(s) for s in
272                        self.db.query(
273                                "SELECT pg_namespace.nspname, pg_class.relname "
274                                "FROM pg_class "
275                                "JOIN pg_namespace ON pg_namespace.oid=pg_class.relnamespace "
276                                "WHERE %s pg_class.relname !~ '^Inv' AND "
277                                        "pg_class.relname !~ '^pg_' "
278                                "ORDER BY 1,2" % where).getresult()]
279
280        def get_tables(self):
281                """Return list of tables in connected database."""
282                return self.get_relations('r')
283
284        def get_attnames(self, cl, newattnames = None):
285                """Given the name of a table, digs out the set of attribute names.
286
287                Returns a dictionary of attribute names (the names are the keys,
288                the values are the names of the attributes' types).
289                If the optional newattnames exists, it must be a dictionary and
290                will become the new attribute names dictionary.
291
292                """
293                if isinstance(newattnames, DictType):
294                        self.__attnames = newattnames
295                        return
296                elif newattnames:
297                        raise ProgrammingError, \
298                                'If supplied, newattnames must be a dictionary'
299                cl = self._split_schema(cl) # split into schema and cl
300                qcl = _join_parts(cl) # build qualified name
301                # May as well cache them:
302                if self.__attnames.has_key(qcl):
303                        return self.__attnames[qcl]
304                if qcl not in self.get_relations('rv'):
305                        raise ProgrammingError, 'Class %s does not exist' % qcl
306                t = {}
307                for att, typ in self.db.query("SELECT pg_attribute.attname"
308                        ",pg_type.typname FROM pg_class"
309                        " JOIN pg_namespace ON pg_class.relnamespace=pg_namespace.oid"
310                        " JOIN pg_attribute ON pg_attribute.attrelid=pg_class.oid"
311                        " JOIN pg_type ON pg_type.oid=pg_attribute.atttypid"
312                        " WHERE pg_namespace.nspname='%s' AND pg_class.relname='%s'"
313                        " AND (pg_attribute.attnum>0 or pg_attribute.attname='oid')"
314                        " AND pg_attribute.attisdropped='f'"
315                                % cl).getresult():
316                        if typ.startswith('bool'):
317                                t[att] = 'bool'
318                        elif typ.startswith('oid'):
319                                t[att] = 'int'
320                        elif typ.startswith('float'):
321                                t[att] = 'decimal'
322                        elif typ.startswith('abstime'):
323                                t[att] = 'date'
324                        elif typ.startswith('date'):
325                                t[att] = 'date'
326                        elif typ.startswith('interval'):
327                                t[att] = 'date'
328                        elif typ.startswith('int'):
329                                t[att] = 'int'
330                        elif typ.startswith('timestamp'):
331                                t[att] = 'date'
332                        elif typ.startswith('money'):
333                                t[att] = 'money'
334                        else:
335                                t[att] = 'text'
336                self.__attnames[qcl] = t # cache it
337                return self.__attnames[qcl]
338
339        def get(self, cl, arg, keyname = None, view = 0):
340                """Get a tuple from a database table or view.
341
342                This method is the basic mechanism to get a single row.  It assumes
343                that the key specifies a unique row.  If keyname is not specified
344                then the primary key for the table is used.  If arg is a dictionary
345                then the value for the key is taken from it and it is modified to
346                include the new values, replacing existing values where necessary.
347                The OID is also put into the dictionary, but in order to allow the
348                caller to work with multiple tables, it is munged as oid(schema.table).
349
350                """
351                if cl.endswith('*'): # scan descendant tables?
352                        cl = cl[:-1].rstrip() # need parent table name
353                qcl = _join_parts(self._split_schema(cl)) # build qualified name
354                # To allow users to work with multiple tables,
355                # we munge the name when the key is "oid"
356                foid = 'oid(%s)' % qcl # build mangled name
357                if keyname == None: # use the primary key by default
358                        keyname = self.pkey(qcl)
359                fnames = self.get_attnames(qcl)
360                if isinstance(arg, DictType):
361                        # XXX this code is for backwards compatibility and will be
362                        # XXX removed eventually
363                        if not arg.has_key(foid):
364                                ofoid = 'oid_' + self._split_schema(cl)[-1]
365                                if arg.has_key(ofoid):
366                                        arg[foid] = arg[ofoid]
367
368                        k = arg[keyname == 'oid' and foid or keyname]
369                else:
370                        k = arg
371                        arg = {}
372                # We want the oid for later updates if that isn't the key
373                if keyname == 'oid':
374                        q = 'SELECT * FROM %s WHERE oid=%s LIMIT 1' % (qcl, k)
375                elif view:
376                        q = 'SELECT * FROM %s WHERE %s=%s LIMIT 1' % \
377                                (qcl, keyname, _quote(k, fnames[keyname]))
378                else:
379                        q = 'SELECT %s FROM %s WHERE %s=%s LIMIT 1' % \
380                                (','.join(fnames.keys()), qcl, \
381                                        keyname, _quote(k, fnames[keyname]))
382                self._do_debug(q)
383                res = self.db.query(q).dictresult()
384                if not res:
385                        raise DatabaseError, \
386                                'No such record in %s where %s=%s' % \
387                                        (qcl, keyname, _quote(k, fnames[keyname]))
388                for k, d in res[0].items():
389                        if k == 'oid':
390                                k = foid
391                        arg[k] = d
392                return arg
393
394        def insert(self, cl, a):
395                """Insert a tuple into a database table.
396
397                This method inserts values into the table specified filling in the
398                values from the dictionary.  It then reloads the dictionary with the
399                values from the database.  This causes the dictionary to be updated
400                with values that are modified by rules, triggers, etc.
401
402                Note: The method currently doesn't support insert into views
403                although PostgreSQL does.
404
405                """
406                qcl = _join_parts(self._split_schema(cl)) # build qualified name
407                foid = 'oid(%s)' % qcl # build mangled name
408                fnames = self.get_attnames(qcl)
409                t = []
410                n = []
411                for f in fnames.keys():
412                        if f != 'oid' and a.has_key(f):
413                                t.append(_quote(a[f], fnames[f]))
414                                n.append(f)
415                q = 'INSERT INTO %s (%s) VALUES (%s)' % \
416                        (qcl, ','.join(n), ','.join(t))
417                self._do_debug(q)
418                a[foid] = self.db.query(q)
419                # Reload the dictionary to catch things modified by engine.
420                # Note that get() changes 'oid' below to oid_schema_table.
421                # If no read perms (it can and does happen), return None.
422                try:
423                        return self.get(qcl, a, 'oid')
424                except:
425                        return None
426
427        def update(self, cl, a):
428                """Update an existing row in a database table.
429
430                Similar to insert but updates an existing row.  The update is based
431                on the OID value as munged by get.  The array returned is the
432                one sent modified to reflect any changes caused by the update due
433                to triggers, rules, defaults, etc.
434
435                """
436                # Update always works on the oid which get returns if available,
437                # otherwise use the primary key.  Fail if neither.
438                qcl = _join_parts(self._split_schema(cl)) # build qualified name
439                foid = 'oid(%s)' % qcl # build mangled oid
440
441                # XXX this code is for backwards compatibility and will be
442                # XXX removed eventually
443                if not a.has_key(foid):
444                        ofoid = 'oid_' + self._split_schema(cl)[-1]
445                        if a.has_key(ofoid):
446                                a[foid] = a[ofoid]
447
448                if a.has_key(foid):
449                        where = "oid=%s" % a[foid]
450                else:
451                        try:
452                                pk = self.pkey(qcl)
453                        except:
454                                raise ProgrammingError, \
455                                        'Update needs primary key or oid as %s' % foid
456                        where = "%s='%s'" % (pk, a[pk])
457                v = []
458                k = 0
459                fnames = self.get_attnames(qcl)
460                for ff in fnames.keys():
461                        if ff != 'oid' and a.has_key(ff):
462                                v.append('%s=%s' % (ff, _quote(a[ff], fnames[ff])))
463                if v == []:
464                        return None
465                q = 'UPDATE %s SET %s WHERE %s' % (qcl, ','.join(v), where)
466                self._do_debug(q)
467                self.db.query(q)
468                # Reload the dictionary to catch things modified by engine:
469                if a.has_key(foid):
470                        return self.get(qcl, a, 'oid')
471                else:
472                        return self.get(qcl, a)
473
474        def clear(self, cl, a = None):
475                """
476
477                This method clears all the attributes to values determined by the types.
478                Numeric types are set to 0, Booleans are set to 'f', and everything
479                else is set to the empty string.  If the array argument is present,
480                it is used as the array and any entries matching attribute names are
481                cleared with everything else left unchanged.
482
483                """
484                # At some point we will need a way to get defaults from a table.
485                if a is None: a = {} # empty if argument is not present
486                qcl = _join_parts(self._split_schema(cl)) # build qualified name
487                foid = 'oid(%s)' % qcl # build mangled oid
488                fnames = self.get_attnames(qcl)
489                for k, t in fnames.items():
490                        if k == 'oid': continue
491                        if t in ['int', 'decimal', 'seq', 'money']:
492                                a[k] = 0
493                        elif t == 'bool':
494                                a[k] = 'f'
495                        else:
496                                a[k] = ''
497                return a
498
499        def delete(self, cl, a):
500                """Delete an existing row in a database table.
501
502                This method deletes the row from a table.
503                It deletes based on the OID munged as described above.
504
505                """
506                # Like update, delete works on the oid.
507                # One day we will be testing that the record to be deleted
508                # isn't referenced somewhere (or else PostgreSQL will).
509                qcl = _join_parts(self._split_schema(cl)) # build qualified name
510                foid = 'oid(%s)' % qcl # build mangled oid
511
512                # XXX this code is for backwards compatibility and will be
513                # XXX removed eventually
514                if not a.has_key(foid):
515                        ofoid = 'oid_' + self._split_schema(cl)[-1]
516                        if a.has_key(ofoid):
517                                a[foid] = a[ofoid]
518
519                q = 'DELETE FROM %s WHERE oid=%s' % (qcl, a[foid])
520                self._do_debug(q)
521                self.db.query(q)
522
523# if run as script, print some information
524if __name__ == '__main__':
525        print 'PyGreSQL version', version
526        print
527        print __doc__
Note: See TracBrowser for help on using the repository browser.