source: trunk/module/pg.py @ 162

Last change on this file since 162 was 162, checked in by darcy, 15 years ago

Switch to using isinstance instead of type. This allows us to send objects
that are subclassed from the base types.

File size: 9.1 KB
Line 
1# pg.py
2# Written by D'Arcy J.M. Cain
3
4# This library implements some basic database management stuff.  It
5# includes the pg module and builds on it.  This is known as the
6# "Classic" interface.  For DB-API compliance use the pgdb module.
7
8from _pg import *
9from types import *
10import string, re, sys
11
12# utility function
13# We expect int, seq, decimal, text or date (more later)
14def _quote(d, t):
15        if d == None:
16                return "NULL"
17
18        if t in ['int', 'seq']:
19                if d == "": return "NULL"
20                return "%d" % long(d)
21
22        if t == 'decimal':
23                if d == "": return "NULL"
24                return "%f" % float(d)
25
26        if t == 'money':
27                if d == "": return "NULL"
28                return "'%.2f'" % float(d)
29
30        if t == 'bool':
31                # Can't run upper() on these
32                if d in (0, 1): return ("'f'", "'t'")[d]
33
34                if string.upper(d) in ['T', 'TRUE', 'Y', 'YES', '1', 'ON']:
35                        return "'t'"
36                else:
37                        return "'f'"
38
39        if t == 'date' and d == '': return "NULL"
40        if t in ('inet', 'cidr') and d == '': return "NULL"
41
42        return "'%s'" % string.strip(re.sub("'", "''", \
43                                                         re.sub("\\\\", "\\\\\\\\", "%s" % d)))
44
45class DB:
46        """This class wraps the pg connection type"""
47
48        def __init__(self, *args, **kw):
49                self.db = connect(*args, **kw)
50
51                # Create convenience methods, in a way that is still overridable
52                # (members are not copied because they are actually functions)
53                for e in self.db.__methods__:
54                        if e not in ("close", "query"): # These are wrapped separately
55                                setattr(self, e, getattr(self.db, e))
56
57                self.__attnames = {}
58                self.__pkeys = {}
59                self.__args = args, kw
60                self.debug = None       # For debugging scripts, this can be set to a
61                                                        # string format specification (e.g. in a CGI
62                                                        # set to "%s<BR>",) to a function which takes
63                                                        # a string argument or a file object to write
64                                                        # debug statements to.
65
66        def _do_debug(self, s):
67                if not self.debug: return
68                if isinstance(self.debug, StringType): print self.debug % s
69                if isinstance(self.debug, FunctionType): self.debug(s)
70                if isinstance(self.debug, FileType): print >> self.debug, s
71
72        # wrap close so we can track state
73        def close(self,):
74                self.db.close()
75                self.db = None
76
77        # in case we need another connection to the same database
78        # note that we can still reopen a database that we have closed
79        def reopen(self):
80                if self.db: self.close()
81                try: self.db = connect(*self.__args[0], **self.__args[1])
82                except:
83                        self.db = None
84                        raise
85
86        # wrap query for debugging
87        def query(self, qstr):
88                self._do_debug(qstr)
89                return self.db.query(qstr)
90
91        def pkey(self, cl, newpkey = None):
92                """This method returns the primary key of a class.  If newpkey
93                        is set and is set and is not a dictionary then set that
94                        value as the primary key of the class.  If it is a dictionary
95                        then replace the __pkeys dictionary with it."""
96                # Get all the primary keys at once
97                if isinstance(newpkey, DictType):
98                        self.__pkeys = newpkey
99                        return
100
101                if newpkey:
102                        self.__pkeys[cl] = newpkey
103                        return newpkey
104
105                if self.__pkeys == {}:
106                        for rel, att in self.db.query("""SELECT
107                                                        pg_class.relname, pg_attribute.attname
108                                                FROM pg_class, pg_attribute, pg_index
109                                                WHERE pg_class.oid = pg_attribute.attrelid AND
110                                                        pg_class.oid = pg_index.indrelid AND
111                                                        pg_index.indkey[0] = pg_attribute.attnum AND
112                                                        pg_index.indisprimary = 't' AND
113                                                        pg_attribute.attisdropped = 'f'""").getresult():
114                                self.__pkeys[rel] = att
115                # Give it one more chance in case it was added after we started
116                elif not self.__pkeys.has_key(cl):
117                        self.__pkeys = {}
118                        return self.pkey(cl)
119
120                # will raise an exception if primary key doesn't exist
121                return self.__pkeys[cl]
122
123        def get_databases(self):
124                return [x[0] for x in
125                        self.db.query("SELECT datname FROM pg_database").getresult()]
126
127        def get_tables(self):
128                return [x[0] for x in
129                                self.db.query("""SELECT relname FROM pg_class
130                                                WHERE relkind = 'r' AND
131                                                        relname !~ '^Inv' AND
132                                                        relname !~ '^pg_'""").getresult()]
133
134        def get_attnames(self, cl, newattnames = None):
135                """This method gets a list of attribute names for a class.  If
136                        the optional newattnames exists it must be a dictionary and
137                        will become the new attribute names dictionary."""
138
139                if isinstance(newattnames, DictType):
140                        self.__attnames = newattnames
141                        return
142                elif newattnames:
143                        raise error, "If supplied, newattnames must be a dictionary"
144
145                # May as well cache them
146                if self.__attnames.has_key(cl):
147                        return self.__attnames[cl]
148
149                query = """SELECT pg_attribute.attname, pg_type.typname
150                                        FROM pg_class, pg_attribute, pg_type
151                                        WHERE pg_class.relname = '%s' AND
152                                                pg_attribute.attnum > 0 AND
153                                                pg_attribute.attrelid = pg_class.oid AND
154                                                pg_attribute.atttypid = pg_type.oid AND
155                                                pg_attribute.attisdropped = 'f'"""
156
157                l = {}
158                for attname, typname in self.db.query(query % cl).getresult():
159                        if re.match("^interval", typname):
160                                l[attname] = 'text'
161                        if re.match("^int", typname):
162                                l[attname] = 'int'
163                        elif re.match("^oid", typname):
164                                l[attname] = 'int'
165                        elif re.match("^text", typname):
166                                l[attname] = 'text'
167                        elif re.match("^char", typname):
168                                l[attname] = 'text'
169                        elif re.match("^name", typname):
170                                l[attname] = 'text'
171                        elif re.match("^abstime", typname):
172                                l[attname] = 'date'
173                        elif re.match("^date", typname):
174                                l[attname] = 'date'
175                        elif re.match("^timestamp", typname):
176                                l[attname] = 'date'
177                        elif re.match("^bool", typname):
178                                l[attname] = 'bool'
179                        elif re.match("^float", typname):
180                                l[attname] = 'decimal'
181                        elif re.match("^money", typname):
182                                l[attname] = 'money'
183                        else:
184                                l[attname] = 'text'
185
186                l['oid'] = 'int'                                # every table has this
187                self.__attnames[cl] = l         # cache it
188                return self.__attnames[cl]
189
190        # return a tuple from a database
191        def get(self, cl, arg, keyname = None, view = 0):
192                if cl[-1] == '*':                       # need parent table name
193                        xcl = cl[:-1]
194                else:
195                        xcl = cl
196
197                if keyname == None:                     # use the primary key by default
198                        keyname = self.pkey(xcl)
199
200                fnames = self.get_attnames(xcl)
201
202                if isinstance(arg, DictType):
203                        # To allow users to work with multiple tables we munge the
204                        # name when the key is "oid"
205                        if keyname == 'oid': k = arg['oid_%s' % xcl]
206                        else: k = arg[keyname]
207                else:
208                        k = arg
209                        arg = {}
210
211                # We want the oid for later updates if that isn't the key
212                if keyname == 'oid':
213                        q = "SELECT * FROM %s WHERE oid = %s" % (cl, k)
214                elif view:
215                        q = "SELECT * FROM %s WHERE %s = %s" % \
216                                (cl, keyname, _quote(k, fnames[keyname]))
217                else:
218                        q = "SELECT oid AS oid_%s, %s FROM %s WHERE %s = %s" % \
219                                (xcl, string.join(fnames.keys(), ','),\
220                                        cl, keyname, _quote(k, fnames[keyname]))
221
222                self._do_debug(q)
223                res = self.db.query(q).dictresult()
224                if res == []:
225                        raise error, \
226                                "No such record in %s where %s is %s" % \
227                                                                (cl, keyname, _quote(k, fnames[keyname]))
228                        return None
229
230                for k in res[0].keys():
231                        arg[k] = res[0][k]
232
233                return arg
234
235        # Inserts a new tuple into a table
236        # We currently don't support insert into views although PostgreSQL does
237        def insert(self, cl, a):
238                fnames = self.get_attnames(cl)
239                l = []
240                n = []
241                for f in fnames.keys():
242                        if f != 'oid' and a.has_key(f):
243                                l.append(_quote(a[f], fnames[f]))
244                                n.append(f)
245
246                try:
247                        q = "INSERT INTO %s (%s) VALUES (%s)" % \
248                                (cl, string.join(n, ','), string.join(l, ','))
249                        self._do_debug(q)
250                        a['oid_%s' % cl] = self.db.query(q)
251                except:
252                        raise error, "Error inserting into %s: %s" % (cl, sys.exc_value)
253
254                # reload the dictionary to catch things modified by engine
255                # note that get() changes 'oid' below to oid_table
256                # if no read perms (it can and does happen) return None
257                try: return self.get(cl, a, 'oid')
258                except: return None
259
260        # Update always works on the oid which get returns if available
261        # otherwise use the primary key.  Fail if neither.
262        def update(self, cl, a):
263                self.pkey(cl)           # make sure we have a self.__pkeys dictionary
264
265                foid = 'oid_%s' % cl
266                if a.has_key(foid):
267                        where = "oid = %s" % a[foid]
268                elif self.__pkeys.has_key(cl) and a.has_key(self.__pkeys[cl]):
269                        where = "%s = '%s'" % (self.__pkeys[cl], a[self.__pkeys[cl]])
270                else:
271                        raise error, "Update needs primary key or oid as %s" % foid
272
273                v = []
274                k = 0
275                fnames = self.get_attnames(cl)
276
277                for ff in fnames.keys():
278                        if ff != 'oid' and a.has_key(ff):
279                                v.append("%s = %s" % (ff, _quote(a[ff], fnames[ff])))
280
281                if v == []:
282                        return None
283
284                try:
285                        q = "UPDATE %s SET %s WHERE %s" % \
286                                                        (cl, string.join(v, ','), where)
287                        self._do_debug(q)
288                        self.db.query(q)
289                except:
290                        raise error, "Can't update %s: %s" % (cl, sys.exc_value)
291
292                # reload the dictionary to catch things modified by engine
293                if a.has_key(foid):
294                        return self.get(cl, a, 'oid')
295                else:
296                        return self.get(cl, a)
297
298        # At some point we will need a way to get defaults from a table
299        def clear(self, cl, a = {}):
300                fnames = self.get_attnames(cl)
301                for ff in fnames.keys():
302                        if fnames[ff] in ['int', 'decimal', 'seq', 'money']:
303                                a[ff] = 0
304                        else:
305                                a[ff] = ""
306
307                a['oid'] = 0
308                return a
309
310        # Like update, delete works on the oid
311        # one day we will be testing that the record to be deleted
312        # isn't referenced somewhere (or else PostgreSQL will)
313        def delete(self, cl, a):
314                try:
315                        q = "DELETE FROM %s WHERE oid = %s" % (cl, a['oid_%s' % cl])
316                        self._do_debug(q)
317                        self.db.query(q)
318                except:
319                        raise error, "Can't delete %s: %s" % (cl, sys.exc_value)
320
321                return None
322
Note: See TracBrowser for help on using the repository browser.