source: trunk/module/pg.py @ 308

Last change on this file since 308 was 308, checked in by cito, 13 years ago

Do not quote SQL constants for date and time.

File size: 16.6 KB
Line 
1#!/usr/bin/env python
2#
3# pg.py
4#
5# Written by D'Arcy J.M. Cain
6# Improved by Christoph Zwerschke
7#
8# $Id: pg.py,v 1.49 2006-11-24 18:18:01 cito Exp $
9#
10
11"""PyGreSQL classic interface.
12
13This pg module implements some basic database management stuff.
14It includes the _pg module and builds on it, providing the higher
15level wrapper class named DB with addtional functionality.
16This is known as the "classic" ("old style") PyGreSQL interface.
17For a DB-API 2 compliant interface use the newer pgdb module.
18
19"""
20
21from _pg import *
22from types import *
23
24# Auxiliary functions which are independent from a DB connection:
25
26def _quote(d, t):
27        """Return quotes if needed."""
28        if d is None:
29                return 'NULL'
30        if t in ('int', 'seq', 'decimal'):
31                if d == '': return 'NULL'
32                return str(d)
33        if t == 'money':
34                if d == '': return 'NULL'
35                return "'%.2f'" % float(d)
36        if t == 'bool':
37                if type(d) == StringType:
38                        if d == '': return 'NULL'
39                        d = str(d).lower() in ('t', 'true', '1', 'y', 'yes', 'on')
40                else:
41                        d = not not d
42                return ("'f'", "'t'")[d]
43        if t in ('date', 'inet', 'cidr'):
44                if d == '': return 'NULL'
45                if d.lower() in ('current_date', 'current_time',
46                        'current_timestamp', 'localtime', 'localtimestamp'):
47                        return d
48        return "'%s'" % str(d).replace("\\", "\\\\").replace("'", "''")
49
50def _is_quoted(s):
51        """Check whether this string is a quoted identifier."""
52        s = s.replace('_', 'a')
53        return not s.isalnum() or s[:1].isdigit() or s != s.lower()
54
55def _is_unquoted(s):
56        """Check whether this string is an unquoted identifier."""
57        s = s.replace('_', 'a')
58        return s.isalnum() and not s[:1].isdigit()
59
60def _split_first_part(s):
61        """Split the first part of a dot separated string."""
62        s = s.lstrip()
63        if s[:1] == '"':
64                p = []
65                s = s.split('"', 3)[1:]
66                p.append(s[0])
67                while len(s) == 3 and s[1] == '':
68                        p.append('"')
69                        s = s[2].split('"', 2)
70                        p.append(s[0])
71                p = [''.join(p)]
72                s = '"'.join(s[1:]).lstrip()
73                if s:
74                        if s[:0] == '.':
75                                p.append(s[1:])
76                        else:
77                                s = _split_first_part(s)
78                                p[0] += s[0]
79                                if len(s) > 1:
80                                        p.append(s[1])
81        else:
82                p = s.split('.', 1)
83                s = p[0].rstrip()
84                if _is_unquoted(s):
85                        s = s.lower()
86                p[0] = s
87        return p
88
89def _split_parts(s):
90        """Split all parts of a dot separated string."""
91        q = []
92        while s:
93                s = _split_first_part(s)
94                q.append(s[0])
95                if len(s) < 2: break
96                s = s[1]
97        return q
98
99def _join_parts(s):
100        """Join all parts of a dot separated string."""
101        return '.'.join([_is_quoted(p) and '"%s"' % p or p for p in s])
102
103# The PostGreSQL database connection interface:
104
105class DB:
106        """Wrapper class for the _pg connection type."""
107
108        def __init__(self, *args, **kw):
109                self.db = connect(*args, **kw)
110                self.dbname = self.db.db
111                self.__attnames = {}
112                self.__pkeys = {}
113                self.__args = args, kw
114                self.debug = None # For debugging scripts, this can be set
115                        # * to a string format specification (e.g. in a CGI set to "%s<BR>"),
116                        # * to a function which takes a string argument or
117                        # * to a file object to write debug statements to.
118
119        def __getattr__(self, name):
120                # All undefined members are the same as in the underlying pg connection:
121                if self.db:
122                        return getattr(self.db, name)
123                else:
124                        raise InternalError, 'Connection is not valid'
125
126        # For convenience, define some module functions as static methods also:
127        escape_string, escape_bytea, unescape_bytea = map(staticmethod,
128                (escape_string, escape_bytea, unescape_bytea))
129
130        def _do_debug(self, s):
131                """Print a debug message."""
132                if not self.debug: return
133                if isinstance(self.debug, StringType): print self.debug % s
134                if isinstance(self.debug, FunctionType): self.debug(s)
135                if isinstance(self.debug, FileType): print >> self.debug, s
136
137        def close(self):
138                """Close the database connection."""
139                # Wraps shared library function so we can track state.
140
141                if self.db:
142                        self.db.close()
143                        self.db = None
144                else:
145                        raise InternalError, 'Connection already closed'
146
147        def reopen(self):
148                """Reopen connection to the database.
149
150                Used in case we need another connection to the same database.
151                Note that we can still reopen a database that we have closed.
152
153                """
154                if self.db:
155                        self.db.close()
156                try:
157                        self.db = connect(*self.__args[0], **self.__args[1])
158                except:
159                        self.db = None
160                        raise
161
162        def query(self, qstr):
163                """Executes a SQL command string.
164
165                This method simply sends a SQL query to the database. If the query is
166                an insert statement, the return value is the OID of the newly
167                inserted row.  If it is otherwise a query that does not return a result
168                (ie. is not a some kind of SELECT statement), it returns None.
169                Otherwise, it returns a pgqueryobject that can be accessed via the
170                getresult or dictresult method or simply printed.
171
172                """
173                # Wraps shared library function for debugging.
174                if not self.db:
175                        raise InternalError, 'Connection is not valid'
176                self._do_debug(qstr)
177                return self.db.query(qstr)
178
179        def _split_schema(self, cl):
180                """Return schema and name of object separately.
181
182                This auxiliary function splits off the namespace (schema)
183                belonging to the class with the name cl. If the class name
184                is not qualified, the function is able to determine the schema
185                of the class, taking into account the current search path.
186
187                """
188                s = _split_parts(cl)
189                if len(s) > 1: # name already qualfied?
190                        # should be database.schema.table or schema.table
191                        if len(s) > 3:
192                                raise ProgrammingError, 'Too many dots in class name %s' % cl
193                        schema, cl = s[-2:]
194                else:
195                        cl = s[0]
196                        # determine search path
197                        query = 'SELECT current_schemas(TRUE)'
198                        schemas = self.db.query(query).getresult()[0][0][1:-1].split(',')
199                        if schemas: # non-empty path
200                                # search schema for this object in the current search path
201                                query = ' UNION '.join(["SELECT %d AS n, '%s' AS nspname" % s
202                                        for s in enumerate(schemas)])
203                                query = ("SELECT nspname FROM pg_class"
204                                        " JOIN pg_namespace ON pg_class.relnamespace=pg_namespace.oid"
205                                        " JOIN (%s) AS p USING (nspname)"
206                                        " WHERE pg_class.relname='%s'"
207                                        " ORDER BY n LIMIT 1" % (query, cl))
208                                schema = self.db.query(query).getresult()
209                                if schema: # schema found
210                                        schema = schema[0][0]
211                                else: # object not found in current search path
212                                        schema = 'public'
213                        else: # empty path
214                                schema = 'public'
215                return schema, cl
216
217        def pkey(self, cl, newpkey = None):
218                """This method gets or sets the primary key of a class.
219
220                If newpkey is set and is not a dictionary then set that
221                value as the primary key of the class.  If it is a dictionary
222                then replace the __pkeys dictionary with it.
223
224                """
225                # First see if the caller is supplying a dictionary
226                if isinstance(newpkey, DictType):
227                        # make sure that we have a namespace
228                        self.__pkeys = {}
229                        for x in newpkey.keys():
230                                if x.find('.') == -1:
231                                        self.__pkeys['public.' + x] = newpkey[x]
232                                else:
233                                        self.__pkeys[x] = newpkey[x]
234
235                        return self.__pkeys
236
237                qcl = _join_parts(self._split_schema(cl)) # build qualified name
238                if newpkey:
239                        self.__pkeys[qcl] = newpkey
240                        return newpkey
241
242                # Get all the primary keys at once
243                if self.__pkeys == {} or not self.__pkeys.has_key(qcl):
244                        # if not found, check again in case it was added after we started
245                        for r in self.db.query("SELECT pg_namespace.nspname"
246                                ",pg_class.relname,pg_attribute.attname FROM pg_class"
247                                " JOIN pg_namespace ON pg_namespace.oid=pg_class.relnamespace"
248                                " AND pg_namespace.nspname NOT LIKE 'pg_%'"
249                                " JOIN pg_attribute ON pg_attribute.attrelid=pg_class.oid"
250                                " AND pg_attribute.attisdropped='f'"
251                                " JOIN pg_index ON pg_index.indrelid=pg_class.oid"
252                                " AND pg_index.indisprimary='t'"
253                                " AND pg_index.indkey[0]=pg_attribute.attnum").getresult():
254                                self.__pkeys[_join_parts(r[:2])] = r[2] # build qualified name
255                        self._do_debug(self.__pkeys)
256                # will raise an exception if primary key doesn't exist
257                return self.__pkeys[qcl]
258
259        def get_databases(self):
260                """Get list of databases in the system."""
261                return [s[0] for s in
262                        self.db.query('SELECT datname FROM pg_database').getresult()]
263
264        def get_relations(self, kinds = None):
265                """Get list of relations in connected database of specified kinds.
266
267                        If kinds is None or empty, all kinds of relations are returned.
268                        Otherwise kinds can be a string or sequence of type letters
269                        specifying which kind of relations you want to list.
270
271                """
272                if kinds:
273                        where = "pg_class.relkind IN (%s) AND" % \
274                                                        ','.join(["'%s'" % x for x in kinds])
275                else:
276                        where = ''
277
278                return [_join_parts(s) for s in
279                        self.db.query(
280                                "SELECT pg_namespace.nspname, pg_class.relname "
281                                "FROM pg_class "
282                                "JOIN pg_namespace ON pg_namespace.oid=pg_class.relnamespace "
283                                "WHERE %s pg_class.relname !~ '^Inv' AND "
284                                        "pg_class.relname !~ '^pg_' "
285                                "ORDER BY 1,2" % where).getresult()]
286
287        def get_tables(self):
288                """Return list of tables in connected database."""
289                return self.get_relations('r')
290
291        def get_attnames(self, cl, newattnames = None):
292                """Given the name of a table, digs out the set of attribute names.
293
294                Returns a dictionary of attribute names (the names are the keys,
295                the values are the names of the attributes' types).
296                If the optional newattnames exists, it must be a dictionary and
297                will become the new attribute names dictionary.
298
299                """
300                if isinstance(newattnames, DictType):
301                        self.__attnames = newattnames
302                        return
303                elif newattnames:
304                        raise ProgrammingError, \
305                                'If supplied, newattnames must be a dictionary'
306                cl = self._split_schema(cl) # split into schema and cl
307                qcl = _join_parts(cl) # build qualified name
308                # May as well cache them:
309                if self.__attnames.has_key(qcl):
310                        return self.__attnames[qcl]
311                if qcl not in self.get_relations('rv'):
312                        raise ProgrammingError, 'Class %s does not exist' % qcl
313                t = {}
314                for att, typ in self.db.query("SELECT pg_attribute.attname"
315                        ",pg_type.typname FROM pg_class"
316                        " JOIN pg_namespace ON pg_class.relnamespace=pg_namespace.oid"
317                        " JOIN pg_attribute ON pg_attribute.attrelid=pg_class.oid"
318                        " JOIN pg_type ON pg_type.oid=pg_attribute.atttypid"
319                        " WHERE pg_namespace.nspname='%s' AND pg_class.relname='%s'"
320                        " AND (pg_attribute.attnum>0 or pg_attribute.attname='oid')"
321                        " AND pg_attribute.attisdropped='f'"
322                                % cl).getresult():
323                        if typ.startswith('bool'):
324                                t[att] = 'bool'
325                        elif typ.startswith('oid'):
326                                t[att] = 'int'
327                        elif typ.startswith('float'):
328                                t[att] = 'decimal'
329                        elif typ.startswith('abstime'):
330                                t[att] = 'date'
331                        elif typ.startswith('date'):
332                                t[att] = 'date'
333                        elif typ.startswith('interval'):
334                                t[att] = 'date'
335                        elif typ.startswith('int'):
336                                t[att] = 'int'
337                        elif typ.startswith('timestamp'):
338                                t[att] = 'date'
339                        elif typ.startswith('money'):
340                                t[att] = 'money'
341                        else:
342                                t[att] = 'text'
343                self.__attnames[qcl] = t # cache it
344                return self.__attnames[qcl]
345
346        def get(self, cl, arg, keyname = None, view = 0):
347                """Get a tuple from a database table or view.
348
349                This method is the basic mechanism to get a single row.  It assumes
350                that the key specifies a unique row.  If keyname is not specified
351                then the primary key for the table is used.  If arg is a dictionary
352                then the value for the key is taken from it and it is modified to
353                include the new values, replacing existing values where necessary.
354                The OID is also put into the dictionary, but in order to allow the
355                caller to work with multiple tables, it is munged as oid(schema.table).
356
357                """
358                if cl.endswith('*'): # scan descendant tables?
359                        cl = cl[:-1].rstrip() # need parent table name
360                qcl = _join_parts(self._split_schema(cl)) # build qualified name
361                # To allow users to work with multiple tables,
362                # we munge the name when the key is "oid"
363                foid = 'oid(%s)' % qcl # build mangled name
364                if keyname == None: # use the primary key by default
365                        keyname = self.pkey(qcl)
366                fnames = self.get_attnames(qcl)
367                if isinstance(arg, DictType):
368                        # XXX this code is for backwards compatibility and will be
369                        # XXX removed eventually
370                        if not arg.has_key(foid):
371                                ofoid = 'oid_' + self._split_schema(cl)[-1]
372                                if arg.has_key(ofoid):
373                                        arg[foid] = arg[ofoid]
374
375                        k = arg[keyname == 'oid' and foid or keyname]
376                else:
377                        k = arg
378                        arg = {}
379                # We want the oid for later updates if that isn't the key
380                if keyname == 'oid':
381                        q = 'SELECT * FROM %s WHERE oid=%s LIMIT 1' % (qcl, k)
382                elif view:
383                        q = 'SELECT * FROM %s WHERE %s=%s LIMIT 1' % \
384                                (qcl, keyname, _quote(k, fnames[keyname]))
385                else:
386                        q = 'SELECT %s FROM %s WHERE %s=%s LIMIT 1' % \
387                                (','.join(fnames.keys()), qcl, \
388                                        keyname, _quote(k, fnames[keyname]))
389                self._do_debug(q)
390                res = self.db.query(q).dictresult()
391                if not res:
392                        raise DatabaseError, \
393                                'No such record in %s where %s=%s' % \
394                                        (qcl, keyname, _quote(k, fnames[keyname]))
395                for k, d in res[0].items():
396                        if k == 'oid':
397                                k = foid
398                        arg[k] = d
399                return arg
400
401        def insert(self, cl, a):
402                """Insert a tuple into a database table.
403
404                This method inserts values into the table specified filling in the
405                values from the dictionary.  It then reloads the dictionary with the
406                values from the database.  This causes the dictionary to be updated
407                with values that are modified by rules, triggers, etc.
408
409                Note: The method currently doesn't support insert into views
410                although PostgreSQL does.
411
412                """
413                qcl = _join_parts(self._split_schema(cl)) # build qualified name
414                foid = 'oid(%s)' % qcl # build mangled name
415                fnames = self.get_attnames(qcl)
416                t = []
417                n = []
418                for f in fnames.keys():
419                        if f != 'oid' and a.has_key(f):
420                                t.append(_quote(a[f], fnames[f]))
421                                n.append(f)
422                q = 'INSERT INTO %s (%s) VALUES (%s)' % \
423                        (qcl, ','.join(n), ','.join(t))
424                self._do_debug(q)
425                a[foid] = self.db.query(q)
426                # Reload the dictionary to catch things modified by engine.
427                # Note that get() changes 'oid' below to oid_schema_table.
428                # If no read perms (it can and does happen), return None.
429                try:
430                        return self.get(qcl, a, 'oid')
431                except:
432                        return None
433
434        def update(self, cl, a):
435                """Update an existing row in a database table.
436
437                Similar to insert but updates an existing row.  The update is based
438                on the OID value as munged by get.  The array returned is the
439                one sent modified to reflect any changes caused by the update due
440                to triggers, rules, defaults, etc.
441
442                """
443                # Update always works on the oid which get returns if available,
444                # otherwise use the primary key.  Fail if neither.
445                qcl = _join_parts(self._split_schema(cl)) # build qualified name
446                foid = 'oid(%s)' % qcl # build mangled oid
447
448                # XXX this code is for backwards compatibility and will be
449                # XXX removed eventually
450                if not a.has_key(foid):
451                        ofoid = 'oid_' + self._split_schema(cl)[-1]
452                        if a.has_key(ofoid):
453                                a[foid] = a[ofoid]
454
455                if a.has_key(foid):
456                        where = "oid=%s" % a[foid]
457                else:
458                        try:
459                                pk = self.pkey(qcl)
460                        except:
461                                raise ProgrammingError, \
462                                        'Update needs primary key or oid as %s' % foid
463                        where = "%s='%s'" % (pk, a[pk])
464                v = []
465                k = 0
466                fnames = self.get_attnames(qcl)
467                for ff in fnames.keys():
468                        if ff != 'oid' and a.has_key(ff):
469                                v.append('%s=%s' % (ff, _quote(a[ff], fnames[ff])))
470                if v == []:
471                        return None
472                q = 'UPDATE %s SET %s WHERE %s' % (qcl, ','.join(v), where)
473                self._do_debug(q)
474                self.db.query(q)
475                # Reload the dictionary to catch things modified by engine:
476                if a.has_key(foid):
477                        return self.get(qcl, a, 'oid')
478                else:
479                        return self.get(qcl, a)
480
481        def clear(self, cl, a = None):
482                """
483
484                This method clears all the attributes to values determined by the types.
485                Numeric types are set to 0, Booleans are set to 'f', and everything
486                else is set to the empty string.  If the array argument is present,
487                it is used as the array and any entries matching attribute names are
488                cleared with everything else left unchanged.
489
490                """
491                # At some point we will need a way to get defaults from a table.
492                if a is None: a = {} # empty if argument is not present
493                qcl = _join_parts(self._split_schema(cl)) # build qualified name
494                foid = 'oid(%s)' % qcl # build mangled oid
495                fnames = self.get_attnames(qcl)
496                for k, t in fnames.items():
497                        if k == 'oid': continue
498                        if t in ['int', 'decimal', 'seq', 'money']:
499                                a[k] = 0
500                        elif t == 'bool':
501                                a[k] = 'f'
502                        else:
503                                a[k] = ''
504                return a
505
506        def delete(self, cl, a):
507                """Delete an existing row in a database table.
508
509                This method deletes the row from a table.
510                It deletes based on the OID munged as described above.
511
512                """
513                # Like update, delete works on the oid.
514                # One day we will be testing that the record to be deleted
515                # isn't referenced somewhere (or else PostgreSQL will).
516                qcl = _join_parts(self._split_schema(cl)) # build qualified name
517                foid = 'oid(%s)' % qcl # build mangled oid
518
519                # XXX this code is for backwards compatibility and will be
520                # XXX removed eventually
521                if not a.has_key(foid):
522                        ofoid = 'oid_' + self._split_schema(cl)[-1]
523                        if a.has_key(ofoid):
524                                a[foid] = a[ofoid]
525
526                q = 'DELETE FROM %s WHERE oid=%s' % (qcl, a[foid])
527                self._do_debug(q)
528                self.db.query(q)
529
530# if run as script, print some information
531if __name__ == '__main__':
532        print 'PyGreSQL version', version
533        print
534        print __doc__
Note: See TracBrowser for help on using the repository browser.