source: trunk/module/pg.py @ 181

Last change on this file since 181 was 181, checked in by darcy, 15 years ago

Don't crash in update method if primary key doesn't exist yet.

File size: 8.7 KB
Line 
1# pg.py
2# Written by D'Arcy J.M. Cain
3
4# This library implements some basic database management stuff.  It
5# includes the pg module and builds on it.  This is known as the
6# "Classic" interface.  For DB-API compliance use the pgdb module.
7
8from _pg import *
9from types import *
10import string, re, sys
11
12# utility function
13# We expect int, seq, decimal, text or date (more later)
14def _quote(d, t):
15        if d == None:
16                return "NULL"
17
18        if t in ['int', 'seq']:
19                if d == "": return "NULL"
20                return "%d" % long(d)
21
22        if t == 'decimal':
23                if d == "": return "NULL"
24                return "%f" % float(d)
25
26        if t == 'money':
27                if d == "": return "NULL"
28                return "'%.2f'" % float(d)
29
30        if t == 'bool':
31                # Can't run upper() on these
32                if d in (0, 1): return ("'f'", "'t'")[d]
33
34                if string.upper(d) in ['T', 'TRUE', 'Y', 'YES', '1', 'ON']:
35                        return "'t'"
36                else:
37                        return "'f'"
38
39        if t == 'date' and d == '': return "NULL"
40        if t in ('inet', 'cidr') and d == '': return "NULL"
41
42        return "'%s'" % string.strip(re.sub("'", "''", \
43                                                        re.sub("\\\\", "\\\\\\\\", "%s" % d)))
44
45class DB:
46        """This class wraps the pg connection type"""
47
48        def __init__(self, *args, **kw):
49                self.db = connect(*args, **kw)
50
51                # Create convenience methods, in a way that is still overridable
52                # (members are not copied because they are actually functions)
53                for e in self.db.__methods__:
54                        if e not in ("close", "query"): # These are wrapped separately
55                                setattr(self, e, getattr(self.db, e))
56
57                self.__attnames = {}
58                self.__pkeys = {}
59                self.__args = args, kw
60                self.debug = None       # For debugging scripts, this can be set to a
61                                                        # string format specification (e.g. in a CGI
62                                                        # set to "%s<BR>",) to a function which takes
63                                                        # a string argument or a file object to write
64                                                        # debug statements to.
65
66        def _do_debug(self, s):
67                if not self.debug: return
68                if isinstance(self.debug, StringType): print self.debug % s
69                if isinstance(self.debug, FunctionType): self.debug(s)
70                if isinstance(self.debug, FileType): print >> self.debug, s
71
72        # wrap close so we can track state
73        def close(self,):
74                self.db.close()
75                self.db = None
76
77        # in case we need another connection to the same database
78        # note that we can still reopen a database that we have closed
79        def reopen(self):
80                if self.db: self.close()
81                try: self.db = connect(*self.__args[0], **self.__args[1])
82                except:
83                        self.db = None
84                        raise
85
86        # wrap query for debugging
87        def query(self, qstr):
88                self._do_debug(qstr)
89                return self.db.query(qstr)
90
91        def pkey(self, cl, newpkey = None):
92                """This method returns the primary key of a class.  If newpkey
93                        is set and is set and is not a dictionary then set that
94                        value as the primary key of the class.  If it is a dictionary
95                        then replace the __pkeys dictionary with it."""
96                # Get all the primary keys at once
97                if isinstance(newpkey, DictType):
98                        self.__pkeys = newpkey
99                        return
100
101                if newpkey:
102                        self.__pkeys[cl] = newpkey
103                        return newpkey
104
105                if self.__pkeys == {}:
106                        for rel, att in self.db.query("""SELECT
107                                                        pg_class.relname, pg_attribute.attname
108                                                FROM pg_class, pg_attribute, pg_index
109                                                WHERE pg_class.oid = pg_attribute.attrelid AND
110                                                        pg_class.oid = pg_index.indrelid AND
111                                                        pg_index.indkey[0] = pg_attribute.attnum AND
112                                                        pg_index.indisprimary = 't' AND
113                                                        pg_attribute.attisdropped = 'f'""").getresult():
114                                self.__pkeys[rel] = att
115                # Give it one more chance in case it was added after we started
116                elif not self.__pkeys.has_key(cl):
117                        self.__pkeys = {}
118                        return self.pkey(cl)
119
120                # will raise an exception if primary key doesn't exist
121                return self.__pkeys[cl]
122
123        def get_databases(self):
124                return [x[0] for x in
125                        self.db.query("SELECT datname FROM pg_database").getresult()]
126
127        def get_tables(self):
128                return [x[0] for x in
129                                self.db.query("""SELECT relname FROM pg_class
130                                                WHERE relkind = 'r' AND
131                                                        relname !~ '^Inv' AND
132                                                        relname !~ '^pg_'""").getresult()]
133
134        def get_attnames(self, cl, newattnames = None):
135                """This method gets a list of attribute names for a class.  If
136                        the optional newattnames exists it must be a dictionary and
137                        will become the new attribute names dictionary."""
138
139                if isinstance(newattnames, DictType):
140                        self.__attnames = newattnames
141                        return
142                elif newattnames:
143                        raise ProgrammingError, \
144                                        "If supplied, newattnames must be a dictionary"
145
146                # May as well cache them
147                if self.__attnames.has_key(cl):
148                        return self.__attnames[cl]
149
150                query = """SELECT pg_attribute.attname, pg_type.typname
151                                        FROM pg_class, pg_attribute, pg_type
152                                        WHERE pg_class.relname = '%s' AND
153                                                pg_attribute.attnum > 0 AND
154                                                pg_attribute.attrelid = pg_class.oid AND
155                                                pg_attribute.atttypid = pg_type.oid AND
156                                                pg_attribute.attisdropped = 'f'"""
157
158                l = {}
159                for attname, typname in self.db.query(query % cl).getresult():
160                        if re.match("^interval", typname):
161                                l[attname] = 'text'
162                        if re.match("^int", typname):
163                                l[attname] = 'int'
164                        elif re.match("^oid", typname):
165                                l[attname] = 'int'
166                        elif re.match("^text", typname):
167                                l[attname] = 'text'
168                        elif re.match("^char", typname):
169                                l[attname] = 'text'
170                        elif re.match("^name", typname):
171                                l[attname] = 'text'
172                        elif re.match("^abstime", typname):
173                                l[attname] = 'date'
174                        elif re.match("^date", typname):
175                                l[attname] = 'date'
176                        elif re.match("^timestamp", typname):
177                                l[attname] = 'date'
178                        elif re.match("^bool", typname):
179                                l[attname] = 'bool'
180                        elif re.match("^float", typname):
181                                l[attname] = 'decimal'
182                        elif re.match("^money", typname):
183                                l[attname] = 'money'
184                        else:
185                                l[attname] = 'text'
186
187                l['oid'] = 'int'                                # every table has this
188                self.__attnames[cl] = l         # cache it
189                return self.__attnames[cl]
190
191        # return a tuple from a database
192        def get(self, cl, arg, keyname = None, view = 0):
193                if cl[-1] == '*':                       # need parent table name
194                        xcl = cl[:-1]
195                else:
196                        xcl = cl
197
198                if keyname == None:                     # use the primary key by default
199                        keyname = self.pkey(xcl)
200
201                fnames = self.get_attnames(xcl)
202
203                if isinstance(arg, DictType):
204                        # To allow users to work with multiple tables we munge the
205                        # name when the key is "oid"
206                        if keyname == 'oid': k = arg['oid_%s' % xcl]
207                        else: k = arg[keyname]
208                else:
209                        k = arg
210                        arg = {}
211
212                # We want the oid for later updates if that isn't the key
213                if keyname == 'oid':
214                        q = "SELECT * FROM %s WHERE oid = %s" % (cl, k)
215                elif view:
216                        q = "SELECT * FROM %s WHERE %s = %s" % \
217                                (cl, keyname, _quote(k, fnames[keyname]))
218                else:
219                        q = "SELECT oid AS oid_%s, %s FROM %s WHERE %s = %s" % \
220                                (xcl, string.join(fnames.keys(), ','),\
221                                        cl, keyname, _quote(k, fnames[keyname]))
222
223                self._do_debug(q)
224                res = self.db.query(q).dictresult()
225                if res == []:
226                        raise DatabaseError, \
227                                "No such record in %s where %s is %s" % \
228                                                                (cl, keyname, _quote(k, fnames[keyname]))
229                        return None
230
231                for k in res[0].keys():
232                        arg[k] = res[0][k]
233
234                return arg
235
236        # Inserts a new tuple into a table
237        # We currently don't support insert into views although PostgreSQL does
238        def insert(self, cl, a):
239                fnames = self.get_attnames(cl)
240                l = []
241                n = []
242                for f in fnames.keys():
243                        if f != 'oid' and a.has_key(f):
244                                l.append(_quote(a[f], fnames[f]))
245                                n.append(f)
246
247                q = "INSERT INTO %s (%s) VALUES (%s)" % \
248                        (cl, string.join(n, ','), string.join(l, ','))
249                self._do_debug(q)
250                a['oid_%s' % cl] = self.db.query(q)
251
252                # reload the dictionary to catch things modified by engine
253                # note that get() changes 'oid' below to oid_table
254                # if no read perms (it can and does happen) return None
255                try: return self.get(cl, a, 'oid')
256                except: return None
257
258        # Update always works on the oid which get returns if available
259        # otherwise use the primary key.  Fail if neither.
260        def update(self, cl, a):
261                foid = 'oid_%s' % cl
262                if a.has_key(foid):
263                        where = "oid = %s" % a[foid]
264                else:
265                        try: pk = self.pkeys(cl)
266                        except: raise ProgrammingError, \
267                                        "Update needs primary key or oid as %s" % foid
268
269                        where = "%s = '%s'" % (pk, a[pk])
270
271                v = []
272                k = 0
273                fnames = self.get_attnames(cl)
274
275                for ff in fnames.keys():
276                        if ff != 'oid' and a.has_key(ff):
277                                v.append("%s = %s" % (ff, _quote(a[ff], fnames[ff])))
278
279                if v == []:
280                        return None
281
282                q = "UPDATE %s SET %s WHERE %s" % (cl, string.join(v, ','), where)
283                self._do_debug(q)
284                self.db.query(q)
285
286                # reload the dictionary to catch things modified by engine
287                if a.has_key(foid):
288                        return self.get(cl, a, 'oid')
289                else:
290                        return self.get(cl, a)
291
292        # At some point we will need a way to get defaults from a table
293        def clear(self, cl, a = {}):
294                fnames = self.get_attnames(cl)
295                for ff in fnames.keys():
296                        if fnames[ff] in ['int', 'decimal', 'seq', 'money']:
297                                a[ff] = 0
298                        else:
299                                a[ff] = ""
300
301                a['oid'] = 0
302                return a
303
304        # Like update, delete works on the oid
305        # one day we will be testing that the record to be deleted
306        # isn't referenced somewhere (or else PostgreSQL will)
307        def delete(self, cl, a):
308                q = "DELETE FROM %s WHERE oid = %s" % (cl, a['oid_%s' % cl])
309                self._do_debug(q)
310                self.db.query(q)
311
Note: See TracBrowser for help on using the repository browser.