source: trunk/module/pg.py @ 201

Last change on this file since 201 was 201, checked in by darcy, 14 years ago

Remove debugging line.

File size: 9.4 KB
Line 
1# pg.py
2# Written by D'Arcy J.M. Cain
3# $Id: pg.py,v 1.30 2005-05-23 14:09:56 darcy Exp $
4
5# This library implements some basic database management stuff.  It
6# includes the pg module and builds on it.  This is known as the
7# "Classic" interface.  For DB-API compliance use the pgdb module.
8
9from _pg import *
10from types import *
11import string, re, sys
12
13# utility function
14# We expect int, seq, decimal, text or date (more later)
15def _quote(d, t):
16        if d == None:
17                return "NULL"
18
19        if t in ['int', 'seq']:
20                if d == "": return "NULL"
21                return "%d" % long(d)
22
23        if t == 'decimal':
24                if d == "": return "NULL"
25                return "%f" % float(d)
26
27        if t == 'money':
28                if d == "": return "NULL"
29                return "'%.2f'" % float(d)
30
31        if t == 'bool':
32                # Can't run upper() on these
33                if d in (0, 1): return ("'f'", "'t'")[d]
34
35                if string.upper(d) in ['T', 'TRUE', 'Y', 'YES', '1', 'ON']:
36                        return "'t'"
37                else:
38                        return "'f'"
39
40        if t == 'date' and d == '': return "NULL"
41        if t in ('inet', 'cidr') and d == '': return "NULL"
42
43        return "'%s'" % string.strip(re.sub("'", "''", \
44                                                        re.sub("\\\\", "\\\\\\\\", "%s" % d)))
45
46class DB:
47        """This class wraps the pg connection type"""
48
49        def __init__(self, *args, **kw):
50                self.db = connect(*args, **kw)
51
52                # Create convenience methods, in a way that is still overridable
53                # (members are not copied because they are actually functions)
54                for e in self.db.__methods__:
55                        if e not in ("close", "query"): # These are wrapped separately
56                                setattr(self, e, getattr(self.db, e))
57
58                self.__attnames = {}
59                self.__pkeys = {}
60                self.__args = args, kw
61                self.debug = None       # For debugging scripts, this can be set to a
62                                                        # string format specification (e.g. in a CGI
63                                                        # set to "%s<BR>",) to a function which takes
64                                                        # a string argument or a file object to write
65                                                        # debug statements to.
66
67        def _do_debug(self, s):
68                if not self.debug: return
69                if isinstance(self.debug, StringType): print self.debug % s
70                if isinstance(self.debug, FunctionType): self.debug(s)
71                if isinstance(self.debug, FileType): print >> self.debug, s
72
73        # wrap close so we can track state
74        def close(self,):
75                self.db.close()
76                self.db = None
77
78        # in case we need another connection to the same database
79        # note that we can still reopen a database that we have closed
80        def reopen(self):
81                if self.db: self.close()
82                try: self.db = connect(*self.__args[0], **self.__args[1])
83                except:
84                        self.db = None
85                        raise
86
87        # wrap query for debugging
88        def query(self, qstr):
89                self._do_debug(qstr)
90                return self.db.query(qstr)
91
92        def pkey(self, cl, newpkey = None):
93                """This method returns the primary key of a class.  If newpkey
94                        is set and is set and is not a dictionary then set that
95                        value as the primary key of the class.  If it is a dictionary
96                        then replace the __pkeys dictionary with it."""
97                # Get all the primary keys at once
98                if isinstance(newpkey, DictType):
99                        self.__pkeys = newpkey
100                        return
101
102                if newpkey:
103                        self.__pkeys[cl] = newpkey
104                        return newpkey
105
106                if self.__pkeys == {}:
107                        for rel, nsp, att in self.db.query("""
108                                                SELECT pg_class.relname, pg_namespace.nspname,
109                                                        pg_attribute.attname
110                                                FROM pg_class, pg_namespace, pg_attribute, pg_index
111                                                WHERE pg_class.oid = pg_attribute.attrelid AND
112                                                        pg_class.relnamespace = pg_namespace.oid AND
113                                                        pg_class.oid = pg_index.indrelid AND
114                                                        pg_index.indkey[0] = pg_attribute.attnum AND
115                                                        pg_namespace.nspname NOT LIKE 'pg_%' AND
116                                                        pg_index.indisprimary = 't' AND
117                                                        pg_attribute.attisdropped = 'f'""").getresult():
118                                self.__pkeys["%s.%s" % (nsp, rel)] = att
119                # Give it one more chance in case it was added after we started
120                elif not self.__pkeys.has_key(cl):
121                        self.__pkeys = {}
122                        return self.pkey(cl)
123
124                # make sure that we have a namespace
125                if cl.find('.') == -1:
126                        cl = 'public.' + cl
127
128                # will raise an exception if primary key doesn't exist
129                return self.__pkeys[cl]
130
131        def get_databases(self):
132                return [x[0] for x in
133                        self.db.query("SELECT datname FROM pg_database").getresult()]
134
135        def get_tables(self):
136                return [x[0] for x in
137                                self.db.query("""SELECT relname FROM pg_class
138                                                WHERE relkind = 'r' AND
139                                                        relname !~ '^Inv' AND
140                                                        relname !~ '^pg_'""").getresult()]
141
142        def get_attnames(self, cl, newattnames = None):
143                """This method gets a list of attribute names for a class.  If
144                        the optional newattnames exists it must be a dictionary and
145                        will become the new attribute names dictionary."""
146
147                if isinstance(newattnames, DictType):
148                        self.__attnames = newattnames
149                        return
150                elif newattnames:
151                        raise ProgrammingError, \
152                                        "If supplied, newattnames must be a dictionary"
153
154                # May as well cache them
155                if self.__attnames.has_key(cl):
156                        return self.__attnames[cl]
157
158                # check for schema name
159                if cl.find('.') == -1:
160                        schema = 'public'
161                        table = cl
162                else:
163                        try: schema, table = cl.split('.')
164                        except ValueError, err:
165                                raise ProgrammingError('Invalid class %s' % cl)
166
167                query = """SELECT pg_attribute.attname, pg_type.typname
168                                        FROM pg_class, pg_attribute, pg_type, pg_namespace
169                                        WHERE pg_class.relname = '%s' AND
170                                                pg_namespace.nspname = '%s' AND
171                                                pg_attribute.attnum > 0 AND
172                                                pg_attribute.attrelid = pg_class.oid AND
173                                                pg_attribute.atttypid = pg_type.oid AND
174                                                pg_class.relnamespace = pg_namespace.oid AND
175                                                pg_attribute.attisdropped = 'f'""" % (table, schema)
176
177                l = {}
178                for attname, typname in self.db.query(query).getresult():
179                        if re.match("^interval", typname):
180                                l[attname] = 'date'
181                        elif re.match("^int", typname):
182                                l[attname] = 'int'
183                        elif re.match("^oid", typname):
184                                l[attname] = 'int'
185                        elif re.match("^text", typname):
186                                l[attname] = 'text'
187                        elif re.match("^char", typname):
188                                l[attname] = 'text'
189                        elif re.match("^name", typname):
190                                l[attname] = 'text'
191                        elif re.match("^abstime", typname):
192                                l[attname] = 'date'
193                        elif re.match("^date", typname):
194                                l[attname] = 'date'
195                        elif re.match("^timestamp", typname):
196                                l[attname] = 'date'
197                        elif re.match("^bool", typname):
198                                l[attname] = 'bool'
199                        elif re.match("^float", typname):
200                                l[attname] = 'decimal'
201                        elif re.match("^money", typname):
202                                l[attname] = 'money'
203                        else:
204                                l[attname] = 'text'
205
206                l['oid'] = 'int'                                # every table has this
207                self.__attnames[cl] = l         # cache it
208                return self.__attnames[cl]
209
210        # return a tuple from a database
211        def get(self, cl, arg, keyname = None, view = 0):
212                if cl[-1] == '*':                       # need parent table name
213                        xcl = cl[:-1]
214                else:
215                        xcl = cl
216
217                if keyname == None:                     # use the primary key by default
218                        keyname = self.pkey(xcl)
219
220                fnames = self.get_attnames(xcl)
221
222                if isinstance(arg, DictType):
223                        # To allow users to work with multiple tables we munge the
224                        # name when the key is "oid"
225                        if keyname == 'oid': k = arg['oid_%s' % xcl]
226                        else: k = arg[keyname]
227                else:
228                        k = arg
229                        arg = {}
230
231                # We want the oid for later updates if that isn't the key
232                if keyname == 'oid':
233                        q = "SELECT * FROM %s WHERE oid = %s" % (cl, k)
234                elif view:
235                        q = "SELECT * FROM %s WHERE %s = %s" % \
236                                (cl, keyname, _quote(k, fnames[keyname]))
237                else:
238                        q = "SELECT oid AS oid_%s, %s FROM %s WHERE %s = %s" % \
239                                (xcl, string.join(fnames.keys(), ','),\
240                                        cl, keyname, _quote(k, fnames[keyname]))
241
242                self._do_debug(q)
243                res = self.db.query(q).dictresult()
244                if res == []:
245                        raise DatabaseError, \
246                                "No such record in %s where %s is %s" % \
247                                                                (cl, keyname, _quote(k, fnames[keyname]))
248                        return None
249
250                for k in res[0].keys():
251                        arg[k] = res[0][k]
252
253                return arg
254
255        # Inserts a new tuple into a table
256        # We currently don't support insert into views although PostgreSQL does
257        def insert(self, cl, a):
258                fnames = self.get_attnames(cl)
259                l = []
260                n = []
261                for f in fnames.keys():
262                        if f != 'oid' and a.has_key(f):
263                                l.append(_quote(a[f], fnames[f]))
264                                n.append(f)
265
266                q = "INSERT INTO %s (%s) VALUES (%s)" % \
267                        (cl, string.join(n, ','), string.join(l, ','))
268                self._do_debug(q)
269                a['oid_%s' % cl] = self.db.query(q)
270
271                # reload the dictionary to catch things modified by engine
272                # note that get() changes 'oid' below to oid_table
273                # if no read perms (it can and does happen) return None
274                try: return self.get(cl, a, 'oid')
275                except: return None
276
277        # Update always works on the oid which get returns if available
278        # otherwise use the primary key.  Fail if neither.
279        def update(self, cl, a):
280                foid = 'oid_%s' % cl
281                if a.has_key(foid):
282                        where = "oid = %s" % a[foid]
283                else:
284                        try: pk = self.pkey(cl)
285                        except: raise ProgrammingError, \
286                                        "Update needs primary key or oid as %s" % foid
287
288                        where = "%s = '%s'" % (pk, a[pk])
289
290                v = []
291                k = 0
292                fnames = self.get_attnames(cl)
293
294                for ff in fnames.keys():
295                        if ff != 'oid' and a.has_key(ff):
296                                v.append("%s = %s" % (ff, _quote(a[ff], fnames[ff])))
297
298                if v == []:
299                        return None
300
301                q = "UPDATE %s SET %s WHERE %s" % (cl, string.join(v, ','), where)
302                self._do_debug(q)
303                self.db.query(q)
304
305                # reload the dictionary to catch things modified by engine
306                if a.has_key(foid):
307                        return self.get(cl, a, 'oid')
308                else:
309                        return self.get(cl, a)
310
311        # At some point we will need a way to get defaults from a table
312        def clear(self, cl, a = {}):
313                fnames = self.get_attnames(cl)
314                for ff in fnames.keys():
315                        if fnames[ff] in ['int', 'decimal', 'seq', 'money']:
316                                a[ff] = 0
317                        else:
318                                a[ff] = ""
319
320                a['oid'] = 0
321                return a
322
323        # Like update, delete works on the oid
324        # one day we will be testing that the record to be deleted
325        # isn't referenced somewhere (or else PostgreSQL will)
326        def delete(self, cl, a):
327                q = "DELETE FROM %s WHERE oid = %s" % (cl, a['oid_%s' % cl])
328                self._do_debug(q)
329                self.db.query(q)
330
Note: See TracBrowser for help on using the repository browser.