source: trunk/module/pg.py @ 222

Last change on this file since 222 was 222, checked in by cito, 14 years ago

Added shebang line as requested by Devrim Gunduz
for building PostgreSQL RPM packages.

File size: 15.2 KB
Line 
1#!/usr/bin/env python
2#
3# pg.py
4#
5# Written by D'Arcy J.M. Cain
6# Improved by Christoph Zwerschke
7#
8# $Id: pg.py,v 1.34 2005-11-18 13:51:02 cito Exp $
9#
10
11"""PyGreSQL classic interface.
12
13This pg module implements some basic database management stuff.
14It includes the _pg module and builds on it, providing the higher
15level wrapper class named DB with addtional functionality.
16This is known as the "classic" ("old style") PyGreSQL interface.
17For a DB-API 2 compliant interface use the newer pgdb module.
18
19"""
20
21from _pg import *
22from types import *
23
24# Auxiliary functions which are independent from a DB connection:
25
26def _quote(d, t):
27        """Return quotes if needed."""
28        if d is None:
29                return 'NULL'
30        if t in ('int', 'seq', 'decimal'):
31                if d == '': return 'NULL'
32                return str(d)
33        if t == 'money':
34                if d == '': return 'NULL'
35                return "'%.2f'" % float(d)
36        if t == 'bool':
37                if type(d) == StringType:
38                        if d == '': return 'NULL'
39                        d = str(d).lower() in ('t', 'true', '1', 'y', 'yes', 'on')
40                else:
41                        d = not not d
42                return ("'f'", "'t'")[d]
43        if t in ('date', 'inet', 'cidr'):
44                if d == '': return 'NULL'
45        return "'%s'" % str(d).replace('\\','\\\\').replace('\'','\\\'')
46
47def _is_quoted(s):
48        """Check whether this string is a quoted identifier."""
49        s = s.replace('_', 'a')
50        return not s.isalnum() or s[:1].isdigit() or s != s.lower()
51
52def _is_unquoted(s):
53        """Check whether this string is an unquoted identifier."""
54        s = s.replace('_', 'a')
55        return s.isalnum() and not s[:1].isdigit()
56
57def _split_first_part(s):
58        """Split the first part of a dot separated string."""
59        s = s.lstrip()
60        if s[:1] == '"':
61                p = []
62                s = s.split('"', 3)[1:]
63                p.append(s[0])
64                while len(s) == 3 and s[1] == '':
65                        p.append('"')
66                        s = s[2].split('"', 2)
67                        p.append(s[0])
68                p = [''.join(p)]
69                s = '"'.join(s[1:]).lstrip()
70                if s:
71                        if s[:0] == '.':
72                                p.append(s[1:])
73                        else:
74                                s = _split_first_part(s)
75                                p[0] += s[0]
76                                if len(s) > 1:
77                                        p.append(s[1])
78        else:
79                p = s.split('.', 1)
80                s = p[0].rstrip()
81                if _is_unquoted(s):
82                        s = s.lower()
83                p[0] = s
84        return p
85
86def _split_parts(s):
87        """Split all parts of a dot separated string."""
88        q = []
89        while s:
90                s = _split_first_part(s)
91                q.append(s[0])
92                if len(s) < 2: break
93                s = s[1]
94        return q
95
96def _join_parts(s):
97        """Join all parts of a dot separated string."""
98        return '.'.join([_is_quoted(p) and '"%s"' % p or p for p in s])
99
100# The PostGreSQL database connection interface:
101
102class DB:
103        """Wrapper class for the _pg connection type."""
104
105        def __init__(self, *args, **kw):
106                self.db = connect(*args, **kw)
107                self.dbname = self.db.db
108                self.__attnames = {}
109                self.__pkeys = {}
110                self.__args = args, kw
111                self.debug = None # For debugging scripts, this can be set
112                        # * to a string format specification (e.g. in a CGI set to "%s<BR>"),
113                        # * to a function which takes a string argument or
114                        # * to a file object to write debug statements to.
115
116        def __getattr__(self, name):
117                # All undefined members are the same as in the underlying pg connection:
118                if self.db:
119                        return getattr(self.db, name)
120                else:
121                        raise InternalError, 'Connection is not valid'
122
123        def _do_debug(self, s):
124                """Print a debug message."""
125                if not self.debug: return
126                if isinstance(self.debug, StringType): print self.debug % s
127                if isinstance(self.debug, FunctionType): self.debug(s)
128                if isinstance(self.debug, FileType): print >> self.debug, s
129
130        def close(self):
131                """Close the database connection."""
132                # Wraps shared library function so we can track state.
133
134                if self.db:
135                        self.db.close()
136                        self.db = None
137                else:
138                        raise InternalError, 'Connection already closed'
139
140        def reopen(self):
141                """Reopen connection to the database.
142
143                Used in case we need another connection to the same database.
144                Note that we can still reopen a database that we have closed.
145
146                """
147                if self.db:
148                        self.db.close()
149                try:
150                        self.db = connect(*self.__args[0], **self.__args[1])
151                except:
152                        self.db = None
153                        raise
154
155        def query(self, qstr):
156                """Executes a SQL command string.
157
158                This method simply sends a SQL query to the database. If the query is
159                an insert statement, the return value is the OID of the newly
160                inserted row.  If it is otherwise a query that does not return a result
161                (ie. is not a some kind of SELECT statement), it returns None.
162                Otherwise, it returns a pgqueryobject that can be accessed via the
163                getresult or dictresult method or simply printed.
164
165                """
166                # Wraps shared library function for debugging.
167                if not self.db:
168                        raise InternalError, 'Connection is not valid'
169                self._do_debug(qstr)
170                return self.db.query(qstr)
171
172        def _split_schema(self, cl):
173                """Return schema and name of object separately.
174
175                This auxiliary function splits off the namespace (schema)
176                belonging to the class with the name cl. If the class name
177                is not qualified, the function is able to determine the schema
178                of the class, taking into account the current search path.
179
180                """
181                s = _split_parts(cl)
182                if len(s) > 1: # name already qualfied?
183                        # should be database.schema.table or schema.table
184                        if len(s) > 3:
185                                raise ProgrammingError, 'Too many dots in class name %s' % cl
186                        schema, cl = s[-2:]
187                else:
188                        cl = s[0]
189                        # determine search path
190                        query = 'SELECT current_schemas(TRUE)'
191                        schemas = self.db.query(query).getresult()[0][0][1:-1].split(',')
192                        if schemas: # non-empty path
193                                # search schema for this object in the current search path
194                                query = ' UNION '.join(["SELECT %d AS n, '%s' AS nspname" % s
195                                        for s in enumerate(schemas)])
196                                query = ("SELECT nspname FROM pg_class"
197                                        " JOIN pg_namespace ON pg_class.relnamespace=pg_namespace.oid"
198                                        " JOIN (%s) AS p USING (nspname)"
199                                        " WHERE pg_class.relname='%s'"
200                                        " ORDER BY n LIMIT 1" % (query, cl))
201                                schema = self.db.query(query).getresult()
202                                if schema: # schema found
203                                        schema = schema[0][0]
204                                else: # object not found in current search path
205                                        schema = 'public'
206                        else: # empty path
207                                schema = 'public'
208                return schema, cl
209
210        def pkey(self, cl, newpkey = None):
211                """This method gets or sets the primary key of a class.
212
213                If newpkey is set and is not a dictionary then set that
214                value as the primary key of the class.  If it is a dictionary
215                then replace the __pkeys dictionary with it.
216
217                """
218                # Get all the primary keys at once
219                if isinstance(newpkey, DictType):
220                        self.__pkeys = newpkey
221                        return newpkey
222                qcl = _join_parts(self._split_schema(cl)) # build qualified name
223                if newpkey:
224                        self.__pkeys[qcl] = newpkey
225                        return newpkey
226                if self.__pkeys == {} or not self.__pkeys.has_key(qcl):
227                        # if not found, determine pkey again in case it was added after we started
228                        for r in self.db.query("SELECT pg_namespace.nspname"
229                                ",pg_class.relname,pg_attribute.attname FROM pg_class"
230                                " JOIN pg_namespace ON pg_namespace.oid=pg_class.relnamespace"
231                                " AND pg_namespace.nspname NOT LIKE 'pg_%'"
232                                " JOIN pg_attribute ON pg_attribute.attrelid=pg_class.oid"
233                                " AND pg_attribute.attisdropped='f'"
234                                " JOIN pg_index ON pg_index.indrelid=pg_class.oid"
235                                " AND pg_index.indisprimary='t'"
236                                " AND pg_index.indkey[0]=pg_attribute.attnum").getresult():
237                                self.__pkeys[_join_parts(r[:2])] = r[2] # build qualified name
238                        self._do_debug(self.__pkeys)
239                # will raise an exception if primary key doesn't exist
240                return self.__pkeys[qcl]
241
242        def get_databases(self):
243                """Get list of databases in the system."""
244                return [s for s, in
245                        self.db.query('SELECT datname FROM pg_database').getresult()]
246
247        def get_tables(self):
248                """Get list of tables in connected database."""
249                return [_join_parts(s) for s in
250                        self.db.query("SELECT pg_namespace.nspname"
251                                ",pg_class.relname FROM pg_class"
252                                " JOIN pg_namespace ON pg_namespace.oid=pg_class.relnamespace"
253                                " WHERE pg_class.relkind='r' AND"
254                                " pg_class.relname!~'^Inv' AND "
255                                " pg_class.relname!~'^pg_' ORDER BY 1,2").getresult()]
256
257        def get_attnames(self, cl, newattnames = None):
258                """Given the name of a table, digs out the set of attribute names.
259
260                Returns a dictionary of attribute names (the names are the keys,
261                the values are the names of the attributes' types).
262                If the optional newattnames exists, it must be a dictionary and
263                will become the new attribute names dictionary.
264
265                """
266                if isinstance(newattnames, DictType):
267                        self.__attnames = newattnames
268                        return
269                elif newattnames:
270                        raise ProgrammingError, \
271                                'If supplied, newattnames must be a dictionary'
272                cl = self._split_schema(cl) # split into schema and cl
273                qcl = _join_parts(cl) # build qualified name
274                # May as well cache them:
275                if self.__attnames.has_key(qcl):
276                        return self.__attnames[qcl]
277                if qcl not in self.get_tables():
278                        raise ProgrammingError, 'Class %s does not exist' % qcl
279                t = {}
280                for att, typ in self.db.query("SELECT pg_attribute.attname"
281                        ",pg_type.typname FROM pg_class"
282                        " JOIN pg_namespace ON pg_class.relnamespace=pg_namespace.oid"
283                        " JOIN pg_attribute ON pg_attribute.attrelid=pg_class.oid"
284                        " JOIN pg_type ON pg_type.oid=pg_attribute.atttypid"
285                        " WHERE pg_namespace.nspname='%s' AND pg_class.relname='%s'"
286                        " AND pg_attribute.attnum>0 AND pg_attribute.attisdropped='f'"
287                                % cl).getresult():
288                        if typ.startswith('interval'):
289                                t[att] = 'date'
290                        elif typ.startswith('int'):
291                                t[att] = 'int'
292                        elif typ.startswith('oid'):
293                                t[att] = 'int'
294                        elif typ.startswith('text'):
295                                t[att] = 'text'
296                        elif typ.startswith('char'):
297                                t[att] = 'text'
298                        elif typ.startswith('name'):
299                                t[att] = 'text'
300                        elif typ.startswith('abstime'):
301                                t[att] = 'date'
302                        elif typ.startswith('date'):
303                                t[att] = 'date'
304                        elif typ.startswith('timestamp'):
305                                t[att] = 'date'
306                        elif typ.startswith('bool'):
307                                t[att] = 'bool'
308                        elif typ.startswith('float'):
309                                t[att] = 'decimal'
310                        elif typ.startswith('money'):
311                                t[att] = 'money'
312                        else:
313                                t[att] = 'text'
314                t['oid'] = 'int' # every table has this
315                self.__attnames[qcl] = t # cache it
316                return self.__attnames[qcl]
317
318        def get(self, cl, arg, keyname = None, view = 0):
319                """Get a tuple from a database table.
320
321                This method is the basic mechanism to get a single row.  It assumes
322                that the key specifies a unique row.  If keyname is not specified
323                then the primary key for the table is used.  If arg is a dictionary
324                then the value for the key is taken from it and it is modified to
325                include the new values, replacing existing values where necessary.
326                The OID is also put into the dictionary, but in order to allow the
327                caller to work with multiple tables, it is munged as oid(schema.table).
328
329                """
330                if cl.endswith('*'): # scan descendant tables?
331                        cl = cl[:-1].rstrip() # need parent table name
332                qcl = _join_parts(self._split_schema(cl)) # build qualified name
333                # To allow users to work with multiple tables,
334                # we munge the name when the key is "oid"
335                foid = 'oid(%s)' % qcl # build mangled name
336                if keyname == None: # use the primary key by default
337                        keyname = self.pkey(qcl)
338                fnames = self.get_attnames(qcl)
339                if isinstance(arg, DictType):
340                        k = arg[keyname == 'oid' and foid or keyname]
341                else:
342                        k = arg
343                        arg = {}
344                # We want the oid for later updates if that isn't the key
345                if keyname == 'oid':
346                        q = 'SELECT * FROM %s WHERE oid=%s LIMIT 1' % (qcl, k)
347                elif view:
348                        q = 'SELECT * FROM %s WHERE %s=%s LIMIT 1' % \
349                                (qcl, keyname, _quote(k, fnames[keyname]))
350                else:
351                        q = 'SELECT %s FROM %s WHERE %s=%s LIMIT 1' % \
352                                (','.join(fnames.keys()), qcl, \
353                                        keyname, _quote(k, fnames[keyname]))
354                self._do_debug(q)
355                res = self.db.query(q).dictresult()
356                if not res:
357                        raise DatabaseError, \
358                                'No such record in %s where %s=%s' % \
359                                        (qcl, keyname, _quote(k, fnames[keyname]))
360                for k, d in res[0].items():
361                        if k == 'oid':
362                                k = foid
363                        arg[k] = d
364                return arg
365
366        def insert(self, cl, a):
367                """Insert a tuple into a database table.
368
369                This method inserts values into the table specified filling in the
370                values from the dictionary.  It then reloads the dictionary with the
371                values from the database.  This causes the dictionary to be updated
372                with values that are modified by rules, triggers, etc.
373
374                Note: The method currently doesn't support insert into views
375                although PostgreSQL does.
376
377                """
378                qcl = _join_parts(self._split_schema(cl)) # build qualified name
379                foid = 'oid(%s)' % qcl # build mangled name
380                fnames = self.get_attnames(qcl)
381                t = []
382                n = []
383                for f in fnames.keys():
384                        if f != 'oid' and a.has_key(f):
385                                t.append(_quote(a[f], fnames[f]))
386                                n.append(f)
387                q = 'INSERT INTO %s (%s) VALUES (%s)' % \
388                        (qcl, ','.join(n), ','.join(t))
389                self._do_debug(q)
390                a[foid] = self.db.query(q)
391                # Reload the dictionary to catch things modified by engine.
392                # Note that get() changes 'oid' below to oid_schema_table.
393                # If no read perms (it can and does happen), return None.
394                try:
395                        return self.get(qcl, a, 'oid')
396                except:
397                        return None
398
399        def update(self, cl, a):
400                """Update an existing row in a database table.
401
402                Similar to insert but updates an existing row.  The update is based
403                on the OID value as munged by get.  The array returned is the
404                one sent modified to reflect any changes caused by the update due
405                to triggers, rules, defaults, etc.
406
407                """
408                # Update always works on the oid which get returns if available,
409                # otherwise use the primary key.  Fail if neither.
410                qcl = _join_parts(self._split_schema(cl)) # build qualified name
411                foid = 'oid(%s)' % qcl # build mangled oid
412                if a.has_key(foid):
413                        where = "oid=%s" % a[foid]
414                else:
415                        try:
416                                pk = self.pkey(qcl)
417                        except:
418                                raise ProgrammingError, \
419                                        'Update needs primary key or oid as %s' % foid
420                        where = "%s='%s'" % (pk, a[pk])
421                v = []
422                k = 0
423                fnames = self.get_attnames(qcl)
424                for ff in fnames.keys():
425                        if ff != 'oid' and a.has_key(ff):
426                                v.append('%s=%s' % (ff, _quote(a[ff], fnames[ff])))
427                if v == []:
428                        return None
429                q = 'UPDATE %s SET %s WHERE %s' % (qcl, ','.join(v), where)
430                self._do_debug(q)
431                self.db.query(q)
432                # Reload the dictionary to catch things modified by engine:
433                if a.has_key(foid):
434                        return self.get(qcl, a, 'oid')
435                else:
436                        return self.get(qcl, a)
437
438        def clear(self, cl, a = None):
439                """
440
441                This method clears all the attributes to values determined by the types.
442                Numeric types are set to 0, Booleans are set to 'f', dates are set
443                to 'now()' and everything else is set to the empty string.
444                If the array argument is present, it is used as the array and any entries
445                matching attribute names are cleared with everything else left unchanged.
446
447                """
448                # At some point we will need a way to get defaults from a table.
449                if a is None: a = {} # empty if argument is not present
450                qcl = _join_parts(self._split_schema(cl)) # build qualified name
451                foid = 'oid(%s)' % qcl # build mangled oid
452                fnames = self.get_attnames(qcl)
453                for k, t in fnames.items():
454                        if k == 'oid': continue
455                        if t in ['int', 'decimal', 'seq', 'money']:
456                                a[k] = 0
457                        elif t == 'bool':
458                                a[k] = 'f'
459                        elif t == 'date':
460                                a[k] = 'now()'
461                        else:
462                                a[k] = ''
463                return a
464
465        def delete(self, cl, a):
466                """Delete an existing row in a database table.
467
468                This method deletes the row from a table.
469                It deletes based on the OID munged as described above.
470
471                """
472                # Like update, delete works on the oid.
473                # One day we will be testing that the record to be deleted
474                # isn't referenced somewhere (or else PostgreSQL will).
475                qcl = _join_parts(self._split_schema(cl)) # build qualified name
476                foid = 'oid(%s)' % qcl # build mangled oid
477                q = 'DELETE FROM %s WHERE oid=%s' % (qcl, a[foid])
478                self._do_debug(q)
479                self.db.query(q)
480
481# if run as script, print some information
482if __name__ == '__main__':
483        print 'PyGreSQL version', version
484        print
485        print __doc__
Note: See TracBrowser for help on using the repository browser.