source: trunk/module/pg.py @ 296

Last change on this file since 296 was 296, checked in by cito, 13 years ago

Small bugfix (bad code that does not work in Python 2.5).

File size: 16.5 KB
Line 
1#!/usr/bin/env python
2#
3# pg.py
4#
5# Written by D'Arcy J.M. Cain
6# Improved by Christoph Zwerschke
7#
8# $Id: pg.py,v 1.48 2006-05-30 23:43:30 cito Exp $
9#
10
11"""PyGreSQL classic interface.
12
13This pg module implements some basic database management stuff.
14It includes the _pg module and builds on it, providing the higher
15level wrapper class named DB with addtional functionality.
16This is known as the "classic" ("old style") PyGreSQL interface.
17For a DB-API 2 compliant interface use the newer pgdb module.
18
19"""
20
21from _pg import *
22from types import *
23
24# Auxiliary functions which are independent from a DB connection:
25
26def _quote(d, t):
27        """Return quotes if needed."""
28        if d is None:
29                return 'NULL'
30        if t in ('int', 'seq', 'decimal'):
31                if d == '': return 'NULL'
32                return str(d)
33        if t == 'money':
34                if d == '': return 'NULL'
35                return "'%.2f'" % float(d)
36        if t == 'bool':
37                if type(d) == StringType:
38                        if d == '': return 'NULL'
39                        d = str(d).lower() in ('t', 'true', '1', 'y', 'yes', 'on')
40                else:
41                        d = not not d
42                return ("'f'", "'t'")[d]
43        if t in ('date', 'inet', 'cidr'):
44                if d == '': return 'NULL'
45        return "'%s'" % str(d).replace("\\", "\\\\").replace("'", "''")
46
47def _is_quoted(s):
48        """Check whether this string is a quoted identifier."""
49        s = s.replace('_', 'a')
50        return not s.isalnum() or s[:1].isdigit() or s != s.lower()
51
52def _is_unquoted(s):
53        """Check whether this string is an unquoted identifier."""
54        s = s.replace('_', 'a')
55        return s.isalnum() and not s[:1].isdigit()
56
57def _split_first_part(s):
58        """Split the first part of a dot separated string."""
59        s = s.lstrip()
60        if s[:1] == '"':
61                p = []
62                s = s.split('"', 3)[1:]
63                p.append(s[0])
64                while len(s) == 3 and s[1] == '':
65                        p.append('"')
66                        s = s[2].split('"', 2)
67                        p.append(s[0])
68                p = [''.join(p)]
69                s = '"'.join(s[1:]).lstrip()
70                if s:
71                        if s[:0] == '.':
72                                p.append(s[1:])
73                        else:
74                                s = _split_first_part(s)
75                                p[0] += s[0]
76                                if len(s) > 1:
77                                        p.append(s[1])
78        else:
79                p = s.split('.', 1)
80                s = p[0].rstrip()
81                if _is_unquoted(s):
82                        s = s.lower()
83                p[0] = s
84        return p
85
86def _split_parts(s):
87        """Split all parts of a dot separated string."""
88        q = []
89        while s:
90                s = _split_first_part(s)
91                q.append(s[0])
92                if len(s) < 2: break
93                s = s[1]
94        return q
95
96def _join_parts(s):
97        """Join all parts of a dot separated string."""
98        return '.'.join([_is_quoted(p) and '"%s"' % p or p for p in s])
99
100# The PostGreSQL database connection interface:
101
102class DB:
103        """Wrapper class for the _pg connection type."""
104
105        def __init__(self, *args, **kw):
106                self.db = connect(*args, **kw)
107                self.dbname = self.db.db
108                self.__attnames = {}
109                self.__pkeys = {}
110                self.__args = args, kw
111                self.debug = None # For debugging scripts, this can be set
112                        # * to a string format specification (e.g. in a CGI set to "%s<BR>"),
113                        # * to a function which takes a string argument or
114                        # * to a file object to write debug statements to.
115
116        def __getattr__(self, name):
117                # All undefined members are the same as in the underlying pg connection:
118                if self.db:
119                        return getattr(self.db, name)
120                else:
121                        raise InternalError, 'Connection is not valid'
122
123        # For convenience, define some module functions as static methods also:
124        escape_string, escape_bytea, unescape_bytea = map(staticmethod,
125                (escape_string, escape_bytea, unescape_bytea))
126
127        def _do_debug(self, s):
128                """Print a debug message."""
129                if not self.debug: return
130                if isinstance(self.debug, StringType): print self.debug % s
131                if isinstance(self.debug, FunctionType): self.debug(s)
132                if isinstance(self.debug, FileType): print >> self.debug, s
133
134        def close(self):
135                """Close the database connection."""
136                # Wraps shared library function so we can track state.
137
138                if self.db:
139                        self.db.close()
140                        self.db = None
141                else:
142                        raise InternalError, 'Connection already closed'
143
144        def reopen(self):
145                """Reopen connection to the database.
146
147                Used in case we need another connection to the same database.
148                Note that we can still reopen a database that we have closed.
149
150                """
151                if self.db:
152                        self.db.close()
153                try:
154                        self.db = connect(*self.__args[0], **self.__args[1])
155                except:
156                        self.db = None
157                        raise
158
159        def query(self, qstr):
160                """Executes a SQL command string.
161
162                This method simply sends a SQL query to the database. If the query is
163                an insert statement, the return value is the OID of the newly
164                inserted row.  If it is otherwise a query that does not return a result
165                (ie. is not a some kind of SELECT statement), it returns None.
166                Otherwise, it returns a pgqueryobject that can be accessed via the
167                getresult or dictresult method or simply printed.
168
169                """
170                # Wraps shared library function for debugging.
171                if not self.db:
172                        raise InternalError, 'Connection is not valid'
173                self._do_debug(qstr)
174                return self.db.query(qstr)
175
176        def _split_schema(self, cl):
177                """Return schema and name of object separately.
178
179                This auxiliary function splits off the namespace (schema)
180                belonging to the class with the name cl. If the class name
181                is not qualified, the function is able to determine the schema
182                of the class, taking into account the current search path.
183
184                """
185                s = _split_parts(cl)
186                if len(s) > 1: # name already qualfied?
187                        # should be database.schema.table or schema.table
188                        if len(s) > 3:
189                                raise ProgrammingError, 'Too many dots in class name %s' % cl
190                        schema, cl = s[-2:]
191                else:
192                        cl = s[0]
193                        # determine search path
194                        query = 'SELECT current_schemas(TRUE)'
195                        schemas = self.db.query(query).getresult()[0][0][1:-1].split(',')
196                        if schemas: # non-empty path
197                                # search schema for this object in the current search path
198                                query = ' UNION '.join(["SELECT %d AS n, '%s' AS nspname" % s
199                                        for s in enumerate(schemas)])
200                                query = ("SELECT nspname FROM pg_class"
201                                        " JOIN pg_namespace ON pg_class.relnamespace=pg_namespace.oid"
202                                        " JOIN (%s) AS p USING (nspname)"
203                                        " WHERE pg_class.relname='%s'"
204                                        " ORDER BY n LIMIT 1" % (query, cl))
205                                schema = self.db.query(query).getresult()
206                                if schema: # schema found
207                                        schema = schema[0][0]
208                                else: # object not found in current search path
209                                        schema = 'public'
210                        else: # empty path
211                                schema = 'public'
212                return schema, cl
213
214        def pkey(self, cl, newpkey = None):
215                """This method gets or sets the primary key of a class.
216
217                If newpkey is set and is not a dictionary then set that
218                value as the primary key of the class.  If it is a dictionary
219                then replace the __pkeys dictionary with it.
220
221                """
222                # First see if the caller is supplying a dictionary
223                if isinstance(newpkey, DictType):
224                        # make sure that we have a namespace
225                        self.__pkeys = {}
226                        for x in newpkey.keys():
227                                if x.find('.') == -1:
228                                        self.__pkeys['public.' + x] = newpkey[x]
229                                else:
230                                        self.__pkeys[x] = newpkey[x]
231
232                        return self.__pkeys
233
234                qcl = _join_parts(self._split_schema(cl)) # build qualified name
235                if newpkey:
236                        self.__pkeys[qcl] = newpkey
237                        return newpkey
238
239                # Get all the primary keys at once
240                if self.__pkeys == {} or not self.__pkeys.has_key(qcl):
241                        # if not found, check again in case it was added after we started
242                        for r in self.db.query("SELECT pg_namespace.nspname"
243                                ",pg_class.relname,pg_attribute.attname FROM pg_class"
244                                " JOIN pg_namespace ON pg_namespace.oid=pg_class.relnamespace"
245                                " AND pg_namespace.nspname NOT LIKE 'pg_%'"
246                                " JOIN pg_attribute ON pg_attribute.attrelid=pg_class.oid"
247                                " AND pg_attribute.attisdropped='f'"
248                                " JOIN pg_index ON pg_index.indrelid=pg_class.oid"
249                                " AND pg_index.indisprimary='t'"
250                                " AND pg_index.indkey[0]=pg_attribute.attnum").getresult():
251                                self.__pkeys[_join_parts(r[:2])] = r[2] # build qualified name
252                        self._do_debug(self.__pkeys)
253                # will raise an exception if primary key doesn't exist
254                return self.__pkeys[qcl]
255
256        def get_databases(self):
257                """Get list of databases in the system."""
258                return [s[0] for s in
259                        self.db.query('SELECT datname FROM pg_database').getresult()]
260
261        def get_relations(self, kinds = None):
262                """Get list of relations in connected database of specified kinds.
263
264                        If kinds is None or empty, all kinds of relations are returned.
265                        Otherwise kinds can be a string or sequence of type letters
266                        specifying which kind of relations you want to list.
267
268                """
269                if kinds:
270                        where = "pg_class.relkind IN (%s) AND" % \
271                                                        ','.join(["'%s'" % x for x in kinds])
272                else:
273                        where = ''
274
275                return [_join_parts(s) for s in
276                        self.db.query(
277                                "SELECT pg_namespace.nspname, pg_class.relname "
278                                "FROM pg_class "
279                                "JOIN pg_namespace ON pg_namespace.oid=pg_class.relnamespace "
280                                "WHERE %s pg_class.relname !~ '^Inv' AND "
281                                        "pg_class.relname !~ '^pg_' "
282                                "ORDER BY 1,2" % where).getresult()]
283
284        def get_tables(self):
285                """Return list of tables in connected database."""
286                return self.get_relations('r')
287
288        def get_attnames(self, cl, newattnames = None):
289                """Given the name of a table, digs out the set of attribute names.
290
291                Returns a dictionary of attribute names (the names are the keys,
292                the values are the names of the attributes' types).
293                If the optional newattnames exists, it must be a dictionary and
294                will become the new attribute names dictionary.
295
296                """
297                if isinstance(newattnames, DictType):
298                        self.__attnames = newattnames
299                        return
300                elif newattnames:
301                        raise ProgrammingError, \
302                                'If supplied, newattnames must be a dictionary'
303                cl = self._split_schema(cl) # split into schema and cl
304                qcl = _join_parts(cl) # build qualified name
305                # May as well cache them:
306                if self.__attnames.has_key(qcl):
307                        return self.__attnames[qcl]
308                if qcl not in self.get_relations('rv'):
309                        raise ProgrammingError, 'Class %s does not exist' % qcl
310                t = {}
311                for att, typ in self.db.query("SELECT pg_attribute.attname"
312                        ",pg_type.typname FROM pg_class"
313                        " JOIN pg_namespace ON pg_class.relnamespace=pg_namespace.oid"
314                        " JOIN pg_attribute ON pg_attribute.attrelid=pg_class.oid"
315                        " JOIN pg_type ON pg_type.oid=pg_attribute.atttypid"
316                        " WHERE pg_namespace.nspname='%s' AND pg_class.relname='%s'"
317                        " AND (pg_attribute.attnum>0 or pg_attribute.attname='oid')"
318                        " AND pg_attribute.attisdropped='f'"
319                                % cl).getresult():
320                        if typ.startswith('bool'):
321                                t[att] = 'bool'
322                        elif typ.startswith('oid'):
323                                t[att] = 'int'
324                        elif typ.startswith('float'):
325                                t[att] = 'decimal'
326                        elif typ.startswith('abstime'):
327                                t[att] = 'date'
328                        elif typ.startswith('date'):
329                                t[att] = 'date'
330                        elif typ.startswith('interval'):
331                                t[att] = 'date'
332                        elif typ.startswith('int'):
333                                t[att] = 'int'
334                        elif typ.startswith('timestamp'):
335                                t[att] = 'date'
336                        elif typ.startswith('money'):
337                                t[att] = 'money'
338                        else:
339                                t[att] = 'text'
340                self.__attnames[qcl] = t # cache it
341                return self.__attnames[qcl]
342
343        def get(self, cl, arg, keyname = None, view = 0):
344                """Get a tuple from a database table or view.
345
346                This method is the basic mechanism to get a single row.  It assumes
347                that the key specifies a unique row.  If keyname is not specified
348                then the primary key for the table is used.  If arg is a dictionary
349                then the value for the key is taken from it and it is modified to
350                include the new values, replacing existing values where necessary.
351                The OID is also put into the dictionary, but in order to allow the
352                caller to work with multiple tables, it is munged as oid(schema.table).
353
354                """
355                if cl.endswith('*'): # scan descendant tables?
356                        cl = cl[:-1].rstrip() # need parent table name
357                qcl = _join_parts(self._split_schema(cl)) # build qualified name
358                # To allow users to work with multiple tables,
359                # we munge the name when the key is "oid"
360                foid = 'oid(%s)' % qcl # build mangled name
361                if keyname == None: # use the primary key by default
362                        keyname = self.pkey(qcl)
363                fnames = self.get_attnames(qcl)
364                if isinstance(arg, DictType):
365                        # XXX this code is for backwards compatibility and will be
366                        # XXX removed eventually
367                        if not arg.has_key(foid):
368                                ofoid = 'oid_' + self._split_schema(cl)[-1]
369                                if arg.has_key(ofoid):
370                                        arg[foid] = arg[ofoid]
371
372                        k = arg[keyname == 'oid' and foid or keyname]
373                else:
374                        k = arg
375                        arg = {}
376                # We want the oid for later updates if that isn't the key
377                if keyname == 'oid':
378                        q = 'SELECT * FROM %s WHERE oid=%s LIMIT 1' % (qcl, k)
379                elif view:
380                        q = 'SELECT * FROM %s WHERE %s=%s LIMIT 1' % \
381                                (qcl, keyname, _quote(k, fnames[keyname]))
382                else:
383                        q = 'SELECT %s FROM %s WHERE %s=%s LIMIT 1' % \
384                                (','.join(fnames.keys()), qcl, \
385                                        keyname, _quote(k, fnames[keyname]))
386                self._do_debug(q)
387                res = self.db.query(q).dictresult()
388                if not res:
389                        raise DatabaseError, \
390                                'No such record in %s where %s=%s' % \
391                                        (qcl, keyname, _quote(k, fnames[keyname]))
392                for k, d in res[0].items():
393                        if k == 'oid':
394                                k = foid
395                        arg[k] = d
396                return arg
397
398        def insert(self, cl, a):
399                """Insert a tuple into a database table.
400
401                This method inserts values into the table specified filling in the
402                values from the dictionary.  It then reloads the dictionary with the
403                values from the database.  This causes the dictionary to be updated
404                with values that are modified by rules, triggers, etc.
405
406                Note: The method currently doesn't support insert into views
407                although PostgreSQL does.
408
409                """
410                qcl = _join_parts(self._split_schema(cl)) # build qualified name
411                foid = 'oid(%s)' % qcl # build mangled name
412                fnames = self.get_attnames(qcl)
413                t = []
414                n = []
415                for f in fnames.keys():
416                        if f != 'oid' and a.has_key(f):
417                                t.append(_quote(a[f], fnames[f]))
418                                n.append(f)
419                q = 'INSERT INTO %s (%s) VALUES (%s)' % \
420                        (qcl, ','.join(n), ','.join(t))
421                self._do_debug(q)
422                a[foid] = self.db.query(q)
423                # Reload the dictionary to catch things modified by engine.
424                # Note that get() changes 'oid' below to oid_schema_table.
425                # If no read perms (it can and does happen), return None.
426                try:
427                        return self.get(qcl, a, 'oid')
428                except:
429                        return None
430
431        def update(self, cl, a):
432                """Update an existing row in a database table.
433
434                Similar to insert but updates an existing row.  The update is based
435                on the OID value as munged by get.  The array returned is the
436                one sent modified to reflect any changes caused by the update due
437                to triggers, rules, defaults, etc.
438
439                """
440                # Update always works on the oid which get returns if available,
441                # otherwise use the primary key.  Fail if neither.
442                qcl = _join_parts(self._split_schema(cl)) # build qualified name
443                foid = 'oid(%s)' % qcl # build mangled oid
444
445                # XXX this code is for backwards compatibility and will be
446                # XXX removed eventually
447                if not a.has_key(foid):
448                        ofoid = 'oid_' + self._split_schema(cl)[-1]
449                        if a.has_key(ofoid):
450                                a[foid] = a[ofoid]
451
452                if a.has_key(foid):
453                        where = "oid=%s" % a[foid]
454                else:
455                        try:
456                                pk = self.pkey(qcl)
457                        except:
458                                raise ProgrammingError, \
459                                        'Update needs primary key or oid as %s' % foid
460                        where = "%s='%s'" % (pk, a[pk])
461                v = []
462                k = 0
463                fnames = self.get_attnames(qcl)
464                for ff in fnames.keys():
465                        if ff != 'oid' and a.has_key(ff):
466                                v.append('%s=%s' % (ff, _quote(a[ff], fnames[ff])))
467                if v == []:
468                        return None
469                q = 'UPDATE %s SET %s WHERE %s' % (qcl, ','.join(v), where)
470                self._do_debug(q)
471                self.db.query(q)
472                # Reload the dictionary to catch things modified by engine:
473                if a.has_key(foid):
474                        return self.get(qcl, a, 'oid')
475                else:
476                        return self.get(qcl, a)
477
478        def clear(self, cl, a = None):
479                """
480
481                This method clears all the attributes to values determined by the types.
482                Numeric types are set to 0, Booleans are set to 'f', and everything
483                else is set to the empty string.  If the array argument is present,
484                it is used as the array and any entries matching attribute names are
485                cleared with everything else left unchanged.
486
487                """
488                # At some point we will need a way to get defaults from a table.
489                if a is None: a = {} # empty if argument is not present
490                qcl = _join_parts(self._split_schema(cl)) # build qualified name
491                foid = 'oid(%s)' % qcl # build mangled oid
492                fnames = self.get_attnames(qcl)
493                for k, t in fnames.items():
494                        if k == 'oid': continue
495                        if t in ['int', 'decimal', 'seq', 'money']:
496                                a[k] = 0
497                        elif t == 'bool':
498                                a[k] = 'f'
499                        else:
500                                a[k] = ''
501                return a
502
503        def delete(self, cl, a):
504                """Delete an existing row in a database table.
505
506                This method deletes the row from a table.
507                It deletes based on the OID munged as described above.
508
509                """
510                # Like update, delete works on the oid.
511                # One day we will be testing that the record to be deleted
512                # isn't referenced somewhere (or else PostgreSQL will).
513                qcl = _join_parts(self._split_schema(cl)) # build qualified name
514                foid = 'oid(%s)' % qcl # build mangled oid
515
516                # XXX this code is for backwards compatibility and will be
517                # XXX removed eventually
518                if not a.has_key(foid):
519                        ofoid = 'oid_' + self._split_schema(cl)[-1]
520                        if a.has_key(ofoid):
521                                a[foid] = a[ofoid]
522
523                q = 'DELETE FROM %s WHERE oid=%s' % (qcl, a[foid])
524                self._do_debug(q)
525                self.db.query(q)
526
527# if run as script, print some information
528if __name__ == '__main__':
529        print 'PyGreSQL version', version
530        print
531        print __doc__
Note: See TracBrowser for help on using the repository browser.