Changeset 346


Ignore:
Timestamp:
Nov 1, 2008, 2:21:12 PM (11 years ago)
Author:
cito
Message:

Consistently use 4 spaces instead of tabs for all Python moduls, as recommended in PEP8.

Location:
trunk/module
Files:
4 edited

Legend:

Unmodified
Added
Removed
  • trunk/module/pg.py

    r333 r346  
    66# Improved by Christoph Zwerschke
    77#
    8 # $Id: pg.py,v 1.57 2008-10-09 20:45:29 cito Exp $
     8# $Id: pg.py,v 1.58 2008-11-01 18:21:12 cito Exp $
    99#
    1010
     
    3131
    3232def _quote(d, t):
    33         """Return quotes if needed."""
    34         if d is None:
    35                 return 'NULL'
    36         if t in ('int', 'seq', 'float', 'num'):
    37                 if d == '':
    38                         return 'NULL'
    39                 return str(d)
    40         if t == 'money':
    41                 if d == '':
    42                         return 'NULL'
    43                 return "'%.2f'" % float(d)
    44         if t == 'bool':
    45                 if type(d) == StringType:
    46                         if d == '':
    47                                 return 'NULL'
    48                         d = str(d).lower() in ('t', 'true', '1', 'y', 'yes', 'on')
    49                 else:
    50                         d = not not d
    51                 return ("'f'", "'t'")[d]
    52         if t in ('date', 'inet', 'cidr'):
    53                 if d == '':
    54                         return 'NULL'
    55                 if d.lower() in ('current_date', 'current_time',
    56                         'current_timestamp', 'localtime', 'localtimestamp'):
    57                         return d
    58         return "'%s'" % str(d).replace("\\", "\\\\").replace("'", "''")
     33    """Return quotes if needed."""
     34    if d is None:
     35        return 'NULL'
     36    if t in ('int', 'seq', 'float', 'num'):
     37        if d == '':
     38            return 'NULL'
     39        return str(d)
     40    if t == 'money':
     41        if d == '':
     42            return 'NULL'
     43        return "'%.2f'" % float(d)
     44    if t == 'bool':
     45        if type(d) == StringType:
     46            if d == '':
     47                return 'NULL'
     48            d = str(d).lower() in ('t', 'true', '1', 'y', 'yes', 'on')
     49        else:
     50            d = not not d
     51        return ("'f'", "'t'")[d]
     52    if t in ('date', 'inet', 'cidr'):
     53        if d == '':
     54            return 'NULL'
     55        if d.lower() in ('current_date', 'current_time',
     56            'current_timestamp', 'localtime', 'localtimestamp'):
     57            return d
     58    return "'%s'" % str(d).replace("\\", "\\\\").replace("'", "''")
    5959
    6060def _is_quoted(s):
    61         """Check whether this string is a quoted identifier."""
    62         s = s.replace('_', 'a')
    63         return not s.isalnum() or s[:1].isdigit() or s != s.lower()
     61    """Check whether this string is a quoted identifier."""
     62    s = s.replace('_', 'a')
     63    return not s.isalnum() or s[:1].isdigit() or s != s.lower()
    6464
    6565def _is_unquoted(s):
    66         """Check whether this string is an unquoted identifier."""
    67         s = s.replace('_', 'a')
    68         return s.isalnum() and not s[:1].isdigit()
     66    """Check whether this string is an unquoted identifier."""
     67    s = s.replace('_', 'a')
     68    return s.isalnum() and not s[:1].isdigit()
    6969
    7070def _split_first_part(s):
    71         """Split the first part of a dot separated string."""
    72         s = s.lstrip()
    73         if s[:1] == '"':
    74                 p = []
    75                 s = s.split('"', 3)[1:]
    76                 p.append(s[0])
    77                 while len(s) == 3 and s[1] == '':
    78                         p.append('"')
    79                         s = s[2].split('"', 2)
    80                         p.append(s[0])
    81                 p = [''.join(p)]
    82                 s = '"'.join(s[1:]).lstrip()
    83                 if s:
    84                         if s[:0] == '.':
    85                                 p.append(s[1:])
    86                         else:
    87                                 s = _split_first_part(s)
    88                                 p[0] += s[0]
    89                                 if len(s) > 1:
    90                                         p.append(s[1])
    91         else:
    92                 p = s.split('.', 1)
    93                 s = p[0].rstrip()
    94                 if _is_unquoted(s):
    95                         s = s.lower()
    96                 p[0] = s
    97         return p
     71    """Split the first part of a dot separated string."""
     72    s = s.lstrip()
     73    if s[:1] == '"':
     74        p = []
     75        s = s.split('"', 3)[1:]
     76        p.append(s[0])
     77        while len(s) == 3 and s[1] == '':
     78            p.append('"')
     79            s = s[2].split('"', 2)
     80            p.append(s[0])
     81        p = [''.join(p)]
     82        s = '"'.join(s[1:]).lstrip()
     83        if s:
     84            if s[:0] == '.':
     85                p.append(s[1:])
     86            else:
     87                s = _split_first_part(s)
     88                p[0] += s[0]
     89                if len(s) > 1:
     90                    p.append(s[1])
     91    else:
     92        p = s.split('.', 1)
     93        s = p[0].rstrip()
     94        if _is_unquoted(s):
     95            s = s.lower()
     96        p[0] = s
     97    return p
    9898
    9999def _split_parts(s):
    100         """Split all parts of a dot separated string."""
    101         q = []
    102         while s:
    103                 s = _split_first_part(s)
    104                 q.append(s[0])
    105                 if len(s) < 2:
    106                         break
    107                 s = s[1]
    108         return q
     100    """Split all parts of a dot separated string."""
     101    q = []
     102    while s:
     103        s = _split_first_part(s)
     104        q.append(s[0])
     105        if len(s) < 2:
     106            break
     107        s = s[1]
     108    return q
    109109
    110110def _join_parts(s):
    111         """Join all parts of a dot separated string."""
    112         return '.'.join([_is_quoted(p) and '"%s"' % p or p for p in s])
     111    """Join all parts of a dot separated string."""
     112    return '.'.join([_is_quoted(p) and '"%s"' % p or p for p in s])
    113113
    114114
     
    116116
    117117class DB:
    118         """Wrapper class for the _pg connection type."""
    119 
    120         def __init__(self, *args, **kw):
    121                 """Create a new connection.
    122 
    123                 You can pass either the connection parameters or an existing
    124                 _pg or pgdb connection. This allows you to use the methods
    125                 of the classic pg interface with a DB-API 2 pgdb connection.
    126 
    127                 """
    128                 if not args and len(kw) == 1:
    129                         db = kw.get('db')
    130                 elif not kw and len(args) == 1:
    131                         db = args[0]
    132                 else:
    133                         db = None
    134                 if db:
    135                         if isinstance(db, DB):
    136                                 db = db.db
    137                         else:
    138                                 try:
    139                                         db = db._cnx
    140                                 except AttributeError:
    141                                         pass
    142                 if not db or not hasattr(db, 'db') or not hasattr(db, 'query'):
    143                         db = connect(*args, **kw)
    144                         self._closeable = 1
    145                 else:
    146                         self._closeable = 0
    147                 self.db = db
    148                 self.dbname = db.db
    149                 self._attnames = {}
    150                 self._pkeys = {}
    151                 self._args = args, kw
    152                 self.debug = None # For debugging scripts, this can be set
    153                         # * to a string format specification (e.g. in CGI set to "%s<BR>"),
    154                         # * to a function which takes a string argument or
    155                         # * to a file object to write debug statements to.
    156 
    157         def __getattr__(self, name):
    158                 # All undefined members are the same as in the underlying pg connection:
    159                 if self.db:
    160                         return getattr(self.db, name)
    161                 else:
    162                         raise InternalError('Connection is not valid')
    163 
    164         # For convenience, define some module functions as static methods also:
    165         escape_string, escape_bytea, unescape_bytea = map(staticmethod,
    166                 (escape_string, escape_bytea, unescape_bytea))
    167 
    168         def _do_debug(self, s):
    169                 """Print a debug message."""
    170                 if self.debug:
    171                         if isinstance(self.debug, StringType):
    172                                 print self.debug % s
    173                         elif isinstance(self.debug, FunctionType):
    174                                 self.debug(s)
    175                         elif isinstance(self.debug, FileType):
    176                                 print >> self.debug, s
    177 
    178         def close(self):
    179                 """Close the database connection."""
    180                 # Wraps shared library function so we can track state.
    181                 if self._closeable:
    182                         if self.db:
    183                                 self.db.close()
    184                                 self.db = None
    185                         else:
    186                                 raise InternalError('Connection already closed')
    187 
    188         def reopen(self):
    189                 """Reopen connection to the database.
    190 
    191                 Used in case we need another connection to the same database.
    192                 Note that we can still reopen a database that we have closed.
    193 
    194                 """
    195                 # There is no such shared library function.
    196                 if self._closeable:
    197                         if self.db:
    198                                 self.db.close()
    199                         try:
    200                                 self.db = connect(*self._args[0], **self._args[1])
    201                         except:
    202                                 self.db = None
    203                                 raise
    204 
    205         def query(self, qstr):
    206                 """Executes a SQL command string.
    207 
    208                 This method simply sends a SQL query to the database. If the query is
    209                 an insert statement, the return value is the OID of the newly
    210                 inserted row.  If it is otherwise a query that does not return a result
    211                 (ie. is not a some kind of SELECT statement), it returns None.
    212                 Otherwise, it returns a pgqueryobject that can be accessed via the
    213                 getresult or dictresult method or simply printed.
    214 
    215                 """
    216                 # Wraps shared library function for debugging.
    217                 if not self.db:
    218                         raise InternalError('Connection is not valid')
    219                 self._do_debug(qstr)
    220                 return self.db.query(qstr)
    221 
    222         def _split_schema(self, cl):
    223                 """Return schema and name of object separately.
    224 
    225                 This auxiliary function splits off the namespace (schema)
    226                 belonging to the class with the name cl. If the class name
    227                 is not qualified, the function is able to determine the schema
    228                 of the class, taking into account the current search path.
    229 
    230                 """
    231                 s = _split_parts(cl)
    232                 if len(s) > 1: # name already qualfied?
    233                         # should be database.schema.table or schema.table
    234                         if len(s) > 3:
    235                                 raise ProgrammingError('Too many dots in class name %s' % cl)
    236                         schema, cl = s[-2:]
    237                 else:
    238                         cl = s[0]
    239                         # determine search path
    240                         query = 'SELECT current_schemas(TRUE)'
    241                         schemas = self.db.query(query).getresult()[0][0][1:-1].split(',')
    242                         if schemas: # non-empty path
    243                                 # search schema for this object in the current search path
    244                                 query = ' UNION '.join(
    245                                         ["SELECT %d::integer AS n, '%s'::name AS nspname"
    246                                                 % s for s in enumerate(schemas)])
    247                                 query = ("SELECT nspname FROM pg_class"
    248                                         " JOIN pg_namespace ON pg_class.relnamespace=pg_namespace.oid"
    249                                         " JOIN (%s) AS p USING (nspname)"
    250                                         " WHERE pg_class.relname='%s'"
    251                                         " ORDER BY n LIMIT 1" % (query, cl))
    252                                 schema = self.db.query(query).getresult()
    253                                 if schema: # schema found
    254                                         schema = schema[0][0]
    255                                 else: # object not found in current search path
    256                                         schema = 'public'
    257                         else: # empty path
    258                                 schema = 'public'
    259                 return schema, cl
    260 
    261         def pkey(self, cl, newpkey = None):
    262                 """This method gets or sets the primary key of a class.
    263 
    264                 If newpkey is set and is not a dictionary then set that
    265                 value as the primary key of the class.  If it is a dictionary
    266                 then replace the _pkeys dictionary with it.
    267 
    268                 """
    269                 # First see if the caller is supplying a dictionary
    270                 if isinstance(newpkey, DictType):
    271                         # make sure that we have a namespace
    272                         self._pkeys = {}
    273                         for x in newpkey.keys():
    274                                 if x.find('.') == -1:
    275                                         self._pkeys['public.' + x] = newpkey[x]
    276                                 else:
    277                                         self._pkeys[x] = newpkey[x]
    278                         return self._pkeys
    279 
    280                 qcl = _join_parts(self._split_schema(cl)) # build qualified name
    281                 if newpkey:
    282                         self._pkeys[qcl] = newpkey
    283                         return newpkey
    284 
    285                 # Get all the primary keys at once
    286                 if self._pkeys == {} or not self._pkeys.has_key(qcl):
    287                         # if not found, check again in case it was added after we started
    288                         self._pkeys = dict([
    289                                 (_join_parts(r[:2]), r[2]) for r in self.db.query(
    290                                 "SELECT pg_namespace.nspname, pg_class.relname"
    291                                         ",pg_attribute.attname FROM pg_class"
    292                                 " JOIN pg_namespace ON pg_namespace.oid=pg_class.relnamespace"
    293                                         " AND pg_namespace.nspname NOT LIKE 'pg_%'"
    294                                 " JOIN pg_attribute ON pg_attribute.attrelid=pg_class.oid"
    295                                         " AND pg_attribute.attisdropped='f'"
    296                                 " JOIN pg_index ON pg_index.indrelid=pg_class.oid"
    297                                         " AND pg_index.indisprimary='t'"
    298                                         " AND pg_index.indkey[0]=pg_attribute.attnum"
    299                                 ).getresult()])
    300                         self._do_debug(self._pkeys)
    301                 # will raise an exception if primary key doesn't exist
    302                 return self._pkeys[qcl]
    303 
    304         def get_databases(self):
    305                 """Get list of databases in the system."""
    306                 return [s[0] for s in
    307                         self.db.query('SELECT datname FROM pg_database').getresult()]
    308 
    309         def get_relations(self, kinds = None):
    310                 """Get list of relations in connected database of specified kinds.
    311 
    312                         If kinds is None or empty, all kinds of relations are returned.
    313                         Otherwise kinds can be a string or sequence of type letters
    314                         specifying which kind of relations you want to list.
    315 
    316                 """
    317                 where = kinds and "pg_class.relkind IN (%s) AND" % ','.join(
    318                         ["'%s'" % x for x in kinds]) or ''
    319                 return map(_join_parts, self.db.query(
    320                         "SELECT pg_namespace.nspname, pg_class.relname "
    321                         "FROM pg_class "
    322                         "JOIN pg_namespace ON pg_namespace.oid=pg_class.relnamespace "
    323                         "WHERE %s pg_class.relname !~ '^Inv' AND "
    324                                 "pg_class.relname !~ '^pg_' "
    325                         "ORDER BY 1, 2" % where).getresult())
    326 
    327         def get_tables(self):
    328                 """Return list of tables in connected database."""
    329                 return self.get_relations('r')
    330 
    331         def get_attnames(self, cl, newattnames = None):
    332                 """Given the name of a table, digs out the set of attribute names.
    333 
    334                 Returns a dictionary of attribute names (the names are the keys,
    335                 the values are the names of the attributes' types).
    336                 If the optional newattnames exists, it must be a dictionary and
    337                 will become the new attribute names dictionary.
    338 
    339                 """
    340                 if isinstance(newattnames, DictType):
    341                         self._attnames = newattnames
    342                         return
    343                 elif newattnames:
    344                         raise ProgrammingError(
    345                                 'If supplied, newattnames must be a dictionary')
    346                 cl = self._split_schema(cl) # split into schema and cl
    347                 qcl = _join_parts(cl) # build qualified name
    348                 # May as well cache them:
    349                 if self._attnames.has_key(qcl):
    350                         return self._attnames[qcl]
    351                 if qcl not in self.get_relations('rv'):
    352                         raise ProgrammingError('Class %s does not exist' % qcl)
    353                 t = {}
    354                 for att, typ in self.db.query("SELECT pg_attribute.attname"
    355                         ",pg_type.typname FROM pg_class"
    356                         " JOIN pg_namespace ON pg_class.relnamespace=pg_namespace.oid"
    357                         " JOIN pg_attribute ON pg_attribute.attrelid=pg_class.oid"
    358                         " JOIN pg_type ON pg_type.oid=pg_attribute.atttypid"
    359                         " WHERE pg_namespace.nspname='%s' AND pg_class.relname='%s'"
    360                         " AND (pg_attribute.attnum>0 or pg_attribute.attname='oid')"
    361                         " AND pg_attribute.attisdropped='f'"
    362                                 % cl).getresult():
    363                         if typ.startswith('bool'):
    364                                 t[att] = 'bool'
    365                         elif typ.startswith('oid'):
    366                                 t[att] = 'int'
    367                         elif typ.startswith('float'):
    368                                 t[att] = 'float'
    369                         elif typ.startswith('numeric'):
    370                                 t[att] = 'num'
    371                         elif typ.startswith('abstime'):
    372                                 t[att] = 'date'
    373                         elif typ.startswith('date'):
    374                                 t[att] = 'date'
    375                         elif typ.startswith('interval'):
    376                                 t[att] = 'date'
    377                         elif typ.startswith('int'):
    378                                 t[att] = 'int'
    379                         elif typ.startswith('timestamp'):
    380                                 t[att] = 'date'
    381                         elif typ.startswith('money'):
    382                                 t[att] = 'money'
    383                         else:
    384                                 t[att] = 'text'
    385                 self._attnames[qcl] = t # cache it
    386                 return self._attnames[qcl]
    387 
    388         def get(self, cl, arg, keyname = None, view = 0):
    389                 """Get a tuple from a database table or view.
    390 
    391                 This method is the basic mechanism to get a single row.  It assumes
    392                 that the key specifies a unique row.  If keyname is not specified
    393                 then the primary key for the table is used.  If arg is a dictionary
    394                 then the value for the key is taken from it and it is modified to
    395                 include the new values, replacing existing values where necessary.
    396                 The OID is also put into the dictionary, but in order to allow the
    397                 caller to work with multiple tables, it is munged as oid(schema.table).
    398 
    399                 """
    400                 if cl.endswith('*'): # scan descendant tables?
    401                         cl = cl[:-1].rstrip() # need parent table name
    402                 qcl = _join_parts(self._split_schema(cl)) # build qualified name
    403                 # To allow users to work with multiple tables,
    404                 # we munge the name when the key is "oid"
    405                 foid = 'oid(%s)' % qcl # build mangled name
    406                 if keyname is None: # use the primary key by default
    407                         keyname = self.pkey(qcl)
    408                 fnames = self.get_attnames(qcl)
    409                 if isinstance(arg, DictType):
    410                         # XXX this code is for backwards compatibility and will be
    411                         # XXX removed eventually
    412                         if not arg.has_key(foid):
    413                                 ofoid = 'oid_' + self._split_schema(cl)[-1]
    414                                 if arg.has_key(ofoid):
    415                                         arg[foid] = arg[ofoid]
    416 
    417                         k = arg[keyname == 'oid' and foid or keyname]
    418                 else:
    419                         k = arg
    420                         arg = {}
    421                 # We want the oid for later updates if that isn't the key
    422                 if keyname == 'oid':
    423                         q = 'SELECT * FROM %s WHERE oid=%s LIMIT 1' % (qcl, k)
    424                 elif view:
    425                         q = 'SELECT * FROM %s WHERE %s=%s LIMIT 1' % \
    426                                 (qcl, keyname, _quote(k, fnames[keyname]))
    427                 else:
    428                         q = 'SELECT %s FROM %s WHERE %s=%s LIMIT 1' % \
    429                                 (','.join(fnames.keys()), qcl, \
    430                                         keyname, _quote(k, fnames[keyname]))
    431                 self._do_debug(q)
    432                 res = self.db.query(q).dictresult()
    433                 if not res:
    434                         raise DatabaseError('No such record in %s where %s=%s'
    435                                 % (qcl, keyname, _quote(k, fnames[keyname])))
    436                 for k, d in res[0].items():
    437                         if k == 'oid':
    438                                 k = foid
    439                         arg[k] = d
    440                 return arg
    441 
    442         def insert(self, cl, d = None, **kw):
    443                 """Insert a tuple into a database table.
    444 
    445                 This method inserts values into the table specified filling in the
    446                 values from the dictionary.  It then reloads the dictionary with the
    447                 values from the database.  This causes the dictionary to be updated
    448                 with values that are modified by rules, triggers, etc.
    449 
    450                 Note: The method currently doesn't support insert into views
    451                 although PostgreSQL does.
    452 
    453                 """
    454                 if d is None:
    455                         a = {}
    456                 else:
    457                         a = d
    458                 a.update(kw)
    459 
    460                 qcl = _join_parts(self._split_schema(cl)) # build qualified name
    461                 foid = 'oid(%s)' % qcl # build mangled name
    462                 fnames = self.get_attnames(qcl)
    463                 t = []
    464                 n = []
    465                 for f in fnames.keys():
    466                         if f != 'oid' and a.has_key(f):
    467                                 t.append(_quote(a[f], fnames[f]))
    468                                 n.append('"%s"' % f)
    469                 q = 'INSERT INTO %s (%s) VALUES (%s)' % \
    470                         (qcl, ','.join(n), ','.join(t))
    471                 self._do_debug(q)
    472                 a[foid] = self.db.query(q)
    473                 # Reload the dictionary to catch things modified by engine.
    474                 # Note that get() changes 'oid' below to oid_schema_table.
    475                 # If no read perms (it can and does happen), return None.
    476                 try:
    477                         return self.get(qcl, a, 'oid')
    478                 except:
    479                         return None
    480 
    481         def update(self, cl, d = None, **kw):
    482                 """Update an existing row in a database table.
    483 
    484                 Similar to insert but updates an existing row.  The update is based
    485                 on the OID value as munged by get.  The array returned is the
    486                 one sent modified to reflect any changes caused by the update due
    487                 to triggers, rules, defaults, etc.
    488 
    489                 """
    490                 # Update always works on the oid which get returns if available,
    491                 # otherwise use the primary key.  Fail if neither.
    492                 qcl = _join_parts(self._split_schema(cl)) # build qualified name
    493                 foid = 'oid(%s)' % qcl # build mangled oid
    494 
    495                 # Note that we only accept oid key from named args for safety
    496                 if kw.has_key('oid'):
    497                         kw[foid] = kw['oid']
    498                         del kw['oid']
    499 
    500                 if d is None:
    501                         a = {}
    502                 else:
    503                         a = d
    504                 a.update(kw)
    505 
    506                 # XXX this code is for backwards compatibility and will be
    507                 # XXX removed eventually
    508                 if not a.has_key(foid):
    509                         ofoid = 'oid_' + self._split_schema(cl)[-1]
    510                         if a.has_key(ofoid):
    511                                 a[foid] = a[ofoid]
    512 
    513                 if a.has_key(foid):
    514                         where = "oid=%s" % a[foid]
    515                 else:
    516                         try:
    517                                 pk = self.pkey(qcl)
    518                         except:
    519                                 raise ProgrammingError(
    520                                         'Update needs primary key or oid as %s' % foid)
    521                         where = "%s='%s'" % (pk, a[pk])
    522                 v = []
    523                 k = 0
    524                 fnames = self.get_attnames(qcl)
    525                 for ff in fnames.keys():
    526                         if ff != 'oid' and a.has_key(ff):
    527                                 v.append('%s=%s' % (ff, _quote(a[ff], fnames[ff])))
    528                 if v == []:
    529                         return None
    530                 q = 'UPDATE %s SET %s WHERE %s' % (qcl, ','.join(v), where)
    531                 self._do_debug(q)
    532                 self.db.query(q)
    533                 # Reload the dictionary to catch things modified by engine:
    534                 if a.has_key(foid):
    535                         return self.get(qcl, a, 'oid')
    536                 else:
    537                         return self.get(qcl, a)
    538 
    539         def clear(self, cl, a = None):
    540                 """
    541 
    542                 This method clears all the attributes to values determined by the types.
    543                 Numeric types are set to 0, Booleans are set to 'f', and everything
    544                 else is set to the empty string.  If the array argument is present,
    545                 it is used as the array and any entries matching attribute names are
    546                 cleared with everything else left unchanged.
    547 
    548                 """
    549                 # At some point we will need a way to get defaults from a table.
    550                 if a is None:
    551                         a = {} # empty if argument is not present
    552                 qcl = _join_parts(self._split_schema(cl)) # build qualified name
    553                 foid = 'oid(%s)' % qcl # build mangled oid
    554                 fnames = self.get_attnames(qcl)
    555                 for k, t in fnames.items():
    556                         if k == 'oid':
    557                                 continue
    558                         if t in ['int', 'seq', 'float', 'num', 'money']:
    559                                 a[k] = 0
    560                         elif t == 'bool':
    561                                 a[k] = 'f'
    562                         else:
    563                                 a[k] = ''
    564                 return a
    565 
    566         def delete(self, cl, d = None, **kw):
    567                 """Delete an existing row in a database table.
    568 
    569                 This method deletes the row from a table.
    570                 It deletes based on the OID munged as described above."""
    571 
    572                 # Like update, delete works on the oid.
    573                 # One day we will be testing that the record to be deleted
    574                 # isn't referenced somewhere (or else PostgreSQL will).
    575                 qcl = _join_parts(self._split_schema(cl)) # build qualified name
    576                 foid = 'oid(%s)' % qcl # build mangled oid
    577 
    578                 # Note that we only accept oid key from named args for safety
    579                 if kw.has_key('oid'):
    580                         kw[foid] = kw['oid']
    581                         del kw['oid']
    582 
    583                 if d is None:
    584                         a = {}
    585                 else:
    586                         a = d
    587                 a.update(kw)
    588 
    589                 # XXX this code is for backwards compatibility and will be
    590                 # XXX removed eventually
    591                 if not a.has_key(foid):
    592                         ofoid = 'oid_' + self._split_schema(cl)[-1]
    593                         if a.has_key(ofoid):
    594                                 a[foid] = a[ofoid]
    595 
    596                 q = 'DELETE FROM %s WHERE oid=%s' % (qcl, a[foid])
    597                 self._do_debug(q)
    598                 self.db.query(q)
     118    """Wrapper class for the _pg connection type."""
     119
     120    def __init__(self, *args, **kw):
     121        """Create a new connection.
     122
     123        You can pass either the connection parameters or an existing
     124        _pg or pgdb connection. This allows you to use the methods
     125        of the classic pg interface with a DB-API 2 pgdb connection.
     126
     127        """
     128        if not args and len(kw) == 1:
     129            db = kw.get('db')
     130        elif not kw and len(args) == 1:
     131            db = args[0]
     132        else:
     133            db = None
     134        if db:
     135            if isinstance(db, DB):
     136                db = db.db
     137            else:
     138                try:
     139                    db = db._cnx
     140                except AttributeError:
     141                    pass
     142        if not db or not hasattr(db, 'db') or not hasattr(db, 'query'):
     143            db = connect(*args, **kw)
     144            self._closeable = 1
     145        else:
     146            self._closeable = 0
     147        self.db = db
     148        self.dbname = db.db
     149        self._attnames = {}
     150        self._pkeys = {}
     151        self._args = args, kw
     152        self.debug = None # For debugging scripts, this can be set
     153            # * to a string format specification (e.g. in CGI set to "%s<BR>"),
     154            # * to a function which takes a string argument or
     155            # * to a file object to write debug statements to.
     156
     157    def __getattr__(self, name):
     158        # All undefined members are the same as in the underlying pg connection:
     159        if self.db:
     160            return getattr(self.db, name)
     161        else:
     162            raise InternalError('Connection is not valid')
     163
     164    # For convenience, define some module functions as static methods also:
     165    escape_string, escape_bytea, unescape_bytea = map(staticmethod,
     166        (escape_string, escape_bytea, unescape_bytea))
     167
     168    def _do_debug(self, s):
     169        """Print a debug message."""
     170        if self.debug:
     171            if isinstance(self.debug, StringType):
     172                print self.debug % s
     173            elif isinstance(self.debug, FunctionType):
     174                self.debug(s)
     175            elif isinstance(self.debug, FileType):
     176                print >> self.debug, s
     177
     178    def close(self):
     179        """Close the database connection."""
     180        # Wraps shared library function so we can track state.
     181        if self._closeable:
     182            if self.db:
     183                self.db.close()
     184                self.db = None
     185            else:
     186                raise InternalError('Connection already closed')
     187
     188    def reopen(self):
     189        """Reopen connection to the database.
     190
     191        Used in case we need another connection to the same database.
     192        Note that we can still reopen a database that we have closed.
     193
     194        """
     195        # There is no such shared library function.
     196        if self._closeable:
     197            if self.db:
     198                self.db.close()
     199            try:
     200                self.db = connect(*self._args[0], **self._args[1])
     201            except:
     202                self.db = None
     203                raise
     204
     205    def query(self, qstr):
     206        """Executes a SQL command string.
     207
     208        This method simply sends a SQL query to the database. If the query is
     209        an insert statement, the return value is the OID of the newly
     210        inserted row.  If it is otherwise a query that does not return a result
     211        (ie. is not a some kind of SELECT statement), it returns None.
     212        Otherwise, it returns a pgqueryobject that can be accessed via the
     213        getresult or dictresult method or simply printed.
     214
     215        """
     216        # Wraps shared library function for debugging.
     217        if not self.db:
     218            raise InternalError('Connection is not valid')
     219        self._do_debug(qstr)
     220        return self.db.query(qstr)
     221
     222    def _split_schema(self, cl):
     223        """Return schema and name of object separately.
     224
     225        This auxiliary function splits off the namespace (schema)
     226        belonging to the class with the name cl. If the class name
     227        is not qualified, the function is able to determine the schema
     228        of the class, taking into account the current search path.
     229
     230        """
     231        s = _split_parts(cl)
     232        if len(s) > 1: # name already qualfied?
     233            # should be database.schema.table or schema.table
     234            if len(s) > 3:
     235                raise ProgrammingError('Too many dots in class name %s' % cl)
     236            schema, cl = s[-2:]
     237        else:
     238            cl = s[0]
     239            # determine search path
     240            query = 'SELECT current_schemas(TRUE)'
     241            schemas = self.db.query(query).getresult()[0][0][1:-1].split(',')
     242            if schemas: # non-empty path
     243                # search schema for this object in the current search path
     244                query = ' UNION '.join(
     245                    ["SELECT %d::integer AS n, '%s'::name AS nspname"
     246                        % s for s in enumerate(schemas)])
     247                query = ("SELECT nspname FROM pg_class"
     248                    " JOIN pg_namespace ON pg_class.relnamespace=pg_namespace.oid"
     249                    " JOIN (%s) AS p USING (nspname)"
     250                    " WHERE pg_class.relname='%s'"
     251                    " ORDER BY n LIMIT 1" % (query, cl))
     252                schema = self.db.query(query).getresult()
     253                if schema: # schema found
     254                    schema = schema[0][0]
     255                else: # object not found in current search path
     256                    schema = 'public'
     257            else: # empty path
     258                schema = 'public'
     259        return schema, cl
     260
     261    def pkey(self, cl, newpkey = None):
     262        """This method gets or sets the primary key of a class.
     263
     264        If newpkey is set and is not a dictionary then set that
     265        value as the primary key of the class.  If it is a dictionary
     266        then replace the _pkeys dictionary with it.
     267
     268        """
     269        # First see if the caller is supplying a dictionary
     270        if isinstance(newpkey, DictType):
     271            # make sure that we have a namespace
     272            self._pkeys = {}
     273            for x in newpkey.keys():
     274                if x.find('.') == -1:
     275                    self._pkeys['public.' + x] = newpkey[x]
     276                else:
     277                    self._pkeys[x] = newpkey[x]
     278            return self._pkeys
     279
     280        qcl = _join_parts(self._split_schema(cl)) # build qualified name
     281        if newpkey:
     282            self._pkeys[qcl] = newpkey
     283            return newpkey
     284
     285        # Get all the primary keys at once
     286        if self._pkeys == {} or not self._pkeys.has_key(qcl):
     287            # if not found, check again in case it was added after we started
     288            self._pkeys = dict([
     289                (_join_parts(r[:2]), r[2]) for r in self.db.query(
     290                "SELECT pg_namespace.nspname, pg_class.relname"
     291                    ",pg_attribute.attname FROM pg_class"
     292                " JOIN pg_namespace ON pg_namespace.oid=pg_class.relnamespace"
     293                    " AND pg_namespace.nspname NOT LIKE 'pg_%'"
     294                " JOIN pg_attribute ON pg_attribute.attrelid=pg_class.oid"
     295                    " AND pg_attribute.attisdropped='f'"
     296                " JOIN pg_index ON pg_index.indrelid=pg_class.oid"
     297                    " AND pg_index.indisprimary='t'"
     298                    " AND pg_index.indkey[0]=pg_attribute.attnum"
     299                ).getresult()])
     300            self._do_debug(self._pkeys)
     301        # will raise an exception if primary key doesn't exist
     302        return self._pkeys[qcl]
     303
     304    def get_databases(self):
     305        """Get list of databases in the system."""
     306        return [s[0] for s in
     307            self.db.query('SELECT datname FROM pg_database').getresult()]
     308
     309    def get_relations(self, kinds = None):
     310        """Get list of relations in connected database of specified kinds.
     311
     312            If kinds is None or empty, all kinds of relations are returned.
     313            Otherwise kinds can be a string or sequence of type letters
     314            specifying which kind of relations you want to list.
     315
     316        """
     317        where = kinds and "pg_class.relkind IN (%s) AND" % ','.join(
     318            ["'%s'" % x for x in kinds]) or ''
     319        return map(_join_parts, self.db.query(
     320            "SELECT pg_namespace.nspname, pg_class.relname "
     321            "FROM pg_class "
     322            "JOIN pg_namespace ON pg_namespace.oid=pg_class.relnamespace "
     323            "WHERE %s pg_class.relname !~ '^Inv' AND "
     324                "pg_class.relname !~ '^pg_' "
     325            "ORDER BY 1, 2" % where).getresult())
     326
     327    def get_tables(self):
     328        """Return list of tables in connected database."""
     329        return self.get_relations('r')
     330
     331    def get_attnames(self, cl, newattnames = None):
     332        """Given the name of a table, digs out the set of attribute names.
     333
     334        Returns a dictionary of attribute names (the names are the keys,
     335        the values are the names of the attributes' types).
     336        If the optional newattnames exists, it must be a dictionary and
     337        will become the new attribute names dictionary.
     338
     339        """
     340        if isinstance(newattnames, DictType):
     341            self._attnames = newattnames
     342            return
     343        elif newattnames:
     344            raise ProgrammingError(
     345                'If supplied, newattnames must be a dictionary')
     346        cl = self._split_schema(cl) # split into schema and cl
     347        qcl = _join_parts(cl) # build qualified name
     348        # May as well cache them:
     349        if self._attnames.has_key(qcl):
     350            return self._attnames[qcl]
     351        if qcl not in self.get_relations('rv'):
     352            raise ProgrammingError('Class %s does not exist' % qcl)
     353        t = {}
     354        for att, typ in self.db.query("SELECT pg_attribute.attname"
     355            ",pg_type.typname FROM pg_class"
     356            " JOIN pg_namespace ON pg_class.relnamespace=pg_namespace.oid"
     357            " JOIN pg_attribute ON pg_attribute.attrelid=pg_class.oid"
     358            " JOIN pg_type ON pg_type.oid=pg_attribute.atttypid"
     359            " WHERE pg_namespace.nspname='%s' AND pg_class.relname='%s'"
     360            " AND (pg_attribute.attnum>0 or pg_attribute.attname='oid')"
     361            " AND pg_attribute.attisdropped='f'"
     362                % cl).getresult():
     363            if typ.startswith('bool'):
     364                t[att] = 'bool'
     365            elif typ.startswith('oid'):
     366                t[att] = 'int'
     367            elif typ.startswith('float'):
     368                t[att] = 'float'
     369            elif typ.startswith('numeric'):
     370                t[att] = 'num'
     371            elif typ.startswith('abstime'):
     372                t[att] = 'date'
     373            elif typ.startswith('date'):
     374                t[att] = 'date'
     375            elif typ.startswith('interval'):
     376                t[att] = 'date'
     377            elif typ.startswith('int'):
     378                t[att] = 'int'
     379            elif typ.startswith('timestamp'):
     380                t[att] = 'date'
     381            elif typ.startswith('money'):
     382                t[att] = 'money'
     383            else:
     384                t[att] = 'text'
     385        self._attnames[qcl] = t # cache it
     386        return self._attnames[qcl]
     387
     388    def get(self, cl, arg, keyname = None, view = 0):
     389        """Get a tuple from a database table or view.
     390
     391        This method is the basic mechanism to get a single row.  It assumes
     392        that the key specifies a unique row.  If keyname is not specified
     393        then the primary key for the table is used.  If arg is a dictionary
     394        then the value for the key is taken from it and it is modified to
     395        include the new values, replacing existing values where necessary.
     396        The OID is also put into the dictionary, but in order to allow the
     397        caller to work with multiple tables, it is munged as oid(schema.table).
     398
     399        """
     400        if cl.endswith('*'): # scan descendant tables?
     401            cl = cl[:-1].rstrip() # need parent table name
     402        qcl = _join_parts(self._split_schema(cl)) # build qualified name
     403        # To allow users to work with multiple tables,
     404        # we munge the name when the key is "oid"
     405        foid = 'oid(%s)' % qcl # build mangled name
     406        if keyname is None: # use the primary key by default
     407            keyname = self.pkey(qcl)
     408        fnames = self.get_attnames(qcl)
     409        if isinstance(arg, DictType):
     410            # XXX this code is for backwards compatibility and will be
     411            # XXX removed eventually
     412            if not arg.has_key(foid):
     413                ofoid = 'oid_' + self._split_schema(cl)[-1]
     414                if arg.has_key(ofoid):
     415                    arg[foid] = arg[ofoid]
     416
     417            k = arg[keyname == 'oid' and foid or keyname]
     418        else:
     419            k = arg
     420            arg = {}
     421        # We want the oid for later updates if that isn't the key
     422        if keyname == 'oid':
     423            q = 'SELECT * FROM %s WHERE oid=%s LIMIT 1' % (qcl, k)
     424        elif view:
     425            q = 'SELECT * FROM %s WHERE %s=%s LIMIT 1' % \
     426                (qcl, keyname, _quote(k, fnames[keyname]))
     427        else:
     428            q = 'SELECT %s FROM %s WHERE %s=%s LIMIT 1' % \
     429                (','.join(fnames.keys()), qcl, \
     430                    keyname, _quote(k, fnames[keyname]))
     431        self._do_debug(q)
     432        res = self.db.query(q).dictresult()
     433        if not res:
     434            raise DatabaseError('No such record in %s where %s=%s'
     435                % (qcl, keyname, _quote(k, fnames[keyname])))
     436        for k, d in res[0].items():
     437            if k == 'oid':
     438                k = foid
     439            arg[k] = d
     440        return arg
     441
     442    def insert(self, cl, d = None, **kw):
     443        """Insert a tuple into a database table.
     444
     445        This method inserts values into the table specified filling in the
     446        values from the dictionary.  It then reloads the dictionary with the
     447        values from the database.  This causes the dictionary to be updated
     448        with values that are modified by rules, triggers, etc.
     449
     450        Note: The method currently doesn't support insert into views
     451        although PostgreSQL does.
     452
     453        """
     454        if d is None:
     455            a = {}
     456        else:
     457            a = d
     458        a.update(kw)
     459
     460        qcl = _join_parts(self._split_schema(cl)) # build qualified name
     461        foid = 'oid(%s)' % qcl # build mangled name
     462        fnames = self.get_attnames(qcl)
     463        t = []
     464        n = []
     465        for f in fnames.keys():
     466            if f != 'oid' and a.has_key(f):
     467                t.append(_quote(a[f], fnames[f]))
     468                n.append('"%s"' % f)
     469        q = 'INSERT INTO %s (%s) VALUES (%s)' % \
     470            (qcl, ','.join(n), ','.join(t))
     471        self._do_debug(q)
     472        a[foid] = self.db.query(q)
     473        # Reload the dictionary to catch things modified by engine.
     474        # Note that get() changes 'oid' below to oid_schema_table.
     475        # If no read perms (it can and does happen), return None.
     476        try:
     477            return self.get(qcl, a, 'oid')
     478        except:
     479            return None
     480
     481    def update(self, cl, d = None, **kw):
     482        """Update an existing row in a database table.
     483
     484        Similar to insert but updates an existing row.  The update is based
     485        on the OID value as munged by get.  The array returned is the
     486        one sent modified to reflect any changes caused by the update due
     487        to triggers, rules, defaults, etc.
     488
     489        """
     490        # Update always works on the oid which get returns if available,
     491        # otherwise use the primary key.  Fail if neither.
     492        qcl = _join_parts(self._split_schema(cl)) # build qualified name
     493        foid = 'oid(%s)' % qcl # build mangled oid
     494
     495        # Note that we only accept oid key from named args for safety
     496        if kw.has_key('oid'):
     497            kw[foid] = kw['oid']
     498            del kw['oid']
     499
     500        if d is None:
     501            a = {}
     502        else:
     503            a = d
     504        a.update(kw)
     505
     506        # XXX this code is for backwards compatibility and will be
     507        # XXX removed eventually
     508        if not a.has_key(foid):
     509            ofoid = 'oid_' + self._split_schema(cl)[-1]
     510            if a.has_key(ofoid):
     511                a[foid] = a[ofoid]
     512
     513        if a.has_key(foid):
     514            where = "oid=%s" % a[foid]
     515        else:
     516            try:
     517                pk = self.pkey(qcl)
     518            except:
     519                raise ProgrammingError(
     520                    'Update needs primary key or oid as %s' % foid)
     521            where = "%s='%s'" % (pk, a[pk])
     522        v = []
     523        k = 0
     524        fnames = self.get_attnames(qcl)
     525        for ff in fnames.keys():
     526            if ff != 'oid' and a.has_key(ff):
     527                v.append('%s=%s' % (ff, _quote(a[ff], fnames[ff])))
     528        if v == []:
     529            return None
     530        q = 'UPDATE %s SET %s WHERE %s' % (qcl, ','.join(v), where)
     531        self._do_debug(q)
     532        self.db.query(q)
     533        # Reload the dictionary to catch things modified by engine:
     534        if a.has_key(foid):
     535            return self.get(qcl, a, 'oid')
     536        else:
     537            return self.get(qcl, a)
     538
     539    def clear(self, cl, a = None):
     540        """
     541
     542        This method clears all the attributes to values determined by the types.
     543        Numeric types are set to 0, Booleans are set to 'f', and everything
     544        else is set to the empty string.  If the array argument is present,
     545        it is used as the array and any entries matching attribute names are
     546        cleared with everything else left unchanged.
     547
     548        """
     549        # At some point we will need a way to get defaults from a table.
     550        if a is None:
     551            a = {} # empty if argument is not present
     552        qcl = _join_parts(self._split_schema(cl)) # build qualified name
     553        foid = 'oid(%s)' % qcl # build mangled oid
     554        fnames = self.get_attnames(qcl)
     555        for k, t in fnames.items():
     556            if k == 'oid':
     557                continue
     558            if t in ['int', 'seq', 'float', 'num', 'money']:
     559                a[k] = 0
     560            elif t == 'bool':
     561                a[k] = 'f'
     562            else:
     563                a[k] = ''
     564        return a
     565
     566    def delete(self, cl, d = None, **kw):
     567        """Delete an existing row in a database table.
     568
     569        This method deletes the row from a table.
     570        It deletes based on the OID munged as described above."""
     571
     572        # Like update, delete works on the oid.
     573        # One day we will be testing that the record to be deleted
     574        # isn't referenced somewhere (or else PostgreSQL will).
     575        qcl = _join_parts(self._split_schema(cl)) # build qualified name
     576        foid = 'oid(%s)' % qcl # build mangled oid
     577
     578        # Note that we only accept oid key from named args for safety
     579        if kw.has_key('oid'):
     580            kw[foid] = kw['oid']
     581            del kw['oid']
     582
     583        if d is None:
     584            a = {}
     585        else:
     586            a = d
     587        a.update(kw)
     588
     589        # XXX this code is for backwards compatibility and will be
     590        # XXX removed eventually
     591        if not a.has_key(foid):
     592            ofoid = 'oid_' + self._split_schema(cl)[-1]
     593            if a.has_key(ofoid):
     594                a[foid] = a[ofoid]
     595
     596        q = 'DELETE FROM %s WHERE oid=%s' % (qcl, a[foid])
     597        self._do_debug(q)
     598        self.db.query(q)
    599599
    600600
     
    602602
    603603if __name__ == '__main__':
    604         print 'PyGreSQL version', version
    605         print
    606         print __doc__
     604    print 'PyGreSQL version', version
     605    print
     606    print __doc__
  • trunk/module/pgdb.py

    r343 r346  
    55# Written by D'Arcy J.M. Cain
    66#
    7 # $Id: pgdb.py,v 1.43 2008-11-01 15:38:02 cito Exp $
     7# $Id: pgdb.py,v 1.44 2008-11-01 18:21:12 cito Exp $
    88#
    99
     
    1919Basic usage:
    2020
    21         pgdb.connect(connect_string) # open a connection
    22         # connect_string = 'host:database:user:password:opt:tty'
    23         # All parts are optional. You may also pass host through
    24         # password as keyword arguments. To pass a port,
    25         # pass it in the host keyword parameter:
    26         pgdb.connect(host='localhost:5432')
    27 
    28         connection.cursor() # open a cursor
    29 
    30         cursor.execute(query[, params])
    31         # Execute a query, binding params (a dictionary) if they are
    32         # passed. The binding syntax is the same as the % operator
    33         # for dictionaries, and no quoting is done.
    34 
    35         cursor.executemany(query, list of params)
    36         # Execute a query many times, binding each param dictionary
    37         # from the list.
    38 
    39         cursor.fetchone() # fetch one row, [value, value, ...]
    40 
    41         cursor.fetchall() # fetch all rows, [[value, value, ...], ...]
    42 
    43         cursor.fetchmany([size])
    44         # returns size or cursor.arraysize number of rows,
    45         # [[value, value, ...], ...] from result set.
    46         # Default cursor.arraysize is 1.
    47 
    48         cursor.description # returns information about the columns
    49         #       [(column_name, type_name, display_size,
    50         #               internal_size, precision, scale, null_ok), ...]
    51         # Note that precision, scale and null_ok are not implemented.
    52 
    53         cursor.rowcount # number of rows available in the result set
    54         # Available after a call to execute.
    55 
    56         connection.commit() # commit transaction
    57 
    58         connection.rollback() # or rollback transaction
    59 
    60         cursor.close() # close the cursor
    61 
    62         connection.close() # close the connection
     21    pgdb.connect(connect_string) # open a connection
     22    # connect_string = 'host:database:user:password:opt:tty'
     23    # All parts are optional. You may also pass host through
     24    # password as keyword arguments. To pass a port,
     25    # pass it in the host keyword parameter:
     26    pgdb.connect(host='localhost:5432')
     27
     28    connection.cursor() # open a cursor
     29
     30    cursor.execute(query[, params])
     31    # Execute a query, binding params (a dictionary) if they are
     32    # passed. The binding syntax is the same as the % operator
     33    # for dictionaries, and no quoting is done.
     34
     35    cursor.executemany(query, list of params)
     36    # Execute a query many times, binding each param dictionary
     37    # from the list.
     38
     39    cursor.fetchone() # fetch one row, [value, value, ...]
     40
     41    cursor.fetchall() # fetch all rows, [[value, value, ...], ...]
     42
     43    cursor.fetchmany([size])
     44    # returns size or cursor.arraysize number of rows,
     45    # [[value, value, ...], ...] from result set.
     46    # Default cursor.arraysize is 1.
     47
     48    cursor.description # returns information about the columns
     49    #   [(column_name, type_name, display_size,
     50    #           internal_size, precision, scale, null_ok), ...]
     51    # Note that precision, scale and null_ok are not implemented.
     52
     53    cursor.rowcount # number of rows available in the result set
     54    # Available after a call to execute.
     55
     56    connection.commit() # commit transaction
     57
     58    connection.rollback() # or rollback transaction
     59
     60    cursor.close() # close the cursor
     61
     62    connection.close() # close the connection
    6363
    6464"""
     
    6767import time
    6868try:
    69         frozenset
     69    frozenset
    7070except NameError: # Python < 2.4
    71         from sets import ImmutableSet as frozenset
     71    from sets import ImmutableSet as frozenset
    7272try: # use mx.DateTime module if available
    73         from mx.DateTime import DateTime, \
    74                 TimeDelta, DateTimeType
     73    from mx.DateTime import DateTime, \
     74        TimeDelta, DateTimeType
    7575except ImportError: # otherwise use standard datetime module
    76         from datetime import datetime as DateTime, \
    77                 timedelta as TimeDelta, datetime as DateTimeType
     76    from datetime import datetime as DateTime, \
     77        timedelta as TimeDelta, datetime as DateTimeType
    7878try: # use Decimal if available
    79         from decimal import Decimal
    80         set_decimal(Decimal)
     79    from decimal import Decimal
     80    set_decimal(Decimal)
    8181except ImportError: # otherwise (Python < 2.4)
    82         Decimal = float # use float instead of Decimal
     82    Decimal = float # use float instead of Decimal
    8383
    8484
     
    9898
    9999def _cast_bool(value):
    100         return value[:1] in ['t', 'T']
     100    return value[:1] in ['t', 'T']
    101101
    102102
    103103def _cast_money(value):
    104         return Decimal(''.join(filter(
    105                 lambda v: v in '0123456789.-', value)))
     104    return Decimal(''.join(filter(
     105        lambda v: v in '0123456789.-', value)))
    106106
    107107
    108108_cast = {'bool': _cast_bool,
    109         'int2': int, 'int4': int, 'serial': int,
    110         'int8': long, 'oid': long, 'oid8': long,
    111         'float4': float, 'float8': float,
    112         'numeric': Decimal, 'money': _cast_money}
     109    'int2': int, 'int4': int, 'serial': int,
     110    'int8': long, 'oid': long, 'oid8': long,
     111    'float4': float, 'float8': float,
     112    'numeric': Decimal, 'money': _cast_money}
    113113
    114114
    115115class pgdbTypeCache(dict):
    116         """Cache for database types."""
    117 
    118         def __init__(self, cnx):
    119                 """Initialize type cache for connection."""
    120                 super(pgdbTypeCache, self).__init__()
    121                 self._src = cnx.source()
    122 
    123         def typecast(typ, value):
    124                 """Cast value to database type."""
    125                 if value is None:
    126                         # for NULL values, no typecast is necessary
    127                         return None
    128                 cast = _cast.get(typ)
    129                 if cast is None:
    130                         # no typecast available or necessary
    131                         return value
    132                 else:
    133                         return cast(value)
    134         typecast = staticmethod(typecast)
    135 
    136         def getdescr(self, oid):
    137                 """Get name of database type with given oid."""
    138                 try:
    139                         return self[oid]
    140                 except KeyError:
    141                         self._src.execute(
    142                                 "SELECT typname, typlen "
    143                                 "FROM pg_type WHERE oid=%s" % oid)
    144                         res = self._src.fetch(1)[0]
    145                         # The column name is omitted from the return value.
    146                         # It will have to be prepended by the caller.
    147                         res = (res[0], None, int(res[1]),
    148                                 None, None, None)
    149                         self[oid] = res
    150                         return res
     116    """Cache for database types."""
     117
     118    def __init__(self, cnx):
     119        """Initialize type cache for connection."""
     120        super(pgdbTypeCache, self).__init__()
     121        self._src = cnx.source()
     122
     123    def typecast(typ, value):
     124        """Cast value to database type."""
     125        if value is None:
     126            # for NULL values, no typecast is necessary
     127            return None
     128        cast = _cast.get(typ)
     129        if cast is None:
     130            # no typecast available or necessary
     131            return value
     132        else:
     133            return cast(value)
     134    typecast = staticmethod(typecast)
     135
     136    def getdescr(self, oid):
     137        """Get name of database type with given oid."""
     138        try:
     139            return self[oid]
     140        except KeyError:
     141            self._src.execute(
     142                "SELECT typname, typlen "
     143                "FROM pg_type WHERE oid=%s" % oid)
     144            res = self._src.fetch(1)[0]
     145            # The column name is omitted from the return value.
     146            # It will have to be prepended by the caller.
     147            res = (res[0], None, int(res[1]),
     148                None, None, None)
     149            self[oid] = res
     150            return res
    151151
    152152
    153153class _quoteitem(dict):
    154         """Dictionary with auto quoting of its items."""
    155 
    156         def __getitem__(self, key):
    157                 return _quote(super(_quoteitem, self).__getitem__(key))
     154    """Dictionary with auto quoting of its items."""
     155
     156    def __getitem__(self, key):
     157        return _quote(super(_quoteitem, self).__getitem__(key))
    158158
    159159
    160160def _quote(val):
    161         """Quote value depending on its type."""
    162         if isinstance(val, DateTimeType):
    163                 val = str(val)
    164         elif isinstance(val, unicode):
    165                 val = val.encode( 'utf-8' )
    166         if isinstance(val, str):
    167                 val = "'%s'" % str(val).replace("\\", "\\\\").replace("'", "''")
    168         elif isinstance(val, (int, long, float)):
    169                 pass
    170         elif val is None:
    171                 val = 'NULL'
    172         elif isinstance(val, (list, tuple)):
    173                 val = '(%s)' % ','.join(map(lambda v: str(_quote(v)), val))
    174         elif Decimal is not float and isinstance(val, Decimal):
    175                 pass
    176         elif hasattr(val, '__pg_repr__'):
    177                 val = val.__pg_repr__()
    178         else:
    179                 raise InterfaceError('do not know how to handle type %s' % type(val))
    180         return val
     161    """Quote value depending on its type."""
     162    if isinstance(val, DateTimeType):
     163        val = str(val)
     164    elif isinstance(val, unicode):
     165        val = val.encode( 'utf-8' )
     166    if isinstance(val, str):
     167        val = "'%s'" % str(val).replace("\\", "\\\\").replace("'", "''")
     168    elif isinstance(val, (int, long, float)):
     169        pass
     170    elif val is None:
     171        val = 'NULL'
     172    elif isinstance(val, (list, tuple)):
     173        val = '(%s)' % ','.join(map(lambda v: str(_quote(v)), val))
     174    elif Decimal is not float and isinstance(val, Decimal):
     175        pass
     176    elif hasattr(val, '__pg_repr__'):
     177        val = val.__pg_repr__()
     178    else:
     179        raise InterfaceError('do not know how to handle type %s' % type(val))
     180    return val
    181181
    182182
    183183def _quoteparams(string, params):
    184         """Quote parameters.
    185 
    186         This function works for both mappings and sequences.
    187 
    188         """
    189         if hasattr(params, 'has_key'):
    190                 params = _quoteitem(params)
    191         else:
    192                 params = tuple(map(_quote, params))
    193         return string % params
     184    """Quote parameters.
     185
     186    This function works for both mappings and sequences.
     187
     188    """
     189    if hasattr(params, 'has_key'):
     190        params = _quoteitem(params)
     191    else:
     192        params = tuple(map(_quote, params))
     193    return string % params
    194194
    195195
     
    197197
    198198class pgdbCursor(object):
    199         """Cursor Object."""
    200 
    201         def __init__(self, dbcnx):
    202                 """Create a cursor object for the database connection."""
    203                 self._dbcnx = dbcnx
    204                 self._cnx = dbcnx._cnx
    205                 self._type_cache = dbcnx._type_cache
    206                 self._src = self._cnx.source()
    207                 self.description = None
    208                 self.rowcount = -1
    209                 self.arraysize = 1
    210                 self.lastrowid = None
    211 
    212         def row_factory(row):
    213                 """Process rows before they are returned.
    214 
    215                 You can overwrite this with a custom row factory,
    216                 e.g. a dict factory:
    217 
    218                 class myCursor(pgdb.pgdbCursor):
    219                         def cursor.row_factory(self, row):
    220                                 d = {}
    221                                 for idx, col in enumerate(self.description):
    222                                         d[col[0]] = row[idx]
    223                                 return d
    224                 cursor = myCursor(cnx)
    225 
    226                 """
    227                 return row
    228         row_factory = staticmethod(row_factory)
    229 
    230         def close(self):
    231                 """Close the cursor object."""
    232                 self._src.close()
    233                 self.description = None
    234                 self.rowcount = -1
    235                 self.lastrowid = None
    236 
    237         def execute(self, operation, params=None):
    238                 """Prepare and execute a database operation (query or command)."""
    239                 # The parameters may also be specified as list of
    240                 # tuples to e.g. insert multiple rows in a single
    241                 # operation, but this kind of usage is deprecated:
    242                 if (params and isinstance(params, list)
    243                                 and isinstance(params[0], tuple)):
    244                         self.executemany(operation, params)
    245                 else:
    246                         # not a list of tuples
    247                         self.executemany(operation, (params,))
    248 
    249         def executemany(self, operation, param_seq):
    250                 """Prepare operation and execute it against a parameter sequence."""
    251                 if not param_seq:
    252                         # don't do anything without parameters
    253                         return
    254                 self.description = None
    255                 self.rowcount = -1
    256                 # first try to execute all queries
    257                 totrows = 0
    258                 sql = "BEGIN"
    259                 try:
    260                         if not self._dbcnx._tnx:
    261                                 try:
    262                                         self._cnx.source().execute(sql)
    263                                 except:
    264                                         raise OperationalError("can't start transaction")
    265                                 self._dbcnx._tnx = True
    266                         for params in param_seq:
    267                                 if params:
    268                                         sql = _quoteparams(operation, params)
    269                                 else:
    270                                         sql = operation
    271                                 rows = self._src.execute(sql)
    272                                 if rows: # true if not DML
    273                                         totrows += rows
    274                                 else:
    275                                         self.rowcount = -1
    276                 except Error, msg:
    277                         raise DatabaseError("error '%s' in '%s'" % (msg, sql))
    278                 except Exception, err:
    279                         raise OperationalError("internal error in '%s': %s" % (sql, err))
    280                 except:
    281                         raise OperationalError("internal error in '%s'" % sql)
    282                 # then initialize result raw count and description
    283                 if self._src.resulttype == RESULT_DQL:
    284                         self.rowcount = self._src.ntuples
    285                         getdescr = self._type_cache.getdescr
    286                         coltypes = self._src.listinfo()
    287                         self.description = [typ[1:2] + getdescr(typ[2]) for typ in coltypes]
    288                         self.lastrowid = self._src.oidstatus()
    289                 else:
    290                         self.rowcount = totrows
    291                         self.description = None
    292                         self.lastrowid = self._src.oidstatus()
    293 
    294         def fetchone(self):
    295                 """Fetch the next row of a query result set."""
    296                 res = self.fetchmany(1, False)
    297                 try:
    298                         return res[0]
    299                 except IndexError:
    300                         return None
    301 
    302         def fetchall(self):
    303                 """Fetch all (remaining) rows of a query result."""
    304                 return self.fetchmany(-1, False)
    305 
    306         def fetchmany(self, size=None, keep=False):
    307                 """Fetch the next set of rows of a query result.
    308 
    309                 The number of rows to fetch per call is specified by the
    310                 size parameter. If it is not given, the cursor's arraysize
    311                 determines the number of rows to be fetched. If you set
    312                 the keep parameter to true, this is kept as new arraysize.
    313 
    314                 """
    315                 if size is None:
    316                         size = self.arraysize
    317                 if keep:
    318                         self.arraysize = size
    319                 try:
    320                         result = self._src.fetch(size)
    321                 except Error, err:
    322                         raise DatabaseError(str(err))
    323                 row_factory = self.row_factory
    324                 typecast = self._type_cache.typecast
    325                 coltypes = [desc[1] for desc in self.description]
    326                 return [row_factory([typecast(*args)
    327                         for args in zip(coltypes, row)]) for row in result]
    328 
    329         def nextset():
    330                 """Not supported."""
    331                 raise NotSupportedError("nextset() is not supported")
    332         nextset = staticmethod(nextset)
    333 
    334         def setinputsizes(sizes):
    335                 """Not supported."""
    336                 pass
    337         setinputsizes = staticmethod(setinputsizes)
    338 
    339         def setoutputsize(size, column=0):
    340                 """Not supported."""
    341                 pass
    342         setoutputsize = staticmethod(setoutputsize)
     199    """Cursor Object."""
     200
     201    def __init__(self, dbcnx):
     202        """Create a cursor object for the database connection."""
     203        self._dbcnx = dbcnx
     204        self._cnx = dbcnx._cnx
     205        self._type_cache = dbcnx._type_cache
     206        self._src = self._cnx.source()
     207        self.description = None
     208        self.rowcount = -1
     209        self.arraysize = 1
     210        self.lastrowid = None
     211
     212    def row_factory(row):
     213        """Process rows before they are returned.
     214
     215        You can overwrite this with a custom row factory,
     216        e.g. a dict factory:
     217
     218        class myCursor(pgdb.pgdbCursor):
     219            def cursor.row_factory(self, row):
     220                d = {}
     221                for idx, col in enumerate(self.description):
     222                    d[col[0]] = row[idx]
     223                return d
     224        cursor = myCursor(cnx)
     225
     226        """
     227        return row
     228    row_factory = staticmethod(row_factory)
     229
     230    def close(self):
     231        """Close the cursor object."""
     232        self._src.close()
     233        self.description = None
     234        self.rowcount = -1
     235        self.lastrowid = None
     236
     237    def execute(self, operation, params=None):
     238        """Prepare and execute a database operation (query or command)."""
     239        # The parameters may also be specified as list of
     240        # tuples to e.g. insert multiple rows in a single
     241        # operation, but this kind of usage is deprecated:
     242        if (params and isinstance(params, list)
     243                and isinstance(params[0], tuple)):
     244            self.executemany(operation, params)
     245        else:
     246            # not a list of tuples
     247            self.executemany(operation, (params,))
     248
     249    def executemany(self, operation, param_seq):
     250        """Prepare operation and execute it against a parameter sequence."""
     251        if not param_seq:
     252            # don't do anything without parameters
     253            return
     254        self.description = None
     255        self.rowcount = -1
     256        # first try to execute all queries
     257        totrows = 0
     258        sql = "BEGIN"
     259        try:
     260            if not self._dbcnx._tnx:
     261                try:
     262                    self._cnx.source().execute(sql)
     263                except:
     264                    raise OperationalError("can't start transaction")
     265                self._dbcnx._tnx = True
     266            for params in param_seq:
     267                if params:
     268                    sql = _quoteparams(operation, params)
     269                else:
     270                    sql = operation
     271                rows = self._src.execute(sql)
     272                if rows: # true if not DML
     273                    totrows += rows
     274                else:
     275                    self.rowcount = -1
     276        except Error, msg:
     277            raise DatabaseError("error '%s' in '%s'" % (msg, sql))
     278        except Exception, err:
     279            raise OperationalError("internal error in '%s': %s" % (sql, err))
     280        except:
     281            raise OperationalError("internal error in '%s'" % sql)
     282        # then initialize result raw count and description
     283        if self._src.resulttype == RESULT_DQL:
     284            self.rowcount = self._src.ntuples
     285            getdescr = self._type_cache.getdescr
     286            coltypes = self._src.listinfo()
     287            self.description = [typ[1:2] + getdescr(typ[2]) for typ in coltypes]
     288            self.lastrowid = self._src.oidstatus()
     289        else:
     290            self.rowcount = totrows
     291            self.description = None
     292            self.lastrowid = self._src.oidstatus()
     293
     294    def fetchone(self):
     295        """Fetch the next row of a query result set."""
     296        res = self.fetchmany(1, False)
     297        try:
     298            return res[0]
     299        except IndexError:
     300            return None
     301
     302    def fetchall(self):
     303        """Fetch all (remaining) rows of a query result."""
     304        return self.fetchmany(-1, False)
     305
     306    def fetchmany(self, size=None, keep=False):
     307        """Fetch the next set of rows of a query result.
     308
     309        The number of rows to fetch per call is specified by the
     310        size parameter. If it is not given, the cursor's arraysize
     311        determines the number of rows to be fetched. If you set
     312        the keep parameter to true, this is kept as new arraysize.
     313
     314        """
     315        if size is None:
     316            size = self.arraysize
     317        if keep:
     318            self.arraysize = size
     319        try:
     320            result = self._src.fetch(size)
     321        except Error, err:
     322            raise DatabaseError(str(err))
     323        row_factory = self.row_factory
     324        typecast = self._type_cache.typecast
     325        coltypes = [desc[1] for desc in self.description]
     326        return [row_factory([typecast(*args)
     327            for args in zip(coltypes, row)]) for row in result]
     328
     329    def nextset():
     330        """Not supported."""
     331        raise NotSupportedError("nextset() is not supported")
     332    nextset = staticmethod(nextset)
     333
     334    def setinputsizes(sizes):
     335        """Not supported."""
     336        pass
     337    setinputsizes = staticmethod(setinputsizes)
     338
     339    def setoutputsize(size, column=0):
     340        """Not supported."""
     341        pass
     342    setoutputsize = staticmethod(setoutputsize)
    343343
    344344
     
    346346
    347347class pgdbCnx(object):
    348         """Connection Object."""
    349 
    350         def __init__(self, cnx):
    351                 """Create a database connection object."""
    352                 self._cnx = cnx # connection
    353                 self._tnx = False # transaction state
    354                 self._type_cache = pgdbTypeCache(cnx)
    355                 try:
    356                         self._cnx.source()
    357                 except:
    358                         raise OperationalError("invalid connection")
    359 
    360         def close(self):
    361                 """Close the connection object."""
    362                 if self._cnx:
    363                         self._cnx.close()
    364                         self._cnx = None
    365                 else:
    366                         raise OperationalError("connection has been closed")
    367 
    368         def commit(self):
    369                 """Commit any pending transaction to the database."""
    370                 if self._cnx:
    371                         if self._tnx:
    372                                 self._tnx = False
    373                                 try:
    374                                         self._cnx.source().execute("COMMIT")
    375                                 except:
    376                                         raise OperationalError("can't commit")
    377                 else:
    378                         raise OperationalError("connection has been closed")
    379 
    380         def rollback(self):
    381                 """Roll back to the start of any pending transaction."""
    382                 if self._cnx:
    383                         if self._tnx:
    384                                 self._tnx = False
    385                                 try:
    386                                         self._cnx.source().execute("ROLLBACK")
    387                                 except:
    388                                         raise OperationalError("can't rollback")
    389                 else:
    390                         raise OperationalError("connection has been closed")
    391 
    392         def cursor(self):
    393                 """Return a new Cursor Object using the connection."""
    394                 if self._cnx:
    395                         try:
    396                                 return pgdbCursor(self)
    397                         except:
    398                                 raise OperationalError("invalid connection")
    399                 else:
    400                         raise OperationalError("connection has been closed")
     348    """Connection Object."""
     349
     350    def __init__(self, cnx):
     351        """Create a database connection object."""
     352        self._cnx = cnx # connection
     353        self._tnx = False # transaction state
     354        self._type_cache = pgdbTypeCache(cnx)
     355        try:
     356            self._cnx.source()
     357        except:
     358            raise OperationalError("invalid connection")
     359
     360    def close(self):
     361        """Close the connection object."""
     362        if self._cnx:
     363            self._cnx.close()
     364            self._cnx = None
     365        else:
     366            raise OperationalError("connection has been closed")
     367
     368    def commit(self):
     369        """Commit any pending transaction to the database."""
     370        if self._cnx:
     371            if self._tnx:
     372                self._tnx = False
     373                try:
     374                    self._cnx.source().execute("COMMIT")
     375                except:
     376                    raise OperationalError("can't commit")
     377        else:
     378            raise OperationalError("connection has been closed")
     379
     380    def rollback(self):
     381        """Roll back to the start of any pending transaction."""
     382        if self._cnx:
     383            if self._tnx:
     384                self._tnx = False
     385                try:
     386                    self._cnx.source().execute("ROLLBACK")
     387                except:
     388                    raise OperationalError("can't rollback")
     389        else:
     390            raise OperationalError("connection has been closed")
     391
     392    def cursor(self):
     393        """Return a new Cursor Object using the connection."""
     394        if self._cnx:
     395            try:
     396                return pgdbCursor(self)
     397            except:
     398                raise OperationalError("invalid connection")
     399        else:
     400            raise OperationalError("connection has been closed")
    401401
    402402
     
    406406
    407407def connect(dsn=None,
    408                 user=None, password=None,
    409                 host=None, database=None):
    410         """Connects to a database."""
    411         # first get params from DSN
    412         dbport = -1
    413         dbhost = ""
    414         dbbase = ""
    415         dbuser = ""
    416         dbpasswd = ""
    417         dbopt = ""
    418         dbtty = ""
    419         try:
    420                 params = dsn.split(":")
    421                 dbhost = params[0]
    422                 dbbase = params[1]
    423                 dbuser = params[2]
    424                 dbpasswd = params[3]
    425                 dbopt = params[4]
    426                 dbtty = params[5]
    427         except (IndexError, TypeError):
    428                 pass
    429 
    430         # override if necessary
    431         if user is not None:
    432                 dbuser = user
    433         if password is not None:
    434                 dbpasswd = password
    435         if database is not None:
    436                 dbbase = database
    437         if host is not None:
    438                 try:
    439                         params = host.split(":")
    440                         dbhost = params[0]
    441                         dbport = int(params[1])
    442                 except (IndexError, TypeError, ValueError):
    443                         pass
    444 
    445         # empty host is localhost
    446         if dbhost == "":
    447                 dbhost = None
    448         if dbuser == "":
    449                 dbuser = None
    450 
    451         # open the connection
    452         cnx = _connect_(dbbase, dbhost, dbport, dbopt,
    453                 dbtty, dbuser, dbpasswd)
    454         return pgdbCnx(cnx)
     408        user=None, password=None,
     409        host=None, database=None):
     410    """Connects to a database."""
     411    # first get params from DSN
     412    dbport = -1
     413    dbhost = ""
     414    dbbase = ""
     415    dbuser = ""
     416    dbpasswd = ""
     417    dbopt = ""
     418    dbtty = ""
     419    try:
     420        params = dsn.split(":")
     421        dbhost = params[0]
     422        dbbase = params[1]
     423        dbuser = params[2]
     424        dbpasswd = params[3]
     425        dbopt = params[4]
     426        dbtty = params[5]
     427    except (IndexError, TypeError):
     428        pass
     429
     430    # override if necessary
     431    if user is not None:
     432        dbuser = user
     433    if password is not None:
     434        dbpasswd = password
     435    if database is not None:
     436        dbbase = database
     437    if host is not None:
     438        try:
     439            params = host.split(":")
     440            dbhost = params[0]
     441            dbport = int(params[1])
     442        except (IndexError, TypeError, ValueError):
     443            pass
     444
     445    # empty host is localhost
     446    if dbhost == "":
     447        dbhost = None
     448    if dbuser == "":
     449        dbuser = None
     450
     451    # open the connection
     452    cnx = _connect_(dbbase, dbhost, dbport, dbopt,
     453        dbtty, dbuser, dbpasswd)
     454    return pgdbCnx(cnx)
    455455
    456456
     
    458458
    459459class pgdbType(frozenset):
    460         """Type class for a couple of PostgreSQL data types.
    461 
    462         PostgreSQL is object-oriented: types are dynamic.
    463         We must thus use type names as internal type codes.
    464 
    465         """
    466 
    467         if frozenset.__module__ == '__builtin__':
    468                 def __new__(cls, values):
    469                         if isinstance(values, basestring):
    470                                 values = values.split()
    471                         return super(pgdbType, cls).__new__(cls, values)
    472         else: # Python < 2.4
    473                 def __init__(self, values):
    474                         if isinstance(values, basestring):
    475                                 values = values.split()
    476                         super(pgdbType, self).__init__(values)
    477 
    478         def __eq__(self, other):
    479                 if isinstance(other, basestring):
    480                         return other in self
    481                 else:
    482                         return super(pgdbType, self).__eq__(other)
    483 
    484         def __ne__(self, other):
    485                 if isinstance(other, basestring):
    486                         return other not in self
    487                 else:
    488                         return super(pgdbType, self).__ne__(other)
     460    """Type class for a couple of PostgreSQL data types.
     461
     462    PostgreSQL is object-oriented: types are dynamic.
     463    We must thus use type names as internal type codes.
     464
     465    """
     466
     467    if frozenset.__module__ == '__builtin__':
     468        def __new__(cls, values):
     469            if isinstance(values, basestring):
     470                values = values.split()
     471            return super(pgdbType, cls).__new__(cls, values)
     472    else: # Python < 2.4
     473        def __init__(self, values):
     474            if isinstance(values, basestring):
     475                values = values.split()
     476            super(pgdbType, self).__init__(values)
     477
     478    def __eq__(self, other):
     479        if isinstance(other, basestring):
     480            return other in self
     481        else:
     482            return super(pgdbType, self).__eq__(other)
     483
     484    def __ne__(self, other):
     485        if isinstance(other, basestring):
     486            return other not in self
     487        else:
     488            return super(pgdbType, self).__ne__(other)
    489489
    490490
     
    495495NUMBER = pgdbType('int2 int4 serial int8 float4 float8 numeric money')
    496496DATETIME = pgdbType('date time timetz timestamp timestamptz datetime abstime'
    497         ' interval tinterval timespan reltime')
     497    ' interval tinterval timespan reltime')
    498498ROWID = pgdbType('oid oid8')
    499499
     
    516516
    517517def Date(year, month, day):
    518         """Construct an object holding a date value."""
    519         return DateTime(year, month, day)
     518    """Construct an object holding a date value."""
     519    return DateTime(year, month, day)
    520520
    521521def Time(hour, minute, second):
    522         """Construct an object holding a time value."""
    523         return TimeDelta(hour, minute, second)
     522    """Construct an object holding a time value."""
     523    return TimeDelta(hour, minute, second)
    524524
    525525def Timestamp(year, month, day, hour, minute, second):
    526         """construct an object holding a time stamp value."""
    527         return DateTime(year, month, day, hour, minute, second)
     526    """construct an object holding a time stamp value."""
     527    return DateTime(year, month, day, hour, minute, second)
    528528
    529529def DateFromTicks(ticks):
    530         """Construct an object holding a date value from the given ticks value."""
    531         return Date(*time.localtime(ticks)[:3])
     530    """Construct an object holding a date value from the given ticks value."""
     531    return Date(*time.localtime(ticks)[:3])
    532532
    533533def TimeFromTicks(ticks):
    534         """construct an object holding a time value from the given ticks value."""
    535         return Time(*time.localtime(ticks)[3:6])
     534    """construct an object holding a time value from the given ticks value."""
     535    return Time(*time.localtime(ticks)[3:6])
    536536
    537537def TimestampFromTicks(ticks):
    538         """construct an object holding a time stamp from the given ticks value."""
    539         return Timestamp(*time.localtime(ticks)[:6])
     538    """construct an object holding a time stamp from the given ticks value."""
     539    return Timestamp(*time.localtime(ticks)[:6])
    540540
    541541def Binary(value):
    542         """construct an object capable of holding a binary (long) string value."""
    543         return value
     542    """construct an object capable of holding a binary (long) string value."""
     543    return value
    544544
    545545
     
    547547
    548548if __name__ == '__main__':
    549         print 'PyGreSQL version', version
    550         print
    551         print __doc__
     549    print 'PyGreSQL version', version
     550    print
     551    print __doc__
  • trunk/module/setup.py

    r338 r346  
    11#!/usr/bin/env python
    2 # $Id: setup.py,v 1.24 2008-10-31 21:13:19 darcy Exp $
     2# $Id: setup.py,v 1.25 2008-11-01 18:21:12 cito Exp $
    33
    44"""Setup script for PyGreSQL version 4.0
     
    4545
    4646def pg_config(s):
    47         """Retrieve information about installed version of PostgreSQL."""
    48         f = os.popen("pg_config --%s" % s)
    49         d = f.readline().strip()
    50         if f.close() is not None:
    51                 raise Exception, "pg_config tool is not available."
    52         if not d:
    53                 raise Exception, "Could not get %s information." % s
    54         return d
     47    """Retrieve information about installed version of PostgreSQL."""
     48    f = os.popen("pg_config --%s" % s)
     49    d = f.readline().strip()
     50    if f.close() is not None:
     51        raise Exception, "pg_config tool is not available."
     52    if not d:
     53        raise Exception, "Could not get %s information." % s
     54    return d
    5555
    5656def mk_include():
    57         """Create a temporary local include directory.
     57    """Create a temporary local include directory.
    5858
    59         The directory will contain a copy of the PostgreSQL server header files,
    60         where all features which are not necessary for PyGreSQL are disabled.
    61         """
    62         os.mkdir('include')
    63         for f in os.listdir(pg_include_dir_server):
    64                 if not f.endswith('.h'):
    65                         continue
    66                 d = open(os.path.join(pg_include_dir_server, f)).read()
    67                 if f == 'pg_config.h':
    68                         d += '\n'
    69                         d += '#undef ENABLE_NLS\n'
    70                         d += '#undef USE_REPL_SNPRINTF\n'
    71                         d += '#undef USE_SSL\n'
    72                 open(os.path.join('include', f), 'w').write(d)
     59    The directory will contain a copy of the PostgreSQL server header files,
     60    where all features which are not necessary for PyGreSQL are disabled.
     61    """
     62    os.mkdir('include')
     63    for f in os.listdir(pg_include_dir_server):
     64        if not f.endswith('.h'):
     65            continue
     66        d = open(os.path.join(pg_include_dir_server, f)).read()
     67        if f == 'pg_config.h':
     68            d += '\n'
     69            d += '#undef ENABLE_NLS\n'
     70            d += '#undef USE_REPL_SNPRINTF\n'
     71            d += '#undef USE_SSL\n'
     72        open(os.path.join('include', f), 'w').write(d)
    7373
    7474def rm_include():
    75         """Remove the temporary local include directory."""
    76         if os.path.exists('include'):
    77                 for f in os.listdir('include'):
    78                         os.remove(os.path.join('include', f))
    79                 os.rmdir('include')
     75    """Remove the temporary local include directory."""
     76    if os.path.exists('include'):
     77        for f in os.listdir('include'):
     78            os.remove(os.path.join('include', f))
     79        os.rmdir('include')
    8080
    8181pg_include_dir = pg_config('includedir')
     
    9393
    9494if sys.platform == "win32":
    95         include_dirs.append(os.path.join(pg_include_dir_server, 'port/win32'))
     95    include_dirs.append(os.path.join(pg_include_dir_server, 'port/win32'))
    9696
    9797setup(
    98         name = "PyGreSQL",
    99         version = "4.0",
    100         description = "Python PostgreSQL Interfaces",
    101         author = "D'Arcy J. M. Cain",
    102         author_email = "darcy@PyGreSQL.org",
    103         url = "http://www.pygresql.org",
    104         license = "Python",
    105         py_modules = ['pg', 'pgdb'],
    106         ext_modules = [Extension(
    107                 '_pg', ['pgmodule.c'],
    108                 include_dirs = include_dirs,
    109                 library_dirs = library_dirs,
    110                 libraries = libraries,
    111                 extra_compile_args = ['-O2'],
    112                 )],
    113         )
     98    name = "PyGreSQL",
     99    version = "4.0",
     100    description = "Python PostgreSQL Interfaces",
     101    author = "D'Arcy J. M. Cain",
     102    author_email = "darcy@PyGreSQL.org",
     103    url = "http://www.pygresql.org",
     104    license = "Python",
     105    py_modules = ['pg', 'pgdb'],
     106    ext_modules = [Extension(
     107        '_pg', ['pgmodule.c'],
     108        include_dirs = include_dirs,
     109        library_dirs = library_dirs,
     110        libraries = libraries,
     111        extra_compile_args = ['-O2'],
     112        )],
     113    )
    114114
    115115rm_include()
  • trunk/module/test_pg.py

    r333 r346  
    55# Written by Christoph Zwerschke
    66#
    7 # $Id: test_pg.py,v 1.13 2008-10-09 20:45:29 cito Exp $
     7# $Id: test_pg.py,v 1.14 2008-11-01 18:21:12 cito Exp $
    88#
    99
     
    3131german = 1
    3232try:
    33         import locale
    34         locale.setlocale(locale.LC_ALL, ('de', 'latin1'))
     33    import locale
     34    locale.setlocale(locale.LC_ALL, ('de', 'latin1'))
    3535except:
    36         try:
    37                 locale.setlocale(locale.LC_ALL, 'german')
    38         except:
    39                 german = 0
     36    try:
     37        locale.setlocale(locale.LC_ALL, 'german')
     38    except:
     39        german = 0
    4040
    4141try:
    42         from decimal import Decimal
     42    from decimal import Decimal
    4343except ImportError:
    44         Decimal = float
     44    Decimal = float
    4545
    4646
    4747def smart_ddl(conn, cmd):
    48         """Execute DDL, but don't complain about minor things."""
    49         try:
    50                 if cmd.startswith('create table '):
    51                         i = cmd.find(' as select ')
    52                         if i < 0:
    53                                 i = len(cmd)
    54                         conn.query(cmd[:i] + ' with oids' + cmd[i:])
    55                 else:
    56                         conn.query(cmd)
    57         except pg.ProgrammingError:
    58                 if cmd.startswith('drop table ') \
    59                         or cmd.startswith('set ') \
    60                         or cmd.startswith('alter database '):
    61                         pass
    62                 elif cmd.startswith('create table '):
    63                         conn.query(cmd)
    64                 else:
    65                         raise
     48    """Execute DDL, but don't complain about minor things."""
     49    try:
     50        if cmd.startswith('create table '):
     51            i = cmd.find(' as select ')
     52            if i < 0:
     53                i = len(cmd)
     54            conn.query(cmd[:i] + ' with oids' + cmd[i:])
     55        else:
     56            conn.query(cmd)
     57    except pg.ProgrammingError:
     58        if cmd.startswith('drop table ') \
     59            or cmd.startswith('set ') \
     60            or cmd.startswith('alter database '):
     61            pass
     62        elif cmd.startswith('create table '):
     63            conn.query(cmd)
     64        else:
     65            raise
    6666
    6767
    6868class TestAuxiliaryFunctions(unittest.TestCase):
    69         """Test the auxiliary functions external to the connection class."""
    70 
    71         def testQuote(self):
    72                 f = pg._quote
    73                 self.assertEqual(f(None, None), 'NULL')
    74                 self.assertEqual(f(None, 'int'), 'NULL')
    75                 self.assertEqual(f(None, 'float'), 'NULL')
    76                 self.assertEqual(f(None, 'num'), 'NULL')
    77                 self.assertEqual(f(None, 'money'), 'NULL')
    78                 self.assertEqual(f(None, 'bool'), 'NULL')
    79                 self.assertEqual(f(None, 'date'), 'NULL')
    80                 self.assertEqual(f('', 'int'), 'NULL')
    81                 self.assertEqual(f('', 'seq'), 'NULL')
    82                 self.assertEqual(f('', 'float'), 'NULL')
    83                 self.assertEqual(f('', 'num'), 'NULL')
    84                 self.assertEqual(f('', 'money'), 'NULL')
    85                 self.assertEqual(f('', 'bool'), 'NULL')
    86                 self.assertEqual(f('', 'date'), 'NULL')
    87                 self.assertEqual(f('', 'text'), "''")
    88                 self.assertEqual(f(123456789, 'int'), '123456789')
    89                 self.assertEqual(f(123654789, 'seq'), '123654789')
    90                 self.assertEqual(f(123456987, 'num'), '123456987')
    91                 self.assertEqual(f(1.23654789, 'num'), '1.23654789')
    92                 self.assertEqual(f(12365478.9, 'num'), '12365478.9')
    93                 self.assertEqual(f('123456789', 'num'), '123456789')
    94                 self.assertEqual(f('1.23456789', 'num'), '1.23456789')
    95                 self.assertEqual(f('12345678.9', 'num'), '12345678.9')
    96                 self.assertEqual(f(123, 'money'), "'123.00'")
    97                 self.assertEqual(f('123', 'money'), "'123.00'")
    98                 self.assertEqual(f(123.45, 'money'), "'123.45'")
    99                 self.assertEqual(f('123.45', 'money'), "'123.45'")
    100                 self.assertEqual(f(123.454, 'money'), "'123.45'")
    101                 self.assertEqual(f('123.454', 'money'), "'123.45'")
    102                 self.assertEqual(f(123.456, 'money'), "'123.46'")
    103                 self.assertEqual(f('123.456', 'money'), "'123.46'")
    104                 self.assertEqual(f('f', 'bool'), "'f'")
    105                 self.assertEqual(f('F', 'bool'), "'f'")
    106                 self.assertEqual(f('false', 'bool'), "'f'")
    107                 self.assertEqual(f('False', 'bool'), "'f'")
    108                 self.assertEqual(f('FALSE', 'bool'), "'f'")
    109                 self.assertEqual(f(0, 'bool'), "'f'")
    110                 self.assertEqual(f('0', 'bool'), "'f'")
    111                 self.assertEqual(f('-', 'bool'), "'f'")
    112                 self.assertEqual(f('n', 'bool'), "'f'")
    113                 self.assertEqual(f('N', 'bool'), "'f'")
    114                 self.assertEqual(f('no', 'bool'), "'f'")
    115                 self.assertEqual(f('off', 'bool'), "'f'")
    116                 self.assertEqual(f('t', 'bool'), "'t'")
    117                 self.assertEqual(f('T', 'bool'), "'t'")
    118                 self.assertEqual(f('true', 'bool'), "'t'")
    119                 self.assertEqual(f('True', 'bool'), "'t'")
    120                 self.assertEqual(f('TRUE', 'bool'), "'t'")
    121                 self.assertEqual(f(1, 'bool'), "'t'")
    122                 self.assertEqual(f(2, 'bool'), "'t'")
    123                 self.assertEqual(f(-1, 'bool'), "'t'")
    124                 self.assertEqual(f(0.5, 'bool'), "'t'")
    125                 self.assertEqual(f('1', 'bool'), "'t'")
    126                 self.assertEqual(f('y', 'bool'), "'t'")
    127                 self.assertEqual(f('Y', 'bool'), "'t'")
    128                 self.assertEqual(f('yes', 'bool'), "'t'")
    129                 self.assertEqual(f('on', 'bool'), "'t'")
    130                 self.assertEqual(f('01.01.2000', 'date'), "'01.01.2000'")
    131                 self.assertEqual(f(123, 'text'), "'123'")
    132                 self.assertEqual(f(1.23, 'text'), "'1.23'")
    133                 self.assertEqual(f('abc', 'text'), "'abc'")
    134                 self.assertEqual(f("ab'c", 'text'), "'ab''c'")
    135                 self.assertEqual(f('ab\\c', 'text'), "'ab\\\\c'")
    136                 self.assertEqual(f("a\\b'c", 'text'), "'a\\\\b''c'")
    137 
    138         def testIsQuoted(self):
    139                 f = pg._is_quoted
    140                 self.assert_(f('A'))
    141                 self.assert_(f('0'))
    142                 self.assert_(f('#'))
    143                 self.assert_(f('*'))
    144                 self.assert_(f('.'))
    145                 self.assert_(f(' '))
    146                 self.assert_(f('a b'))
    147                 self.assert_(f('a+b'))
    148                 self.assert_(f('a*b'))
    149                 self.assert_(f('a.b'))
    150                 self.assert_(f('0ab'))
    151                 self.assert_(f('aBc'))
    152                 self.assert_(f('ABC'))
    153                 self.assert_(f('"a"'))
    154                 self.assert_(not f('a'))
    155                 self.assert_(not f('a0'))
    156                 self.assert_(not f('_'))
    157                 self.assert_(not f('_a'))
    158                 self.assert_(not f('_0'))
    159                 self.assert_(not f('_a_0_'))
    160                 self.assert_(not f('ab'))
    161                 self.assert_(not f('ab0'))
    162                 self.assert_(not f('abc'))
    163                 self.assert_(not f('abc'))
    164                 if german:
    165                         self.assert_(not f('\xe4'))
    166                         self.assert_(f('\xc4'))
    167                         self.assert_(not f('k\xe4se'))
    168                         self.assert_(f('K\xe4se'))
    169                         self.assert_(not f('emmentaler_k\xe4se'))
    170                         self.assert_(f('emmentaler k\xe4se'))
    171                         self.assert_(f('EmmentalerK\xe4se'))
    172                         self.assert_(f('Emmentaler K\xe4se'))
    173 
    174         def testIsUnquoted(self):
    175                 f = pg._is_unquoted
    176                 self.assert_(f('A'))
    177                 self.assert_(not f('0'))
    178                 self.assert_(not f('#'))
    179                 self.assert_(not f('*'))
    180                 self.assert_(not f('.'))
    181                 self.assert_(not f(' '))
    182                 self.assert_(not f('a b'))
    183                 self.assert_(not f('a+b'))
    184                 self.assert_(not f('a*b'))
    185                 self.assert_(not f('a.b'))
    186                 self.assert_(not f('0ab'))
    187                 self.assert_(f('aBc'))
    188                 self.assert_(f('ABC'))
    189                 self.assert_(not f('"a"'))
    190                 self.assert_(f('a0'))
    191                 self.assert_(f('_'))
    192                 self.assert_(f('_a'))
    193                 self.assert_(f('_0'))
    194                 self.assert_(f('_a_0_'))
    195                 self.assert_(f('ab'))
    196                 self.assert_(f('ab0'))
    197                 self.assert_(f('abc'))
    198                 self.assert_(f('\xe4'))
    199                 self.assert_(f('\xc4'))
    200                 self.assert_(f('k\xe4se'))
    201                 self.assert_(f('K\xe4se'))
    202                 self.assert_(f('emmentaler_k\xe4se'))
    203                 self.assert_(not f('emmentaler k\xe4se'))
    204                 self.assert_(f('EmmentalerK\xe4se'))
    205                 self.assert_(not f('Emmentaler K\xe4se'))
    206 
    207         def testSplitFirstPart(self):
    208                 f = pg._split_first_part
    209                 self.assertEqual(f('a.b'), ['a', 'b'])
    210                 self.assertEqual(f('a.b.c'), ['a', 'b.c'])
    211                 self.assertEqual(f('"a.b".c'), ['a.b', 'c'])
    212                 self.assertEqual(f('a."b.c"'), ['a', '"b.c"'])
    213                 self.assertEqual(f('A.b.c'), ['a', 'b.c'])
    214                 self.assertEqual(f('Ab.c'), ['ab', 'c'])
    215                 self.assertEqual(f('aB.c'), ['ab', 'c'])
    216                 self.assertEqual(f('AB.c'), ['ab', 'c'])
    217                 self.assertEqual(f('A b.c'), ['A b', 'c'])
    218                 self.assertEqual(f('a B.c'), ['a B', 'c'])
    219                 self.assertEqual(f('"A".b.c'), ['A', 'b.c'])
    220                 self.assertEqual(f('"A""B".c'), ['A"B', 'c'])
    221                 self.assertEqual(f('a.b.c.d.e.f.g'), ['a', 'b.c.d.e.f.g'])
    222                 self.assertEqual(f('"a.b.c.d.e.f".g'), ['a.b.c.d.e.f', 'g'])
    223                 self.assertEqual(f('a.B.c.D.e.F.g'), ['a', 'B.c.D.e.F.g'])
    224                 self.assertEqual(f('A.b.C.d.E.f.G'), ['a', 'b.C.d.E.f.G'])
    225 
    226         def testSplitParts(self):
    227                 f = pg._split_parts
    228                 self.assertEqual(f('a.b'), ['a', 'b'])
    229                 self.assertEqual(f('a.b.c'), ['a', 'b', 'c'])
    230                 self.assertEqual(f('"a.b".c'), ['a.b', 'c'])
    231                 self.assertEqual(f('a."b.c"'), ['a', 'b.c'])
    232                 self.assertEqual(f('A.b.c'), ['a', 'b', 'c'])
    233                 self.assertEqual(f('Ab.c'), ['ab', 'c'])
    234                 self.assertEqual(f('aB.c'), ['ab', 'c'])
    235                 self.assertEqual(f('AB.c'), ['ab', 'c'])
    236                 self.assertEqual(f('A b.c'), ['A b', 'c'])
    237                 self.assertEqual(f('a B.c'), ['a B', 'c'])
    238                 self.assertEqual(f('"A".b.c'), ['A', 'b', 'c'])
    239                 self.assertEqual(f('"A""B".c'), ['A"B', 'c'])
    240                 self.assertEqual(f('a.b.c.d.e.f.g'), ['a', 'b', 'c', 'd', 'e', 'f', 'g'])
    241                 self.assertEqual(f('"a.b.c.d.e.f".g'), ['a.b.c.d.e.f', 'g'])
    242                 self.assertEqual(f('a.B.c.D.e.F.g'), ['a', 'b', 'c', 'd', 'e', 'f', 'g'])
    243                 self.assertEqual(f('A.b.C.d.E.f.G'), ['a', 'b', 'c', 'd', 'e', 'f', 'g'])
    244 
    245         def testJoinParts(self):
    246                 f = pg._join_parts
    247                 self.assertEqual(f(('a',)), 'a')
    248                 self.assertEqual(f(('a', 'b')), 'a.b')
    249                 self.assertEqual(f(('a', 'b', 'c')), 'a.b.c')
    250                 self.assertEqual(f(('a', 'b', 'c', 'd', 'e', 'f', 'g')), 'a.b.c.d.e.f.g')
    251                 self.assertEqual(f(('A', 'b')), '"A".b')
    252                 self.assertEqual(f(('a', 'B')), 'a."B"')
    253                 self.assertEqual(f(('a b', 'c')), '"a b".c')
    254                 self.assertEqual(f(('a', 'b c')), 'a."b c"')
    255                 self.assertEqual(f(('a_b', 'c')), 'a_b.c')
    256                 self.assertEqual(f(('a', 'b_c')), 'a.b_c')
    257                 self.assertEqual(f(('0', 'a')), '"0".a')
    258                 self.assertEqual(f(('0_', 'a')), '"0_".a')
    259                 self.assertEqual(f(('_0', 'a')), '_0.a')
    260                 self.assertEqual(f(('_a', 'b')), '_a.b')
    261                 self.assertEqual(f(('a', 'B', '0', 'c0', 'C0', 'd e', 'f_g', 'h.i', 'jklm', 'nopq')),
    262                         'a."B"."0".c0."C0"."d e".f_g."h.i".jklm.nopq')
     69    """Test the auxiliary functions external to the connection class."""
     70
     71    def testQuote(self):
     72        f = pg._quote
     73        self.assertEqual(f(None, None), 'NULL')
     74        self.assertEqual(f(None, 'int'), 'NULL')
     75        self.assertEqual(f(None, 'float'), 'NULL')
     76        self.assertEqual(f(None, 'num'), 'NULL')
     77        self.assertEqual(f(None, 'money'), 'NULL')
     78        self.assertEqual(f(None, 'bool'), 'NULL')
     79        self.assertEqual(f(None, 'date'), 'NULL')
     80        self.assertEqual(f('', 'int'), 'NULL')
     81        self.assertEqual(f('', 'seq'), 'NULL')
     82        self.assertEqual(f('', 'float'), 'NULL')
     83        self.assertEqual(f('', 'num'), 'NULL')
     84        self.assertEqual(f('', 'money'), 'NULL')
     85        self.assertEqual(f('', 'bool'), 'NULL')
     86        self.assertEqual(f('', 'date'), 'NULL')
     87        self.assertEqual(f('', 'text'), "''")
     88        self.assertEqual(f(123456789, 'int'), '123456789')
     89        self.assertEqual(f(123654789, 'seq'), '123654789')
     90        self.assertEqual(f(123456987, 'num'), '123456987')
     91        self.assertEqual(f(1.23654789, 'num'), '1.23654789')
     92        self.assertEqual(f(12365478.9, 'num'), '12365478.9')
     93        self.assertEqual(f('123456789', 'num'), '123456789')
     94        self.assertEqual(f('1.23456789', 'num'), '1.23456789')
     95        self.assertEqual(f('12345678.9', 'num'), '12345678.9')
     96        self.assertEqual(f(123, 'money'), "'123.00'")
     97        self.assertEqual(f('123', 'money'), "'123.00'")
     98        self.assertEqual(f(123.45, 'money'), "'123.45'")
     99        self.assertEqual(f('123.45', 'money'), "'123.45'")
     100        self.assertEqual(f(123.454, 'money'), "'123.45'")
     101        self.assertEqual(f('123.454', 'money'), "'123.45'")
     102        self.assertEqual(f(123.456, 'money'), "'123.46'")
     103        self.assertEqual(f('123.456', 'money'), "'123.46'")
     104        self.assertEqual(f('f', 'bool'), "'f'")
     105        self.assertEqual(f('F', 'bool'), "'f'")
     106        self.assertEqual(f('false', 'bool'), "'f'")
     107        self.assertEqual(f('False', 'bool'), "'f'")
     108        self.assertEqual(f('FALSE', 'bool'), "'f'")
     109        self.assertEqual(f(0, 'bool'), "'f'")
     110        self.assertEqual(f('0', 'bool'), "'f'")
     111        self.assertEqual(f('-', 'bool'), "'f'")
     112        self.assertEqual(f('n', 'bool'), "'f'")
     113        self.assertEqual(f('N', 'bool'), "'f'")
     114        self.assertEqual(f('no', 'bool'), "'f'")
     115        self.assertEqual(f('off', 'bool'), "'f'")
     116        self.assertEqual(f('t', 'bool'), "'t'")
     117        self.assertEqual(f('T', 'bool'), "'t'")
     118        self.assertEqual(f('true', 'bool'), "'t'")
     119        self.assertEqual(f('True', 'bool'), "'t'")
     120        self.assertEqual(f('TRUE', 'bool'), "'t'")
     121        self.assertEqual(f(1, 'bool'), "'t'")
     122        self.assertEqual(f(2, 'bool'), "'t'")
     123        self.assertEqual(f(-1, 'bool'), "'t'")
     124        self.assertEqual(f(0.5, 'bool'), "'t'")
     125        self.assertEqual(f('1', 'bool'), "'t'")
     126        self.assertEqual(f('y', 'bool'), "'t'")
     127        self.assertEqual(f('Y', 'bool'), "'t'")
     128        self.assertEqual(f('yes', 'bool'), "'t'")
     129        self.assertEqual(f('on', 'bool'), "'t'")
     130        self.assertEqual(f('01.01.2000', 'date'), "'01.01.2000'")
     131        self.assertEqual(f(123, 'text'), "'123'")
     132        self.assertEqual(f(1.23, 'text'), "'1.23'")
     133        self.assertEqual(f('abc', 'text'), "'abc'")
     134        self.assertEqual(f("ab'c", 'text'), "'ab''c'")
     135        self.assertEqual(f('ab\\c', 'text'), "'ab\\\\c'")
     136        self.assertEqual(f("a\\b'c", 'text'), "'a\\\\b''c'")
     137
     138    def testIsQuoted(self):
     139        f = pg._is_quoted
     140        self.assert_(f('A'))
     141        self.assert_(f('0'))
     142        self.assert_(f('#'))
     143        self.assert_(f('*'))
     144        self.assert_(f('.'))
     145        self.assert_(f(' '))
     146        self.assert_(f('a b'))
     147        self.assert_(f('a+b'))
     148        self.assert_(f('a*b'))
     149        self.assert_(f('a.b'))
     150        self.assert_(f('0ab'))
     151        self.assert_(f('aBc'))
     152        self.assert_(f('ABC'))
     153        self.assert_(f('"a"'))
     154        self.assert_(not f('a'))
     155        self.assert_(not f('a0'))
     156        self.assert_(not f('_'))
     157        self.assert_(not f('_a'))
     158        self.assert_(not f('_0'))
     159        self.assert_(not f('_a_0_'))
     160        self.assert_(not f('ab'))
     161        self.assert_(not f('ab0'))
     162        self.assert_(not f('abc'))
     163        self.assert_(not f('abc'))
     164        if german:
     165            self.assert_(not f('\xe4'))
     166            self.assert_(f('\xc4'))
     167            self.assert_(not f('k\xe4se'))
     168            self.assert_(f('K\xe4se'))
     169            self.assert_(not f('emmentaler_k\xe4se'))
     170            self.assert_(f('emmentaler k\xe4se'))
     171            self.assert_(f('EmmentalerK\xe4se'))
     172            self.assert_(f('Emmentaler K\xe4se'))
     173
     174    def testIsUnquoted(self):
     175        f = pg._is_unquoted
     176        self.assert_(f('A'))
     177        self.assert_(not f('0'))
     178        self.assert_(not f('#'))
     179        self.assert_(not f('*'))
     180        self.assert_(not f('.'))
     181        self.assert_(not f(' '))
     182        self.assert_(not f('a b'))
     183        self.assert_(not f('a+b'))
     184        self.assert_(not f('a*b'))
     185        self.assert_(not f('a.b'))
     186        self.assert_(not f('0ab'))
     187        self.assert_(f('aBc'))
     188        self.assert_(f('ABC'))
     189        self.assert_(not f('"a"'))
     190        self.assert_(f('a0'))
     191        self.assert_(f('_'))
     192        self.assert_(f('_a'))
     193        self.assert_(f('_0'))
     194        self.assert_(f('_a_0_'))
     195        self.assert_(f('ab'))
     196        self.assert_(f('ab0'))
     197        self.assert_(f('abc'))
     198        self.assert_(f('\xe4'))
     199        self.assert_(f('\xc4'))
     200        self.assert_(f('k\xe4se'))
     201        self.assert_(f('K\xe4se'))
     202        self.assert_(f('emmentaler_k\xe4se'))
     203        self.assert_(not f('emmentaler k\xe4se'))
     204        self.assert_(f('EmmentalerK\xe4se'))
     205        self.assert_(not f('Emmentaler K\xe4se'))
     206
     207    def testSplitFirstPart(self):
     208        f = pg._split_first_part
     209        self.assertEqual(f('a.b'), ['a', 'b'])
     210        self.assertEqual(f('a.b.c'), ['a', 'b.c'])
     211        self.assertEqual(f('"a.b".c'), ['a.b', 'c'])
     212        self.assertEqual(f('a."b.c"'), ['a', '"b.c"'])
     213        self.assertEqual(f('A.b.c'), ['a', 'b.c'])
     214        self.assertEqual(f('Ab.c'), ['ab', 'c'])
     215        self.assertEqual(f('aB.c'), ['ab', 'c'])
     216        self.assertEqual(f('AB.c'), ['ab', 'c'])
     217        self.assertEqual(f('A b.c'), ['A b', 'c'])
     218        self.assertEqual(f('a B.c'), ['a B', 'c'])
     219        self.assertEqual(f('"A".b.c'), ['A', 'b.c'])
     220        self.assertEqual(f('"A""B".c'), ['A"B', 'c'])
     221        self.assertEqual(f('a.b.c.d.e.f.g'), ['a', 'b.c.d.e.f.g'])
     222        self.assertEqual(f('"a.b.c.d.e.f".g'), ['a.b.c.d.e.f', 'g'])
     223        self.assertEqual(f('a.B.c.D.e.F.g'), ['a', 'B.c.D.e.F.g'])
     224        self.assertEqual(f('A.b.C.d.E.f.G'), ['a', 'b.C.d.E.f.G'])
     225
     226    def testSplitParts(self):
     227        f = pg._split_parts
     228        self.assertEqual(f('a.b'), ['a', 'b'])
     229        self.assertEqual(f('a.b.c'), ['a', 'b', 'c'])
     230        self.assertEqual(f('"a.b".c'), ['a.b', 'c'])
     231        self.assertEqual(f('a."b.c"'), ['a', 'b.c'])
     232        self.assertEqual(f('A.b.c'), ['a', 'b', 'c'])
     233        self.assertEqual(f('Ab.c'), ['ab', 'c'])
     234        self.assertEqual(f('aB.c'), ['ab', 'c'])
     235        self.assertEqual(f('AB.c'), ['ab', 'c'])
     236        self.assertEqual(f('A b.c'), ['A b', 'c'])
     237        self.assertEqual(f('a B.c'), ['a B', 'c'])
     238        self.assertEqual(f('"A".b.c'), ['A', 'b', 'c'])
     239        self.assertEqual(f('"A""B".c'), ['A"B', 'c'])
     240        self.assertEqual(f('a.b.c.d.e.f.g'), ['a', 'b', 'c', 'd', 'e', 'f', 'g'])
     241        self.assertEqual(f('"a.b.c.d.e.f".g'), ['a.b.c.d.e.f', 'g'])
     242        self.assertEqual(f('a.B.c.D.e.F.g'), ['a', 'b', 'c', 'd', 'e', 'f', 'g'])
     243        self.assertEqual(f('A.b.C.d.E.f.G'), ['a', 'b', 'c', 'd', 'e', 'f', 'g'])
     244
     245    def testJoinParts(self):
     246        f = pg._join_parts
     247        self.assertEqual(f(('a',)), 'a')
     248        self.assertEqual(f(('a', 'b')), 'a.b')
     249        self.assertEqual(f(('a', 'b', 'c')), 'a.b.c')
     250        self.assertEqual(f(('a', 'b', 'c', 'd', 'e', 'f', 'g')), 'a.b.c.d.e.f.g')
     251        self.assertEqual(f(('A', 'b')), '"A".b')
     252        self.assertEqual(f(('a', 'B')), 'a."B"')
     253        self.assertEqual(f(('a b', 'c')), '"a b".c')
     254        self.assertEqual(f(('a', 'b c')), 'a."b c"')
     255        self.assertEqual(f(('a_b', 'c')), 'a_b.c')
     256        self.assertEqual(f(('a', 'b_c')), 'a.b_c')
     257        self.assertEqual(f(('0', 'a')), '"0".a')
     258        self.assertEqual(f(('0_', 'a')), '"0_".a')
     259        self.assertEqual(f(('_0', 'a')), '_0.a')
     260        self.assertEqual(f(('_a', 'b')), '_a.b')
     261        self.assertEqual(f(('a', 'B', '0', 'c0', 'C0', 'd e', 'f_g', 'h.i', 'jklm', 'nopq')),
     262            'a."B"."0".c0."C0"."d e".f_g."h.i".jklm.nopq')
    263263
    264264
    265265class TestHasConnect(unittest.TestCase):
    266         """Test existence of basic pg module functions."""
    267 
    268         def testhasPgError(self):
    269                 self.assert_(issubclass(pg.Error, StandardError))
    270 
    271         def testhasPgWarning(self):
    272                 self.assert_(issubclass(pg.Warning, StandardError))
    273 
    274         def testhasPgInterfaceError(self):
    275                 self.assert_(issubclass(pg.InterfaceError, pg.Error))
    276 
    277         def testhasPgDatabaseError(self):
    278                 self.assert_(issubclass(pg.DatabaseError, pg.Error))
    279 
    280         def testhasPgInternalError(self):
    281                 self.assert_(issubclass(pg.InternalError, pg.DatabaseError))
    282 
    283         def testhasPgOperationalError(self):
    284                 self.assert_(issubclass(pg.OperationalError, pg.DatabaseError))
    285 
    286         def testhasPgProgrammingError(self):
    287                 self.assert_(issubclass(pg.ProgrammingError, pg.DatabaseError))
    288 
    289         def testhasPgIntegrityError(self):
    290                 self.assert_(issubclass(pg.IntegrityError, pg.DatabaseError))
    291 
    292         def testhasPgDataError(self):
    293                 self.assert_(issubclass(pg.DataError, pg.DatabaseError))
    294 
    295         def testhasPgNotSupportedError(self):
    296                 self.assert_(issubclass(pg.NotSupportedError, pg.DatabaseError))
    297 
    298         def testhasConnect(self):
    299                 self.assert_(callable(pg.connect))
    300 
    301         def testhasEscapeString(self):
    302                 self.assert_(callable(pg.escape_string))
    303 
    304         def testhasEscapeBytea(self):
    305                 self.assert_(callable(pg.escape_bytea))
    306 
    307         def testhasUnescapeBytea(self):
    308                 self.assert_(callable(pg.unescape_bytea))
    309 
    310         def testDefHost(self):
    311                 d0 = pg.get_defhost()
    312                 d1 = 'pgtesthost'
    313                 pg.set_defhost(d1)
    314                 self.assertEqual(pg.get_defhost(), d1)
    315                 pg.set_defhost(d0)
    316                 self.assertEqual(pg.get_defhost(), d0)
    317 
    318         def testDefPort(self):
    319                 d0 = pg.get_defport()
    320                 d1 = 1234
    321                 pg.set_defport(d1)
    322                 self.assertEqual(pg.get_defport(), d1)
    323                 if d0 is None:
    324                         d0 = -1
    325                 pg.set_defport(d0)
    326                 if d0 == -1:
    327                         d0 = None
    328                 self.assertEqual(pg.get_defport(), d0)
    329 
    330         def testDefOpt(self):
    331                 d0 = pg.get_defopt()
    332                 d1 = '-h pgtesthost -p 1234'
    333                 pg.set_defopt(d1)
    334                 self.assertEqual(pg.get_defopt(), d1)
    335                 pg.set_defopt(d0)
    336                 self.assertEqual(pg.get_defopt(), d0)
    337 
    338         def testDefTty(self):
    339                 d0 = pg.get_deftty()
    340                 d1 = 'pgtesttty'
    341                 pg.set_deftty(d1)
    342                 self.assertEqual(pg.get_deftty(), d1)
    343                 pg.set_deftty(d0)
    344                 self.assertEqual(pg.get_deftty(), d0)
    345 
    346         def testDefBase(self):
    347                 d0 = pg.get_defbase()
    348                 d1 = 'pgtestdb'
    349                 pg.set_defbase(d1)
    350                 self.assertEqual(pg.get_defbase(), d1)
    351                 pg.set_defbase(d0)
    352                 self.assertEqual(pg.get_defbase(), d0)
     266    """Test existence of basic pg module functions."""
     267
     268    def testhasPgError(self):
     269        self.assert_(issubclass(pg.Error, StandardError))
     270
     271    def testhasPgWarning(self):
     272        self.assert_(issubclass(pg.Warning, StandardError))
     273
     274    def testhasPgInterfaceError(self):
     275        self.assert_(issubclass(pg.InterfaceError, pg.Error))
     276
     277    def testhasPgDatabaseError(self):
     278        self.assert_(issubclass(pg.DatabaseError, pg.Error))
     279
     280    def testhasPgInternalError(self):
     281        self.assert_(issubclass(pg.InternalError, pg.DatabaseError))
     282
     283    def testhasPgOperationalError(self):
     284        self.assert_(issubclass(pg.OperationalError, pg.DatabaseError))
     285
     286    def testhasPgProgrammingError(self):
     287        self.assert_(issubclass(pg.ProgrammingError, pg.DatabaseError))
     288
     289    def testhasPgIntegrityError(self):
     290        self.assert_(issubclass(pg.IntegrityError, pg.DatabaseError))
     291
     292    def testhasPgDataError(self):
     293        self.assert_(issubclass(pg.DataError, pg.DatabaseError))
     294
     295    def testhasPgNotSupportedError(self):
     296        self.assert_(issubclass(pg.NotSupportedError, pg.DatabaseError))
     297
     298    def testhasConnect(self):
     299        self.assert_(callable(pg.connect))
     300
     301    def testhasEscapeString(self):
     302        self.assert_(callable(pg.escape_string))
     303
     304    def testhasEscapeBytea(self):
     305        self.assert_(callable(pg.escape_bytea))
     306
     307    def testhasUnescapeBytea(self):
     308        self.assert_(callable(pg.unescape_bytea))
     309
     310    def testDefHost(self):
     311        d0 = pg.get_defhost()
     312        d1 = 'pgtesthost'
     313        pg.set_defhost(d1)
     314        self.assertEqual(pg.get_defhost(), d1)
     315        pg.set_defhost(d0)
     316        self.assertEqual(pg.get_defhost(), d0)
     317
     318    def testDefPort(self):
     319        d0 = pg.get_defport()
     320        d1 = 1234
     321        pg.set_defport(d1)
     322        self.assertEqual(pg.get_defport(), d1)
     323        if d0 is None:
     324            d0 = -1
     325        pg.set_defport(d0)
     326        if d0 == -1:
     327            d0 = None
     328        self.assertEqual(pg.get_defport(), d0)
     329
     330    def testDefOpt(self):
     331        d0 = pg.get_defopt()
     332        d1 = '-h pgtesthost -p 1234'
     333        pg.set_defopt(d1)
     334        self.assertEqual(pg.get_defopt(), d1)
     335        pg.set_defopt(d0)
     336        self.assertEqual(pg.get_defopt(), d0)
     337
     338    def testDefTty(self):
     339        d0 = pg.get_deftty()
     340        d1 = 'pgtesttty'
     341        pg.set_deftty(d1)
     342        self.assertEqual(pg.get_deftty(), d1)
     343        pg.set_deftty(d0)
     344        self.assertEqual(pg.get_deftty(), d0)
     345
     346    def testDefBase(self):
     347        d0 = pg.get_defbase()
     348        d1 = 'pgtestdb'
     349        pg.set_defbase(d1)
     350        self.assertEqual(pg.get_defbase(), d1)
     351        pg.set_defbase(d0)
     352        self.assertEqual(pg.get_defbase(), d0)
    353353
    354354
    355355class TestEscapeFunctions(unittest.TestCase):
    356         """"Test pg escape and unescape functions."""
    357 
    358         def testEscapeString(self):
    359                 self.assertEqual(pg.escape_string('hello'), 'hello')
    360                 self.assertEqual(pg.escape_string(
    361                         r"It's fine to have a \ inside."),
    362                         r"It''s fine to have a \\ inside.")
    363 
    364         def testEscapeBytea(self):
    365                 self.assertEqual(pg.escape_bytea('hello'), 'hello')
    366                 self.assertEqual(pg.escape_bytea(
    367                         'O\x00ps\xff!'), r'O\\000ps\\377!')
    368 
    369         def testUnescapeBytea(self):
    370                 self.assertEqual(pg.unescape_bytea('hello'), 'hello')
    371                 self.assertEqual(pg.unescape_bytea(
    372                         r'O\000ps\377!'), 'O\x00ps\xff!')
     356    """"Test pg escape and unescape functions."""
     357
     358    def testEscapeString(self):
     359        self.assertEqual(pg.escape_string('hello'), 'hello')
     360        self.assertEqual(pg.escape_string(
     361            r"It's fine to have a \ inside."),
     362            r"It''s fine to have a \\ inside.")
     363
     364    def testEscapeBytea(self):
     365        self.assertEqual(pg.escape_bytea('hello'), 'hello')
     366        self.assertEqual(pg.escape_bytea(
     367            'O\x00ps\xff!'), r'O\\000ps\\377!')
     368
     369    def testUnescapeBytea(self):
     370        self.assertEqual(pg.unescape_bytea('hello'), 'hello')
     371        self.assertEqual(pg.unescape_bytea(
     372            r'O\000ps\377!'), 'O\x00ps\xff!')
    373373
    374374
    375375class TestCanConnect(unittest.TestCase):
    376         """Test whether a basic connection to PostgreSQL is possible."""
    377 
    378         def testCanConnectTemplate1(self):
    379                 dbname = 'template1'
    380                 try:
    381                         connection = pg.connect(dbname)
    382                 except:
    383                         self.fail('Cannot connect to database ' + dbname)
    384                 try:
    385                         connection.close()
    386                 except:
    387                         self.fail('Cannot close the database connection')
     376    """Test whether a basic connection to PostgreSQL is possible."""
     377
     378    def testCanConnectTemplate1(self):
     379        dbname = 'template1'
     380        try:
     381            connection = pg.connect(dbname)
     382        except:
     383            self.fail('Cannot connect to database ' + dbname)
     384        try:
     385            connection.close()
     386        except:
     387            self.fail('Cannot close the database connection')
    388388
    389389
    390390class TestConnectObject(unittest.TestCase):
    391         """"Test existence of basic pg connection methods."""
    392 
    393         def setUp(self):
    394                 dbname = 'template1'
    395                 self.dbname = dbname
    396                 self.connection = pg.connect(dbname)
    397 
    398         def tearDown(self):
    399                 self.connection.close()
    400 
    401         def testAllConnectAttributes(self):
    402                 attributes = ['db', 'error', 'host', 'options',
    403                         'port', 'status', 'tty', 'user']
    404                 connection_attributes = [a for a in dir(self.connection)
    405                         if not callable(eval("self.connection." + a))]
    406                 self.assertEqual(attributes, connection_attributes)
    407 
    408         def testAllConnectMethods(self):
    409                 methods = ['cancel', 'close', 'endcopy',
    410                         'fileno', 'getline', 'getlo', 'getnotify',
    411                         'inserttable', 'locreate', 'loimport',
    412                         'parameter', 'putline', 'query', 'reset',
    413                         'source', 'transaction']
    414                 connection_methods = [a for a in dir(self.connection)
    415                         if callable(eval("self.connection." + a))]
    416                 if 'transaction' not in connection_methods:
    417                         # this may be the case for PostgreSQL < 7.4
    418                         connection_methods.append('transaction')
    419                 if 'parameter' not in connection_methods:
    420                         # this may be the case for PostgreSQL < 7.4
    421                         connection_methods.append('parameter')
    422                 self.assertEqual(methods, connection_methods)
    423 
    424         def testAttributeDb(self):
    425                 self.assertEqual(self.connection.db, self.dbname)
    426 
    427         def testAttributeError(self):
    428                 error = self.connection.error
    429                 self.assert_(not error or 'krb5_' in error)
    430 
    431         def testAttributeHost(self):
    432                 def_host = 'localhost'
    433                 self.assertEqual(self.connection.host, def_host)
    434 
    435         def testAttributeOptions(self):
    436                 no_options = ''
    437                 self.assertEqual(self.connection.options, no_options)
    438 
    439         def testAttributePort(self):
    440                 def_port = 5432
    441                 self.assertEqual(self.connection.port, def_port)
    442 
    443         def testAttributeStatus(self):
    444                 status_ok = 1
    445                 self.assertEqual(self.connection.status, status_ok)
    446 
    447         def testAttributeTty(self):
    448                 def_tty = ''
    449                 self.assertEqual(self.connection.tty, def_tty)
    450 
    451         def testAttributeUser(self):
    452                 def_user = 'Deprecated facility'
    453                 self.assertEqual(self.connection.user, def_user)
    454 
    455         def testMethodQuery(self):
    456                 self.connection.query("select 1+1")
    457 
    458         def testMethodEndcopy(self):
    459                 try:
    460                         self.connection.endcopy()
    461                 except IOError:
    462                         pass
    463 
    464         def testMethodClose(self):
    465                 self.connection.close()
    466                 try:
    467                         self.connection.reset()
    468                         fail('Reset should give an error for a closed connection')
    469                 except:
    470                         pass
    471                 self.assertRaises(pg.InternalError, self.connection.close)
    472                 try:
    473                         self.connection.query('select 1')
    474                         self.fail('Query should give an error for a closed connection')
    475                 except:
    476                         pass
    477                 self.connection = pg.connect(self.dbname)
     391    """"Test existence of basic pg connection methods."""
     392
     393    def setUp(self):
     394        dbname = 'template1'
     395        self.dbname = dbname
     396        self.connection = pg.connect(dbname)
     397
     398    def tearDown(self):
     399        self.connection.close()
     400
     401    def testAllConnectAttributes(self):
     402        attributes = ['db', 'error', 'host', 'options',
     403            'port', 'status', 'tty', 'user']
     404        connection_attributes = [a for a in dir(self.connection)
     405            if not callable(eval("self.connection." + a))]
     406        self.assertEqual(attributes, connection_attributes)
     407
     408    def testAllConnectMethods(self):
     409        methods = ['cancel', 'close', 'endcopy',
     410            'fileno', 'getline', 'getlo', 'getnotify',
     411            'inserttable', 'locreate', 'loimport',
     412            'parameter', 'putline', 'query', 'reset',
     413            'source', 'transaction']
     414        connection_methods = [a for a in dir(self.connection)
     415            if callable(eval("self.connection." + a))]
     416        if 'transaction' not in connection_methods:
     417            # this may be the case for PostgreSQL < 7.4
     418            connection_methods.append('transaction')
     419        if 'parameter' not in connection_methods:
     420            # this may be the case for PostgreSQL < 7.4
     421            connection_methods.append('parameter')
     422        self.assertEqual(methods, connection_methods)
     423
     424    def testAttributeDb(self):
     425        self.assertEqual(self.connection.db, self.dbname)
     426
     427    def testAttributeError(self):
     428        error = self.connection.error
     429        self.assert_(not error or 'krb5_' in error)
     430
     431    def testAttributeHost(self):
     432        def_host = 'localhost'
     433        self.assertEqual(self.connection.host, def_host)
     434
     435    def testAttributeOptions(self):
     436        no_options = ''
     437        self.assertEqual(self.connection.options, no_options)
     438
     439    def testAttributePort(self):
     440        def_port = 5432
     441        self.assertEqual(self.connection.port, def_port)
     442
     443    def testAttributeStatus(self):
     444        status_ok = 1
     445        self.assertEqual(self.connection.status, status_ok)
     446
     447    def testAttributeTty(self):
     448        def_tty = ''
     449        self.assertEqual(self.connection.tty, def_tty)
     450
     451    def testAttributeUser(self):
     452        def_user = 'Deprecated facility'
     453        self.assertEqual(self.connection.user, def_user)
     454
     455    def testMethodQuery(self):
     456        self.connection.query("select 1+1")
     457
     458    def testMethodEndcopy(self):
     459        try:
     460            self.connection.endcopy()
     461        except IOError:
     462            pass
     463
     464    def testMethodClose(self):
     465        self.connection.close()
     466        try:
     467            self.connection.reset()
     468            fail('Reset should give an error for a closed connection')
     469        except:
     470            pass
     471        self.assertRaises(pg.InternalError, self.connection.close)
     472        try:
     473            self.connection.query('select 1')
     474            self.fail('Query should give an error for a closed connection')
     475        except:
     476            pass
     477        self.connection = pg.connect(self.dbname)
    478478
    479479
    480480class TestSimpleQueries(unittest.TestCase):
    481         """"Test simple queries via a basic pg connection."""
    482 
    483         def setUp(self):
    484                 dbname = 'template1'
    485                 self.c = pg.connect(dbname)
    486 
    487         def tearDown(self):
    488                 self.c.close()
    489 
    490         def testSelect0(self):
    491                 q = "select 0"
    492                 self.c.query(q)
    493 
    494         def testSelect0Semicolon(self):
    495                 q = "select 0;"
    496                 self.c.query(q)
    497 
    498         def testSelectSemicolon(self):
    499                 q = "select ;"
    500                 self.assertRaises(pg.ProgrammingError, self.c.query, q)
    501 
    502         def testGetresult(self):
    503                 q = "select 0"
    504                 result = [(0,)]
    505                 r = self.c.query(q).getresult()
    506                 self.assertEqual(r, result)
    507 
    508         def testDictresult(self):
    509                 q = "select 0 as alias0"
    510                 result = [{'alias0': 0}]
    511                 r = self.c.query(q).dictresult()
    512                 self.assertEqual(r, result)
    513 
    514         def testGet3Cols(self):
    515                 q = "select 1,2,3"
    516                 result = [(1,2,3)]
    517                 r = self.c.query(q).getresult()
    518                 self.assertEqual(r, result)
    519 
    520         def testGet3DictCols(self):
    521                 q = "select 1 as a,2 as b,3 as c"
    522                 result = [dict(a=1, b=2, c=3)]
    523                 r = self.c.query(q).dictresult()
    524                 self.assertEqual(r, result)
    525 
    526         def testGet3Rows(self):
    527                 q = "select 3 union select 1 union select 2 order by 1"
    528                 result = [(1,), (2,), (3,)]
    529                 r = self.c.query(q).getresult()
    530                 self.assertEqual(r, result)
    531 
    532         def testGet3DictRows(self):
    533                 q = "select 3 as alias3" \
    534                         " union select 1 union select 2 order by 1"
    535                 result = [{'alias3': 1}, {'alias3': 2}, {'alias3': 3}]
    536                 r = self.c.query(q).dictresult()
    537                 self.assertEqual(r, result)
    538 
    539         def testDictresultNames(self):
    540                 q = "select 'MixedCase' as MixedCaseAlias"
    541                 result = [{'mixedcasealias': 'MixedCase'}]
    542                 r = self.c.query(q).dictresult()
    543                 self.assertEqual(r, result)
    544                 q = "select 'MixedCase' as \"MixedCaseAlias\""
    545                 result = [{'MixedCaseAlias': 'MixedCase'}]
    546                 r = self.c.query(q).dictresult()
    547                 self.assertEqual(r, result)
    548 
    549         def testBigGetresult(self):
    550                 num_cols = 100
    551                 num_rows = 100
    552                 q = "select " + ','.join(map(str, xrange(num_cols)))
    553                 q = ' union all '.join((q,) * num_rows)
    554                 r = self.c.query(q).getresult()
    555                 result = [tuple(range(num_cols))] * num_rows
    556                 self.assertEqual(r, result)
    557 
    558         def testListfields(self):
    559                 q = 'select 0 as a, 0 as b, 0 as c,' \
    560                         ' 0 as c, 0 as b, 0 as a,' \
    561                         ' 0 as lowercase, 0 as UPPERCASE,' \
    562                         ' 0 as MixedCase, 0 as "MixedCase",' \
    563                         ' 0 as a_long_name_with_underscores,' \
    564                         ' 0 as "A long name with Blanks"'
    565                 r = self.c.query(q).listfields()
    566                 result = ('a', 'b', 'c', 'c', 'b', 'a',
    567                         'lowercase', 'uppercase', 'mixedcase', 'MixedCase',
    568                         'a_long_name_with_underscores',
    569                         'A long name with Blanks')
    570                 self.assertEqual(r, result)
    571 
    572         def testFieldname(self):
    573                 q = "select 0 as z, 0 as a, 0 as x, 0 as y"
    574                 r = self.c.query(q).fieldname(2)
    575                 result = "x"
    576                 self.assertEqual(r, result)
    577 
    578         def testFieldnum(self):
    579                 q = "select 0 as z, 0 as a, 0 as x, 0 as y"
    580                 r = self.c.query(q).fieldnum("x")
    581                 result = 2
    582                 self.assertEqual(r, result)
    583 
    584         def testNtuples(self):
    585                 q = "select 1 as a, 2 as b, 3 as c, 4 as d" \
    586                         " union select 5 as a, 6 as b, 7 as c, 8 as d"
    587                 r = self.c.query(q).ntuples()
    588                 result = 2
    589                 self.assertEqual(r, result)
    590 
    591         def testPrint(self):
    592                 q = "select 1 as a, 'hello' as h, 'w' as world" \
    593                         " union select 2, 'xyz', 'uvw'"
    594                 r = self.c.query(q)
    595                 t = '~test_pg_testPrint_temp.tmp'
    596                 s = open(t, 'w')
    597                 import sys, os
    598                 stdout, sys.stdout = sys.stdout, s
    599                 try:
    600                         print r
    601                 except:
    602                         pass
    603                 sys.stdout = stdout
    604                 s.close()
    605                 r = filter(bool, open(t, 'r').read().splitlines())
    606                 os.remove(t)
    607                 self.assertEqual(r,
    608                         ['a|h    |world',
    609                         '-+-----+-----',
    610                         '1|hello|w    ',
    611                         '2|xyz  |uvw  ',
    612                         '(2 rows)'])
     481    """"Test simple queries via a basic pg connection."""
     482
     483    def setUp(self):
     484        dbname = 'template1'
     485        self.c = pg.connect(dbname)
     486
     487    def tearDown(self):
     488        self.c.close()
     489
     490    def testSelect0(self):
     491        q = "select 0"
     492        self.c.query(q)
     493
     494    def testSelect0Semicolon(self):
     495        q = "select 0;"
     496        self.c.query(q)
     497
     498    def testSelectSemicolon(self):
     499        q = "select ;"
     500        self.assertRaises(pg.ProgrammingError, self.c.query, q)
     501
     502    def testGetresult(self):
     503        q = "select 0"
     504        result = [(0,)]
     505        r = self.c.query(q).getresult()
     506        self.assertEqual(r, result)
     507
     508    def testDictresult(self):
     509        q = "select 0 as alias0"
     510        result = [{'alias0': 0}]
     511        r = self.c.query(q).dictresult()
     512        self.assertEqual(r, result)
     513
     514    def testGet3Cols(self):
     515        q = "select 1,2,3"
     516        result = [(1,2,3)]
     517        r = self.c.query(q).getresult()
     518        self.assertEqual(r, result)
     519
     520    def testGet3DictCols(self):
     521        q = "select 1 as a,2 as b,3 as c"
     522        result = [dict(a=1, b=2, c=3)]
     523        r = self.c.query(q).dictresult()
     524        self.assertEqual(r, result)
     525
     526    def testGet3Rows(self):
     527        q = "select 3 union select 1 union select 2 order by 1"
     528        result = [(1,), (2,), (3,)]
     529        r = self.c.query(q).getresult()
     530        self.assertEqual(r, result)
     531
     532    def testGet3DictRows(self):
     533        q = "select 3 as alias3" \
     534            " union select 1 union select 2 order by 1"
     535        result = [{'alias3': 1}, {'alias3': 2}, {'alias3': 3}]
     536        r = self.c.query(q).dictresult()
     537        self.assertEqual(r, result)
     538
     539    def testDictresultNames(self):
     540        q = "select 'MixedCase' as MixedCaseAlias"
     541        result = [{'mixedcasealias': 'MixedCase'}]
     542        r = self.c.query(q).dictresult()
     543        self.assertEqual(r, result)
     544        q = "select 'MixedCase' as \"MixedCaseAlias\""
     545        result = [{'MixedCaseAlias': 'MixedCase'}]
     546        r = self.c.query(q).dictresult()
     547        self.assertEqual(r, result)
     548
     549    def testBigGetresult(self):
     550        num_cols = 100
     551        num_rows = 100
     552        q = "select " + ','.join(map(str, xrange(num_cols)))
     553        q = ' union all '.join((q,) * num_rows)
     554        r = self.c.query(q).getresult()
     555        result = [tuple(range(num_cols))] * num_rows
     556        self.assertEqual(r, result)
     557
     558    def testListfields(self):
     559        q = 'select 0 as a, 0 as b, 0 as c,' \
     560            ' 0 as c, 0 as b, 0 as a,' \
     561            ' 0 as lowercase, 0 as UPPERCASE,' \
     562            ' 0 as MixedCase, 0 as "MixedCase",' \
     563            ' 0 as a_long_name_with_underscores,' \
     564            ' 0 as "A long name with Blanks"'
     565        r = self.c.query(q).listfields()
     566        result = ('a', 'b', 'c', 'c', 'b', 'a',
     567            'lowercase', 'uppercase', 'mixedcase', 'MixedCase',
     568            'a_long_name_with_underscores',
     569            'A long name with Blanks')
     570        self.assertEqual(r, result)
     571
     572    def testFieldname(self):
     573        q = "select 0 as z, 0 as a, 0 as x, 0 as y"
     574        r = self.c.query(q).fieldname(2)
     575        result = "x"
     576        self.assertEqual(r, result)
     577
     578    def testFieldnum(self):
     579        q = "select 0 as z, 0 as a, 0 as x, 0 as y"
     580        r = self.c.query(q).fieldnum("x")
     581        result = 2
     582        self.assertEqual(r, result)
     583
     584    def testNtuples(self):
     585        q = "select 1 as a, 2 as b, 3 as c, 4 as d" \
     586            " union select 5 as a, 6 as b, 7 as c, 8 as d"
     587        r = self.c.query(q).ntuples()
     588        result = 2
     589        self.assertEqual(r, result)
     590
     591    def testPrint(self):
     592        q = "select 1 as a, 'hello' as h, 'w' as world" \
     593            " union select 2, 'xyz', 'uvw'"
     594        r = self.c.query(q)
     595        t = '~test_pg_testPrint_temp.tmp'
     596        s = open(t, 'w')
     597        import sys, os
     598        stdout, sys.stdout = sys.stdout, s
     599        try:
     600            print r
     601        except:
     602            pass
     603        sys.stdout = stdout
     604        s.close()
     605        r = filter(bool, open(t, 'r').read().splitlines())
     606        os.remove(t)
     607        self.assertEqual(r,
     608            ['a|h    |world',
     609            '-+-----+-----',
     610            '1|hello|w    ',
     611            '2|xyz  |uvw  ',
     612            '(2 rows)'])
    613613
    614614
    615615class TestInserttable(unittest.TestCase):
    616         """"Test inserttable method."""
    617 
    618         # Test database needed: must be run as a DBTestSuite.
    619 
    620         def setUp(self):
    621                 dbname = DBTestSuite.dbname
    622                 self.c = pg.connect(dbname)
    623                 self.c.query('truncate table test')
    624 
    625         def tearDown(self):
    626                 self.c.close()
    627 
    628         def testInserttable1Row(self):
    629                 d = Decimal is float and 1.0 or None
    630                 data = [(1, 1, 1L, d, 1.0, 1.0, d, "1", "1111", "1")]
    631                 self.c.inserttable("test", data)
    632                 r = self.c.query("select * from test").getresult()
    633                 self.assertEqual(r, data)
    634 
    635         def testInserttable4Rows(self):
    636                 data = [(-1, -1, -1L, None, -1.0, -1.0, None, "-1", "-1-1", "-1"),
    637                         (0, 0, 0L, None, 0.0, 0.0, None, "0", "0000", "0"),
    638                         (1, 1, 1L, None, 1.0, 1.0, None, "1", "1111", "1"),
    639                         (2, 2, 2L, None, 2.0, 2.0, None, "2", "2222", "2")]
    640                 self.c.inserttable("test", data)
    641                 r = self.c.query("select * from test order by 1").getresult()
    642                 self.assertEqual(r, data)
    643 
    644         def testInserttableMultipleRows(self):
    645                 num_rows = 100
    646                 data = [(1, 1, 1L, None, 1.0, 1.0, None, "1", "1111", "1")] * num_rows
    647                 self.c.inserttable("test", data)
    648                 r = self.c.query("select count(*) from test").getresult()[0][0]
    649                 self.assertEqual(r, num_rows)
    650 
    651         def testInserttableMultipleCalls(self):
    652                 num_rows = 10
    653                 data = [(1, 1, 1L, None, 1.0, 1.0, None, "1", "1111", "1")]
    654                 for i in range(num_rows):
    655                         self.c.inserttable("test", data)
    656                 r = self.c.query("select count(*) from test").getresult()[0][0]
    657                 self.assertEqual(r, num_rows)
    658 
    659         def testInserttableNullValues(self):
    660                 num_rows = 100
    661                 data = [(None,) * 10]
    662                 self.c.inserttable("test", data)
    663                 r = self.c.query("select * from test").getresult()
    664                 self.assertEqual(r, data)
    665 
    666         def testInserttableMaxValues(self):
    667                 data = [(2**15 - 1, int(2**31 - 1), long(2**31 - 1),
    668                         None, 1.0 + 1.0/32, 1.0 + 1.0/32, None,
    669                         "1234", "1234", "1234" * 10)]
    670                 self.c.inserttable("test", data)
    671                 r = self.c.query("select * from test").getresult()
    672                 self.assertEqual(r, data)
     616    """"Test inserttable method."""
     617
     618    # Test database needed: must be run as a DBTestSuite.
     619
     620    def setUp(self):
     621        dbname = DBTestSuite.dbname
     622        self.c = pg.connect(dbname)
     623        self.c.query('truncate table test')
     624
     625    def tearDown(self):
     626        self.c.close()
     627
     628    def testInserttable1Row(self):
     629        d = Decimal is float and 1.0 or None
     630        data = [(1, 1, 1L, d, 1.0, 1.0, d, "1", "1111", "1")]
     631        self.c.inserttable("test", data)
     632        r = self.c.query("select * from test").getresult()
     633        self.assertEqual(r, data)
     634
     635    def testInserttable4Rows(self):
     636        data = [(-1, -1, -1L, None, -1.0, -1.0, None, "-1", "-1-1", "-1"),
     637            (0, 0, 0L, None, 0.0, 0.0, None, "0", "0000", "0"),
     638            (1, 1, 1L, None, 1.0, 1.0, None, "1", "1111", "1"),
     639            (2, 2, 2L, None, 2.0, 2.0, None, "2", "2222", "2")]
     640        self.c.inserttable("test", data)
     641        r = self.c.query("select * from test order by 1").getresult()
     642        self.assertEqual(r, data)
     643
     644    def testInserttableMultipleRows(self):
     645        num_rows = 100
     646        data = [(1, 1, 1L, None, 1.0, 1.0, None, "1", "1111", "1")] * num_rows
     647        self.c.inserttable("test", data)
     648        r = self.c.query("select count(*) from test").getresult()[0][0]
     649        self.assertEqual(r, num_rows)
     650
     651    def testInserttableMultipleCalls(self):
     652        num_rows = 10
     653        data = [(1, 1, 1L, None, 1.0, 1.0, None, "1", "1111", "1")]
     654        for i in range(num_rows):
     655            self.c.inserttable("test", data)
     656        r = self.c.query("select count(*) from test").getresult()[0][0]
     657        self.assertEqual(r, num_rows)
     658
     659    def testInserttableNullValues(self):
     660        num_rows = 100
     661        data = [(None,) * 10]
     662        self.c.inserttable("test", data)
     663        r = self.c.query("select * from test").getresult()
     664        self.assertEqual(r, data)
     665
     666    def testInserttableMaxValues(self):
     667        data = [(2**15 - 1, int(2**31 - 1), long(2**31 - 1),
     668            None, 1.0 + 1.0/32, 1.0 + 1.0/32, None,
     669            "1234", "1234", "1234" * 10)]
     670        self.c.inserttable("test", data)
     671        r = self.c.query("select * from test").getresult()
     672        self.assertEqual(r, data)
    673673
    674674
    675675class TestDBClassBasic(unittest.TestCase):
    676         """"Test existence of the DB class wrapped pg connection methods."""
    677 
    678         def setUp(self):
    679                 dbname = 'template1'
    680                 self.dbname = dbname
    681                 self.db = pg.DB(dbname)
    682 
    683         def tearDown(self):
    684                 self.db.close()
    685 
    686         def testAllDBAttributes(self):
    687                 attributes = ['cancel', 'clear', 'close',
    688                         'db', 'dbname', 'debug', 'delete', 'endcopy',
    689                         'error', 'escape_bytea', 'escape_string',
    690                         'fileno', 'get', 'get_attnames',
    691                         'get_databases', 'get_relations', 'get_tables',
    692                         'getline', 'getlo', 'getnotify', 'host',
    693                         'insert', 'inserttable', 'locreate', 'loimport',
    694                         'options', 'parameter', 'pkey', 'port', 'putline',
    695                         'query', 'reopen', 'reset', 'source', 'status',
    696                         'transaction', 'tty', 'unescape_bytea',
    697                         'update', 'user']
    698                 db_attributes = [a for a in dir(self.db)
    699                         if not a.startswith('_')]
    700                 if 'transaction' not in db_attributes:
    701                         # this may be the case for PostgreSQL < 7.4
    702                         db_attributes.insert(-4, 'transaction')
    703                 if 'parameter' not in db_attributes:
    704                         # this may be the case for PostgreSQL < 7.4
    705                         db_attributes.insert(-12, 'parameter')
    706                 self.assertEqual(attributes, db_attributes)
    707 
    708         def testAttributeDb(self):
    709                 self.assertEqual(self.db.db.db, self.dbname)
    710 
    711         def testAttributeDbname(self):
    712                 self.assertEqual(self.db.dbname, self.dbname)
    713 
    714         def testAttributeError(self):
    715                 error = self.db.error
    716                 self.assert_(not error or 'krb5_' in error)
    717                 self.assertEqual(self.db.error, self.db.db.error)
    718 
    719         def testAttributeHost(self):
    720                 def_host = 'localhost'
    721                 self.assertEqual(self.db.host, def_host)
    722                 self.assertEqual(self.db.db.host, def_host)
    723 
    724         def testAttributeOptions(self):
    725                 no_options = ''
    726                 self.assertEqual(self.db.options, no_options)
    727                 self.assertEqual(self.db.db.options, no_options)
    728 
    729         def testAttributePort(self):
    730                 def_port = 5432
    731                 self.assertEqual(self.db.port, def_port)
    732                 self.assertEqual(self.db.db.port, def_port)
    733 
    734         def testAttributeStatus(self):
    735                 status_ok = 1
    736                 self.assertEqual(self.db.status, status_ok)
    737                 self.assertEqual(self.db.db.status, status_ok)
    738 
    739         def testAttributeTty(self):
    740                 def_tty = ''
    741                 self.assertEqual(self.db.tty, def_tty)
    742                 self.assertEqual(self.db.db.tty, def_tty)
    743 
    744         def testAttributeUser(self):
    745                 def_user = 'Deprecated facility'
    746                 self.assertEqual(self.db.user, def_user)
    747                 self.assertEqual(self.db.db.user, def_user)
    748 
    749         def testMethodEscapeString(self):
    750                 self.assertEqual(self.db.escape_string('hello'), 'hello')
    751 
    752         def testMethodEscapeBytea(self):
    753                 self.assertEqual(self.db.escape_bytea('hello'), 'hello')
    754 
    755         def testMethodUnescapeBytea(self):
    756                 self.assertEqual(self.db.unescape_bytea('hello'), 'hello')
    757 
    758         def testMethodQuery(self):
    759                 self.db.query("select 1+1")
    760 
    761         def testMethodEndcopy(self):
    762                 try:
    763                         self.db.endcopy()
    764                 except IOError:
    765                         pass
    766 
    767         def testMethodClose(self):
    768                 self.db.close()
    769                 try:
    770                         self.db.reset()
    771                         fail('Reset should give an error for a closed connection')
    772                 except:
    773                         pass
    774                 self.assertRaises(pg.InternalError, self.db.close)
    775                 self.assertRaises(pg.InternalError, self.db.query, 'select 1')
    776                 self.db = pg.DB(self.dbname)
    777 
    778         def testExistingConnection(self):
    779                 db = pg.DB(self.db.db)
    780                 self.assertEqual(self.db.db, db.db)
    781                 self.assert_(db.db)
    782                 db.close()
    783                 self.assert_(db.db)
    784                 db.reopen()
    785                 self.assert_(db.db)
    786                 db.close()
    787                 self.assert_(db.db)
    788                 db = pg.DB(self.db)
    789                 self.assertEqual(self.db.db, db.db)
    790                 db = pg.DB(db=self.db.db)
    791                 self.assertEqual(self.db.db, db.db)
    792                 class DB2:
    793                         pass
    794                 db2 = DB2()
    795                 db2._cnx = self.db.db
    796                 db = pg.DB(db2)
    797                 self.assertEqual(self.db.db, db.db)
     676    """"Test existence of the DB class wrapped pg connection methods."""
     677
     678    def setUp(self):
     679        dbname = 'template1'
     680        self.dbname = dbname
     681        self.db = pg.DB(dbname)
     682
     683    def tearDown(self):
     684        self.db.close()
     685
     686    def testAllDBAttributes(self):
     687        attributes = ['cancel', 'clear', 'close',
     688            'db', 'dbname', 'debug', 'delete', 'endcopy',
     689            'error', 'escape_bytea', 'escape_string',
     690            'fileno', 'get', 'get_attnames',
     691            'get_databases', 'get_relations', 'get_tables',
     692            'getline', 'getlo', 'getnotify', 'host',
     693            'insert', 'inserttable', 'locreate', 'loimport',
     694            'options', 'parameter', 'pkey', 'port', 'putline',
     695            'query', 'reopen', 'reset', 'source', 'status',
     696            'transaction', 'tty', 'unescape_bytea',
     697            'update', 'user']
     698        db_attributes = [a for a in dir(self.db)
     699            if not a.startswith('_')]
     700        if 'transaction' not in db_attributes:
     701            # this may be the case for PostgreSQL < 7.4
     702            db_attributes.insert(-4, 'transaction')
     703        if 'parameter' not in db_attributes:
     704            # this may be the case for PostgreSQL < 7.4
     705            db_attributes.insert(-12, 'parameter')
     706        self.assertEqual(attributes, db_attributes)
     707
     708    def testAttributeDb(self):
     709        self.assertEqual(self.db.db.db, self.dbname)
     710
     711    def testAttributeDbname(self):
     712        self.assertEqual(self.db.dbname, self.dbname)
     713
     714    def testAttributeError(self):
     715        error = self.db.error
     716        self.assert_(not error or 'krb5_' in error)
     717        self.assertEqual(self.db.error, self.db.db.error)
     718
     719    def testAttributeHost(self):
     720        def_host = 'localhost'
     721        self.assertEqual(self.db.host, def_host)
     722        self.assertEqual(self.db.db.host, def_host)
     723
     724    def testAttributeOptions(self):
     725        no_options = ''
     726        self.assertEqual(self.db.options, no_options)
     727        self.assertEqual(self.db.db.options, no_options)
     728
     729    def testAttributePort(self):
     730        def_port = 5432
     731        self.assertEqual(self.db.port, def_port)
     732        self.assertEqual(self.db.db.port, def_port)
     733
     734    def testAttributeStatus(self):
     735        status_ok = 1
     736        self.assertEqual(self.db.status, status_ok)
     737        self.assertEqual(self.db.db.status, status_ok)
     738
     739    def testAttributeTty(self):
     740        def_tty = ''
     741        self.assertEqual(self.db.tty, def_tty)
     742        self.assertEqual(self.db.db.tty, def_tty)
     743
     744    def testAttributeUser(self):
     745        def_user = 'Deprecated facility'
     746        self.assertEqual(self.db.user, def_user)
     747        self.assertEqual(self.db.db.user, def_user)
     748
     749    def testMethodEscapeString(self):
     750        self.assertEqual(self.db.escape_string('hello'), 'hello')
     751
     752    def testMethodEscapeBytea(self):
     753        self.assertEqual(self.db.escape_bytea('hello'), 'hello')
     754
     755    def testMethodUnescapeBytea(self):
     756        self.assertEqual(self.db.unescape_bytea('hello'), 'hello')
     757
     758    def testMethodQuery(self):
     759        self.db.query("select 1+1")
     760
     761    def testMethodEndcopy(self):
     762        try:
     763            self.db.endcopy()
     764        except IOError:
     765            pass
     766
     767    def testMethodClose(self):
     768        self.db.close()
     769        try:
     770            self.db.reset()
     771            fail('Reset should give an error for a closed connection')
     772        except:
     773            pass
     774        self.assertRaises(pg.InternalError, self.db.close)
     775        self.assertRaises(pg.InternalError, self.db.query, 'select 1')
     776        self.db = pg.DB(self.dbname)
     777
     778    def testExistingConnection(self):
     779        db = pg.DB(self.db.db)
     780        self.assertEqual(self.db.db, db.db)
     781        self.assert_(db.db)
     782        db.close()
     783        self.assert_(db.db)
     784        db.reopen()
     785        self.assert_(db.db)
     786        db.close()
     787        self.assert_(db.db)
     788        db = pg.DB(self.db)
     789        self.assertEqual(self.db.db, db.db)
     790        db = pg.DB(db=self.db.db)
     791        self.assertEqual(self.db.db, db.db)
     792        class DB2:
     793            pass
     794        db2 = DB2()
     795        db2._cnx = self.db.db
     796        db = pg.DB(db2)
     797        self.assertEqual(self.db.db, db.db)
    798798
    799799
    800800class TestDBClass(unittest.TestCase):
    801         """"Test the methods of the DB class wrapped pg connection."""
    802 
    803         # Test database needed: must be run as a DBTestSuite.
    804 
    805         def setUp(self):
    806                 dbname = DBTestSuite.dbname
    807                 self.dbname = dbname
    808                 self.db = pg.DB(dbname)
    809 
    810         def tearDown(self):
    811                 self.db.close()
    812 
    813         def testEscapeString(self):
    814                 self.assertEqual(self.db.escape_string(
    815                         r"It's fine to have a \ inside."),
    816                         r"It''s fine to have a \\ inside.")
    817 
    818         def testEscapeBytea(self):
    819                 self.assertEqual(self.db.escape_bytea(
    820                         'O\x00ps\xff!'), r'O\\000ps\\377!')
    821 
    822         def testUnescapeBytea(self):
    823                 self.assertEqual(self.db.unescape_bytea(
    824                         r'O\000ps\377!'), 'O\x00ps\xff!')
    825 
    826         def testPkey(self):
    827                 smart_ddl(self.db, 'drop table pkeytest0')
    828                 smart_ddl(self.db, "create table pkeytest0 ("
    829                         "a smallint)")
    830                 smart_ddl(self.db, 'drop table pkeytest1')
    831                 smart_ddl(self.db, "create table pkeytest1 ("
    832                         "b smallint primary key)")
    833                 smart_ddl(self.db, 'drop table pkeytest2')
    834                 smart_ddl(self.db, "create table pkeytest2 ("
    835                         "c smallint, d smallint primary key)")
    836                 smart_ddl(self.db, 'drop table pkeytest3')
    837                 smart_ddl(self.db, "create table pkeytest3 ("
    838                         "e smallint, f smallint, g smallint, "
    839                         "h smallint, i smallint, "
    840                         "primary key (f,h))")
    841                 self.assertRaises(KeyError, self.db.pkey, "pkeytest0")
    842                 self.assertEqual(self.db.pkey("pkeytest1"), "b")
    843                 self.assertEqual(self.db.pkey("pkeytest2"), "d")
    844                 self.assertEqual(self.db.pkey("pkeytest3"), "f")
    845 
    846         def testGetDatabases(self):
    847                 databases = self.db.get_databases()
    848                 self.assert_('template0' in databases)
    849                 self.assert_('template1' in databases)
    850                 self.assert_(self.dbname in databases)
    851 
    852         def testGetTables(self):
    853                 result1 = self.db.get_tables()
    854                 tables = ('"A very Special Name"',
    855                         '"A_MiXeD_quoted_NaMe"', 'a1', 'a2',
    856                         'A_MiXeD_NaMe', '"another special name"',
    857                         'averyveryveryveryveryveryverylongtablename',
    858                         'b0', 'b3', 'x', 'xx', 'xXx', 'y', 'z')
    859                 for t in tables:
    860                         smart_ddl(self.db, 'drop table ' + t)
    861                         smart_ddl(self.db, "create table %s"
    862                                 " as select 0" % t)
    863                 result3 = self.db.get_tables()
    864                 result2 = []
    865                 for t in result3:
    866                         if t not in result1:
    867                                 result2.append(t)
    868                 result3 = []
    869                 for t in tables:
    870                         if not t.startswith('"'):
    871                                 t = t.lower()
    872                         result3.append('public.' + t)
    873                 self.assertEqual(result2, result3)
    874                 for t in result2:
    875                         self.db.query('drop table ' + t)
    876                 result2 = self.db.get_tables()
    877                 self.assertEqual(result2, result1)
    878 
    879         def testGetRelations(self):
    880                 result = self.db.get_relations()
    881                 self.assert_('public.test' in result)
    882                 self.assert_('public.test_view' in result)
    883                 result = self.db.get_relations('rv')
    884                 self.assert_('public.test' in result)
    885                 self.assert_('public.test_view' in result)
    886                 result = self.db.get_relations('r')
    887                 self.assert_('public.test' in result)
    888                 self.assert_('public.test_view' not in result)
    889                 result = self.db.get_relations('v')
    890                 self.assert_('public.test' not in result)
    891                 self.assert_('public.test_view' in result)
    892                 result = self.db.get_relations('cisSt')
    893                 self.assert_('public.test' not in result)
    894                 self.assert_('public.test_view' not in result)
    895 
    896         def testAttnames(self):
    897                 self.assertRaises(pg.ProgrammingError,
    898                         self.db.get_attnames, 'does_not_exist')
    899                 self.assertRaises(pg.ProgrammingError,
    900                         self.db.get_attnames, 'has.too.many.dots')
    901                 for table in ('attnames_test_table', 'test table for attnames'):
    902                         smart_ddl(self.db, 'drop table "%s"' % table)
    903                         smart_ddl(self.db, 'create table "%s" ('
    904                                 'a smallint, b integer, c bigint, '
    905                                 'e numeric, f float, f2 double precision, m money, '
    906                                 'x smallint, y smallint, z smallint, '
    907                                 'Normal_NaMe smallint, "Special Name" smallint, '
    908                                 't text, u char(2), v varchar(2), '
    909                                 'primary key (y, u))' % table)
    910                         attributes = self.db.get_attnames(table)
    911                         result = {'a': 'int', 'c': 'int', 'b': 'int',
    912                                 'e': 'num', 'f': 'float', 'f2': 'float', 'm': 'money',
    913                                 'normal_name': 'int', 'Special Name': 'int',
    914                                 'u': 'text', 't': 'text', 'v': 'text',
    915                                 'y': 'int', 'x': 'int', 'z': 'int', 'oid': 'int' }
    916                         self.assertEqual(attributes, result)
    917 
    918         def testGet(self):
    919                 for table in ('get_test_table', 'test table for get'):
    920                         smart_ddl(self.db, 'drop table "%s"' % table)
    921                         smart_ddl(self.db, 'create table "%s" ('
    922                                 "n integer, t text)" % table)
    923                         for n, t in enumerate('xyz'):
    924                                 self.db.query('insert into "%s" values('
    925                                         "%d, '%s')" % (table, n+1, t))
    926                         self.assertRaises(KeyError, self.db.get, table, 2)
    927                         r = self.db.get(table, 2, 'n')
    928                         oid_table = table
    929                         if ' ' in table:
    930                                 oid_table = '"%s"' % oid_table
    931                         oid_table = 'oid(public.%s)' % oid_table
    932                         self.assert_(oid_table in r)
    933                         self.assert_(isinstance(r[oid_table], int))
    934                         result = {'t': 'y', 'n': 2, oid_table: r[oid_table]}
    935                         self.assertEqual(r, result)
    936                         self.assertEqual(self.db.get(table, r[oid_table], 'oid')['t'], 'y')
    937                         self.assertEqual(self.db.get(table, 1, 'n')['t'], 'x')
    938                         self.assertEqual(self.db.get(table, 3, 'n')['t'], 'z')
    939                         self.assertEqual(self.db.get(table, 2, 'n')['t'], 'y')
    940                         self.assertRaises(pg.DatabaseError, self.db.get, table, 4, 'n')
    941                         r['n'] = 3
    942                         self.assertEqual(self.db.get(table, r, 'n')['t'], 'z')
    943                         self.assertEqual(self.db.get(table, 1, 'n')['t'], 'x')
    944                         self.db.query('alter table "%s" alter n set not null' % table)
    945                         self.db.query('alter table "%s" add primary key (n)' % table)
    946                         self.assertEqual(self.db.get(table, 3)['t'], 'z')
    947                         self.assertEqual(self.db.get(table, 1)['t'], 'x')
    948                         self.assertEqual(self.db.get(table, 2)['t'], 'y')
    949                         r['n'] = 1
    950                         self.assertEqual(self.db.get(table, r)['t'], 'x')
    951                         r['n'] = 3
    952                         self.assertEqual(self.db.get(table, r)['t'], 'z')
    953                         r['n'] = 2
    954                         self.assertEqual(self.db.get(table, r)['t'], 'y')
    955 
    956         def testGetFromView(self):
    957                 self.db.query('delete from test where i4=14')
    958                 self.db.query('insert into test (i4, v4) values('
    959                         "14, 'abc4')")
    960                 r = self.db.get('test_view', 14, 'i4')
    961                 self.assert_('v4' in r)
    962                 self.assertEqual(r['v4'], 'abc4')
    963 
    964         def testInsert(self):
    965                 for table in ('insert_test_table', 'test table for insert'):
    966                         smart_ddl(self.db, 'drop table "%s"' % table)
    967                         smart_ddl(self.db, 'create table "%s" ('
    968                                 "i2 smallint, i4 integer, i8 bigint,"
    969                                 "d numeric, f4 real, f8 double precision, m money, "
    970                                 "v4 varchar(4), c4 char(4), t text,"
    971                                 "b boolean, ts timestamp)" % table)
    972                         data = dict(i2 = 2**15 - 1,
    973                                 i4 = int(2**31 - 1), i8 = long(2**31 - 1),
    974                                 d = Decimal('123456789.9876543212345678987654321'),
    975                                 f4 = 1.0 + 1.0/32, f8 = 1.0 + 1.0/32,
    976                                 m = "1234.56", v4 = "1234", c4 = "1234", t = "1234" * 10,
    977                                 b = 1, ts = 'current_date')
    978                         r = self.db.insert(table, data)
    979                         self.assertEqual(r, data)
    980                         oid_table = table
    981                         if ' ' in table:
    982                                 oid_table = '"%s"' % oid_table
    983                         oid_table = 'oid(public.%s)' % oid_table
    984                         self.assert_(oid_table in r)
    985                         self.assert_(isinstance(r[oid_table], int))
    986                         s = self.db.query('select oid,* from "%s"' % table).dictresult()[0]
    987                         s[oid_table] = s['oid']
    988                         del s['oid']
    989                         self.assertEqual(r, s)
    990 
    991         def testUpdate(self):
    992                 for table in ('update_test_table', 'test table for update'):
    993                         smart_ddl(self.db, 'drop table "%s"' % table)
    994                         smart_ddl(self.db, 'create table "%s" ('
    995                                 "n integer, t text)" % table)
    996                         for n, t in enumerate('xyz'):
    997                                 self.db.query('insert into "%s" values('
    998                                         "%d, '%s')" % (table, n+1, t))
    999                         self.assertRaises(KeyError, self.db.get, table, 2)
    1000                         r = self.db.get(table, 2, 'n')
    1001                         r['t'] = 'u'
    1002                         s = self.db.update(table, r)
    1003                         self.assertEqual(s, r)
    1004                         r = self.db.query('select t from "%s" where n=2' % table
    1005                                 ).getresult()[0][0]
    1006                         self.assertEqual(r, 'u')
    1007 
    1008         def testClear(self):
    1009                 for table in ('clear_test_table', 'test table for clear'):
    1010                         smart_ddl(self.db, 'drop table "%s"' % table)
    1011                         smart_ddl(self.db, 'create table "%s" ('
    1012                                 "n integer, b boolean, d date, t text)" % table)
    1013                         r = self.db.clear(table)
    1014                         result = {'n': 0, 'b': 'f', 'd': '', 't': ''}
    1015                         self.assertEqual(r, result)
    1016                         r['a'] = r['n'] = 1
    1017                         r['d'] = r['t'] = 'x'
    1018                         r['b']
    1019                         r['oid'] = 1L
    1020                         r = self.db.clear(table, r)
    1021                         result = {'a': 1, 'n': 0, 'b': 'f', 'd': '', 't': '', 'oid': 1L}
    1022                         self.assertEqual(r, result)
    1023 
    1024         def testDelete(self):
    1025                 for table in ('delete_test_table', 'test table for delete'):
    1026                         smart_ddl(self.db, 'drop table "%s"' % table)
    1027                         smart_ddl(self.db, 'create table "%s" ('
    1028                                 "n integer, t text)" % table)
    1029                         for n, t in enumerate('xyz'):
    1030                                 self.db.query('insert into "%s" values('
    1031                                         "%d, '%s')" % (table, n+1, t))
    1032                         self.assertRaises(KeyError, self.db.get, table, 2)
    1033                         r = self.db.get(table, 1, 'n')
    1034                         s = self.db.delete(table, r)
    1035                         r = self.db.get(table, 3, 'n')
    1036                         s = self.db.delete(table, r)
    1037                         r = self.db.query('select * from "%s"' % table).dictresult()
    1038                         self.assertEqual(len(r), 1)
    1039                         r = r[0]
    1040                         result = {'n': 2, 't': 'y'}
    1041                         self.assertEqual(r, result)
    1042                         r = self.db.get(table, 2, 'n')
    1043                         s = self.db.delete(table, r)
    1044                         self.assertRaises(pg.DatabaseError, self.db.get, table, 2, 'n')
    1045 
    1046         def testBytea(self):
    1047                 smart_ddl(self.db, 'drop table bytea_test')
    1048                 smart_ddl(self.db, 'create table bytea_test ('
    1049                         'data bytea)')
    1050                 s = "It's all \\ kinds \x00 of\r nasty \xff stuff!\n"
    1051                 r = self.db.escape_bytea(s)
    1052                 self.db.query('insert into bytea_test values('
    1053                         "'%s')" % r)
    1054                 r = self.db.query('select * from bytea_test').getresult()
    1055                 self.assert_(len(r) == 1)
    1056                 r = r[0]
    1057                 self.assert_(len(r) == 1)
    1058                 r = r[0]
    1059                 r = self.db.unescape_bytea(r)
    1060                 self.assertEqual(r, s)
     801    """"Test the methods of the DB class wrapped pg connection."""
     802
     803    # Test database needed: must be run as a DBTestSuite.
     804
     805    def setUp(self):
     806        dbname = DBTestSuite.dbname
     807        self.dbname = dbname
     808        self.db = pg.DB(dbname)
     809
     810    def tearDown(self):
     811        self.db.close()
     812
     813    def testEscapeString(self):
     814        self.assertEqual(self.db.escape_string(
     815            r"It's fine to have a \ inside."),
     816            r"It''s fine to have a \\ inside.")
     817
     818    def testEscapeBytea(self):
     819        self.assertEqual(self.db.escape_bytea(
     820            'O\x00ps\xff!'), r'O\\000ps\\377!')
     821
     822    def testUnescapeBytea(self):
     823        self.assertEqual(self.db.unescape_bytea(
     824            r'O\000ps\377!'), 'O\x00ps\xff!')
     825
     826    def testPkey(self):
     827        smart_ddl(self.db, 'drop table pkeytest0')
     828        smart_ddl(self.db, "create table pkeytest0 ("
     829            "a smallint)")
     830        smart_ddl(self.db, 'drop table pkeytest1')
     831        smart_ddl(self.db, "create table pkeytest1 ("
     832            "b smallint primary key)")
     833        smart_ddl(self.db, 'drop table pkeytest2')
     834        smart_ddl(self.db, "create table pkeytest2 ("
     835            "c smallint, d smallint primary key)")
     836        smart_ddl(self.db, 'drop table pkeytest3')
     837        smart_ddl(self.db, "create table pkeytest3 ("
     838            "e smallint, f smallint, g smallint, "
     839            "h smallint, i smallint, "
     840            "primary key (f,h))")
     841        self.assertRaises(KeyError, self.db.pkey, "pkeytest0")
     842        self.assertEqual(self.db.pkey("pkeytest1"), "b")
     843        self.assertEqual(self.db.pkey("pkeytest2"), "d")
     844        self.assertEqual(self.db.pkey("pkeytest3"), "f")
     845
     846    def testGetDatabases(self):
     847        databases = self.db.get_databases()
     848        self.assert_('template0' in databases)
     849        self.assert_('template1' in databases)
     850        self.assert_(self.dbname in databases)
     851
     852    def testGetTables(self):
     853        result1 = self.db.get_tables()
     854        tables = ('"A very Special Name"',
     855            '"A_MiXeD_quoted_NaMe"', 'a1', 'a2',
     856            'A_MiXeD_NaMe', '"another special name"',
     857            'averyveryveryveryveryveryverylongtablename',
     858            'b0', 'b3', 'x', 'xx', 'xXx', 'y', 'z')
     859        for t in tables:
     860            smart_ddl(self.db, 'drop table ' + t)
     861            smart_ddl(self.db, "create table %s"
     862                " as select 0" % t)
     863        result3 = self.db.get_tables()
     864        result2 = []
     865        for t in result3:
     866            if t not in result1:
     867                result2.append(t)
     868        result3 = []
     869        for t in tables:
     870            if not t.startswith('"'):
     871                t = t.lower()
     872            result3.append('public.' + t)
     873        self.assertEqual(result2, result3)
     874        for t in result2:
     875            self.db.query('drop table ' + t)
     876        result2 = self.db.get_tables()
     877        self.assertEqual(result2, result1)
     878
     879    def testGetRelations(self):
     880        result = self.db.get_relations()
     881        self.assert_('public.test' in result)
     882        self.assert_('public.test_view' in result)
     883        result = self.db.get_relations('rv')
     884        self.assert_('public.test' in result)
     885        self.assert_('public.test_view' in result)
     886        result = self.db.get_relations('r')
     887        self.assert_('public.test' in result)
     888        self.assert_('public.test_view' not in result)
     889        result = self.db.get_relations('v')
     890        self.assert_('public.test' not in result)
     891        self.assert_('public.test_view' in result)
     892        result = self.db.get_relations('cisSt')
     893        self.assert_('public.test' not in result)
     894        self.assert_('public.test_view' not in result)
     895
     896    def testAttnames(self):
     897        self.assertRaises(pg.ProgrammingError,
     898            self.db.get_attnames, 'does_not_exist')
     899        self.assertRaises(pg.ProgrammingError,
     900            self.db.get_attnames, 'has.too.many.dots')
     901        for table in ('attnames_test_table', 'test table for attnames'):
     902            smart_ddl(self.db, 'drop table "%s"' % table)
     903            smart_ddl(self.db, 'create table "%s" ('
     904                'a smallint, b integer, c bigint, '
     905                'e numeric, f float, f2 double precision, m money, '
     906                'x smallint, y smallint, z smallint, '
     907                'Normal_NaMe smallint, "Special Name" smallint, '
     908                't text, u char(2), v varchar(2), '
     909                'primary key (y, u))' % table)
     910            attributes = self.db.get_attnames(table)
     911            result = {'a': 'int', 'c': 'int', 'b': 'int',
     912                'e': 'num', 'f': 'float', 'f2': 'float', 'm': 'money',
     913                'normal_name': 'int', 'Special Name': 'int',
     914                'u': 'text', 't': 'text', 'v': 'text',
     915                'y': 'int', 'x': 'int', 'z': 'int', 'oid': 'int' }
     916            self.assertEqual(attributes, result)
     917
     918    def testGet(self):
     919        for table in ('get_test_table', 'test table for get'):
     920            smart_ddl(self.db, 'drop table "%s"' % table)
     921            smart_ddl(self.db, 'create table "%s" ('
     922                "n integer, t text)" % table)
     923            for n, t in enumerate('xyz'):
     924                self.db.query('insert into "%s" values('
     925                    "%d, '%s')" % (table, n+1, t))
     926            self.assertRaises(KeyError, self.db.get, table, 2)
     927            r = self.db.get(table, 2, 'n')
     928            oid_table = table
     929            if ' ' in table:
     930                oid_table = '"%s"' % oid_table
     931            oid_table = 'oid(public.%s)' % oid_table
     932            self.assert_(oid_table in r)
     933            self.assert_(isinstance(r[oid_table], int))
     934            result = {'t': 'y', 'n': 2, oid_table: r[oid_table]}
     935            self.assertEqual(r, result)
     936            self.assertEqual(self.db.get(table, r[oid_table], 'oid')['t'], 'y')
     937            self.assertEqual(self.db.get(table, 1, 'n')['t'], 'x')
     938            self.assertEqual(self.db.get(table, 3, 'n')['t'], 'z')
     939            self.assertEqual(self.db.get(table, 2, 'n')['t'], 'y')
     940            self.assertRaises(pg.DatabaseError, self.db.get, table, 4, 'n')
     941            r['n'] = 3
     942            self.assertEqual(self.db.get(table, r, 'n')['t'], 'z')
     943            self.assertEqual(self.db.get(table, 1, 'n')['t'], 'x')
     944            self.db.query('alter table "%s" alter n set not null' % table)
     945            self.db.query('alter table "%s" add primary key (n)' % table)
     946            self.assertEqual(self.db.get(table, 3)['t'], 'z')
     947            self.assertEqual(self.db.get(table, 1)['t'], 'x')
     948            self.assertEqual(self.db.get(table, 2)['t'], 'y')
     949            r['n'] = 1
     950            self.assertEqual(self.db.get(table, r)['t'], 'x')
     951            r['n'] = 3
     952            self.assertEqual(self.db.get(table, r)['t'], 'z')
     953            r['n'] = 2
     954            self.assertEqual(self.db.get(table, r)['t'], 'y')
     955
     956    def testGetFromView(self):
     957        self.db.query('delete from test where i4=14')
     958        self.db.query('insert into test (i4, v4) values('
     959            "14, 'abc4')")
     960        r = self.db.get('test_view', 14, 'i4')
     961        self.assert_('v4' in r)
     962        self.assertEqual(r['v4'], 'abc4')
     963
     964    def testInsert(self):
     965        for table in ('insert_test_table', 'test table for insert'):
     966            smart_ddl(self.db, 'drop table "%s"' % table)
     967            smart_ddl(self.db, 'create table "%s" ('
     968                "i2 smallint, i4 integer, i8 bigint,"
     969                "d numeric, f4 real, f8 double precision, m money, "
     970                "v4 varchar(4), c4 char(4), t text,"
     971                "b boolean, ts timestamp)" % table)
     972            data = dict(i2 = 2**15 - 1,
     973                i4 = int(2**31 - 1), i8 = long(2**31 - 1),
     974                d = Decimal('123456789.9876543212345678987654321'),
     975                f4 = 1.0 + 1.0/32, f8 = 1.0 + 1.0/32,
     976                m = "1234.56", v4 = "1234", c4 = "1234", t = "1234" * 10,
     977                b = 1, ts = 'current_date')
     978            r = self.db.insert(table, data)
     979            self.assertEqual(r, data)
     980            oid_table = table
     981            if ' ' in table:
     982                oid_table = '"%s"' % oid_table
     983            oid_table = 'oid(public.%s)' % oid_table
     984            self.assert_(oid_table in r)
     985            self.assert_(isinstance(r[oid_table], int))
     986            s = self.db.query('select oid,* from "%s"' % table).dictresult()[0]
     987            s[oid_table] = s['oid']
     988            del s['oid']
     989            self.assertEqual(r, s)
     990
     991    def testUpdate(self):
     992        for table in ('update_test_table', 'test table for update'):
     993            smart_ddl(self.db, 'drop table "%s"' % table)
     994            smart_ddl(self.db, 'create table "%s" ('
     995                "n integer, t text)" % table)
     996            for n, t in enumerate('xyz'):
     997                self.db.query('insert into "%s" values('
     998                    "%d, '%s')" % (table, n+1, t))
     999            self.assertRaises(KeyError, self.db.get, table, 2)
     1000            r = self.db.get(table, 2, 'n')
     1001            r['t'] = 'u'
     1002            s = self.db.update(table, r)
     1003            self.assertEqual(s, r)
     1004            r = self.db.query('select t from "%s" where n=2' % table
     1005                ).getresult()[0][0]
     1006            self.assertEqual(r, 'u')
     1007
     1008    def testClear(self):
     1009        for table in ('clear_test_table', 'test table for clear'):
     1010            smart_ddl(self.db, 'drop table "%s"' % table)
     1011            smart_ddl(self.db, 'create table "%s" ('
     1012                "n integer, b boolean, d date, t text)" % table)
     1013            r = self.db.clear(table)
     1014            result = {'n': 0, 'b': 'f', 'd': '', 't': ''}
     1015            self.assertEqual(r, result)
     1016            r['a'] = r['n'] = 1
     1017            r['d'] = r['t'] = 'x'
     1018            r['b']
     1019            r['oid'] = 1L
     1020            r = self.db.clear(table, r)
     1021            result = {'a': 1, 'n': 0, 'b': 'f', 'd': '', 't': '', 'oid': 1L}
     1022            self.assertEqual(r, result)
     1023
     1024    def testDelete(self):
     1025        for table in ('delete_test_table', 'test table for delete'):
     1026            smart_ddl(self.db, 'drop table "%s"' % table)
     1027            smart_ddl(self.db, 'create table "%s" ('
     1028                "n integer, t text)" % table)
     1029            for n, t in enumerate('xyz'):
     1030                self.db.query('insert into "%s" values('
     1031                    "%d, '%s')" % (table, n+1, t))
     1032            self.assertRaises(KeyError, self.db.get, table, 2)
     1033            r = self.db.get(table, 1, 'n')
     1034            s = self.db.delete(table, r)
     1035            r = self.db.get(table, 3, 'n')
     1036            s = self.db.delete(table, r)
     1037            r = self.db.query('select * from "%s"' % table).dictresult()
     1038            self.assertEqual(len(r), 1)
     1039            r = r[0]
     1040            result = {'n': 2, 't': 'y'}
     1041            self.assertEqual(r, result)
     1042            r = self.db.get(table, 2, 'n')
     1043            s = self.db.delete(table, r)
     1044            self.assertRaises(pg.DatabaseError, self.db.get, table, 2, 'n')
     1045
     1046    def testBytea(self):
     1047        smart_ddl(self.db, 'drop table bytea_test')
     1048        smart_ddl(self.db, 'create table bytea_test ('
     1049            'data bytea)')
     1050        s = "It's all \\ kinds \x00 of\r nasty \xff stuff!\n"
     1051        r = self.db.escape_bytea(s)
     1052        self.db.query('insert into bytea_test values('
     1053            "'%s')" % r)
     1054        r = self.db.query('select * from bytea_test').getresult()
     1055        self.assert_(len(r) == 1)
     1056        r = r[0]
     1057        self.assert_(len(r) == 1)
     1058        r = r[0]
     1059        r = self.db.unescape_bytea(r)
     1060        self.assertEqual(r, s)
    10611061
    10621062
    10631063class TestSchemas(unittest.TestCase):
    1064         """"Test correct handling of schemas (namespaces)."""
    1065 
    1066         # Test database needed: must be run as a DBTestSuite.
    1067 
    1068         def setUp(self):
    1069                 dbname = DBTestSuite.dbname
    1070                 self.dbname = dbname
    1071                 self.db = pg.DB(dbname)
    1072 
    1073         def tearDown(self):
    1074                 self.db.close()
    1075 
    1076         def testGetTables(self):
    1077                 tables = self.db.get_tables()
    1078                 for num_schema in range(5):
    1079                         if num_schema:
    1080                                 schema = "s" + str(num_schema)
    1081                         else:
    1082                                 schema = "public"
    1083                         for t in (schema + ".t",
    1084                                 schema + ".t" + str(num_schema)):
    1085                                 self.assert_(t in tables, t + ' not in get_tables()')
    1086 
    1087         def testGetAttnames(self):
    1088                 result = {'oid': 'int', 'd': 'int', 'n': 'int'}
    1089                 r = self.db.get_attnames("t")
    1090                 self.assertEqual(r, result)
    1091                 r = self.db.get_attnames("s4.t4")
    1092                 self.assertEqual(r, result)
    1093                 smart_ddl(self.db, "create table s3.t3m"
    1094                         " as select 1 as m")
    1095                 result_m = {'oid': 'int', 'm': 'int'}
    1096                 r = self.db.get_attnames("s3.t3m")
    1097                 self.assertEqual(r, result_m)
    1098                 self.db.query("set search_path to s1,s3")
    1099                 r = self.db.get_attnames("t3")
    1100                 self.assertEqual(r, result)
    1101                 r = self.db.get_attnames("t3m")
    1102                 self.assertEqual(r, result_m)
    1103 
    1104         def testGet(self):
    1105                 self.assertEqual(self.db.get("t", 1, 'n')['d'], 0)
    1106                 self.assertEqual(self.db.get("t0", 1, 'n')['d'], 0)
    1107                 self.assertEqual(self.db.get("public.t", 1, 'n')['d'], 0)
    1108                 self.assertEqual(self.db.get("public.t0", 1, 'n')['d'], 0)
    1109                 self.assertRaises(pg.ProgrammingError, self.db.get, "public.t1", 1, 'n')
    1110                 self.assertEqual(self.db.get("s1.t1", 1, 'n')['d'], 1)
    1111                 self.assertEqual(self.db.get("s3.t", 1, 'n')['d'], 3)
    1112                 self.db.query("set search_path to s2,s4")
    1113                 self.assertRaises(pg.ProgrammingError, self.db.get, "t1", 1, 'n')
    1114                 self.assertEqual(self.db.get("t4", 1, 'n')['d'], 4)
    1115                 self.assertRaises(pg.ProgrammingError, self.db.get, "t3", 1, 'n')
    1116                 self.assertEqual(self.db.get("t", 1, 'n')['d'], 2)
    1117                 self.assertEqual(self.db.get("s3.t3", 1, 'n')['d'], 3)
    1118                 self.db.query("set search_path to s1,s3")
    1119                 self.assertRaises(pg.ProgrammingError, self.db.get, "t2", 1, 'n')
    1120                 self.assertEqual(self.db.get("t3", 1, 'n')['d'], 3)
    1121                 self.assertRaises(pg.ProgrammingError, self.db.get, "t4", 1, 'n')
    1122                 self.assertEqual(self.db.get("t", 1, 'n')['d'], 1)
    1123                 self.assertEqual(self.db.get("s4.t4", 1, 'n')['d'], 4)
    1124 
    1125         def testMangling(self):
    1126                 r = self.db.get("t", 1, 'n')
    1127                 self.assert_('oid(public.t)' in r)
    1128                 self.db.query("set search_path to s2")
    1129                 r = self.db.get("t2", 1, 'n')
    1130                 self.assert_('oid(s2.t2)' in r)
    1131                 self.db.query("set search_path to s3")
    1132                 r = self.db.get("t", 1, 'n')
    1133                 self.assert_('oid(s3.t)' in r)
     1064    """"Test correct handling of schemas (namespaces)."""
     1065
     1066    # Test database needed: must be run as a DBTestSuite.
     1067
     1068    def setUp(self):
     1069        dbname = DBTestSuite.dbname
     1070        self.dbname = dbname
     1071        self.db = pg.DB(dbname)
     1072
     1073    def tearDown(self):
     1074        self.db.close()
     1075
     1076    def testGetTables(self):
     1077        tables = self.db.get_tables()
     1078        for num_schema in range(5):
     1079            if num_schema:
     1080                schema = "s" + str(num_schema)
     1081            else:
     1082                schema = "public"
     1083            for t in (schema + ".t",
     1084                schema + ".t" + str(num_schema)):
     1085                self.assert_(t in tables, t + ' not in get_tables()')
     1086
     1087    def testGetAttnames(self):
     1088        result = {'oid': 'int', 'd': 'int', 'n': 'int'}
     1089        r = self.db.get_attnames("t")
     1090        self.assertEqual(r, result)
     1091        r = self.db.get_attnames("s4.t4")
     1092        self.assertEqual(r, result)
     1093        smart_ddl(self.db, "create table s3.t3m"
     1094            " as select 1 as m")
     1095        result_m = {'oid': 'int', 'm': 'int'}
     1096        r = self.db.get_attnames("s3.t3m")
     1097        self.assertEqual(r, result_m)
     1098        self.db.query("set search_path to s1,s3")
     1099        r = self.db.get_attnames("t3")
     1100        self.assertEqual(r, result)
     1101        r = self.db.get_attnames("t3m")
     1102        self.assertEqual(r, result_m)
     1103
     1104    def testGet(self):
     1105        self.assertEqual(self.db.get("t", 1, 'n')['d'], 0)
     1106        self.assertEqual(self.db.get("t0", 1, 'n')['d'], 0)
     1107        self.assertEqual(self.db.get("public.t", 1, 'n')['d'], 0)
     1108        self.assertEqual(self.db.get("public.t0", 1, 'n')['d'], 0)
     1109        self.assertRaises(pg.ProgrammingError, self.db.get, "public.t1", 1, 'n')
     1110        self.assertEqual(self.db.get("s1.t1", 1, 'n')['d'], 1)
     1111        self.assertEqual(self.db.get("s3.t", 1, 'n')['d'], 3)
     1112        self.db.query("set search_path to s2,s4")
     1113        self.assertRaises(pg.ProgrammingError, self.db.get, "t1", 1, 'n')
     1114        self.assertEqual(self.db.get("t4", 1, 'n')['d'], 4)
     1115        self.assertRaises(pg.ProgrammingError, self.db.get, "t3", 1, 'n')
     1116        self.assertEqual(self.db.get("t", 1, 'n')['d'], 2)
     1117        self.assertEqual(self.db.get("s3.t3", 1, 'n')['d'], 3)
     1118        self.db.query("set search_path to s1,s3")
     1119        self.assertRaises(pg.ProgrammingError, self.db.get, "t2", 1, 'n')
     1120        self.assertEqual(self.db.get("t3", 1, 'n')['d'], 3)
     1121        self.assertRaises(pg.ProgrammingError, self.db.get, "t4", 1, 'n')
     1122        self.assertEqual(self.db.get("t", 1, 'n')['d'], 1)
     1123        self.assertEqual(self.db.get("s4.t4", 1, 'n')['d'], 4)
     1124
     1125    def testMangling(self):
     1126        r = self.db.get("t", 1, 'n')
     1127        self.assert_('oid(public.t)' in r)
     1128        self.db.query("set search_path to s2")
     1129        r = self.db.get("t2", 1, 'n')
     1130        self.assert_('oid(s2.t2)' in r)
     1131        self.db.query("set search_path to s3")
     1132        r = self.db.get("t", 1, 'n')
     1133        self.assert_('oid(s3.t)' in r)
    11341134
    11351135
    11361136class DBTestSuite(unittest.TestSuite):
    1137         """Test suite that provides a test database."""
    1138 
    1139         dbname = "testpg_tempdb"
    1140 
    1141         # It would be too slow to create and drop the test database for
    1142         # every single test, so it is done once for the whole suite only.
    1143 
    1144         def setUp(self):
    1145                 dbname = self.dbname
    1146                 c = pg.connect("template1")
    1147                 try:
    1148                         c.query("drop database " + dbname)
    1149                 except pg.Error:
    1150                         pass
    1151                 c.query("create database " + dbname
    1152                         + " template=template0")
    1153                 for s in ('client_min_messages = warning',
    1154                         'default_with_oids = on',
    1155                         'standard_conforming_strings = off',
    1156                         'escape_string_warning = off'):
    1157                         smart_ddl(c, 'alter database %s set %s' % (dbname, s))
    1158                 c.close()
    1159                 c = pg.connect(dbname)
    1160                 smart_ddl(c, "create table test ("
    1161                         "i2 smallint, i4 integer, i8 bigint,"
    1162                         "d numeric, f4 real, f8 double precision, m money, "
    1163                         "v4 varchar(4), c4 char(4), t text)")
    1164                 c.query("create view test_view as"
    1165                         " select i4, v4 from test")
    1166                 for num_schema in range(5):
    1167                         if num_schema:
    1168                                 schema = "s%d" % num_schema
    1169                                 c.query("create schema " + schema)
    1170                         else:
    1171                                 schema = "public"
    1172                         smart_ddl(c, "create table %s.t"
    1173                                 " as select 1 as n, %d as d"
    1174                                 % (schema, num_schema))
    1175                         smart_ddl(c, "create table %s.t%d"
    1176                                 " as select 1 as n, %d as d"
    1177                                 % (schema, num_schema, num_schema))
    1178                 c.close()
    1179 
    1180         def tearDown(self):
    1181                 dbname = self.dbname
    1182                 c = pg.connect(dbname)
    1183                 c.query("checkpoint")
    1184                 c.close()
    1185                 c = pg.connect("template1")
    1186                 c.query("drop database " + dbname)
    1187                 c.close()
    1188 
    1189         def __call__(self, result):
    1190                 self.setUp()
    1191                 unittest.TestSuite.__call__(self, result)
    1192                 self.tearDown()
     1137    """Test suite that provides a test database."""
     1138
     1139    dbname = "testpg_tempdb"
     1140
     1141    # It would be too slow to create and drop the test database for
     1142    # every single test, so it is done once for the whole suite only.
     1143
     1144    def setUp(self):
     1145        dbname = self.dbname
     1146        c = pg.connect("template1")
     1147        try:
     1148            c.query("drop database " + dbname)
     1149        except pg.Error:
     1150            pass
     1151        c.query("create database " + dbname
     1152            + " template=template0")
     1153        for s in ('client_min_messages = warning',
     1154            'default_with_oids = on',
     1155            'standard_conforming_strings = off',
     1156            'escape_string_warning = off'):
     1157            smart_ddl(c, 'alter database %s set %s' % (dbname, s))
     1158        c.close()
     1159        c = pg.connect(dbname)
     1160        smart_ddl(c, "create table test ("
     1161            "i2 smallint, i4 integer, i8 bigint,"
     1162            "d numeric, f4 real, f8 double precision, m money, "
     1163            "v4 varchar(4), c4 char(4), t text)")
     1164        c.query("create view test_view as"
     1165            " select i4, v4 from test")
     1166        for num_schema in range(5):
     1167            if num_schema:
     1168                schema = "s%d" % num_schema
     1169                c.query("create schema " + schema)
     1170            else:
     1171                schema = "public"
     1172            smart_ddl(c, "create table %s.t"
     1173                " as select 1 as n, %d as d"
     1174                % (schema, num_schema))
     1175            smart_ddl(c, "create table %s.t%d"
     1176                " as select 1 as n, %d as d"
     1177                % (schema, num_schema, num_schema))
     1178        c.close()
     1179
     1180    def tearDown(self):
     1181        dbname = self.dbname
     1182        c = pg.connect(dbname)
     1183        c.query("checkpoint")
     1184        c.close()
     1185        c = pg.connect("template1")
     1186        c.query("drop database " + dbname)
     1187        c.close()
     1188
     1189    def __call__(self, result):
     1190        self.setUp()
     1191        unittest.TestSuite.__call__(self, result)
     1192        self.tearDown()
    11931193
    11941194
    11951195if __name__ == '__main__':
    11961196
    1197         # All tests that do not need a database:
    1198         TestSuite1 = unittest.TestSuite((
    1199                 unittest.makeSuite(TestAuxiliaryFunctions),
    1200                 unittest.makeSuite(TestHasConnect),
    1201                 unittest.makeSuite(TestEscapeFunctions),
    1202                 unittest.makeSuite(TestCanConnect),
    1203                 unittest.makeSuite(TestConnectObject),
    1204                 unittest.makeSuite(TestSimpleQueries),
    1205                 unittest.makeSuite(TestDBClassBasic),
    1206                 ))
    1207 
    1208         # All tests that need a test database:
    1209         TestSuite2 = DBTestSuite((
    1210                 unittest.makeSuite(TestInserttable),
    1211                 unittest.makeSuite(TestDBClass),
    1212                 unittest.makeSuite(TestSchemas),
    1213                 ))
    1214 
    1215         # All tests together in one test suite:
    1216         TestSuite = unittest.TestSuite((
    1217                 TestSuite1,
    1218                 TestSuite2
    1219         ))
    1220 
    1221         unittest.TextTestRunner(verbosity=2).run(TestSuite)
     1197    # All tests that do not need a database:
     1198    TestSuite1 = unittest.TestSuite((
     1199        unittest.makeSuite(TestAuxiliaryFunctions),
     1200        unittest.makeSuite(TestHasConnect),
     1201        unittest.makeSuite(TestEscapeFunctions),
     1202        unittest.makeSuite(TestCanConnect),
     1203        unittest.makeSuite(TestConnectObject),
     1204        unittest.makeSuite(TestSimpleQueries),
     1205        unittest.makeSuite(TestDBClassBasic),
     1206        ))
     1207
     1208    # All tests that need a test database:
     1209    TestSuite2 = DBTestSuite((
     1210        unittest.makeSuite(TestInserttable),
     1211        unittest.makeSuite(TestDBClass),
     1212        unittest.makeSuite(TestSchemas),
     1213        ))
     1214
     1215    # All tests together in one test suite:
     1216    TestSuite = unittest.TestSuite((
     1217        TestSuite1,
     1218        TestSuite2
     1219    ))
     1220
     1221    unittest.TextTestRunner(verbosity=2).run(TestSuite)
Note: See TracChangeset for help on using the changeset viewer.