Changeset 772 for trunk


Ignore:
Timestamp:
Jan 20, 2016, 5:44:20 PM (4 years ago)
Author:
cito
Message:

Improve test coverage for the pgdb module

Includes a simple patch that allows storing Python lists or tuple values
in PostgreSQL array fields (they are not yet converted when read, though).

Also re-activated the shortcut methods on the connection again
since they can be sometimes useful.

Test coverage is now around 95%, the remaining lines are due to support for
old Python versions or obscure database errors that can't easily be aroused.

Location:
trunk
Files:
4 edited

Legend:

Unmodified
Added
Removed
  • trunk/docs/contents/changelog.rst

    r770 r772  
    2626  The column names and types can now also be requested through the
    2727  colnames and coltypes attributes, which are not part of DB-API 2 though.
     28- Re-activated the shortcut methods of the DB-API connection since they
     29  can be handy when doing experiments or writing quick scripts. We keep
     30  them undocumented though and discourage using them in production.
    2831- The tty parameter and attribute of database connections has been
    2932  removed since it is not supported any more since PostgreSQL 7.4.
  • trunk/pgdb.py

    r740 r772  
    113113paramstyle = 'pyformat'
    114114
    115 # shortcut methods are not supported by default
    116 # since they have been excluded from DB API 2
    117 # and are not recommended by the DB SIG.
    118 
    119 shortcutmethods = 0
     115# shortcut methods have been excluded from DB API 2 and
     116# are not recommended by the DB SIG, but they can be handy
     117shortcutmethods = 1
    120118
    121119
     
    145143
    146144def _cast_float(value):
    147     try:
    148         return float(value)
    149     except ValueError:
    150         if value == 'NaN':
    151             return nan
    152         elif value == 'Infinity':
    153             return inf
    154         elif value == '-Infinity':
    155             return -inf
    156         raise
     145    return float(value)  # this also works with NaN and Infinity
    157146
    158147
     
    281270            val = 'NULL'
    282271        elif isinstance(val, (list, tuple)):
    283             val = '(%s)' % ','.join(map(lambda v: str(self._quote(v)), val))
     272            q = self._quote
     273            val = 'ARRAY[%s]' % ','.join(str(q(v)) for v in val)
    284274        elif Decimal is not float and isinstance(val, Decimal):
    285275            pass
     
    340330                    self._cnx.source().execute(sql)
    341331                except DatabaseError:
    342                     raise
    343                 except Exception:
     332                    raise  # database provides error message
     333                except Exception as err:
    344334                    raise _op_error("can't start transaction")
    345335                self._dbcnx._tnx = True
     
    355345                    self.rowcount = -1
    356346        except DatabaseError:
    357             raise
     347            raise  # database provides error message
    358348        except Error as err:
    359             raise _db_error("error in '%s': '%s' " % (sql, err))
     349            raise _db_error(
     350                "error in '%s': '%s' " % (sql, err), InterfaceError)
    360351        except Exception as err:
    361352            raise _op_error("internal error in '%s': %s" % (sql, err))
     
    494485            if size is None:
    495486                size = 8192
     487            elif not isinstance(size, int):
     488                raise TypeError("The size option must be an integer")
    496489            if size > 0:
    497                 if not isinstance(size, int):
    498                     raise TypeError("The size option must be an integer")
    499490
    500491                def chunks():
  • trunk/tests/test_dbapi20.py

    r761 r772  
    2929        pass
    3030
     31from datetime import datetime
     32
    3133try:
    3234    long
     
    3840except ImportError:  # Python 2.6 or 3.0
    3941    OrderedDict = None
     42
     43
     44class PgBitString:
     45    """Test object with a PostgreSQL representation as Bit String."""
     46
     47    def __init__(self, value):
     48        self.value = value
     49
     50    def __pg_repr__(self):
     51         return "B'{0:b}'".format(self.value)
    4052
    4153
     
    341353        self.assertTrue(isnan(nan) and not isinf(nan))
    342354        self.assertTrue(isinf(inf) and not isnan(inf))
    343         values = [0, 1, 0.03125, -42.53125, nan, inf, -inf]
     355        values = [0, 1, 0.03125, -42.53125, nan, inf, -inf,
     356            'nan', 'inf', '-inf', 'NaN', 'Infinity', '-Infinity']
    344357        table = self.table_prefix + 'booze'
    345358        con = self._connect()
     
    357370        rows = [row[1] for row in rows]
    358371        for inval, outval in zip(values, rows):
     372            if inval in ('inf', 'Infinity'):
     373                inval = inf
     374            elif inval in ('-inf', '-Infinity'):
     375                inval = -inf
     376            elif inval in ('nan', 'NaN'):
     377                inval = nan
    359378            if isinf(inval):
    360379                self.assertTrue(isinf(outval))
     
    367386            else:
    368387                self.assertEqual(inval, outval)
     388
     389    def test_datetime(self):
     390        values = ['2011-07-17 15:47:42', datetime(2016, 1, 20, 20, 15, 51)]
     391        table = self.table_prefix + 'booze'
     392        con = self._connect()
     393        try:
     394            cur = con.cursor()
     395            cur.execute("set datestyle to 'iso'")
     396            cur.execute(
     397                "create table %s (n smallint, ts timestamp)" % table)
     398            params = enumerate(values)
     399            cur.executemany("insert into %s values (%%s,%%s)" % table, params)
     400            cur.execute("select * from %s order by 1" % table)
     401            rows = cur.fetchall()
     402        finally:
     403            con.close()
     404        self.assertEqual(len(rows), len(values))
     405        rows = [row[1] for row in rows]
     406        for inval, outval in zip(values, rows):
     407            if isinstance(inval, datetime):
     408                inval = inval.strftime('%Y-%m-%d %H:%M:%S')
     409            self.assertEqual(inval, outval)
     410
     411    def test_array(self):
     412        values = ([20000, 25000, 25000, 30000],
     413            [['breakfast', 'consulting'], ['meeting', 'lunch']])
     414        output = ('{20000,25000,25000,30000}',
     415            '{{breakfast,consulting},{meeting,lunch}}')
     416        table = self.table_prefix + 'booze'
     417        con = self._connect()
     418        try:
     419            cur = con.cursor()
     420            cur.execute("create table %s (i int[], t text[][])" % table)
     421            cur.execute("insert into %s values (%%s,%%s)" % table, values)
     422            cur.execute("select * from %s" % table)
     423            row = cur.fetchone()
     424        finally:
     425            con.close()
     426        self.assertEqual(row, output)
     427
     428    def test_custom_type(self):
     429        values = [3, 5, 65]
     430        values = list(map(PgBitString, values))
     431        table = self.table_prefix + 'booze'
     432        con = self._connect()
     433        try:
     434            cur = con.cursor()
     435            params = enumerate(values)  # params have __pg_repr__ method
     436            cur.execute(
     437                'create table "%s" (n smallint, b bit varying(7))' % table)
     438            cur.executemany("insert into %s values (%%s,%%s)" % table, params)
     439            cur.execute("select * from %s order by 1" % table)
     440            rows = cur.fetchall()
     441        finally:
     442            con.close()
     443        self.assertEqual(len(rows), len(values))
     444        con = self._connect()
     445        try:
     446            cur = con.cursor()
     447            params = (1, object())  # an object that cannot be handled
     448            self.assertRaises(pgdb.InterfaceError, cur.execute,
     449                "insert into %s values (%%s,%%s)" % table, params)
     450        finally:
     451            con.close()
    369452
    370453    def test_set_decimal_type(self):
     
    473556        values[4] = values[6] = False
    474557        self.assertEqual(rows, values)
     558
     559    def test_execute_edge_cases(self):
     560        con = self._connect()
     561        try:
     562            cur = con.cursor()
     563            sql = 'invalid'  # should be ignored with empty parameter list
     564            cur.executemany(sql, [])
     565            sql = 'select %d + 1'
     566            cur.execute(sql, [(1,)])  # deprecated use of execute()
     567            self.assertEqual(cur.fetchone()[0], 2)
     568            sql = 'select 1/0'  # cannot be executed
     569            self.assertRaises(pgdb.ProgrammingError, cur.execute, sql)
     570            cur.close()
     571            con.rollback()
     572            if pgdb.shortcutmethods:
     573                res = con.execute('select %d', (1,)).fetchone()
     574                self.assertEqual(res, (1,))
     575                res = con.executemany('select %d', [(1,), (2,)]).fetchone()
     576                self.assertEqual(res, (2,))
     577        finally:
     578            con.close()
     579        sql = 'select 1'  # cannot be executed after connection is closed
     580        self.assertRaises(pgdb.OperationalError, cur.execute, sql)
     581
     582    def test_fetchmany_with_keep(self):
     583        con = self._connect()
     584        try:
     585            cur = con.cursor()
     586            self.assertEqual(cur.arraysize, 1)
     587            cur.execute('select * from generate_series(1, 25)')
     588            self.assertEqual(len(cur.fetchmany()), 1)
     589            self.assertEqual(len(cur.fetchmany()), 1)
     590            self.assertEqual(cur.arraysize, 1)
     591            cur.arraysize = 3
     592            self.assertEqual(len(cur.fetchmany()), 3)
     593            self.assertEqual(len(cur.fetchmany()), 3)
     594            self.assertEqual(cur.arraysize, 3)
     595            self.assertEqual(len(cur.fetchmany(size=2)), 2)
     596            self.assertEqual(cur.arraysize, 3)
     597            self.assertEqual(len(cur.fetchmany()), 3)
     598            self.assertEqual(len(cur.fetchmany()), 3)
     599            self.assertEqual(len(cur.fetchmany(size=2, keep=True)), 2)
     600            self.assertEqual(cur.arraysize, 2)
     601            self.assertEqual(len(cur.fetchmany()), 2)
     602            self.assertEqual(len(cur.fetchmany()), 2)
     603            self.assertEqual(len(cur.fetchmany(25)), 3)
     604        finally:
     605            con.close()
    475606
    476607    def test_nextset(self):
  • trunk/tests/test_dbapi20_copy.py

    r770 r772  
    5151    def __str__(self):
    5252        data = self.data
    53         if str is unicode:
     53        if str is unicode:  # Python >= 3.0
    5454            data = data.decode('utf-8')
    5555        return data
     
    7676    def __str__(self):
    7777        data = self.data
    78         if str is unicode:
     78        if str is unicode:  # Python >= 3.0
    7979            data = data.decode('utf-8')
    8080        return data
     
    221221             format='text', sep='\t', null='', columns=['id', 'name'])
    222222        self.assertRaises(TypeError, call)
     223        self.assertRaises(TypeError, call, None)
     224        self.assertRaises(TypeError, call, None, None)
    223225        self.assertRaises(TypeError, call, '0\t')
     226        self.assertRaises(TypeError, call, '0\t', None)
    224227        self.assertRaises(TypeError, call, '0\t', 42)
    225228        self.assertRaises(TypeError, call, '0\t', ['copytest'])
     
    231234        self.assertRaises(ValueError, call, '0\t', 'copytest', size='bad')
    232235        self.assertRaises(TypeError, call, '0\t', 'copytest', columns=42)
     236        self.assertRaises(ValueError, call, b'', 'copytest',
     237            format='binary', sep=',')
    233238
    234239    def test_input_string(self):
     
    249254        self.check_rowcount()
    250255
    251     if str is unicode:
     256    if str is unicode:  # Python >= 3.0
    252257
    253258        def test_input_bytes(self):
     
    258263            self.check_table()
    259264
    260     if str is not unicode:
     265    else:  # Python < 3.0
    261266
    262267        def test_input_unicode(self):
     
    272277        self.check_rowcount()
    273278
     279    def test_input_iterable_invalid(self):
     280        self.assertRaises(IOError, self.copy_from, [None])
     281
    274282    def test_input_iterable_with_newlines(self):
    275283        self.copy_from('%s\n' % row for row in self.data_text.splitlines())
    276284        self.check_table()
     285
     286    if str is unicode:  # Python >= 3.0
     287
     288        def test_input_iterable_bytes(self):
     289            self.copy_from(row.encode('utf-8')
     290                for row in self.data_text.splitlines())
     291            self.check_table()
    277292
    278293    def test_sep(self):
     
    367382        self.check_rowcount()
    368383
     384    def test_size_invalid(self):
     385        self.assertRaises(TypeError,
     386            self.copy_from, self.data_file, size='invalid')
     387
    369388
    370389class TestCopyTo(TestCopy):
     
    416435        self.check_rowcount()
    417436
    418     if str is unicode:
     437    if str is unicode:  # Python >= 3.0
    419438
    420439        def test_generator_bytes(self):
     
    427446            self.assertEqual(rows, self.data_text.encode('utf-8'))
    428447
    429     if str is not unicode:
     448    else:  # Python < 3.0
    430449
    431450        def test_generator_unicode(self):
     
    517536
    518537    def test_query(self):
     538        self.assertRaises(ValueError, self.cursor.copy_to, None,
     539            "select name from copytest", columns='noname')
    519540        ret = self.cursor.copy_to(None,
    520541            "select name||'!' from copytest where id=1941")
     
    532553        self.assertEqual(str(stream), self.data_text)
    533554        data = self.data_text
    534         if str is unicode:
     555        if str is unicode:  # Python >= 3.0
    535556            data = data.encode('utf-8')
    536557        sizes = [len(row) + 1 for row in data.splitlines()]
Note: See TracChangeset for help on using the changeset viewer.