source: trunk/tests/test_dbapi20.py @ 814

Last change on this file since 814 was 814, checked in by cito, 4 years ago

Add typecasting of dates, times, timestamps, intervals

So far, PyGreSQL has returned these types only as strings (in various
formats depending on the DateStyle? setting) and left it to the user
to parse and interpret the strings. These types are now properly cast
into the corresponding detetime types of Python, and this works with
any setting of DatesStyle?, even if you change DateStyle? in the middle
of a database session.

To implement this, a fast method for getting the datestyle (cached and
without roundtrip to the database) has been added. Also, the typecast
mechanism has been extended so that typecast functions can optionally
also take the connection as argument.

The date and time typecast functions have been implemented in Python
using the new typecast registry and added to both pg and pgdb. Some
duplication of code in the two modules was unavoidable, since we don't
want the modules to be dependent of each other or install additional
helper modules. One day we might want to change this, put everything
in one package and factor out some of the functionality.

  • Property svn:executable set to *
  • Property svn:keywords set to Id
File size: 42.2 KB
Line 
1#! /usr/bin/python
2# -*- coding: utf-8 -*-
3# $Id: test_dbapi20.py 814 2016-02-03 20:23:20Z cito $
4
5try:
6    import unittest2 as unittest  # for Python < 2.7
7except ImportError:
8    import unittest
9
10import pgdb
11
12try:
13    from . import dbapi20
14except (ImportError, ValueError, SystemError):
15    import dbapi20
16
17# We need a database to test against.
18# If LOCAL_PyGreSQL.py exists we will get our information from that.
19# Otherwise we use the defaults.
20dbname = 'dbapi20_test'
21dbhost = ''
22dbport = 5432
23try:
24    from .LOCAL_PyGreSQL import *
25except (ImportError, ValueError):
26    try:
27        from LOCAL_PyGreSQL import *
28    except ImportError:
29        pass
30
31from datetime import date, time, datetime, timedelta
32
33try:
34    from datetime import timezone
35except ImportError:  # Python < 3.2
36    timezone = None
37
38try:
39    long
40except NameError:  # Python >= 3.0
41    long = int
42
43try:
44    from collections import OrderedDict
45except ImportError:  # Python 2.6 or 3.0
46    OrderedDict = None
47
48
49class PgBitString:
50    """Test object with a PostgreSQL representation as Bit String."""
51
52    def __init__(self, value):
53        self.value = value
54
55    def __pg_repr__(self):
56         return "B'{0:b}'".format(self.value)
57
58
59class test_PyGreSQL(dbapi20.DatabaseAPI20Test):
60
61    driver = pgdb
62    connect_args = ()
63    connect_kw_args = {'database': dbname,
64        'host': '%s:%d' % (dbhost or '', dbport or -1)}
65
66    lower_func = 'lower'  # For stored procedure test
67
68    def setUp(self):
69        # Call superclass setUp in case this does something in the future
70        dbapi20.DatabaseAPI20Test.setUp(self)
71        try:
72            con = self._connect()
73            con.close()
74        except pgdb.Error:  # try to create a missing database
75            import pg
76            try:  # first try to log in as superuser
77                db = pg.DB('postgres', dbhost or None, dbport or -1,
78                    user='postgres')
79            except Exception:  # then try to log in as current user
80                db = pg.DB('postgres', dbhost or None, dbport or -1)
81            db.query('create database ' + dbname)
82
83    def tearDown(self):
84        dbapi20.DatabaseAPI20Test.tearDown(self)
85
86    def test_callproc_no_params(self):
87        con = self._connect()
88        cur = con.cursor()
89        # note that now() does not change within a transaction
90        cur.execute('select now()')
91        now = cur.fetchone()[0]
92        res = cur.callproc('now')
93        self.assertIsNone(res)
94        res = cur.fetchone()[0]
95        self.assertEqual(res, now)
96
97    def test_callproc_bad_params(self):
98        con = self._connect()
99        cur = con.cursor()
100        self.assertRaises(TypeError, cur.callproc, 'lower', 42)
101        self.assertRaises(pgdb.ProgrammingError, cur.callproc, 'lower', (42,))
102
103    def test_callproc_one_param(self):
104        con = self._connect()
105        cur = con.cursor()
106        params = (42.4382,)
107        res = cur.callproc("round", params)
108        self.assertIs(res, params)
109        res = cur.fetchone()[0]
110        self.assertEqual(res, 42)
111
112    def test_callproc_two_params(self):
113        con = self._connect()
114        cur = con.cursor()
115        params = (9, 4)
116        res = cur.callproc("div", params)
117        self.assertIs(res, params)
118        res = cur.fetchone()[0]
119        self.assertEqual(res, 2)
120
121    def test_cursor_type(self):
122
123        class TestCursor(pgdb.Cursor):
124            pass
125
126        con = self._connect()
127        self.assertIs(con.cursor_type, pgdb.Cursor)
128        cur = con.cursor()
129        self.assertIsInstance(cur, pgdb.Cursor)
130        self.assertNotIsInstance(cur, TestCursor)
131        con.cursor_type = TestCursor
132        cur = con.cursor()
133        self.assertIsInstance(cur, TestCursor)
134        cur = con.cursor()
135        self.assertIsInstance(cur, TestCursor)
136        con = self._connect()
137        self.assertIs(con.cursor_type, pgdb.Cursor)
138        cur = con.cursor()
139        self.assertIsInstance(cur, pgdb.Cursor)
140        self.assertNotIsInstance(cur, TestCursor)
141
142    def test_row_factory(self):
143
144        class TestCursor(pgdb.Cursor):
145
146            def row_factory(self, row):
147                return dict(('column %s' % desc[0], value)
148                    for desc, value in zip(self.description, row))
149
150        con = self._connect()
151        con.cursor_type = TestCursor
152        cur = con.cursor()
153        self.assertIsInstance(cur, TestCursor)
154        res = cur.execute("select 1 as a, 2 as b")
155        self.assertIs(res, cur, 'execute() should return cursor')
156        res = cur.fetchone()
157        self.assertIsInstance(res, dict)
158        self.assertEqual(res, {'column a': 1, 'column b': 2})
159        cur.execute("select 1 as a, 2 as b union select 3, 4 order by 1")
160        res = cur.fetchall()
161        self.assertIsInstance(res, list)
162        self.assertEqual(len(res), 2)
163        self.assertIsInstance(res[0], dict)
164        self.assertEqual(res[0], {'column a': 1, 'column b': 2})
165        self.assertIsInstance(res[1], dict)
166        self.assertEqual(res[1], {'column a': 3, 'column b': 4})
167
168    def test_build_row_factory(self):
169
170        class TestCursor(pgdb.Cursor):
171
172            def build_row_factory(self):
173                keys = [desc[0] for desc in self.description]
174                return lambda row: dict((key, value)
175                    for key, value in zip(keys, row))
176
177        con = self._connect()
178        con.cursor_type = TestCursor
179        cur = con.cursor()
180        self.assertIsInstance(cur, TestCursor)
181        cur.execute("select 1 as a, 2 as b")
182        res = cur.fetchone()
183        self.assertIsInstance(res, dict)
184        self.assertEqual(res, {'a': 1, 'b': 2})
185        cur.execute("select 1 as a, 2 as b union select 3, 4 order by 1")
186        res = cur.fetchall()
187        self.assertIsInstance(res, list)
188        self.assertEqual(len(res), 2)
189        self.assertIsInstance(res[0], dict)
190        self.assertEqual(res[0], {'a': 1, 'b': 2})
191        self.assertIsInstance(res[1], dict)
192        self.assertEqual(res[1], {'a': 3, 'b': 4})
193
194    def test_cursor_with_named_columns(self):
195        con = self._connect()
196        cur = con.cursor()
197        res = cur.execute("select 1 as abc, 2 as de, 3 as f")
198        self.assertIs(res, cur, 'execute() should return cursor')
199        res = cur.fetchone()
200        self.assertIsInstance(res, tuple)
201        self.assertEqual(res, (1, 2, 3))
202        self.assertEqual(res._fields, ('abc', 'de', 'f'))
203        self.assertEqual(res.abc, 1)
204        self.assertEqual(res.de, 2)
205        self.assertEqual(res.f, 3)
206        cur.execute("select 1 as one, 2 as two union select 3, 4 order by 1")
207        res = cur.fetchall()
208        self.assertIsInstance(res, list)
209        self.assertEqual(len(res), 2)
210        self.assertIsInstance(res[0], tuple)
211        self.assertEqual(res[0], (1, 2))
212        self.assertEqual(res[0]._fields, ('one', 'two'))
213        self.assertIsInstance(res[1], tuple)
214        self.assertEqual(res[1], (3, 4))
215        self.assertEqual(res[1]._fields, ('one', 'two'))
216
217    def test_cursor_with_unnamed_columns(self):
218        con = self._connect()
219        cur = con.cursor()
220        cur.execute("select 1, 2, 3")
221        res = cur.fetchone()
222        self.assertIsInstance(res, tuple)
223        self.assertEqual(res, (1, 2, 3))
224        old_py = OrderedDict is None  # Python 2.6 or 3.0
225        # old Python versions cannot rename tuple fields with underscore
226        if old_py:
227            self.assertEqual(res._fields, ('column_0', 'column_1', 'column_2'))
228        else:
229            self.assertEqual(res._fields, ('_0', '_1', '_2'))
230        cur.execute("select 1 as one, 2, 3 as three")
231        res = cur.fetchone()
232        self.assertIsInstance(res, tuple)
233        self.assertEqual(res, (1, 2, 3))
234        if old_py:  # cannot auto rename with underscore
235            self.assertEqual(res._fields, ('one', 'column_1', 'three'))
236        else:
237            self.assertEqual(res._fields, ('one', '_1', 'three'))
238        cur.execute("select 1 as abc, 2 as def")
239        res = cur.fetchone()
240        self.assertIsInstance(res, tuple)
241        self.assertEqual(res, (1, 2))
242        if old_py:
243            self.assertEqual(res._fields, ('column_0', 'column_1'))
244        else:
245            self.assertEqual(res._fields, ('abc', '_1'))
246
247    def test_colnames(self):
248        con = self._connect()
249        cur = con.cursor()
250        cur.execute("select 1, 2, 3")
251        names = cur.colnames
252        self.assertIsInstance(names, list)
253        self.assertEqual(names, ['?column?', '?column?', '?column?'])
254        cur.execute("select 1 as a, 2 as bc, 3 as def, 4 as g")
255        names = cur.colnames
256        self.assertIsInstance(names, list)
257        self.assertEqual(names, ['a', 'bc', 'def', 'g'])
258
259    def test_coltypes(self):
260        con = self._connect()
261        cur = con.cursor()
262        cur.execute("select 1::int2, 2::int4, 3::int8")
263        types = cur.coltypes
264        self.assertIsInstance(types, list)
265        self.assertEqual(types, ['int2', 'int4', 'int8'])
266
267    def test_description_fields(self):
268        con = self._connect()
269        cur = con.cursor()
270        cur.execute("select 123456789::int8 col0,"
271            " 123456.789::numeric(41, 13) as col1,"
272            " 'foobar'::char(39) as col2")
273        desc = cur.description
274        self.assertIsInstance(desc, list)
275        self.assertEqual(len(desc), 3)
276        cols = [('int8', 8, None), ('numeric', 41, 13), ('bpchar', 39, None)]
277        for i in range(3):
278            c, d = cols[i], desc[i]
279            self.assertIsInstance(d, tuple)
280            self.assertEqual(len(d), 7)
281            self.assertIsInstance(d.name, str)
282            self.assertEqual(d.name, 'col%d' % i)
283            self.assertIsInstance(d.type_code, str)
284            self.assertEqual(d.type_code, c[0])
285            self.assertIsNone(d.display_size)
286            self.assertIsInstance(d.internal_size, int)
287            self.assertEqual(d.internal_size, c[1])
288            if c[2] is not None:
289                self.assertIsInstance(d.precision, int)
290                self.assertEqual(d.precision, c[1])
291                self.assertIsInstance(d.scale, int)
292                self.assertEqual(d.scale, c[2])
293            else:
294                self.assertIsNone(d.precision)
295                self.assertIsNone(d.scale)
296            self.assertIsNone(d.null_ok)
297
298    def test_type_cache_info(self):
299        con = self._connect()
300        try:
301            cur = con.cursor()
302            type_cache = con.type_cache
303            self.assertNotIn('numeric', type_cache)
304            type_info = type_cache['numeric']
305            self.assertIn('numeric', type_cache)
306            self.assertEqual(type_info, 'numeric')
307            self.assertEqual(type_info.oid, 1700)
308            self.assertEqual(type_info.len, -1)
309            self.assertEqual(type_info.type, 'b')  # base
310            self.assertEqual(type_info.category, 'N')  # numeric
311            self.assertEqual(type_info.delim, ',')
312            self.assertEqual(type_info.relid, 0)
313            self.assertIs(con.type_cache[1700], type_info)
314            self.assertNotIn('pg_type', type_cache)
315            type_info = type_cache['pg_type']
316            self.assertIn('pg_type', type_cache)
317            self.assertEqual(type_info.type, 'c')  # composite
318            self.assertEqual(type_info.category, 'C')  # composite
319            cols = type_cache.get_fields('pg_type')
320            self.assertEqual(cols[0].name, 'typname')
321            typname = type_cache[cols[0].type]
322            self.assertEqual(typname, 'name')
323            self.assertEqual(typname.type, 'b')  # base
324            self.assertEqual(typname.category, 'S')  # string
325            self.assertEqual(cols[3].name, 'typlen')
326            typlen = type_cache[cols[3].type]
327            self.assertEqual(typlen, 'int2')
328            self.assertEqual(typlen.type, 'b')  # base
329            self.assertEqual(typlen.category, 'N')  # numeric
330            cur.close()
331            cur = con.cursor()
332            type_cache = con.type_cache
333            self.assertIn('numeric', type_cache)
334            cur.close()
335        finally:
336            con.close()
337        con = self._connect()
338        try:
339            cur = con.cursor()
340            type_cache = con.type_cache
341            self.assertNotIn('pg_type', type_cache)
342            self.assertEqual(type_cache.get('pg_type'), type_info)
343            self.assertIn('pg_type', type_cache)
344            self.assertIsNone(type_cache.get(
345                self.table_prefix + '_surely_does_not_exist'))
346            cur.close()
347        finally:
348            con.close()
349
350    def test_type_cache_typecast(self):
351        con = self._connect()
352        try:
353            cur = con.cursor()
354            type_cache = con.type_cache
355            self.assertIs(type_cache.get_typecast('int4'), int)
356            cast_int = lambda v: 'int(%s)' % v
357            type_cache.set_typecast('int4', cast_int)
358            query = 'select 2::int2, 4::int4, 8::int8'
359            cur.execute(query)
360            i2, i4, i8 = cur.fetchone()
361            self.assertEqual(i2, 2)
362            self.assertEqual(i4, 'int(4)')
363            self.assertEqual(i8, 8)
364            self.assertEqual(type_cache.typecast(42, 'int4'), 'int(42)')
365            type_cache.set_typecast(['int2', 'int8'], cast_int)
366            cur.execute(query)
367            i2, i4, i8 = cur.fetchone()
368            self.assertEqual(i2, 'int(2)')
369            self.assertEqual(i4, 'int(4)')
370            self.assertEqual(i8, 'int(8)')
371            type_cache.reset_typecast('int4')
372            cur.execute(query)
373            i2, i4, i8 = cur.fetchone()
374            self.assertEqual(i2, 'int(2)')
375            self.assertEqual(i4, 4)
376            self.assertEqual(i8, 'int(8)')
377            type_cache.reset_typecast(['int2', 'int8'])
378            cur.execute(query)
379            i2, i4, i8 = cur.fetchone()
380            self.assertEqual(i2, 2)
381            self.assertEqual(i4, 4)
382            self.assertEqual(i8, 8)
383            type_cache.set_typecast(['int2', 'int8'], cast_int)
384            cur.execute(query)
385            i2, i4, i8 = cur.fetchone()
386            self.assertEqual(i2, 'int(2)')
387            self.assertEqual(i4, 4)
388            self.assertEqual(i8, 'int(8)')
389            type_cache.reset_typecast()
390            cur.execute(query)
391            i2, i4, i8 = cur.fetchone()
392            self.assertEqual(i2, 2)
393            self.assertEqual(i4, 4)
394            self.assertEqual(i8, 8)
395            cur.close()
396        finally:
397            con.close()
398
399    def test_cursor_iteration(self):
400        con = self._connect()
401        cur = con.cursor()
402        cur.execute("select 1 union select 2 union select 3")
403        self.assertEqual([r[0] for r in cur], [1, 2, 3])
404
405    def test_fetch_2_rows(self):
406        Decimal = pgdb.decimal_type()
407        values = ('test', pgdb.Binary(b'\xff\x52\xb2'),
408            True, 5, 6, 5.7, Decimal('234.234234'), Decimal('75.45'),
409            pgdb.Date(2011, 7, 17), pgdb.Time(15, 47, 42),
410            pgdb.Timestamp(2008, 10, 20, 15, 25, 35),
411            pgdb.Interval(15, 31, 5), 7897234)
412        table = self.table_prefix + 'booze'
413        con = self._connect()
414        try:
415            cur = con.cursor()
416            cur.execute("set datestyle to iso")
417            cur.execute("create table %s ("
418                "stringtest varchar,"
419                "binarytest bytea,"
420                "booltest bool,"
421                "integertest int4,"
422                "longtest int8,"
423                "floattest float8,"
424                "numerictest numeric,"
425                "moneytest money,"
426                "datetest date,"
427                "timetest time,"
428                "datetimetest timestamp,"
429                "intervaltest interval,"
430                "rowidtest oid)" % table)
431            cur.execute("set standard_conforming_strings to on")
432            for s in ('numeric', 'monetary', 'time'):
433                cur.execute("set lc_%s to 'C'" % s)
434            for _i in range(2):
435                cur.execute("insert into %s values ("
436                    "%%s,%%s,%%s,%%s,%%s,%%s,%%s,"
437                    "'%%s'::money,%%s,%%s,%%s,%%s,%%s)" % table, values)
438            cur.execute("select * from %s" % table)
439            rows = cur.fetchall()
440            self.assertEqual(len(rows), 2)
441            row0 = rows[0]
442            self.assertEqual(row0, values)
443            self.assertEqual(row0, rows[1])
444            self.assertIsInstance(row0[0], str)
445            self.assertIsInstance(row0[1], bytes)
446            self.assertIsInstance(row0[2], bool)
447            self.assertIsInstance(row0[3], int)
448            self.assertIsInstance(row0[4], long)
449            self.assertIsInstance(row0[5], float)
450            self.assertIsInstance(row0[6], Decimal)
451            self.assertIsInstance(row0[7], Decimal)
452            self.assertIsInstance(row0[8], date)
453            self.assertIsInstance(row0[9], time)
454            self.assertIsInstance(row0[10], datetime)
455            self.assertIsInstance(row0[11], timedelta)
456        finally:
457            con.close()
458
459    def test_sqlstate(self):
460        con = self._connect()
461        cur = con.cursor()
462        try:
463            cur.execute("select 1/0")
464        except pgdb.DatabaseError as error:
465            self.assertTrue(isinstance(error, pgdb.ProgrammingError))
466            # the SQLSTATE error code for division by zero is 22012
467            self.assertEqual(error.sqlstate, '22012')
468
469    def test_float(self):
470        nan, inf = float('nan'), float('inf')
471        from math import isnan, isinf
472        self.assertTrue(isnan(nan) and not isinf(nan))
473        self.assertTrue(isinf(inf) and not isnan(inf))
474        values = [0, 1, 0.03125, -42.53125, nan, inf, -inf,
475            'nan', 'inf', '-inf', 'NaN', 'Infinity', '-Infinity']
476        table = self.table_prefix + 'booze'
477        con = self._connect()
478        try:
479            cur = con.cursor()
480            cur.execute(
481                "create table %s (n smallint, floattest float)" % table)
482            params = enumerate(values)
483            cur.executemany("insert into %s values (%%d,%%s)" % table, params)
484            cur.execute("select floattest from %s order by n" % table)
485            rows = cur.fetchall()
486            self.assertEqual(cur.description[0].type_code, pgdb.FLOAT)
487            self.assertNotEqual(cur.description[0].type_code, pgdb.ARRAY)
488            self.assertNotEqual(cur.description[0].type_code, pgdb.RECORD)
489        finally:
490            con.close()
491        self.assertEqual(len(rows), len(values))
492        rows = [row[0] for row in rows]
493        for inval, outval in zip(values, rows):
494            if inval in ('inf', 'Infinity'):
495                inval = inf
496            elif inval in ('-inf', '-Infinity'):
497                inval = -inf
498            elif inval in ('nan', 'NaN'):
499                inval = nan
500            if isinf(inval):
501                self.assertTrue(isinf(outval))
502                if inval < 0:
503                    self.assertTrue(outval < 0)
504                else:
505                    self.assertTrue(outval > 0)
506            elif isnan(inval):
507                self.assertTrue(isnan(outval))
508            else:
509                self.assertEqual(inval, outval)
510
511    def test_datetime(self):
512        dt = datetime(2011, 7, 17, 15, 47, 42, 317509)
513        td = dt - datetime(1970, 1, 1)
514        table = self.table_prefix + 'booze'
515        con = self._connect()
516        try:
517            cur = con.cursor()
518            cur.execute("set datestyle to iso")
519            cur.execute("set datestyle to iso")
520            cur.execute("create table %s ("
521                "d date, t time,  ts timestamp,"
522                "tz timetz, tsz timestamptz, i interval)" % table)
523            for n in range(3):
524                values = [dt.date(), dt.time(), dt,
525                    dt.time(), dt, td]
526                if timezone:
527                    values[3] = values[3].replace(tzinfo=timezone.utc)
528                    values[4] = values[4].replace(tzinfo=timezone.utc)
529                if n == 0:  # input as objects
530                    params = values
531                if n == 1:  # input as text
532                    params = [v.isoformat() for v in values[:5]]  # as text
533                    params.append('%d days %d seconds %d microseconds '
534                        % (td.days, td.seconds, td.microseconds))
535                elif n == 2:  # input using type helpers
536                    d = (dt.year, dt.month, dt.day)
537                    t = (dt.hour, dt.minute, dt.second, dt.microsecond)
538                    i = (td.days, 0, 0, td.seconds, td.microseconds)
539                    params = [pgdb.Date(*d), pgdb.Time(*t),
540                            pgdb.Timestamp(*(d + t)), pgdb.Time(*t),
541                            pgdb.Timestamp(*(d + t)), pgdb.Interval(*i)]
542                cur.execute("insert into %s"
543                    " values (%%s,%%s,%%s,%%s,%%s,%%s)" % table, params)
544                for datestyle in ('iso', 'postgres, mdy', 'postgres, dmy',
545                        'sql, mdy', 'sql, dmy', 'german'):
546                    cur.execute("set datestyle to %s" % datestyle)
547                    cur.execute("select * from %s" % table)
548                    d = cur.description
549                    for i in range(6):
550                        self.assertEqual(d[i].type_code, pgdb.DATETIME)
551                        self.assertNotEqual(d[i].type_code, pgdb.STRING)
552                        self.assertNotEqual(d[i].type_code, pgdb.ARRAY)
553                        self.assertNotEqual(d[i].type_code, pgdb.RECORD)
554                    self.assertEqual(d[0].type_code, pgdb.DATE)
555                    self.assertEqual(d[1].type_code, pgdb.TIME)
556                    self.assertEqual(d[2].type_code, pgdb.TIMESTAMP)
557                    self.assertEqual(d[3].type_code, pgdb.TIME)
558                    self.assertEqual(d[4].type_code, pgdb.TIMESTAMP)
559                    self.assertEqual(d[5].type_code, pgdb.INTERVAL)
560                    row = cur.fetchone()
561                    self.assertEqual(row, tuple(values))
562                cur.execute("delete from %s" % table)
563        finally:
564            con.close()
565
566    def test_insert_array(self):
567        values = [(None, None), ([], []), ([None], [[None], ['null']]),
568            ([1, 2, 3], [['a', 'b'], ['c', 'd']]),
569            ([20000, 25000, 25000, 30000],
570            [['breakfast', 'consulting'], ['meeting', 'lunch']]),
571            ([0, 1, -1], [['Hello, World!', '"Hi!"'], ['{x,y}', ' x y ']])]
572        table = self.table_prefix + 'booze'
573        con = self._connect()
574        try:
575            cur = con.cursor()
576            cur.execute("create table %s"
577                " (n smallint, i int[], t text[][])" % table)
578            params = [(n, v[0], v[1]) for n, v in enumerate(values)]
579            # Note that we must explicit casts because we are inserting
580            # empty arrays.  Otherwise this is not necessary.
581            cur.executemany("insert into %s values"
582                " (%%d,%%s::int[],%%s::text[][])" % table, params)
583            cur.execute("select i, t from %s order by n" % table)
584            d = cur.description
585            self.assertEqual(d[0].type_code, pgdb.ARRAY)
586            self.assertNotEqual(d[0].type_code, pgdb.RECORD)
587            self.assertEqual(d[0].type_code, pgdb.NUMBER)
588            self.assertEqual(d[0].type_code, pgdb.INTEGER)
589            self.assertEqual(d[1].type_code, pgdb.ARRAY)
590            self.assertNotEqual(d[1].type_code, pgdb.RECORD)
591            self.assertEqual(d[1].type_code, pgdb.STRING)
592            rows = cur.fetchall()
593        finally:
594            con.close()
595        self.assertEqual(rows, values)
596
597    def test_select_array(self):
598        values = ([1, 2, 3, None], ['a', 'b', 'c', None])
599        con = self._connect()
600        try:
601            cur = con.cursor()
602            cur.execute("select %s::int[], %s::text[]", values)
603            row = cur.fetchone()
604        finally:
605            con.close()
606        self.assertEqual(row, values)
607
608    def test_insert_record(self):
609        values = [('John', 61), ('Jane', 63),
610                  ('Fred', None), ('Wilma', None),
611                  (None, 42), (None, None)]
612        table = self.table_prefix + 'booze'
613        record = self.table_prefix + 'munch'
614        con = self._connect()
615        try:
616            cur = con.cursor()
617            cur.execute("create type %s as (name varchar, age int)" % record)
618            cur.execute("create table %s (n smallint, r %s)" % (table, record))
619            params = enumerate(values)
620            cur.executemany("insert into %s values (%%d,%%s)" % table, params)
621            cur.execute("select r from %s order by n" % table)
622            type_code = cur.description[0].type_code
623            self.assertEqual(type_code, record)
624            self.assertEqual(type_code, pgdb.RECORD)
625            self.assertNotEqual(type_code, pgdb.ARRAY)
626            columns = con.type_cache.get_fields(type_code)
627            self.assertEqual(columns[0].name, 'name')
628            self.assertEqual(columns[1].name, 'age')
629            self.assertEqual(con.type_cache[columns[0].type], 'varchar')
630            self.assertEqual(con.type_cache[columns[1].type], 'int4')
631            rows = cur.fetchall()
632        finally:
633            cur.execute('drop table %s' % table)
634            cur.execute('drop type %s' % record)
635            con.close()
636        self.assertEqual(len(rows), len(values))
637        rows = [row[0] for row in rows]
638        self.assertEqual(rows, values)
639        self.assertEqual(rows[0].name, 'John')
640        self.assertEqual(rows[0].age, 61)
641
642    def test_select_record(self):
643        value = (1, 25000, 2.5, 'hello', 'Hello World!', 'Hello, World!',
644            '(test)', '(x,y)', ' x y ', 'null', None)
645        con = self._connect()
646        try:
647            cur = con.cursor()
648            cur.execute("select %s as test_record", [value])
649            self.assertEqual(cur.description[0].name, 'test_record')
650            self.assertEqual(cur.description[0].type_code, 'record')
651            row = cur.fetchone()[0]
652        finally:
653            con.close()
654        # Note that the element types get lost since we created an
655        # untyped record (an anonymous composite type). For the same
656        # reason this is also a normal tuple, not a named tuple.
657        text_row = tuple(None if v is None else str(v) for v in value)
658        self.assertEqual(row, text_row)
659
660    def test_custom_type(self):
661        values = [3, 5, 65]
662        values = list(map(PgBitString, values))
663        table = self.table_prefix + 'booze'
664        con = self._connect()
665        try:
666            cur = con.cursor()
667            params = enumerate(values)  # params have __pg_repr__ method
668            cur.execute(
669                'create table "%s" (n smallint, b bit varying(7))' % table)
670            cur.executemany("insert into %s values (%%s,%%s)" % table, params)
671            cur.execute("select * from %s" % table)
672            rows = cur.fetchall()
673        finally:
674            con.close()
675        self.assertEqual(len(rows), len(values))
676        con = self._connect()
677        try:
678            cur = con.cursor()
679            params = (1, object())  # an object that cannot be handled
680            self.assertRaises(pgdb.InterfaceError, cur.execute,
681                "insert into %s values (%%s,%%s)" % table, params)
682        finally:
683            con.close()
684
685    def test_set_decimal_type(self):
686        decimal_type = pgdb.decimal_type()
687        self.assertTrue(decimal_type is not None and callable(decimal_type))
688        con = self._connect()
689        try:
690            cur = con.cursor()
691            # change decimal type globally to int
692            int_type = lambda v: int(float(v))
693            self.assertTrue(pgdb.decimal_type(int_type) is int_type)
694            cur.execute('select 4.25')
695            self.assertEqual(cur.description[0].type_code, pgdb.NUMBER)
696            value = cur.fetchone()[0]
697            self.assertTrue(isinstance(value, int))
698            self.assertEqual(value, 4)
699            # change decimal type again to float
700            self.assertTrue(pgdb.decimal_type(float) is float)
701            cur.execute('select 4.25')
702            self.assertEqual(cur.description[0].type_code, pgdb.NUMBER)
703            value = cur.fetchone()[0]
704            # the connection still uses the old setting
705            self.assertTrue(isinstance(value, int))
706            # bust the cache for type functions for the connection
707            con.type_cache.reset_typecast()
708            cur.execute('select 4.25')
709            self.assertEqual(cur.description[0].type_code, pgdb.NUMBER)
710            value = cur.fetchone()[0]
711            # now the connection uses the new setting
712            self.assertTrue(isinstance(value, float))
713            self.assertEqual(value, 4.25)
714        finally:
715            con.close()
716            pgdb.decimal_type(decimal_type)
717        self.assertTrue(pgdb.decimal_type() is decimal_type)
718
719    def test_global_typecast(self):
720        try:
721            query = 'select 2::int2, 4::int4, 8::int8'
722            self.assertIs(pgdb.get_typecast('int4'), int)
723            cast_int = lambda v: 'int(%s)' % v
724            pgdb.set_typecast('int4', cast_int)
725            con = self._connect()
726            try:
727                i2, i4, i8 = con.cursor().execute(query).fetchone()
728            finally:
729                con.close()
730            self.assertEqual(i2, 2)
731            self.assertEqual(i4, 'int(4)')
732            self.assertEqual(i8, 8)
733            pgdb.set_typecast(['int2', 'int8'], cast_int)
734            con = self._connect()
735            try:
736                i2, i4, i8 = con.cursor().execute(query).fetchone()
737            finally:
738                con.close()
739            self.assertEqual(i2, 'int(2)')
740            self.assertEqual(i4, 'int(4)')
741            self.assertEqual(i8, 'int(8)')
742            pgdb.reset_typecast('int4')
743            con = self._connect()
744            try:
745                i2, i4, i8 = con.cursor().execute(query).fetchone()
746            finally:
747                con.close()
748            self.assertEqual(i2, 'int(2)')
749            self.assertEqual(i4, 4)
750            self.assertEqual(i8, 'int(8)')
751            pgdb.reset_typecast(['int2', 'int8'])
752            con = self._connect()
753            try:
754                i2, i4, i8 = con.cursor().execute(query).fetchone()
755            finally:
756                con.close()
757            self.assertEqual(i2, 2)
758            self.assertEqual(i4, 4)
759            self.assertEqual(i8, 8)
760            pgdb.set_typecast(['int2', 'int8'], cast_int)
761            con = self._connect()
762            try:
763                i2, i4, i8 = con.cursor().execute(query).fetchone()
764            finally:
765                con.close()
766            self.assertEqual(i2, 'int(2)')
767            self.assertEqual(i4, 4)
768            self.assertEqual(i8, 'int(8)')
769        finally:
770            pgdb.reset_typecast()
771        con = self._connect()
772        try:
773            i2, i4, i8 = con.cursor().execute(query).fetchone()
774        finally:
775            con.close()
776        self.assertEqual(i2, 2)
777        self.assertEqual(i4, 4)
778        self.assertEqual(i8, 8)
779
780    def test_unicode_with_utf8(self):
781        table = self.table_prefix + 'booze'
782        input = u"He wes Leovenaðes sone — liðe him be Drihten"
783        con = self._connect()
784        try:
785            cur = con.cursor()
786            cur.execute("create table %s (t text)" % table)
787            try:
788                cur.execute("set client_encoding=utf8")
789                cur.execute(u"select '%s'" % input)
790            except Exception:
791                self.skipTest("database does not support utf8")
792            output1 = cur.fetchone()[0]
793            cur.execute("insert into %s values (%%s)" % table, (input,))
794            cur.execute("select * from %s" % table)
795            output2 = cur.fetchone()[0]
796            cur.execute("select t = '%s' from %s" % (input, table))
797            output3 = cur.fetchone()[0]
798            cur.execute("select t = %%s from %s" % table, (input,))
799            output4 = cur.fetchone()[0]
800        finally:
801            con.close()
802        if str is bytes:  # Python < 3.0
803            input = input.encode('utf8')
804        self.assertIsInstance(output1, str)
805        self.assertEqual(output1, input)
806        self.assertIsInstance(output2, str)
807        self.assertEqual(output2, input)
808        self.assertIsInstance(output3, bool)
809        self.assertTrue(output3)
810        self.assertIsInstance(output4, bool)
811        self.assertTrue(output4)
812
813    def test_unicode_with_latin1(self):
814        table = self.table_prefix + 'booze'
815        input = u"Ehrt den König seine WÃŒrde, ehret uns der HÀnde Fleiß."
816        con = self._connect()
817        try:
818            cur = con.cursor()
819            cur.execute("create table %s (t text)" % table)
820            try:
821                cur.execute("set client_encoding=latin1")
822                cur.execute(u"select '%s'" % input)
823            except Exception:
824                self.skipTest("database does not support latin1")
825            output1 = cur.fetchone()[0]
826            cur.execute("insert into %s values (%%s)" % table, (input,))
827            cur.execute("select * from %s" % table)
828            output2 = cur.fetchone()[0]
829            cur.execute("select t = '%s' from %s" % (input, table))
830            output3 = cur.fetchone()[0]
831            cur.execute("select t = %%s from %s" % table, (input,))
832            output4 = cur.fetchone()[0]
833        finally:
834            con.close()
835        if str is bytes:  # Python < 3.0
836            input = input.encode('latin1')
837        self.assertIsInstance(output1, str)
838        self.assertEqual(output1, input)
839        self.assertIsInstance(output2, str)
840        self.assertEqual(output2, input)
841        self.assertIsInstance(output3, bool)
842        self.assertTrue(output3)
843        self.assertIsInstance(output4, bool)
844        self.assertTrue(output4)
845
846    def test_bool(self):
847        values = [False, True, None, 't', 'f', 'true', 'false']
848        table = self.table_prefix + 'booze'
849        con = self._connect()
850        try:
851            cur = con.cursor()
852            cur.execute(
853                "create table %s (n smallint, booltest bool)" % table)
854            params = enumerate(values)
855            cur.executemany("insert into %s values (%%s,%%s)" % table, params)
856            cur.execute("select booltest from %s order by n" % table)
857            rows = cur.fetchall()
858            self.assertEqual(cur.description[0].type_code, pgdb.BOOL)
859        finally:
860            con.close()
861        rows = [row[0] for row in rows]
862        values[3] = values[5] = True
863        values[4] = values[6] = False
864        self.assertEqual(rows, values)
865
866    def test_literal(self):
867        con = self._connect()
868        try:
869            cur = con.cursor()
870            value = "lower('Hello')"
871            cur.execute("select %s, %s", (value, pgdb.Literal(value)))
872            row = cur.fetchone()
873        finally:
874            con.close()
875        self.assertEqual(row, (value, 'hello'))
876
877
878    def test_json(self):
879        inval = {"employees":
880            [{"firstName": "John", "lastName": "Doe", "age": 61}]}
881        table = self.table_prefix + 'booze'
882        con = self._connect()
883        try:
884            cur = con.cursor()
885            try:
886                cur.execute("create table %s (jsontest json)" % table)
887            except pgdb.ProgrammingError:
888                self.skipTest('database does not support json')
889            params = (pgdb.Json(inval),)
890            cur.execute("insert into %s values (%%s)" % table, params)
891            cur.execute("select jsontest from %s" % table)
892            outval = cur.fetchone()[0]
893            self.assertEqual(cur.description[0].type_code, pgdb.JSON)
894        finally:
895            con.close()
896        self.assertEqual(inval, outval)
897
898    def test_jsonb(self):
899        inval = {"employees":
900            [{"firstName": "John", "lastName": "Doe", "age": 61}]}
901        table = self.table_prefix + 'booze'
902        con = self._connect()
903        try:
904            cur = con.cursor()
905            try:
906                cur.execute("create table %s (jsonbtest jsonb)" % table)
907            except pgdb.ProgrammingError:
908                self.skipTest('database does not support jsonb')
909            params = (pgdb.Json(inval),)
910            cur.execute("insert into %s values (%%s)" % table, params)
911            cur.execute("select jsonbtest from %s" % table)
912            outval = cur.fetchone()[0]
913            self.assertEqual(cur.description[0].type_code, pgdb.JSON)
914        finally:
915            con.close()
916        self.assertEqual(inval, outval)
917
918    def test_execute_edge_cases(self):
919        con = self._connect()
920        try:
921            cur = con.cursor()
922            sql = 'invalid'  # should be ignored with empty parameter list
923            cur.executemany(sql, [])
924            sql = 'select %d + 1'
925            cur.execute(sql, [(1,), (2,)])  # deprecated use of execute()
926            self.assertEqual(cur.fetchone()[0], 3)
927            sql = 'select 1/0'  # cannot be executed
928            self.assertRaises(pgdb.ProgrammingError, cur.execute, sql)
929            cur.close()
930            con.rollback()
931            if pgdb.shortcutmethods:
932                res = con.execute('select %d', (1,)).fetchone()
933                self.assertEqual(res, (1,))
934                res = con.executemany('select %d', [(1,), (2,)]).fetchone()
935                self.assertEqual(res, (2,))
936        finally:
937            con.close()
938        sql = 'select 1'  # cannot be executed after connection is closed
939        self.assertRaises(pgdb.OperationalError, cur.execute, sql)
940
941    def test_fetchmany_with_keep(self):
942        con = self._connect()
943        try:
944            cur = con.cursor()
945            self.assertEqual(cur.arraysize, 1)
946            cur.execute('select * from generate_series(1, 25)')
947            self.assertEqual(len(cur.fetchmany()), 1)
948            self.assertEqual(len(cur.fetchmany()), 1)
949            self.assertEqual(cur.arraysize, 1)
950            cur.arraysize = 3
951            self.assertEqual(len(cur.fetchmany()), 3)
952            self.assertEqual(len(cur.fetchmany()), 3)
953            self.assertEqual(cur.arraysize, 3)
954            self.assertEqual(len(cur.fetchmany(size=2)), 2)
955            self.assertEqual(cur.arraysize, 3)
956            self.assertEqual(len(cur.fetchmany()), 3)
957            self.assertEqual(len(cur.fetchmany()), 3)
958            self.assertEqual(len(cur.fetchmany(size=2, keep=True)), 2)
959            self.assertEqual(cur.arraysize, 2)
960            self.assertEqual(len(cur.fetchmany()), 2)
961            self.assertEqual(len(cur.fetchmany()), 2)
962            self.assertEqual(len(cur.fetchmany(25)), 3)
963        finally:
964            con.close()
965
966    def test_nextset(self):
967        con = self._connect()
968        cur = con.cursor()
969        self.assertRaises(con.NotSupportedError, cur.nextset)
970
971    def test_setoutputsize(self):
972        pass  # not supported
973
974    def test_connection_errors(self):
975        con = self._connect()
976        self.assertEqual(con.Error, pgdb.Error)
977        self.assertEqual(con.Warning, pgdb.Warning)
978        self.assertEqual(con.InterfaceError, pgdb.InterfaceError)
979        self.assertEqual(con.DatabaseError, pgdb.DatabaseError)
980        self.assertEqual(con.InternalError, pgdb.InternalError)
981        self.assertEqual(con.OperationalError, pgdb.OperationalError)
982        self.assertEqual(con.ProgrammingError, pgdb.ProgrammingError)
983        self.assertEqual(con.IntegrityError, pgdb.IntegrityError)
984        self.assertEqual(con.DataError, pgdb.DataError)
985        self.assertEqual(con.NotSupportedError, pgdb.NotSupportedError)
986
987    def test_connection_as_contextmanager(self):
988        table = self.table_prefix + 'booze'
989        con = self._connect()
990        try:
991            cur = con.cursor()
992            cur.execute("create table %s (n smallint check(n!=4))" % table)
993            with con:
994                cur.execute("insert into %s values (1)" % table)
995                cur.execute("insert into %s values (2)" % table)
996            try:
997                with con:
998                    cur.execute("insert into %s values (3)" % table)
999                    cur.execute("insert into %s values (4)" % table)
1000            except con.ProgrammingError as error:
1001                self.assertTrue('check' in str(error).lower())
1002            with con:
1003                cur.execute("insert into %s values (5)" % table)
1004                cur.execute("insert into %s values (6)" % table)
1005            try:
1006                with con:
1007                    cur.execute("insert into %s values (7)" % table)
1008                    cur.execute("insert into %s values (8)" % table)
1009                    raise ValueError('transaction should rollback')
1010            except ValueError as error:
1011                self.assertEqual(str(error), 'transaction should rollback')
1012            with con:
1013                cur.execute("insert into %s values (9)" % table)
1014            cur.execute("select * from %s order by 1" % table)
1015            rows = cur.fetchall()
1016            rows = [row[0] for row in rows]
1017        finally:
1018            con.close()
1019        self.assertEqual(rows, [1, 2, 5, 6, 9])
1020
1021    def test_cursor_connection(self):
1022        con = self._connect()
1023        cur = con.cursor()
1024        self.assertEqual(cur.connection, con)
1025        cur.close()
1026
1027    def test_cursor_as_contextmanager(self):
1028        con = self._connect()
1029        with con.cursor() as cur:
1030            self.assertEqual(cur.connection, con)
1031
1032    def test_pgdb_type(self):
1033        self.assertEqual(pgdb.STRING, pgdb.STRING)
1034        self.assertNotEqual(pgdb.STRING, pgdb.INTEGER)
1035        self.assertNotEqual(pgdb.STRING, pgdb.BOOL)
1036        self.assertNotEqual(pgdb.BOOL, pgdb.INTEGER)
1037        self.assertEqual(pgdb.INTEGER, pgdb.INTEGER)
1038        self.assertNotEqual(pgdb.INTEGER, pgdb.NUMBER)
1039        self.assertEqual('char', pgdb.STRING)
1040        self.assertEqual('varchar', pgdb.STRING)
1041        self.assertEqual('text', pgdb.STRING)
1042        self.assertNotEqual('numeric', pgdb.STRING)
1043        self.assertEqual('numeric', pgdb.NUMERIC)
1044        self.assertEqual('numeric', pgdb.NUMBER)
1045        self.assertEqual('int4', pgdb.NUMBER)
1046        self.assertNotEqual('int4', pgdb.NUMERIC)
1047        self.assertEqual('int2', pgdb.SMALLINT)
1048        self.assertNotEqual('int4', pgdb.SMALLINT)
1049        self.assertEqual('int2', pgdb.INTEGER)
1050        self.assertEqual('int4', pgdb.INTEGER)
1051        self.assertEqual('int8', pgdb.INTEGER)
1052        self.assertNotEqual('int4', pgdb.LONG)
1053        self.assertEqual('int8', pgdb.LONG)
1054        self.assertTrue('char' in pgdb.STRING)
1055        self.assertTrue(pgdb.NUMERIC <= pgdb.NUMBER)
1056        self.assertTrue(pgdb.NUMBER >= pgdb.INTEGER)
1057        self.assertTrue(pgdb.TIME <= pgdb.DATETIME)
1058        self.assertTrue(pgdb.DATETIME >= pgdb.DATE)
1059        self.assertEqual(pgdb.ARRAY, pgdb.ARRAY)
1060        self.assertNotEqual(pgdb.ARRAY, pgdb.STRING)
1061        self.assertEqual('_char', pgdb.ARRAY)
1062        self.assertNotEqual('char', pgdb.ARRAY)
1063        self.assertEqual(pgdb.RECORD, pgdb.RECORD)
1064        self.assertNotEqual(pgdb.RECORD, pgdb.STRING)
1065        self.assertNotEqual(pgdb.RECORD, pgdb.ARRAY)
1066        self.assertEqual('record', pgdb.RECORD)
1067        self.assertNotEqual('_record', pgdb.RECORD)
1068
1069
1070if __name__ == '__main__':
1071    unittest.main()
Note: See TracBrowser for help on using the repository browser.