Changeset 732 for trunk


Ignore:
Timestamp:
Jan 13, 2016, 7:45:34 AM (4 years ago)
Author:
cito
Message:

Better handling of quoted identifiers

Methods like get(), update() did not handle quoted identifiers properly
(i.e. identifiers with spaces, mixed case characters or special characters).
This has been improved and tests have been added to make sure this works.

Location:
trunk
Files:
2 edited

Legend:

Unmodified
Added
Removed
  • trunk/pg.py

    r730 r732  
    5050# Auxiliary functions that are independent from a DB connection:
    5151
    52 def _quote_class_name(cl):
    53     """Quote a class name.
    54 
    55     Class names are always quoted unless they contain a dot.
    56     In this ambiguous case quotes must be added manually.
    57 
    58     """
    59     if '.' not in cl:
    60         cl = '"%s"' % cl
    61     return cl
    62 
    63 
    64 def _quote_class_param(cl, param):
    65     """Quote parameter representing a class name.
    66 
    67     The parameter is automatically quoted unless the class name contains a dot.
    68     In this ambiguous case quotes must be added manually.
    69 
    70     """
    71     if isinstance(param, int):
    72         param = "$%d" % param
    73     if '.' not in cl:
    74         param = 'quote_ident(%s)' % (param,)
    75     return param
    76 
    77 
    7852def _oid_key(cl):
    79     """Build oid key from qualified class name."""
     53    """Build oid key from a class name."""
    8054    return 'oid(%s)' % cl
    8155
     
    329303                print(s)
    330304
     305    def _escape_qualified_name(self, s):
     306        """Escape a qualified name.
     307
     308        Escapes the name for use as an SQL identifier, unless the
     309        name contains a dot, in which case the name is ambiguous
     310        (could be a qualified name or just a name with a dot in it)
     311        and must be quoted manually by the caller.
     312
     313        """
     314        if '.' not in s:
     315            s = self.escape_identifier(s)
     316        return s
     317
    331318    @staticmethod
    332319    def _make_bool(d):
     
    362349
    363350    def _prepare_bytea(self, d):
     351        """Prepare a bytea parameter."""
    364352        return self.escape_bytea(d)
    365353
     
    384372        return '$%d' % len(params)
    385373
     374    @staticmethod
     375    def _prepare_qualified_param(cl, param):
     376        """Quote parameter representing a qualified name.
     377
     378        Escapes the name for use as an SQL parameter, unless the
     379        name contains a dot, in which case the name is ambiguous
     380        (could be a qualified name or just a name with a dot in it)
     381        and must be quoted manually by the caller.
     382
     383        """
     384        if isinstance(param, int):
     385            param = "$%d" % param
     386        if '.' not in cl:
     387            param = 'quote_ident(%s)' % (param,)
     388        return param
     389
    386390    # Public methods
    387391
     
    508512                " AND NOT a.attisdropped"
    509513                " WHERE i.indrelid=%s::regclass"
    510                 " AND i.indisprimary" % _quote_class_param(cl, 1))
     514                " AND i.indisprimary" % self._prepare_qualified_param(cl, 1))
    511515            pkey = self.db.query(q, (cl,)).getresult()
    512516            if not pkey:
     
    573577                " AND NOT a.attisdropped") % (
    574578                    '::regtype' if self._regtypes else '',
    575                     _quote_class_param(cl, 1))
     579                    self._prepare_qualified_param(cl, 1))
    576580            names = self.db.query(q, (cl,)).getresult()
    577581            if not names:
     
    602606        except KeyError:  # cache miss, ask the database
    603607            q = "SELECT has_table_privilege(%s, $2)" % (
    604                 _quote_class_param(cl, 1),)
     608                self._prepare_qualified_param(cl, 1),)
    605609            q = self.db.query(q, (cl, privilege))
    606610            ret = q.getresult()[0][0] == self._make_bool(True)
     
    637641        params = []
    638642        param = partial(self._prepare_param, params=params)
     643        col = self.escape_identifier
    639644        # We want the oid for later updates if that isn't the key
    640645        if keyname == 'oid':
     
    652657                if len(keyname) > 1:
    653658                    raise _prg_error('Composite key needs dict as arg')
    654                 arg = dict([(k, arg) for k in keyname])
    655             what = ', '.join(attnames)
     659                arg = dict((k, arg) for k in keyname)
     660            what = ', '.join(col(k) for k in attnames)
    656661            where = ' AND '.join(['%s = %s'
    657                 % (k, param(arg[k], attnames[k])) for k in keyname])
     662                % (col(k), param(arg[k], attnames[k])) for k in keyname])
    658663        q = 'SELECT %s FROM %s WHERE %s LIMIT 1' % (
    659             what, _quote_class_name(cl), where)
     664            what, self._escape_qualified_name(cl), where)
    660665        self._do_debug(q, params)
    661666        res = self.db.query(q, params).dictresult()
     
    694699        params = []
    695700        param = partial(self._prepare_param, params=params)
     701        col = self.escape_identifier
    696702        names, values = [], []
    697703        for n in attnames:
    698704            if n != 'oid' and n in d:
    699                 names.append('"%s"' % n)
     705                names.append(col(n))
    700706                values.append(param(d[n], attnames[n]))
    701707        names, values = ', '.join(names), ', '.join(values)
     
    706712            ret = ''
    707713        q = 'INSERT INTO %s (%s) VALUES (%s)%s' % (
    708             _quote_class_name(cl), names, values, ret)
     714            self._escape_qualified_name(cl), names, values, ret)
    709715        self._do_debug(q, params)
    710716        res = self.db.query(q, params)
     
    754760        params = []
    755761        param = partial(self._prepare_param, params=params)
     762        col = self.escape_identifier
    756763        if qoid in d:
    757764            where = 'oid = %s' % param(d[qoid], 'int')
     
    766773            try:
    767774                where = ' AND '.join(['%s = %s'
    768                     % (k, param(d[k], attnames[k])) for k in keyname])
     775                    % (col(k), param(d[k], attnames[k])) for k in keyname])
    769776            except KeyError:
    770777                raise _prg_error('Update needs primary key or oid.')
     
    772779        for n in attnames:
    773780            if n in d and n not in keyname:
    774                 values.append('%s = %s' % (n, param(d[n], attnames[n])))
     781                values.append('%s = %s' % (col(n), param(d[n], attnames[n])))
    775782        if not values:
    776783            return d
     
    782789            ret = ''
    783790        q = 'UPDATE %s SET %s WHERE %s%s' % (
    784             _quote_class_name(cl), values, where, ret)
     791            self._escape_qualified_name(cl), values, where, ret)
    785792        self._do_debug(q, params)
    786793        res = self.db.query(q, params)
     
    859866                keyname = (keyname,)
    860867            attnames = self.get_attnames(cl)
     868            col = self.escape_identifier
    861869            try:
    862870                where = ' AND '.join(['%s = %s'
    863                     % (k, param(d[k], attnames[k])) for k in keyname])
     871                    % (col(k), param(d[k], attnames[k])) for k in keyname])
    864872            except KeyError:
    865873                raise _prg_error('Delete needs primary key or oid.')
    866         q = 'DELETE FROM %s WHERE %s' % (_quote_class_name(cl), where)
     874        q = 'DELETE FROM %s WHERE %s' % (
     875            self._escape_qualified_name(cl), where)
    867876        self._do_debug(q, params)
    868877        return int(self.db.query(q, params))
  • trunk/tests/test_classic_dbwrapper.py

    r730 r732  
    636636        get = self.db.get
    637637        query = self.db.query
    638         for table in ('get_test_table', 'test table for get'):
    639             query('drop table if exists "%s"' % table)
    640             query('create table "%s" ('
    641                 "n integer, t text) with oids" % table)
    642             for n, t in enumerate('xyz'):
    643                 query('insert into "%s" values('"%d, '%s')"
    644                     % (table, n + 1, t))
    645             self.assertRaises(pg.ProgrammingError, get, table, 2)
    646             r = get(table, 2, 'n')
    647             oid_table = 'oid(%s)' % table
    648             self.assertIn(oid_table, r)
    649             oid = r[oid_table]
    650             self.assertIsInstance(oid, int)
    651             result = {'t': 'y', 'n': 2, oid_table: oid}
    652             self.assertEqual(r, result)
    653             self.assertEqual(get(table + ' *', 2, 'n'), r)
    654             self.assertEqual(get(table, oid, 'oid')['t'], 'y')
    655             self.assertEqual(get(table, 1, 'n')['t'], 'x')
    656             self.assertEqual(get(table, 3, 'n')['t'], 'z')
    657             self.assertEqual(get(table, 2, 'n')['t'], 'y')
    658             self.assertRaises(pg.DatabaseError, get, table, 4, 'n')
    659             r['n'] = 3
    660             self.assertEqual(get(table, r, 'n')['t'], 'z')
    661             self.assertEqual(get(table, 1, 'n')['t'], 'x')
    662             query('alter table "%s" alter n set not null' % table)
    663             query('alter table "%s" add primary key (n)' % table)
    664             self.assertEqual(get(table, 3)['t'], 'z')
    665             self.assertEqual(get(table, 1)['t'], 'x')
    666             self.assertEqual(get(table, 2)['t'], 'y')
    667             r['n'] = 1
    668             self.assertEqual(get(table, r)['t'], 'x')
    669             r['n'] = 3
    670             self.assertEqual(get(table, r)['t'], 'z')
    671             r['n'] = 2
    672             self.assertEqual(get(table, r)['t'], 'y')
    673             query('drop table "%s"' % table)
     638        table = 'get_test_table'
     639        query('drop table if exists "%s"' % table)
     640        query('create table "%s" ('
     641            "n integer, t text) with oids" % table)
     642        for n, t in enumerate('xyz'):
     643            query('insert into "%s" values('"%d, '%s')"
     644                % (table, n + 1, t))
     645        self.assertRaises(pg.ProgrammingError, get, table, 2)
     646        r = get(table, 2, 'n')
     647        oid_table = 'oid(%s)' % table
     648        self.assertIn(oid_table, r)
     649        oid = r[oid_table]
     650        self.assertIsInstance(oid, int)
     651        result = {'t': 'y', 'n': 2, oid_table: oid}
     652        self.assertEqual(r, result)
     653        self.assertEqual(get(table + ' *', 2, 'n'), r)
     654        self.assertEqual(get(table, oid, 'oid')['t'], 'y')
     655        self.assertEqual(get(table, 1, 'n')['t'], 'x')
     656        self.assertEqual(get(table, 3, 'n')['t'], 'z')
     657        self.assertEqual(get(table, 2, 'n')['t'], 'y')
     658        self.assertRaises(pg.DatabaseError, get, table, 4, 'n')
     659        r['n'] = 3
     660        self.assertEqual(get(table, r, 'n')['t'], 'z')
     661        self.assertEqual(get(table, 1, 'n')['t'], 'x')
     662        query('alter table "%s" alter n set not null' % table)
     663        query('alter table "%s" add primary key (n)' % table)
     664        self.assertEqual(get(table, 3)['t'], 'z')
     665        self.assertEqual(get(table, 1)['t'], 'x')
     666        self.assertEqual(get(table, 2)['t'], 'y')
     667        r['n'] = 1
     668        self.assertEqual(get(table, r)['t'], 'x')
     669        r['n'] = 3
     670        self.assertEqual(get(table, r)['t'], 'z')
     671        r['n'] = 2
     672        self.assertEqual(get(table, r)['t'], 'y')
     673        query('drop table "%s"' % table)
    674674
    675675    def testGetWithCompositeKey(self):
     
    677677        query = self.db.query
    678678        table = 'get_test_table_1'
    679         query("drop table if exists %s" % table)
    680         query("create table %s ("
     679        query('drop table if exists "%s"' % table)
     680        query('create table "%s" ('
    681681            "n integer, t text, primary key (n))" % table)
    682682        for n, t in enumerate('abc'):
    683             query("insert into %s values("
     683            query('insert into "%s" values('
    684684                "%d, '%s')" % (table, n + 1, t))
    685685        self.assertEqual(get(table, 2)['t'], 'b')
    686         query("drop table %s" % table)
     686        query('drop table "%s"' % table)
    687687        table = 'get_test_table_2'
    688         query("drop table if exists %s" % table)
    689         query("create table %s ("
     688        query('drop table if exists "%s"' % table)
     689        query('create table "%s" ('
    690690            "n integer, m integer, t text, primary key (n, m))" % table)
    691691        for n in range(3):
    692692            for m in range(2):
    693693                t = chr(ord('a') + 2 * n + m)
    694                 query("insert into %s values("
     694                query('insert into "%s" values('
    695695                    "%d, %d, '%s')" % (table, n + 1, m + 1, t))
    696696        self.assertRaises(pg.ProgrammingError, get, table, 2)
     
    700700        self.assertEqual(get(table, dict(n=3, m=2),
    701701                             frozenset(['n', 'm']))['t'], 'f')
    702         query("drop table %s" % table)
     702        query('drop table "%s"' % table)
     703
     704    def testGetWithQuotedNames(self):
     705        get = self.db.get
     706        query = self.db.query
     707        table = 'test table for get()'
     708        query('drop table if exists "%s"' % table)
     709        query('create table "%s" ('
     710            '"Prime!" smallint primary key,'
     711            '"much space" integer, "Questions?" text)' % table)
     712        query('insert into "%s"'
     713              " values(17, 1001, 'No!')" % table)
     714        r = get(table, 17)
     715        self.assertIsInstance(r, dict)
     716        self.assertEqual(r['Prime!'], 17)
     717        self.assertEqual(r['much space'], 1001)
     718        self.assertEqual(r['Questions?'], 'No!')
     719        query('drop table "%s"' % table)
    703720
    704721    def testGetFromView(self):
     
    715732        bool_on = pg.get_bool()
    716733        decimal = pg.get_decimal()
    717         for table in ('insert_test_table', 'test table for insert'):
    718             query('drop table if exists "%s"' % table)
    719             query('create table "%s" ('
    720                 "i2 smallint, i4 integer, i8 bigint,"
    721                 " d numeric, f4 real, f8 double precision, m money,"
    722                 " v4 varchar(4), c4 char(4), t text,"
    723                 " b boolean, ts timestamp) with oids" % table)
    724             oid_table = 'oid(%s)' % table
    725             tests = [dict(i2=None, i4=None, i8=None),
    726                 (dict(i2='', i4='', i8=''), dict(i2=None, i4=None, i8=None)),
    727                 (dict(i2=0, i4=0, i8=0), dict(i2=0, i4=0, i8=0)),
    728                 dict(i2=42, i4=123456, i8=9876543210),
    729                 dict(i2=2 ** 15 - 1,
    730                     i4=int(2 ** 31 - 1), i8=long(2 ** 63 - 1)),
    731                 dict(d=None), (dict(d=''), dict(d=None)),
    732                 dict(d=Decimal(0)), (dict(d=0), dict(d=Decimal(0))),
    733                 dict(f4=None, f8=None), dict(f4=0, f8=0),
    734                 (dict(f4='', f8=''), dict(f4=None, f8=None)),
    735                 (dict(d=1234.5, f4=1234.5, f8=1234.5),
    736                       dict(d=Decimal('1234.5'))),
    737                 dict(d=Decimal('123.456789'), f4=12.375, f8=123.4921875),
    738                 dict(d=Decimal('123456789.9876543212345678987654321')),
    739                 dict(m=None), (dict(m=''), dict(m=None)),
    740                 dict(m=Decimal('-1234.56')),
    741                 (dict(m=('-1234.56')), dict(m=Decimal('-1234.56'))),
    742                 dict(m=Decimal('1234.56')), dict(m=Decimal('123456')),
    743                 (dict(m='1234.56'), dict(m=Decimal('1234.56'))),
    744                 (dict(m=1234.5), dict(m=Decimal('1234.5'))),
    745                 (dict(m=-1234.5), dict(m=Decimal('-1234.5'))),
    746                 (dict(m=123456), dict(m=Decimal('123456'))),
    747                 (dict(m='1234567.89'), dict(m=Decimal('1234567.89'))),
    748                 dict(b=None), (dict(b=''), dict(b=None)),
    749                 dict(b='f'), dict(b='t'),
    750                 (dict(b=0), dict(b='f')), (dict(b=1), dict(b='t')),
    751                 (dict(b=False), dict(b='f')), (dict(b=True), dict(b='t')),
    752                 (dict(b='0'), dict(b='f')), (dict(b='1'), dict(b='t')),
    753                 (dict(b='n'), dict(b='f')), (dict(b='y'), dict(b='t')),
    754                 (dict(b='no'), dict(b='f')), (dict(b='yes'), dict(b='t')),
    755                 (dict(b='off'), dict(b='f')), (dict(b='on'), dict(b='t')),
    756                 dict(v4=None, c4=None, t=None),
    757                 (dict(v4='', c4='', t=''), dict(c4=' ' * 4)),
    758                 dict(v4='1234', c4='1234', t='1234' * 10),
    759                 dict(v4='abcd', c4='abcd', t='abcdefg'),
    760                 (dict(v4='abc', c4='abc', t='abc'), dict(c4='abc ')),
    761                 dict(ts=None), (dict(ts=''), dict(ts=None)),
    762                 (dict(ts=0), dict(ts=None)), (dict(ts=False), dict(ts=None)),
    763                 dict(ts='2012-12-21 00:00:00'),
    764                 (dict(ts='2012-12-21'), dict(ts='2012-12-21 00:00:00')),
    765                 dict(ts='2012-12-21 12:21:12'),
    766                 dict(ts='2013-01-05 12:13:14'),
    767                 dict(ts='current_timestamp')]
    768             for test in tests:
    769                 if isinstance(test, dict):
    770                     data = test
    771                     change = {}
     734        table = 'insert_test_table'
     735        query('drop table if exists "%s"' % table)
     736        query('create table "%s" ('
     737            "i2 smallint, i4 integer, i8 bigint,"
     738            " d numeric, f4 real, f8 double precision, m money,"
     739            " v4 varchar(4), c4 char(4), t text,"
     740            " b boolean, ts timestamp) with oids" % table)
     741        oid_table = 'oid(%s)' % table
     742        tests = [dict(i2=None, i4=None, i8=None),
     743            (dict(i2='', i4='', i8=''), dict(i2=None, i4=None, i8=None)),
     744            (dict(i2=0, i4=0, i8=0), dict(i2=0, i4=0, i8=0)),
     745            dict(i2=42, i4=123456, i8=9876543210),
     746            dict(i2=2 ** 15 - 1,
     747                i4=int(2 ** 31 - 1), i8=long(2 ** 63 - 1)),
     748            dict(d=None), (dict(d=''), dict(d=None)),
     749            dict(d=Decimal(0)), (dict(d=0), dict(d=Decimal(0))),
     750            dict(f4=None, f8=None), dict(f4=0, f8=0),
     751            (dict(f4='', f8=''), dict(f4=None, f8=None)),
     752            (dict(d=1234.5, f4=1234.5, f8=1234.5),
     753                  dict(d=Decimal('1234.5'))),
     754            dict(d=Decimal('123.456789'), f4=12.375, f8=123.4921875),
     755            dict(d=Decimal('123456789.9876543212345678987654321')),
     756            dict(m=None), (dict(m=''), dict(m=None)),
     757            dict(m=Decimal('-1234.56')),
     758            (dict(m=('-1234.56')), dict(m=Decimal('-1234.56'))),
     759            dict(m=Decimal('1234.56')), dict(m=Decimal('123456')),
     760            (dict(m='1234.56'), dict(m=Decimal('1234.56'))),
     761            (dict(m=1234.5), dict(m=Decimal('1234.5'))),
     762            (dict(m=-1234.5), dict(m=Decimal('-1234.5'))),
     763            (dict(m=123456), dict(m=Decimal('123456'))),
     764            (dict(m='1234567.89'), dict(m=Decimal('1234567.89'))),
     765            dict(b=None), (dict(b=''), dict(b=None)),
     766            dict(b='f'), dict(b='t'),
     767            (dict(b=0), dict(b='f')), (dict(b=1), dict(b='t')),
     768            (dict(b=False), dict(b='f')), (dict(b=True), dict(b='t')),
     769            (dict(b='0'), dict(b='f')), (dict(b='1'), dict(b='t')),
     770            (dict(b='n'), dict(b='f')), (dict(b='y'), dict(b='t')),
     771            (dict(b='no'), dict(b='f')), (dict(b='yes'), dict(b='t')),
     772            (dict(b='off'), dict(b='f')), (dict(b='on'), dict(b='t')),
     773            dict(v4=None, c4=None, t=None),
     774            (dict(v4='', c4='', t=''), dict(c4=' ' * 4)),
     775            dict(v4='1234', c4='1234', t='1234' * 10),
     776            dict(v4='abcd', c4='abcd', t='abcdefg'),
     777            (dict(v4='abc', c4='abc', t='abc'), dict(c4='abc ')),
     778            dict(ts=None), (dict(ts=''), dict(ts=None)),
     779            (dict(ts=0), dict(ts=None)), (dict(ts=False), dict(ts=None)),
     780            dict(ts='2012-12-21 00:00:00'),
     781            (dict(ts='2012-12-21'), dict(ts='2012-12-21 00:00:00')),
     782            dict(ts='2012-12-21 12:21:12'),
     783            dict(ts='2013-01-05 12:13:14'),
     784            dict(ts='current_timestamp')]
     785        for test in tests:
     786            if isinstance(test, dict):
     787                data = test
     788                change = {}
     789            else:
     790                data, change = test
     791            expect = data.copy()
     792            expect.update(change)
     793            if bool_on:
     794                b = expect.get('b')
     795                if b is not None:
     796                    expect['b'] = b == 't'
     797            if decimal is not Decimal:
     798                d = expect.get('d')
     799                if d is not None:
     800                    expect['d'] = decimal(d)
     801                m = expect.get('m')
     802                if m is not None:
     803                    expect['m'] = decimal(m)
     804            self.assertEqual(insert(table, data), data)
     805            self.assertIn(oid_table, data)
     806            oid = data[oid_table]
     807            self.assertIsInstance(oid, int)
     808            data = dict(item for item in data.items()
     809                if item[0] in expect)
     810            ts = expect.get('ts')
     811            if ts == 'current_timestamp':
     812                ts = expect['ts'] = data['ts']
     813                if len(ts) > 19:
     814                    self.assertEqual(ts[19], '.')
     815                    ts = ts[:19]
    772816                else:
    773                     data, change = test
    774                 expect = data.copy()
    775                 expect.update(change)
    776                 if bool_on:
    777                     b = expect.get('b')
    778                     if b is not None:
    779                         expect['b'] = b == 't'
    780                 if decimal is not Decimal:
    781                     d = expect.get('d')
    782                     if d is not None:
    783                         expect['d'] = decimal(d)
    784                     m = expect.get('m')
    785                     if m is not None:
    786                         expect['m'] = decimal(m)
    787                 self.assertEqual(insert(table, data), data)
    788                 self.assertIn(oid_table, data)
    789                 oid = data[oid_table]
    790                 self.assertIsInstance(oid, int)
    791                 data = dict(item for item in data.items()
    792                     if item[0] in expect)
    793                 ts = expect.get('ts')
    794                 if ts == 'current_timestamp':
    795                     ts = expect['ts'] = data['ts']
    796                     if len(ts) > 19:
    797                         self.assertEqual(ts[19], '.')
    798                         ts = ts[:19]
    799                     else:
    800                         self.assertEqual(len(ts), 19)
    801                     self.assertTrue(ts[:4].isdigit())
    802                     self.assertEqual(ts[4], '-')
    803                     self.assertEqual(ts[10], ' ')
    804                     self.assertTrue(ts[11:13].isdigit())
    805                     self.assertEqual(ts[13], ':')
    806                 self.assertEqual(data, expect)
    807                 data = query(
    808                     'select oid,* from "%s"' % table).dictresult()[0]
    809                 self.assertEqual(data['oid'], oid)
    810                 data = dict(item for item in data.items()
    811                     if item[0] in expect)
    812                 self.assertEqual(data, expect)
    813                 query('delete from "%s"' % table)
    814             query('drop table "%s"' % table)
     817                    self.assertEqual(len(ts), 19)
     818                self.assertTrue(ts[:4].isdigit())
     819                self.assertEqual(ts[4], '-')
     820                self.assertEqual(ts[10], ' ')
     821                self.assertTrue(ts[11:13].isdigit())
     822                self.assertEqual(ts[13], ':')
     823            self.assertEqual(data, expect)
     824            data = query(
     825                'select oid,* from "%s"' % table).dictresult()[0]
     826            self.assertEqual(data['oid'], oid)
     827            data = dict(item for item in data.items()
     828                if item[0] in expect)
     829            self.assertEqual(data, expect)
     830            query('delete from "%s"' % table)
     831        query('drop table "%s"' % table)
     832
     833    def testInsertWithQuotedNames(self):
     834        insert = self.db.insert
     835        query = self.db.query
     836        table = 'test table for insert()'
     837        query('drop table if exists "%s"' % table)
     838        query('create table "%s" ('
     839            '"Prime!" smallint primary key,'
     840            '"much space" integer, "Questions?" text)' % table)
     841        r = {'Prime!': 11, 'much space': 2002, 'Questions?': 'What?'}
     842        r = insert(table, r)
     843        self.assertIsInstance(r, dict)
     844        self.assertEqual(r['Prime!'], 11)
     845        self.assertEqual(r['much space'], 2002)
     846        self.assertEqual(r['Questions?'], 'What?')
     847        r = query('select * from "%s" limit 2' % table).dictresult()
     848        self.assertEqual(len(r), 1)
     849        r = r[0]
     850        self.assertEqual(r['Prime!'], 11)
     851        self.assertEqual(r['much space'], 2002)
     852        self.assertEqual(r['Questions?'], 'What?')
     853        query('drop table "%s"' % table)
    815854
    816855    def testUpdate(self):
    817856        update = self.db.update
    818857        query = self.db.query
    819         for table in ('update_test_table', 'test table for update'):
    820             query('drop table if exists "%s"' % table)
    821             query('create table "%s" ('
    822                 "n integer, t text) with oids" % table)
    823             for n, t in enumerate('xyz'):
    824                 query('insert into "%s" values('
    825                     "%d, '%s')" % (table, n + 1, t))
    826             self.assertRaises(pg.ProgrammingError, self.db.get, table, 2)
    827             r = self.db.get(table, 2, 'n')
    828             r['t'] = 'u'
    829             s = update(table, r)
    830             self.assertEqual(s, r)
    831             r = query('select t from "%s" where n=2' % table
    832                       ).getresult()[0][0]
    833             self.assertEqual(r, 'u')
    834             query('drop table "%s"' % table)
     858        table = 'update_test_table'
     859        query('drop table if exists "%s"' % table)
     860        query('create table "%s" ('
     861            "n integer, t text) with oids" % table)
     862        for n, t in enumerate('xyz'):
     863            query('insert into "%s" values('
     864                "%d, '%s')" % (table, n + 1, t))
     865        self.assertRaises(pg.ProgrammingError, self.db.get, table, 2)
     866        r = self.db.get(table, 2, 'n')
     867        r['t'] = 'u'
     868        s = update(table, r)
     869        self.assertEqual(s, r)
     870        r = query('select t from "%s" where n=2' % table
     871                  ).getresult()[0][0]
     872        self.assertEqual(r, 'u')
     873        query('drop table "%s"' % table)
    835874
    836875    def testUpdateWithCompositeKey(self):
     
    838877        query = self.db.query
    839878        table = 'update_test_table_1'
    840         query("drop table if exists %s" % table)
    841         query("create table %s ("
     879        query('drop table if exists "%s"' % table)
     880        query('create table "%s" ('
    842881            "n integer, t text, primary key (n))" % table)
    843882        for n, t in enumerate('abc'):
    844             query("insert into %s values("
     883            query('insert into "%s" values('
    845884                "%d, '%s')" % (table, n + 1, t))
    846885        self.assertRaises(pg.ProgrammingError, update,
     
    850889                  ).getresult()[0][0]
    851890        self.assertEqual(r, 'd')
    852         query("drop table %s" % table)
     891        query('drop table "%s"' % table)
    853892        table = 'update_test_table_2'
    854         query("drop table if exists %s" % table)
    855         query("create table %s ("
     893        query('drop table if exists "%s"' % table)
     894        query('create table "%s" ('
    856895            "n integer, m integer, t text, primary key (n, m))" % table)
    857896        for n in range(3):
    858897            for m in range(2):
    859898                t = chr(ord('a') + 2 * n + m)
    860                 query("insert into %s values("
     899                query('insert into "%s" values('
    861900                    "%d, %d, '%s')" % (table, n + 1, m + 1, t))
    862901        self.assertRaises(pg.ProgrammingError, update,
     
    867906            ' order by m' % table).getresult()]
    868907        self.assertEqual(r, ['c', 'x'])
    869         query("drop table %s" % table)
     908        query('drop table "%s"' % table)
     909
     910    def testUpdateWithQuotedNames(self):
     911        update = self.db.update
     912        query = self.db.query
     913        table = 'test table for update()'
     914        query('drop table if exists "%s"' % table)
     915        query('create table "%s" ('
     916            '"Prime!" smallint primary key,'
     917            '"much space" integer, "Questions?" text)' % table)
     918        query('insert into "%s"'
     919              " values(13, 3003, 'Why!')" % table)
     920        r = {'Prime!': 13, 'much space': 7007, 'Questions?': 'When?'}
     921        r = update(table, r)
     922        self.assertIsInstance(r, dict)
     923        self.assertEqual(r['Prime!'], 13)
     924        self.assertEqual(r['much space'], 7007)
     925        self.assertEqual(r['Questions?'], 'When?')
     926        r = query('select * from "%s" limit 2' % table).dictresult()
     927        self.assertEqual(len(r), 1)
     928        r = r[0]
     929        self.assertEqual(r['Prime!'], 13)
     930        self.assertEqual(r['much space'], 7007)
     931        self.assertEqual(r['Questions?'], 'When?')
     932        query('drop table "%s"' % table)
    870933
    871934    def testClear(self):
     
    873936        query = self.db.query
    874937        f = False if pg.get_bool() else 'f'
    875         for table in ('clear_test_table', 'test table for clear'):
    876             query('drop table if exists "%s"' % table)
    877             query('create table "%s" ('
    878                 "n integer, b boolean, d date, t text)" % table)
    879             r = clear(table)
    880             result = {'n': 0, 'b': f, 'd': '', 't': ''}
    881             self.assertEqual(r, result)
    882             r['a'] = r['n'] = 1
    883             r['d'] = r['t'] = 'x'
    884             r['b'] = 't'
    885             r['oid'] = long(1)
    886             r = clear(table, r)
    887             result = {'a': 1, 'n': 0, 'b': f, 'd': '', 't': '',
    888                 'oid': long(1)}
    889             self.assertEqual(r, result)
    890             query('drop table "%s"' % table)
     938        table = 'clear_test_table'
     939        query('drop table if exists "%s"' % table)
     940        query('create table "%s" ('
     941            "n integer, b boolean, d date, t text)" % table)
     942        r = clear(table)
     943        result = {'n': 0, 'b': f, 'd': '', 't': ''}
     944        self.assertEqual(r, result)
     945        r['a'] = r['n'] = 1
     946        r['d'] = r['t'] = 'x'
     947        r['b'] = 't'
     948        r['oid'] = long(1)
     949        r = clear(table, r)
     950        result = {'a': 1, 'n': 0, 'b': f, 'd': '', 't': '',
     951            'oid': long(1)}
     952        self.assertEqual(r, result)
     953        query('drop table "%s"' % table)
     954
     955    def testClearWithQuotedNames(self):
     956        clear = self.db.clear
     957        query = self.db.query
     958        table = 'test table for clear()'
     959        query('drop table if exists "%s"' % table)
     960        query('create table "%s" ('
     961            '"Prime!" smallint primary key,'
     962            '"much space" integer, "Questions?" text)' % table)
     963        r = clear(table)
     964        self.assertIsInstance(r, dict)
     965        self.assertEqual(r['Prime!'], 0)
     966        self.assertEqual(r['much space'], 0)
     967        self.assertEqual(r['Questions?'], '')
     968        query('drop table "%s"' % table)
    891969
    892970    def testDelete(self):
    893971        delete = self.db.delete
    894972        query = self.db.query
    895         for table in ('delete_test_table', 'test table for delete'):
    896             query('drop table if exists "%s"' % table)
    897             query('create table "%s" ('
    898                 "n integer, t text) with oids" % table)
    899             for n, t in enumerate('xyz'):
    900                 query('insert into "%s" values('
    901                     "%d, '%s')" % (table, n + 1, t))
    902             self.assertRaises(pg.ProgrammingError, self.db.get, table, 2)
    903             r = self.db.get(table, 1, 'n')
    904             s = delete(table, r)
    905             self.assertEqual(s, 1)
    906             r = self.db.get(table, 3, 'n')
    907             s = delete(table, r)
    908             self.assertEqual(s, 1)
    909             s = delete(table, r)
    910             self.assertEqual(s, 0)
    911             r = query('select * from "%s"' % table).dictresult()
    912             self.assertEqual(len(r), 1)
    913             r = r[0]
    914             result = {'n': 2, 't': 'y'}
    915             self.assertEqual(r, result)
    916             r = self.db.get(table, 2, 'n')
    917             s = delete(table, r)
    918             self.assertEqual(s, 1)
    919             s = delete(table, r)
    920             self.assertEqual(s, 0)
    921             self.assertRaises(pg.DatabaseError, self.db.get, table, 2, 'n')
    922             query('drop table "%s"' % table)
     973        table = 'delete_test_table'
     974        query('drop table if exists "%s"' % table)
     975        query('create table "%s" ('
     976            "n integer, t text) with oids" % table)
     977        for n, t in enumerate('xyz'):
     978            query('insert into "%s" values('
     979                "%d, '%s')" % (table, n + 1, t))
     980        self.assertRaises(pg.ProgrammingError, self.db.get, table, 2)
     981        r = self.db.get(table, 1, 'n')
     982        s = delete(table, r)
     983        self.assertEqual(s, 1)
     984        r = self.db.get(table, 3, 'n')
     985        s = delete(table, r)
     986        self.assertEqual(s, 1)
     987        s = delete(table, r)
     988        self.assertEqual(s, 0)
     989        r = query('select * from "%s"' % table).dictresult()
     990        self.assertEqual(len(r), 1)
     991        r = r[0]
     992        result = {'n': 2, 't': 'y'}
     993        self.assertEqual(r, result)
     994        r = self.db.get(table, 2, 'n')
     995        s = delete(table, r)
     996        self.assertEqual(s, 1)
     997        s = delete(table, r)
     998        self.assertEqual(s, 0)
     999        self.assertRaises(pg.DatabaseError, self.db.get, table, 2, 'n')
     1000        query('drop table "%s"' % table)
    9231001
    9241002    def testDeleteWithCompositeKey(self):
    9251003        query = self.db.query
    9261004        table = 'delete_test_table_1'
    927         query("drop table if exists %s" % table)
    928         query("create table %s ("
     1005        query('drop table if exists "%s"' % table)
     1006        query('create table "%s" ('
    9291007            "n integer, t text, primary key (n))" % table)
    9301008        for n, t in enumerate('abc'):
     
    9411019                  ).getresult()[0][0]
    9421020        self.assertEqual(r, 'c')
    943         query("drop table %s" % table)
     1021        query('drop table "%s"' % table)
    9441022        table = 'delete_test_table_2'
    945         query("drop table if exists %s" % table)
    946         query("create table %s ("
     1023        query('drop table if exists "%s"' % table)
     1024        query('create table "%s" ('
    9471025            "n integer, m integer, t text, primary key (n, m))" % table)
    9481026        for n in range(3):
    9491027            for m in range(2):
    9501028                t = chr(ord('a') + 2 * n + m)
    951                 query("insert into %s values("
     1029                query('insert into "%s" values('
    9521030                    "%d, %d, '%s')" % (table, n + 1, m + 1, t))
    9531031        self.assertRaises(pg.ProgrammingError, self.db.delete,
     
    9651043            ' order by m' % table).getresult()]
    9661044        self.assertEqual(r, ['f'])
    967         query("drop table %s" % table)
     1045        query('drop table "%s"' % table)
     1046
     1047    def testDeleteWithQuotedNames(self):
     1048        delete = self.db.delete
     1049        query = self.db.query
     1050        table = 'test table for delete()'
     1051        query('drop table if exists "%s"' % table)
     1052        query('create table "%s" ('
     1053            '"Prime!" smallint primary key,'
     1054            '"much space" integer, "Questions?" text)' % table)
     1055        query('insert into "%s"'
     1056              " values(19, 5005, 'Yes!')" % table)
     1057        r = {'Prime!': 17}
     1058        r = delete(table, r)
     1059        self.assertEqual(r, 0)
     1060        r = query('select count(*) from "%s"' % table).getresult()
     1061        self.assertEqual(r[0][0], 1)
     1062        r = {'Prime!': 19}
     1063        r = delete(table, r)
     1064        self.assertEqual(r, 1)
     1065        r = query('select count(*) from "%s"' % table).getresult()
     1066        self.assertEqual(r[0][0], 0)
     1067        query('drop table "%s"' % table)
    9681068
    9691069    def testTransaction(self):
Note: See TracChangeset for help on using the changeset viewer.