source: trunk/module/pg.py @ 232

Last change on this file since 232 was 232, checked in by darcy, 14 years ago

Change get_tables method to get_relations and allow any type of relation
to be retrieved.

Make new get_tables method that uses get_relations.

Use new method in get_attnames method to get attributes of views as well.

File size: 15.4 KB
Line 
1#!/usr/bin/env python
2#
3# pg.py
4#
5# Written by D'Arcy J.M. Cain
6# Improved by Christoph Zwerschke
7#
8# $Id: pg.py,v 1.36 2006-01-22 12:52:45 darcy Exp $
9#
10
11"""PyGreSQL classic interface.
12
13This pg module implements some basic database management stuff.
14It includes the _pg module and builds on it, providing the higher
15level wrapper class named DB with addtional functionality.
16This is known as the "classic" ("old style") PyGreSQL interface.
17For a DB-API 2 compliant interface use the newer pgdb module.
18
19"""
20
21from _pg import *
22from types import *
23
24# Auxiliary functions which are independent from a DB connection:
25
26def _quote(d, t):
27        """Return quotes if needed."""
28        if d is None:
29                return 'NULL'
30        if t in ('int', 'seq', 'decimal'):
31                if d == '': return 'NULL'
32                return str(d)
33        if t == 'money':
34                if d == '': return 'NULL'
35                return "'%.2f'" % float(d)
36        if t == 'bool':
37                if type(d) == StringType:
38                        if d == '': return 'NULL'
39                        d = str(d).lower() in ('t', 'true', '1', 'y', 'yes', 'on')
40                else:
41                        d = not not d
42                return ("'f'", "'t'")[d]
43        if t in ('date', 'inet', 'cidr'):
44                if d == '': return 'NULL'
45        return "'%s'" % str(d).replace('\\','\\\\').replace('\'','\\\'')
46
47def _is_quoted(s):
48        """Check whether this string is a quoted identifier."""
49        s = s.replace('_', 'a')
50        return not s.isalnum() or s[:1].isdigit() or s != s.lower()
51
52def _is_unquoted(s):
53        """Check whether this string is an unquoted identifier."""
54        s = s.replace('_', 'a')
55        return s.isalnum() and not s[:1].isdigit()
56
57def _split_first_part(s):
58        """Split the first part of a dot separated string."""
59        s = s.lstrip()
60        if s[:1] == '"':
61                p = []
62                s = s.split('"', 3)[1:]
63                p.append(s[0])
64                while len(s) == 3 and s[1] == '':
65                        p.append('"')
66                        s = s[2].split('"', 2)
67                        p.append(s[0])
68                p = [''.join(p)]
69                s = '"'.join(s[1:]).lstrip()
70                if s:
71                        if s[:0] == '.':
72                                p.append(s[1:])
73                        else:
74                                s = _split_first_part(s)
75                                p[0] += s[0]
76                                if len(s) > 1:
77                                        p.append(s[1])
78        else:
79                p = s.split('.', 1)
80                s = p[0].rstrip()
81                if _is_unquoted(s):
82                        s = s.lower()
83                p[0] = s
84        return p
85
86def _split_parts(s):
87        """Split all parts of a dot separated string."""
88        q = []
89        while s:
90                s = _split_first_part(s)
91                q.append(s[0])
92                if len(s) < 2: break
93                s = s[1]
94        return q
95
96def _join_parts(s):
97        """Join all parts of a dot separated string."""
98        return '.'.join([_is_quoted(p) and '"%s"' % p or p for p in s])
99
100# The PostGreSQL database connection interface:
101
102class DB:
103        """Wrapper class for the _pg connection type."""
104
105        def __init__(self, *args, **kw):
106                self.db = connect(*args, **kw)
107                self.dbname = self.db.db
108                self.__attnames = {}
109                self.__pkeys = {}
110                self.__args = args, kw
111                self.debug = None # For debugging scripts, this can be set
112                        # * to a string format specification (e.g. in a CGI set to "%s<BR>"),
113                        # * to a function which takes a string argument or
114                        # * to a file object to write debug statements to.
115
116        def __getattr__(self, name):
117                # All undefined members are the same as in the underlying pg connection:
118                if self.db:
119                        return getattr(self.db, name)
120                else:
121                        raise InternalError, 'Connection is not valid'
122
123        def _do_debug(self, s):
124                """Print a debug message."""
125                if not self.debug: return
126                if isinstance(self.debug, StringType): print self.debug % s
127                if isinstance(self.debug, FunctionType): self.debug(s)
128                if isinstance(self.debug, FileType): print >> self.debug, s
129
130        def close(self):
131                """Close the database connection."""
132                # Wraps shared library function so we can track state.
133
134                if self.db:
135                        self.db.close()
136                        self.db = None
137                else:
138                        raise InternalError, 'Connection already closed'
139
140        def reopen(self):
141                """Reopen connection to the database.
142
143                Used in case we need another connection to the same database.
144                Note that we can still reopen a database that we have closed.
145
146                """
147                if self.db:
148                        self.db.close()
149                try:
150                        self.db = connect(*self.__args[0], **self.__args[1])
151                except:
152                        self.db = None
153                        raise
154
155        def query(self, qstr):
156                """Executes a SQL command string.
157
158                This method simply sends a SQL query to the database. If the query is
159                an insert statement, the return value is the OID of the newly
160                inserted row.  If it is otherwise a query that does not return a result
161                (ie. is not a some kind of SELECT statement), it returns None.
162                Otherwise, it returns a pgqueryobject that can be accessed via the
163                getresult or dictresult method or simply printed.
164
165                """
166                # Wraps shared library function for debugging.
167                if not self.db:
168                        raise InternalError, 'Connection is not valid'
169                self._do_debug(qstr)
170                return self.db.query(qstr)
171
172        def _split_schema(self, cl):
173                """Return schema and name of object separately.
174
175                This auxiliary function splits off the namespace (schema)
176                belonging to the class with the name cl. If the class name
177                is not qualified, the function is able to determine the schema
178                of the class, taking into account the current search path.
179
180                """
181                s = _split_parts(cl)
182                if len(s) > 1: # name already qualfied?
183                        # should be database.schema.table or schema.table
184                        if len(s) > 3:
185                                raise ProgrammingError, 'Too many dots in class name %s' % cl
186                        schema, cl = s[-2:]
187                else:
188                        cl = s[0]
189                        # determine search path
190                        query = 'SELECT current_schemas(TRUE)'
191                        schemas = self.db.query(query).getresult()[0][0][1:-1].split(',')
192                        if schemas: # non-empty path
193                                # search schema for this object in the current search path
194                                query = ' UNION '.join(["SELECT %d AS n, '%s' AS nspname" % s
195                                        for s in enumerate(schemas)])
196                                query = ("SELECT nspname FROM pg_class"
197                                        " JOIN pg_namespace ON pg_class.relnamespace=pg_namespace.oid"
198                                        " JOIN (%s) AS p USING (nspname)"
199                                        " WHERE pg_class.relname='%s'"
200                                        " ORDER BY n LIMIT 1" % (query, cl))
201                                schema = self.db.query(query).getresult()
202                                if schema: # schema found
203                                        schema = schema[0][0]
204                                else: # object not found in current search path
205                                        schema = 'public'
206                        else: # empty path
207                                schema = 'public'
208                return schema, cl
209
210        def pkey(self, cl, newpkey = None):
211                """This method gets or sets the primary key of a class.
212
213                If newpkey is set and is not a dictionary then set that
214                value as the primary key of the class.  If it is a dictionary
215                then replace the __pkeys dictionary with it.
216
217                """
218                # Get all the primary keys at once
219                if isinstance(newpkey, DictType):
220                        self.__pkeys = newpkey
221                        return newpkey
222                qcl = _join_parts(self._split_schema(cl)) # build qualified name
223                if newpkey:
224                        self.__pkeys[qcl] = newpkey
225                        return newpkey
226                if self.__pkeys == {} or not self.__pkeys.has_key(qcl):
227                        # if not found, determine pkey again in case it was added after we started
228                        for r in self.db.query("SELECT pg_namespace.nspname"
229                                ",pg_class.relname,pg_attribute.attname FROM pg_class"
230                                " JOIN pg_namespace ON pg_namespace.oid=pg_class.relnamespace"
231                                " AND pg_namespace.nspname NOT LIKE 'pg_%'"
232                                " JOIN pg_attribute ON pg_attribute.attrelid=pg_class.oid"
233                                " AND pg_attribute.attisdropped='f'"
234                                " JOIN pg_index ON pg_index.indrelid=pg_class.oid"
235                                " AND pg_index.indisprimary='t'"
236                                " AND pg_index.indkey[0]=pg_attribute.attnum").getresult():
237                                self.__pkeys[_join_parts(r[:2])] = r[2] # build qualified name
238                        self._do_debug(self.__pkeys)
239                # will raise an exception if primary key doesn't exist
240                return self.__pkeys[qcl]
241
242        def get_databases(self):
243                """Get list of databases in the system."""
244                return [s for s, in
245                        self.db.query('SELECT datname FROM pg_database').getresult()]
246
247        def get_relations(self, typ = None):
248                """Get list of relations in connected database of specified types.
249                        If type is None, all relations are returned.
250                        Otherwise type can be a string or sequence of type letters."""
251
252                if typ:
253                        where = "pg_class.relkind IN (%s) AND" % \
254                                                        ','.join(["'%s'" % x for x in typ])
255                else:
256                        where = ''
257
258                return [_join_parts(s) for s in
259                        self.db.query(
260                                "SELECT pg_namespace.nspname, pg_class.relname "
261                                "FROM pg_class "
262                                "JOIN pg_namespace ON pg_namespace.oid=pg_class.relnamespace "
263                                "WHERE %s pg_class.relname !~ '^Inv' AND "
264                                        "pg_class.relname !~ '^pg_' "
265                                "ORDER BY 1,2" % where).getresult()]
266
267        def get_tables(self):
268                """Return list of tables in connected database."""
269
270                return self.get_relations('r')
271
272        def get_attnames(self, cl, newattnames = None):
273                """Given the name of a table, digs out the set of attribute names.
274
275                Returns a dictionary of attribute names (the names are the keys,
276                the values are the names of the attributes' types).
277                If the optional newattnames exists, it must be a dictionary and
278                will become the new attribute names dictionary.
279
280                """
281                if isinstance(newattnames, DictType):
282                        self.__attnames = newattnames
283                        return
284                elif newattnames:
285                        raise ProgrammingError, \
286                                'If supplied, newattnames must be a dictionary'
287                cl = self._split_schema(cl) # split into schema and cl
288                qcl = _join_parts(cl) # build qualified name
289                # May as well cache them:
290                if self.__attnames.has_key(qcl):
291                        return self.__attnames[qcl]
292                if qcl not in self.get_relations('rv'):
293                        raise ProgrammingError, 'Class %s does not exist' % qcl
294                t = {}
295                for att, typ in self.db.query("SELECT pg_attribute.attname"
296                        ",pg_type.typname FROM pg_class"
297                        " JOIN pg_namespace ON pg_class.relnamespace=pg_namespace.oid"
298                        " JOIN pg_attribute ON pg_attribute.attrelid=pg_class.oid"
299                        " JOIN pg_type ON pg_type.oid=pg_attribute.atttypid"
300                        " WHERE pg_namespace.nspname='%s' AND pg_class.relname='%s'"
301                        " AND (pg_attribute.attnum>0 or pg_attribute.attname='oid')"
302                        " AND pg_attribute.attisdropped='f'"
303                                % cl).getresult():
304                        if typ.startswith('bool'):
305                                t[att] = 'bool'
306                        elif typ.startswith('int'):
307                                t[att] = 'int'
308                        elif typ.startswith('oid'):
309                                t[att] = 'int'
310                        elif typ.startswith('float'):
311                                t[att] = 'decimal'
312                        elif typ.startswith('abstime'):
313                                t[att] = 'date'
314                        elif typ.startswith('date'):
315                                t[att] = 'date'
316                        elif typ.startswith('interval'):
317                                t[att] = 'date'
318                        elif typ.startswith('timestamp'):
319                                t[att] = 'date'
320                        elif typ.startswith('money'):
321                                t[att] = 'money'
322                        else:
323                                t[att] = 'text'
324                self.__attnames[qcl] = t # cache it
325                return self.__attnames[qcl]
326
327        def get(self, cl, arg, keyname = None, view = 0):
328                """Get a tuple from a database table or view.
329
330                This method is the basic mechanism to get a single row.  It assumes
331                that the key specifies a unique row.  If keyname is not specified
332                then the primary key for the table is used.  If arg is a dictionary
333                then the value for the key is taken from it and it is modified to
334                include the new values, replacing existing values where necessary.
335                The OID is also put into the dictionary, but in order to allow the
336                caller to work with multiple tables, it is munged as oid(schema.table).
337
338                """
339                if cl.endswith('*'): # scan descendant tables?
340                        cl = cl[:-1].rstrip() # need parent table name
341                qcl = _join_parts(self._split_schema(cl)) # build qualified name
342                # To allow users to work with multiple tables,
343                # we munge the name when the key is "oid"
344                foid = 'oid(%s)' % qcl # build mangled name
345                if keyname == None: # use the primary key by default
346                        keyname = self.pkey(qcl)
347                fnames = self.get_attnames(qcl)
348                if isinstance(arg, DictType):
349                        k = arg[keyname == 'oid' and foid or keyname]
350                else:
351                        k = arg
352                        arg = {}
353                # We want the oid for later updates if that isn't the key
354                if keyname == 'oid':
355                        q = 'SELECT * FROM %s WHERE oid=%s LIMIT 1' % (qcl, k)
356                elif view:
357                        q = 'SELECT * FROM %s WHERE %s=%s LIMIT 1' % \
358                                (qcl, keyname, _quote(k, fnames[keyname]))
359                else:
360                        q = 'SELECT %s FROM %s WHERE %s=%s LIMIT 1' % \
361                                (','.join(fnames.keys()), qcl, \
362                                        keyname, _quote(k, fnames[keyname]))
363                self._do_debug(q)
364                res = self.db.query(q).dictresult()
365                if not res:
366                        raise DatabaseError, \
367                                'No such record in %s where %s=%s' % \
368                                        (qcl, keyname, _quote(k, fnames[keyname]))
369                for k, d in res[0].items():
370                        if k == 'oid':
371                                k = foid
372                        arg[k] = d
373                return arg
374
375        def insert(self, cl, a):
376                """Insert a tuple into a database table.
377
378                This method inserts values into the table specified filling in the
379                values from the dictionary.  It then reloads the dictionary with the
380                values from the database.  This causes the dictionary to be updated
381                with values that are modified by rules, triggers, etc.
382
383                Note: The method currently doesn't support insert into views
384                although PostgreSQL does.
385
386                """
387                qcl = _join_parts(self._split_schema(cl)) # build qualified name
388                foid = 'oid(%s)' % qcl # build mangled name
389                fnames = self.get_attnames(qcl)
390                t = []
391                n = []
392                for f in fnames.keys():
393                        if f != 'oid' and a.has_key(f):
394                                t.append(_quote(a[f], fnames[f]))
395                                n.append(f)
396                q = 'INSERT INTO %s (%s) VALUES (%s)' % \
397                        (qcl, ','.join(n), ','.join(t))
398                self._do_debug(q)
399                a[foid] = self.db.query(q)
400                # Reload the dictionary to catch things modified by engine.
401                # Note that get() changes 'oid' below to oid_schema_table.
402                # If no read perms (it can and does happen), return None.
403                try:
404                        return self.get(qcl, a, 'oid')
405                except:
406                        return None
407
408        def update(self, cl, a):
409                """Update an existing row in a database table.
410
411                Similar to insert but updates an existing row.  The update is based
412                on the OID value as munged by get.  The array returned is the
413                one sent modified to reflect any changes caused by the update due
414                to triggers, rules, defaults, etc.
415
416                """
417                # Update always works on the oid which get returns if available,
418                # otherwise use the primary key.  Fail if neither.
419                qcl = _join_parts(self._split_schema(cl)) # build qualified name
420                foid = 'oid(%s)' % qcl # build mangled oid
421                if a.has_key(foid):
422                        where = "oid=%s" % a[foid]
423                else:
424                        try:
425                                pk = self.pkey(qcl)
426                        except:
427                                raise ProgrammingError, \
428                                        'Update needs primary key or oid as %s' % foid
429                        where = "%s='%s'" % (pk, a[pk])
430                v = []
431                k = 0
432                fnames = self.get_attnames(qcl)
433                for ff in fnames.keys():
434                        if ff != 'oid' and a.has_key(ff):
435                                v.append('%s=%s' % (ff, _quote(a[ff], fnames[ff])))
436                if v == []:
437                        return None
438                q = 'UPDATE %s SET %s WHERE %s' % (qcl, ','.join(v), where)
439                self._do_debug(q)
440                self.db.query(q)
441                # Reload the dictionary to catch things modified by engine:
442                if a.has_key(foid):
443                        return self.get(qcl, a, 'oid')
444                else:
445                        return self.get(qcl, a)
446
447        def clear(self, cl, a = None):
448                """
449
450                This method clears all the attributes to values determined by the types.
451                Numeric types are set to 0, Booleans are set to 'f', dates are set
452                to 'now()' and everything else is set to the empty string.
453                If the array argument is present, it is used as the array and any entries
454                matching attribute names are cleared with everything else left unchanged.
455
456                """
457                # At some point we will need a way to get defaults from a table.
458                if a is None: a = {} # empty if argument is not present
459                qcl = _join_parts(self._split_schema(cl)) # build qualified name
460                foid = 'oid(%s)' % qcl # build mangled oid
461                fnames = self.get_attnames(qcl)
462                for k, t in fnames.items():
463                        if k == 'oid': continue
464                        if t in ['int', 'decimal', 'seq', 'money']:
465                                a[k] = 0
466                        elif t == 'bool':
467                                a[k] = 'f'
468                        elif t == 'date':
469                                a[k] = 'now()'
470                        else:
471                                a[k] = ''
472                return a
473
474        def delete(self, cl, a):
475                """Delete an existing row in a database table.
476
477                This method deletes the row from a table.
478                It deletes based on the OID munged as described above.
479
480                """
481                # Like update, delete works on the oid.
482                # One day we will be testing that the record to be deleted
483                # isn't referenced somewhere (or else PostgreSQL will).
484                qcl = _join_parts(self._split_schema(cl)) # build qualified name
485                foid = 'oid(%s)' % qcl # build mangled oid
486                q = 'DELETE FROM %s WHERE oid=%s' % (qcl, a[foid])
487                self._do_debug(q)
488                self.db.query(q)
489
490# if run as script, print some information
491if __name__ == '__main__':
492        print 'PyGreSQL version', version
493        print
494        print __doc__
Note: See TracBrowser for help on using the repository browser.