source: trunk/module/pg.py @ 320

Last change on this file since 320 was 320, checked in by cito, 13 years ago

In some cases, PostgreSQL stumbled over unknown types in _split_schema(). Made these types explicit.

File size: 17.0 KB
Line 
1#!/usr/bin/env python
2#
3# pg.py
4#
5# Written by D'Arcy J.M. Cain
6# Improved by Christoph Zwerschke
7#
8# $Id: pg.py,v 1.53 2007-03-02 20:35:55 cito Exp $
9#
10
11"""PyGreSQL classic interface.
12
13This pg module implements some basic database management stuff.
14It includes the _pg module and builds on it, providing the higher
15level wrapper class named DB with addtional functionality.
16This is known as the "classic" ("old style") PyGreSQL interface.
17For a DB-API 2 compliant interface use the newer pgdb module.
18
19"""
20
21from _pg import *
22from types import *
23
24# Auxiliary functions which are independent from a DB connection:
25
26def _quote(d, t):
27        """Return quotes if needed."""
28        if d is None:
29                return 'NULL'
30        if t in ('int', 'seq', 'decimal'):
31                if d == '': return 'NULL'
32                return str(d)
33        if t == 'money':
34                if d == '': return 'NULL'
35                return "'%.2f'" % float(d)
36        if t == 'bool':
37                if type(d) == StringType:
38                        if d == '': return 'NULL'
39                        d = str(d).lower() in ('t', 'true', '1', 'y', 'yes', 'on')
40                else:
41                        d = not not d
42                return ("'f'", "'t'")[d]
43        if t in ('date', 'inet', 'cidr'):
44                if d == '': return 'NULL'
45                if d.lower() in ('current_date', 'current_time',
46                        'current_timestamp', 'localtime', 'localtimestamp'):
47                        return d
48        return "'%s'" % str(d).replace("\\", "\\\\").replace("'", "''")
49
50def _is_quoted(s):
51        """Check whether this string is a quoted identifier."""
52        s = s.replace('_', 'a')
53        return not s.isalnum() or s[:1].isdigit() or s != s.lower()
54
55def _is_unquoted(s):
56        """Check whether this string is an unquoted identifier."""
57        s = s.replace('_', 'a')
58        return s.isalnum() and not s[:1].isdigit()
59
60def _split_first_part(s):
61        """Split the first part of a dot separated string."""
62        s = s.lstrip()
63        if s[:1] == '"':
64                p = []
65                s = s.split('"', 3)[1:]
66                p.append(s[0])
67                while len(s) == 3 and s[1] == '':
68                        p.append('"')
69                        s = s[2].split('"', 2)
70                        p.append(s[0])
71                p = [''.join(p)]
72                s = '"'.join(s[1:]).lstrip()
73                if s:
74                        if s[:0] == '.':
75                                p.append(s[1:])
76                        else:
77                                s = _split_first_part(s)
78                                p[0] += s[0]
79                                if len(s) > 1:
80                                        p.append(s[1])
81        else:
82                p = s.split('.', 1)
83                s = p[0].rstrip()
84                if _is_unquoted(s):
85                        s = s.lower()
86                p[0] = s
87        return p
88
89def _split_parts(s):
90        """Split all parts of a dot separated string."""
91        q = []
92        while s:
93                s = _split_first_part(s)
94                q.append(s[0])
95                if len(s) < 2: break
96                s = s[1]
97        return q
98
99def _join_parts(s):
100        """Join all parts of a dot separated string."""
101        return '.'.join([_is_quoted(p) and '"%s"' % p or p for p in s])
102
103# The PostGreSQL database connection interface:
104
105class DB:
106        """Wrapper class for the _pg connection type."""
107
108        def __init__(self, *args, **kw):
109                self.db = connect(*args, **kw)
110                self.dbname = self.db.db
111                self.__attnames = {}
112                self.__pkeys = {}
113                self.__args = args, kw
114                self.debug = None # For debugging scripts, this can be set
115                        # * to a string format specification (e.g. in CGI set to "%s<BR>"),
116                        # * to a function which takes a string argument or
117                        # * to a file object to write debug statements to.
118
119        def __getattr__(self, name):
120                # All undefined members are the same as in the underlying pg connection:
121                if self.db:
122                        return getattr(self.db, name)
123                else:
124                        raise InternalError, 'Connection is not valid'
125
126        # For convenience, define some module functions as static methods also:
127        escape_string, escape_bytea, unescape_bytea = map(staticmethod,
128                (escape_string, escape_bytea, unescape_bytea))
129
130        def _do_debug(self, s):
131                """Print a debug message."""
132                if not self.debug: return
133                if isinstance(self.debug, StringType): print self.debug % s
134                if isinstance(self.debug, FunctionType): self.debug(s)
135                if isinstance(self.debug, FileType): print >> self.debug, s
136
137        def close(self):
138                """Close the database connection."""
139                # Wraps shared library function so we can track state.
140
141                if self.db:
142                        self.db.close()
143                        self.db = None
144                else:
145                        raise InternalError, 'Connection already closed'
146
147        def reopen(self):
148                """Reopen connection to the database.
149
150                Used in case we need another connection to the same database.
151                Note that we can still reopen a database that we have closed.
152
153                """
154                if self.db:
155                        self.db.close()
156                try:
157                        self.db = connect(*self.__args[0], **self.__args[1])
158                except:
159                        self.db = None
160                        raise
161
162        def query(self, qstr):
163                """Executes a SQL command string.
164
165                This method simply sends a SQL query to the database. If the query is
166                an insert statement, the return value is the OID of the newly
167                inserted row.  If it is otherwise a query that does not return a result
168                (ie. is not a some kind of SELECT statement), it returns None.
169                Otherwise, it returns a pgqueryobject that can be accessed via the
170                getresult or dictresult method or simply printed.
171
172                """
173                # Wraps shared library function for debugging.
174                if not self.db:
175                        raise InternalError, 'Connection is not valid'
176                self._do_debug(qstr)
177                return self.db.query(qstr)
178
179        def _split_schema(self, cl):
180                """Return schema and name of object separately.
181
182                This auxiliary function splits off the namespace (schema)
183                belonging to the class with the name cl. If the class name
184                is not qualified, the function is able to determine the schema
185                of the class, taking into account the current search path.
186
187                """
188                s = _split_parts(cl)
189                if len(s) > 1: # name already qualfied?
190                        # should be database.schema.table or schema.table
191                        if len(s) > 3:
192                                raise ProgrammingError, 'Too many dots in class name %s' % cl
193                        schema, cl = s[-2:]
194                else:
195                        cl = s[0]
196                        # determine search path
197                        query = 'SELECT current_schemas(TRUE)'
198                        schemas = self.db.query(query).getresult()[0][0][1:-1].split(',')
199                        if schemas: # non-empty path
200                                # search schema for this object in the current search path
201                                query = ' UNION '.join(
202                                        ["SELECT %d::integer AS n, '%s'::name AS nspname"
203                                                % s for s in enumerate(schemas)])
204                                query = ("SELECT nspname FROM pg_class"
205                                        " JOIN pg_namespace ON pg_class.relnamespace=pg_namespace.oid"
206                                        " JOIN (%s) AS p USING (nspname)"
207                                        " WHERE pg_class.relname='%s'"
208                                        " ORDER BY n LIMIT 1" % (query, cl))
209                                schema = self.db.query(query).getresult()
210                                if schema: # schema found
211                                        schema = schema[0][0]
212                                else: # object not found in current search path
213                                        schema = 'public'
214                        else: # empty path
215                                schema = 'public'
216                return schema, cl
217
218        def pkey(self, cl, newpkey = None):
219                """This method gets or sets the primary key of a class.
220
221                If newpkey is set and is not a dictionary then set that
222                value as the primary key of the class.  If it is a dictionary
223                then replace the __pkeys dictionary with it.
224
225                """
226                # First see if the caller is supplying a dictionary
227                if isinstance(newpkey, DictType):
228                        # make sure that we have a namespace
229                        self.__pkeys = {}
230                        for x in newpkey.keys():
231                                if x.find('.') == -1:
232                                        self.__pkeys['public.' + x] = newpkey[x]
233                                else:
234                                        self.__pkeys[x] = newpkey[x]
235                        return self.__pkeys
236
237                qcl = _join_parts(self._split_schema(cl)) # build qualified name
238                if newpkey:
239                        self.__pkeys[qcl] = newpkey
240                        return newpkey
241
242                # Get all the primary keys at once
243                if self.__pkeys == {} or not self.__pkeys.has_key(qcl):
244                        # if not found, check again in case it was added after we started
245                        self.__pkeys = dict([
246                                (_join_parts(r[:2]), r[2]) for r in self.db.query(
247                                "SELECT pg_namespace.nspname, pg_class.relname"
248                                        ",pg_attribute.attname FROM pg_class"
249                                " JOIN pg_namespace ON pg_namespace.oid=pg_class.relnamespace"
250                                        " AND pg_namespace.nspname NOT LIKE 'pg_%'"
251                                " JOIN pg_attribute ON pg_attribute.attrelid=pg_class.oid"
252                                        " AND pg_attribute.attisdropped='f'"
253                                " JOIN pg_index ON pg_index.indrelid=pg_class.oid"
254                                        " AND pg_index.indisprimary='t'"
255                                        " AND pg_index.indkey[0]=pg_attribute.attnum"
256                                ).getresult()])
257                        self._do_debug(self.__pkeys)
258                # will raise an exception if primary key doesn't exist
259                return self.__pkeys[qcl]
260
261        def get_databases(self):
262                """Get list of databases in the system."""
263                return [s[0] for s in
264                        self.db.query('SELECT datname FROM pg_database').getresult()]
265
266        def get_relations(self, kinds = None):
267                """Get list of relations in connected database of specified kinds.
268
269                        If kinds is None or empty, all kinds of relations are returned.
270                        Otherwise kinds can be a string or sequence of type letters
271                        specifying which kind of relations you want to list.
272
273                """
274                where = kinds and "pg_class.relkind IN (%s) AND" % ','.join(
275                        ["'%s'" % x for x in kinds]) or ''
276                return map(_join_parts, self.db.query(
277                        "SELECT pg_namespace.nspname, pg_class.relname "
278                        "FROM pg_class "
279                        "JOIN pg_namespace ON pg_namespace.oid=pg_class.relnamespace "
280                        "WHERE %s pg_class.relname !~ '^Inv' AND "
281                                "pg_class.relname !~ '^pg_' "
282                        "ORDER BY 1, 2" % where).getresult())
283
284        def get_tables(self):
285                """Return list of tables in connected database."""
286                return self.get_relations('r')
287
288        def get_attnames(self, cl, newattnames = None):
289                """Given the name of a table, digs out the set of attribute names.
290
291                Returns a dictionary of attribute names (the names are the keys,
292                the values are the names of the attributes' types).
293                If the optional newattnames exists, it must be a dictionary and
294                will become the new attribute names dictionary.
295
296                """
297                if isinstance(newattnames, DictType):
298                        self.__attnames = newattnames
299                        return
300                elif newattnames:
301                        raise ProgrammingError, \
302                                'If supplied, newattnames must be a dictionary'
303                cl = self._split_schema(cl) # split into schema and cl
304                qcl = _join_parts(cl) # build qualified name
305                # May as well cache them:
306                if self.__attnames.has_key(qcl):
307                        return self.__attnames[qcl]
308                if qcl not in self.get_relations('rv'):
309                        raise ProgrammingError, 'Class %s does not exist' % qcl
310                t = {}
311                for att, typ in self.db.query("SELECT pg_attribute.attname"
312                        ",pg_type.typname FROM pg_class"
313                        " JOIN pg_namespace ON pg_class.relnamespace=pg_namespace.oid"
314                        " JOIN pg_attribute ON pg_attribute.attrelid=pg_class.oid"
315                        " JOIN pg_type ON pg_type.oid=pg_attribute.atttypid"
316                        " WHERE pg_namespace.nspname='%s' AND pg_class.relname='%s'"
317                        " AND (pg_attribute.attnum>0 or pg_attribute.attname='oid')"
318                        " AND pg_attribute.attisdropped='f'"
319                                % cl).getresult():
320                        if typ.startswith('bool'):
321                                t[att] = 'bool'
322                        elif typ.startswith('oid'):
323                                t[att] = 'int'
324                        elif typ.startswith('float'):
325                                t[att] = 'decimal'
326                        elif typ.startswith('abstime'):
327                                t[att] = 'date'
328                        elif typ.startswith('date'):
329                                t[att] = 'date'
330                        elif typ.startswith('interval'):
331                                t[att] = 'date'
332                        elif typ.startswith('int'):
333                                t[att] = 'int'
334                        elif typ.startswith('timestamp'):
335                                t[att] = 'date'
336                        elif typ.startswith('money'):
337                                t[att] = 'money'
338                        else:
339                                t[att] = 'text'
340                self.__attnames[qcl] = t # cache it
341                return self.__attnames[qcl]
342
343        def get(self, cl, arg, keyname = None, view = 0):
344                """Get a tuple from a database table or view.
345
346                This method is the basic mechanism to get a single row.  It assumes
347                that the key specifies a unique row.  If keyname is not specified
348                then the primary key for the table is used.  If arg is a dictionary
349                then the value for the key is taken from it and it is modified to
350                include the new values, replacing existing values where necessary.
351                The OID is also put into the dictionary, but in order to allow the
352                caller to work with multiple tables, it is munged as oid(schema.table).
353
354                """
355                if cl.endswith('*'): # scan descendant tables?
356                        cl = cl[:-1].rstrip() # need parent table name
357                qcl = _join_parts(self._split_schema(cl)) # build qualified name
358                # To allow users to work with multiple tables,
359                # we munge the name when the key is "oid"
360                foid = 'oid(%s)' % qcl # build mangled name
361                if keyname == None: # use the primary key by default
362                        keyname = self.pkey(qcl)
363                fnames = self.get_attnames(qcl)
364                if isinstance(arg, DictType):
365                        # XXX this code is for backwards compatibility and will be
366                        # XXX removed eventually
367                        if not arg.has_key(foid):
368                                ofoid = 'oid_' + self._split_schema(cl)[-1]
369                                if arg.has_key(ofoid):
370                                        arg[foid] = arg[ofoid]
371
372                        k = arg[keyname == 'oid' and foid or keyname]
373                else:
374                        k = arg
375                        arg = {}
376                # We want the oid for later updates if that isn't the key
377                if keyname == 'oid':
378                        q = 'SELECT * FROM %s WHERE oid=%s LIMIT 1' % (qcl, k)
379                elif view:
380                        q = 'SELECT * FROM %s WHERE %s=%s LIMIT 1' % \
381                                (qcl, keyname, _quote(k, fnames[keyname]))
382                else:
383                        q = 'SELECT %s FROM %s WHERE %s=%s LIMIT 1' % \
384                                (','.join(fnames.keys()), qcl, \
385                                        keyname, _quote(k, fnames[keyname]))
386                self._do_debug(q)
387                res = self.db.query(q).dictresult()
388                if not res:
389                        raise DatabaseError, \
390                                'No such record in %s where %s=%s' % \
391                                        (qcl, keyname, _quote(k, fnames[keyname]))
392                for k, d in res[0].items():
393                        if k == 'oid':
394                                k = foid
395                        arg[k] = d
396                return arg
397
398        def insert(self, cl, d = None, **kw):
399                """Insert a tuple into a database table.
400
401                This method inserts values into the table specified filling in the
402                values from the dictionary.  It then reloads the dictionary with the
403                values from the database.  This causes the dictionary to be updated
404                with values that are modified by rules, triggers, etc.
405
406                Note: The method currently doesn't support insert into views
407                although PostgreSQL does.
408
409                """
410                if d is None: a = {}
411                else: a = d
412                a.update(kw)
413
414                qcl = _join_parts(self._split_schema(cl)) # build qualified name
415                foid = 'oid(%s)' % qcl # build mangled name
416                fnames = self.get_attnames(qcl)
417                t = []
418                n = []
419                for f in fnames.keys():
420                        if f != 'oid' and a.has_key(f):
421                                t.append(_quote(a[f], fnames[f]))
422                                n.append(f)
423                q = 'INSERT INTO %s (%s) VALUES (%s)' % \
424                        (qcl, ','.join(n), ','.join(t))
425                self._do_debug(q)
426                a[foid] = self.db.query(q)
427                # Reload the dictionary to catch things modified by engine.
428                # Note that get() changes 'oid' below to oid_schema_table.
429                # If no read perms (it can and does happen), return None.
430                try:
431                        return self.get(qcl, a, 'oid')
432                except:
433                        return None
434
435        def update(self, cl, d = None, **kw):
436                """Update an existing row in a database table.
437
438                Similar to insert but updates an existing row.  The update is based
439                on the OID value as munged by get.  The array returned is the
440                one sent modified to reflect any changes caused by the update due
441                to triggers, rules, defaults, etc.
442
443                """
444                # Update always works on the oid which get returns if available,
445                # otherwise use the primary key.  Fail if neither.
446                qcl = _join_parts(self._split_schema(cl)) # build qualified name
447                foid = 'oid(%s)' % qcl # build mangled oid
448
449                # Note that we only accept oid key from named args for safety
450                if kw.has_key('oid'):
451                        kw[foid] = kw['oid']
452                        del kw['oid']
453
454                if d is None: a = {}
455                else: a = d
456                a.update(kw)
457
458                # XXX this code is for backwards compatibility and will be
459                # XXX removed eventually
460                if not a.has_key(foid):
461                        ofoid = 'oid_' + self._split_schema(cl)[-1]
462                        if a.has_key(ofoid):
463                                a[foid] = a[ofoid]
464
465                if a.has_key(foid):
466                        where = "oid=%s" % a[foid]
467                else:
468                        try:
469                                pk = self.pkey(qcl)
470                        except:
471                                raise ProgrammingError, \
472                                        'Update needs primary key or oid as %s' % foid
473                        where = "%s='%s'" % (pk, a[pk])
474                v = []
475                k = 0
476                fnames = self.get_attnames(qcl)
477                for ff in fnames.keys():
478                        if ff != 'oid' and a.has_key(ff):
479                                v.append('%s=%s' % (ff, _quote(a[ff], fnames[ff])))
480                if v == []:
481                        return None
482                q = 'UPDATE %s SET %s WHERE %s' % (qcl, ','.join(v), where)
483                self._do_debug(q)
484                self.db.query(q)
485                # Reload the dictionary to catch things modified by engine:
486                if a.has_key(foid):
487                        return self.get(qcl, a, 'oid')
488                else:
489                        return self.get(qcl, a)
490
491        def clear(self, cl, a = None):
492                """
493
494                This method clears all the attributes to values determined by the types.
495                Numeric types are set to 0, Booleans are set to 'f', and everything
496                else is set to the empty string.  If the array argument is present,
497                it is used as the array and any entries matching attribute names are
498                cleared with everything else left unchanged.
499
500                """
501                # At some point we will need a way to get defaults from a table.
502                if a is None: a = {} # empty if argument is not present
503                qcl = _join_parts(self._split_schema(cl)) # build qualified name
504                foid = 'oid(%s)' % qcl # build mangled oid
505                fnames = self.get_attnames(qcl)
506                for k, t in fnames.items():
507                        if k == 'oid': continue
508                        if t in ['int', 'decimal', 'seq', 'money']:
509                                a[k] = 0
510                        elif t == 'bool':
511                                a[k] = 'f'
512                        else:
513                                a[k] = ''
514                return a
515
516        def delete(self, cl, d = None, **kw):
517                """Delete an existing row in a database table.
518
519                This method deletes the row from a table.
520                It deletes based on the OID munged as described above."""
521
522                # Like update, delete works on the oid.
523                # One day we will be testing that the record to be deleted
524                # isn't referenced somewhere (or else PostgreSQL will).
525                qcl = _join_parts(self._split_schema(cl)) # build qualified name
526                foid = 'oid(%s)' % qcl # build mangled oid
527
528                # Note that we only accept oid key from named args for safety
529                if kw.has_key('oid'):
530                        kw[foid] = kw['oid']
531                        del kw['oid']
532
533                if d is None: a = {}
534                else: a = d
535                a.update(kw)
536
537                # XXX this code is for backwards compatibility and will be
538                # XXX removed eventually
539                if not a.has_key(foid):
540                        ofoid = 'oid_' + self._split_schema(cl)[-1]
541                        if a.has_key(ofoid):
542                                a[foid] = a[ofoid]
543
544                q = 'DELETE FROM %s WHERE oid=%s' % (qcl, a[foid])
545                self._do_debug(q)
546                self.db.query(q)
547
548# if run as script, print some information
549if __name__ == '__main__':
550        print 'PyGreSQL version', version
551        print
552        print __doc__
Note: See TracBrowser for help on using the repository browser.