source: trunk/tests/test_dbapi20.py @ 817

Last change on this file since 817 was 817, checked in by cito, 4 years ago

Support the hstore data type

Added adaptation and typecasting of the hstore type as Python dictionaries.
For the typecasting, a fast parser has been added to the C extension.

  • Property svn:executable set to *
  • Property svn:keywords set to Id
File size: 44.3 KB
Line 
1#! /usr/bin/python
2# -*- coding: utf-8 -*-
3# $Id: test_dbapi20.py 817 2016-02-04 20:18:08Z cito $
4
5try:
6    import unittest2 as unittest  # for Python < 2.7
7except ImportError:
8    import unittest
9
10import pgdb
11
12try:
13    from . import dbapi20
14except (ImportError, ValueError, SystemError):
15    import dbapi20
16
17# We need a database to test against.
18# If LOCAL_PyGreSQL.py exists we will get our information from that.
19# Otherwise we use the defaults.
20dbname = 'dbapi20_test'
21dbhost = ''
22dbport = 5432
23try:
24    from .LOCAL_PyGreSQL import *
25except (ImportError, ValueError):
26    try:
27        from LOCAL_PyGreSQL import *
28    except ImportError:
29        pass
30
31from datetime import date, time, datetime, timedelta
32
33try:
34    from datetime import timezone
35except ImportError:  # Python < 3.2
36    timezone = None
37
38try:
39    long
40except NameError:  # Python >= 3.0
41    long = int
42
43try:
44    from collections import OrderedDict
45except ImportError:  # Python 2.6 or 3.0
46    OrderedDict = None
47
48
49class PgBitString:
50    """Test object with a PostgreSQL representation as Bit String."""
51
52    def __init__(self, value):
53        self.value = value
54
55    def __pg_repr__(self):
56         return "B'{0:b}'".format(self.value)
57
58
59class test_PyGreSQL(dbapi20.DatabaseAPI20Test):
60
61    driver = pgdb
62    connect_args = ()
63    connect_kw_args = {'database': dbname,
64        'host': '%s:%d' % (dbhost or '', dbport or -1)}
65
66    lower_func = 'lower'  # For stored procedure test
67
68    def setUp(self):
69        # Call superclass setUp in case this does something in the future
70        dbapi20.DatabaseAPI20Test.setUp(self)
71        try:
72            con = self._connect()
73            con.close()
74        except pgdb.Error:  # try to create a missing database
75            import pg
76            try:  # first try to log in as superuser
77                db = pg.DB('postgres', dbhost or None, dbport or -1,
78                    user='postgres')
79            except Exception:  # then try to log in as current user
80                db = pg.DB('postgres', dbhost or None, dbport or -1)
81            db.query('create database ' + dbname)
82
83    def tearDown(self):
84        dbapi20.DatabaseAPI20Test.tearDown(self)
85
86    def test_callproc_no_params(self):
87        con = self._connect()
88        cur = con.cursor()
89        # note that now() does not change within a transaction
90        cur.execute('select now()')
91        now = cur.fetchone()[0]
92        res = cur.callproc('now')
93        self.assertIsNone(res)
94        res = cur.fetchone()[0]
95        self.assertEqual(res, now)
96
97    def test_callproc_bad_params(self):
98        con = self._connect()
99        cur = con.cursor()
100        self.assertRaises(TypeError, cur.callproc, 'lower', 42)
101        self.assertRaises(pgdb.ProgrammingError, cur.callproc, 'lower', (42,))
102
103    def test_callproc_one_param(self):
104        con = self._connect()
105        cur = con.cursor()
106        params = (42.4382,)
107        res = cur.callproc("round", params)
108        self.assertIs(res, params)
109        res = cur.fetchone()[0]
110        self.assertEqual(res, 42)
111
112    def test_callproc_two_params(self):
113        con = self._connect()
114        cur = con.cursor()
115        params = (9, 4)
116        res = cur.callproc("div", params)
117        self.assertIs(res, params)
118        res = cur.fetchone()[0]
119        self.assertEqual(res, 2)
120
121    def test_cursor_type(self):
122
123        class TestCursor(pgdb.Cursor):
124            pass
125
126        con = self._connect()
127        self.assertIs(con.cursor_type, pgdb.Cursor)
128        cur = con.cursor()
129        self.assertIsInstance(cur, pgdb.Cursor)
130        self.assertNotIsInstance(cur, TestCursor)
131        con.cursor_type = TestCursor
132        cur = con.cursor()
133        self.assertIsInstance(cur, TestCursor)
134        cur = con.cursor()
135        self.assertIsInstance(cur, TestCursor)
136        con = self._connect()
137        self.assertIs(con.cursor_type, pgdb.Cursor)
138        cur = con.cursor()
139        self.assertIsInstance(cur, pgdb.Cursor)
140        self.assertNotIsInstance(cur, TestCursor)
141
142    def test_row_factory(self):
143
144        class TestCursor(pgdb.Cursor):
145
146            def row_factory(self, row):
147                return dict(('column %s' % desc[0], value)
148                    for desc, value in zip(self.description, row))
149
150        con = self._connect()
151        con.cursor_type = TestCursor
152        cur = con.cursor()
153        self.assertIsInstance(cur, TestCursor)
154        res = cur.execute("select 1 as a, 2 as b")
155        self.assertIs(res, cur, 'execute() should return cursor')
156        res = cur.fetchone()
157        self.assertIsInstance(res, dict)
158        self.assertEqual(res, {'column a': 1, 'column b': 2})
159        cur.execute("select 1 as a, 2 as b union select 3, 4 order by 1")
160        res = cur.fetchall()
161        self.assertIsInstance(res, list)
162        self.assertEqual(len(res), 2)
163        self.assertIsInstance(res[0], dict)
164        self.assertEqual(res[0], {'column a': 1, 'column b': 2})
165        self.assertIsInstance(res[1], dict)
166        self.assertEqual(res[1], {'column a': 3, 'column b': 4})
167
168    def test_build_row_factory(self):
169
170        class TestCursor(pgdb.Cursor):
171
172            def build_row_factory(self):
173                keys = [desc[0] for desc in self.description]
174                return lambda row: dict((key, value)
175                    for key, value in zip(keys, row))
176
177        con = self._connect()
178        con.cursor_type = TestCursor
179        cur = con.cursor()
180        self.assertIsInstance(cur, TestCursor)
181        cur.execute("select 1 as a, 2 as b")
182        res = cur.fetchone()
183        self.assertIsInstance(res, dict)
184        self.assertEqual(res, {'a': 1, 'b': 2})
185        cur.execute("select 1 as a, 2 as b union select 3, 4 order by 1")
186        res = cur.fetchall()
187        self.assertIsInstance(res, list)
188        self.assertEqual(len(res), 2)
189        self.assertIsInstance(res[0], dict)
190        self.assertEqual(res[0], {'a': 1, 'b': 2})
191        self.assertIsInstance(res[1], dict)
192        self.assertEqual(res[1], {'a': 3, 'b': 4})
193
194    def test_cursor_with_named_columns(self):
195        con = self._connect()
196        cur = con.cursor()
197        res = cur.execute("select 1 as abc, 2 as de, 3 as f")
198        self.assertIs(res, cur, 'execute() should return cursor')
199        res = cur.fetchone()
200        self.assertIsInstance(res, tuple)
201        self.assertEqual(res, (1, 2, 3))
202        self.assertEqual(res._fields, ('abc', 'de', 'f'))
203        self.assertEqual(res.abc, 1)
204        self.assertEqual(res.de, 2)
205        self.assertEqual(res.f, 3)
206        cur.execute("select 1 as one, 2 as two union select 3, 4 order by 1")
207        res = cur.fetchall()
208        self.assertIsInstance(res, list)
209        self.assertEqual(len(res), 2)
210        self.assertIsInstance(res[0], tuple)
211        self.assertEqual(res[0], (1, 2))
212        self.assertEqual(res[0]._fields, ('one', 'two'))
213        self.assertIsInstance(res[1], tuple)
214        self.assertEqual(res[1], (3, 4))
215        self.assertEqual(res[1]._fields, ('one', 'two'))
216
217    def test_cursor_with_unnamed_columns(self):
218        con = self._connect()
219        cur = con.cursor()
220        cur.execute("select 1, 2, 3")
221        res = cur.fetchone()
222        self.assertIsInstance(res, tuple)
223        self.assertEqual(res, (1, 2, 3))
224        old_py = OrderedDict is None  # Python 2.6 or 3.0
225        # old Python versions cannot rename tuple fields with underscore
226        if old_py:
227            self.assertEqual(res._fields, ('column_0', 'column_1', 'column_2'))
228        else:
229            self.assertEqual(res._fields, ('_0', '_1', '_2'))
230        cur.execute("select 1 as one, 2, 3 as three")
231        res = cur.fetchone()
232        self.assertIsInstance(res, tuple)
233        self.assertEqual(res, (1, 2, 3))
234        if old_py:  # cannot auto rename with underscore
235            self.assertEqual(res._fields, ('one', 'column_1', 'three'))
236        else:
237            self.assertEqual(res._fields, ('one', '_1', 'three'))
238        cur.execute("select 1 as abc, 2 as def")
239        res = cur.fetchone()
240        self.assertIsInstance(res, tuple)
241        self.assertEqual(res, (1, 2))
242        if old_py:
243            self.assertEqual(res._fields, ('column_0', 'column_1'))
244        else:
245            self.assertEqual(res._fields, ('abc', '_1'))
246
247    def test_colnames(self):
248        con = self._connect()
249        cur = con.cursor()
250        cur.execute("select 1, 2, 3")
251        names = cur.colnames
252        self.assertIsInstance(names, list)
253        self.assertEqual(names, ['?column?', '?column?', '?column?'])
254        cur.execute("select 1 as a, 2 as bc, 3 as def, 4 as g")
255        names = cur.colnames
256        self.assertIsInstance(names, list)
257        self.assertEqual(names, ['a', 'bc', 'def', 'g'])
258
259    def test_coltypes(self):
260        con = self._connect()
261        cur = con.cursor()
262        cur.execute("select 1::int2, 2::int4, 3::int8")
263        types = cur.coltypes
264        self.assertIsInstance(types, list)
265        self.assertEqual(types, ['int2', 'int4', 'int8'])
266
267    def test_description_fields(self):
268        con = self._connect()
269        cur = con.cursor()
270        cur.execute("select 123456789::int8 col0,"
271            " 123456.789::numeric(41, 13) as col1,"
272            " 'foobar'::char(39) as col2")
273        desc = cur.description
274        self.assertIsInstance(desc, list)
275        self.assertEqual(len(desc), 3)
276        cols = [('int8', 8, None), ('numeric', 41, 13), ('bpchar', 39, None)]
277        for i in range(3):
278            c, d = cols[i], desc[i]
279            self.assertIsInstance(d, tuple)
280            self.assertEqual(len(d), 7)
281            self.assertIsInstance(d.name, str)
282            self.assertEqual(d.name, 'col%d' % i)
283            self.assertIsInstance(d.type_code, str)
284            self.assertEqual(d.type_code, c[0])
285            self.assertIsNone(d.display_size)
286            self.assertIsInstance(d.internal_size, int)
287            self.assertEqual(d.internal_size, c[1])
288            if c[2] is not None:
289                self.assertIsInstance(d.precision, int)
290                self.assertEqual(d.precision, c[1])
291                self.assertIsInstance(d.scale, int)
292                self.assertEqual(d.scale, c[2])
293            else:
294                self.assertIsNone(d.precision)
295                self.assertIsNone(d.scale)
296            self.assertIsNone(d.null_ok)
297
298    def test_type_cache_info(self):
299        con = self._connect()
300        try:
301            cur = con.cursor()
302            type_cache = con.type_cache
303            self.assertNotIn('numeric', type_cache)
304            type_info = type_cache['numeric']
305            self.assertIn('numeric', type_cache)
306            self.assertEqual(type_info, 'numeric')
307            self.assertEqual(type_info.oid, 1700)
308            self.assertEqual(type_info.len, -1)
309            self.assertEqual(type_info.type, 'b')  # base
310            self.assertEqual(type_info.category, 'N')  # numeric
311            self.assertEqual(type_info.delim, ',')
312            self.assertEqual(type_info.relid, 0)
313            self.assertIs(con.type_cache[1700], type_info)
314            self.assertNotIn('pg_type', type_cache)
315            type_info = type_cache['pg_type']
316            self.assertIn('pg_type', type_cache)
317            self.assertEqual(type_info.type, 'c')  # composite
318            self.assertEqual(type_info.category, 'C')  # composite
319            cols = type_cache.get_fields('pg_type')
320            self.assertEqual(cols[0].name, 'typname')
321            typname = type_cache[cols[0].type]
322            self.assertEqual(typname, 'name')
323            self.assertEqual(typname.type, 'b')  # base
324            self.assertEqual(typname.category, 'S')  # string
325            self.assertEqual(cols[3].name, 'typlen')
326            typlen = type_cache[cols[3].type]
327            self.assertEqual(typlen, 'int2')
328            self.assertEqual(typlen.type, 'b')  # base
329            self.assertEqual(typlen.category, 'N')  # numeric
330            cur.close()
331            cur = con.cursor()
332            type_cache = con.type_cache
333            self.assertIn('numeric', type_cache)
334            cur.close()
335        finally:
336            con.close()
337        con = self._connect()
338        try:
339            cur = con.cursor()
340            type_cache = con.type_cache
341            self.assertNotIn('pg_type', type_cache)
342            self.assertEqual(type_cache.get('pg_type'), type_info)
343            self.assertIn('pg_type', type_cache)
344            self.assertIsNone(type_cache.get(
345                self.table_prefix + '_surely_does_not_exist'))
346            cur.close()
347        finally:
348            con.close()
349
350    def test_type_cache_typecast(self):
351        con = self._connect()
352        try:
353            cur = con.cursor()
354            type_cache = con.type_cache
355            self.assertIs(type_cache.get_typecast('int4'), int)
356            cast_int = lambda v: 'int(%s)' % v
357            type_cache.set_typecast('int4', cast_int)
358            query = 'select 2::int2, 4::int4, 8::int8'
359            cur.execute(query)
360            i2, i4, i8 = cur.fetchone()
361            self.assertEqual(i2, 2)
362            self.assertEqual(i4, 'int(4)')
363            self.assertEqual(i8, 8)
364            self.assertEqual(type_cache.typecast(42, 'int4'), 'int(42)')
365            type_cache.set_typecast(['int2', 'int8'], cast_int)
366            cur.execute(query)
367            i2, i4, i8 = cur.fetchone()
368            self.assertEqual(i2, 'int(2)')
369            self.assertEqual(i4, 'int(4)')
370            self.assertEqual(i8, 'int(8)')
371            type_cache.reset_typecast('int4')
372            cur.execute(query)
373            i2, i4, i8 = cur.fetchone()
374            self.assertEqual(i2, 'int(2)')
375            self.assertEqual(i4, 4)
376            self.assertEqual(i8, 'int(8)')
377            type_cache.reset_typecast(['int2', 'int8'])
378            cur.execute(query)
379            i2, i4, i8 = cur.fetchone()
380            self.assertEqual(i2, 2)
381            self.assertEqual(i4, 4)
382            self.assertEqual(i8, 8)
383            type_cache.set_typecast(['int2', 'int8'], cast_int)
384            cur.execute(query)
385            i2, i4, i8 = cur.fetchone()
386            self.assertEqual(i2, 'int(2)')
387            self.assertEqual(i4, 4)
388            self.assertEqual(i8, 'int(8)')
389            type_cache.reset_typecast()
390            cur.execute(query)
391            i2, i4, i8 = cur.fetchone()
392            self.assertEqual(i2, 2)
393            self.assertEqual(i4, 4)
394            self.assertEqual(i8, 8)
395            cur.close()
396        finally:
397            con.close()
398
399    def test_cursor_iteration(self):
400        con = self._connect()
401        cur = con.cursor()
402        cur.execute("select 1 union select 2 union select 3")
403        self.assertEqual([r[0] for r in cur], [1, 2, 3])
404
405    def test_fetch_2_rows(self):
406        Decimal = pgdb.decimal_type()
407        values = ('test', pgdb.Binary(b'\xff\x52\xb2'),
408            True, 5, 6, 5.7, Decimal('234.234234'), Decimal('75.45'),
409            pgdb.Date(2011, 7, 17), pgdb.Time(15, 47, 42),
410            pgdb.Timestamp(2008, 10, 20, 15, 25, 35),
411            pgdb.Interval(15, 31, 5), 7897234)
412        table = self.table_prefix + 'booze'
413        con = self._connect()
414        try:
415            cur = con.cursor()
416            cur.execute("set datestyle to iso")
417            cur.execute("create table %s ("
418                "stringtest varchar,"
419                "binarytest bytea,"
420                "booltest bool,"
421                "integertest int4,"
422                "longtest int8,"
423                "floattest float8,"
424                "numerictest numeric,"
425                "moneytest money,"
426                "datetest date,"
427                "timetest time,"
428                "datetimetest timestamp,"
429                "intervaltest interval,"
430                "rowidtest oid)" % table)
431            cur.execute("set standard_conforming_strings to on")
432            for s in ('numeric', 'monetary', 'time'):
433                cur.execute("set lc_%s to 'C'" % s)
434            for _i in range(2):
435                cur.execute("insert into %s values ("
436                    "%%s,%%s,%%s,%%s,%%s,%%s,%%s,"
437                    "'%%s'::money,%%s,%%s,%%s,%%s,%%s)" % table, values)
438            cur.execute("select * from %s" % table)
439            rows = cur.fetchall()
440            self.assertEqual(len(rows), 2)
441            row0 = rows[0]
442            self.assertEqual(row0, values)
443            self.assertEqual(row0, rows[1])
444            self.assertIsInstance(row0[0], str)
445            self.assertIsInstance(row0[1], bytes)
446            self.assertIsInstance(row0[2], bool)
447            self.assertIsInstance(row0[3], int)
448            self.assertIsInstance(row0[4], long)
449            self.assertIsInstance(row0[5], float)
450            self.assertIsInstance(row0[6], Decimal)
451            self.assertIsInstance(row0[7], Decimal)
452            self.assertIsInstance(row0[8], date)
453            self.assertIsInstance(row0[9], time)
454            self.assertIsInstance(row0[10], datetime)
455            self.assertIsInstance(row0[11], timedelta)
456        finally:
457            con.close()
458
459    def test_sqlstate(self):
460        con = self._connect()
461        cur = con.cursor()
462        try:
463            cur.execute("select 1/0")
464        except pgdb.DatabaseError as error:
465            self.assertTrue(isinstance(error, pgdb.ProgrammingError))
466            # the SQLSTATE error code for division by zero is 22012
467            self.assertEqual(error.sqlstate, '22012')
468
469    def test_float(self):
470        nan, inf = float('nan'), float('inf')
471        from math import isnan, isinf
472        self.assertTrue(isnan(nan) and not isinf(nan))
473        self.assertTrue(isinf(inf) and not isnan(inf))
474        values = [0, 1, 0.03125, -42.53125, nan, inf, -inf,
475            'nan', 'inf', '-inf', 'NaN', 'Infinity', '-Infinity']
476        table = self.table_prefix + 'booze'
477        con = self._connect()
478        try:
479            cur = con.cursor()
480            cur.execute(
481                "create table %s (n smallint, floattest float)" % table)
482            params = enumerate(values)
483            cur.executemany("insert into %s values (%%d,%%s)" % table, params)
484            cur.execute("select floattest from %s order by n" % table)
485            rows = cur.fetchall()
486            self.assertEqual(cur.description[0].type_code, pgdb.FLOAT)
487            self.assertNotEqual(cur.description[0].type_code, pgdb.ARRAY)
488            self.assertNotEqual(cur.description[0].type_code, pgdb.RECORD)
489        finally:
490            con.close()
491        self.assertEqual(len(rows), len(values))
492        rows = [row[0] for row in rows]
493        for inval, outval in zip(values, rows):
494            if inval in ('inf', 'Infinity'):
495                inval = inf
496            elif inval in ('-inf', '-Infinity'):
497                inval = -inf
498            elif inval in ('nan', 'NaN'):
499                inval = nan
500            if isinf(inval):
501                self.assertTrue(isinf(outval))
502                if inval < 0:
503                    self.assertTrue(outval < 0)
504                else:
505                    self.assertTrue(outval > 0)
506            elif isnan(inval):
507                self.assertTrue(isnan(outval))
508            else:
509                self.assertEqual(inval, outval)
510
511    def test_datetime(self):
512        dt = datetime(2011, 7, 17, 15, 47, 42, 317509)
513        table = self.table_prefix + 'booze'
514        con = self._connect()
515        try:
516            cur = con.cursor()
517            cur.execute("create table %s ("
518                "d date, t time,  ts timestamp,"
519                "tz timetz, tsz timestamptz)" % table)
520            for n in range(3):
521                values = [dt.date(), dt.time(), dt,
522                    dt.time(), dt]
523                if timezone:
524                    values[3] = values[3].replace(tzinfo=timezone.utc)
525                    values[4] = values[4].replace(tzinfo=timezone.utc)
526                if n == 0:  # input as objects
527                    params = values
528                if n == 1:  # input as text
529                    params = [v.isoformat() for v in values]  # as text
530                elif n == 2:  # input using type helpers
531                    d = (dt.year, dt.month, dt.day)
532                    t = (dt.hour, dt.minute, dt.second, dt.microsecond)
533                    params = [pgdb.Date(*d), pgdb.Time(*t),
534                            pgdb.Timestamp(*(d + t)), pgdb.Time(*t),
535                            pgdb.Timestamp(*(d + t))]
536                for datestyle in ('iso', 'postgres, mdy', 'postgres, dmy',
537                        'sql, mdy', 'sql, dmy', 'german'):
538                    cur.execute("set datestyle to %s" % datestyle)
539                    cur.execute("insert into %s"
540                        " values (%%s,%%s,%%s,%%s,%%s)" % table, params)
541                    cur.execute("select * from %s" % table)
542                    d = cur.description
543                    for i in range(5):
544                        self.assertEqual(d[i].type_code, pgdb.DATETIME)
545                        self.assertNotEqual(d[i].type_code, pgdb.STRING)
546                        self.assertNotEqual(d[i].type_code, pgdb.ARRAY)
547                        self.assertNotEqual(d[i].type_code, pgdb.RECORD)
548                    self.assertEqual(d[0].type_code, pgdb.DATE)
549                    self.assertEqual(d[1].type_code, pgdb.TIME)
550                    self.assertEqual(d[2].type_code, pgdb.TIMESTAMP)
551                    self.assertEqual(d[3].type_code, pgdb.TIME)
552                    self.assertEqual(d[4].type_code, pgdb.TIMESTAMP)
553                    row = cur.fetchone()
554                    self.assertEqual(row, tuple(values))
555                    cur.execute("delete from %s" % table)
556        finally:
557            con.close()
558
559    def test_interval(self):
560        td = datetime(2011, 7, 17, 15, 47, 42, 317509) - datetime(1970, 1, 1)
561        table = self.table_prefix + 'booze'
562        con = self._connect()
563        try:
564            cur = con.cursor()
565            cur.execute("create table %s (i interval)" % table)
566            for n in range(3):
567                if n == 0:  # input as objects
568                    param = td
569                if n == 1:  # input as text
570                    param = '%d days %d seconds %d microseconds ' % (
571                        td.days, td.seconds, td.microseconds)
572                elif n == 2:  # input using type helpers
573                    param = pgdb.Interval(
574                        td.days, 0, 0, td.seconds, td.microseconds)
575                for intervalstyle in ('sql_standard ', 'postgres',
576                        'postgres_verbose', 'iso_8601'):
577                    cur.execute("set intervalstyle to %s" % intervalstyle)
578                    cur.execute("insert into %s"
579                        " values (%%s)" % table, [param])
580                    cur.execute("select * from %s" % table)
581                    tc = cur.description[0].type_code
582                    self.assertEqual(tc, pgdb.DATETIME)
583                    self.assertNotEqual(tc, pgdb.STRING)
584                    self.assertNotEqual(tc, pgdb.ARRAY)
585                    self.assertNotEqual(tc, pgdb.RECORD)
586                    self.assertEqual(tc, pgdb.INTERVAL)
587                    row = cur.fetchone()
588                    self.assertEqual(row, (td,))
589                    cur.execute("delete from %s" % table)
590        finally:
591            con.close()
592
593    def test_hstore(self):
594        con = self._connect()
595        try:
596            cur = con.cursor()
597            cur.execute("select 'k=>v'::hstore")
598        except pgdb.ProgrammingError:
599            try:
600                cur.execute("create extension hstore")
601            except pgdb.ProgrammingError:
602                self.skipTest("hstore extension not enabled")
603        finally:
604            con.close()
605        d = {'k': 'v', 'foo': 'bar', 'baz': 'whatever',
606            '1a': 'anything at all', '2=b': 'value = 2', '3>c': 'value > 3',
607            '4"c': 'value " 4', "5'c": "value ' 5", 'hello, world': '"hi!"',
608            'None': None, 'NULL': 'NULL', 'empty': ''}
609        con = self._connect()
610        try:
611            cur = con.cursor()
612            cur.execute("select %s::hstore", (pgdb.Hstore(d),))
613            result = cur.fetchone()[0]
614        finally:
615            con.close()
616        self.assertIsInstance(result, dict)
617        self.assertEqual(result, d)
618
619    def test_insert_array(self):
620        values = [(None, None), ([], []), ([None], [[None], ['null']]),
621            ([1, 2, 3], [['a', 'b'], ['c', 'd']]),
622            ([20000, 25000, 25000, 30000],
623            [['breakfast', 'consulting'], ['meeting', 'lunch']]),
624            ([0, 1, -1], [['Hello, World!', '"Hi!"'], ['{x,y}', ' x y ']])]
625        table = self.table_prefix + 'booze'
626        con = self._connect()
627        try:
628            cur = con.cursor()
629            cur.execute("create table %s"
630                " (n smallint, i int[], t text[][])" % table)
631            params = [(n, v[0], v[1]) for n, v in enumerate(values)]
632            # Note that we must explicit casts because we are inserting
633            # empty arrays.  Otherwise this is not necessary.
634            cur.executemany("insert into %s values"
635                " (%%d,%%s::int[],%%s::text[][])" % table, params)
636            cur.execute("select i, t from %s order by n" % table)
637            d = cur.description
638            self.assertEqual(d[0].type_code, pgdb.ARRAY)
639            self.assertNotEqual(d[0].type_code, pgdb.RECORD)
640            self.assertEqual(d[0].type_code, pgdb.NUMBER)
641            self.assertEqual(d[0].type_code, pgdb.INTEGER)
642            self.assertEqual(d[1].type_code, pgdb.ARRAY)
643            self.assertNotEqual(d[1].type_code, pgdb.RECORD)
644            self.assertEqual(d[1].type_code, pgdb.STRING)
645            rows = cur.fetchall()
646        finally:
647            con.close()
648        self.assertEqual(rows, values)
649
650    def test_select_array(self):
651        values = ([1, 2, 3, None], ['a', 'b', 'c', None])
652        con = self._connect()
653        try:
654            cur = con.cursor()
655            cur.execute("select %s::int[], %s::text[]", values)
656            row = cur.fetchone()
657        finally:
658            con.close()
659        self.assertEqual(row, values)
660
661    def test_insert_record(self):
662        values = [('John', 61), ('Jane', 63),
663                  ('Fred', None), ('Wilma', None),
664                  (None, 42), (None, None)]
665        table = self.table_prefix + 'booze'
666        record = self.table_prefix + 'munch'
667        con = self._connect()
668        try:
669            cur = con.cursor()
670            cur.execute("create type %s as (name varchar, age int)" % record)
671            cur.execute("create table %s (n smallint, r %s)" % (table, record))
672            params = enumerate(values)
673            cur.executemany("insert into %s values (%%d,%%s)" % table, params)
674            cur.execute("select r from %s order by n" % table)
675            type_code = cur.description[0].type_code
676            self.assertEqual(type_code, record)
677            self.assertEqual(type_code, pgdb.RECORD)
678            self.assertNotEqual(type_code, pgdb.ARRAY)
679            columns = con.type_cache.get_fields(type_code)
680            self.assertEqual(columns[0].name, 'name')
681            self.assertEqual(columns[1].name, 'age')
682            self.assertEqual(con.type_cache[columns[0].type], 'varchar')
683            self.assertEqual(con.type_cache[columns[1].type], 'int4')
684            rows = cur.fetchall()
685        finally:
686            cur.execute('drop table %s' % table)
687            cur.execute('drop type %s' % record)
688            con.close()
689        self.assertEqual(len(rows), len(values))
690        rows = [row[0] for row in rows]
691        self.assertEqual(rows, values)
692        self.assertEqual(rows[0].name, 'John')
693        self.assertEqual(rows[0].age, 61)
694
695    def test_select_record(self):
696        value = (1, 25000, 2.5, 'hello', 'Hello World!', 'Hello, World!',
697            '(test)', '(x,y)', ' x y ', 'null', None)
698        con = self._connect()
699        try:
700            cur = con.cursor()
701            cur.execute("select %s as test_record", [value])
702            self.assertEqual(cur.description[0].name, 'test_record')
703            self.assertEqual(cur.description[0].type_code, 'record')
704            row = cur.fetchone()[0]
705        finally:
706            con.close()
707        # Note that the element types get lost since we created an
708        # untyped record (an anonymous composite type). For the same
709        # reason this is also a normal tuple, not a named tuple.
710        text_row = tuple(None if v is None else str(v) for v in value)
711        self.assertEqual(row, text_row)
712
713    def test_custom_type(self):
714        values = [3, 5, 65]
715        values = list(map(PgBitString, values))
716        table = self.table_prefix + 'booze'
717        con = self._connect()
718        try:
719            cur = con.cursor()
720            params = enumerate(values)  # params have __pg_repr__ method
721            cur.execute(
722                'create table "%s" (n smallint, b bit varying(7))' % table)
723            cur.executemany("insert into %s values (%%s,%%s)" % table, params)
724            cur.execute("select * from %s" % table)
725            rows = cur.fetchall()
726        finally:
727            con.close()
728        self.assertEqual(len(rows), len(values))
729        con = self._connect()
730        try:
731            cur = con.cursor()
732            params = (1, object())  # an object that cannot be handled
733            self.assertRaises(pgdb.InterfaceError, cur.execute,
734                "insert into %s values (%%s,%%s)" % table, params)
735        finally:
736            con.close()
737
738    def test_set_decimal_type(self):
739        decimal_type = pgdb.decimal_type()
740        self.assertTrue(decimal_type is not None and callable(decimal_type))
741        con = self._connect()
742        try:
743            cur = con.cursor()
744            # change decimal type globally to int
745            int_type = lambda v: int(float(v))
746            self.assertTrue(pgdb.decimal_type(int_type) is int_type)
747            cur.execute('select 4.25')
748            self.assertEqual(cur.description[0].type_code, pgdb.NUMBER)
749            value = cur.fetchone()[0]
750            self.assertTrue(isinstance(value, int))
751            self.assertEqual(value, 4)
752            # change decimal type again to float
753            self.assertTrue(pgdb.decimal_type(float) is float)
754            cur.execute('select 4.25')
755            self.assertEqual(cur.description[0].type_code, pgdb.NUMBER)
756            value = cur.fetchone()[0]
757            # the connection still uses the old setting
758            self.assertTrue(isinstance(value, int))
759            # bust the cache for type functions for the connection
760            con.type_cache.reset_typecast()
761            cur.execute('select 4.25')
762            self.assertEqual(cur.description[0].type_code, pgdb.NUMBER)
763            value = cur.fetchone()[0]
764            # now the connection uses the new setting
765            self.assertTrue(isinstance(value, float))
766            self.assertEqual(value, 4.25)
767        finally:
768            con.close()
769            pgdb.decimal_type(decimal_type)
770        self.assertTrue(pgdb.decimal_type() is decimal_type)
771
772    def test_global_typecast(self):
773        try:
774            query = 'select 2::int2, 4::int4, 8::int8'
775            self.assertIs(pgdb.get_typecast('int4'), int)
776            cast_int = lambda v: 'int(%s)' % v
777            pgdb.set_typecast('int4', cast_int)
778            con = self._connect()
779            try:
780                i2, i4, i8 = con.cursor().execute(query).fetchone()
781            finally:
782                con.close()
783            self.assertEqual(i2, 2)
784            self.assertEqual(i4, 'int(4)')
785            self.assertEqual(i8, 8)
786            pgdb.set_typecast(['int2', 'int8'], cast_int)
787            con = self._connect()
788            try:
789                i2, i4, i8 = con.cursor().execute(query).fetchone()
790            finally:
791                con.close()
792            self.assertEqual(i2, 'int(2)')
793            self.assertEqual(i4, 'int(4)')
794            self.assertEqual(i8, 'int(8)')
795            pgdb.reset_typecast('int4')
796            con = self._connect()
797            try:
798                i2, i4, i8 = con.cursor().execute(query).fetchone()
799            finally:
800                con.close()
801            self.assertEqual(i2, 'int(2)')
802            self.assertEqual(i4, 4)
803            self.assertEqual(i8, 'int(8)')
804            pgdb.reset_typecast(['int2', 'int8'])
805            con = self._connect()
806            try:
807                i2, i4, i8 = con.cursor().execute(query).fetchone()
808            finally:
809                con.close()
810            self.assertEqual(i2, 2)
811            self.assertEqual(i4, 4)
812            self.assertEqual(i8, 8)
813            pgdb.set_typecast(['int2', 'int8'], cast_int)
814            con = self._connect()
815            try:
816                i2, i4, i8 = con.cursor().execute(query).fetchone()
817            finally:
818                con.close()
819            self.assertEqual(i2, 'int(2)')
820            self.assertEqual(i4, 4)
821            self.assertEqual(i8, 'int(8)')
822        finally:
823            pgdb.reset_typecast()
824        con = self._connect()
825        try:
826            i2, i4, i8 = con.cursor().execute(query).fetchone()
827        finally:
828            con.close()
829        self.assertEqual(i2, 2)
830        self.assertEqual(i4, 4)
831        self.assertEqual(i8, 8)
832
833    def test_unicode_with_utf8(self):
834        table = self.table_prefix + 'booze'
835        input = u"He wes Leovenaðes sone — liðe him be Drihten"
836        con = self._connect()
837        try:
838            cur = con.cursor()
839            cur.execute("create table %s (t text)" % table)
840            try:
841                cur.execute("set client_encoding=utf8")
842                cur.execute(u"select '%s'" % input)
843            except Exception:
844                self.skipTest("database does not support utf8")
845            output1 = cur.fetchone()[0]
846            cur.execute("insert into %s values (%%s)" % table, (input,))
847            cur.execute("select * from %s" % table)
848            output2 = cur.fetchone()[0]
849            cur.execute("select t = '%s' from %s" % (input, table))
850            output3 = cur.fetchone()[0]
851            cur.execute("select t = %%s from %s" % table, (input,))
852            output4 = cur.fetchone()[0]
853        finally:
854            con.close()
855        if str is bytes:  # Python < 3.0
856            input = input.encode('utf8')
857        self.assertIsInstance(output1, str)
858        self.assertEqual(output1, input)
859        self.assertIsInstance(output2, str)
860        self.assertEqual(output2, input)
861        self.assertIsInstance(output3, bool)
862        self.assertTrue(output3)
863        self.assertIsInstance(output4, bool)
864        self.assertTrue(output4)
865
866    def test_unicode_with_latin1(self):
867        table = self.table_prefix + 'booze'
868        input = u"Ehrt den König seine WÃŒrde, ehret uns der HÀnde Fleiß."
869        con = self._connect()
870        try:
871            cur = con.cursor()
872            cur.execute("create table %s (t text)" % table)
873            try:
874                cur.execute("set client_encoding=latin1")
875                cur.execute(u"select '%s'" % input)
876            except Exception:
877                self.skipTest("database does not support latin1")
878            output1 = cur.fetchone()[0]
879            cur.execute("insert into %s values (%%s)" % table, (input,))
880            cur.execute("select * from %s" % table)
881            output2 = cur.fetchone()[0]
882            cur.execute("select t = '%s' from %s" % (input, table))
883            output3 = cur.fetchone()[0]
884            cur.execute("select t = %%s from %s" % table, (input,))
885            output4 = cur.fetchone()[0]
886        finally:
887            con.close()
888        if str is bytes:  # Python < 3.0
889            input = input.encode('latin1')
890        self.assertIsInstance(output1, str)
891        self.assertEqual(output1, input)
892        self.assertIsInstance(output2, str)
893        self.assertEqual(output2, input)
894        self.assertIsInstance(output3, bool)
895        self.assertTrue(output3)
896        self.assertIsInstance(output4, bool)
897        self.assertTrue(output4)
898
899    def test_bool(self):
900        values = [False, True, None, 't', 'f', 'true', 'false']
901        table = self.table_prefix + 'booze'
902        con = self._connect()
903        try:
904            cur = con.cursor()
905            cur.execute(
906                "create table %s (n smallint, booltest bool)" % table)
907            params = enumerate(values)
908            cur.executemany("insert into %s values (%%s,%%s)" % table, params)
909            cur.execute("select booltest from %s order by n" % table)
910            rows = cur.fetchall()
911            self.assertEqual(cur.description[0].type_code, pgdb.BOOL)
912        finally:
913            con.close()
914        rows = [row[0] for row in rows]
915        values[3] = values[5] = True
916        values[4] = values[6] = False
917        self.assertEqual(rows, values)
918
919    def test_literal(self):
920        con = self._connect()
921        try:
922            cur = con.cursor()
923            value = "lower('Hello')"
924            cur.execute("select %s, %s", (value, pgdb.Literal(value)))
925            row = cur.fetchone()
926        finally:
927            con.close()
928        self.assertEqual(row, (value, 'hello'))
929
930
931    def test_json(self):
932        inval = {"employees":
933            [{"firstName": "John", "lastName": "Doe", "age": 61}]}
934        table = self.table_prefix + 'booze'
935        con = self._connect()
936        try:
937            cur = con.cursor()
938            try:
939                cur.execute("create table %s (jsontest json)" % table)
940            except pgdb.ProgrammingError:
941                self.skipTest('database does not support json')
942            params = (pgdb.Json(inval),)
943            cur.execute("insert into %s values (%%s)" % table, params)
944            cur.execute("select jsontest from %s" % table)
945            outval = cur.fetchone()[0]
946            self.assertEqual(cur.description[0].type_code, pgdb.JSON)
947        finally:
948            con.close()
949        self.assertEqual(inval, outval)
950
951    def test_jsonb(self):
952        inval = {"employees":
953            [{"firstName": "John", "lastName": "Doe", "age": 61}]}
954        table = self.table_prefix + 'booze'
955        con = self._connect()
956        try:
957            cur = con.cursor()
958            try:
959                cur.execute("create table %s (jsonbtest jsonb)" % table)
960            except pgdb.ProgrammingError:
961                self.skipTest('database does not support jsonb')
962            params = (pgdb.Json(inval),)
963            cur.execute("insert into %s values (%%s)" % table, params)
964            cur.execute("select jsonbtest from %s" % table)
965            outval = cur.fetchone()[0]
966            self.assertEqual(cur.description[0].type_code, pgdb.JSON)
967        finally:
968            con.close()
969        self.assertEqual(inval, outval)
970
971    def test_execute_edge_cases(self):
972        con = self._connect()
973        try:
974            cur = con.cursor()
975            sql = 'invalid'  # should be ignored with empty parameter list
976            cur.executemany(sql, [])
977            sql = 'select %d + 1'
978            cur.execute(sql, [(1,), (2,)])  # deprecated use of execute()
979            self.assertEqual(cur.fetchone()[0], 3)
980            sql = 'select 1/0'  # cannot be executed
981            self.assertRaises(pgdb.ProgrammingError, cur.execute, sql)
982            cur.close()
983            con.rollback()
984            if pgdb.shortcutmethods:
985                res = con.execute('select %d', (1,)).fetchone()
986                self.assertEqual(res, (1,))
987                res = con.executemany('select %d', [(1,), (2,)]).fetchone()
988                self.assertEqual(res, (2,))
989        finally:
990            con.close()
991        sql = 'select 1'  # cannot be executed after connection is closed
992        self.assertRaises(pgdb.OperationalError, cur.execute, sql)
993
994    def test_fetchmany_with_keep(self):
995        con = self._connect()
996        try:
997            cur = con.cursor()
998            self.assertEqual(cur.arraysize, 1)
999            cur.execute('select * from generate_series(1, 25)')
1000            self.assertEqual(len(cur.fetchmany()), 1)
1001            self.assertEqual(len(cur.fetchmany()), 1)
1002            self.assertEqual(cur.arraysize, 1)
1003            cur.arraysize = 3
1004            self.assertEqual(len(cur.fetchmany()), 3)
1005            self.assertEqual(len(cur.fetchmany()), 3)
1006            self.assertEqual(cur.arraysize, 3)
1007            self.assertEqual(len(cur.fetchmany(size=2)), 2)
1008            self.assertEqual(cur.arraysize, 3)
1009            self.assertEqual(len(cur.fetchmany()), 3)
1010            self.assertEqual(len(cur.fetchmany()), 3)
1011            self.assertEqual(len(cur.fetchmany(size=2, keep=True)), 2)
1012            self.assertEqual(cur.arraysize, 2)
1013            self.assertEqual(len(cur.fetchmany()), 2)
1014            self.assertEqual(len(cur.fetchmany()), 2)
1015            self.assertEqual(len(cur.fetchmany(25)), 3)
1016        finally:
1017            con.close()
1018
1019    def test_nextset(self):
1020        con = self._connect()
1021        cur = con.cursor()
1022        self.assertRaises(con.NotSupportedError, cur.nextset)
1023
1024    def test_setoutputsize(self):
1025        pass  # not supported
1026
1027    def test_connection_errors(self):
1028        con = self._connect()
1029        self.assertEqual(con.Error, pgdb.Error)
1030        self.assertEqual(con.Warning, pgdb.Warning)
1031        self.assertEqual(con.InterfaceError, pgdb.InterfaceError)
1032        self.assertEqual(con.DatabaseError, pgdb.DatabaseError)
1033        self.assertEqual(con.InternalError, pgdb.InternalError)
1034        self.assertEqual(con.OperationalError, pgdb.OperationalError)
1035        self.assertEqual(con.ProgrammingError, pgdb.ProgrammingError)
1036        self.assertEqual(con.IntegrityError, pgdb.IntegrityError)
1037        self.assertEqual(con.DataError, pgdb.DataError)
1038        self.assertEqual(con.NotSupportedError, pgdb.NotSupportedError)
1039
1040    def test_connection_as_contextmanager(self):
1041        table = self.table_prefix + 'booze'
1042        con = self._connect()
1043        try:
1044            cur = con.cursor()
1045            cur.execute("create table %s (n smallint check(n!=4))" % table)
1046            with con:
1047                cur.execute("insert into %s values (1)" % table)
1048                cur.execute("insert into %s values (2)" % table)
1049            try:
1050                with con:
1051                    cur.execute("insert into %s values (3)" % table)
1052                    cur.execute("insert into %s values (4)" % table)
1053            except con.ProgrammingError as error:
1054                self.assertTrue('check' in str(error).lower())
1055            with con:
1056                cur.execute("insert into %s values (5)" % table)
1057                cur.execute("insert into %s values (6)" % table)
1058            try:
1059                with con:
1060                    cur.execute("insert into %s values (7)" % table)
1061                    cur.execute("insert into %s values (8)" % table)
1062                    raise ValueError('transaction should rollback')
1063            except ValueError as error:
1064                self.assertEqual(str(error), 'transaction should rollback')
1065            with con:
1066                cur.execute("insert into %s values (9)" % table)
1067            cur.execute("select * from %s order by 1" % table)
1068            rows = cur.fetchall()
1069            rows = [row[0] for row in rows]
1070        finally:
1071            con.close()
1072        self.assertEqual(rows, [1, 2, 5, 6, 9])
1073
1074    def test_cursor_connection(self):
1075        con = self._connect()
1076        cur = con.cursor()
1077        self.assertEqual(cur.connection, con)
1078        cur.close()
1079
1080    def test_cursor_as_contextmanager(self):
1081        con = self._connect()
1082        with con.cursor() as cur:
1083            self.assertEqual(cur.connection, con)
1084
1085    def test_pgdb_type(self):
1086        self.assertEqual(pgdb.STRING, pgdb.STRING)
1087        self.assertNotEqual(pgdb.STRING, pgdb.INTEGER)
1088        self.assertNotEqual(pgdb.STRING, pgdb.BOOL)
1089        self.assertNotEqual(pgdb.BOOL, pgdb.INTEGER)
1090        self.assertEqual(pgdb.INTEGER, pgdb.INTEGER)
1091        self.assertNotEqual(pgdb.INTEGER, pgdb.NUMBER)
1092        self.assertEqual('char', pgdb.STRING)
1093        self.assertEqual('varchar', pgdb.STRING)
1094        self.assertEqual('text', pgdb.STRING)
1095        self.assertNotEqual('numeric', pgdb.STRING)
1096        self.assertEqual('numeric', pgdb.NUMERIC)
1097        self.assertEqual('numeric', pgdb.NUMBER)
1098        self.assertEqual('int4', pgdb.NUMBER)
1099        self.assertNotEqual('int4', pgdb.NUMERIC)
1100        self.assertEqual('int2', pgdb.SMALLINT)
1101        self.assertNotEqual('int4', pgdb.SMALLINT)
1102        self.assertEqual('int2', pgdb.INTEGER)
1103        self.assertEqual('int4', pgdb.INTEGER)
1104        self.assertEqual('int8', pgdb.INTEGER)
1105        self.assertNotEqual('int4', pgdb.LONG)
1106        self.assertEqual('int8', pgdb.LONG)
1107        self.assertTrue('char' in pgdb.STRING)
1108        self.assertTrue(pgdb.NUMERIC <= pgdb.NUMBER)
1109        self.assertTrue(pgdb.NUMBER >= pgdb.INTEGER)
1110        self.assertTrue(pgdb.TIME <= pgdb.DATETIME)
1111        self.assertTrue(pgdb.DATETIME >= pgdb.DATE)
1112        self.assertEqual(pgdb.ARRAY, pgdb.ARRAY)
1113        self.assertNotEqual(pgdb.ARRAY, pgdb.STRING)
1114        self.assertEqual('_char', pgdb.ARRAY)
1115        self.assertNotEqual('char', pgdb.ARRAY)
1116        self.assertEqual(pgdb.RECORD, pgdb.RECORD)
1117        self.assertNotEqual(pgdb.RECORD, pgdb.STRING)
1118        self.assertNotEqual(pgdb.RECORD, pgdb.ARRAY)
1119        self.assertEqual('record', pgdb.RECORD)
1120        self.assertNotEqual('_record', pgdb.RECORD)
1121
1122
1123if __name__ == '__main__':
1124    unittest.main()
Note: See TracBrowser for help on using the repository browser.