source: trunk/tests/test_dbapi20.py

Last change on this file was 1025, checked in by cito, 6 weeks ago

Make tests run with PostgreSQL 12

Skip/adapt tests that use tables with oids

  • Property svn:executable set to *
  • Property svn:keywords set to Id
File size: 54.8 KB
Line 
1#!/usr/bin/python
2# -*- coding: utf-8 -*-
3# $Id: test_dbapi20.py 1025 2019-10-03 23:35:34Z cito $
4
5try:
6    import unittest2 as unittest  # for Python < 2.7
7except ImportError:
8    import unittest
9
10import pgdb
11
12try:
13    from . import dbapi20
14except (ImportError, ValueError, SystemError):
15    import dbapi20
16
17# We need a database to test against.
18# If LOCAL_PyGreSQL.py exists we will get our information from that.
19# Otherwise we use the defaults.
20dbname = 'dbapi20_test'
21dbhost = ''
22dbport = 5432
23try:
24    from .LOCAL_PyGreSQL import *
25except (ImportError, ValueError):
26    try:
27        from LOCAL_PyGreSQL import *
28    except ImportError:
29        pass
30
31import gc
32import sys
33
34from datetime import date, time, datetime, timedelta
35from uuid import UUID as Uuid
36
37try:  # noinspection PyUnresolvedReferences
38    long
39except NameError:  # Python >= 3.0
40    long = int
41
42try:
43    from collections import OrderedDict
44except ImportError:  # Python 2.6 or 3.0
45    OrderedDict = None
46
47
48class PgBitString:
49    """Test object with a PostgreSQL representation as Bit String."""
50
51    def __init__(self, value):
52        self.value = value
53
54    def __pg_repr__(self):
55         return "B'{0:b}'".format(self.value)
56
57
58class test_PyGreSQL(dbapi20.DatabaseAPI20Test):
59
60    driver = pgdb
61    connect_args = ()
62    connect_kw_args = {'database': dbname,
63        'host': '%s:%d' % (dbhost or '', dbport or -1)}
64
65    lower_func = 'lower'  # For stored procedure test
66
67    def setUp(self):
68        # Call superclass setUp in case this does something in the future
69        dbapi20.DatabaseAPI20Test.setUp(self)
70        try:
71            con = self._connect()
72            con.close()
73        except pgdb.Error:  # try to create a missing database
74            import pg
75            try:  # first try to log in as superuser
76                db = pg.DB('postgres', dbhost or None, dbport or -1,
77                    user='postgres')
78            except Exception:  # then try to log in as current user
79                db = pg.DB('postgres', dbhost or None, dbport or -1)
80            db.query('create database ' + dbname)
81
82    def tearDown(self):
83        dbapi20.DatabaseAPI20Test.tearDown(self)
84
85    def test_version(self):
86        v = pgdb.version
87        self.assertIsInstance(v, str)
88        self.assertIn('.', v)
89        self.assertEqual(pgdb.__version__, v)
90
91    def test_connect_kwargs(self):
92        application_name = 'PyGreSQL DB API 2.0 Test'
93        self.connect_kw_args['application_name'] = application_name
94        con = self._connect()
95        cur = con.cursor()
96        cur.execute("select application_name from pg_stat_activity"
97            " where application_name = %s", (application_name,))
98        self.assertEqual(cur.fetchone(), (application_name,))
99
100    def test_percent_sign(self):
101        con = self._connect()
102        cur = con.cursor()
103        cur.execute("select %s, 'a %% sign'", ('a % sign',))
104        self.assertEqual(cur.fetchone(), ('a % sign', 'a % sign'))
105        cur.execute("select 'a % sign'")
106        self.assertEqual(cur.fetchone(), ('a % sign',))
107        cur.execute("select 'a %% sign'")
108        self.assertEqual(cur.fetchone(), ('a % sign',))
109
110    def test_callproc_no_params(self):
111        con = self._connect()
112        cur = con.cursor()
113        # note that now() does not change within a transaction
114        cur.execute('select now()')
115        now = cur.fetchone()[0]
116        res = cur.callproc('now')
117        self.assertIsNone(res)
118        res = cur.fetchone()[0]
119        self.assertEqual(res, now)
120
121    def test_callproc_bad_params(self):
122        con = self._connect()
123        cur = con.cursor()
124        self.assertRaises(TypeError, cur.callproc, 'lower', 42)
125        self.assertRaises(pgdb.ProgrammingError, cur.callproc, 'lower', (42,))
126
127    def test_callproc_one_param(self):
128        con = self._connect()
129        cur = con.cursor()
130        params = (42.4382,)
131        res = cur.callproc("round", params)
132        self.assertIs(res, params)
133        res = cur.fetchone()[0]
134        self.assertEqual(res, 42)
135
136    def test_callproc_two_params(self):
137        con = self._connect()
138        cur = con.cursor()
139        params = (9, 4)
140        res = cur.callproc("div", params)
141        self.assertIs(res, params)
142        res = cur.fetchone()[0]
143        self.assertEqual(res, 2)
144
145    def test_cursor_type(self):
146
147        class TestCursor(pgdb.Cursor):
148            pass
149
150        con = self._connect()
151        self.assertIs(con.cursor_type, pgdb.Cursor)
152        cur = con.cursor()
153        self.assertIsInstance(cur, pgdb.Cursor)
154        self.assertNotIsInstance(cur, TestCursor)
155        con.cursor_type = TestCursor
156        cur = con.cursor()
157        self.assertIsInstance(cur, TestCursor)
158        cur = con.cursor()
159        self.assertIsInstance(cur, TestCursor)
160        con = self._connect()
161        self.assertIs(con.cursor_type, pgdb.Cursor)
162        cur = con.cursor()
163        self.assertIsInstance(cur, pgdb.Cursor)
164        self.assertNotIsInstance(cur, TestCursor)
165
166    def test_row_factory(self):
167
168        class TestCursor(pgdb.Cursor):
169
170            def row_factory(self, row):
171                return dict(('column %s' % desc[0], value)
172                    for desc, value in zip(self.description, row))
173
174        con = self._connect()
175        con.cursor_type = TestCursor
176        cur = con.cursor()
177        self.assertIsInstance(cur, TestCursor)
178        res = cur.execute("select 1 as a, 2 as b")
179        self.assertIs(res, cur, 'execute() should return cursor')
180        res = cur.fetchone()
181        self.assertIsInstance(res, dict)
182        self.assertEqual(res, {'column a': 1, 'column b': 2})
183        cur.execute("select 1 as a, 2 as b union select 3, 4 order by 1")
184        res = cur.fetchall()
185        self.assertIsInstance(res, list)
186        self.assertEqual(len(res), 2)
187        self.assertIsInstance(res[0], dict)
188        self.assertEqual(res[0], {'column a': 1, 'column b': 2})
189        self.assertIsInstance(res[1], dict)
190        self.assertEqual(res[1], {'column a': 3, 'column b': 4})
191
192    def test_build_row_factory(self):
193
194        class TestCursor(pgdb.Cursor):
195
196            def build_row_factory(self):
197                keys = [desc[0] for desc in self.description]
198                return lambda row: dict((key, value)
199                    for key, value in zip(keys, row))
200
201        con = self._connect()
202        con.cursor_type = TestCursor
203        cur = con.cursor()
204        self.assertIsInstance(cur, TestCursor)
205        cur.execute("select 1 as a, 2 as b")
206        res = cur.fetchone()
207        self.assertIsInstance(res, dict)
208        self.assertEqual(res, {'a': 1, 'b': 2})
209        cur.execute("select 1 as a, 2 as b union select 3, 4 order by 1")
210        res = cur.fetchall()
211        self.assertIsInstance(res, list)
212        self.assertEqual(len(res), 2)
213        self.assertIsInstance(res[0], dict)
214        self.assertEqual(res[0], {'a': 1, 'b': 2})
215        self.assertIsInstance(res[1], dict)
216        self.assertEqual(res[1], {'a': 3, 'b': 4})
217
218    def test_cursor_with_named_columns(self):
219        con = self._connect()
220        cur = con.cursor()
221        res = cur.execute("select 1 as abc, 2 as de, 3 as f")
222        self.assertIs(res, cur, 'execute() should return cursor')
223        res = cur.fetchone()
224        self.assertIsInstance(res, tuple)
225        self.assertEqual(res, (1, 2, 3))
226        self.assertEqual(res._fields, ('abc', 'de', 'f'))
227        self.assertEqual(res.abc, 1)
228        self.assertEqual(res.de, 2)
229        self.assertEqual(res.f, 3)
230        cur.execute("select 1 as one, 2 as two union select 3, 4 order by 1")
231        res = cur.fetchall()
232        self.assertIsInstance(res, list)
233        self.assertEqual(len(res), 2)
234        self.assertIsInstance(res[0], tuple)
235        self.assertEqual(res[0], (1, 2))
236        self.assertEqual(res[0]._fields, ('one', 'two'))
237        self.assertIsInstance(res[1], tuple)
238        self.assertEqual(res[1], (3, 4))
239        self.assertEqual(res[1]._fields, ('one', 'two'))
240
241    def test_cursor_with_unnamed_columns(self):
242        con = self._connect()
243        cur = con.cursor()
244        cur.execute("select 1, 2, 3")
245        res = cur.fetchone()
246        self.assertIsInstance(res, tuple)
247        self.assertEqual(res, (1, 2, 3))
248        old_py = OrderedDict is None  # Python 2.6 or 3.0
249        # old Python versions cannot rename tuple fields with underscore
250        if old_py:
251            self.assertEqual(res._fields, ('column_0', 'column_1', 'column_2'))
252        else:
253            self.assertEqual(res._fields, ('_0', '_1', '_2'))
254        cur.execute("select 1 as one, 2, 3 as three")
255        res = cur.fetchone()
256        self.assertIsInstance(res, tuple)
257        self.assertEqual(res, (1, 2, 3))
258        if old_py:  # cannot auto rename with underscore
259            self.assertEqual(res._fields, ('one', 'column_1', 'three'))
260        else:
261            self.assertEqual(res._fields, ('one', '_1', 'three'))
262
263    def test_cursor_with_badly_named_columns(self):
264        con = self._connect()
265        cur = con.cursor()
266        cur.execute("select 1 as abc, 2 as def")
267        res = cur.fetchone()
268        self.assertIsInstance(res, tuple)
269        self.assertEqual(res, (1, 2))
270        old_py = OrderedDict is None  # Python 2.6 or 3.0
271        if old_py:
272            self.assertEqual(res._fields, ('abc', 'column_1'))
273        else:
274            self.assertEqual(res._fields, ('abc', '_1'))
275        cur.execute('select 1 as snake_case, 2 as "CamelCase",'
276            ' 3 as "kebap-case", 4 as "_bad", 5 as "0bad", 6 as "bad$"')
277        res = cur.fetchone()
278        self.assertIsInstance(res, tuple)
279        self.assertEqual(res, (1, 2, 3, 4, 5, 6))
280        # old Python versions cannot rename tuple fields with underscore
281        self.assertEqual(res._fields[:2], ('snake_case', 'CamelCase'))
282        fields = ('_2', '_3', '_4', '_5')
283        if old_py:
284            fields = tuple('column' + field for field in fields)
285        self.assertEqual(res._fields[2:], fields)
286
287    def test_colnames(self):
288        con = self._connect()
289        cur = con.cursor()
290        cur.execute("select 1, 2, 3")
291        names = cur.colnames
292        self.assertIsInstance(names, list)
293        self.assertEqual(names, ['?column?', '?column?', '?column?'])
294        cur.execute("select 1 as a, 2 as bc, 3 as def, 4 as g")
295        names = cur.colnames
296        self.assertIsInstance(names, list)
297        self.assertEqual(names, ['a', 'bc', 'def', 'g'])
298
299    def test_coltypes(self):
300        con = self._connect()
301        cur = con.cursor()
302        cur.execute("select 1::int2, 2::int4, 3::int8")
303        types = cur.coltypes
304        self.assertIsInstance(types, list)
305        self.assertEqual(types, ['int2', 'int4', 'int8'])
306
307    def test_description_fields(self):
308        con = self._connect()
309        cur = con.cursor()
310        cur.execute("select 123456789::int8 col0,"
311            " 123456.789::numeric(41, 13) as col1,"
312            " 'foobar'::char(39) as col2")
313        desc = cur.description
314        self.assertIsInstance(desc, list)
315        self.assertEqual(len(desc), 3)
316        cols = [('int8', 8, None), ('numeric', 41, 13), ('bpchar', 39, None)]
317        for i in range(3):
318            c, d = cols[i], desc[i]
319            self.assertIsInstance(d, tuple)
320            self.assertEqual(len(d), 7)
321            self.assertIsInstance(d.name, str)
322            self.assertEqual(d.name, 'col%d' % i)
323            self.assertIsInstance(d.type_code, str)
324            self.assertEqual(d.type_code, c[0])
325            self.assertIsNone(d.display_size)
326            self.assertIsInstance(d.internal_size, int)
327            self.assertEqual(d.internal_size, c[1])
328            if c[2] is not None:
329                self.assertIsInstance(d.precision, int)
330                self.assertEqual(d.precision, c[1])
331                self.assertIsInstance(d.scale, int)
332                self.assertEqual(d.scale, c[2])
333            else:
334                self.assertIsNone(d.precision)
335                self.assertIsNone(d.scale)
336            self.assertIsNone(d.null_ok)
337
338    def test_type_cache_info(self):
339        con = self._connect()
340        try:
341            cur = con.cursor()
342            type_cache = con.type_cache
343            self.assertNotIn('numeric', type_cache)
344            type_info = type_cache['numeric']
345            self.assertIn('numeric', type_cache)
346            self.assertEqual(type_info, 'numeric')
347            self.assertEqual(type_info.oid, 1700)
348            self.assertEqual(type_info.len, -1)
349            self.assertEqual(type_info.type, 'b')  # base
350            self.assertEqual(type_info.category, 'N')  # numeric
351            self.assertEqual(type_info.delim, ',')
352            self.assertEqual(type_info.relid, 0)
353            self.assertIs(con.type_cache[1700], type_info)
354            self.assertNotIn('pg_type', type_cache)
355            type_info = type_cache['pg_type']
356            self.assertIn('pg_type', type_cache)
357            self.assertEqual(type_info.type, 'c')  # composite
358            self.assertEqual(type_info.category, 'C')  # composite
359            cols = type_cache.get_fields('pg_type')
360            if cols[0].name == 'oid':  # PostgreSQL < 12
361                del cols[0]
362            self.assertEqual(cols[0].name, 'typname')
363            typname = type_cache[cols[0].type]
364            self.assertEqual(typname, 'name')
365            self.assertEqual(typname.type, 'b')  # base
366            self.assertEqual(typname.category, 'S')  # string
367            self.assertEqual(cols[3].name, 'typlen')
368            typlen = type_cache[cols[3].type]
369            self.assertEqual(typlen, 'int2')
370            self.assertEqual(typlen.type, 'b')  # base
371            self.assertEqual(typlen.category, 'N')  # numeric
372            cur.close()
373            cur = con.cursor()
374            type_cache = con.type_cache
375            self.assertIn('numeric', type_cache)
376            cur.close()
377        finally:
378            con.close()
379        con = self._connect()
380        try:
381            cur = con.cursor()
382            type_cache = con.type_cache
383            self.assertNotIn('pg_type', type_cache)
384            self.assertEqual(type_cache.get('pg_type'), type_info)
385            self.assertIn('pg_type', type_cache)
386            self.assertIsNone(type_cache.get(
387                self.table_prefix + '_surely_does_not_exist'))
388            cur.close()
389        finally:
390            con.close()
391
392    def test_type_cache_typecast(self):
393        con = self._connect()
394        try:
395            cur = con.cursor()
396            type_cache = con.type_cache
397            self.assertIs(type_cache.get_typecast('int4'), int)
398            cast_int = lambda v: 'int(%s)' % v
399            type_cache.set_typecast('int4', cast_int)
400            query = 'select 2::int2, 4::int4, 8::int8'
401            cur.execute(query)
402            i2, i4, i8 = cur.fetchone()
403            self.assertEqual(i2, 2)
404            self.assertEqual(i4, 'int(4)')
405            self.assertEqual(i8, 8)
406            self.assertEqual(type_cache.typecast(42, 'int4'), 'int(42)')
407            type_cache.set_typecast(['int2', 'int8'], cast_int)
408            cur.execute(query)
409            i2, i4, i8 = cur.fetchone()
410            self.assertEqual(i2, 'int(2)')
411            self.assertEqual(i4, 'int(4)')
412            self.assertEqual(i8, 'int(8)')
413            type_cache.reset_typecast('int4')
414            cur.execute(query)
415            i2, i4, i8 = cur.fetchone()
416            self.assertEqual(i2, 'int(2)')
417            self.assertEqual(i4, 4)
418            self.assertEqual(i8, 'int(8)')
419            type_cache.reset_typecast(['int2', 'int8'])
420            cur.execute(query)
421            i2, i4, i8 = cur.fetchone()
422            self.assertEqual(i2, 2)
423            self.assertEqual(i4, 4)
424            self.assertEqual(i8, 8)
425            type_cache.set_typecast(['int2', 'int8'], cast_int)
426            cur.execute(query)
427            i2, i4, i8 = cur.fetchone()
428            self.assertEqual(i2, 'int(2)')
429            self.assertEqual(i4, 4)
430            self.assertEqual(i8, 'int(8)')
431            type_cache.reset_typecast()
432            cur.execute(query)
433            i2, i4, i8 = cur.fetchone()
434            self.assertEqual(i2, 2)
435            self.assertEqual(i4, 4)
436            self.assertEqual(i8, 8)
437            cur.close()
438        finally:
439            con.close()
440
441    def test_cursor_iteration(self):
442        con = self._connect()
443        cur = con.cursor()
444        cur.execute("select 1 union select 2 union select 3 order by 1")
445        self.assertEqual([r[0] for r in cur], [1, 2, 3])
446
447    def test_cursor_invalidation(self):
448        con = self._connect()
449        cur = con.cursor()
450        cur.execute("select 1 union select 2")
451        self.assertEqual(cur.fetchone(), (1,))
452        self.assertFalse(con.closed)
453        con.close()
454        self.assertTrue(con.closed)
455        self.assertRaises(pgdb.OperationalError, cur.fetchone)
456
457    def test_fetch_2_rows(self):
458        Decimal = pgdb.decimal_type()
459        values = ('test', pgdb.Binary(b'\xff\x52\xb2'),
460            True, 5, 6, 5.7, Decimal('234.234234'), Decimal('75.45'),
461            pgdb.Date(2011, 7, 17), pgdb.Time(15, 47, 42),
462            pgdb.Timestamp(2008, 10, 20, 15, 25, 35),
463            pgdb.Interval(15, 31, 5), 7897234)
464        table = self.table_prefix + 'booze'
465        con = self._connect()
466        try:
467            cur = con.cursor()
468            cur.execute("set datestyle to iso")
469            cur.execute("create table %s ("
470                "stringtest varchar,"
471                "binarytest bytea,"
472                "booltest bool,"
473                "integertest int4,"
474                "longtest int8,"
475                "floattest float8,"
476                "numerictest numeric,"
477                "moneytest money,"
478                "datetest date,"
479                "timetest time,"
480                "datetimetest timestamp,"
481                "intervaltest interval,"
482                "rowidtest oid)" % table)
483            cur.execute("set standard_conforming_strings to on")
484            for s in ('numeric', 'monetary', 'time'):
485                cur.execute("set lc_%s to 'C'" % s)
486            for _i in range(2):
487                cur.execute("insert into %s values ("
488                    "%%s,%%s,%%s,%%s,%%s,%%s,%%s,"
489                    "'%%s'::money,%%s,%%s,%%s,%%s,%%s)" % table, values)
490            cur.execute("select * from %s" % table)
491            rows = cur.fetchall()
492            self.assertEqual(len(rows), 2)
493            row0 = rows[0]
494            self.assertEqual(row0, values)
495            self.assertEqual(row0, rows[1])
496            self.assertIsInstance(row0[0], str)
497            self.assertIsInstance(row0[1], bytes)
498            self.assertIsInstance(row0[2], bool)
499            self.assertIsInstance(row0[3], int)
500            self.assertIsInstance(row0[4], long)
501            self.assertIsInstance(row0[5], float)
502            self.assertIsInstance(row0[6], Decimal)
503            self.assertIsInstance(row0[7], Decimal)
504            self.assertIsInstance(row0[8], date)
505            self.assertIsInstance(row0[9], time)
506            self.assertIsInstance(row0[10], datetime)
507            self.assertIsInstance(row0[11], timedelta)
508        finally:
509            con.close()
510
511    def test_integrity_error(self):
512        table = self.table_prefix + 'booze'
513        con = self._connect()
514        try:
515            cur = con.cursor()
516            cur.execute("set client_min_messages = warning")
517            cur.execute("create table %s (i int primary key)" % table)
518            cur.execute("insert into %s values (1)" % table)
519            cur.execute("insert into %s values (2)" % table)
520            self.assertRaises(pgdb.IntegrityError, cur.execute,
521                "insert into %s values (1)" % table)
522        finally:
523            con.close()
524
525    def test_update_rowcount(self):
526        table = self.table_prefix + 'booze'
527        con = self._connect()
528        try:
529            cur = con.cursor()
530            cur.execute("create table %s (i int)" % table)
531            cur.execute("insert into %s values (1)" % table)
532            cur.execute("update %s set i=2 where i=2 returning i" % table)
533            self.assertEqual(cur.rowcount, 0)
534            cur.execute("update %s set i=2 where i=1 returning i" % table)
535            self.assertEqual(cur.rowcount, 1)
536            cur.close()
537            # keep rowcount even if cursor is closed (needed by SQLAlchemy)
538            self.assertEqual(cur.rowcount, 1)
539        finally:
540            con.close()
541
542    def test_sqlstate(self):
543        con = self._connect()
544        cur = con.cursor()
545        try:
546            cur.execute("select 1/0")
547        except pgdb.DatabaseError as error:
548            self.assertTrue(isinstance(error, pgdb.DataError))
549            # the SQLSTATE error code for division by zero is 22012
550            self.assertEqual(error.sqlstate, '22012')
551
552    def test_float(self):
553        nan, inf = float('nan'), float('inf')
554        from math import isnan, isinf
555        self.assertTrue(isnan(nan) and not isinf(nan))
556        self.assertTrue(isinf(inf) and not isnan(inf))
557        values = [0, 1, 0.03125, -42.53125, nan, inf, -inf,
558            'nan', 'inf', '-inf', 'NaN', 'Infinity', '-Infinity']
559        table = self.table_prefix + 'booze'
560        con = self._connect()
561        try:
562            cur = con.cursor()
563            cur.execute(
564                "create table %s (n smallint, floattest float)" % table)
565            params = enumerate(values)
566            cur.executemany("insert into %s values (%%d,%%s)" % table, params)
567            cur.execute("select floattest from %s order by n" % table)
568            rows = cur.fetchall()
569            self.assertEqual(cur.description[0].type_code, pgdb.FLOAT)
570            self.assertNotEqual(cur.description[0].type_code, pgdb.ARRAY)
571            self.assertNotEqual(cur.description[0].type_code, pgdb.RECORD)
572        finally:
573            con.close()
574        self.assertEqual(len(rows), len(values))
575        rows = [row[0] for row in rows]
576        for inval, outval in zip(values, rows):
577            if inval in ('inf', 'Infinity'):
578                inval = inf
579            elif inval in ('-inf', '-Infinity'):
580                inval = -inf
581            elif inval in ('nan', 'NaN'):
582                inval = nan
583            if isinf(inval):
584                self.assertTrue(isinf(outval))
585                if inval < 0:
586                    self.assertTrue(outval < 0)
587                else:
588                    self.assertTrue(outval > 0)
589            elif isnan(inval):
590                self.assertTrue(isnan(outval))
591            else:
592                self.assertEqual(inval, outval)
593
594    def test_datetime(self):
595        dt = datetime(2011, 7, 17, 15, 47, 42, 317509)
596        table = self.table_prefix + 'booze'
597        con = self._connect()
598        try:
599            cur = con.cursor()
600            cur.execute("set timezone = UTC")
601            cur.execute("create table %s ("
602                "d date, t time,  ts timestamp,"
603                "tz timetz, tsz timestamptz)" % table)
604            for n in range(3):
605                values = [dt.date(), dt.time(), dt,
606                    dt.time(), dt]
607                values[3] = values[3].replace(tzinfo=pgdb.timezone.utc)
608                values[4] = values[4].replace(tzinfo=pgdb.timezone.utc)
609                if n == 0:  # input as objects
610                    params = values
611                if n == 1:  # input as text
612                    params = [v.isoformat() for v in values]  # as text
613                elif n == 2:  # input using type helpers
614                    d = (dt.year, dt.month, dt.day)
615                    t = (dt.hour, dt.minute, dt.second, dt.microsecond)
616                    z = (pgdb.timezone.utc,)
617                    params = [pgdb.Date(*d), pgdb.Time(*t),
618                            pgdb.Timestamp(*(d + t)), pgdb.Time(*(t + z)),
619                            pgdb.Timestamp(*(d + t + z))]
620                for datestyle in ('iso', 'postgres, mdy', 'postgres, dmy',
621                        'sql, mdy', 'sql, dmy', 'german'):
622                    cur.execute("set datestyle to %s" % datestyle)
623                    if n != 1:
624                        cur.execute("select %s,%s,%s,%s,%s", params)
625                        row = cur.fetchone()
626                        self.assertEqual(row, tuple(values))
627                    cur.execute("insert into %s"
628                        " values (%%s,%%s,%%s,%%s,%%s)" % table, params)
629                    cur.execute("select * from %s" % table)
630                    d = cur.description
631                    for i in range(5):
632                        self.assertEqual(d[i].type_code, pgdb.DATETIME)
633                        self.assertNotEqual(d[i].type_code, pgdb.STRING)
634                        self.assertNotEqual(d[i].type_code, pgdb.ARRAY)
635                        self.assertNotEqual(d[i].type_code, pgdb.RECORD)
636                    self.assertEqual(d[0].type_code, pgdb.DATE)
637                    self.assertEqual(d[1].type_code, pgdb.TIME)
638                    self.assertEqual(d[2].type_code, pgdb.TIMESTAMP)
639                    self.assertEqual(d[3].type_code, pgdb.TIME)
640                    self.assertEqual(d[4].type_code, pgdb.TIMESTAMP)
641                    row = cur.fetchone()
642                    self.assertEqual(row, tuple(values))
643                    cur.execute("delete from %s" % table)
644        finally:
645            con.close()
646
647    def test_interval(self):
648        td = datetime(2011, 7, 17, 15, 47, 42, 317509) - datetime(1970, 1, 1)
649        table = self.table_prefix + 'booze'
650        con = self._connect()
651        try:
652            cur = con.cursor()
653            cur.execute("create table %s (i interval)" % table)
654            for n in range(3):
655                if n == 0:  # input as objects
656                    param = td
657                if n == 1:  # input as text
658                    param = '%d days %d seconds %d microseconds ' % (
659                        td.days, td.seconds, td.microseconds)
660                elif n == 2:  # input using type helpers
661                    param = pgdb.Interval(
662                        td.days, 0, 0, td.seconds, td.microseconds)
663                for intervalstyle in ('sql_standard ', 'postgres',
664                        'postgres_verbose', 'iso_8601'):
665                    cur.execute("set intervalstyle to %s" % intervalstyle)
666                    cur.execute("insert into %s"
667                        " values (%%s)" % table, [param])
668                    cur.execute("select * from %s" % table)
669                    tc = cur.description[0].type_code
670                    self.assertEqual(tc, pgdb.DATETIME)
671                    self.assertNotEqual(tc, pgdb.STRING)
672                    self.assertNotEqual(tc, pgdb.ARRAY)
673                    self.assertNotEqual(tc, pgdb.RECORD)
674                    self.assertEqual(tc, pgdb.INTERVAL)
675                    row = cur.fetchone()
676                    self.assertEqual(row, (td,))
677                    cur.execute("delete from %s" % table)
678        finally:
679            con.close()
680
681    def test_hstore(self):
682        con = self._connect()
683        try:
684            cur = con.cursor()
685            cur.execute("select 'k=>v'::hstore")
686        except pgdb.DatabaseError:
687            try:
688                cur.execute("create extension hstore")
689            except pgdb.DatabaseError:
690                self.skipTest("hstore extension not enabled")
691        finally:
692            con.close()
693        d = {'k': 'v', 'foo': 'bar', 'baz': 'whatever', 'back\\': '\\slash',
694            '1a': 'anything at all', '2=b': 'value = 2', '3>c': 'value > 3',
695            '4"c': 'value " 4', "5'c": "value ' 5", 'hello, world': '"hi!"',
696            'None': None, 'NULL': 'NULL', 'empty': ''}
697        con = self._connect()
698        try:
699            cur = con.cursor()
700            cur.execute("select %s::hstore", (pgdb.Hstore(d),))
701            result = cur.fetchone()[0]
702        finally:
703            con.close()
704        self.assertIsInstance(result, dict)
705        self.assertEqual(result, d)
706
707    def test_uuid(self):
708        self.assertIs(Uuid, pgdb.Uuid)
709        d = Uuid('{12345678-1234-5678-1234-567812345678}')
710        con = self._connect()
711        try:
712            cur = con.cursor()
713            cur.execute("select %s::uuid", (d,))
714            result = cur.fetchone()[0]
715        finally:
716            con.close()
717        self.assertIsInstance(result, Uuid)
718        self.assertEqual(result, d)
719
720    def test_insert_array(self):
721        values = [(None, None), ([], []), ([None], [[None], ['null']]),
722            ([1, 2, 3], [['a', 'b'], ['c', 'd']]),
723            ([20000, 25000, 25000, 30000],
724            [['breakfast', 'consulting'], ['meeting', 'lunch']]),
725            ([0, 1, -1], [['Hello, World!', '"Hi!"'], ['{x,y}', ' x y ']])]
726        table = self.table_prefix + 'booze'
727        con = self._connect()
728        try:
729            cur = con.cursor()
730            cur.execute("create table %s"
731                " (n smallint, i int[], t text[][])" % table)
732            params = [(n, v[0], v[1]) for n, v in enumerate(values)]
733            # Note that we must explicit casts because we are inserting
734            # empty arrays.  Otherwise this is not necessary.
735            cur.executemany("insert into %s values"
736                " (%%d,%%s::int[],%%s::text[][])" % table, params)
737            cur.execute("select i, t from %s order by n" % table)
738            d = cur.description
739            self.assertEqual(d[0].type_code, pgdb.ARRAY)
740            self.assertNotEqual(d[0].type_code, pgdb.RECORD)
741            self.assertEqual(d[0].type_code, pgdb.NUMBER)
742            self.assertEqual(d[0].type_code, pgdb.INTEGER)
743            self.assertEqual(d[1].type_code, pgdb.ARRAY)
744            self.assertNotEqual(d[1].type_code, pgdb.RECORD)
745            self.assertEqual(d[1].type_code, pgdb.STRING)
746            rows = cur.fetchall()
747        finally:
748            con.close()
749        self.assertEqual(rows, values)
750
751    def test_select_array(self):
752        values = ([1, 2, 3, None], ['a', 'b', 'c', None])
753        con = self._connect()
754        try:
755            cur = con.cursor()
756            cur.execute("select %s::int[], %s::text[]", values)
757            row = cur.fetchone()
758        finally:
759            con.close()
760        self.assertEqual(row, values)
761
762    def test_unicode_list_and_tuple(self):
763        value = (u'KÀse', u'WÃŒrstchen')
764        con = self._connect()
765        try:
766            cur = con.cursor()
767            try:
768                cur.execute("select %s, %s", value)
769            except pgdb.DatabaseError:
770                self.skipTest('database does not support latin-1')
771            row = cur.fetchone()
772            cur.execute("select %s, %s", (list(value), tuple(value)))
773            as_list, as_tuple = cur.fetchone()
774        finally:
775            con.close()
776        self.assertEqual(as_list, list(row))
777        self.assertEqual(as_tuple, tuple(row))
778
779    def test_insert_record(self):
780        values = [('John', 61), ('Jane', 63),
781                  ('Fred', None), ('Wilma', None),
782                  (None, 42), (None, None)]
783        table = self.table_prefix + 'booze'
784        record = self.table_prefix + 'munch'
785        con = self._connect()
786        try:
787            cur = con.cursor()
788            cur.execute("create type %s as (name varchar, age int)" % record)
789            cur.execute("create table %s (n smallint, r %s)" % (table, record))
790            params = enumerate(values)
791            cur.executemany("insert into %s values (%%d,%%s)" % table, params)
792            cur.execute("select r from %s order by n" % table)
793            type_code = cur.description[0].type_code
794            self.assertEqual(type_code, record)
795            self.assertEqual(type_code, pgdb.RECORD)
796            self.assertNotEqual(type_code, pgdb.ARRAY)
797            columns = con.type_cache.get_fields(type_code)
798            self.assertEqual(columns[0].name, 'name')
799            self.assertEqual(columns[1].name, 'age')
800            self.assertEqual(con.type_cache[columns[0].type], 'varchar')
801            self.assertEqual(con.type_cache[columns[1].type], 'int4')
802            rows = cur.fetchall()
803        finally:
804            cur.execute('drop table %s' % table)
805            cur.execute('drop type %s' % record)
806            con.close()
807        self.assertEqual(len(rows), len(values))
808        rows = [row[0] for row in rows]
809        self.assertEqual(rows, values)
810        self.assertEqual(rows[0].name, 'John')
811        self.assertEqual(rows[0].age, 61)
812
813    def test_select_record(self):
814        value = (1, 25000, 2.5, 'hello', 'Hello World!', 'Hello, World!',
815            '(test)', '(x,y)', ' x y ', 'null', None)
816        con = self._connect()
817        try:
818            cur = con.cursor()
819            cur.execute("select %s as test_record", [value])
820            self.assertEqual(cur.description[0].name, 'test_record')
821            self.assertEqual(cur.description[0].type_code, 'record')
822            row = cur.fetchone()[0]
823        finally:
824            con.close()
825        # Note that the element types get lost since we created an
826        # untyped record (an anonymous composite type). For the same
827        # reason this is also a normal tuple, not a named tuple.
828        text_row = tuple(None if v is None else str(v) for v in value)
829        self.assertEqual(row, text_row)
830
831    def test_custom_type(self):
832        values = [3, 5, 65]
833        values = list(map(PgBitString, values))
834        table = self.table_prefix + 'booze'
835        con = self._connect()
836        try:
837            cur = con.cursor()
838            params = enumerate(values)  # params have __pg_repr__ method
839            cur.execute(
840                'create table "%s" (n smallint, b bit varying(7))' % table)
841            cur.executemany("insert into %s values (%%s,%%s)" % table, params)
842            cur.execute("select * from %s" % table)
843            rows = cur.fetchall()
844        finally:
845            con.close()
846        self.assertEqual(len(rows), len(values))
847        con = self._connect()
848        try:
849            cur = con.cursor()
850            params = (1, object())  # an object that cannot be handled
851            self.assertRaises(pgdb.InterfaceError, cur.execute,
852                "insert into %s values (%%s,%%s)" % table, params)
853        finally:
854            con.close()
855
856    def test_set_decimal_type(self):
857        decimal_type = pgdb.decimal_type()
858        self.assertTrue(decimal_type is not None and callable(decimal_type))
859        con = self._connect()
860        try:
861            cur = con.cursor()
862            # change decimal type globally to int
863            int_type = lambda v: int(float(v))
864            self.assertTrue(pgdb.decimal_type(int_type) is int_type)
865            cur.execute('select 4.25')
866            self.assertEqual(cur.description[0].type_code, pgdb.NUMBER)
867            value = cur.fetchone()[0]
868            self.assertTrue(isinstance(value, int))
869            self.assertEqual(value, 4)
870            # change decimal type again to float
871            self.assertTrue(pgdb.decimal_type(float) is float)
872            cur.execute('select 4.25')
873            self.assertEqual(cur.description[0].type_code, pgdb.NUMBER)
874            value = cur.fetchone()[0]
875            # the connection still uses the old setting
876            self.assertTrue(isinstance(value, int))
877            # bust the cache for type functions for the connection
878            con.type_cache.reset_typecast()
879            cur.execute('select 4.25')
880            self.assertEqual(cur.description[0].type_code, pgdb.NUMBER)
881            value = cur.fetchone()[0]
882            # now the connection uses the new setting
883            self.assertTrue(isinstance(value, float))
884            self.assertEqual(value, 4.25)
885        finally:
886            con.close()
887            pgdb.decimal_type(decimal_type)
888        self.assertTrue(pgdb.decimal_type() is decimal_type)
889
890    def test_global_typecast(self):
891        try:
892            query = 'select 2::int2, 4::int4, 8::int8'
893            self.assertIs(pgdb.get_typecast('int4'), int)
894            cast_int = lambda v: 'int(%s)' % v
895            pgdb.set_typecast('int4', cast_int)
896            con = self._connect()
897            try:
898                i2, i4, i8 = con.cursor().execute(query).fetchone()
899            finally:
900                con.close()
901            self.assertEqual(i2, 2)
902            self.assertEqual(i4, 'int(4)')
903            self.assertEqual(i8, 8)
904            pgdb.set_typecast(['int2', 'int8'], cast_int)
905            con = self._connect()
906            try:
907                i2, i4, i8 = con.cursor().execute(query).fetchone()
908            finally:
909                con.close()
910            self.assertEqual(i2, 'int(2)')
911            self.assertEqual(i4, 'int(4)')
912            self.assertEqual(i8, 'int(8)')
913            pgdb.reset_typecast('int4')
914            con = self._connect()
915            try:
916                i2, i4, i8 = con.cursor().execute(query).fetchone()
917            finally:
918                con.close()
919            self.assertEqual(i2, 'int(2)')
920            self.assertEqual(i4, 4)
921            self.assertEqual(i8, 'int(8)')
922            pgdb.reset_typecast(['int2', 'int8'])
923            con = self._connect()
924            try:
925                i2, i4, i8 = con.cursor().execute(query).fetchone()
926            finally:
927                con.close()
928            self.assertEqual(i2, 2)
929            self.assertEqual(i4, 4)
930            self.assertEqual(i8, 8)
931            pgdb.set_typecast(['int2', 'int8'], cast_int)
932            con = self._connect()
933            try:
934                i2, i4, i8 = con.cursor().execute(query).fetchone()
935            finally:
936                con.close()
937            self.assertEqual(i2, 'int(2)')
938            self.assertEqual(i4, 4)
939            self.assertEqual(i8, 'int(8)')
940        finally:
941            pgdb.reset_typecast()
942        con = self._connect()
943        try:
944            i2, i4, i8 = con.cursor().execute(query).fetchone()
945        finally:
946            con.close()
947        self.assertEqual(i2, 2)
948        self.assertEqual(i4, 4)
949        self.assertEqual(i8, 8)
950
951    def test_set_typecast_for_arrays(self):
952        query = 'select ARRAY[1,2,3]'
953        try:
954            con = self._connect()
955            try:
956                r = con.cursor().execute(query).fetchone()[0]
957            finally:
958                con.close()
959            self.assertIsInstance(r, list)
960            self.assertEqual(r, [1, 2, 3])
961            pgdb.set_typecast('anyarray', lambda v, basecast: v)
962            con = self._connect()
963            try:
964                r = con.cursor().execute(query).fetchone()[0]
965            finally:
966                con.close()
967            self.assertIsInstance(r, str)
968            self.assertEqual(r, '{1,2,3}')
969        finally:
970            pgdb.reset_typecast()
971        con = self._connect()
972        try:
973            r = con.cursor().execute(query).fetchone()[0]
974        finally:
975            con.close()
976        self.assertIsInstance(r, list)
977        self.assertEqual(r, [1, 2, 3])
978
979    def test_unicode_with_utf8(self):
980        table = self.table_prefix + 'booze'
981        input = u"He wes Leovenaðes sone — liðe him be Drihten"
982        con = self._connect()
983        try:
984            cur = con.cursor()
985            cur.execute("create table %s (t text)" % table)
986            try:
987                cur.execute("set client_encoding=utf8")
988                cur.execute(u"select '%s'" % input)
989            except Exception:
990                self.skipTest("database does not support utf8")
991            output1 = cur.fetchone()[0]
992            cur.execute("insert into %s values (%%s)" % table, (input,))
993            cur.execute("select * from %s" % table)
994            output2 = cur.fetchone()[0]
995            cur.execute("select t = '%s' from %s" % (input, table))
996            output3 = cur.fetchone()[0]
997            cur.execute("select t = %%s from %s" % table, (input,))
998            output4 = cur.fetchone()[0]
999        finally:
1000            con.close()
1001        if str is bytes:  # Python < 3.0
1002            input = input.encode('utf8')
1003        self.assertIsInstance(output1, str)
1004        self.assertEqual(output1, input)
1005        self.assertIsInstance(output2, str)
1006        self.assertEqual(output2, input)
1007        self.assertIsInstance(output3, bool)
1008        self.assertTrue(output3)
1009        self.assertIsInstance(output4, bool)
1010        self.assertTrue(output4)
1011
1012    def test_unicode_with_latin1(self):
1013        table = self.table_prefix + 'booze'
1014        input = u"Ehrt den König seine WÃŒrde, ehret uns der HÀnde Fleiß."
1015        con = self._connect()
1016        try:
1017            cur = con.cursor()
1018            cur.execute("create table %s (t text)" % table)
1019            try:
1020                cur.execute("set client_encoding=latin1")
1021                cur.execute(u"select '%s'" % input)
1022            except Exception:
1023                self.skipTest("database does not support latin1")
1024            output1 = cur.fetchone()[0]
1025            cur.execute("insert into %s values (%%s)" % table, (input,))
1026            cur.execute("select * from %s" % table)
1027            output2 = cur.fetchone()[0]
1028            cur.execute("select t = '%s' from %s" % (input, table))
1029            output3 = cur.fetchone()[0]
1030            cur.execute("select t = %%s from %s" % table, (input,))
1031            output4 = cur.fetchone()[0]
1032        finally:
1033            con.close()
1034        if str is bytes:  # Python < 3.0
1035            input = input.encode('latin1')
1036        self.assertIsInstance(output1, str)
1037        self.assertEqual(output1, input)
1038        self.assertIsInstance(output2, str)
1039        self.assertEqual(output2, input)
1040        self.assertIsInstance(output3, bool)
1041        self.assertTrue(output3)
1042        self.assertIsInstance(output4, bool)
1043        self.assertTrue(output4)
1044
1045    def test_bool(self):
1046        values = [False, True, None, 't', 'f', 'true', 'false']
1047        table = self.table_prefix + 'booze'
1048        con = self._connect()
1049        try:
1050            cur = con.cursor()
1051            cur.execute(
1052                "create table %s (n smallint, booltest bool)" % table)
1053            params = enumerate(values)
1054            cur.executemany("insert into %s values (%%s,%%s)" % table, params)
1055            cur.execute("select booltest from %s order by n" % table)
1056            rows = cur.fetchall()
1057            self.assertEqual(cur.description[0].type_code, pgdb.BOOL)
1058        finally:
1059            con.close()
1060        rows = [row[0] for row in rows]
1061        values[3] = values[5] = True
1062        values[4] = values[6] = False
1063        self.assertEqual(rows, values)
1064
1065    def test_literal(self):
1066        con = self._connect()
1067        try:
1068            cur = con.cursor()
1069            value = "lower('Hello')"
1070            cur.execute("select %s, %s", (value, pgdb.Literal(value)))
1071            row = cur.fetchone()
1072        finally:
1073            con.close()
1074        self.assertEqual(row, (value, 'hello'))
1075
1076    def test_json(self):
1077        inval = {"employees":
1078            [{"firstName": "John", "lastName": "Doe", "age": 61}]}
1079        table = self.table_prefix + 'booze'
1080        con = self._connect()
1081        try:
1082            cur = con.cursor()
1083            try:
1084                cur.execute("create table %s (jsontest json)" % table)
1085            except pgdb.ProgrammingError:
1086                self.skipTest('database does not support json')
1087            params = (pgdb.Json(inval),)
1088            cur.execute("insert into %s values (%%s)" % table, params)
1089            cur.execute("select jsontest from %s" % table)
1090            outval = cur.fetchone()[0]
1091            self.assertEqual(cur.description[0].type_code, pgdb.JSON)
1092        finally:
1093            con.close()
1094        self.assertEqual(inval, outval)
1095
1096    def test_jsonb(self):
1097        inval = {"employees":
1098            [{"firstName": "John", "lastName": "Doe", "age": 61}]}
1099        table = self.table_prefix + 'booze'
1100        con = self._connect()
1101        try:
1102            cur = con.cursor()
1103            try:
1104                cur.execute("create table %s (jsonbtest jsonb)" % table)
1105            except pgdb.ProgrammingError:
1106                self.skipTest('database does not support jsonb')
1107            params = (pgdb.Json(inval),)
1108            cur.execute("insert into %s values (%%s)" % table, params)
1109            cur.execute("select jsonbtest from %s" % table)
1110            outval = cur.fetchone()[0]
1111            self.assertEqual(cur.description[0].type_code, pgdb.JSON)
1112        finally:
1113            con.close()
1114        self.assertEqual(inval, outval)
1115
1116    def test_execute_edge_cases(self):
1117        con = self._connect()
1118        try:
1119            cur = con.cursor()
1120            sql = 'invalid'  # should be ignored with empty parameter list
1121            cur.executemany(sql, [])
1122            sql = 'select %d + 1'
1123            cur.execute(sql, [(1,), (2,)])  # deprecated use of execute()
1124            self.assertEqual(cur.fetchone()[0], 3)
1125            sql = 'select 1/0'  # cannot be executed
1126            self.assertRaises(pgdb.DataError, cur.execute, sql)
1127            cur.close()
1128            con.rollback()
1129            if pgdb.shortcutmethods:
1130                res = con.execute('select %d', (1,)).fetchone()
1131                self.assertEqual(res, (1,))
1132                res = con.executemany('select %d', [(1,), (2,)]).fetchone()
1133                self.assertEqual(res, (2,))
1134        finally:
1135            con.close()
1136        sql = 'select 1'  # cannot be executed after connection is closed
1137        self.assertRaises(pgdb.OperationalError, cur.execute, sql)
1138
1139    def test_fetchmany_with_keep(self):
1140        con = self._connect()
1141        try:
1142            cur = con.cursor()
1143            self.assertEqual(cur.arraysize, 1)
1144            cur.execute('select * from generate_series(1, 25)')
1145            self.assertEqual(len(cur.fetchmany()), 1)
1146            self.assertEqual(len(cur.fetchmany()), 1)
1147            self.assertEqual(cur.arraysize, 1)
1148            cur.arraysize = 3
1149            self.assertEqual(len(cur.fetchmany()), 3)
1150            self.assertEqual(len(cur.fetchmany()), 3)
1151            self.assertEqual(cur.arraysize, 3)
1152            self.assertEqual(len(cur.fetchmany(size=2)), 2)
1153            self.assertEqual(cur.arraysize, 3)
1154            self.assertEqual(len(cur.fetchmany()), 3)
1155            self.assertEqual(len(cur.fetchmany()), 3)
1156            self.assertEqual(len(cur.fetchmany(size=2, keep=True)), 2)
1157            self.assertEqual(cur.arraysize, 2)
1158            self.assertEqual(len(cur.fetchmany()), 2)
1159            self.assertEqual(len(cur.fetchmany()), 2)
1160            self.assertEqual(len(cur.fetchmany(25)), 3)
1161        finally:
1162            con.close()
1163
1164    def test_nextset(self):
1165        con = self._connect()
1166        cur = con.cursor()
1167        self.assertRaises(con.NotSupportedError, cur.nextset)
1168
1169    def test_setoutputsize(self):
1170        pass  # not supported
1171
1172    def test_connection_errors(self):
1173        con = self._connect()
1174        self.assertEqual(con.Error, pgdb.Error)
1175        self.assertEqual(con.Warning, pgdb.Warning)
1176        self.assertEqual(con.InterfaceError, pgdb.InterfaceError)
1177        self.assertEqual(con.DatabaseError, pgdb.DatabaseError)
1178        self.assertEqual(con.InternalError, pgdb.InternalError)
1179        self.assertEqual(con.OperationalError, pgdb.OperationalError)
1180        self.assertEqual(con.ProgrammingError, pgdb.ProgrammingError)
1181        self.assertEqual(con.IntegrityError, pgdb.IntegrityError)
1182        self.assertEqual(con.DataError, pgdb.DataError)
1183        self.assertEqual(con.NotSupportedError, pgdb.NotSupportedError)
1184
1185    def test_transaction(self):
1186        table = self.table_prefix + 'booze'
1187        con1 = self._connect()
1188        cur1 = con1.cursor()
1189        self.executeDDL1(cur1)
1190        con1.commit()
1191        con2 = self._connect()
1192        cur2 = con2.cursor()
1193        cur2.execute("select name from %s" % table)
1194        self.assertIsNone(cur2.fetchone())
1195        cur1.execute("insert into %s values('Schlafly')" % table)
1196        cur2.execute("select name from %s" % table)
1197        self.assertIsNone(cur2.fetchone())
1198        con1.commit()
1199        cur2.execute("select name from %s" % table)
1200        self.assertEqual(cur2.fetchone(), ('Schlafly',))
1201        con2.close()
1202        con1.close()
1203
1204    def test_autocommit(self):
1205        table = self.table_prefix + 'booze'
1206        con1 = self._connect()
1207        con1.autocommit = True
1208        cur1 = con1.cursor()
1209        self.executeDDL1(cur1)
1210        con2 = self._connect()
1211        cur2 = con2.cursor()
1212        cur2.execute("select name from %s" % table)
1213        self.assertIsNone(cur2.fetchone())
1214        cur1.execute("insert into %s values('Shmaltz Pastrami')" % table)
1215        cur2.execute("select name from %s" % table)
1216        self.assertEqual(cur2.fetchone(), ('Shmaltz Pastrami',))
1217        con2.close()
1218        con1.close()
1219
1220    def test_connection_as_contextmanager(self):
1221        table = self.table_prefix + 'booze'
1222        for autocommit in False, True:
1223            con = self._connect()
1224            con.autocommit = autocommit
1225            try:
1226                cur = con.cursor()
1227                if autocommit:
1228                    cur.execute("truncate %s" % table)
1229                else:
1230                    cur.execute(
1231                        "create table %s (n smallint check(n!=4))" % table)
1232                with con:
1233                    cur.execute("insert into %s values (1)" % table)
1234                    cur.execute("insert into %s values (2)" % table)
1235                try:
1236                    with con:
1237                        cur.execute("insert into %s values (3)" % table)
1238                        cur.execute("insert into %s values (4)" % table)
1239                except con.IntegrityError as error:
1240                    self.assertTrue('check' in str(error).lower())
1241                with con:
1242                    cur.execute("insert into %s values (5)" % table)
1243                    cur.execute("insert into %s values (6)" % table)
1244                try:
1245                    with con:
1246                        cur.execute("insert into %s values (7)" % table)
1247                        cur.execute("insert into %s values (8)" % table)
1248                        raise ValueError('transaction should rollback')
1249                except ValueError as error:
1250                    self.assertEqual(str(error), 'transaction should rollback')
1251                with con:
1252                    cur.execute("insert into %s values (9)" % table)
1253                cur.execute("select * from %s order by 1" % table)
1254                rows = cur.fetchall()
1255                rows = [row[0] for row in rows]
1256            finally:
1257                con.close()
1258            self.assertEqual(rows, [1, 2, 5, 6, 9])
1259
1260    def test_cursor_connection(self):
1261        con = self._connect()
1262        cur = con.cursor()
1263        self.assertEqual(cur.connection, con)
1264        cur.close()
1265
1266    def test_cursor_as_contextmanager(self):
1267        con = self._connect()
1268        with con.cursor() as cur:
1269            self.assertEqual(cur.connection, con)
1270
1271    def test_pgdb_type(self):
1272        self.assertEqual(pgdb.STRING, pgdb.STRING)
1273        self.assertNotEqual(pgdb.STRING, pgdb.INTEGER)
1274        self.assertNotEqual(pgdb.STRING, pgdb.BOOL)
1275        self.assertNotEqual(pgdb.BOOL, pgdb.INTEGER)
1276        self.assertEqual(pgdb.INTEGER, pgdb.INTEGER)
1277        self.assertNotEqual(pgdb.INTEGER, pgdb.NUMBER)
1278        self.assertEqual('char', pgdb.STRING)
1279        self.assertEqual('varchar', pgdb.STRING)
1280        self.assertEqual('text', pgdb.STRING)
1281        self.assertNotEqual('numeric', pgdb.STRING)
1282        self.assertEqual('numeric', pgdb.NUMERIC)
1283        self.assertEqual('numeric', pgdb.NUMBER)
1284        self.assertEqual('int4', pgdb.NUMBER)
1285        self.assertNotEqual('int4', pgdb.NUMERIC)
1286        self.assertEqual('int2', pgdb.SMALLINT)
1287        self.assertNotEqual('int4', pgdb.SMALLINT)
1288        self.assertEqual('int2', pgdb.INTEGER)
1289        self.assertEqual('int4', pgdb.INTEGER)
1290        self.assertEqual('int8', pgdb.INTEGER)
1291        self.assertNotEqual('int4', pgdb.LONG)
1292        self.assertEqual('int8', pgdb.LONG)
1293        self.assertTrue('char' in pgdb.STRING)
1294        self.assertTrue(pgdb.NUMERIC <= pgdb.NUMBER)
1295        self.assertTrue(pgdb.NUMBER >= pgdb.INTEGER)
1296        self.assertTrue(pgdb.TIME <= pgdb.DATETIME)
1297        self.assertTrue(pgdb.DATETIME >= pgdb.DATE)
1298        self.assertEqual(pgdb.ARRAY, pgdb.ARRAY)
1299        self.assertNotEqual(pgdb.ARRAY, pgdb.STRING)
1300        self.assertEqual('_char', pgdb.ARRAY)
1301        self.assertNotEqual('char', pgdb.ARRAY)
1302        self.assertEqual(pgdb.RECORD, pgdb.RECORD)
1303        self.assertNotEqual(pgdb.RECORD, pgdb.STRING)
1304        self.assertNotEqual(pgdb.RECORD, pgdb.ARRAY)
1305        self.assertEqual('record', pgdb.RECORD)
1306        self.assertNotEqual('_record', pgdb.RECORD)
1307
1308    def test_no_close(self):
1309        data = ('hello', 'world')
1310        con = self._connect()
1311        cur = con.cursor()
1312        cur.build_row_factory = lambda: tuple
1313        cur.execute("select %s, %s", data)
1314        row = cur.fetchone()
1315        self.assertEqual(row, data)
1316
1317    def test_set_row_factory_size(self):
1318        try:
1319            from functools import lru_cache
1320        except ImportError:  # Python < 3.2
1321            lru_cache = None
1322        queries = ['select 1 as a, 2 as b, 3 as c', 'select 123 as abc']
1323        con = self._connect()
1324        cur = con.cursor()
1325        for maxsize in (None, 0, 1, 2, 3, 10, 1024):
1326            pgdb.set_row_factory_size(maxsize)
1327            for i in range(3):
1328                for q in queries:
1329                    cur.execute(q)
1330                    r = cur.fetchone()
1331                    if q.endswith('abc'):
1332                        self.assertEqual(r, (123,))
1333                        self.assertEqual(r._fields, ('abc',))
1334                    else:
1335                        self.assertEqual(r, (1, 2, 3))
1336                        self.assertEqual(r._fields, ('a', 'b', 'c'))
1337            if lru_cache:
1338                info = pgdb._row_factory.cache_info()
1339                self.assertEqual(info.maxsize, maxsize)
1340                self.assertEqual(info.hits + info.misses, 6)
1341                self.assertEqual(info.hits,
1342                    0 if maxsize is not None and maxsize < 2 else 4)
1343
1344    def test_memory_leaks(self):
1345        ids = set()
1346        objs = []
1347        add_ids = ids.update
1348        gc.collect()
1349        objs[:] = gc.get_objects()
1350        add_ids(id(obj) for obj in objs)
1351        self.test_no_close()
1352        gc.collect()
1353        objs[:] = gc.get_objects()
1354        objs[:] = [obj for obj in objs if id(obj) not in ids]
1355        if objs and sys.version_info[:3] in ((3, 5, 0), (3, 5, 1)):
1356            # workaround for Python issue 26811
1357            objs[:] = [obj for obj in objs if repr(obj) != '(<NULL>,)']
1358        self.assertEqual(len(objs), 0)
1359
1360    def test_cve_2018_1058(self):
1361        # internal queries should use qualified table and operator names,
1362        # see https://nvd.nist.gov/vuln/detail/CVE-2018-1058
1363        con = self._connect()
1364        cur = con.cursor()
1365        execute = cur.execute
1366        try:
1367            execute("SET TIMEZONE TO 'UTC'")
1368            execute("SHOW TIMEZONE")
1369            self.assertEqual(cur.fetchone()[0], 'UTC')
1370            execute("""
1371                CREATE OR REPLACE FUNCTION public.bad_eq(oid, integer)
1372                RETURNS boolean AS $$
1373                BEGIN
1374                  SET TIMEZONE TO 'CET';
1375                  RETURN oideq($1, $2::oid);
1376                END
1377                $$ LANGUAGE plpgsql
1378                """)
1379            execute("""
1380                CREATE OPERATOR public.= (
1381                  PROCEDURE = public.bad_eq,
1382                  LEFTARG = oid, RIGHTARG = integer
1383                );
1384                """)
1385            # the following select changes the time zone as a side effect if
1386            # internal query uses unqualified = operator as it did earlier
1387            execute("SELECT 1")
1388            execute("SHOW TIMEZONE")  # make sure time zone has not changed
1389            self.assertEqual(cur.fetchone()[0], 'UTC')
1390        finally:
1391            execute("DROP OPERATOR IF EXISTS public.= (oid, integer)")
1392            execute("DROP FUNCTION IF EXISTS public.bad_eq(oid, integer)")
1393            cur.close()
1394            con.close()
1395
1396
1397if __name__ == '__main__':
1398    unittest.main()
Note: See TracBrowser for help on using the repository browser.