source: trunk/module/pg.py @ 365

Last change on this file since 365 was 365, checked in by cito, 11 years ago

Added support for PQescapeStringConn() and PQescapeByteaConn(). Promoted the _quote() functions to methods of the connection. Made them use the former functions instead of manually and faulitly handling backslashes and quotes.

File size: 22.3 KB
Line 
1#!/usr/bin/env python
2#
3# pg.py
4#
5# Written by D'Arcy J.M. Cain
6# Improved by Christoph Zwerschke
7#
8# $Id: pg.py,v 1.63 2008-11-21 21:17:53 cito Exp $
9#
10
11"""PyGreSQL classic interface.
12
13This pg module implements some basic database management stuff.
14It includes the _pg module and builds on it, providing the higher
15level wrapper class named DB with addtional functionality.
16This is known as the "classic" ("old style") PyGreSQL interface.
17For a DB-API 2 compliant interface use the newer pgdb module.
18
19"""
20
21from _pg import *
22try:
23    frozenset
24except NameError: # Python < 2.4
25    from sets import ImmutableSet as frozenset
26try:
27    from decimal import Decimal
28    set_decimal(Decimal)
29except ImportError:
30    pass # Python < 2.4
31
32
33# Auxiliary functions which are independent from a DB connection:
34
35def _is_quoted(s):
36    """Check whether this string is a quoted identifier."""
37    s = s.replace('_', 'a')
38    return not s.isalnum() or s[:1].isdigit() or s != s.lower()
39
40def _is_unquoted(s):
41    """Check whether this string is an unquoted identifier."""
42    s = s.replace('_', 'a')
43    return s.isalnum() and not s[:1].isdigit()
44
45def _split_first_part(s):
46    """Split the first part of a dot separated string."""
47    s = s.lstrip()
48    if s[:1] == '"':
49        p = []
50        s = s.split('"', 3)[1:]
51        p.append(s[0])
52        while len(s) == 3 and s[1] == '':
53            p.append('"')
54            s = s[2].split('"', 2)
55            p.append(s[0])
56        p = [''.join(p)]
57        s = '"'.join(s[1:]).lstrip()
58        if s:
59            if s[:0] == '.':
60                p.append(s[1:])
61            else:
62                s = _split_first_part(s)
63                p[0] += s[0]
64                if len(s) > 1:
65                    p.append(s[1])
66    else:
67        p = s.split('.', 1)
68        s = p[0].rstrip()
69        if _is_unquoted(s):
70            s = s.lower()
71        p[0] = s
72    return p
73
74def _split_parts(s):
75    """Split all parts of a dot separated string."""
76    q = []
77    while s:
78        s = _split_first_part(s)
79        q.append(s[0])
80        if len(s) < 2:
81            break
82        s = s[1]
83    return q
84
85def _join_parts(s):
86    """Join all parts of a dot separated string."""
87    return '.'.join([_is_quoted(p) and '"%s"' % p or p for p in s])
88
89
90# The PostGreSQL database connection interface:
91
92class DB(object):
93    """Wrapper class for the _pg connection type."""
94
95    def __init__(self, *args, **kw):
96        """Create a new connection.
97
98        You can pass either the connection parameters or an existing
99        _pg or pgdb connection. This allows you to use the methods
100        of the classic pg interface with a DB-API 2 pgdb connection.
101
102        """
103        if not args and len(kw) == 1:
104            db = kw.get('db')
105        elif not kw and len(args) == 1:
106            db = args[0]
107        else:
108            db = None
109        if db:
110            if isinstance(db, DB):
111                db = db.db
112            else:
113                try:
114                    db = db._cnx
115                except AttributeError:
116                    pass
117        if not db or not hasattr(db, 'db') or not hasattr(db, 'query'):
118            db = connect(*args, **kw)
119            self._closeable = 1
120        else:
121            self._closeable = 0
122        self.db = db
123        self.dbname = db.db
124        self._attnames = {}
125        self._pkeys = {}
126        self._args = args, kw
127        self.debug = None # For debugging scripts, this can be set
128            # * to a string format specification (e.g. in CGI set to "%s<BR>"),
129            # * to a file object to write debug statements or
130            # * to a callable object which takes a string argument.
131
132    def __getattr__(self, name):
133        # All undefined members are the same as in the underlying pg connection:
134        if self.db:
135            return getattr(self.db, name)
136        else:
137            raise InternalError('Connection is not valid')
138
139    # Auxiliary methods
140
141    def _do_debug(self, s):
142        """Print a debug message."""
143        if self.debug:
144            if isinstance(self.debug, basestring):
145                print self.debug % s
146            elif isinstance(self.debug, file):
147                print >> self.debug, s
148            elif callable(self.debug):
149                self.debug(s)
150
151    def _quote_text(self, d):
152        """Quote text value."""
153        if not isinstance(d, basestring):
154            d = str(d)
155        return "'%s'" % self.escape_string(d)
156
157    _bool_true = frozenset('t true 1 y yes on'.split())
158
159    def _quote_bool(self, d):
160        """Quote boolean value."""
161        if isinstance(d, basestring):
162            if not d:
163                return 'NULL'
164            d = d.lower() in self._bool_true
165        else:
166            d = bool(d)
167        return ("'f'", "'t'")[d]
168
169    _date_literals = frozenset('current_date current_time'
170        ' current_timestamp localtime localtimestamp'.split())
171
172    def _quote_date(self, d):
173        """Quote date value."""
174        if not d:
175            return 'NULL'
176        if isinstance(d, basestring) and d.lower() in self._date_literals:
177            return d
178        return self._quote_text(d)
179
180    def _quote_num(self, d):
181        """Quote numeric value."""
182        if not d:
183            return 'NULL'
184        return str(d)
185
186    def _quote_money(self, d):
187        """Quote money value."""
188        if not d:
189            return 'NULL'
190        return "'%.2f'" % float(d)
191
192    _quote_funcs = dict( # quote methods for each type
193        text=_quote_text, bool=_quote_bool, date=_quote_date,
194        int=_quote_num, num=_quote_num, float=_quote_num,
195        money=_quote_money)
196
197    def _quote(self, d, t):
198        """Return quotes if needed."""
199        if d is None:
200            return 'NULL'
201        try:
202            quote_func = self._quote_funcs[t]
203        except KeyError:
204            quote_func = self._quote_funcs['text']
205        return quote_func(self, d)
206
207    def _split_schema(self, cl):
208        """Return schema and name of object separately.
209
210        This auxiliary function splits off the namespace (schema)
211        belonging to the class with the name cl. If the class name
212        is not qualified, the function is able to determine the schema
213        of the class, taking into account the current search path.
214
215        """
216        s = _split_parts(cl)
217        if len(s) > 1: # name already qualfied?
218            # should be database.schema.table or schema.table
219            if len(s) > 3:
220                raise ProgrammingError('Too many dots in class name %s' % cl)
221            schema, cl = s[-2:]
222        else:
223            cl = s[0]
224            # determine search path
225            query = 'SELECT current_schemas(TRUE)'
226            schemas = self.db.query(query).getresult()[0][0][1:-1].split(',')
227            if schemas: # non-empty path
228                # search schema for this object in the current search path
229                query = ' UNION '.join(
230                    ["SELECT %d::integer AS n, '%s'::name AS nspname"
231                        % s for s in enumerate(schemas)])
232                query = ("SELECT nspname FROM pg_class"
233                    " JOIN pg_namespace ON pg_class.relnamespace=pg_namespace.oid"
234                    " JOIN (%s) AS p USING (nspname)"
235                    " WHERE pg_class.relname='%s'"
236                    " ORDER BY n LIMIT 1" % (query, cl))
237                schema = self.db.query(query).getresult()
238                if schema: # schema found
239                    schema = schema[0][0]
240                else: # object not found in current search path
241                    schema = 'public'
242            else: # empty path
243                schema = 'public'
244        return schema, cl
245
246    # Public methods
247
248    # escape_string and escape_bytea exist as methods,
249    # so we define unescape_bytea as a method as well
250    unescape_bytea = staticmethod(unescape_bytea)
251
252    def close(self):
253        """Close the database connection."""
254        # Wraps shared library function so we can track state.
255        if self._closeable:
256            if self.db:
257                self.db.close()
258                self.db = None
259            else:
260                raise InternalError('Connection already closed')
261
262    def reopen(self):
263        """Reopen connection to the database.
264
265        Used in case we need another connection to the same database.
266        Note that we can still reopen a database that we have closed.
267
268        """
269        # There is no such shared library function.
270        if self._closeable:
271            db = connect(*self._args[0], **self._args[1])
272            if self.db:
273                self.db.close()
274            self.db = db
275
276    def query(self, qstr):
277        """Executes a SQL command string.
278
279        This method simply sends a SQL query to the database. If the query is
280        an insert statement, the return value is the OID of the newly
281        inserted row.  If it is otherwise a query that does not return a result
282        (ie. is not a some kind of SELECT statement), it returns None.
283        Otherwise, it returns a pgqueryobject that can be accessed via the
284        getresult or dictresult method or simply printed.
285
286        """
287        # Wraps shared library function for debugging.
288        if not self.db:
289            raise InternalError('Connection is not valid')
290        self._do_debug(qstr)
291        return self.db.query(qstr)
292
293    def pkey(self, cl, newpkey = None):
294        """This method gets or sets the primary key of a class.
295
296        If newpkey is set and is not a dictionary then set that
297        value as the primary key of the class.  If it is a dictionary
298        then replace the _pkeys dictionary with it.
299
300        """
301        # First see if the caller is supplying a dictionary
302        if isinstance(newpkey, dict):
303            # make sure that we have a namespace
304            self._pkeys = {}
305            for x in newpkey.keys():
306                if x.find('.') == -1:
307                    self._pkeys['public.' + x] = newpkey[x]
308                else:
309                    self._pkeys[x] = newpkey[x]
310            return self._pkeys
311
312        qcl = _join_parts(self._split_schema(cl)) # build qualified name
313        if newpkey:
314            self._pkeys[qcl] = newpkey
315            return newpkey
316
317        # Get all the primary keys at once
318        if qcl not in self._pkeys:
319            # if not found, check again in case it was added after we started
320            self._pkeys = dict([
321                (_join_parts(r[:2]), r[2]) for r in self.db.query(
322                "SELECT pg_namespace.nspname, pg_class.relname"
323                    ",pg_attribute.attname FROM pg_class"
324                " JOIN pg_namespace ON pg_namespace.oid=pg_class.relnamespace"
325                    " AND pg_namespace.nspname NOT LIKE 'pg_%'"
326                " JOIN pg_attribute ON pg_attribute.attrelid=pg_class.oid"
327                    " AND pg_attribute.attisdropped='f'"
328                " JOIN pg_index ON pg_index.indrelid=pg_class.oid"
329                    " AND pg_index.indisprimary='t'"
330                    " AND pg_index.indkey[0]=pg_attribute.attnum"
331                ).getresult()])
332            self._do_debug(self._pkeys)
333        # will raise an exception if primary key doesn't exist
334        return self._pkeys[qcl]
335
336    def get_databases(self):
337        """Get list of databases in the system."""
338        return [s[0] for s in
339            self.db.query('SELECT datname FROM pg_database').getresult()]
340
341    def get_relations(self, kinds = None):
342        """Get list of relations in connected database of specified kinds.
343
344            If kinds is None or empty, all kinds of relations are returned.
345            Otherwise kinds can be a string or sequence of type letters
346            specifying which kind of relations you want to list.
347
348        """
349        where = kinds and "pg_class.relkind IN (%s) AND" % ','.join(
350            ["'%s'" % x for x in kinds]) or ''
351        return map(_join_parts, self.db.query(
352            "SELECT pg_namespace.nspname, pg_class.relname "
353            "FROM pg_class "
354            "JOIN pg_namespace ON pg_namespace.oid=pg_class.relnamespace "
355            "WHERE %s pg_class.relname !~ '^Inv' AND "
356                "pg_class.relname !~ '^pg_' "
357            "ORDER BY 1, 2" % where).getresult())
358
359    def get_tables(self):
360        """Return list of tables in connected database."""
361        return self.get_relations('r')
362
363    def get_attnames(self, cl, newattnames = None):
364        """Given the name of a table, digs out the set of attribute names.
365
366        Returns a dictionary of attribute names (the names are the keys,
367        the values are the names of the attributes' types).
368        If the optional newattnames exists, it must be a dictionary and
369        will become the new attribute names dictionary.
370
371        """
372        if isinstance(newattnames, dict):
373            self._attnames = newattnames
374            return
375        elif newattnames:
376            raise ProgrammingError(
377                'If supplied, newattnames must be a dictionary')
378        cl = self._split_schema(cl) # split into schema and cl
379        qcl = _join_parts(cl) # build qualified name
380        # May as well cache them:
381        if qcl in self._attnames:
382            return self._attnames[qcl]
383        if qcl not in self.get_relations('rv'):
384            raise ProgrammingError('Class %s does not exist' % qcl)
385        t = {}
386        for att, typ in self.db.query("SELECT pg_attribute.attname"
387            ",pg_type.typname FROM pg_class"
388            " JOIN pg_namespace ON pg_class.relnamespace=pg_namespace.oid"
389            " JOIN pg_attribute ON pg_attribute.attrelid=pg_class.oid"
390            " JOIN pg_type ON pg_type.oid=pg_attribute.atttypid"
391            " WHERE pg_namespace.nspname='%s' AND pg_class.relname='%s'"
392            " AND (pg_attribute.attnum>0 or pg_attribute.attname='oid')"
393            " AND pg_attribute.attisdropped='f'"
394                % cl).getresult():
395            if typ.startswith('bool'):
396                t[att] = 'bool'
397            elif typ.startswith('abstime'):
398                t[att] = 'date'
399            elif typ.startswith('date'):
400                t[att] = 'date'
401            elif typ.startswith('interval'):
402                t[att] = 'date'
403            elif typ.startswith('timestamp'):
404                t[att] = 'date'
405            elif typ.startswith('oid'):
406                t[att] = 'int'
407            elif typ.startswith('int'):
408                t[att] = 'int'
409            elif typ.startswith('float'):
410                t[att] = 'float'
411            elif typ.startswith('numeric'):
412                t[att] = 'num'
413            elif typ.startswith('money'):
414                t[att] = 'money'
415            else:
416                t[att] = 'text'
417        self._attnames[qcl] = t # cache it
418        return self._attnames[qcl]
419
420    def get(self, cl, arg, keyname = None, view = 0):
421        """Get a tuple from a database table or view.
422
423        This method is the basic mechanism to get a single row.  It assumes
424        that the key specifies a unique row.  If keyname is not specified
425        then the primary key for the table is used.  If arg is a dictionary
426        then the value for the key is taken from it and it is modified to
427        include the new values, replacing existing values where necessary.
428        The OID is also put into the dictionary, but in order to allow the
429        caller to work with multiple tables, it is munged as oid(schema.table).
430
431        """
432        if cl.endswith('*'): # scan descendant tables?
433            cl = cl[:-1].rstrip() # need parent table name
434        qcl = _join_parts(self._split_schema(cl)) # build qualified name
435        # To allow users to work with multiple tables,
436        # we munge the name when the key is "oid"
437        foid = 'oid(%s)' % qcl # build mangled name
438        if keyname is None: # use the primary key by default
439            keyname = self.pkey(qcl)
440        fnames = self.get_attnames(qcl)
441        if isinstance(arg, dict):
442            # XXX this code is for backwards compatibility and will be
443            # XXX removed eventually
444            if foid not in arg:
445                ofoid = 'oid_' + self._split_schema(cl)[-1]
446                if ofoid in arg:
447                    arg[foid] = arg[ofoid]
448
449            k = arg[keyname == 'oid' and foid or keyname]
450        else:
451            k = arg
452            arg = {}
453        # We want the oid for later updates if that isn't the key
454        if keyname == 'oid':
455            q = 'SELECT * FROM %s WHERE oid=%s LIMIT 1' % (qcl, k)
456        elif view:
457            q = 'SELECT * FROM %s WHERE %s=%s LIMIT 1' % \
458                (qcl, keyname, _quote(k, fnames[keyname]))
459        else:
460            q = 'SELECT %s FROM %s WHERE %s=%s LIMIT 1' % \
461                (','.join(fnames.keys()), qcl, \
462                    keyname, self._quote(k, fnames[keyname]))
463        self._do_debug(q)
464        res = self.db.query(q).dictresult()
465        if not res:
466            raise DatabaseError('No such record in %s where %s=%s'
467                % (qcl, keyname, self._quote(k, fnames[keyname])))
468        for k, d in res[0].items():
469            if k == 'oid':
470                k = foid
471            arg[k] = d
472        return arg
473
474    def insert(self, cl, d = None, **kw):
475        """Insert a tuple into a database table.
476
477        This method inserts values into the table specified filling in the
478        values from the dictionary.  It then reloads the dictionary with the
479        values from the database.  This causes the dictionary to be updated
480        with values that are modified by rules, triggers, etc.
481
482        Note: The method currently doesn't support insert into views
483        although PostgreSQL does.
484
485        """
486        if d is None:
487            a = {}
488        else:
489            a = d
490        a.update(kw)
491
492        qcl = _join_parts(self._split_schema(cl)) # build qualified name
493        foid = 'oid(%s)' % qcl # build mangled name
494        fnames = self.get_attnames(qcl)
495        t = []
496        n = []
497        for f in fnames.keys():
498            if f != 'oid' and f in a:
499                t.append(self._quote(a[f], fnames[f]))
500                n.append('"%s"' % f)
501        q = 'INSERT INTO %s (%s) VALUES (%s)' % \
502            (qcl, ','.join(n), ','.join(t))
503        self._do_debug(q)
504        a[foid] = self.db.query(q)
505        # Reload the dictionary to catch things modified by engine.
506        # Note that get() changes 'oid' below to oid_schema_table.
507        # If no read perms (it can and does happen), return None.
508        try:
509            return self.get(qcl, a, 'oid')
510        except Exception:
511            return None
512
513    def update(self, cl, d = None, **kw):
514        """Update an existing row in a database table.
515
516        Similar to insert but updates an existing row.  The update is based
517        on the OID value as munged by get.  The array returned is the
518        one sent modified to reflect any changes caused by the update due
519        to triggers, rules, defaults, etc.
520
521        """
522        # Update always works on the oid which get returns if available,
523        # otherwise use the primary key.  Fail if neither.
524        qcl = _join_parts(self._split_schema(cl)) # build qualified name
525        foid = 'oid(%s)' % qcl # build mangled oid
526
527        # Note that we only accept oid key from named args for safety
528        if 'oid' in kw:
529            kw[foid] = kw['oid']
530            del kw['oid']
531
532        if d is None:
533            a = {}
534        else:
535            a = d
536        a.update(kw)
537
538        # XXX this code is for backwards compatibility and will be
539        # XXX removed eventually
540        if foid not in a:
541            ofoid = 'oid_' + self._split_schema(cl)[-1]
542            if ofoid in a:
543                a[foid] = a[ofoid]
544
545        if foid in a:
546            where = "oid=%s" % a[foid]
547        else:
548            try:
549                pk = self.pkey(qcl)
550            except Exception:
551                raise ProgrammingError(
552                    'Update needs primary key or oid as %s' % foid)
553            where = "%s='%s'" % (pk, a[pk])
554        v = []
555        k = 0
556        fnames = self.get_attnames(qcl)
557        for ff in fnames.keys():
558            if ff != 'oid' and ff in a:
559                v.append('%s=%s' % (ff, self._quote(a[ff], fnames[ff])))
560        if v == []:
561            return None
562        q = 'UPDATE %s SET %s WHERE %s' % (qcl, ','.join(v), where)
563        self._do_debug(q)
564        self.db.query(q)
565        # Reload the dictionary to catch things modified by engine:
566        if foid in a:
567            return self.get(qcl, a, 'oid')
568        else:
569            return self.get(qcl, a)
570
571    def clear(self, cl, a = None):
572        """
573
574        This method clears all the attributes to values determined by the types.
575        Numeric types are set to 0, Booleans are set to 'f', and everything
576        else is set to the empty string.  If the array argument is present,
577        it is used as the array and any entries matching attribute names are
578        cleared with everything else left unchanged.
579
580        """
581        # At some point we will need a way to get defaults from a table.
582        if a is None:
583            a = {} # empty if argument is not present
584        qcl = _join_parts(self._split_schema(cl)) # build qualified name
585        foid = 'oid(%s)' % qcl # build mangled oid
586        fnames = self.get_attnames(qcl)
587        for k, t in fnames.items():
588            if k == 'oid':
589                continue
590            if t in ('int', 'float', 'num', 'money'):
591                a[k] = 0
592            elif t == 'bool':
593                a[k] = 'f'
594            else:
595                a[k] = ''
596        return a
597
598    def delete(self, cl, d = None, **kw):
599        """Delete an existing row in a database table.
600
601        This method deletes the row from a table.
602        It deletes based on the OID munged as described above."""
603
604        # Like update, delete works on the oid.
605        # One day we will be testing that the record to be deleted
606        # isn't referenced somewhere (or else PostgreSQL will).
607        qcl = _join_parts(self._split_schema(cl)) # build qualified name
608        foid = 'oid(%s)' % qcl # build mangled oid
609
610        # Note that we only accept oid key from named args for safety
611        if 'oid' in kw:
612            kw[foid] = kw['oid']
613            del kw['oid']
614
615        if d is None:
616            a = {}
617        else:
618            a = d
619        a.update(kw)
620
621        # XXX this code is for backwards compatibility and will be
622        # XXX removed eventually
623        if foid not in a:
624            ofoid = 'oid_' + self._split_schema(cl)[-1]
625            if ofoid in a:
626                a[foid] = a[ofoid]
627
628        q = 'DELETE FROM %s WHERE oid=%s' % (qcl, a[foid])
629        self._do_debug(q)
630        self.db.query(q)
631
632
633# if run as script, print some information
634
635if __name__ == '__main__':
636    print 'PyGreSQL version', version
637    print
638    print __doc__
Note: See TracBrowser for help on using the repository browser.