Changeset 204 for trunk/module/pg.py


Ignore:
Timestamp:
Aug 12, 2005, 6:02:58 PM (14 years ago)
Author:
cito
Message:

Major improvements in classic pg module

File:
1 edited

Legend:

Unmodified
Added
Removed
  • trunk/module/pg.py

    r202 r204  
    11# pg.py
    22# Written by D'Arcy J.M. Cain
    3 # $Id: pg.py,v 1.31 2005-05-26 13:58:36 cito Exp $
    4 
    5 """PyGreSQL Classic Interface.
    6 
    7 This library implements some basic database management stuff.
    8 It includes the pg module and builds on it.  This is known as the
    9 "Classic" interface.  For DB-API compliance use the pgdb module.
     3# Improved by Christoph Zwerschke
     4# $Id: pg.py,v 1.32 2005-08-12 22:02:58 cito Exp $
     5
     6"""PyGreSQL classic interface.
     7
     8This pg module implements some basic database management stuff.
     9It includes the _pg module and builds on it, providing the higher
     10level wrapper class named DB with addtional functionality.
     11This is known as the "classic" ("old style") PyGreSQL interface.
     12For a DB-API 2 compliant interface use the newer pgdb module.
    1013"""
    1114
     
    1316from types import *
    1417
     18# Auxiliary functions which are independent from a DB connection:
     19
    1520def _quote(d, t):
    16         """Return quotes if needed (utility function)."""
    17         if d == None:
    18                 return "NULL"
    19         if t in ('int', 'seq'):
    20                 if d == "": return "NULL"
    21                 return "%d" % long(d)
    22         if t == 'decimal':
    23                 if d == "": return "NULL"
    24                 return "%f" % float(d)
     21        """Return quotes if needed."""
     22        if d is None:
     23                return 'NULL'
     24        if t in ('int', 'seq', 'decimal'):
     25                if d == '': return 'NULL'
     26                return str(d)
    2527        if t == 'money':
    26                 if d == "": return "NULL"
     28                if d == '': return 'NULL'
    2729                return "'%.2f'" % float(d)
    2830        if t == 'bool':
    29                 d = str(d).lower() in ('t', 'true', '1', 'y', 'yes', 'on')
     31                if type(d) == StringType:
     32                        if d == '': return 'NULL'
     33                        d = str(d).lower() in ('t', 'true', '1', 'y', 'yes', 'on')
     34                else:
     35                        d = not not d
    3036                return ("'f'", "'t'")[d]
    31         if t in ('date', 'inet', 'cidr') and d == '': return "NULL"
     37        if t in ('date', 'inet', 'cidr'):
     38                if d == '': return 'NULL'
    3239        return "'%s'" % str(d).replace('\\','\\\\').replace('\'','\\\'')
    3340
     41def _is_quoted(s):
     42        """Check whether this string is a quoted identifier."""
     43        s = s.replace('_', 'a')
     44        return not s.isalnum() or s[:1].isdigit() or s != s.lower()
     45
     46def _is_unquoted(s):
     47        """Check whether this string is an unquoted identifier."""
     48        s = s.replace('_', 'a')
     49        return s.isalnum() and not s[:1].isdigit()
     50
     51def _split_first_part(s):
     52        """Split the first part of a dot separated string."""
     53        s = s.lstrip()
     54        if s[:1] == '"':
     55                p = []
     56                s = s.split('"', 3)[1:]
     57                p.append(s[0])
     58                while len(s) == 3 and s[1] == '':
     59                        p.append('"')
     60                        s = s[2].split('"', 2)
     61                        p.append(s[0])
     62                p = [''.join(p)]
     63                s = '"'.join(s[1:]).lstrip()
     64                if s:
     65                        if s[:0] == '.':
     66                                p.append(s[1:])
     67                        else:
     68                                s = _split_first_part(s)
     69                                p[0] += s[0]
     70                                if len(s) > 1:
     71                                        p.append(s[1])
     72        else:
     73                p = s.split('.', 1)
     74                s = p[0].rstrip()
     75                if _is_unquoted(s):
     76                        s = s.lower()
     77                p[0] = s
     78        return p
     79
     80def _split_parts(s):
     81        """Split all parts of a dot separated string."""
     82        q = []
     83        while s:
     84                s = _split_first_part(s)
     85                q.append(s[0])
     86                if len(s) < 2: break
     87                s = s[1]
     88        return q
     89
     90def _join_parts(s):
     91        """Join all parts of a dot separated string."""
     92        return '.'.join([_is_quoted(p) and '"%s"' % p or p for p in s])
     93
     94# The PostGreSQL database connection interface:
     95
    3496class DB:
    35         """This class wraps the pg connection type."""
     97        """Wrapper class for the _pg connection type."""
    3698
    3799        def __init__(self, *args, **kw):
    38100                self.db = connect(*args, **kw)
    39 
    40                 # Create convenience methods, in a way that is still overridable
    41                 # (members are not copied because they are actually functions):
    42                 for e in self.db.__methods__:
    43                         if e not in ("close", "query"): # These are wrapped separately
    44                                 setattr(self, e, getattr(self.db, e))
    45 
     101                self.dbname = self.db.db
    46102                self.__attnames = {}
    47103                self.__pkeys = {}
     
    52108                        # * to a file object to write debug statements to.
    53109
     110        def __getattr__(self, name):
     111                # All undefined members are the same as in the underlying pg connection:
     112                if self.db:
     113                        return getattr(self.db, name)
     114                else:
     115                        raise InternalError, 'Connection is not valid'
     116
    54117        def _do_debug(self, s):
    55118                """Print a debug message."""
     
    59122                if isinstance(self.debug, FileType): print >> self.debug, s
    60123
    61         def close(self,):
     124        def close(self):
    62125                """Close the database connection."""
    63126                # Wraps shared library function so we can track state.
    64                 self.db.close()
    65                 self.db = None
     127
     128                if self.db:
     129                        self.db.close()
     130                        self.db = None
     131                else:
     132                        raise InternalError, 'Connection already closed'
    66133
    67134        def reopen(self):
     
    72139
    73140                """
    74                 if self.db: self.close()
    75                 try: self.db = connect(*self.__args[0], **self.__args[1])
     141                if self.db:
     142                        self.db.close()
     143                try:
     144                        self.db = connect(*self.__args[0], **self.__args[1])
    76145                except:
    77146                        self.db = None
     
    90159                """
    91160                # Wraps shared library function for debugging.
     161                if not self.db:
     162                        raise InternalError, 'Connection is not valid'
    92163                self._do_debug(qstr)
    93164                return self.db.query(qstr)
    94165
    95         def split_schema(self, cl):
     166        def _split_schema(self, cl):
    96167                """Return schema and name of object separately.
    97168
    98                 If name is not qualified, determine first schema in search path
    99                 contaning the object, otherwise split qualified name.
    100 
    101                 """
    102                 if cl.find('.') < 0:
     169                This auxiliary function splits off the namespace (schema)
     170                belonging to the class with the name cl. If the class name
     171                is not qualified, the function is able to determine the schema
     172                of the class, taking into account the current search path.
     173
     174                """
     175                s = _split_parts(cl)
     176                if len(s) > 1: # name already qualfied?
     177                        # should be database.schema.table or schema.table
     178                        if len(s) > 3:
     179                                raise ProgrammingError, 'Too many dots in class name %s' % cl
     180                        schema, cl = s[-2:]
     181                else:
     182                        cl = s[0]
    103183                        # determine search path
    104                         query = "SELECT current_schemas(TRUE)"
     184                        query = 'SELECT current_schemas(TRUE)'
    105185                        schemas = self.db.query(query).getresult()[0][0][1:-1].split(',')
    106186                        if schemas: # non-empty path
    107187                                # search schema for this object in the current search path
    108                                 query = " UNION ".join(["SELECT %d AS n, '%s' AS nspname" % s
     188                                query = ' UNION '.join(["SELECT %d AS n, '%s' AS nspname" % s
    109189                                        for s in enumerate(schemas)])
    110190                                query = ("SELECT nspname FROM pg_class"
     
    120200                        else: # empty path
    121201                                schema = 'public'
    122                 else: # name already qualfied?
    123                         # should be database.schema.table or schema.table
    124                         schema, cl = cl.split('.', 2)[-2:]
    125202                return schema, cl
    126203
     
    133210
    134211                """
     212                # Get all the primary keys at once
    135213                # Get all the primary keys at once
    136214                if isinstance(newpkey, DictType):
    137215                        self.__pkeys = newpkey
    138216                        return newpkey
    139                 qcl = "%s.%s"  % self.split_schema(cl) # determine qualified name
     217                qcl = _join_parts(self._split_schema(cl)) # build qualified name
    140218                if newpkey:
    141219                        self.__pkeys[qcl] = newpkey
     
    143221                if self.__pkeys == {} or not self.__pkeys.has_key(qcl):
    144222                        # if not found, determine pkey again in case it was added after we started
    145                         for q, a in self.db.query("SELECT pg_namespace.nspname||'.'||"
    146                                 "pg_class.relname,pg_attribute.attname FROM pg_class"
     223                        for r in self.db.query("SELECT pg_namespace.nspname"
     224                                ",pg_class.relname,pg_attribute.attname FROM pg_class"
    147225                                " JOIN pg_namespace ON pg_namespace.oid=pg_class.relnamespace"
    148226                                " AND pg_namespace.nspname NOT LIKE 'pg_%'"
     
    152230                                " AND pg_index.indisprimary='t'"
    153231                                " AND pg_index.indkey[0]=pg_attribute.attnum").getresult():
    154                                 self.__pkeys[q] = a
     232                                self.__pkeys[_join_parts(r[:2])] = r[2] # build qualified name
    155233                        self._do_debug(self.__pkeys)
    156234                # will raise an exception if primary key doesn't exist
     
    159237        def get_databases(self):
    160238                """Get list of databases in the system."""
    161                 return [x[0] for x in
    162                         self.db.query("SELECT datname FROM pg_database").getresult()]
     239                return [s for s, in
     240                        self.db.query('SELECT datname FROM pg_database').getresult()]
    163241
    164242        def get_tables(self):
    165243                """Get list of tables in connected database."""
    166                 return [x[0] for x in
    167                         self.db.query("SELECT pg_namespace.nspname||'.'||"
    168                                 "pg_class.relname FROM pg_class"
     244                return [_join_parts(s) for s in
     245                        self.db.query("SELECT pg_namespace.nspname"
     246                                ",pg_class.relname FROM pg_class"
    169247                                " JOIN pg_namespace ON pg_namespace.oid=pg_class.relnamespace"
    170248                                " WHERE pg_class.relkind='r' AND"
    171                                 " pg_class.relname !~ '^Inv' AND "
    172                                 " pg_class.relname !~ '^pg_' order by 1").getresult()]
     249                                " pg_class.relname!~'^Inv' AND "
     250                                " pg_class.relname!~'^pg_' ORDER BY 1,2").getresult()]
    173251
    174252        def get_attnames(self, cl, newattnames = None):
     
    186264                elif newattnames:
    187265                        raise ProgrammingError, \
    188                                 "If supplied, newattnames must be a dictionary"
    189                 cl = self.split_schema(cl) # split into schema and cl
    190                 qcl = '%s.%s' % cl # build qualified name
     266                                'If supplied, newattnames must be a dictionary'
     267                cl = self._split_schema(cl) # split into schema and cl
     268                qcl = _join_parts(cl) # build qualified name
    191269                # May as well cache them:
    192270                if self.__attnames.has_key(qcl):
    193271                        return self.__attnames[qcl]
    194                 query = ("SELECT pg_attribute.attname,pg_type.typname FROM pg_class"
     272                if qcl not in self.get_tables():
     273                        raise ProgrammingError, 'Class %s does not exist' % qcl
     274                t = {}
     275                for att, typ in self.db.query("SELECT pg_attribute.attname"
     276                        ",pg_type.typname FROM pg_class"
    195277                        " JOIN pg_namespace ON pg_class.relnamespace=pg_namespace.oid"
    196278                        " JOIN pg_attribute ON pg_attribute.attrelid=pg_class.oid"
    197279                        " JOIN pg_type ON pg_type.oid=pg_attribute.atttypid"
    198280                        " WHERE pg_namespace.nspname='%s' AND pg_class.relname='%s'"
    199                         " AND pg_attribute.attnum>0 AND pg_attribute.attisdropped='f'" % cl)
    200                 l = {}
    201                 for att, typ in self.db.query(query).getresult():
    202                         if typ.startswith("interval"):
    203                                 l[att] = 'date'
    204                         elif typ.startswith("int"):
    205                                 l[att] = 'int'
    206                         elif typ.startswith("oid"):
    207                                 l[att] = 'int'
    208                         elif typ.startswith("text"):
    209                                 l[att] = 'text'
    210                         elif typ.startswith("char"):
    211                                 l[att] = 'text'
    212                         elif typ.startswith("name"):
    213                                 l[att] = 'text'
    214                         elif typ.startswith("abstime"):
    215                                 l[att] = 'date'
    216                         elif typ.startswith("date"):
    217                                 l[att] = 'date'
    218                         elif typ.startswith("timestamp"):
    219                                 l[att] = 'date'
    220                         elif typ.startswith("bool"):
    221                                 l[att] = 'bool'
    222                         elif typ.startswith("float"):
    223                                 l[att] = 'decimal'
    224                         elif typ.startswith("money"):
    225                                 l[att] = 'money'
     281                        " AND pg_attribute.attnum>0 AND pg_attribute.attisdropped='f'"
     282                                % cl).getresult():
     283                        if typ.startswith('interval'):
     284                                t[att] = 'date'
     285                        elif typ.startswith('int'):
     286                                t[att] = 'int'
     287                        elif typ.startswith('oid'):
     288                                t[att] = 'int'
     289                        elif typ.startswith('text'):
     290                                t[att] = 'text'
     291                        elif typ.startswith('char'):
     292                                t[att] = 'text'
     293                        elif typ.startswith('name'):
     294                                t[att] = 'text'
     295                        elif typ.startswith('abstime'):
     296                                t[att] = 'date'
     297                        elif typ.startswith('date'):
     298                                t[att] = 'date'
     299                        elif typ.startswith('timestamp'):
     300                                t[att] = 'date'
     301                        elif typ.startswith('bool'):
     302                                t[att] = 'bool'
     303                        elif typ.startswith('float'):
     304                                t[att] = 'decimal'
     305                        elif typ.startswith('money'):
     306                                t[att] = 'money'
    226307                        else:
    227                                 l[att] = 'text'
    228                 l['oid'] = 'int' # every table has this
    229                 self.__attnames[qcl] = l # cache it
     308                                t[att] = 'text'
     309                t['oid'] = 'int' # every table has this
     310                self.__attnames[qcl] = t # cache it
    230311                return self.__attnames[qcl]
    231312
     
    238319                then the value for the key is taken from it and it is modified to
    239320                include the new values, replacing existing values where necessary.
    240                 The oid is also put into the dictionary but in order to allow the
    241                 caller to work with multiple tables, the attribute name is munge
    242                 as "oid_schema_table" to make it unique.
     321                The OID is also put into the dictionary, but in order to allow the
     322                caller to work with multiple tables, it is munged as oid(schema.table).
    243323
    244324                """
    245325                if cl.endswith('*'): # scan descendant tables?
    246                         xcl = cl[:-1].rstrip() # need parent table name
    247                 else:
    248                         xcl = cl
    249                 xcl = self.split_schema(xcl) # split into schema and cl
    250                 qcl = '%s.%s' % xcl # build qualified name
    251                 if keyname == None:     # use the primary key by default
     326                        cl = cl[:-1].rstrip() # need parent table name
     327                qcl = _join_parts(self._split_schema(cl)) # build qualified name
     328                # To allow users to work with multiple tables,
     329                # we munge the name when the key is "oid"
     330                foid = 'oid(%s)' % qcl # build mangled name
     331                if keyname == None: # use the primary key by default
    252332                        keyname = self.pkey(qcl)
    253333                fnames = self.get_attnames(qcl)
    254334                if isinstance(arg, DictType):
    255                         # To allow users to work with multiple tables,
    256                         # we munge the name when the key is "oid"
    257                         if keyname == 'oid':
    258                                 foid = 'oid_%s_%s' % xcl # build mangled name
    259                                 k = arg[foid]
    260                         else:
    261                                 k = arg[keyname]
     335                        k = arg[keyname == 'oid' and foid or keyname]
    262336                else:
    263337                        k = arg
     
    265339                # We want the oid for later updates if that isn't the key
    266340                if keyname == 'oid':
    267                         q = "SELECT * FROM %s WHERE oid=%s" % (cl, k)
     341                        q = 'SELECT * FROM %s WHERE oid=%s LIMIT 1' % (qcl, k)
    268342                elif view:
    269                         q = "SELECT * FROM %s WHERE %s=%s" % \
    270                                 (cl, keyname, _quote(k, fnames[keyname]))
    271                 else:
    272                         foid = 'oid_%s_%s' % xcl # build mangled name
    273                         q = "SELECT oid AS %s,%s FROM %s WHERE %s=%s" % \
    274                                 (foid, ','.join(fnames.keys()), cl, \
     343                        q = 'SELECT * FROM %s WHERE %s=%s LIMIT 1' % \
     344                                (qcl, keyname, _quote(k, fnames[keyname]))
     345                else:
     346                        q = 'SELECT %s FROM %s WHERE %s=%s LIMIT 1' % \
     347                                (','.join(fnames.keys()), qcl, \
    275348                                        keyname, _quote(k, fnames[keyname]))
    276349                self._do_debug(q)
    277350                res = self.db.query(q).dictresult()
    278                 if res == []:
     351                if not res:
    279352                        raise DatabaseError, \
    280                                 "No such record in %s where %s is %s" % \
    281                                         (cl, keyname, _quote(k, fnames[keyname]))
    282                         return None
    283                 for k in res[0].keys():
    284                         arg[k] = res[0][k]
     353                                'No such record in %s where %s=%s' % \
     354                                        (qcl, keyname, _quote(k, fnames[keyname]))
     355                for k, d in res[0].items():
     356                        if k == 'oid':
     357                                k = foid
     358                        arg[k] = d
    285359                return arg
    286360
     
    297371
    298372                """
    299                 cl = self.split_schema(cl) # split into schema and cl
    300                 qcl = '%s.%s' % cl # build qualified name
    301                 foid = 'oid_%s_%s' % cl # build mangled name
     373                qcl = _join_parts(self._split_schema(cl)) # build qualified name
     374                foid = 'oid(%s)' % qcl # build mangled name
    302375                fnames = self.get_attnames(qcl)
    303                 l = []
     376                t = []
    304377                n = []
    305378                for f in fnames.keys():
    306379                        if f != 'oid' and a.has_key(f):
    307                                 l.append(_quote(a[f], fnames[f]))
     380                                t.append(_quote(a[f], fnames[f]))
    308381                                n.append(f)
    309                 q = "INSERT INTO %s (%s) VALUES (%s)" % \
    310                         (qcl, ','.join(n), ','.join(l))
     382                q = 'INSERT INTO %s (%s) VALUES (%s)' % \
     383                        (qcl, ','.join(n), ','.join(t))
    311384                self._do_debug(q)
    312385                a[foid] = self.db.query(q)
     
    330403                # Update always works on the oid which get returns if available,
    331404                # otherwise use the primary key.  Fail if neither.
    332                 cl = self.split_schema(cl) # split into schema and cl
    333                 qcl = '%s.%s' % cl # build qualified name
    334                 foid = 'oid_%s_%s' % cl # build mangled oid
     405                qcl = _join_parts(self._split_schema(cl)) # build qualified name
     406                foid = 'oid(%s)' % qcl # build mangled oid
    335407                if a.has_key(foid):
    336408                        where = "oid=%s" % a[foid]
     
    340412                        except:
    341413                                raise ProgrammingError, \
    342                                         "Update needs primary key or oid as %s" % foid
     414                                        'Update needs primary key or oid as %s' % foid
    343415                        where = "%s='%s'" % (pk, a[pk])
    344416                v = []
     
    347419                for ff in fnames.keys():
    348420                        if ff != 'oid' and a.has_key(ff):
    349                                 v.append("%s=%s" % (ff, _quote(a[ff], fnames[ff])))
     421                                v.append('%s=%s' % (ff, _quote(a[ff], fnames[ff])))
    350422                if v == []:
    351423                        return None
    352                 q = "UPDATE %s SET %s WHERE %s" % (qcl, ','.join(v), where)
     424                q = 'UPDATE %s SET %s WHERE %s' % (qcl, ','.join(v), where)
    353425                self._do_debug(q)
    354426                self.db.query(q)
     
    359431                        return self.get(qcl, a)
    360432
    361         def clear(self, cl, a = {}):
     433        def clear(self, cl, a = None):
    362434                """
    363435
    364436                This method clears all the attributes to values determined by the types.
    365                 Numeric types are set to 0, dates are set to 'TODAY' and everything
    366                 else is set to the empty string.  If the array argument is present,
    367                 it is used as the array and any entries matching attribute names
    368                 are cleared with everything else left unchanged.
     437                Numeric types are set to 0, Booleans are set to 'f', dates are set
     438                to 'now()' and everything else is set to the empty string.
     439                If the array argument is present, it is used as the array and any entries
     440                matching attribute names are cleared with everything else left unchanged.
    369441
    370442                """
    371443                # At some point we will need a way to get defaults from a table.
    372                 qcl = "%s.%s"  % self.split_schema(cl) # determine qualified name
     444                if a is None: a = {} # empty if argument is not present
     445                qcl = _join_parts(self._split_schema(cl)) # build qualified name
     446                foid = 'oid(%s)' % qcl # build mangled oid
    373447                fnames = self.get_attnames(qcl)
    374                 for ff in fnames.keys():
    375                         if fnames[ff] in ['int', 'decimal', 'seq', 'money']:
    376                                 a[ff] = 0
     448                for k, t in fnames.items():
     449                        if k == 'oid': continue
     450                        if t in ['int', 'decimal', 'seq', 'money']:
     451                                a[k] = 0
     452                        elif t == 'bool':
     453                                a[k] = 'f'
     454                        elif t == 'date':
     455                                a[k] = 'now()'
    377456                        else:
    378                                 a[ff] = ""
    379                 a['oid'] = 0
     457                                a[k] = ''
    380458                return a
    381459
     
    390468                # One day we will be testing that the record to be deleted
    391469                # isn't referenced somewhere (or else PostgreSQL will).
    392                 cl = self.split_schema(cl) # split into schema and cl
    393                 qcl = '%s.%s' % cl # build qualified name
    394                 foid = 'oid_%s_%s' % cl # build mangled oid
    395                 q = "DELETE FROM %s WHERE oid=%s" % (qcl, a[foid])
     470                qcl = _join_parts(self._split_schema(cl)) # build qualified name
     471                foid = 'oid(%s)' % qcl # build mangled oid
     472                q = 'DELETE FROM %s WHERE oid=%s' % (qcl, a[foid])
    396473                self._do_debug(q)
    397474                self.db.query(q)
Note: See TracChangeset for help on using the changeset viewer.