source: trunk/module/pg.py @ 29

Last change on this file since 29 was 29, checked in by momjian, 19 years ago

Update to PyGreSQL 3.1:

Fix some quoting functions. In particular handle NULLs better.

Use a method to add primary key information rather than direct
manipulation of the class structures.

Break decimal out in _quote (in pg.py) and treat it as float.

Treat timestamp like date for quoting purposes.

Remove a redundant SELECT from the get method speeding it, and
insert since it calls get, up a little.

Add test for BOOL type in typecast method to pgdbTypeCache class.
(tv@…)

Fix pgdb.py to send port as integer to lower level function
(dildog@…)

Change pg.py to speed up some operations

Allow updates on tables with no primary keys.

D'Arcy J.M. Cain

File size: 7.5 KB
Line 
1# pgutil.py
2# Written by D'Arcy J.M. Cain
3
4# This library implements some basic database management stuff
5# It includes the pg module and builds on it
6
7from _pg import *
8import string, re, sys
9
10# utility function
11# We expect int, seq, decimal, text or date (more later)
12def _quote(d, t):
13        if d == None:
14                return "NULL"
15
16        if t in ['int', 'seq']:
17                if d == "": return "NULL"
18                return "%d" % int(d)
19
20        if t == 'decimal':
21                if d == "": return "NULL"
22                return "%f" % float(d)
23
24        if t == 'money':
25                if d == "": return "NULL"
26                return "'%.2f'" % float(d)
27
28        if t == 'bool':
29                # Can't run upper() on these
30                if d in (0, 1): return ('f', 't')[d]
31
32                if string.upper(d) in ['T', 'TRUE', 'Y', 'YES', '1', 'ON']:
33                        return "'t'"
34                else:
35                        return "'f'"
36
37        if t == 'date' and d == '': return "NULL"
38        if t in ('inet', 'cidr') and d == '': return "NULL"
39
40        return "'%s'" % string.strip(re.sub("'", "''", \
41                                                         re.sub("\\\\", "\\\\\\\\", "%s" %d)))
42
43class DB:
44        """This class wraps the pg connection type"""
45
46        def __init__(self, *args, **kw):
47                self.db = apply(connect, args, kw)
48
49                # Create convience methods, in a way that is still overridable.
50                for e in ( 'query', 'reset', 'close', 'getnotify', 'inserttable',
51                                        'putline', 'getline', 'endcopy',
52                                        'host', 'port', 'db', 'options',
53                                        'tty', 'error', 'status', 'user',
54                                        'locreate', 'getlo', 'loimport' ):
55                        if not hasattr(self,e) and hasattr(self.db,e):
56                                exec 'self.%s = self.db.%s' % ( e, e )
57
58                self.__attnames__ = {}
59                self.__pkeys__ = {}
60                self.debug = None       # For debugging scripts, set to output format
61                                                        # that takes a single string arg.  For example
62                                                        # in a CGI set to "%s<BR>"
63
64                # Get all the primary keys at once
65                for rel, att in self.db.query("""SELECT
66                                                        pg_class.relname, pg_attribute.attname
67                                                FROM pg_class, pg_attribute, pg_index
68                                                WHERE pg_class.oid = pg_attribute.attrelid AND
69                                                        pg_class.oid = pg_index.indrelid AND
70                                                        pg_index.indkey[0] = pg_attribute.attnum AND
71                                                        pg_index.indisprimary = 't'""").getresult():
72                        self.__pkeys__[rel] = att
73
74        # wrap query for debugging
75        def query(self, qstr):
76                if self.debug != None:
77                        print self.debug % qstr
78                return self.db.query(qstr)
79
80        # If third arg supplied set primary key to it
81        def pkey(self, cl, newpkey = None):
82                if newpkey:
83                        self.__pkeys__[cl] = newpkey
84
85                # will raise an exception if primary key doesn't exist
86                return self.__pkeys__[cl]
87
88        def get_databases(self):
89                l = []
90                for n in self.db.query("SELECT datname FROM pg_database").getresult():
91                        l.append(n[0])
92                return l
93
94        def get_tables(self):
95                l = []
96                for n in self.db.query("""SELECT relname FROM pg_class
97                                                WHERE relkind = 'r' AND
98                                                        relname !~ '^Inv' AND
99                                                        relname !~ '^pg_'""").getresult():
100                        l.append(n[0])
101                return l
102
103        def get_attnames(self, cl):
104                # May as well cache them
105                if self.__attnames__.has_key(cl):
106                        return self.__attnames__[cl]
107
108                query = """SELECT pg_attribute.attname, pg_type.typname
109                                        FROM pg_class, pg_attribute, pg_type
110                                        WHERE pg_class.relname = '%s' AND
111                                                pg_attribute.attnum > 0 AND
112                                                pg_attribute.attrelid = pg_class.oid AND
113                                                pg_attribute.atttypid = pg_type.oid"""
114
115                l = {}
116                for attname, typname in self.db.query(query % cl).getresult():
117                        if re.match("^int", typname):
118                                l[attname] = 'int'
119                        elif re.match("^oid", typname):
120                                l[attname] = 'int'
121                        elif re.match("^text", typname):
122                                l[attname] = 'text'
123                        elif re.match("^char", typname):
124                                l[attname] = 'text'
125                        elif re.match("^name", typname):
126                                l[attname] = 'text'
127                        elif re.match("^abstime", typname):
128                                l[attname] = 'date'
129                        elif re.match("^date", typname):
130                                l[attname] = 'date'
131                        elif re.match("^timestamp", typname):
132                                l[attname] = 'date'
133                        elif re.match("^bool", typname):
134                                l[attname] = 'bool'
135                        elif re.match("^float", typname):
136                                l[attname] = 'decimal'
137                        elif re.match("^money", typname):
138                                l[attname] = 'money'
139                        else:
140                                l[attname] = 'text'
141
142                self.__attnames__[cl] = l
143                return self.__attnames__[cl]
144
145        # return a tuple from a database
146        def get(self, cl, arg, keyname = None, view = 0):
147                if cl[-1] == '*':                       # need parent table name
148                        xcl = cl[:-1]
149                else:
150                        xcl = cl
151
152                if keyname == None:                     # use the primary key by default
153                        keyname = self.__pkeys__[xcl]
154
155                fnames = self.get_attnames(xcl)
156
157                if type(arg) == type({}):
158                        # To allow users to work with multiple tables we munge the
159                        # name when the key is "oid"
160                        if keyname == 'oid': k = arg['oid_%s' % xcl]
161                        else: k = arg[keyname]
162                else:
163                        k = arg
164                        arg = {}
165
166                # We want the oid for later updates if that isn't the key
167                if keyname == 'oid':
168                        q = "SELECT * FROM %s WHERE oid = %s" % (cl, k)
169                elif view:
170                        q = "SELECT * FROM %s WHERE %s = %s" % \
171                                (cl, keyname, _quote(k, fnames[keyname]))
172                else:
173                        q = "SELECT oid AS oid_%s, %s FROM %s WHERE %s = %s" % \
174                                (xcl, string.join(fnames.keys(), ','),\
175                                        cl, keyname, _quote(k, fnames[keyname]))
176
177                if self.debug != None: print self.debug % q
178                res = self.db.query(q).dictresult()
179                if res == []:
180                        raise error, \
181                                "No such record in %s where %s is %s" % \
182                                                                (cl, keyname, _quote(k, fnames[keyname]))
183                        return None
184
185                for k in res[0].keys():
186                        arg[k] = res[0][k]
187
188                return arg
189
190        # Inserts a new tuple into a table
191        # We currently don't support insert into views although PostgreSQL does
192        def insert(self, cl, a):
193                fnames = self.get_attnames(cl)
194                l = []
195                n = []
196                for f in fnames.keys():
197                        if a.has_key(f):
198                                l.append(_quote(a[f], fnames[f]))
199                                n.append(f)
200
201                try:
202                        q = "INSERT INTO %s (%s) VALUES (%s)" % \
203                                (cl, string.join(n, ','), string.join(l, ','))
204                        if self.debug != None: print self.debug % q
205                        a['oid_%s' % cl] = self.db.query(q)
206                except:
207                        raise error, "Error inserting into %s: %s" % (cl, sys.exc_value)
208
209                # reload the dictionary to catch things modified by engine
210                # note that get() changes 'oid' below to oid_table
211                # if no read perms (it can and does happen) return None
212                try: return self.get(cl, a, 'oid')
213                except: return None
214
215        # Update always works on the oid which get returns if available
216        # otherwise use the primary key.  Fail if neither.
217        def update(self, cl, a):
218                foid = 'oid_%s' % cl
219                if a.has_key(foid):
220                        where = "oid = %s" % a[foid]
221                elif self.__pkeys__.has_key(cl) and a.has_key(self.__pkeys__[cl]):
222                        where = "%s = '%s'" % (self.__pkeys__[cl], a[self.__pkeys__[cl]])
223                else:
224                        raise error, "Update needs primary key or oid as %s" % foid
225
226                v = []
227                k = 0
228                fnames = self.get_attnames(cl)
229
230                for ff in fnames.keys():
231                        if a.has_key(ff):
232                                v.append("%s = %s" % (ff, _quote(a[ff], fnames[ff])))
233
234                if v == []:
235                        return None
236
237                try:
238                        q = "UPDATE %s SET %s WHERE %s" % \
239                                                        (cl, string.join(v, ','), where)
240                        if self.debug != None: print self.debug % q
241                        self.db.query(q)
242                except:
243                        raise error, "Can't update %s: %s" % (cl, sys.exc_value)
244
245                # reload the dictionary to catch things modified by engine
246                if a.has_key(foid):
247                        return self.get(cl, a, 'oid')
248                else:
249                        return self.get(cl, a)
250
251        # At some point we will need a way to get defaults from a table
252        def clear(self, cl, a = {}):
253                fnames = self.get_attnames(cl)
254                for ff in fnames.keys():
255                        if fnames[ff] in ['int', 'decimal', 'seq', 'money']:
256                                a[ff] = 0
257                        else:
258                                a[ff] = ""
259
260                a['oid'] = 0
261                return a
262
263        # Like update, delete works on the oid
264        # one day we will be testing that the record to be deleted
265        # isn't referenced somewhere (or else PostgreSQL will)
266        def delete(self, cl, a):
267                try:
268                        q = "DELETE FROM %s WHERE oid = %s" % (cl, a['oid_%s' % cl])
269                        if self.debug != None: print self.debug % q
270                        self.db.query(q)
271                except:
272                        raise error, "Can't delete %s: %s" % (cl, sys.exc_value)
273
274                return None
275
Note: See TracBrowser for help on using the repository browser.