source: trunk/tests/test_classic_dbwrapper.py @ 956

Last change on this file since 956 was 956, checked in by cito, 10 months ago

Make name in query_prepared a keyword-only parameter

  • Property svn:executable set to *
  • Property svn:keywords set to Id
File size: 184.6 KB
Line 
1#! /usr/bin/python
2# -*- coding: utf-8 -*-
3
4"""Test the classic PyGreSQL interface.
5
6Sub-tests for the DB wrapper object.
7
8Contributed by Christoph Zwerschke.
9
10These tests need a database to test against.
11"""
12
13try:
14    import unittest2 as unittest  # for Python < 2.7
15except ImportError:
16    import unittest
17
18import os
19import sys
20import gc
21import json
22import tempfile
23
24import pg  # the module under test
25
26from decimal import Decimal
27from datetime import date, time, datetime, timedelta
28from uuid import UUID
29from time import strftime
30from operator import itemgetter
31
32# We need a database to test against.  If LOCAL_PyGreSQL.py exists we will
33# get our information from that.  Otherwise we use the defaults.
34# The current user must have create schema privilege on the database.
35dbname = 'unittest'
36dbhost = None
37dbport = 5432
38
39debug = False  # let DB wrapper print debugging output
40
41try:
42    from .LOCAL_PyGreSQL import *
43except (ImportError, ValueError):
44    try:
45        from LOCAL_PyGreSQL import *
46    except ImportError:
47        pass
48
49try:
50    long
51except NameError:  # Python >= 3.0
52    long = int
53
54try:
55    unicode
56except NameError:  # Python >= 3.0
57    unicode = str
58
59try:
60    from collections import OrderedDict
61except ImportError:  # Python 2.6 or 3.0
62    OrderedDict = dict
63
64if str is bytes:
65    from StringIO import StringIO
66else:
67    from io import StringIO
68
69windows = os.name == 'nt'
70
71# There is a known a bug in libpq under Windows which can cause
72# the interface to crash when calling PQhost():
73do_not_ask_for_host = windows
74do_not_ask_for_host_reason = 'libpq issue on Windows'
75
76
77def DB():
78    """Create a DB wrapper object connecting to the test database."""
79    db = pg.DB(dbname, dbhost, dbport)
80    if debug:
81        db.debug = debug
82    db.query("set client_min_messages=warning")
83    return db
84
85
86class TestAttrDict(unittest.TestCase):
87    """Test the simple ordered dictionary for attribute names."""
88
89    cls = pg.AttrDict
90    base = OrderedDict
91
92    def testInit(self):
93        a = self.cls()
94        self.assertIsInstance(a, self.base)
95        self.assertEqual(a, self.base())
96        items = [('id', 'int'), ('name', 'text')]
97        a = self.cls(items)
98        self.assertIsInstance(a, self.base)
99        self.assertEqual(a, self.base(items))
100        iteritems = iter(items)
101        a = self.cls(iteritems)
102        self.assertIsInstance(a, self.base)
103        self.assertEqual(a, self.base(items))
104
105    def testIter(self):
106        a = self.cls()
107        self.assertEqual(list(a), [])
108        keys = ['id', 'name', 'age']
109        items = [(key, None) for key in keys]
110        a = self.cls(items)
111        self.assertEqual(list(a), keys)
112
113    def testKeys(self):
114        a = self.cls()
115        self.assertEqual(list(a.keys()), [])
116        keys = ['id', 'name', 'age']
117        items = [(key, None) for key in keys]
118        a = self.cls(items)
119        self.assertEqual(list(a.keys()), keys)
120
121    def testValues(self):
122        a = self.cls()
123        self.assertEqual(list(a.values()), [])
124        items = [('id', 'int'), ('name', 'text')]
125        values = [item[1] for item in items]
126        a = self.cls(items)
127        self.assertEqual(list(a.values()), values)
128
129    def testItems(self):
130        a = self.cls()
131        self.assertEqual(list(a.items()), [])
132        items = [('id', 'int'), ('name', 'text')]
133        a = self.cls(items)
134        self.assertEqual(list(a.items()), items)
135
136    def testGet(self):
137        a = self.cls([('id', 1)])
138        try:
139            self.assertEqual(a['id'], 1)
140        except KeyError:
141            self.fail('AttrDict should be readable')
142
143    def testSet(self):
144        a = self.cls()
145        try:
146            a['id'] = 1
147        except TypeError:
148            pass
149        else:
150            self.fail('AttrDict should be read-only')
151
152    def testDel(self):
153        a = self.cls([('id', 1)])
154        try:
155            del a['id']
156        except TypeError:
157            pass
158        else:
159            self.fail('AttrDict should be read-only')
160
161    def testWriteMethods(self):
162        a = self.cls([('id', 1)])
163        self.assertEqual(a['id'], 1)
164        for method in 'clear', 'update', 'pop', 'setdefault', 'popitem':
165            method = getattr(a, method)
166            self.assertRaises(TypeError, method, a)
167
168
169class TestDBClassInit(unittest.TestCase):
170    """Test proper handling of errors when creating DB instances."""
171
172    def testBadParams(self):
173        self.assertRaises(TypeError, pg.DB, invalid=True)
174
175    def testDeleteDb(self):
176        db = DB()
177        del db.db
178        self.assertRaises(pg.InternalError, db.close)
179        del db
180
181
182class TestDBClassBasic(unittest.TestCase):
183    """Test existence of the DB class wrapped pg connection methods."""
184
185    def setUp(self):
186        self.db = DB()
187
188    def tearDown(self):
189        try:
190            self.db.close()
191        except pg.InternalError:
192            pass
193
194    def testAllDBAttributes(self):
195        attributes = [
196            'abort', 'adapter',
197            'begin',
198            'cancel', 'clear', 'close', 'commit',
199            'date_format', 'db', 'dbname', 'dbtypes',
200            'debug', 'decode_json', 'delete',
201            'delete_prepared', 'describe_prepared',
202            'encode_json', 'end', 'endcopy', 'error',
203            'escape_bytea', 'escape_identifier',
204            'escape_literal', 'escape_string',
205            'fileno',
206            'get', 'get_as_dict', 'get_as_list',
207            'get_attnames', 'get_cast_hook',
208            'get_databases', 'get_notice_receiver',
209            'get_parameter', 'get_relations', 'get_tables',
210            'getline', 'getlo', 'getnotify',
211            'has_table_privilege', 'host',
212            'insert', 'inserttable',
213            'locreate', 'loimport',
214            'notification_handler',
215            'options',
216            'parameter', 'pkey', 'port',
217            'prepare', 'protocol_version', 'putline',
218            'query', 'query_formatted', 'query_prepared',
219            'release', 'reopen', 'reset', 'rollback',
220            'savepoint', 'server_version',
221            'set_cast_hook', 'set_notice_receiver',
222            'set_parameter',
223            'source', 'start', 'status',
224            'transaction', 'truncate',
225            'unescape_bytea', 'update', 'upsert',
226            'use_regtypes', 'user',
227        ]
228        # __dir__ is not called in Python 2.6 for old-style classes
229        db_attributes = dir(self.db) if hasattr(
230            self.db.__class__, '__class__') else self.db.__dir__()
231        db_attributes = [a for a in db_attributes
232            if not a.startswith('_')]
233        self.assertEqual(attributes, db_attributes)
234
235    def testAttributeDb(self):
236        self.assertEqual(self.db.db.db, dbname)
237
238    def testAttributeDbname(self):
239        self.assertEqual(self.db.dbname, dbname)
240
241    def testAttributeError(self):
242        error = self.db.error
243        self.assertTrue(not error or 'krb5_' in error)
244        self.assertEqual(self.db.error, self.db.db.error)
245
246    @unittest.skipIf(do_not_ask_for_host, do_not_ask_for_host_reason)
247    def testAttributeHost(self):
248        if dbhost and not dbhost.startswith('/'):
249            host = dbhost
250        else:
251            host = 'localhost'
252        self.assertIsInstance(self.db.host, str)
253        self.assertEqual(self.db.host, host)
254        self.assertEqual(self.db.db.host, host)
255
256    def testAttributeOptions(self):
257        no_options = ''
258        options = self.db.options
259        self.assertEqual(options, no_options)
260        self.assertEqual(options, self.db.db.options)
261
262    def testAttributePort(self):
263        def_port = 5432
264        port = self.db.port
265        self.assertIsInstance(port, int)
266        self.assertEqual(port, dbport or def_port)
267        self.assertEqual(port, self.db.db.port)
268
269    def testAttributeProtocolVersion(self):
270        protocol_version = self.db.protocol_version
271        self.assertIsInstance(protocol_version, int)
272        self.assertTrue(2 <= protocol_version < 4)
273        self.assertEqual(protocol_version, self.db.db.protocol_version)
274
275    def testAttributeServerVersion(self):
276        server_version = self.db.server_version
277        self.assertIsInstance(server_version, int)
278        self.assertTrue(90000 <= server_version < 110000)
279        self.assertEqual(server_version, self.db.db.server_version)
280
281    def testAttributeStatus(self):
282        status_ok = 1
283        status = self.db.status
284        self.assertIsInstance(status, int)
285        self.assertEqual(status, status_ok)
286        self.assertEqual(status, self.db.db.status)
287
288    def testAttributeUser(self):
289        no_user = 'Deprecated facility'
290        user = self.db.user
291        self.assertTrue(user)
292        self.assertIsInstance(user, str)
293        self.assertNotEqual(user, no_user)
294        self.assertEqual(user, self.db.db.user)
295
296    def testMethodEscapeLiteral(self):
297        self.assertEqual(self.db.escape_literal(''), "''")
298
299    def testMethodEscapeIdentifier(self):
300        self.assertEqual(self.db.escape_identifier(''), '""')
301
302    def testMethodEscapeString(self):
303        self.assertEqual(self.db.escape_string(''), '')
304
305    def testMethodEscapeBytea(self):
306        self.assertEqual(self.db.escape_bytea('').replace(
307            '\\x', '').replace('\\', ''), '')
308
309    def testMethodUnescapeBytea(self):
310        self.assertEqual(self.db.unescape_bytea(''), b'')
311
312    def testMethodDecodeJson(self):
313        self.assertEqual(self.db.decode_json('{}'), {})
314
315    def testMethodEncodeJson(self):
316        self.assertEqual(self.db.encode_json({}), '{}')
317
318    def testMethodQuery(self):
319        query = self.db.query
320        query("select 1+1")
321        query("select 1+$1+$2", 2, 3)
322        query("select 1+$1+$2", (2, 3))
323        query("select 1+$1+$2", [2, 3])
324        query("select 1+$1", 1)
325
326    def testMethodQueryEmpty(self):
327        self.assertRaises(ValueError, self.db.query, '')
328
329    def testMethodQueryDataError(self):
330        try:
331            self.db.query("select 1/0")
332        except pg.DataError as error:
333            self.assertEqual(error.sqlstate, '22012')
334
335    def testMethodEndcopy(self):
336        try:
337            self.db.endcopy()
338        except IOError:
339            pass
340
341    def testMethodClose(self):
342        self.db.close()
343        try:
344            self.db.reset()
345        except pg.Error:
346            pass
347        else:
348            self.fail('Reset should give an error for a closed connection')
349        self.assertIsNone(self.db.db)
350        self.assertRaises(pg.InternalError, self.db.close)
351        self.assertRaises(pg.InternalError, self.db.query, 'select 1')
352        self.assertRaises(pg.InternalError, getattr, self.db, 'status')
353        self.assertRaises(pg.InternalError, getattr, self.db, 'error')
354        self.assertRaises(pg.InternalError, getattr, self.db, 'absent')
355
356    def testMethodReset(self):
357        con = self.db.db
358        self.db.reset()
359        self.assertIs(self.db.db, con)
360        self.db.query("select 1+1")
361        self.db.close()
362        self.assertRaises(pg.InternalError, self.db.reset)
363
364    def testMethodReopen(self):
365        con = self.db.db
366        self.db.reopen()
367        self.assertIsNot(self.db.db, con)
368        con = self.db.db
369        self.db.query("select 1+1")
370        self.db.close()
371        self.db.reopen()
372        self.assertIsNot(self.db.db, con)
373        self.db.query("select 1+1")
374        self.db.close()
375
376    def testExistingConnection(self):
377        db = pg.DB(self.db.db)
378        self.assertEqual(self.db.db, db.db)
379        self.assertTrue(db.db)
380        db.close()
381        self.assertTrue(db.db)
382        db.reopen()
383        self.assertTrue(db.db)
384        db.close()
385        self.assertTrue(db.db)
386        db = pg.DB(self.db)
387        self.assertEqual(self.db.db, db.db)
388        db = pg.DB(db=self.db.db)
389        self.assertEqual(self.db.db, db.db)
390
391        class DB2:
392            pass
393
394        db2 = DB2()
395        db2._cnx = self.db.db
396        db = pg.DB(db2)
397        self.assertEqual(self.db.db, db.db)
398
399
400class TestDBClass(unittest.TestCase):
401    """Test the methods of the DB class wrapped pg connection."""
402
403    maxDiff = 80 * 20
404
405    cls_set_up = False
406
407    regtypes = None
408
409    @classmethod
410    def setUpClass(cls):
411        db = DB()
412        db.query("drop table if exists test cascade")
413        db.query("create table test ("
414                 "i2 smallint, i4 integer, i8 bigint,"
415                 " d numeric, f4 real, f8 double precision, m money,"
416                 " v4 varchar(4), c4 char(4), t text)")
417        db.query("create or replace view test_view as"
418                 " select i4, v4 from test")
419        db.close()
420        cls.cls_set_up = True
421
422    @classmethod
423    def tearDownClass(cls):
424        db = DB()
425        db.query("drop table test cascade")
426        db.close()
427
428    def setUp(self):
429        self.assertTrue(self.cls_set_up)
430        self.db = DB()
431        if self.regtypes is None:
432            self.regtypes = self.db.use_regtypes()
433        else:
434            self.db.use_regtypes(self.regtypes)
435        query = self.db.query
436        query('set client_encoding=utf8')
437        query("set lc_monetary='C'")
438        query("set datestyle='ISO,YMD'")
439        query('set standard_conforming_strings=on')
440        try:
441            query('set bytea_output=hex')
442        except pg.ProgrammingError:
443            if self.db.server_version >= 90000:
444                raise  # ignore for older server versions
445
446    def tearDown(self):
447        self.doCleanups()
448        self.db.close()
449
450    def createTable(self, table, definition,
451                    temporary=True, oids=None, values=None):
452        query = self.db.query
453        if not '"' in table or '.' in table:
454            table = '"%s"' % table
455        if not temporary:
456            q = 'drop table if exists %s cascade' % table
457            query(q)
458            self.addCleanup(query, q)
459        temporary = 'temporary table' if temporary else 'table'
460        as_query = definition.startswith(('as ', 'AS '))
461        if not as_query and not definition.startswith('('):
462            definition = '(%s)' % definition
463        with_oids = 'with oids' if oids else 'without oids'
464        q = ['create', temporary, table]
465        if as_query:
466            q.extend([with_oids, definition])
467        else:
468            q.extend([definition, with_oids])
469        q = ' '.join(q)
470        query(q)
471        if values:
472            for params in values:
473                if not isinstance(params, (list, tuple)):
474                    params = [params]
475                values = ', '.join('$%d' % (n + 1) for n in range(len(params)))
476                q = "insert into %s values (%s)" % (table, values)
477                query(q, params)
478
479    def testClassName(self):
480        self.assertEqual(self.db.__class__.__name__, 'DB')
481
482    def testModuleName(self):
483        self.assertEqual(self.db.__module__, 'pg')
484        self.assertEqual(self.db.__class__.__module__, 'pg')
485
486    def testEscapeLiteral(self):
487        f = self.db.escape_literal
488        r = f(b"plain")
489        self.assertIsInstance(r, bytes)
490        self.assertEqual(r, b"'plain'")
491        r = f(u"plain")
492        self.assertIsInstance(r, unicode)
493        self.assertEqual(r, u"'plain'")
494        r = f(u"that's kÀse".encode('utf-8'))
495        self.assertIsInstance(r, bytes)
496        self.assertEqual(r, u"'that''s kÀse'".encode('utf-8'))
497        r = f(u"that's kÀse")
498        self.assertIsInstance(r, unicode)
499        self.assertEqual(r, u"'that''s kÀse'")
500        self.assertEqual(f(r"It's fine to have a \ inside."),
501                         r" E'It''s fine to have a \\ inside.'")
502        self.assertEqual(f('No "quotes" must be escaped.'),
503                         "'No \"quotes\" must be escaped.'")
504
505    def testEscapeIdentifier(self):
506        f = self.db.escape_identifier
507        r = f(b"plain")
508        self.assertIsInstance(r, bytes)
509        self.assertEqual(r, b'"plain"')
510        r = f(u"plain")
511        self.assertIsInstance(r, unicode)
512        self.assertEqual(r, u'"plain"')
513        r = f(u"that's kÀse".encode('utf-8'))
514        self.assertIsInstance(r, bytes)
515        self.assertEqual(r, u'"that\'s kÀse"'.encode('utf-8'))
516        r = f(u"that's kÀse")
517        self.assertIsInstance(r, unicode)
518        self.assertEqual(r, u'"that\'s kÀse"')
519        self.assertEqual(f(r"It's fine to have a \ inside."),
520                         '"It\'s fine to have a \\ inside."')
521        self.assertEqual(f('All "quotes" must be escaped.'),
522                         '"All ""quotes"" must be escaped."')
523
524    def testEscapeString(self):
525        f = self.db.escape_string
526        r = f(b"plain")
527        self.assertIsInstance(r, bytes)
528        self.assertEqual(r, b"plain")
529        r = f(u"plain")
530        self.assertIsInstance(r, unicode)
531        self.assertEqual(r, u"plain")
532        r = f(u"that's kÀse".encode('utf-8'))
533        self.assertIsInstance(r, bytes)
534        self.assertEqual(r, u"that''s kÀse".encode('utf-8'))
535        r = f(u"that's kÀse")
536        self.assertIsInstance(r, unicode)
537        self.assertEqual(r, u"that''s kÀse")
538        self.assertEqual(f(r"It's fine to have a \ inside."),
539                         r"It''s fine to have a \ inside.")
540
541    def testEscapeBytea(self):
542        f = self.db.escape_bytea
543        # note that escape_byte always returns hex output since Pg 9.0,
544        # regardless of the bytea_output setting
545        r = f(b'plain')
546        self.assertIsInstance(r, bytes)
547        self.assertEqual(r, b'\\x706c61696e')
548        r = f(u'plain')
549        self.assertIsInstance(r, unicode)
550        self.assertEqual(r, u'\\x706c61696e')
551        r = f(u"das is' kÀse".encode('utf-8'))
552        self.assertIsInstance(r, bytes)
553        self.assertEqual(r, b'\\x64617320697327206bc3a47365')
554        r = f(u"das is' kÀse")
555        self.assertIsInstance(r, unicode)
556        self.assertEqual(r, u'\\x64617320697327206bc3a47365')
557        self.assertEqual(f(b'O\x00ps\xff!'), b'\\x4f007073ff21')
558
559    def testUnescapeBytea(self):
560        f = self.db.unescape_bytea
561        r = f(b'plain')
562        self.assertIsInstance(r, bytes)
563        self.assertEqual(r, b'plain')
564        r = f(u'plain')
565        self.assertIsInstance(r, bytes)
566        self.assertEqual(r, b'plain')
567        r = f(b"das is' k\\303\\244se")
568        self.assertIsInstance(r, bytes)
569        self.assertEqual(r, u"das is' kÀse".encode('utf8'))
570        r = f(u"das is' k\\303\\244se")
571        self.assertIsInstance(r, bytes)
572        self.assertEqual(r, u"das is' kÀse".encode('utf8'))
573        self.assertEqual(f(r'O\\000ps\\377!'), b'O\\000ps\\377!')
574        self.assertEqual(f(r'\\x706c61696e'), b'\\x706c61696e')
575        self.assertEqual(f(r'\\x746861742773206be47365'),
576                         b'\\x746861742773206be47365')
577        self.assertEqual(f(r'\\x4f007073ff21'), b'\\x4f007073ff21')
578
579    def testDecodeJson(self):
580        f = self.db.decode_json
581        self.assertIsNone(f('null'))
582        data = {
583            "id": 1, "name": "Foo", "price": 1234.5,
584            "new": True, "note": None,
585            "tags": ["Bar", "Eek"],
586            "stock": {"warehouse": 300, "retail": 20}}
587        text = json.dumps(data)
588        r = f(text)
589        self.assertIsInstance(r, dict)
590        self.assertEqual(r, data)
591        self.assertIsInstance(r['id'], int)
592        self.assertIsInstance(r['name'], unicode)
593        self.assertIsInstance(r['price'], float)
594        self.assertIsInstance(r['new'], bool)
595        self.assertIsInstance(r['tags'], list)
596        self.assertIsInstance(r['stock'], dict)
597
598    def testEncodeJson(self):
599        f = self.db.encode_json
600        self.assertEqual(f(None), 'null')
601        data = {
602            "id": 1, "name": "Foo", "price": 1234.5,
603            "new": True, "note": None,
604            "tags": ["Bar", "Eek"],
605            "stock": {"warehouse": 300, "retail": 20}}
606        text = json.dumps(data)
607        r = f(data)
608        self.assertIsInstance(r, str)
609        self.assertEqual(r, text)
610
611    def testGetParameter(self):
612        f = self.db.get_parameter
613        self.assertRaises(TypeError, f)
614        self.assertRaises(TypeError, f, None)
615        self.assertRaises(TypeError, f, 42)
616        self.assertRaises(TypeError, f, '')
617        self.assertRaises(TypeError, f, [])
618        self.assertRaises(TypeError, f, [''])
619        self.assertRaises(pg.ProgrammingError, f, 'this_does_not_exist')
620        r = f('standard_conforming_strings')
621        self.assertEqual(r, 'on')
622        r = f('lc_monetary')
623        self.assertEqual(r, 'C')
624        r = f('datestyle')
625        self.assertEqual(r, 'ISO, YMD')
626        r = f('bytea_output')
627        self.assertEqual(r, 'hex')
628        r = f(['bytea_output', 'lc_monetary'])
629        self.assertIsInstance(r, list)
630        self.assertEqual(r, ['hex', 'C'])
631        r = f(('standard_conforming_strings', 'datestyle', 'bytea_output'))
632        self.assertEqual(r, ['on', 'ISO, YMD', 'hex'])
633        r = f(set(['bytea_output', 'lc_monetary']))
634        self.assertIsInstance(r, dict)
635        self.assertEqual(r, {'bytea_output': 'hex', 'lc_monetary': 'C'})
636        r = f(set(['Bytea_Output', ' LC_Monetary ']))
637        self.assertIsInstance(r, dict)
638        self.assertEqual(r, {'Bytea_Output': 'hex', ' LC_Monetary ': 'C'})
639        s = dict.fromkeys(('bytea_output', 'lc_monetary'))
640        r = f(s)
641        self.assertIs(r, s)
642        self.assertEqual(r, {'bytea_output': 'hex', 'lc_monetary': 'C'})
643        s = dict.fromkeys(('Bytea_Output', ' LC_Monetary '))
644        r = f(s)
645        self.assertIs(r, s)
646        self.assertEqual(r, {'Bytea_Output': 'hex', ' LC_Monetary ': 'C'})
647
648    def testGetParameterServerVersion(self):
649        r = self.db.get_parameter('server_version_num')
650        self.assertIsInstance(r, str)
651        s = self.db.server_version
652        self.assertIsInstance(s, int)
653        self.assertEqual(r, str(s))
654
655    def testGetParameterAll(self):
656        f = self.db.get_parameter
657        r = f('all')
658        self.assertIsInstance(r, dict)
659        self.assertEqual(r['standard_conforming_strings'], 'on')
660        self.assertEqual(r['lc_monetary'], 'C')
661        self.assertEqual(r['DateStyle'], 'ISO, YMD')
662        self.assertEqual(r['bytea_output'], 'hex')
663
664    def testSetParameter(self):
665        f = self.db.set_parameter
666        g = self.db.get_parameter
667        self.assertRaises(TypeError, f)
668        self.assertRaises(TypeError, f, None)
669        self.assertRaises(TypeError, f, 42)
670        self.assertRaises(TypeError, f, '')
671        self.assertRaises(TypeError, f, [])
672        self.assertRaises(TypeError, f, [''])
673        self.assertRaises(ValueError, f, 'all', 'invalid')
674        self.assertRaises(ValueError, f, {
675            'invalid1': 'value1', 'invalid2': 'value2'}, 'value')
676        self.assertRaises(pg.ProgrammingError, f, 'this_does_not_exist')
677        f('standard_conforming_strings', 'off')
678        self.assertEqual(g('standard_conforming_strings'), 'off')
679        f('datestyle', 'ISO, DMY')
680        self.assertEqual(g('datestyle'), 'ISO, DMY')
681        f(['standard_conforming_strings', 'datestyle'], ['on', 'ISO, DMY'])
682        self.assertEqual(g('standard_conforming_strings'), 'on')
683        self.assertEqual(g('datestyle'), 'ISO, DMY')
684        f(['default_with_oids', 'standard_conforming_strings'], 'off')
685        self.assertEqual(g('default_with_oids'), 'off')
686        self.assertEqual(g('standard_conforming_strings'), 'off')
687        f(('standard_conforming_strings', 'datestyle'), ('on', 'ISO, YMD'))
688        self.assertEqual(g('standard_conforming_strings'), 'on')
689        self.assertEqual(g('datestyle'), 'ISO, YMD')
690        f(('default_with_oids', 'standard_conforming_strings'), 'off')
691        self.assertEqual(g('default_with_oids'), 'off')
692        self.assertEqual(g('standard_conforming_strings'), 'off')
693        f(set(['default_with_oids', 'standard_conforming_strings']), 'on')
694        self.assertEqual(g('default_with_oids'), 'on')
695        self.assertEqual(g('standard_conforming_strings'), 'on')
696        self.assertRaises(ValueError, f, set(['default_with_oids',
697            'standard_conforming_strings']), ['off', 'on'])
698        f(set(['default_with_oids', 'standard_conforming_strings']),
699            ['off', 'off'])
700        self.assertEqual(g('default_with_oids'), 'off')
701        self.assertEqual(g('standard_conforming_strings'), 'off')
702        f({'standard_conforming_strings': 'on', 'datestyle': 'ISO, YMD'})
703        self.assertEqual(g('standard_conforming_strings'), 'on')
704        self.assertEqual(g('datestyle'), 'ISO, YMD')
705
706    def testResetParameter(self):
707        db = DB()
708        f = db.set_parameter
709        g = db.get_parameter
710        r = g('default_with_oids')
711        self.assertIn(r, ('on', 'off'))
712        dwi, not_dwi = r, 'off' if r == 'on' else 'on'
713        r = g('standard_conforming_strings')
714        self.assertIn(r, ('on', 'off'))
715        scs, not_scs = r, 'off' if r == 'on' else 'on'
716        f('default_with_oids', not_dwi)
717        f('standard_conforming_strings', not_scs)
718        self.assertEqual(g('default_with_oids'), not_dwi)
719        self.assertEqual(g('standard_conforming_strings'), not_scs)
720        f('default_with_oids')
721        f('standard_conforming_strings', None)
722        self.assertEqual(g('default_with_oids'), dwi)
723        self.assertEqual(g('standard_conforming_strings'), scs)
724        f('default_with_oids', not_dwi)
725        f('standard_conforming_strings', not_scs)
726        self.assertEqual(g('default_with_oids'), not_dwi)
727        self.assertEqual(g('standard_conforming_strings'), not_scs)
728        f(['default_with_oids', 'standard_conforming_strings'], None)
729        self.assertEqual(g('default_with_oids'), dwi)
730        self.assertEqual(g('standard_conforming_strings'), scs)
731        f('default_with_oids', not_dwi)
732        f('standard_conforming_strings', not_scs)
733        self.assertEqual(g('default_with_oids'), not_dwi)
734        self.assertEqual(g('standard_conforming_strings'), not_scs)
735        f(('default_with_oids', 'standard_conforming_strings'))
736        self.assertEqual(g('default_with_oids'), dwi)
737        self.assertEqual(g('standard_conforming_strings'), scs)
738        f('default_with_oids', not_dwi)
739        f('standard_conforming_strings', not_scs)
740        self.assertEqual(g('default_with_oids'), not_dwi)
741        self.assertEqual(g('standard_conforming_strings'), not_scs)
742        f(set(['default_with_oids', 'standard_conforming_strings']))
743        self.assertEqual(g('default_with_oids'), dwi)
744        self.assertEqual(g('standard_conforming_strings'), scs)
745        db.close()
746
747    def testResetParameterAll(self):
748        db = DB()
749        f = db.set_parameter
750        self.assertRaises(ValueError, f, 'all', 0)
751        self.assertRaises(ValueError, f, 'all', 'off')
752        g = db.get_parameter
753        r = g('default_with_oids')
754        self.assertIn(r, ('on', 'off'))
755        dwi, not_dwi = r, 'off' if r == 'on' else 'on'
756        r = g('standard_conforming_strings')
757        self.assertIn(r, ('on', 'off'))
758        scs, not_scs = r, 'off' if r == 'on' else 'on'
759        f('default_with_oids', not_dwi)
760        f('standard_conforming_strings', not_scs)
761        self.assertEqual(g('default_with_oids'), not_dwi)
762        self.assertEqual(g('standard_conforming_strings'), not_scs)
763        f('all')
764        self.assertEqual(g('default_with_oids'), dwi)
765        self.assertEqual(g('standard_conforming_strings'), scs)
766        db.close()
767
768    def testSetParameterLocal(self):
769        f = self.db.set_parameter
770        g = self.db.get_parameter
771        self.assertEqual(g('standard_conforming_strings'), 'on')
772        self.db.begin()
773        f('standard_conforming_strings', 'off', local=True)
774        self.assertEqual(g('standard_conforming_strings'), 'off')
775        self.db.end()
776        self.assertEqual(g('standard_conforming_strings'), 'on')
777
778    def testSetParameterSession(self):
779        f = self.db.set_parameter
780        g = self.db.get_parameter
781        self.assertEqual(g('standard_conforming_strings'), 'on')
782        self.db.begin()
783        f('standard_conforming_strings', 'off', local=False)
784        self.assertEqual(g('standard_conforming_strings'), 'off')
785        self.db.end()
786        self.assertEqual(g('standard_conforming_strings'), 'off')
787
788    def testReset(self):
789        db = DB()
790        default_datestyle = db.get_parameter('datestyle')
791        changed_datestyle = 'ISO, DMY'
792        if changed_datestyle == default_datestyle:
793            changed_datestyle == 'ISO, YMD'
794        self.db.set_parameter('datestyle', changed_datestyle)
795        r = self.db.get_parameter('datestyle')
796        self.assertEqual(r, changed_datestyle)
797        con = self.db.db
798        q = con.query("show datestyle")
799        self.db.reset()
800        r = q.getresult()[0][0]
801        self.assertEqual(r, changed_datestyle)
802        q = con.query("show datestyle")
803        r = q.getresult()[0][0]
804        self.assertEqual(r, default_datestyle)
805        r = self.db.get_parameter('datestyle')
806        self.assertEqual(r, default_datestyle)
807        db.close()
808
809    def testReopen(self):
810        db = DB()
811        default_datestyle = db.get_parameter('datestyle')
812        changed_datestyle = 'ISO, DMY'
813        if changed_datestyle == default_datestyle:
814            changed_datestyle == 'ISO, YMD'
815        self.db.set_parameter('datestyle', changed_datestyle)
816        r = self.db.get_parameter('datestyle')
817        self.assertEqual(r, changed_datestyle)
818        con = self.db.db
819        q = con.query("show datestyle")
820        self.db.reopen()
821        r = q.getresult()[0][0]
822        self.assertEqual(r, changed_datestyle)
823        self.assertRaises(TypeError, getattr, con, 'query')
824        r = self.db.get_parameter('datestyle')
825        self.assertEqual(r, default_datestyle)
826        db.close()
827
828    def testCreateTable(self):
829        table = 'test hello world'
830        values = [(2, "World!"), (1, "Hello")]
831        self.createTable(table, "n smallint, t varchar",
832                         temporary=True, oids=True, values=values)
833        r = self.db.query('select t from "%s" order by n' % table).getresult()
834        r = ', '.join(row[0] for row in r)
835        self.assertEqual(r, "Hello, World!")
836        r = self.db.query('select oid from "%s" limit 1' % table).getresult()
837        self.assertIsInstance(r[0][0], int)
838
839    def testQuery(self):
840        query = self.db.query
841        table = 'test_table'
842        self.createTable(table, "n integer", oids=True)
843        q = "insert into test_table values (1)"
844        r = query(q)
845        self.assertIsInstance(r, int)
846        q = "insert into test_table select 2"
847        r = query(q)
848        self.assertIsInstance(r, int)
849        oid = r
850        q = "select oid from test_table where n=2"
851        r = query(q).getresult()
852        self.assertEqual(len(r), 1)
853        r = r[0]
854        self.assertEqual(len(r), 1)
855        r = r[0]
856        self.assertEqual(r, oid)
857        q = "insert into test_table select 3 union select 4 union select 5"
858        r = query(q)
859        self.assertIsInstance(r, str)
860        self.assertEqual(r, '3')
861        q = "update test_table set n=4 where n<5"
862        r = query(q)
863        self.assertIsInstance(r, str)
864        self.assertEqual(r, '4')
865        q = "delete from test_table"
866        r = query(q)
867        self.assertIsInstance(r, str)
868        self.assertEqual(r, '5')
869
870    def testMultipleQueries(self):
871        self.assertEqual(self.db.query(
872            "create temporary table test_multi (n integer);"
873            "insert into test_multi values (4711);"
874            "select n from test_multi").getresult()[0][0], 4711)
875
876    def testQueryWithParams(self):
877        query = self.db.query
878        self.createTable('test_table', 'n1 integer, n2 integer', oids=True)
879        q = "insert into test_table values ($1, $2)"
880        r = query(q, (1, 2))
881        self.assertIsInstance(r, int)
882        r = query(q, [3, 4])
883        self.assertIsInstance(r, int)
884        r = query(q, [5, 6])
885        self.assertIsInstance(r, int)
886        q = "select * from test_table order by 1, 2"
887        self.assertEqual(query(q).getresult(),
888                         [(1, 2), (3, 4), (5, 6)])
889        q = "select * from test_table where n1=$1 and n2=$2"
890        self.assertEqual(query(q, 3, 4).getresult(), [(3, 4)])
891        q = "update test_table set n2=$2 where n1=$1"
892        r = query(q, 3, 7)
893        self.assertEqual(r, '1')
894        q = "select * from test_table order by 1, 2"
895        self.assertEqual(query(q).getresult(),
896                         [(1, 2), (3, 7), (5, 6)])
897        q = "delete from test_table where n2!=$1"
898        r = query(q, 4)
899        self.assertEqual(r, '3')
900
901    def testEmptyQuery(self):
902        self.assertRaises(ValueError, self.db.query, '')
903
904    def testQueryDataError(self):
905        try:
906            self.db.query("select 1/0")
907        except pg.DataError as error:
908            self.assertEqual(error.sqlstate, '22012')
909
910    def testQueryFormatted(self):
911        f = self.db.query_formatted
912        t = True if pg.get_bool() else 't'
913        # test with tuple
914        q = f("select %s::int, %s::real, %s::text, %s::bool",
915              (3, 2.5, 'hello', True))
916        r = q.getresult()[0]
917        self.assertEqual(r, (3, 2.5, 'hello', t))
918        # test with tuple, inline
919        q = f("select %s, %s, %s, %s", (3, 2.5, 'hello', True), inline=True)
920        r = q.getresult()[0]
921        if isinstance(r[1], Decimal):
922            # Python 2.6 cannot compare float and Decimal
923            r = list(r)
924            r[1] = float(r[1])
925            r = tuple(r)
926        self.assertEqual(r, (3, 2.5, 'hello', t))
927        # test with dict
928        q = f("select %(a)s::int, %(b)s::real, %(c)s::text, %(d)s::bool",
929              dict(a=3, b=2.5, c='hello', d=True))
930        r = q.getresult()[0]
931        self.assertEqual(r, (3, 2.5, 'hello', t))
932        # test with dict, inline
933        q = f("select %(a)s, %(b)s, %(c)s, %(d)s",
934              dict(a=3, b=2.5, c='hello', d=True), inline=True)
935        r = q.getresult()[0]
936        if isinstance(r[1], Decimal):
937            # Python 2.6 cannot compare float and Decimal
938            r = list(r)
939            r[1] = float(r[1])
940            r = tuple(r)
941        self.assertEqual(r, (3, 2.5, 'hello', t))
942        # test with dict and extra values
943        q = f("select %(a)s||%(b)s||%(c)s||%(d)s||'epsilon'",
944              dict(a='alpha', b='beta', c='gamma', d='delta', e='extra'))
945        r = q.getresult()[0][0]
946        self.assertEqual(r, 'alphabetagammadeltaepsilon')
947
948    def testQueryFormattedWithAny(self):
949        f = self.db.query_formatted
950        q = "select 2 = any(%s)"
951        r = f(q, [[1, 3]]).getresult()[0][0]
952        self.assertEqual(r, False if pg.get_bool() else 'f')
953        r = f(q, [[1, 2, 3]]).getresult()[0][0]
954        self.assertEqual(r, True if pg.get_bool() else 't')
955        r = f(q, [[]]).getresult()[0][0]
956        self.assertEqual(r, False if pg.get_bool() else 'f')
957        r = f(q, [[None]]).getresult()[0][0]
958        self.assertIsNone(r)
959
960    def testQueryFormattedWithoutParams(self):
961        f = self.db.query_formatted
962        q = "select 42"
963        r = f(q).getresult()[0][0]
964        self.assertEqual(r, 42)
965        r = f(q, None).getresult()[0][0]
966        self.assertEqual(r, 42)
967        r = f(q, []).getresult()[0][0]
968        self.assertEqual(r, 42)
969        r = f(q, {}).getresult()[0][0]
970        self.assertEqual(r, 42)
971
972    def testPrepare(self):
973        p = self.db.prepare
974        self.assertIsNone(p("select 'hello'", 'my query'))
975        self.assertIsNone(p("select 'world'", 'my other query'))
976        self.assertRaises(pg.ProgrammingError,
977            p, 'my query', "select 'hello, too'")
978
979    def testPrepareUnnamed(self):
980        p = self.db.prepare
981        self.assertIsNone(p("select null"))
982        self.assertIsNone(p("select null", None))
983        self.assertIsNone(p("select null", ''))
984        self.assertIsNone(p("select null", name=None))
985        self.assertIsNone(p("select null", name=''))
986
987    def testQueryPreparedWithoutParams(self):
988        f = self.db.query_prepared
989        self.assertRaises(pg.OperationalError, f, 'q')
990        p = self.db.prepare
991        p("select 17", name='q1')
992        p("select 42", name='q2')
993        r = f(name='q1').getresult()[0][0]
994        self.assertEqual(r, 17)
995        r = f(name='q2').getresult()[0][0]
996        self.assertEqual(r, 42)
997
998    def testQueryPreparedWithParams(self):
999        p = self.db.prepare
1000        p("select 1 + $1 + $2 + $3", name='sum')
1001        p("select initcap($1) || ', ' || $2 || '!'", name='cat')
1002        f = self.db.query_prepared
1003        r = f(2, 3, 5, name='sum').getresult()[0][0]
1004        self.assertEqual(r, 11)
1005        r = f('hello', 'world', name='cat').getresult()[0][0]
1006        self.assertEqual(r, 'Hello, world!')
1007
1008    def testQueryPreparedUnnamedWithOutParams(self):
1009        f = self.db.query_prepared
1010        self.assertRaises(pg.OperationalError, f)
1011        p = self.db.prepare
1012        p("select 'no name'")
1013        r = f().getresult()[0][0]
1014        self.assertEqual(r, 'no name')
1015        r = f(name=None).getresult()[0][0]
1016        self.assertEqual(r, 'no name')
1017        r = f(name='').getresult()[0][0]
1018        self.assertEqual(r, 'no name')
1019
1020    def testQueryPreparedUnnamedWithParams(self):
1021        p = self.db.prepare
1022        p("select 1 + $1 + $2")
1023        f = self.db.query_prepared
1024        r = f(2, 3).getresult()[0][0]
1025        self.assertEqual(r, 6)
1026        r = f(2, 3, name=None).getresult()[0][0]
1027        self.assertEqual(r, 6)
1028        r = f(2, 3, name='').getresult()[0][0]
1029        self.assertEqual(r, 6)
1030
1031    def testDescribePrepared(self):
1032        self.db.prepare("select 1 as first, 2 as second", 'count')
1033        f = self.db.describe_prepared
1034        r = f('count').listfields()
1035        self.assertEqual(r, ('first', 'second'))
1036
1037    def testDescribePreparedUnnamed(self):
1038        self.db.prepare("select null as anon")
1039        f = self.db.describe_prepared
1040        r = f().listfields()
1041        self.assertEqual(r, ('anon',))
1042        r = f(None).listfields()
1043        self.assertEqual(r, ('anon',))
1044        r = f('').listfields()
1045        self.assertEqual(r, ('anon',))
1046
1047    def testDeletePrepared(self):
1048        f = self.db.delete_prepared
1049        f()
1050        e = pg.OperationalError
1051        self.assertRaises(e, f, 'myquery')
1052        p = self.db.prepare
1053        p("select 1", 'q1')
1054        p("select 2", 'q2')
1055        f('q1')
1056        f('q2')
1057        self.assertRaises(e, f, 'q1')
1058        self.assertRaises(e, f, 'q2')
1059        p("select 1", 'q1')
1060        p("select 2", 'q2')
1061        f()
1062        self.assertRaises(e, f, 'q1')
1063        self.assertRaises(e, f, 'q2')
1064
1065    def testPkey(self):
1066        query = self.db.query
1067        pkey = self.db.pkey
1068        self.assertRaises(KeyError, pkey, 'test')
1069        for t in ('pkeytest', 'primary key test'):
1070            self.createTable('%s0' % t, 'a smallint')
1071            self.createTable('%s1' % t, 'b smallint primary key')
1072            self.createTable('%s2' % t,
1073                'c smallint, d smallint primary key')
1074            self.createTable('%s3' % t,
1075                'e smallint, f smallint, g smallint, h smallint, i smallint,'
1076                ' primary key (f, h)')
1077            self.createTable('%s4' % t,
1078                'e smallint, f smallint, g smallint, h smallint, i smallint,'
1079                ' primary key (h, f)')
1080            self.createTable('%s5' % t,
1081                'more_than_one_letter varchar primary key')
1082            self.createTable('%s6' % t,
1083                '"with space" date primary key')
1084            self.createTable('%s7' % t,
1085                'a_very_long_column_name varchar, "with space" date, "42" int,'
1086                ' primary key (a_very_long_column_name, "with space", "42")')
1087            self.assertRaises(KeyError, pkey, '%s0' % t)
1088            self.assertEqual(pkey('%s1' % t), 'b')
1089            self.assertEqual(pkey('%s1' % t, True), ('b',))
1090            self.assertEqual(pkey('%s1' % t, composite=False), 'b')
1091            self.assertEqual(pkey('%s1' % t, composite=True), ('b',))
1092            self.assertEqual(pkey('%s2' % t), 'd')
1093            self.assertEqual(pkey('%s2' % t, composite=True), ('d',))
1094            r = pkey('%s3' % t)
1095            self.assertIsInstance(r, tuple)
1096            self.assertEqual(r, ('f', 'h'))
1097            r = pkey('%s3' % t, composite=False)
1098            self.assertIsInstance(r, tuple)
1099            self.assertEqual(r, ('f', 'h'))
1100            r = pkey('%s4' % t)
1101            self.assertIsInstance(r, tuple)
1102            self.assertEqual(r, ('h', 'f'))
1103            self.assertEqual(pkey('%s5' % t), 'more_than_one_letter')
1104            self.assertEqual(pkey('%s6' % t), 'with space')
1105            r = pkey('%s7' % t)
1106            self.assertIsInstance(r, tuple)
1107            self.assertEqual(r, (
1108                'a_very_long_column_name', 'with space', '42'))
1109            # a newly added primary key will be detected
1110            query('alter table "%s0" add primary key (a)' % t)
1111            self.assertEqual(pkey('%s0' % t), 'a')
1112            # a changed primary key will not be detected,
1113            # indicating that the internal cache is operating
1114            query('alter table "%s1" rename column b to x' % t)
1115            self.assertEqual(pkey('%s1' % t), 'b')
1116            # we get the changed primary key when the cache is flushed
1117            self.assertEqual(pkey('%s1' % t, flush=True), 'x')
1118
1119    def testGetDatabases(self):
1120        databases = self.db.get_databases()
1121        self.assertIn('template0', databases)
1122        self.assertIn('template1', databases)
1123        self.assertNotIn('not existing database', databases)
1124        self.assertIn('postgres', databases)
1125        self.assertIn(dbname, databases)
1126
1127    def testGetTables(self):
1128        get_tables = self.db.get_tables
1129        tables = ('A very Special Name', 'A_MiXeD_quoted_NaMe',
1130                  'Hello, Test World!', 'Zoro', 'a1', 'a2', 'a321',
1131                  'averyveryveryveryveryveryveryreallyreallylongtablename',
1132                  'b0', 'b3', 'x', 'xXx', 'xx', 'y', 'z')
1133        for t in tables:
1134            self.db.query('drop table if exists "%s" cascade' % t)
1135        before_tables = get_tables()
1136        self.assertIsInstance(before_tables, list)
1137        for t in before_tables:
1138            t = t.split('.', 1)
1139            self.assertGreaterEqual(len(t), 2)
1140            if len(t) > 2:
1141                self.assertTrue(t[1].startswith('"'))
1142            t = t[0]
1143            self.assertNotEqual(t, 'information_schema')
1144            self.assertFalse(t.startswith('pg_'))
1145        for t in tables:
1146            self.createTable(t, 'as select 0', temporary=False)
1147        current_tables = get_tables()
1148        new_tables = [t for t in current_tables if t not in before_tables]
1149        expected_new_tables = ['public.%s' % (
1150            '"%s"' % t if ' ' in t or t != t.lower() else t) for t in tables]
1151        self.assertEqual(new_tables, expected_new_tables)
1152        self.doCleanups()
1153        after_tables = get_tables()
1154        self.assertEqual(after_tables, before_tables)
1155
1156    def testGetSystemTables(self):
1157        get_tables = self.db.get_tables
1158        result = get_tables()
1159        self.assertNotIn('pg_catalog.pg_class', result)
1160        self.assertNotIn('information_schema.tables', result)
1161        result = get_tables(system=False)
1162        self.assertNotIn('pg_catalog.pg_class', result)
1163        self.assertNotIn('information_schema.tables', result)
1164        result = get_tables(system=True)
1165        self.assertIn('pg_catalog.pg_class', result)
1166        self.assertNotIn('information_schema.tables', result)
1167
1168    def testGetRelations(self):
1169        get_relations = self.db.get_relations
1170        result = get_relations()
1171        self.assertIn('public.test', result)
1172        self.assertIn('public.test_view', result)
1173        result = get_relations('rv')
1174        self.assertIn('public.test', result)
1175        self.assertIn('public.test_view', result)
1176        result = get_relations('r')
1177        self.assertIn('public.test', result)
1178        self.assertNotIn('public.test_view', result)
1179        result = get_relations('v')
1180        self.assertNotIn('public.test', result)
1181        self.assertIn('public.test_view', result)
1182        result = get_relations('cisSt')
1183        self.assertNotIn('public.test', result)
1184        self.assertNotIn('public.test_view', result)
1185
1186    def testGetSystemRelations(self):
1187        get_relations = self.db.get_relations
1188        result = get_relations()
1189        self.assertNotIn('pg_catalog.pg_class', result)
1190        self.assertNotIn('information_schema.tables', result)
1191        result = get_relations(system=False)
1192        self.assertNotIn('pg_catalog.pg_class', result)
1193        self.assertNotIn('information_schema.tables', result)
1194        result = get_relations(system=True)
1195        self.assertIn('pg_catalog.pg_class', result)
1196        self.assertIn('information_schema.tables', result)
1197
1198    def testGetAttnames(self):
1199        get_attnames = self.db.get_attnames
1200        self.assertRaises(pg.ProgrammingError,
1201                          self.db.get_attnames, 'does_not_exist')
1202        self.assertRaises(pg.ProgrammingError,
1203                          self.db.get_attnames, 'has.too.many.dots')
1204        r = get_attnames('test')
1205        self.assertIsInstance(r, dict)
1206        if self.regtypes:
1207            self.assertEqual(r, dict(
1208                i2='smallint', i4='integer', i8='bigint', d='numeric',
1209                f4='real', f8='double precision', m='money',
1210                v4='character varying', c4='character', t='text'))
1211        else:
1212            self.assertEqual(r, dict(
1213                i2='int', i4='int', i8='int', d='num',
1214                f4='float', f8='float', m='money',
1215                v4='text', c4='text', t='text'))
1216        self.createTable('test_table',
1217                         'n int, alpha smallint, beta bool,'
1218                         ' gamma char(5), tau text, v varchar(3)')
1219        r = get_attnames('test_table')
1220        self.assertIsInstance(r, dict)
1221        if self.regtypes:
1222            self.assertEqual(r, dict(
1223                n='integer', alpha='smallint', beta='boolean',
1224                gamma='character', tau='text', v='character varying'))
1225        else:
1226            self.assertEqual(r, dict(
1227                n='int', alpha='int', beta='bool',
1228                gamma='text', tau='text', v='text'))
1229
1230    def testGetAttnamesWithQuotes(self):
1231        get_attnames = self.db.get_attnames
1232        table = 'test table for get_attnames()'
1233        self.createTable(table,
1234            '"Prime!" smallint, "much space" integer, "Questions?" text')
1235        r = get_attnames(table)
1236        self.assertIsInstance(r, dict)
1237        if self.regtypes:
1238            self.assertEqual(r, {
1239                'Prime!': 'smallint', 'much space': 'integer',
1240                'Questions?': 'text'})
1241        else:
1242            self.assertEqual(r, {
1243                'Prime!': 'int', 'much space': 'int', 'Questions?': 'text'})
1244        table = 'yet another test table for get_attnames()'
1245        self.createTable(table,
1246                         'a smallint, b integer, c bigint,'
1247                         ' e numeric, f real, f2 double precision, m money,'
1248                         ' x smallint, y smallint, z smallint,'
1249                         ' Normal_NaMe smallint, "Special Name" smallint,'
1250                         ' t text, u char(2), v varchar(2),'
1251                         ' primary key (y, u)', oids=True)
1252        r = get_attnames(table)
1253        self.assertIsInstance(r, dict)
1254        if self.regtypes:
1255            self.assertEqual(r, {
1256                'a': 'smallint', 'b': 'integer', 'c': 'bigint',
1257                'e': 'numeric', 'f': 'real', 'f2': 'double precision',
1258                'm': 'money', 'normal_name': 'smallint',
1259                'Special Name': 'smallint', 'u': 'character',
1260                't': 'text', 'v': 'character varying', 'y': 'smallint',
1261                'x': 'smallint', 'z': 'smallint', 'oid': 'oid'})
1262        else:
1263            self.assertEqual(r, {'a': 'int', 'b': 'int', 'c': 'int',
1264                 'e': 'num', 'f': 'float', 'f2': 'float', 'm': 'money',
1265                 'normal_name': 'int', 'Special Name': 'int',
1266                 'u': 'text', 't': 'text', 'v': 'text',
1267                 'y': 'int', 'x': 'int', 'z': 'int', 'oid': 'int'})
1268
1269    def testGetAttnamesWithRegtypes(self):
1270        get_attnames = self.db.get_attnames
1271        self.createTable('test_table', 'n int, alpha smallint, beta bool,'
1272            ' gamma char(5), tau text, v varchar(3)')
1273        use_regtypes = self.db.use_regtypes
1274        regtypes = use_regtypes()
1275        self.assertEqual(regtypes, self.regtypes)
1276        use_regtypes(True)
1277        try:
1278            r = get_attnames("test_table")
1279            self.assertIsInstance(r, dict)
1280        finally:
1281            use_regtypes(regtypes)
1282        self.assertEqual(r, dict(
1283            n='integer', alpha='smallint', beta='boolean',
1284            gamma='character', tau='text', v='character varying'))
1285
1286    def testGetAttnamesWithoutRegtypes(self):
1287        get_attnames = self.db.get_attnames
1288        self.createTable('test_table', 'n int, alpha smallint, beta bool,'
1289            ' gamma char(5), tau text, v varchar(3)')
1290        use_regtypes = self.db.use_regtypes
1291        regtypes = use_regtypes()
1292        self.assertEqual(regtypes, self.regtypes)
1293        use_regtypes(False)
1294        try:
1295            r = get_attnames("test_table")
1296            self.assertIsInstance(r, dict)
1297        finally:
1298            use_regtypes(regtypes)
1299        self.assertEqual(r, dict(
1300            n='int', alpha='int', beta='bool',
1301            gamma='text', tau='text', v='text'))
1302
1303    def testGetAttnamesIsCached(self):
1304        get_attnames = self.db.get_attnames
1305        int_type = 'integer' if self.regtypes else 'int'
1306        text_type = 'text'
1307        query = self.db.query
1308        self.createTable('test_table', 'col int')
1309        r = get_attnames("test_table")
1310        self.assertIsInstance(r, dict)
1311        self.assertEqual(r, dict(col=int_type))
1312        query("alter table test_table alter column col type text")
1313        query("alter table test_table add column col2 int")
1314        r = get_attnames("test_table")
1315        self.assertEqual(r, dict(col=int_type))
1316        r = get_attnames("test_table", flush=True)
1317        self.assertEqual(r, dict(col=text_type, col2=int_type))
1318        query("alter table test_table drop column col2")
1319        r = get_attnames("test_table")
1320        self.assertEqual(r, dict(col=text_type, col2=int_type))
1321        r = get_attnames("test_table", flush=True)
1322        self.assertEqual(r, dict(col=text_type))
1323        query("alter table test_table drop column col")
1324        r = get_attnames("test_table")
1325        self.assertEqual(r, dict(col=text_type))
1326        r = get_attnames("test_table", flush=True)
1327        self.assertEqual(r, dict())
1328
1329    def testGetAttnamesIsOrdered(self):
1330        get_attnames = self.db.get_attnames
1331        r = get_attnames('test', flush=True)
1332        self.assertIsInstance(r, OrderedDict)
1333        if self.regtypes:
1334            self.assertEqual(r, OrderedDict([
1335                ('i2', 'smallint'), ('i4', 'integer'), ('i8', 'bigint'),
1336                ('d', 'numeric'), ('f4', 'real'), ('f8', 'double precision'),
1337                ('m', 'money'), ('v4', 'character varying'),
1338                ('c4', 'character'), ('t', 'text')]))
1339        else:
1340            self.assertEqual(r, OrderedDict([
1341                ('i2', 'int'), ('i4', 'int'), ('i8', 'int'),
1342                ('d', 'num'), ('f4', 'float'), ('f8', 'float'), ('m', 'money'),
1343                ('v4', 'text'), ('c4', 'text'), ('t', 'text')]))
1344        if OrderedDict is not dict:
1345            r = ' '.join(list(r.keys()))
1346            self.assertEqual(r, 'i2 i4 i8 d f4 f8 m v4 c4 t')
1347        table = 'test table for get_attnames'
1348        self.createTable(table, 'n int, alpha smallint, v varchar(3),'
1349                ' gamma char(5), tau text, beta bool')
1350        r = get_attnames(table)
1351        self.assertIsInstance(r, OrderedDict)
1352        if self.regtypes:
1353            self.assertEqual(r, OrderedDict([
1354                ('n', 'integer'), ('alpha', 'smallint'),
1355                ('v', 'character varying'), ('gamma', 'character'),
1356                ('tau', 'text'), ('beta', 'boolean')]))
1357        else:
1358            self.assertEqual(r, OrderedDict([
1359                ('n', 'int'), ('alpha', 'int'), ('v', 'text'),
1360                ('gamma', 'text'), ('tau', 'text'), ('beta', 'bool')]))
1361        if OrderedDict is not dict:
1362            r = ' '.join(list(r.keys()))
1363            self.assertEqual(r, 'n alpha v gamma tau beta')
1364        else:
1365            self.skipTest('OrderedDict is not supported')
1366
1367    def testGetAttnamesIsAttrDict(self):
1368        AttrDict = pg.AttrDict
1369        get_attnames = self.db.get_attnames
1370        r = get_attnames('test', flush=True)
1371        self.assertIsInstance(r, AttrDict)
1372        if self.regtypes:
1373            self.assertEqual(r, AttrDict([
1374                ('i2', 'smallint'), ('i4', 'integer'), ('i8', 'bigint'),
1375                ('d', 'numeric'), ('f4', 'real'), ('f8', 'double precision'),
1376                ('m', 'money'), ('v4', 'character varying'),
1377                ('c4', 'character'), ('t', 'text')]))
1378        else:
1379            self.assertEqual(r, AttrDict([
1380                ('i2', 'int'), ('i4', 'int'), ('i8', 'int'),
1381                ('d', 'num'), ('f4', 'float'), ('f8', 'float'), ('m', 'money'),
1382                ('v4', 'text'), ('c4', 'text'), ('t', 'text')]))
1383        r = ' '.join(list(r.keys()))
1384        self.assertEqual(r, 'i2 i4 i8 d f4 f8 m v4 c4 t')
1385        table = 'test table for get_attnames'
1386        self.createTable(table, 'n int, alpha smallint, v varchar(3),'
1387                ' gamma char(5), tau text, beta bool')
1388        r = get_attnames(table)
1389        self.assertIsInstance(r, AttrDict)
1390        if self.regtypes:
1391            self.assertEqual(r, AttrDict([
1392                ('n', 'integer'), ('alpha', 'smallint'),
1393                ('v', 'character varying'), ('gamma', 'character'),
1394                ('tau', 'text'), ('beta', 'boolean')]))
1395        else:
1396            self.assertEqual(r, AttrDict([
1397                ('n', 'int'), ('alpha', 'int'), ('v', 'text'),
1398                ('gamma', 'text'), ('tau', 'text'), ('beta', 'bool')]))
1399        r = ' '.join(list(r.keys()))
1400        self.assertEqual(r, 'n alpha v gamma tau beta')
1401
1402    def testHasTablePrivilege(self):
1403        can = self.db.has_table_privilege
1404        self.assertEqual(can('test'), True)
1405        self.assertEqual(can('test', 'select'), True)
1406        self.assertEqual(can('test', 'SeLeCt'), True)
1407        self.assertEqual(can('test', 'SELECT'), True)
1408        self.assertEqual(can('test', 'insert'), True)
1409        self.assertEqual(can('test', 'update'), True)
1410        self.assertEqual(can('test', 'delete'), True)
1411        self.assertRaises(pg.DataError, can, 'test', 'foobar')
1412        self.assertRaises(pg.ProgrammingError, can, 'table_does_not_exist')
1413        r = self.db.query('select rolsuper FROM pg_roles'
1414            ' where rolname=current_user').getresult()[0][0]
1415        if not pg.get_bool():
1416            r = r == 't'
1417        if r:
1418            self.skipTest('must not be superuser')
1419        self.assertEqual(can('pg_views', 'select'), True)
1420        self.assertEqual(can('pg_views', 'delete'), False)
1421
1422    def testGet(self):
1423        get = self.db.get
1424        query = self.db.query
1425        table = 'get_test_table'
1426        self.assertRaises(TypeError, get)
1427        self.assertRaises(TypeError, get, table)
1428        self.createTable(table, 'n integer, t text',
1429                         values=enumerate('xyz', start=1))
1430        self.assertRaises(pg.ProgrammingError, get, table, 2)
1431        r = get(table, 2, 'n')
1432        self.assertIsInstance(r, dict)
1433        self.assertEqual(r, dict(n=2, t='y'))
1434        r = get(table, 1, 'n')
1435        self.assertEqual(r, dict(n=1, t='x'))
1436        r = get(table, (3,), ('n',))
1437        self.assertEqual(r, dict(n=3, t='z'))
1438        r = get(table, 'y', 't')
1439        self.assertEqual(r, dict(n=2, t='y'))
1440        self.assertRaises(pg.DatabaseError, get, table, 4)
1441        self.assertRaises(pg.DatabaseError, get, table, 4, 'n')
1442        self.assertRaises(pg.DatabaseError, get, table, 'y')
1443        self.assertRaises(pg.DatabaseError, get, table, 2, 't')
1444        s = dict(n=3)
1445        self.assertRaises(pg.ProgrammingError, get, table, s)
1446        r = get(table, s, 'n')
1447        self.assertIs(r, s)
1448        self.assertEqual(r, dict(n=3, t='z'))
1449        s.update(t='x')
1450        r = get(table, s, 't')
1451        self.assertIs(r, s)
1452        self.assertEqual(s, dict(n=1, t='x'))
1453        r = get(table, s, ('n', 't'))
1454        self.assertIs(r, s)
1455        self.assertEqual(r, dict(n=1, t='x'))
1456        query('alter table "%s" alter n set not null' % table)
1457        query('alter table "%s" add primary key (n)' % table)
1458        r = get(table, 2)
1459        self.assertIsInstance(r, dict)
1460        self.assertEqual(r, dict(n=2, t='y'))
1461        self.assertEqual(get(table, 1)['t'], 'x')
1462        self.assertEqual(get(table, 3)['t'], 'z')
1463        self.assertEqual(get(table + '*', 2)['t'], 'y')
1464        self.assertEqual(get(table + ' *', 2)['t'], 'y')
1465        self.assertRaises(KeyError, get, table, (2, 2))
1466        s = dict(n=3)
1467        r = get(table, s)
1468        self.assertIs(r, s)
1469        self.assertEqual(r, dict(n=3, t='z'))
1470        s.update(n=1)
1471        self.assertEqual(get(table, s)['t'], 'x')
1472        s.update(n=2)
1473        self.assertEqual(get(table, r)['t'], 'y')
1474        s.pop('n')
1475        self.assertRaises(KeyError, get, table, s)
1476
1477    def testGetWithOid(self):
1478        get = self.db.get
1479        query = self.db.query
1480        table = 'get_with_oid_test_table'
1481        self.createTable(table, 'n integer, t text', oids=True,
1482                         values=enumerate('xyz', start=1))
1483        self.assertRaises(pg.ProgrammingError, get, table, 2)
1484        self.assertRaises(KeyError, get, table, {}, 'oid')
1485        r = get(table, 2, 'n')
1486        qoid = 'oid(%s)' % table
1487        self.assertIn(qoid, r)
1488        oid = r[qoid]
1489        self.assertIsInstance(oid, int)
1490        result = {'t': 'y', 'n': 2, qoid: oid}
1491        self.assertEqual(r, result)
1492        r = get(table, oid, 'oid')
1493        self.assertEqual(r, result)
1494        r = get(table, dict(oid=oid))
1495        self.assertEqual(r, result)
1496        r = get(table, dict(oid=oid), 'oid')
1497        self.assertEqual(r, result)
1498        r = get(table, {qoid: oid})
1499        self.assertEqual(r, result)
1500        r = get(table, {qoid: oid}, 'oid')
1501        self.assertEqual(r, result)
1502        self.assertEqual(get(table + '*', 2, 'n'), r)
1503        self.assertEqual(get(table + ' *', 2, 'n'), r)
1504        self.assertEqual(get(table, oid, 'oid')['t'], 'y')
1505        self.assertEqual(get(table, 1, 'n')['t'], 'x')
1506        self.assertEqual(get(table, 3, 'n')['t'], 'z')
1507        self.assertEqual(get(table, 2, 'n')['t'], 'y')
1508        self.assertRaises(pg.DatabaseError, get, table, 4, 'n')
1509        r['n'] = 3
1510        self.assertEqual(get(table, r, 'n')['t'], 'z')
1511        self.assertEqual(get(table, 1, 'n')['t'], 'x')
1512        self.assertEqual(get(table, r, 'oid')['t'], 'z')
1513        query('alter table "%s" alter n set not null' % table)
1514        query('alter table "%s" add primary key (n)' % table)
1515        self.assertEqual(get(table, 3)['t'], 'z')
1516        self.assertEqual(get(table, 1)['t'], 'x')
1517        self.assertEqual(get(table, 2)['t'], 'y')
1518        r['n'] = 1
1519        self.assertEqual(get(table, r)['t'], 'x')
1520        r['n'] = 3
1521        self.assertEqual(get(table, r)['t'], 'z')
1522        r['n'] = 2
1523        self.assertEqual(get(table, r)['t'], 'y')
1524        r = get(table, oid, 'oid')
1525        self.assertEqual(r, result)
1526        r = get(table, dict(oid=oid))
1527        self.assertEqual(r, result)
1528        r = get(table, dict(oid=oid), 'oid')
1529        self.assertEqual(r, result)
1530        r = get(table, {qoid: oid})
1531        self.assertEqual(r, result)
1532        r = get(table, {qoid: oid}, 'oid')
1533        self.assertEqual(r, result)
1534        r = get(table, dict(oid=oid, n=1))
1535        self.assertEqual(r['n'], 1)
1536        self.assertNotEqual(r[qoid], oid)
1537        r = get(table, dict(oid=oid, t='z'), 't')
1538        self.assertEqual(r['n'], 3)
1539        self.assertNotEqual(r[qoid], oid)
1540
1541    def testGetWithCompositeKey(self):
1542        get = self.db.get
1543        query = self.db.query
1544        table = 'get_test_table_1'
1545        self.createTable(table, 'n integer primary key, t text',
1546                         values=enumerate('abc', start=1))
1547        self.assertEqual(get(table, 2)['t'], 'b')
1548        self.assertEqual(get(table, 1, 'n')['t'], 'a')
1549        self.assertEqual(get(table, 2, ('n',))['t'], 'b')
1550        self.assertEqual(get(table, 3, ['n'])['t'], 'c')
1551        self.assertEqual(get(table, (2,), ('n',))['t'], 'b')
1552        self.assertEqual(get(table, 'b', 't')['n'], 2)
1553        self.assertEqual(get(table, ('a',), ('t',))['n'], 1)
1554        self.assertEqual(get(table, ['c'], ['t'])['n'], 3)
1555        table = 'get_test_table_2'
1556        self.createTable(table,
1557                         'n integer, m integer, t text, primary key (n, m)',
1558                         values=[(n + 1, m + 1, chr(ord('a') + 2 * n + m))
1559                                 for n in range(3) for m in range(2)])
1560        self.assertRaises(KeyError, get, table, 2)
1561        self.assertEqual(get(table, (1, 1))['t'], 'a')
1562        self.assertEqual(get(table, (1, 2))['t'], 'b')
1563        self.assertEqual(get(table, (2, 1))['t'], 'c')
1564        self.assertEqual(get(table, (1, 2), ('n', 'm'))['t'], 'b')
1565        self.assertEqual(get(table, (1, 2), ('m', 'n'))['t'], 'c')
1566        self.assertEqual(get(table, (3, 1), ('n', 'm'))['t'], 'e')
1567        self.assertEqual(get(table, (1, 3), ('m', 'n'))['t'], 'e')
1568        self.assertEqual(get(table, dict(n=2, m=2))['t'], 'd')
1569        self.assertEqual(get(table, dict(n=1, m=2), ('n', 'm'))['t'], 'b')
1570        self.assertEqual(get(table, dict(n=2, m=1), ['n', 'm'])['t'], 'c')
1571        self.assertEqual(get(table, dict(n=3, m=2), ('m', 'n'))['t'], 'f')
1572
1573    def testGetWithQuotedNames(self):
1574        get = self.db.get
1575        query = self.db.query
1576        table = 'test table for get()'
1577        self.createTable(table, '"Prime!" smallint primary key,'
1578                                ' "much space" integer, "Questions?" text',
1579                         values=[(17, 1001, 'No!')])
1580        r = get(table, 17)
1581        self.assertIsInstance(r, dict)
1582        self.assertEqual(r['Prime!'], 17)
1583        self.assertEqual(r['much space'], 1001)
1584        self.assertEqual(r['Questions?'], 'No!')
1585
1586    def testGetFromView(self):
1587        self.db.query('delete from test where i4=14')
1588        self.db.query('insert into test (i4, v4) values('
1589                      "14, 'abc4')")
1590        r = self.db.get('test_view', 14, 'i4')
1591        self.assertIn('v4', r)
1592        self.assertEqual(r['v4'], 'abc4')
1593
1594    def testGetLittleBobbyTables(self):
1595        get = self.db.get
1596        query = self.db.query
1597        self.createTable('test_students',
1598                         'firstname varchar primary key, nickname varchar, grade char(2)',
1599                         values=[("D'Arcy", 'Darcey', 'A+'), ('Sheldon', 'Moonpie', 'A+'),
1600                                 ('Robert', 'Little Bobby Tables', 'D-')])
1601        r = get('test_students', 'Sheldon')
1602        self.assertEqual(r, dict(
1603            firstname="Sheldon", nickname='Moonpie', grade='A+'))
1604        r = get('test_students', 'Robert')
1605        self.assertEqual(r, dict(
1606            firstname="Robert", nickname='Little Bobby Tables', grade='D-'))
1607        r = get('test_students', "D'Arcy")
1608        self.assertEqual(r, dict(
1609            firstname="D'Arcy", nickname='Darcey', grade='A+'))
1610        try:
1611            get('test_students', "D' Arcy")
1612        except pg.DatabaseError as error:
1613            self.assertEqual(str(error),
1614                'No such record in test_students\nwhere "firstname" = $1\n'
1615                'with $1="D\' Arcy"')
1616        try:
1617            get('test_students', "Robert'); TRUNCATE TABLE test_students;--")
1618        except pg.DatabaseError as error:
1619            self.assertEqual(str(error),
1620                'No such record in test_students\nwhere "firstname" = $1\n'
1621                'with $1="Robert\'); TRUNCATE TABLE test_students;--"')
1622        q = "select * from test_students order by 1 limit 4"
1623        r = query(q).getresult()
1624        self.assertEqual(len(r), 3)
1625        self.assertEqual(r[1][2], 'D-')
1626
1627    def testInsert(self):
1628        insert = self.db.insert
1629        query = self.db.query
1630        bool_on = pg.get_bool()
1631        decimal = pg.get_decimal()
1632        table = 'insert_test_table'
1633        self.createTable(table,
1634            'i2 smallint, i4 integer, i8 bigint,'
1635            ' d numeric, f4 real, f8 double precision, m money,'
1636            ' v4 varchar(4), c4 char(4), t text,'
1637            ' b boolean, ts timestamp', oids=True)
1638        oid_table = 'oid(%s)' % table
1639        tests = [dict(i2=None, i4=None, i8=None),
1640             (dict(i2='', i4='', i8=''), dict(i2=None, i4=None, i8=None)),
1641             (dict(i2=0, i4=0, i8=0), dict(i2=0, i4=0, i8=0)),
1642             dict(i2=42, i4=123456, i8=9876543210),
1643             dict(i2=2 ** 15 - 1,
1644                  i4=int(2 ** 31 - 1), i8=long(2 ** 63 - 1)),
1645             dict(d=None), (dict(d=''), dict(d=None)),
1646             dict(d=Decimal(0)), (dict(d=0), dict(d=Decimal(0))),
1647             dict(f4=None, f8=None), dict(f4=0, f8=0),
1648             (dict(f4='', f8=''), dict(f4=None, f8=None)),
1649             (dict(d=1234.5, f4=1234.5, f8=1234.5),
1650              dict(d=Decimal('1234.5'))),
1651             dict(d=Decimal('123.456789'), f4=12.375, f8=123.4921875),
1652             dict(d=Decimal('123456789.9876543212345678987654321')),
1653             dict(m=None), (dict(m=''), dict(m=None)),
1654             dict(m=Decimal('-1234.56')),
1655             (dict(m=('-1234.56')), dict(m=Decimal('-1234.56'))),
1656             dict(m=Decimal('1234.56')), dict(m=Decimal('123456')),
1657             (dict(m='1234.56'), dict(m=Decimal('1234.56'))),
1658             (dict(m=1234.5), dict(m=Decimal('1234.5'))),
1659             (dict(m=-1234.5), dict(m=Decimal('-1234.5'))),
1660             (dict(m=123456), dict(m=Decimal('123456'))),
1661             (dict(m='1234567.89'), dict(m=Decimal('1234567.89'))),
1662             dict(b=None), (dict(b=''), dict(b=None)),
1663             dict(b='f'), dict(b='t'),
1664             (dict(b=0), dict(b='f')), (dict(b=1), dict(b='t')),
1665             (dict(b=False), dict(b='f')), (dict(b=True), dict(b='t')),
1666             (dict(b='0'), dict(b='f')), (dict(b='1'), dict(b='t')),
1667             (dict(b='n'), dict(b='f')), (dict(b='y'), dict(b='t')),
1668             (dict(b='no'), dict(b='f')), (dict(b='yes'), dict(b='t')),
1669             (dict(b='off'), dict(b='f')), (dict(b='on'), dict(b='t')),
1670             dict(v4=None, c4=None, t=None),
1671             (dict(v4='', c4='', t=''), dict(c4=' ' * 4)),
1672             dict(v4='1234', c4='1234', t='1234' * 10),
1673             dict(v4='abcd', c4='abcd', t='abcdefg'),
1674             (dict(v4='abc', c4='abc', t='abc'), dict(c4='abc ')),
1675             dict(ts=None), (dict(ts=''), dict(ts=None)),
1676             (dict(ts=0), dict(ts=None)), (dict(ts=False), dict(ts=None)),
1677             dict(ts='2012-12-21 00:00:00'),
1678             (dict(ts='2012-12-21'), dict(ts='2012-12-21 00:00:00')),
1679             dict(ts='2012-12-21 12:21:12'),
1680             dict(ts='2013-01-05 12:13:14'),
1681             dict(ts='current_timestamp')]
1682        for test in tests:
1683            if isinstance(test, dict):
1684                data = test
1685                change = {}
1686            else:
1687                data, change = test
1688            expect = data.copy()
1689            expect.update(change)
1690            if bool_on:
1691                b = expect.get('b')
1692                if b is not None:
1693                    expect['b'] = b == 't'
1694            if decimal is not Decimal:
1695                d = expect.get('d')
1696                if d is not None:
1697                    expect['d'] = decimal(d)
1698                m = expect.get('m')
1699                if m is not None:
1700                    expect['m'] = decimal(m)
1701            self.assertEqual(insert(table, data), data)
1702            self.assertIn(oid_table, data)
1703            oid = data[oid_table]
1704            self.assertIsInstance(oid, int)
1705            data = dict(item for item in data.items()
1706                        if item[0] in expect)
1707            ts = expect.get('ts')
1708            if ts:
1709                if ts == 'current_timestamp':
1710                    ts = data['ts']
1711                    self.assertIsInstance(ts, datetime)
1712                    self.assertEqual(ts.strftime('%Y-%m-%d'),
1713                        strftime('%Y-%m-%d'))
1714                else:
1715                    ts = datetime.strptime(ts, '%Y-%m-%d %H:%M:%S')
1716                expect['ts'] = ts
1717            self.assertEqual(data, expect)
1718            data = query(
1719                'select oid,* from "%s"' % table).dictresult()[0]
1720            self.assertEqual(data['oid'], oid)
1721            data = dict(item for item in data.items()
1722                        if item[0] in expect)
1723            self.assertEqual(data, expect)
1724            query('delete from "%s"' % table)
1725
1726    def testInsertWithOid(self):
1727        insert = self.db.insert
1728        query = self.db.query
1729        self.createTable('test_table', 'n int', oids=True)
1730        self.assertRaises(pg.ProgrammingError, insert, 'test_table', m=1)
1731        r = insert('test_table', n=1)
1732        self.assertIsInstance(r, dict)
1733        self.assertEqual(r['n'], 1)
1734        self.assertNotIn('oid', r)
1735        qoid = 'oid(test_table)'
1736        self.assertIn(qoid, r)
1737        oid = r[qoid]
1738        self.assertEqual(sorted(r.keys()), ['n', qoid])
1739        r = insert('test_table', n=2, oid=oid)
1740        self.assertIsInstance(r, dict)
1741        self.assertEqual(r['n'], 2)
1742        self.assertIn(qoid, r)
1743        self.assertNotEqual(r[qoid], oid)
1744        self.assertNotIn('oid', r)
1745        r = insert('test_table', None, n=3)
1746        self.assertIsInstance(r, dict)
1747        self.assertEqual(r['n'], 3)
1748        s = r
1749        r = insert('test_table', r)
1750        self.assertIs(r, s)
1751        self.assertEqual(r['n'], 3)
1752        r = insert('test_table *', r)
1753        self.assertIs(r, s)
1754        self.assertEqual(r['n'], 3)
1755        r = insert('test_table', r, n=4)
1756        self.assertIs(r, s)
1757        self.assertEqual(r['n'], 4)
1758        self.assertNotIn('oid', r)
1759        self.assertIn(qoid, r)
1760        oid = r[qoid]
1761        r = insert('test_table', r, n=5, oid=oid)
1762        self.assertIs(r, s)
1763        self.assertEqual(r['n'], 5)
1764        self.assertIn(qoid, r)
1765        self.assertNotEqual(r[qoid], oid)
1766        self.assertNotIn('oid', r)
1767        r['oid'] = oid = r[qoid]
1768        r = insert('test_table', r, n=6)
1769        self.assertIs(r, s)
1770        self.assertEqual(r['n'], 6)
1771        self.assertIn(qoid, r)
1772        self.assertNotEqual(r[qoid], oid)
1773        self.assertNotIn('oid', r)
1774        q = 'select n from test_table order by 1 limit 9'
1775        r = ' '.join(str(row[0]) for row in query(q).getresult())
1776        self.assertEqual(r, '1 2 3 3 3 4 5 6')
1777        query("truncate test_table")
1778        query("alter table test_table add unique (n)")
1779        r = insert('test_table', dict(n=7))
1780        self.assertIsInstance(r, dict)
1781        self.assertEqual(r['n'], 7)
1782        self.assertRaises(pg.IntegrityError, insert, 'test_table', r)
1783        r['n'] = 6
1784        self.assertRaises(pg.IntegrityError, insert, 'test_table', r, n=7)
1785        self.assertIsInstance(r, dict)
1786        self.assertEqual(r['n'], 7)
1787        r['n'] = 6
1788        r = insert('test_table', r)
1789        self.assertIsInstance(r, dict)
1790        self.assertEqual(r['n'], 6)
1791        r = ' '.join(str(row[0]) for row in query(q).getresult())
1792        self.assertEqual(r, '6 7')
1793
1794    def testInsertWithQuotedNames(self):
1795        insert = self.db.insert
1796        query = self.db.query
1797        table = 'test table for insert()'
1798        self.createTable(table, '"Prime!" smallint primary key,'
1799                                ' "much space" integer, "Questions?" text')
1800        r = {'Prime!': 11, 'much space': 2002, 'Questions?': 'What?'}
1801        r = insert(table, r)
1802        self.assertIsInstance(r, dict)
1803        self.assertEqual(r['Prime!'], 11)
1804        self.assertEqual(r['much space'], 2002)
1805        self.assertEqual(r['Questions?'], 'What?')
1806        r = query('select * from "%s" limit 2' % table).dictresult()
1807        self.assertEqual(len(r), 1)
1808        r = r[0]
1809        self.assertEqual(r['Prime!'], 11)
1810        self.assertEqual(r['much space'], 2002)
1811        self.assertEqual(r['Questions?'], 'What?')
1812
1813    def testInsertIntoView(self):
1814        insert = self.db.insert
1815        query = self.db.query
1816        query("truncate test")
1817        q = 'select * from test_view order by i4 limit 3'
1818        r = query(q).getresult()
1819        self.assertEqual(r, [])
1820        r = dict(i4=1234, v4='abcd')
1821        insert('test', r)
1822        self.assertIsNone(r['i2'])
1823        self.assertEqual(r['i4'], 1234)
1824        self.assertIsNone(r['i8'])
1825        self.assertEqual(r['v4'], 'abcd')
1826        self.assertIsNone(r['c4'])
1827        r = query(q).getresult()
1828        self.assertEqual(r, [(1234, 'abcd')])
1829        r = dict(i4=5678, v4='efgh')
1830        try:
1831            insert('test_view', r)
1832        except (pg.OperationalError, pg.NotSupportedError) as error:
1833            if self.db.server_version < 90300:
1834                # must setup rules in older PostgreSQL versions
1835                self.skipTest('database cannot insert into view')
1836            self.fail(str(error))
1837        self.assertNotIn('i2', r)
1838        self.assertEqual(r['i4'], 5678)
1839        self.assertNotIn('i8', r)
1840        self.assertEqual(r['v4'], 'efgh')
1841        self.assertNotIn('c4', r)
1842        r = query(q).getresult()
1843        self.assertEqual(r, [(1234, 'abcd'), (5678, 'efgh')])
1844
1845    def testUpdate(self):
1846        update = self.db.update
1847        query = self.db.query
1848        self.assertRaises(pg.ProgrammingError, update,
1849                          'test', i2=2, i4=4, i8=8)
1850        table = 'update_test_table'
1851        self.createTable(table, 'n integer, t text', oids=True,
1852                         values=enumerate('xyz', start=1))
1853        self.assertRaises(pg.ProgrammingError, self.db.get, table, 2)
1854        r = self.db.get(table, 2, 'n')
1855        r['t'] = 'u'
1856        s = update(table, r)
1857        self.assertEqual(s, r)
1858        q = 'select t from "%s" where n=2' % table
1859        r = query(q).getresult()[0][0]
1860        self.assertEqual(r, 'u')
1861
1862    def testUpdateWithOid(self):
1863        update = self.db.update
1864        get = self.db.get
1865        query = self.db.query
1866        self.createTable('test_table', 'n int', oids=True, values=[1])
1867        s = get('test_table', 1, 'n')
1868        self.assertIsInstance(s, dict)
1869        self.assertEqual(s['n'], 1)
1870        s['n'] = 2
1871        r = update('test_table', s)
1872        self.assertIs(r, s)
1873        self.assertEqual(r['n'], 2)
1874        qoid = 'oid(test_table)'
1875        self.assertIn(qoid, r)
1876        self.assertNotIn('oid', r)
1877        self.assertEqual(sorted(r.keys()), ['n', qoid])
1878        r['n'] = 3
1879        oid = r.pop(qoid)
1880        r = update('test_table', r, oid=oid)
1881        self.assertIs(r, s)
1882        self.assertEqual(r['n'], 3)
1883        r.pop(qoid)
1884        self.assertRaises(pg.ProgrammingError, update, 'test_table', r)
1885        s = get('test_table', 3, 'n')
1886        self.assertIsInstance(s, dict)
1887        self.assertEqual(s['n'], 3)
1888        s.pop('n')
1889        r = update('test_table', s)
1890        oid = r.pop(qoid)
1891        self.assertEqual(r, {})
1892        q = "select n from test_table limit 2"
1893        r = query(q).getresult()
1894        self.assertEqual(r, [(3,)])
1895        query("insert into test_table values (1)")
1896        self.assertRaises(pg.ProgrammingError,
1897                          update, 'test_table', dict(oid=oid, n=4))
1898        r = update('test_table', dict(n=4), oid=oid)
1899        self.assertEqual(r['n'], 4)
1900        r = update('test_table *', dict(n=5), oid=oid)
1901        self.assertEqual(r['n'], 5)
1902        query("alter table test_table add column m int")
1903        query("alter table test_table add primary key (n)")
1904        self.assertIn('m', self.db.get_attnames('test_table', flush=True))
1905        self.assertEqual('n', self.db.pkey('test_table', flush=True))
1906        s = dict(n=1, m=4)
1907        r = update('test_table', s)
1908        self.assertIs(r, s)
1909        self.assertEqual(r['n'], 1)
1910        self.assertEqual(r['m'], 4)
1911        s = dict(m=7)
1912        r = update('test_table', s, n=5)
1913        self.assertIs(r, s)
1914        self.assertEqual(r['n'], 5)
1915        self.assertEqual(r['m'], 7)
1916        q = "select n, m from test_table order by 1 limit 3"
1917        r = query(q).getresult()
1918        self.assertEqual(r, [(1, 4), (5, 7)])
1919        s = dict(m=9, oid=oid)
1920        self.assertRaises(KeyError, update, 'test_table', s)
1921        r = update('test_table', s, oid=oid)
1922        self.assertIs(r, s)
1923        self.assertEqual(r['n'], 5)
1924        self.assertEqual(r['m'], 9)
1925        s = dict(n=1, m=3, oid=oid)
1926        r = update('test_table', s)
1927        self.assertIs(r, s)
1928        self.assertEqual(r['n'], 1)
1929        self.assertEqual(r['m'], 3)
1930        r = query(q).getresult()
1931        self.assertEqual(r, [(1, 3), (5, 9)])
1932        s.update(n=4, m=7)
1933        r = update('test_table', s, oid=oid)
1934        self.assertIs(r, s)
1935        self.assertEqual(r['n'], 4)
1936        self.assertEqual(r['m'], 7)
1937        r = query(q).getresult()
1938        self.assertEqual(r, [(1, 3), (4, 7)])
1939
1940    def testUpdateWithoutOid(self):
1941        update = self.db.update
1942        query = self.db.query
1943        self.assertRaises(pg.ProgrammingError, update,
1944                          'test', i2=2, i4=4, i8=8)
1945        table = 'update_test_table'
1946        self.createTable(table, 'n integer primary key, t text', oids=False,
1947                         values=enumerate('xyz', start=1))
1948        r = self.db.get(table, 2)
1949        r['t'] = 'u'
1950        s = update(table, r)
1951        self.assertEqual(s, r)
1952        q = 'select t from "%s" where n=2' % table
1953        r = query(q).getresult()[0][0]
1954        self.assertEqual(r, 'u')
1955
1956    def testUpdateWithCompositeKey(self):
1957        update = self.db.update
1958        query = self.db.query
1959        table = 'update_test_table_1'
1960        self.createTable(table, 'n integer primary key, t text',
1961                         values=enumerate('abc', start=1))
1962        self.assertRaises(KeyError, update, table, dict(t='b'))
1963        s = dict(n=2, t='d')
1964        r = update(table, s)
1965        self.assertIs(r, s)
1966        self.assertEqual(r['n'], 2)
1967        self.assertEqual(r['t'], 'd')
1968        q = 'select t from "%s" where n=2' % table
1969        r = query(q).getresult()[0][0]
1970        self.assertEqual(r, 'd')
1971        s.update(dict(n=4, t='e'))
1972        r = update(table, s)
1973        self.assertEqual(r['n'], 4)
1974        self.assertEqual(r['t'], 'e')
1975        q = 'select t from "%s" where n=2' % table
1976        r = query(q).getresult()[0][0]
1977        self.assertEqual(r, 'd')
1978        q = 'select t from "%s" where n=4' % table
1979        r = query(q).getresult()
1980        self.assertEqual(len(r), 0)
1981        query('drop table "%s"' % table)
1982        table = 'update_test_table_2'
1983        self.createTable(table,
1984                         'n integer, m integer, t text, primary key (n, m)',
1985                         values=[(n + 1, m + 1, chr(ord('a') + 2 * n + m))
1986                                 for n in range(3) for m in range(2)])
1987        self.assertRaises(KeyError, update, table, dict(n=2, t='b'))
1988        self.assertEqual(update(table,
1989                                dict(n=2, m=2, t='x'))['t'], 'x')
1990        q = 'select t from "%s" where n=2 order by m' % table
1991        r = [r[0] for r in query(q).getresult()]
1992        self.assertEqual(r, ['c', 'x'])
1993
1994    def testUpdateWithQuotedNames(self):
1995        update = self.db.update
1996        query = self.db.query
1997        table = 'test table for update()'
1998        self.createTable(table, '"Prime!" smallint primary key,'
1999                                ' "much space" integer, "Questions?" text',
2000                         values=[(13, 3003, 'Why!')])
2001        r = {'Prime!': 13, 'much space': 7007, 'Questions?': 'When?'}
2002        r = update(table, r)
2003        self.assertIsInstance(r, dict)
2004        self.assertEqual(r['Prime!'], 13)
2005        self.assertEqual(r['much space'], 7007)
2006        self.assertEqual(r['Questions?'], 'When?')
2007        r = query('select * from "%s" limit 2' % table).dictresult()
2008        self.assertEqual(len(r), 1)
2009        r = r[0]
2010        self.assertEqual(r['Prime!'], 13)
2011        self.assertEqual(r['much space'], 7007)
2012        self.assertEqual(r['Questions?'], 'When?')
2013
2014    def testUpsert(self):
2015        upsert = self.db.upsert
2016        query = self.db.query
2017        self.assertRaises(pg.ProgrammingError, upsert,
2018                          'test', i2=2, i4=4, i8=8)
2019        table = 'upsert_test_table'
2020        self.createTable(table, 'n integer primary key, t text', oids=True)
2021        s = dict(n=1, t='x')
2022        try:
2023            r = upsert(table, s)
2024        except pg.ProgrammingError as error:
2025            if self.db.server_version < 90500:
2026                self.skipTest('database does not support upsert')
2027            self.fail(str(error))
2028        self.assertIs(r, s)
2029        self.assertEqual(r['n'], 1)
2030        self.assertEqual(r['t'], 'x')
2031        s.update(n=2, t='y')
2032        r = upsert(table, s, **dict.fromkeys(s))
2033        self.assertIs(r, s)
2034        self.assertEqual(r['n'], 2)
2035        self.assertEqual(r['t'], 'y')
2036        q = 'select n, t from "%s" order by n limit 3' % table
2037        r = query(q).getresult()
2038        self.assertEqual(r, [(1, 'x'), (2, 'y')])
2039        s.update(t='z')
2040        r = upsert(table, s)
2041        self.assertIs(r, s)
2042        self.assertEqual(r['n'], 2)
2043        self.assertEqual(r['t'], 'z')
2044        r = query(q).getresult()
2045        self.assertEqual(r, [(1, 'x'), (2, 'z')])
2046        s.update(t='n')
2047        r = upsert(table, s, t=False)
2048        self.assertIs(r, s)
2049        self.assertEqual(r['n'], 2)
2050        self.assertEqual(r['t'], 'z')
2051        r = query(q).getresult()
2052        self.assertEqual(r, [(1, 'x'), (2, 'z')])
2053        s.update(t='y')
2054        r = upsert(table, s, t=True)
2055        self.assertIs(r, s)
2056        self.assertEqual(r['n'], 2)
2057        self.assertEqual(r['t'], 'y')
2058        r = query(q).getresult()
2059        self.assertEqual(r, [(1, 'x'), (2, 'y')])
2060        s.update(t='n')
2061        r = upsert(table, s, t="included.t || '2'")
2062        self.assertIs(r, s)
2063        self.assertEqual(r['n'], 2)
2064        self.assertEqual(r['t'], 'y2')
2065        r = query(q).getresult()
2066        self.assertEqual(r, [(1, 'x'), (2, 'y2')])
2067        s.update(t='y')
2068        r = upsert(table, s, t="excluded.t || '3'")
2069        self.assertIs(r, s)
2070        self.assertEqual(r['n'], 2)
2071        self.assertEqual(r['t'], 'y3')
2072        r = query(q).getresult()
2073        self.assertEqual(r, [(1, 'x'), (2, 'y3')])
2074        s.update(n=1, t='2')
2075        r = upsert(table, s, t="included.t || excluded.t")
2076        self.assertIs(r, s)
2077        self.assertEqual(r['n'], 1)
2078        self.assertEqual(r['t'], 'x2')
2079        r = query(q).getresult()
2080        self.assertEqual(r, [(1, 'x2'), (2, 'y3')])
2081        # not existing columns and oid parameter should be ignored
2082        s = dict(m=3, u='z')
2083        r = upsert(table, s, oid='invalid')
2084        self.assertIs(r, s)
2085
2086    def testUpsertWithOid(self):
2087        upsert = self.db.upsert
2088        get = self.db.get
2089        query = self.db.query
2090        self.createTable('test_table', 'n int', oids=True, values=[1])
2091        self.assertRaises(pg.ProgrammingError,
2092                          upsert, 'test_table', dict(n=2))
2093        r = get('test_table', 1, 'n')
2094        self.assertIsInstance(r, dict)
2095        self.assertEqual(r['n'], 1)
2096        qoid = 'oid(test_table)'
2097        self.assertIn(qoid, r)
2098        self.assertNotIn('oid', r)
2099        oid = r[qoid]
2100        self.assertRaises(pg.ProgrammingError,
2101                          upsert, 'test_table', dict(n=2, oid=oid))
2102        query("alter table test_table add column m int")
2103        query("alter table test_table add primary key (n)")
2104        self.assertIn('m', self.db.get_attnames('test_table', flush=True))
2105        self.assertEqual('n', self.db.pkey('test_table', flush=True))
2106        s = dict(n=2)
2107        try:
2108            r = upsert('test_table', s)
2109        except pg.ProgrammingError as error:
2110            if self.db.server_version < 90500:
2111                self.skipTest('database does not support upsert')
2112            self.fail(str(error))
2113        self.assertIs(r, s)
2114        self.assertEqual(r['n'], 2)
2115        self.assertIsNone(r['m'])
2116        q = query("select n, m from test_table order by n limit 3")
2117        self.assertEqual(q.getresult(), [(1, None), (2, None)])
2118        r['oid'] = oid
2119        r = upsert('test_table', r)
2120        self.assertIs(r, s)
2121        self.assertEqual(r['n'], 2)
2122        self.assertIsNone(r['m'])
2123        self.assertIn(qoid, r)
2124        self.assertNotIn('oid', r)
2125        self.assertNotEqual(r[qoid], oid)
2126        r['m'] = 7
2127        r = upsert('test_table', r)
2128        self.assertIs(r, s)
2129        self.assertEqual(r['n'], 2)
2130        self.assertEqual(r['m'], 7)
2131        r.update(n=1, m=3)
2132        r = upsert('test_table', r)
2133        self.assertIs(r, s)
2134        self.assertEqual(r['n'], 1)
2135        self.assertEqual(r['m'], 3)
2136        q = query("select n, m from test_table order by n limit 3")
2137        self.assertEqual(q.getresult(), [(1, 3), (2, 7)])
2138        r = upsert('test_table', r, oid='invalid')
2139        self.assertIs(r, s)
2140        self.assertEqual(r['n'], 1)
2141        self.assertEqual(r['m'], 3)
2142        r['m'] = 5
2143        r = upsert('test_table', r, m=False)
2144        self.assertIs(r, s)
2145        self.assertEqual(r['n'], 1)
2146        self.assertEqual(r['m'], 3)
2147        r['m'] = 5
2148        r = upsert('test_table', r, m=True)
2149        self.assertIs(r, s)
2150        self.assertEqual(r['n'], 1)
2151        self.assertEqual(r['m'], 5)
2152        r.update(n=2, m=1)
2153        r = upsert('test_table', r, m='included.m')
2154        self.assertIs(r, s)
2155        self.assertEqual(r['n'], 2)
2156        self.assertEqual(r['m'], 7)
2157        r['m'] = 9
2158        r = upsert('test_table', r, m='excluded.m')
2159        self.assertIs(r, s)
2160        self.assertEqual(r['n'], 2)
2161        self.assertEqual(r['m'], 9)
2162        r['m'] = 8
2163        r = upsert('test_table *', r, m='included.m + 1')
2164        self.assertIs(r, s)
2165        self.assertEqual(r['n'], 2)
2166        self.assertEqual(r['m'], 10)
2167        q = query("select n, m from test_table order by n limit 3")
2168        self.assertEqual(q.getresult(), [(1, 5), (2, 10)])
2169
2170    def testUpsertWithCompositeKey(self):
2171        upsert = self.db.upsert
2172        query = self.db.query
2173        table = 'upsert_test_table_2'
2174        self.createTable(table,
2175                         'n integer, m integer, t text, primary key (n, m)')
2176        s = dict(n=1, m=2, t='x')
2177        try:
2178            r = upsert(table, s)
2179        except pg.ProgrammingError as error:
2180            if self.db.server_version < 90500:
2181                self.skipTest('database does not support upsert')
2182            self.fail(str(error))
2183        self.assertIs(r, s)
2184        self.assertEqual(r['n'], 1)
2185        self.assertEqual(r['m'], 2)
2186        self.assertEqual(r['t'], 'x')
2187        s.update(m=3, t='y')
2188        r = upsert(table, s, **dict.fromkeys(s))
2189        self.assertIs(r, s)
2190        self.assertEqual(r['n'], 1)
2191        self.assertEqual(r['m'], 3)
2192        self.assertEqual(r['t'], 'y')
2193        q = 'select n, m, t from "%s" order by n, m limit 3' % table
2194        r = query(q).getresult()
2195        self.assertEqual(r, [(1, 2, 'x'), (1, 3, 'y')])
2196        s.update(t='z')
2197        r = upsert(table, s)
2198        self.assertIs(r, s)
2199        self.assertEqual(r['n'], 1)
2200        self.assertEqual(r['m'], 3)
2201        self.assertEqual(r['t'], 'z')
2202        r = query(q).getresult()
2203        self.assertEqual(r, [(1, 2, 'x'), (1, 3, 'z')])
2204        s.update(t='n')
2205        r = upsert(table, s, t=False)
2206        self.assertIs(r, s)
2207        self.assertEqual(r['n'], 1)
2208        self.assertEqual(r['m'], 3)
2209        self.assertEqual(r['t'], 'z')
2210        r = query(q).getresult()
2211        self.assertEqual(r, [(1, 2, 'x'), (1, 3, 'z')])
2212        s.update(t='n')
2213        r = upsert(table, s, t=True)
2214        self.assertIs(r, s)
2215        self.assertEqual(r['n'], 1)
2216        self.assertEqual(r['m'], 3)
2217        self.assertEqual(r['t'], 'n')
2218        r = query(q).getresult()
2219        self.assertEqual(r, [(1, 2, 'x'), (1, 3, 'n')])
2220        s.update(n=2, t='y')
2221        r = upsert(table, s, t="'z'")
2222        self.assertIs(r, s)
2223        self.assertEqual(r['n'], 2)
2224        self.assertEqual(r['m'], 3)
2225        self.assertEqual(r['t'], 'y')
2226        r = query(q).getresult()
2227        self.assertEqual(r, [(1, 2, 'x'), (1, 3, 'n'), (2, 3, 'y')])
2228        s.update(n=1, t='m')
2229        r = upsert(table, s, t='included.t || excluded.t')
2230        self.assertIs(r, s)
2231        self.assertEqual(r['n'], 1)
2232        self.assertEqual(r['m'], 3)
2233        self.assertEqual(r['t'], 'nm')
2234        r = query(q).getresult()
2235        self.assertEqual(r, [(1, 2, 'x'), (1, 3, 'nm'), (2, 3, 'y')])
2236
2237    def testUpsertWithQuotedNames(self):
2238        upsert = self.db.upsert
2239        query = self.db.query
2240        table = 'test table for upsert()'
2241        self.createTable(table, '"Prime!" smallint primary key,'
2242                                ' "much space" integer, "Questions?" text')
2243        s = {'Prime!': 31, 'much space': 9009, 'Questions?': 'Yes.'}
2244        try:
2245            r = upsert(table, s)
2246        except pg.ProgrammingError as error:
2247            if self.db.server_version < 90500:
2248                self.skipTest('database does not support upsert')
2249            self.fail(str(error))
2250        self.assertIs(r, s)
2251        self.assertEqual(r['Prime!'], 31)
2252        self.assertEqual(r['much space'], 9009)
2253        self.assertEqual(r['Questions?'], 'Yes.')
2254        q = 'select * from "%s" limit 2' % table
2255        r = query(q).getresult()
2256        self.assertEqual(r, [(31, 9009, 'Yes.')])
2257        s.update({'Questions?': 'No.'})
2258        r = upsert(table, s)
2259        self.assertIs(r, s)
2260        self.assertEqual(r['Prime!'], 31)
2261        self.assertEqual(r['much space'], 9009)
2262        self.assertEqual(r['Questions?'], 'No.')
2263        r = query(q).getresult()
2264        self.assertEqual(r, [(31, 9009, 'No.')])
2265
2266    def testClear(self):
2267        clear = self.db.clear
2268        f = False if pg.get_bool() else 'f'
2269        r = clear('test')
2270        result = dict(
2271            i2=0, i4=0, i8=0, d=0, f4=0, f8=0, m=0, v4='', c4='', t='')
2272        self.assertEqual(r, result)
2273        table = 'clear_test_table'
2274        self.createTable(table,
2275            'n integer, f float, b boolean, d date, t text', oids=True)
2276        r = clear(table)
2277        result = dict(n=0, f=0, b=f, d='', t='')
2278        self.assertEqual(r, result)
2279        r['a'] = r['f'] = r['n'] = 1
2280        r['d'] = r['t'] = 'x'
2281        r['b'] = 't'
2282        r['oid'] = long(1)
2283        r = clear(table, r)
2284        result = dict(a=1, n=0, f=0, b=f, d='', t='', oid=long(1))
2285        self.assertEqual(r, result)
2286
2287    def testClearWithQuotedNames(self):
2288        clear = self.db.clear
2289        table = 'test table for clear()'
2290        self.createTable(table, '"Prime!" smallint primary key,'
2291            ' "much space" integer, "Questions?" text')
2292        r = clear(table)
2293        self.assertIsInstance(r, dict)
2294        self.assertEqual(r['Prime!'], 0)
2295        self.assertEqual(r['much space'], 0)
2296        self.assertEqual(r['Questions?'], '')
2297
2298    def testDelete(self):
2299        delete = self.db.delete
2300        query = self.db.query
2301        self.assertRaises(pg.ProgrammingError, delete,
2302                          'test', dict(i2=2, i4=4, i8=8))
2303        table = 'delete_test_table'
2304        self.createTable(table, 'n integer, t text', oids=True,
2305                         values=enumerate('xyz', start=1))
2306        self.assertRaises(pg.ProgrammingError, self.db.get, table, 2)
2307        r = self.db.get(table, 1, 'n')
2308        s = delete(table, r)
2309        self.assertEqual(s, 1)
2310        r = self.db.get(table, 3, 'n')
2311        s = delete(table, r)
2312        self.assertEqual(s, 1)
2313        s = delete(table, r)
2314        self.assertEqual(s, 0)
2315        r = query('select * from "%s"' % table).dictresult()
2316        self.assertEqual(len(r), 1)
2317        r = r[0]
2318        result = {'n': 2, 't': 'y'}
2319        self.assertEqual(r, result)
2320        r = self.db.get(table, 2, 'n')
2321        s = delete(table, r)
2322        self.assertEqual(s, 1)
2323        s = delete(table, r)
2324        self.assertEqual(s, 0)
2325        self.assertRaises(pg.DatabaseError, self.db.get, table, 2, 'n')
2326        # not existing columns and oid parameter should be ignored
2327        r.update(m=3, u='z', oid='invalid')
2328        s = delete(table, r)
2329        self.assertEqual(s, 0)
2330
2331    def testDeleteWithOid(self):
2332        delete = self.db.delete
2333        get = self.db.get
2334        query = self.db.query
2335        self.createTable('test_table', 'n int', oids=True, values=range(1, 7))
2336        r = dict(n=3)
2337        self.assertRaises(pg.ProgrammingError, delete, 'test_table', r)
2338        s = get('test_table', 1, 'n')
2339        qoid = 'oid(test_table)'
2340        self.assertIn(qoid, s)
2341        r = delete('test_table', s)
2342        self.assertEqual(r, 1)
2343        r = delete('test_table', s)
2344        self.assertEqual(r, 0)
2345        q = "select min(n),count(n) from test_table"
2346        self.assertEqual(query(q).getresult()[0], (2, 5))
2347        oid = get('test_table', 2, 'n')[qoid]
2348        s = dict(oid=oid, n=2)
2349        self.assertRaises(pg.ProgrammingError, delete, 'test_table', s)
2350        r = delete('test_table', None, oid=oid)
2351        self.assertEqual(r, 1)
2352        r = delete('test_table', None, oid=oid)
2353        self.assertEqual(r, 0)
2354        self.assertEqual(query(q).getresult()[0], (3, 4))
2355        s = dict(oid=oid, n=2)
2356        oid = get('test_table', 3, 'n')[qoid]
2357        self.assertRaises(pg.ProgrammingError, delete, 'test_table', s)
2358        r = delete('test_table', s, oid=oid)
2359        self.assertEqual(r, 1)
2360        r = delete('test_table', s, oid=oid)
2361        self.assertEqual(r, 0)
2362        self.assertEqual(query(q).getresult()[0], (4, 3))
2363        s = get('test_table', 4, 'n')
2364        r = delete('test_table *', s)
2365        self.assertEqual(r, 1)
2366        r = delete('test_table *', s)
2367        self.assertEqual(r, 0)
2368        self.assertEqual(query(q).getresult()[0], (5, 2))
2369        oid = get('test_table', 5, 'n')[qoid]
2370        s = {qoid: oid, 'm': 4}
2371        r = delete('test_table', s, m=6)
2372        self.assertEqual(r, 1)
2373        r = delete('test_table *', s)
2374        self.assertEqual(r, 0)
2375        self.assertEqual(query(q).getresult()[0], (6, 1))
2376        query("alter table test_table add column m int")
2377        query("alter table test_table add primary key (n)")
2378        self.assertIn('m', self.db.get_attnames('test_table', flush=True))
2379        self.assertEqual('n', self.db.pkey('test_table', flush=True))
2380        for i in range(5):
2381            query("insert into test_table values (%d, %d)" % (i + 1, i + 2))
2382        s = dict(m=2)
2383        self.assertRaises(KeyError, delete, 'test_table', s)
2384        s = dict(m=2, oid=oid)
2385        self.assertRaises(KeyError, delete, 'test_table', s)
2386        r = delete('test_table', dict(m=2), oid=oid)
2387        self.assertEqual(r, 0)
2388        oid = get('test_table', 1, 'n')[qoid]
2389        s = dict(oid=oid)
2390        self.assertRaises(KeyError, delete, 'test_table', s)
2391        r = delete('test_table', s, oid=oid)
2392        self.assertEqual(r, 1)
2393        r = delete('test_table', s, oid=oid)
2394        self.assertEqual(r, 0)
2395        self.assertEqual(query(q).getresult()[0], (2, 5))
2396        s = get('test_table', 2, 'n')
2397        del s['n']
2398        r = delete('test_table', s)
2399        self.assertEqual(r, 1)
2400        r = delete('test_table', s)
2401        self.assertEqual(r, 0)
2402        self.assertEqual(query(q).getresult()[0], (3, 4))
2403        r = delete('test_table', n=3)
2404        self.assertEqual(r, 1)
2405        r = delete('test_table', n=3)
2406        self.assertEqual(r, 0)
2407        self.assertEqual(query(q).getresult()[0], (4, 3))
2408        r = delete('test_table', None, n=4)
2409        self.assertEqual(r, 1)
2410        r = delete('test_table', None, n=4)
2411        self.assertEqual(r, 0)
2412        self.assertEqual(query(q).getresult()[0], (5, 2))
2413        s = dict(n=6)
2414        r = delete('test_table', s, n=5)
2415        self.assertEqual(r, 1)
2416        r = delete('test_table', s, n=5)
2417        self.assertEqual(r, 0)
2418        s = get('test_table', 6, 'n')
2419        self.assertEqual(s['n'], 6)
2420        s['n'] = 7
2421        r = delete('test_table', s)
2422        self.assertEqual(r, 1)
2423        self.assertEqual(query(q).getresult()[0], (None, 0))
2424
2425    def testDeleteWithCompositeKey(self):
2426        query = self.db.query
2427        table = 'delete_test_table_1'
2428        self.createTable(table, 'n integer primary key, t text',
2429                         values=enumerate('abc', start=1))
2430        self.assertRaises(KeyError, self.db.delete, table, dict(t='b'))
2431        self.assertEqual(self.db.delete(table, dict(n=2)), 1)
2432        r = query('select t from "%s" where n=2' % table).getresult()
2433        self.assertEqual(r, [])
2434        self.assertEqual(self.db.delete(table, dict(n=2)), 0)
2435        r = query('select t from "%s" where n=3' % table).getresult()[0][0]
2436        self.assertEqual(r, 'c')
2437        table = 'delete_test_table_2'
2438        self.createTable(table,
2439             'n integer, m integer, t text, primary key (n, m)',
2440             values=[(n + 1, m + 1, chr(ord('a') + 2 * n + m))
2441                     for n in range(3) for m in range(2)])
2442        self.assertRaises(KeyError, self.db.delete, table, dict(n=2, t='b'))
2443        self.assertEqual(self.db.delete(table, dict(n=2, m=2)), 1)
2444        r = [r[0] for r in query('select t from "%s" where n=2'
2445            ' order by m' % table).getresult()]
2446        self.assertEqual(r, ['c'])
2447        self.assertEqual(self.db.delete(table, dict(n=2, m=2)), 0)
2448        r = [r[0] for r in query('select t from "%s" where n=3'
2449             ' order by m' % table).getresult()]
2450        self.assertEqual(r, ['e', 'f'])
2451        self.assertEqual(self.db.delete(table, dict(n=3, m=1)), 1)
2452        r = [r[0] for r in query('select t from "%s" where n=3'
2453             ' order by m' % table).getresult()]
2454        self.assertEqual(r, ['f'])
2455
2456    def testDeleteWithQuotedNames(self):
2457        delete = self.db.delete
2458        query = self.db.query
2459        table = 'test table for delete()'
2460        self.createTable(table, '"Prime!" smallint primary key,'
2461            ' "much space" integer, "Questions?" text',
2462            values=[(19, 5005, 'Yes!')])
2463        r = {'Prime!': 17}
2464        r = delete(table, r)
2465        self.assertEqual(r, 0)
2466        r = query('select count(*) from "%s"' % table).getresult()
2467        self.assertEqual(r[0][0], 1)
2468        r = {'Prime!': 19}
2469        r = delete(table, r)
2470        self.assertEqual(r, 1)
2471        r = query('select count(*) from "%s"' % table).getresult()
2472        self.assertEqual(r[0][0], 0)
2473
2474    def testDeleteReferenced(self):
2475        delete = self.db.delete
2476        query = self.db.query
2477        self.createTable('test_parent',
2478            'n smallint primary key', values=range(3))
2479        self.createTable('test_child',
2480            'n smallint primary key references test_parent', values=range(3))
2481        q = ("select (select count(*) from test_parent),"
2482             " (select count(*) from test_child)")
2483        self.assertEqual(query(q).getresult()[0], (3, 3))
2484        self.assertRaises(pg.IntegrityError,
2485                          delete, 'test_parent', None, n=2)
2486        self.assertRaises(pg.IntegrityError,
2487                          delete, 'test_parent *', None, n=2)
2488        r = delete('test_child', None, n=2)
2489        self.assertEqual(r, 1)
2490        self.assertEqual(query(q).getresult()[0], (3, 2))
2491        r = delete('test_parent', None, n=2)
2492        self.assertEqual(r, 1)
2493        self.assertEqual(query(q).getresult()[0], (2, 2))
2494        self.assertRaises(pg.IntegrityError,
2495                          delete, 'test_parent', dict(n=0))
2496        self.assertRaises(pg.IntegrityError,
2497                          delete, 'test_parent *', dict(n=0))
2498        r = delete('test_child', dict(n=0))
2499        self.assertEqual(r, 1)
2500        self.assertEqual(query(q).getresult()[0], (2, 1))
2501        r = delete('test_child', dict(n=0))
2502        self.assertEqual(r, 0)
2503        r = delete('test_parent', dict(n=0))
2504        self.assertEqual(r, 1)
2505        self.assertEqual(query(q).getresult()[0], (1, 1))
2506        r = delete('test_parent', None, n=0)
2507        self.assertEqual(r, 0)
2508        q = "select n from test_parent natural join test_child limit 2"
2509        self.assertEqual(query(q).getresult(), [(1,)])
2510
2511    def testTempCrud(self):
2512        table = 'test_temp_table'
2513        self.createTable(table, "n int primary key, t varchar", temporary=True)
2514        self.db.insert(table, dict(n=1, t='one'))
2515        self.db.insert(table, dict(n=2, t='too'))
2516        self.db.insert(table, dict(n=3, t='three'))
2517        r = self.db.get(table, 2)
2518        self.assertEqual(r['t'], 'too')
2519        self.db.update(table, dict(n=2, t='two'))
2520        r = self.db.get(table, 2)
2521        self.assertEqual(r['t'], 'two')
2522        self.db.delete(table, r)
2523        r = self.db.query('select n, t from %s order by 1' % table).getresult()
2524        self.assertEqual(r, [(1, 'one'), (3, 'three')])
2525
2526    def testTruncate(self):
2527        truncate = self.db.truncate
2528        self.assertRaises(TypeError, truncate, None)
2529        self.assertRaises(TypeError, truncate, 42)
2530        self.assertRaises(TypeError, truncate, dict(test_table=None))
2531        query = self.db.query
2532        self.createTable('test_table', 'n smallint',
2533                         temporary=False, values=[1] * 3)
2534        q = "select count(*) from test_table"
2535        r = query(q).getresult()[0][0]
2536        self.assertEqual(r, 3)
2537        truncate('test_table')
2538        r = query(q).getresult()[0][0]
2539        self.assertEqual(r, 0)
2540        for i in range(3):
2541            query("insert into test_table values (1)")
2542        r = query(q).getresult()[0][0]
2543        self.assertEqual(r, 3)
2544        truncate('public.test_table')
2545        r = query(q).getresult()[0][0]
2546        self.assertEqual(r, 0)
2547        self.createTable('test_table_2', 'n smallint', temporary=True)
2548        for t in (list, tuple, set):
2549            for i in range(3):
2550                query("insert into test_table values (1)")
2551                query("insert into test_table_2 values (2)")
2552            q = ("select (select count(*) from test_table),"
2553                 " (select count(*) from test_table_2)")
2554            r = query(q).getresult()[0]
2555            self.assertEqual(r, (3, 3))
2556            truncate(t(['test_table', 'test_table_2']))
2557            r = query(q).getresult()[0]
2558            self.assertEqual(r, (0, 0))
2559
2560    def testTruncateRestart(self):
2561        truncate = self.db.truncate
2562        self.assertRaises(TypeError, truncate, 'test_table', restart='invalid')
2563        query = self.db.query
2564        self.createTable('test_table', 'n serial, t text')
2565        for n in range(3):
2566            query("insert into test_table (t) values ('test')")
2567        q = "select count(n), min(n), max(n) from test_table"
2568        r = query(q).getresult()[0]
2569        self.assertEqual(r, (3, 1, 3))
2570        truncate('test_table')
2571        r = query(q).getresult()[0]
2572        self.assertEqual(r, (0, None, None))
2573        for n in range(3):
2574            query("insert into test_table (t) values ('test')")
2575        r = query(q).getresult()[0]
2576        self.assertEqual(r, (3, 4, 6))
2577        truncate('test_table', restart=True)
2578        r = query(q).getresult()[0]
2579        self.assertEqual(r, (0, None, None))
2580        for n in range(3):
2581            query("insert into test_table (t) values ('test')")
2582        r = query(q).getresult()[0]
2583        self.assertEqual(r, (3, 1, 3))
2584
2585    def testTruncateCascade(self):
2586        truncate = self.db.truncate
2587        self.assertRaises(TypeError, truncate, 'test_table', cascade='invalid')
2588        query = self.db.query
2589        self.createTable('test_parent', 'n smallint primary key',
2590                         values=range(3))
2591        self.createTable('test_child',
2592                         'n smallint primary key references test_parent (n)',
2593                         values=range(3))
2594        q = ("select (select count(*) from test_parent),"
2595             " (select count(*) from test_child)")
2596        r = query(q).getresult()[0]
2597        self.assertEqual(r, (3, 3))
2598        self.assertRaises(pg.NotSupportedError, truncate, 'test_parent')
2599        truncate(['test_parent', 'test_child'])
2600        r = query(q).getresult()[0]
2601        self.assertEqual(r, (0, 0))
2602        for n in range(3):
2603            query("insert into test_parent (n) values (%d)" % n)
2604            query("insert into test_child (n) values (%d)" % n)
2605        r = query(q).getresult()[0]
2606        self.assertEqual(r, (3, 3))
2607        truncate('test_parent', cascade=True)
2608        r = query(q).getresult()[0]
2609        self.assertEqual(r, (0, 0))
2610        for n in range(3):
2611            query("insert into test_parent (n) values (%d)" % n)
2612            query("insert into test_child (n) values (%d)" % n)
2613        r = query(q).getresult()[0]
2614        self.assertEqual(r, (3, 3))
2615        truncate('test_child')
2616        r = query(q).getresult()[0]
2617        self.assertEqual(r, (3, 0))
2618        self.assertRaises(pg.NotSupportedError, truncate, 'test_parent')
2619        truncate('test_parent', cascade=True)
2620        r = query(q).getresult()[0]
2621        self.assertEqual(r, (0, 0))
2622
2623    def testTruncateOnly(self):
2624        truncate = self.db.truncate
2625        self.assertRaises(TypeError, truncate, 'test_table', only='invalid')
2626        query = self.db.query
2627        self.createTable('test_parent', 'n smallint')
2628        self.createTable('test_child', 'm smallint) inherits (test_parent')
2629        for n in range(3):
2630            query("insert into test_parent (n) values (1)")
2631            query("insert into test_child (n, m) values (2, 3)")
2632        q = ("select (select count(*) from test_parent),"
2633             " (select count(*) from test_child)")
2634        r = query(q).getresult()[0]
2635        self.assertEqual(r, (6, 3))
2636        truncate('test_parent')
2637        r = query(q).getresult()[0]
2638        self.assertEqual(r, (0, 0))
2639        for n in range(3):
2640            query("insert into test_parent (n) values (1)")
2641            query("insert into test_child (n, m) values (2, 3)")
2642        r = query(q).getresult()[0]
2643        self.assertEqual(r, (6, 3))
2644        truncate('test_parent*')
2645        r = query(q).getresult()[0]
2646        self.assertEqual(r, (0, 0))
2647        for n in range(3):
2648            query("insert into test_parent (n) values (1)")
2649            query("insert into test_child (n, m) values (2, 3)")
2650        r = query(q).getresult()[0]
2651        self.assertEqual(r, (6, 3))
2652        truncate('test_parent', only=True)
2653        r = query(q).getresult()[0]
2654        self.assertEqual(r, (3, 3))
2655        truncate('test_parent', only=False)
2656        r = query(q).getresult()[0]
2657        self.assertEqual(r, (0, 0))
2658        self.assertRaises(ValueError, truncate, 'test_parent*', only=True)
2659        truncate('test_parent*', only=False)
2660        self.createTable('test_parent_2', 'n smallint')
2661        self.createTable('test_child_2', 'm smallint) inherits (test_parent_2')
2662        for t in '', '_2':
2663            for n in range(3):
2664                query("insert into test_parent%s (n) values (1)" % t)
2665                query("insert into test_child%s (n, m) values (2, 3)" % t)
2666        q = ("select (select count(*) from test_parent),"
2667             " (select count(*) from test_child),"
2668             " (select count(*) from test_parent_2),"
2669             " (select count(*) from test_child_2)")
2670        r = query(q).getresult()[0]
2671        self.assertEqual(r, (6, 3, 6, 3))
2672        truncate(['test_parent', 'test_parent_2'], only=[False, True])
2673        r = query(q).getresult()[0]
2674        self.assertEqual(r, (0, 0, 3, 3))
2675        truncate(['test_parent', 'test_parent_2'], only=False)
2676        r = query(q).getresult()[0]
2677        self.assertEqual(r, (0, 0, 0, 0))
2678        self.assertRaises(ValueError, truncate,
2679            ['test_parent*', 'test_child'], only=[True, False])
2680        truncate(['test_parent*', 'test_child'], only=[False, True])
2681
2682    def testTruncateQuoted(self):
2683        truncate = self.db.truncate
2684        query = self.db.query
2685        table = "test table for truncate()"
2686        self.createTable(table, 'n smallint', temporary=False, values=[1] * 3)
2687        q = 'select count(*) from "%s"' % table
2688        r = query(q).getresult()[0][0]
2689        self.assertEqual(r, 3)
2690        truncate(table)
2691        r = query(q).getresult()[0][0]
2692        self.assertEqual(r, 0)
2693        for i in range(3):
2694            query('insert into "%s" values (1)' % table)
2695        r = query(q).getresult()[0][0]
2696        self.assertEqual(r, 3)
2697        truncate('public."%s"' % table)
2698        r = query(q).getresult()[0][0]
2699        self.assertEqual(r, 0)
2700
2701    def testGetAsList(self):
2702        get_as_list = self.db.get_as_list
2703        self.assertRaises(TypeError, get_as_list)
2704        self.assertRaises(TypeError, get_as_list, None)
2705        query = self.db.query
2706        table = 'test_aslist'
2707        r = query('select 1 as colname').namedresult()[0]
2708        self.assertIsInstance(r, tuple)
2709        named = hasattr(r, 'colname')
2710        names = [(1, 'Homer'), (2, 'Marge'),
2711                 (3, 'Bart'), (4, 'Lisa'), (5, 'Maggie')]
2712        self.createTable(table,
2713            'id smallint primary key, name varchar', values=names)
2714        r = get_as_list(table)
2715        self.assertIsInstance(r, list)
2716        self.assertEqual(r, names)
2717        for t, n in zip(r, names):
2718            self.assertIsInstance(t, tuple)
2719            self.assertEqual(t, n)
2720            if named:
2721                self.assertEqual(t.id, n[0])
2722                self.assertEqual(t.name, n[1])
2723                self.assertEqual(t._asdict(), dict(id=n[0], name=n[1]))
2724        r = get_as_list(table, what='name')
2725        self.assertIsInstance(r, list)
2726        expected = sorted((row[1],) for row in names)
2727        self.assertEqual(r, expected)
2728        r = get_as_list(table, what='name, id')
2729        self.assertIsInstance(r, list)
2730        expected = sorted(tuple(reversed(row)) for row in names)
2731        self.assertEqual(r, expected)
2732        r = get_as_list(table, what=['name', 'id'])
2733        self.assertIsInstance(r, list)
2734        self.assertEqual(r, expected)
2735        r = get_as_list(table, where="name like 'Ba%'")
2736        self.assertIsInstance(r, list)
2737        self.assertEqual(r, names[2:3])
2738        r = get_as_list(table, what='name', where="name like 'Ma%'")
2739        self.assertIsInstance(r, list)
2740        self.assertEqual(r, [('Maggie',), ('Marge',)])
2741        r = get_as_list(table, what='name',
2742            where=["name like 'Ma%'", "name like '%r%'"])
2743        self.assertIsInstance(r, list)
2744        self.assertEqual(r, [('Marge',)])
2745        r = get_as_list(table, what='name', order='id')
2746        self.assertIsInstance(r, list)
2747        expected = [(row[1],) for row in names]
2748        self.assertEqual(r, expected)
2749        r = get_as_list(table, what=['name'], order=['id'])
2750        self.assertIsInstance(r, list)
2751        self.assertEqual(r, expected)
2752        r = get_as_list(table, what=['id', 'name'], order=['id', 'name'])
2753        self.assertIsInstance(r, list)
2754        self.assertEqual(r, names)
2755        r = get_as_list(table, what='id * 2 as num', order='id desc')
2756        self.assertIsInstance(r, list)
2757        expected = [(n,) for n in range(10, 0, -2)]
2758        self.assertEqual(r, expected)
2759        r = get_as_list(table, limit=2)
2760        self.assertIsInstance(r, list)
2761        self.assertEqual(r, names[:2])
2762        r = get_as_list(table, offset=3)
2763        self.assertIsInstance(r, list)
2764        self.assertEqual(r, names[3:])
2765        r = get_as_list(table, limit=1, offset=2)
2766        self.assertIsInstance(r, list)
2767        self.assertEqual(r, names[2:3])
2768        r = get_as_list(table, scalar=True)
2769        self.assertIsInstance(r, list)
2770        self.assertEqual(r, list(range(1, 6)))
2771        r = get_as_list(table, what='name', scalar=True)
2772        self.assertIsInstance(r, list)
2773        expected = sorted(row[1] for row in names)
2774        self.assertEqual(r, expected)
2775        r = get_as_list(table, what='name', limit=1, scalar=True)
2776        self.assertIsInstance(r, list)
2777        self.assertEqual(r, expected[:1])
2778        query('alter table "%s" drop constraint "%s_pkey"' % (table, table))
2779        self.assertRaises(KeyError, self.db.pkey, table, flush=True)
2780        names.insert(1, (1, 'Snowball'))
2781        query('insert into "%s" values ($1, $2)' % table, (1, 'Snowball'))
2782        r = get_as_list(table)
2783        self.assertIsInstance(r, list)
2784        self.assertEqual(r, names)
2785        r = get_as_list(table, what='name', where='id=1', scalar=True)
2786        self.assertIsInstance(r, list)
2787        self.assertEqual(r, ['Homer', 'Snowball'])
2788        # test with unordered query
2789        r = get_as_list(table, order=False)
2790        self.assertIsInstance(r, list)
2791        self.assertEqual(set(r), set(names))
2792        # test with arbitrary from clause
2793        from_table = '(select lower(name) as n2 from "%s") as t2' % table
2794        r = get_as_list(from_table)
2795        self.assertIsInstance(r, list)
2796        r = set(row[0] for row in r)
2797        expected = set(row[1].lower() for row in names)
2798        self.assertEqual(r, expected)
2799        r = get_as_list(from_table, order='n2', scalar=True)
2800        self.assertIsInstance(r, list)
2801        self.assertEqual(r, sorted(expected))
2802        r = get_as_list(from_table, order='n2', limit=1)
2803        self.assertIsInstance(r, list)
2804        self.assertEqual(len(r), 1)
2805        t = r[0]
2806        self.assertIsInstance(t, tuple)
2807        if named:
2808            self.assertEqual(t.n2, 'bart')
2809            self.assertEqual(t._asdict(), dict(n2='bart'))
2810        else:
2811            self.assertEqual(t, ('bart',))
2812
2813    def testGetAsDict(self):
2814        get_as_dict = self.db.get_as_dict
2815        self.assertRaises(TypeError, get_as_dict)
2816        self.assertRaises(TypeError, get_as_dict, None)
2817        # the test table has no primary key
2818        self.assertRaises(pg.ProgrammingError, get_as_dict, 'test')
2819        query = self.db.query
2820        table = 'test_asdict'
2821        r = query('select 1 as colname').namedresult()[0]
2822        self.assertIsInstance(r, tuple)
2823        named = hasattr(r, 'colname')
2824        colors = [(1, '#7cb9e8', 'Aero'), (2, '#b5a642', 'Brass'),
2825                  (3, '#b2ffff', 'Celeste'), (4, '#c19a6b', 'Desert')]
2826        self.createTable(table,
2827            'id smallint primary key, rgb char(7), name varchar',
2828            values=colors)
2829        # keyname must be string, list or tuple
2830        self.assertRaises(KeyError, get_as_dict, table, 3)
2831        self.assertRaises(KeyError, get_as_dict, table, dict(id=None))
2832        # missing keyname in row
2833        self.assertRaises(KeyError, get_as_dict, table,
2834                          keyname='rgb', what='name')
2835        r = get_as_dict(table)
2836        self.assertIsInstance(r, OrderedDict)
2837        expected = OrderedDict((row[0], row[1:]) for row in colors)
2838        self.assertEqual(r, expected)
2839        for key in r:
2840            self.assertIsInstance(key, int)
2841            self.assertIn(key, expected)
2842            row = r[key]
2843            self.assertIsInstance(row, tuple)
2844            t = expected[key]
2845            self.assertEqual(row, t)
2846            if named:
2847                self.assertEqual(row.rgb, t[0])
2848                self.assertEqual(row.name, t[1])
2849                self.assertEqual(row._asdict(), dict(rgb=t[0], name=t[1]))
2850        if OrderedDict is not dict:  # Python > 2.6
2851            self.assertEqual(r.keys(), expected.keys())
2852        r = get_as_dict(table, keyname='rgb')
2853        self.assertIsInstance(r, OrderedDict)
2854        expected = OrderedDict((row[1], (row[0], row[2]))
2855                               for row in sorted(colors, key=itemgetter(1)))
2856        self.assertEqual(r, expected)
2857        for key in r:
2858            self.assertIsInstance(key, str)
2859            self.assertIn(key, expected)
2860            row = r[key]
2861            self.assertIsInstance(row, tuple)
2862            t = expected[key]
2863            self.assertEqual(row, t)
2864            if named:
2865                self.assertEqual(row.id, t[0])
2866                self.assertEqual(row.name, t[1])
2867                self.assertEqual(row._asdict(), dict(id=t[0], name=t[1]))
2868        if OrderedDict is not dict:  # Python > 2.6
2869            self.assertEqual(r.keys(), expected.keys())
2870        r = get_as_dict(table, keyname=['id', 'rgb'])
2871        self.assertIsInstance(r, OrderedDict)
2872        expected = OrderedDict((row[:2], row[2:]) for row in colors)
2873        self.assertEqual(r, expected)
2874        for key in r:
2875            self.assertIsInstance(key, tuple)
2876            self.assertIsInstance(key[0], int)
2877            self.assertIsInstance(key[1], str)
2878            if named:
2879                self.assertEqual(key, (key.id, key.rgb))
2880                self.assertEqual(key._fields, ('id', 'rgb'))
2881            row = r[key]
2882            self.assertIsInstance(row, tuple)
2883            self.assertIsInstance(row[0], str)
2884            t = expected[key]
2885            self.assertEqual(row, t)
2886            if named:
2887                self.assertEqual(row.name, t[0])
2888                self.assertEqual(row._asdict(), dict(name=t[0]))
2889        if OrderedDict is not dict:  # Python > 2.6
2890            self.assertEqual(r.keys(), expected.keys())
2891        r = get_as_dict(table, keyname=['id', 'rgb'], scalar=True)
2892        self.assertIsInstance(r, OrderedDict)
2893        expected = OrderedDict((row[:2], row[2]) for row in colors)
2894        self.assertEqual(r, expected)
2895        for key in r:
2896            self.assertIsInstance(key, tuple)
2897            row = r[key]
2898            self.assertIsInstance(row, str)
2899            t = expected[key]
2900            self.assertEqual(row, t)
2901        if OrderedDict is not dict:  # Python > 2.6
2902            self.assertEqual(r.keys(), expected.keys())
2903        r = get_as_dict(table, keyname='rgb', what=['rgb', 'name'], scalar=True)
2904        self.assertIsInstance(r, OrderedDict)
2905        expected = OrderedDict((row[1], row[2])
2906            for row in sorted(colors, key=itemgetter(1)))
2907        self.assertEqual(r, expected)
2908        for key in r:
2909            self.assertIsInstance(key, str)
2910            row = r[key]
2911            self.assertIsInstance(row, str)
2912            t = expected[key]
2913            self.assertEqual(row, t)
2914        if OrderedDict is not dict:  # Python > 2.6
2915            self.assertEqual(r.keys(), expected.keys())
2916        r = get_as_dict(table, what='id, name',
2917            where="rgb like '#b%'", scalar=True)
2918        self.assertIsInstance(r, OrderedDict)
2919        expected = OrderedDict((row[0], row[2]) for row in colors[1:3])
2920        self.assertEqual(r, expected)
2921        for key in r:
2922            self.assertIsInstance(key, int)
2923            row = r[key]
2924            self.assertIsInstance(row, str)
2925            t = expected[key]
2926            self.assertEqual(row, t)
2927        if OrderedDict is not dict:  # Python > 2.6
2928            self.assertEqual(r.keys(), expected.keys())
2929        expected = r
2930        r = get_as_dict(table, what=['name', 'id'],
2931            where=['id > 1', 'id < 4', "rgb like '#b%'",
2932                   "name not like 'A%'", "name not like '%t'"], scalar=True)
2933        self.assertEqual(r, expected)
2934        r = get_as_dict(table, what='name, id', limit=2, offset=1, scalar=True)
2935        self.assertEqual(r, expected)
2936        r = get_as_dict(table, keyname=('id',), what=('name', 'id'),
2937            where=('id > 1', 'id < 4'), order=('id',), scalar=True)
2938        self.assertEqual(r, expected)
2939        r = get_as_dict(table, limit=1)
2940        self.assertEqual(len(r), 1)
2941        self.assertEqual(r[1][1], 'Aero')
2942        r = get_as_dict(table, offset=3)
2943        self.assertEqual(len(r), 1)
2944        self.assertEqual(r[4][1], 'Desert')
2945        r = get_as_dict(table, order='id desc')
2946        expected = OrderedDict((row[0], row[1:]) for row in reversed(colors))
2947        self.assertEqual(r, expected)
2948        r = get_as_dict(table, where='id > 5')
2949        self.assertIsInstance(r, OrderedDict)
2950        self.assertEqual(len(r), 0)
2951        # test with unordered query
2952        expected = dict((row[0], row[1:]) for row in colors)
2953        r = get_as_dict(table, order=False)
2954        self.assertIsInstance(r, dict)
2955        self.assertEqual(r, expected)
2956        if dict is not OrderedDict:  # Python > 2.6
2957            self.assertNotIsInstance(self, OrderedDict)
2958        # test with arbitrary from clause
2959        from_table = '(select id, lower(name) as n2 from "%s") as t2' % table
2960        # primary key must be passed explicitly in this case
2961        self.assertRaises(pg.ProgrammingError, get_as_dict, from_table)
2962        r = get_as_dict(from_table, 'id')
2963        self.assertIsInstance(r, OrderedDict)
2964        expected = OrderedDict((row[0], (row[2].lower(),)) for row in colors)
2965        self.assertEqual(r, expected)
2966        # test without a primary key
2967        query('alter table "%s" drop constraint "%s_pkey"' % (table, table))
2968        self.assertRaises(KeyError, self.db.pkey, table, flush=True)
2969        self.assertRaises(pg.ProgrammingError, get_as_dict, table)
2970        r = get_as_dict(table, keyname='id')
2971        expected = OrderedDict((row[0], row[1:]) for row in colors)
2972        self.assertIsInstance(r, dict)
2973        self.assertEqual(r, expected)
2974        r = (1, '#007fff', 'Azure')
2975        query('insert into "%s" values ($1, $2, $3)' % table, r)
2976        # the last entry will win
2977        expected[1] = r[1:]
2978        r = get_as_dict(table, keyname='id')
2979        self.assertEqual(r, expected)
2980
2981    def testTransaction(self):
2982        query = self.db.query
2983        self.createTable('test_table', 'n integer', temporary=False)
2984        self.db.begin()
2985        query("insert into test_table values (1)")
2986        query("insert into test_table values (2)")
2987        self.db.commit()
2988        self.db.begin()
2989        query("insert into test_table values (3)")
2990        query("insert into test_table values (4)")
2991        self.db.rollback()
2992        self.db.begin()
2993        query("insert into test_table values (5)")
2994        self.db.savepoint('before6')
2995        query("insert into test_table values (6)")
2996        self.db.rollback('before6')
2997        query("insert into test_table values (7)")
2998        self.db.commit()
2999        self.db.begin()
3000        self.db.savepoint('before8')
3001        query("insert into test_table values (8)")
3002        self.db.release('before8')
3003        self.assertRaises(pg.InternalError, self.db.rollback, 'before8')
3004        self.db.commit()
3005        self.db.start()
3006        query("insert into test_table values (9)")
3007        self.db.end()
3008        r = [r[0] for r in query(
3009            "select * from test_table order by 1").getresult()]
3010        self.assertEqual(r, [1, 2, 5, 7, 9])
3011        self.db.begin(mode='read only')
3012        self.assertRaises(pg.InternalError,
3013                          query, "insert into test_table values (0)")
3014        self.db.rollback()
3015        self.db.start(mode='Read Only')
3016        self.assertRaises(pg.InternalError,
3017                          query, "insert into test_table values (0)")
3018        self.db.abort()
3019
3020    def testTransactionAliases(self):
3021        self.assertEqual(self.db.begin, self.db.start)
3022        self.assertEqual(self.db.commit, self.db.end)
3023        self.assertEqual(self.db.rollback, self.db.abort)
3024
3025    def testContextManager(self):
3026        query = self.db.query
3027        self.createTable('test_table', 'n integer check(n>0)')
3028        with self.db:
3029            query("insert into test_table values (1)")
3030            query("insert into test_table values (2)")
3031        try:
3032            with self.db:
3033                query("insert into test_table values (3)")
3034                query("insert into test_table values (4)")
3035                raise ValueError('test transaction should rollback')
3036        except ValueError as error:
3037            self.assertEqual(str(error), 'test transaction should rollback')
3038        with self.db:
3039            query("insert into test_table values (5)")
3040        try:
3041            with self.db:
3042                query("insert into test_table values (6)")
3043                query("insert into test_table values (-1)")
3044        except pg.IntegrityError as error:
3045            self.assertTrue('check' in str(error))
3046        with self.db:
3047            query("insert into test_table values (7)")
3048        r = [r[0] for r in query(
3049            "select * from test_table order by 1").getresult()]
3050        self.assertEqual(r, [1, 2, 5, 7])
3051
3052    def testBytea(self):
3053        query = self.db.query
3054        self.createTable('bytea_test', 'n smallint primary key, data bytea')
3055        s = b"It's all \\ kinds \x00 of\r nasty \xff stuff!\n"
3056        r = self.db.escape_bytea(s)
3057        query('insert into bytea_test values(3, $1)', (r,))
3058        r = query('select * from bytea_test where n=3').getresult()
3059        self.assertEqual(len(r), 1)
3060        r = r[0]
3061        self.assertEqual(len(r), 2)
3062        self.assertEqual(r[0], 3)
3063        r = r[1]
3064        if pg.get_bytea_escaped():
3065            self.assertNotEqual(r, s)
3066            r = pg.unescape_bytea(r)
3067        self.assertIsInstance(r, bytes)
3068        self.assertEqual(r, s)
3069
3070    def testInsertUpdateGetBytea(self):
3071        query = self.db.query
3072        unescape = pg.unescape_bytea if pg.get_bytea_escaped() else None
3073        self.createTable('bytea_test', 'n smallint primary key, data bytea')
3074        # insert null value
3075        r = self.db.insert('bytea_test', n=0, data=None)
3076        self.assertIsInstance(r, dict)
3077        self.assertIn('n', r)
3078        self.assertEqual(r['n'], 0)
3079        self.assertIn('data', r)
3080        self.assertIsNone(r['data'])
3081        s = b'None'
3082        r = self.db.update('bytea_test', n=0, data=s)
3083        self.assertIsInstance(r, dict)
3084        self.assertIn('n', r)
3085        self.assertEqual(r['n'], 0)
3086        self.assertIn('data', r)
3087        r = r['data']
3088        if unescape:
3089            self.assertNotEqual(r, s)
3090            r = unescape(r)
3091        self.assertIsInstance(r, bytes)
3092        self.assertEqual(r, s)
3093        r = self.db.update('bytea_test', n=0, data=None)
3094        self.assertIsNone(r['data'])
3095        # insert as bytes
3096        s = b"It's all \\ kinds \x00 of\r nasty \xff stuff!\n"
3097        r = self.db.insert('bytea_test', n=5, data=s)
3098        self.assertIsInstance(r, dict)
3099        self.assertIn('n', r)
3100        self.assertEqual(r['n'], 5)
3101        self.assertIn('data', r)
3102        r = r['data']
3103        if unescape:
3104            self.assertNotEqual(r, s)
3105            r = unescape(r)
3106        self.assertIsInstance(r, bytes)
3107        self.assertEqual(r, s)
3108        # update as bytes
3109        s += b"and now even more \x00 nasty \t stuff!\f"
3110        r = self.db.update('bytea_test', n=5, data=s)
3111        self.assertIsInstance(r, dict)
3112        self.assertIn('n', r)
3113        self.assertEqual(r['n'], 5)
3114        self.assertIn('data', r)
3115        r = r['data']
3116        if unescape:
3117            self.assertNotEqual(r, s)
3118            r = unescape(r)
3119        self.assertIsInstance(r, bytes)
3120        self.assertEqual(r, s)
3121        r = query('select * from bytea_test where n=5').getresult()
3122        self.assertEqual(len(r), 1)
3123        r = r[0]
3124        self.assertEqual(len(r), 2)
3125        self.assertEqual(r[0], 5)
3126        r = r[1]
3127        if unescape:
3128            self.assertNotEqual(r, s)
3129            r = unescape(r)
3130        self.assertIsInstance(r, bytes)
3131        self.assertEqual(r, s)
3132        r = self.db.get('bytea_test', dict(n=5))
3133        self.assertIsInstance(r, dict)
3134        self.assertIn('n', r)
3135        self.assertEqual(r['n'], 5)
3136        self.assertIn('data', r)
3137        r = r['data']
3138        if unescape:
3139            self.assertNotEqual(r, s)
3140            r = pg.unescape_bytea(r)
3141        self.assertIsInstance(r, bytes)
3142        self.assertEqual(r, s)
3143
3144    def testUpsertBytea(self):
3145        self.createTable('bytea_test', 'n smallint primary key, data bytea')
3146        s = b"It's all \\ kinds \x00 of\r nasty \xff stuff!\n"
3147        r = dict(n=7, data=s)
3148        try:
3149            r = self.db.upsert('bytea_test', r)
3150        except pg.ProgrammingError as error:
3151            if self.db.server_version < 90500:
3152                self.skipTest('database does not support upsert')
3153            self.fail(str(error))
3154        self.assertIsInstance(r, dict)
3155        self.assertIn('n', r)
3156        self.assertEqual(r['n'], 7)
3157        self.assertIn('data', r)
3158        if pg.get_bytea_escaped():
3159            self.assertNotEqual(r['data'], s)
3160            r['data'] = pg.unescape_bytea(r['data'])
3161        self.assertIsInstance(r['data'], bytes)
3162        self.assertEqual(r['data'], s)
3163        r['data'] = None
3164        r = self.db.upsert('bytea_test', r)
3165        self.assertIsInstance(r, dict)
3166        self.assertIn('n', r)
3167        self.assertEqual(r['n'], 7)
3168        self.assertIn('data', r)
3169        self.assertIsNone(r['data'])
3170
3171    def testInsertGetJson(self):
3172        try:
3173            self.createTable('json_test', 'n smallint primary key, data json')
3174        except pg.ProgrammingError as error:
3175            if self.db.server_version < 90200:
3176                self.skipTest('database does not support json')
3177            self.fail(str(error))
3178        jsondecode = pg.get_jsondecode()
3179        # insert null value
3180        r = self.db.insert('json_test', n=0, data=None)
3181        self.assertIsInstance(r, dict)
3182        self.assertIn('n', r)
3183        self.assertEqual(r['n'], 0)
3184        self.assertIn('data', r)
3185        self.assertIsNone(r['data'])
3186        r = self.db.get('json_test', 0)
3187        self.assertIsInstance(r, dict)
3188        self.assertIn('n', r)
3189        self.assertEqual(r['n'], 0)
3190        self.assertIn('data', r)
3191        self.assertIsNone(r['data'])
3192        # insert JSON object
3193        data = {
3194            "id": 1, "name": "Foo", "price": 1234.5,
3195            "new": True, "note": None,
3196            "tags": ["Bar", "Eek"],
3197            "stock": {"warehouse": 300, "retail": 20}}
3198        r = self.db.insert('json_test', n=1, data=data)
3199        self.assertIsInstance(r, dict)
3200        self.assertIn('n', r)
3201        self.assertEqual(r['n'], 1)
3202        self.assertIn('data', r)
3203        r = r['data']
3204        if jsondecode is None:
3205            self.assertIsInstance(r, str)
3206            r = json.loads(r)
3207        self.assertIsInstance(r, dict)
3208        self.assertEqual(r, data)
3209        self.assertIsInstance(r['id'], int)
3210        self.assertIsInstance(r['name'], unicode)
3211        self.assertIsInstance(r['price'], float)
3212        self.assertIsInstance(r['new'], bool)
3213        self.assertIsInstance(r['tags'], list)
3214        self.assertIsInstance(r['stock'], dict)
3215        r = self.db.get('json_test', 1)
3216        self.assertIsInstance(r, dict)
3217        self.assertIn('n', r)
3218        self.assertEqual(r['n'], 1)
3219        self.assertIn('data', r)
3220        r = r['data']
3221        if jsondecode is None:
3222            self.assertIsInstance(r, str)
3223            r = json.loads(r)
3224        self.assertIsInstance(r, dict)
3225        self.assertEqual(r, data)
3226        self.assertIsInstance(r['id'], int)
3227        self.assertIsInstance(r['name'], unicode)
3228        self.assertIsInstance(r['price'], float)
3229        self.assertIsInstance(r['new'], bool)
3230        self.assertIsInstance(r['tags'], list)
3231        self.assertIsInstance(r['stock'], dict)
3232        # insert JSON object as text
3233        self.db.insert('json_test', n=2, data=json.dumps(data))
3234        q = "select data from json_test where n in (1, 2) order by n"
3235        r = self.db.query(q).getresult()
3236        self.assertEqual(len(r), 2)
3237        self.assertIsInstance(r[0][0], str if jsondecode is None else dict)
3238        self.assertEqual(r[0][0], r[1][0])
3239
3240    def testInsertGetJsonb(self):
3241        try:
3242            self.createTable('jsonb_test',
3243                             'n smallint primary key, data jsonb')
3244        except pg.ProgrammingError as error:
3245            if self.db.server_version < 90400:
3246                self.skipTest('database does not support jsonb')
3247            self.fail(str(error))
3248        jsondecode = pg.get_jsondecode()
3249        # insert null value
3250        r = self.db.insert('jsonb_test', n=0, data=None)
3251        self.assertIsInstance(r, dict)
3252        self.assertIn('n', r)
3253        self.assertEqual(r['n'], 0)
3254        self.assertIn('data', r)
3255        self.assertIsNone(r['data'])
3256        r = self.db.get('jsonb_test', 0)
3257        self.assertIsInstance(r, dict)
3258        self.assertIn('n', r)
3259        self.assertEqual(r['n'], 0)
3260        self.assertIn('data', r)
3261        self.assertIsNone(r['data'])
3262        # insert JSON object
3263        data = {
3264            "id": 1, "name": "Foo", "price": 1234.5,
3265            "new": True, "note": None,
3266            "tags": ["Bar", "Eek"],
3267            "stock": {"warehouse": 300, "retail": 20}}
3268        r = self.db.insert('jsonb_test', n=1, data=data)
3269        self.assertIsInstance(r, dict)
3270        self.assertIn('n', r)
3271        self.assertEqual(r['n'], 1)
3272        self.assertIn('data', r)
3273        r = r['data']
3274        if jsondecode is None:
3275            self.assertIsInstance(r, str)
3276            r = json.loads(r)
3277        self.assertIsInstance(r, dict)
3278        self.assertEqual(r, data)
3279        self.assertIsInstance(r['id'], int)
3280        self.assertIsInstance(r['name'], unicode)
3281        self.assertIsInstance(r['price'], float)
3282        self.assertIsInstance(r['new'], bool)
3283        self.assertIsInstance(r['tags'], list)
3284        self.assertIsInstance(r['stock'], dict)
3285        r = self.db.get('jsonb_test', 1)
3286        self.assertIsInstance(r, dict)
3287        self.assertIn('n', r)
3288        self.assertEqual(r['n'], 1)
3289        self.assertIn('data', r)
3290        r = r['data']
3291        if jsondecode is None:
3292            self.assertIsInstance(r, str)
3293            r = json.loads(r)
3294        self.assertIsInstance(r, dict)
3295        self.assertEqual(r, data)
3296        self.assertIsInstance(r['id'], int)
3297        self.assertIsInstance(r['name'], unicode)
3298        self.assertIsInstance(r['price'], float)
3299        self.assertIsInstance(r['new'], bool)
3300        self.assertIsInstance(r['tags'], list)
3301        self.assertIsInstance(r['stock'], dict)
3302
3303    def testArray(self):
3304        returns_arrays = pg.get_array()
3305        self.createTable('arraytest',
3306            'id smallint, i2 smallint[], i4 integer[], i8 bigint[],'
3307            ' d numeric[], f4 real[], f8 double precision[], m money[],'
3308            ' b bool[], v4 varchar(4)[], c4 char(4)[], t text[]')
3309        r = self.db.get_attnames('arraytest')
3310        if self.regtypes:
3311            self.assertEqual(r, dict(
3312                id='smallint', i2='smallint[]', i4='integer[]', i8='bigint[]',
3313                d='numeric[]', f4='real[]', f8='double precision[]',
3314                m='money[]', b='boolean[]',
3315                v4='character varying[]', c4='character[]', t='text[]'))
3316        else:
3317            self.assertEqual(r, dict(
3318                id='int', i2='int[]', i4='int[]', i8='int[]',
3319                d='num[]', f4='float[]', f8='float[]', m='money[]',
3320                b='bool[]', v4='text[]', c4='text[]', t='text[]'))
3321        decimal = pg.get_decimal()
3322        if decimal is Decimal:
3323            long_decimal = decimal('123456789.123456789')
3324            odd_money = decimal('1234567891234567.89')
3325        else:
3326            long_decimal = decimal('12345671234.5')
3327            odd_money = decimal('1234567123.25')
3328        t, f = (True, False) if pg.get_bool() else ('t', 'f')
3329        data = dict(id=42, i2=[42, 1234, None, 0, -1],
3330            i4=[42, 123456789, None, 0, 1, -1],
3331            i8=[long(42), long(123456789123456789), None,
3332                long(0), long(1), long(-1)],
3333            d=[decimal(42), long_decimal, None,
3334               decimal(0), decimal(1), decimal(-1), -long_decimal],
3335            f4=[42.0, 1234.5, None, 0.0, 1.0, -1.0,
3336                float('inf'), float('-inf')],
3337            f8=[42.0, 12345671234.5, None, 0.0, 1.0, -1.0,
3338                float('inf'), float('-inf')],
3339            m=[decimal('42.00'), odd_money, None,
3340               decimal('0.00'), decimal('1.00'), decimal('-1.00'), -odd_money],
3341            b=[t, f, t, None, f, t, None, None, t],
3342            v4=['abc', '"Hi"', '', None], c4=['abc ', '"Hi"', '    ', None],
3343            t=['abc', 'Hello, World!', '"Hello, World!"', '', None])
3344        r = data.copy()
3345        self.db.insert('arraytest', r)
3346        if returns_arrays:
3347            self.assertEqual(r, data)
3348        else:
3349            self.assertEqual(r['i4'], '{42,123456789,NULL,0,1,-1}')
3350        self.db.insert('arraytest', r)
3351        r = self.db.get('arraytest', 42, 'id')
3352        if returns_arrays:
3353            self.assertEqual(r, data)
3354        else:
3355            self.assertEqual(r['i4'], '{42,123456789,NULL,0,1,-1}')
3356        r = self.db.query('select * from arraytest limit 1').dictresult()[0]
3357        if returns_arrays:
3358            self.assertEqual(r, data)
3359        else:
3360            self.assertEqual(r['i4'], '{42,123456789,NULL,0,1,-1}')
3361
3362    def testArrayLiteral(self):
3363        insert = self.db.insert
3364        returns_arrays = pg.get_array()
3365        self.createTable('arraytest', 'i int[], t text[]', oids=True)
3366        r = dict(i=[1, 2, 3], t=['a', 'b', 'c'])
3367        insert('arraytest', r)
3368        if returns_arrays:
3369            self.assertEqual(r['i'], [1, 2, 3])
3370            self.assertEqual(r['t'], ['a', 'b', 'c'])
3371        else:
3372            self.assertEqual(r['i'], '{1,2,3}')
3373            self.assertEqual(r['t'], '{a,b,c}')
3374        r = dict(i='{1,2,3}', t='{a,b,c}')
3375        self.db.insert('arraytest', r)
3376        if returns_arrays:
3377            self.assertEqual(r['i'], [1, 2, 3])
3378            self.assertEqual(r['t'], ['a', 'b', 'c'])
3379        else:
3380            self.assertEqual(r['i'], '{1,2,3}')
3381            self.assertEqual(r['t'], '{a,b,c}')
3382        L = pg.Literal
3383        r = dict(i=L("ARRAY[1, 2, 3]"), t=L("ARRAY['a', 'b', 'c']"))
3384        self.db.insert('arraytest', r)
3385        if returns_arrays:
3386            self.assertEqual(r['i'], [1, 2, 3])
3387            self.assertEqual(r['t'], ['a', 'b', 'c'])
3388        else:
3389            self.assertEqual(r['i'], '{1,2,3}')
3390            self.assertEqual(r['t'], '{a,b,c}')
3391        r = dict(i="1, 2, 3", t="'a', 'b', 'c'")
3392        self.assertRaises(pg.DataError, self.db.insert, 'arraytest', r)
3393
3394    def testArrayOfIds(self):
3395        array_on = pg.get_array()
3396        self.createTable('arraytest', 'c cid[], o oid[], x xid[]', oids=True)
3397        r = self.db.get_attnames('arraytest')
3398        if self.regtypes:
3399            self.assertEqual(r, dict(
3400                oid='oid', c='cid[]', o='oid[]', x='xid[]'))
3401        else:
3402            self.assertEqual(r, dict(
3403                oid='int', c='int[]', o='int[]', x='int[]'))
3404        data = dict(c=[11, 12, 13], o=[21, 22, 23], x=[31, 32, 33])
3405        r = data.copy()
3406        self.db.insert('arraytest', r)
3407        qoid = 'oid(arraytest)'
3408        oid = r.pop(qoid)
3409        if array_on:
3410            self.assertEqual(r, data)
3411        else:
3412            self.assertEqual(r['o'], '{21,22,23}')
3413        r = {qoid: oid}
3414        self.db.get('arraytest', r)
3415        self.assertEqual(oid, r.pop(qoid))
3416        if array_on:
3417            self.assertEqual(r, data)
3418        else:
3419            self.assertEqual(r['o'], '{21,22,23}')
3420
3421    def testArrayOfText(self):
3422        array_on = pg.get_array()
3423        self.createTable('arraytest', 'data text[]', oids=True)
3424        r = self.db.get_attnames('arraytest')
3425        self.assertEqual(r['data'], 'text[]')
3426        data = ['Hello, World!', '', None, '{a,b,c}', '"Hi!"',
3427                'null', 'NULL', 'Null', 'nulL',
3428                "It's all \\ kinds of\r nasty stuff!\n"]
3429        r = dict(data=data)
3430        self.db.insert('arraytest', r)
3431        if not array_on:
3432            r['data'] = pg.cast_array(r['data'])
3433        self.assertEqual(r['data'], data)
3434        self.assertIsInstance(r['data'][1], str)
3435        self.assertIsNone(r['data'][2])
3436        r['data'] = None
3437        self.db.get('arraytest', r)
3438        if not array_on:
3439            r['data'] = pg.cast_array(r['data'])
3440        self.assertEqual(r['data'], data)
3441        self.assertIsInstance(r['data'][1], str)
3442        self.assertIsNone(r['data'][2])
3443
3444    def testArrayOfBytea(self):
3445        array_on = pg.get_array()
3446        bytea_escaped = pg.get_bytea_escaped()
3447        self.createTable('arraytest', 'data bytea[]', oids=True)
3448        r = self.db.get_attnames('arraytest')
3449        self.assertEqual(r['data'], 'bytea[]')
3450        data = [b'Hello, World!', b'', None, b'{a,b,c}', b'"Hi!"',
3451                b"It's all \\ kinds \x00 of\r nasty \xff stuff!\n"]
3452        r = dict(data=data)
3453        self.db.insert('arraytest', r)
3454        if array_on:
3455            self.assertIsInstance(r['data'], list)
3456        if array_on and not bytea_escaped:
3457            self.assertEqual(r['data'], data)
3458            self.assertIsInstance(r['data'][1], bytes)
3459            self.assertIsNone(r['data'][2])
3460        else:
3461            self.assertNotEqual(r['data'], data)
3462        r['data'] = None
3463        self.db.get('arraytest', r)
3464        if array_on:
3465            self.assertIsInstance(r['data'], list)
3466        if array_on and not bytea_escaped:
3467            self.assertEqual(r['data'], data)
3468            self.assertIsInstance(r['data'][1], bytes)
3469            self.assertIsNone(r['data'][2])
3470        else:
3471            self.assertNotEqual(r['data'], data)
3472
3473    def testArrayOfJson(self):
3474        try:
3475            self.createTable('arraytest', 'data json[]', oids=True)
3476        except pg.ProgrammingError as error:
3477            if self.db.server_version < 90200:
3478                self.skipTest('database does not support json')
3479            self.fail(str(error))
3480        r = self.db.get_attnames('arraytest')
3481        self.assertEqual(r['data'], 'json[]')
3482        data = [dict(id=815, name='John Doe'), dict(id=816, name='Jane Roe')]
3483        array_on = pg.get_array()
3484        jsondecode = pg.get_jsondecode()
3485        r = dict(data=data)
3486        self.db.insert('arraytest', r)
3487        if not array_on:
3488            r['data'] = pg.cast_array(r['data'], jsondecode)
3489        if jsondecode is None:
3490            r['data'] = [json.loads(d) for d in r['data']]
3491        self.assertEqual(r['data'], data)
3492        r['data'] = None
3493        self.db.get('arraytest', r)
3494        if not array_on:
3495            r['data'] = pg.cast_array(r['data'], jsondecode)
3496        if jsondecode is None:
3497            r['data'] = [json.loads(d) for d in r['data']]
3498        self.assertEqual(r['data'], data)
3499        r = dict(data=[json.dumps(d) for d in data])
3500        self.db.insert('arraytest', r)
3501        if not array_on:
3502            r['data'] = pg.cast_array(r['data'], jsondecode)
3503        if jsondecode is None:
3504            r['data'] = [json.loads(d) for d in r['data']]
3505        self.assertEqual(r['data'], data)
3506        r['data'] = None
3507        self.db.get('arraytest', r)
3508        # insert empty json values
3509        r = dict(data=['', None])
3510        self.db.insert('arraytest', r)
3511        r = r['data']
3512        if array_on:
3513            self.assertIsInstance(r, list)
3514            self.assertEqual(len(r), 2)
3515            self.assertIsNone(r[0])
3516            self.assertIsNone(r[1])
3517        else:
3518            self.assertEqual(r, '{NULL,NULL}')
3519
3520    def testArrayOfJsonb(self):
3521        try:
3522            self.createTable('arraytest', 'data jsonb[]', oids=True)
3523        except pg.ProgrammingError as error:
3524            if self.db.server_version < 90400:
3525                self.skipTest('database does not support jsonb')
3526            self.fail(str(error))
3527        r = self.db.get_attnames('arraytest')
3528        self.assertEqual(r['data'], 'jsonb[]' if self.regtypes else 'json[]')
3529        data = [dict(id=815, name='John Doe'), dict(id=816, name='Jane Roe')]
3530        array_on = pg.get_array()
3531        jsondecode = pg.get_jsondecode()
3532        r = dict(data=data)
3533        self.db.insert('arraytest', r)
3534        if not array_on:
3535            r['data'] = pg.cast_array(r['data'], jsondecode)
3536        if jsondecode is None:
3537            r['data'] = [json.loads(d) for d in r['data']]
3538        self.assertEqual(r['data'], data)
3539        r['data'] = None
3540        self.db.get('arraytest', r)
3541        if not array_on:
3542            r['data'] = pg.cast_array(r['data'], jsondecode)
3543        if jsondecode is None:
3544            r['data'] = [json.loads(d) for d in r['data']]
3545        self.assertEqual(r['data'], data)
3546        r = dict(data=[json.dumps(d) for d in data])
3547        self.db.insert('arraytest', r)
3548        if not array_on:
3549            r['data'] = pg.cast_array(r['data'], jsondecode)
3550        if jsondecode is None:
3551            r['data'] = [json.loads(d) for d in r['data']]
3552        self.assertEqual(r['data'], data)
3553        r['data'] = None
3554        self.db.get('arraytest', r)
3555        # insert empty json values
3556        r = dict(data=['', None])
3557        self.db.insert('arraytest', r)
3558        r = r['data']
3559        if array_on:
3560            self.assertIsInstance(r, list)
3561            self.assertEqual(len(r), 2)
3562            self.assertIsNone(r[0])
3563            self.assertIsNone(r[1])
3564        else:
3565            self.assertEqual(r, '{NULL,NULL}')
3566
3567    def testDeepArray(self):
3568        array_on = pg.get_array()
3569        self.createTable('arraytest', 'data text[][][]', oids=True)
3570        r = self.db.get_attnames('arraytest')
3571        self.assertEqual(r['data'], 'text[]')
3572        data = [[['Hello, World!', '{a,b,c}', 'back\\slash']]]
3573        r = dict(data=data)
3574        self.db.insert('arraytest', r)
3575        if array_on:
3576            self.assertEqual(r['data'], data)
3577        else:
3578            self.assertTrue(r['data'].startswith('{{{"Hello,'))
3579        r['data'] = None
3580        self.db.get('arraytest', r)
3581        if array_on:
3582            self.assertEqual(r['data'], data)
3583        else:
3584            self.assertTrue(r['data'].startswith('{{{"Hello,'))
3585
3586    def testInsertUpdateGetRecord(self):
3587        query = self.db.query
3588        query('create type test_person_type as'
3589              ' (name varchar, age smallint, married bool,'
3590              ' weight real, salary money)')
3591        self.addCleanup(query, 'drop type test_person_type')
3592        self.createTable('test_person', 'person test_person_type',
3593                         temporary=False, oids=True)
3594        attnames = self.db.get_attnames('test_person')
3595        self.assertEqual(len(attnames), 2)
3596        self.assertIn('oid', attnames)
3597        self.assertIn('person', attnames)
3598        person_typ = attnames['person']
3599        if self.regtypes:
3600            self.assertEqual(person_typ, 'test_person_type')
3601        else:
3602            self.assertEqual(person_typ, 'record')
3603        if self.regtypes:
3604            self.assertEqual(person_typ.attnames,
3605                dict(name='character varying', age='smallint',
3606                    married='boolean', weight='real', salary='money'))
3607        else:
3608            self.assertEqual(person_typ.attnames,
3609                dict(name='text', age='int', married='bool',
3610                    weight='float', salary='money'))
3611        decimal = pg.get_decimal()
3612        if pg.get_bool():
3613            bool_class = bool
3614            t, f = True, False
3615        else:
3616            bool_class = str
3617            t, f = 't', 'f'
3618        person = ('John Doe', 61, t, 99.5, decimal('93456.75'))
3619        r = self.db.insert('test_person', None, person=person)
3620        p = r['person']
3621        self.assertIsInstance(p, tuple)
3622        self.assertEqual(p, person)
3623        self.assertEqual(p.name, 'John Doe')
3624        self.assertIsInstance(p.name, str)
3625        self.assertIsInstance(p.age, int)
3626        self.assertIsInstance(p.married, bool_class)
3627        self.assertIsInstance(p.weight, float)
3628        self.assertIsInstance(p.salary, decimal)
3629        person = ('Jane Roe', 59, f, 64.5, decimal('96543.25'))
3630        r['person'] = person
3631        self.db.update('test_person', r)
3632        p = r['person']
3633        self.assertIsInstance(p, tuple)
3634        self.assertEqual(p, person)
3635        self.assertEqual(p.name, 'Jane Roe')
3636        self.assertIsInstance(p.name, str)
3637        self.assertIsInstance(p.age, int)
3638        self.assertIsInstance(p.married, bool_class)
3639        self.assertIsInstance(p.weight, float)
3640        self.assertIsInstance(p.salary, decimal)
3641        r['person'] = None
3642        self.db.get('test_person', r)
3643        p = r['person']
3644        self.assertIsInstance(p, tuple)
3645        self.assertEqual(p, person)
3646        self.assertEqual(p.name, 'Jane Roe')
3647        self.assertIsInstance(p.name, str)
3648        self.assertIsInstance(p.age, int)
3649        self.assertIsInstance(p.married, bool_class)
3650        self.assertIsInstance(p.weight, float)
3651        self.assertIsInstance(p.salary, decimal)
3652        person = (None,) * 5
3653        r = self.db.insert('test_person', None, person=person)
3654        p = r['person']
3655        self.assertIsInstance(p, tuple)
3656        self.assertIsNone(p.name)
3657        self.assertIsNone(p.age)
3658        self.assertIsNone(p.married)
3659        self.assertIsNone(p.weight)
3660        self.assertIsNone(p.salary)
3661        r['person'] = None
3662        self.db.get('test_person', r)
3663        p = r['person']
3664        self.assertIsInstance(p, tuple)
3665        self.assertIsNone(p.name)
3666        self.assertIsNone(p.age)
3667        self.assertIsNone(p.married)
3668        self.assertIsNone(p.weight)
3669        self.assertIsNone(p.salary)
3670        r = self.db.insert('test_person', None, person=None)
3671        self.assertIsNone(r['person'])
3672        r['person'] = None
3673        self.db.get('test_person', r)
3674        self.assertIsNone(r['person'])
3675
3676    def testRecordInsertBytea(self):
3677        query = self.db.query
3678        query('create type test_person_type as'
3679              ' (name text, picture bytea)')
3680        self.addCleanup(query, 'drop type test_person_type')
3681        self.createTable('test_person', 'person test_person_type',
3682                         temporary=False, oids=True)
3683        person_typ = self.db.get_attnames('test_person')['person']
3684        self.assertEqual(person_typ.attnames,
3685                         dict(name='text', picture='bytea'))
3686        person = ('John Doe', b'O\x00ps\xff!')
3687        r = self.db.insert('test_person', None, person=person)
3688        p = r['person']
3689        self.assertIsInstance(p, tuple)
3690        self.assertEqual(p, person)
3691        self.assertEqual(p.name, 'John Doe')
3692        self.assertIsInstance(p.name, str)
3693        self.assertEqual(p.picture, person[1])
3694        self.assertIsInstance(p.picture, bytes)
3695
3696    def testRecordInsertJson(self):
3697        query = self.db.query
3698        try:
3699            query('create type test_person_type as'
3700                  ' (name text, data json)')
3701        except pg.ProgrammingError as error:
3702            if self.db.server_version < 90200:
3703                self.skipTest('database does not support json')
3704            self.fail(str(error))
3705        self.addCleanup(query, 'drop type test_person_type')
3706        self.createTable('test_person', 'person test_person_type',
3707                         temporary=False, oids=True)
3708        person_typ = self.db.get_attnames('test_person')['person']
3709        self.assertEqual(person_typ.attnames,
3710                         dict(name='text', data='json'))
3711        person = ('John Doe', dict(age=61, married=True, weight=99.5))
3712        r = self.db.insert('test_person', None, person=person)
3713        p = r['person']
3714        self.assertIsInstance(p, tuple)
3715        if pg.get_jsondecode() is None:
3716            p = p._replace(data=json.loads(p.data))
3717        self.assertEqual(p, person)
3718        self.assertEqual(p.name, 'John Doe')
3719        self.assertIsInstance(p.name, str)
3720        self.assertEqual(p.data, person[1])
3721        self.assertIsInstance(p.data, dict)
3722
3723    def testRecordLiteral(self):
3724        query = self.db.query
3725        query('create type test_person_type as'
3726              ' (name varchar, age smallint)')
3727        self.addCleanup(query, 'drop type test_person_type')
3728        self.createTable('test_person', 'person test_person_type',
3729                         temporary=False, oids=True)
3730        person_typ = self.db.get_attnames('test_person')['person']
3731        if self.regtypes:
3732            self.assertEqual(person_typ, 'test_person_type')
3733        else:
3734            self.assertEqual(person_typ, 'record')
3735        if self.regtypes:
3736            self.assertEqual(person_typ.attnames,
3737                             dict(name='character varying', age='smallint'))
3738        else:
3739            self.assertEqual(person_typ.attnames,
3740                             dict(name='text', age='int'))
3741        person = pg.Literal("('John Doe', 61)")
3742        r = self.db.insert('test_person', None, person=person)
3743        p = r['person']
3744        self.assertIsInstance(p, tuple)
3745        self.assertEqual(p.name, 'John Doe')
3746        self.assertIsInstance(p.name, str)
3747        self.assertEqual(p.age, 61)
3748        self.assertIsInstance(p.age, int)
3749
3750    def testDate(self):
3751        query = self.db.query
3752        for datestyle in ('ISO', 'Postgres, MDY', 'Postgres, DMY',
3753                'SQL, MDY', 'SQL, DMY', 'German'):
3754            self.db.set_parameter('datestyle', datestyle)
3755            d = date(2016, 3, 14)
3756            q = "select $1::date"
3757            r = query(q, (d,)).getresult()[0][0]
3758            self.assertIsInstance(r, date)
3759            self.assertEqual(r, d)
3760            q = "select '10000-08-01'::date, '0099-01-08 BC'::date"
3761            r = query(q).getresult()[0]
3762            self.assertIsInstance(r[0], date)
3763            self.assertIsInstance(r[1], date)
3764            self.assertEqual(r[0], date.max)
3765            self.assertEqual(r[1], date.min)
3766        q = "select 'infinity'::date, '-infinity'::date"
3767        r = query(q).getresult()[0]
3768        self.assertIsInstance(r[0], date)
3769        self.assertIsInstance(r[1], date)
3770        self.assertEqual(r[0], date.max)
3771        self.assertEqual(r[1], date.min)
3772
3773    def testTime(self):
3774        query = self.db.query
3775        d = time(15, 9, 26)
3776        q = "select $1::time"
3777        r = query(q, (d,)).getresult()[0][0]
3778        self.assertIsInstance(r, time)
3779        self.assertEqual(r, d)
3780        d = time(15, 9, 26, 535897)
3781        q = "select $1::time"
3782        r = query(q, (d,)).getresult()[0][0]
3783        self.assertIsInstance(r, time)
3784        self.assertEqual(r, d)
3785
3786    def testTimetz(self):
3787        query = self.db.query
3788        timezones = dict(CET=1, EET=2, EST=-5, UTC=0)
3789        for timezone in sorted(timezones):
3790            tz = '%+03d00' % timezones[timezone]
3791            try:
3792                tzinfo = datetime.strptime(tz, '%z').tzinfo
3793            except ValueError:  # Python < 3.2
3794                tzinfo = pg._get_timezone(tz)
3795            self.db.set_parameter('timezone', timezone)
3796            d = time(15, 9, 26, tzinfo=tzinfo)
3797            q = "select $1::timetz"
3798            r = query(q, (d,)).getresult()[0][0]
3799            self.assertIsInstance(r, time)
3800            self.assertEqual(r, d)
3801            d = time(15, 9, 26, 535897, tzinfo)
3802            q = "select $1::timetz"
3803            r = query(q, (d,)).getresult()[0][0]
3804            self.assertIsInstance(r, time)
3805            self.assertEqual(r, d)
3806
3807    def testTimestamp(self):
3808        query = self.db.query
3809        for datestyle in ('ISO', 'Postgres, MDY', 'Postgres, DMY',
3810                'SQL, MDY', 'SQL, DMY', 'German'):
3811            self.db.set_parameter('datestyle', datestyle)
3812            d = datetime(2016, 3, 14)
3813            q = "select $1::timestamp"
3814            r = query(q, (d,)).getresult()[0][0]
3815            self.assertIsInstance(r, datetime)
3816            self.assertEqual(r, d)
3817            d = datetime(2016, 3, 14, 15, 9, 26)
3818            q = "select $1::timestamp"
3819            r = query(q, (d,)).getresult()[0][0]
3820            self.assertIsInstance(r, datetime)
3821            self.assertEqual(r, d)
3822            d = datetime(2016, 3, 14, 15, 9, 26, 535897)
3823            q = "select $1::timestamp"
3824            r = query(q, (d,)).getresult()[0][0]
3825            self.assertIsInstance(r, datetime)
3826            self.assertEqual(r, d)
3827            q = ("select '10000-08-01 AD'::timestamp,"
3828                " '0099-01-08 BC'::timestamp")
3829            r = query(q).getresult()[0]
3830            self.assertIsInstance(r[0], datetime)
3831            self.assertIsInstance(r[1], datetime)
3832            self.assertEqual(r[0], datetime.max)
3833            self.assertEqual(r[1], datetime.min)
3834        q = "select 'infinity'::timestamp, '-infinity'::timestamp"
3835        r = query(q).getresult()[0]
3836        self.assertIsInstance(r[0], datetime)
3837        self.assertIsInstance(r[1], datetime)
3838        self.assertEqual(r[0], datetime.max)
3839        self.assertEqual(r[1], datetime.min)
3840
3841    def testTimestamptz(self):
3842        query = self.db.query
3843        timezones = dict(CET=1, EET=2, EST=-5, UTC=0)
3844        for timezone in sorted(timezones):
3845            tz = '%+03d00' % timezones[timezone]
3846            try:
3847                tzinfo = datetime.strptime(tz, '%z').tzinfo
3848            except ValueError:  # Python < 3.2
3849                tzinfo = pg._get_timezone(tz)
3850            self.db.set_parameter('timezone', timezone)
3851            for datestyle in ('ISO', 'Postgres, MDY', 'Postgres, DMY',
3852                    'SQL, MDY', 'SQL, DMY', 'German'):
3853                self.db.set_parameter('datestyle', datestyle)
3854                d = datetime(2016, 3, 14, tzinfo=tzinfo)
3855                q = "select $1::timestamptz"
3856                r = query(q, (d,)).getresult()[0][0]
3857                self.assertIsInstance(r, datetime)
3858                self.assertEqual(r, d)
3859                d = datetime(2016, 3, 14, 15, 9, 26, tzinfo=tzinfo)
3860                q = "select $1::timestamptz"
3861                r = query(q, (d,)).getresult()[0][0]
3862                self.assertIsInstance(r, datetime)
3863                self.assertEqual(r, d)
3864                d = datetime(2016, 3, 14, 15, 9, 26, 535897, tzinfo)
3865                q = "select $1::timestamptz"
3866                r = query(q, (d,)).getresult()[0][0]
3867                self.assertIsInstance(r, datetime)
3868                self.assertEqual(r, d)
3869                q = ("select '10000-08-01 AD'::timestamptz,"
3870                    " '0099-01-08 BC'::timestamptz")
3871                r = query(q).getresult()[0]
3872                self.assertIsInstance(r[0], datetime)
3873                self.assertIsInstance(r[1], datetime)
3874                self.assertEqual(r[0], datetime.max)
3875                self.assertEqual(r[1], datetime.min)
3876        q = "select 'infinity'::timestamptz, '-infinity'::timestamptz"
3877        r = query(q).getresult()[0]
3878        self.assertIsInstance(r[0], datetime)
3879        self.assertIsInstance(r[1], datetime)
3880        self.assertEqual(r[0], datetime.max)
3881        self.assertEqual(r[1], datetime.min)
3882
3883    def testInterval(self):
3884        query = self.db.query
3885        for intervalstyle in (
3886                'sql_standard', 'postgres', 'postgres_verbose', 'iso_8601'):
3887            self.db.set_parameter('intervalstyle', intervalstyle)
3888            d = timedelta(3)
3889            q = "select $1::interval"
3890            r = query(q, (d,)).getresult()[0][0]
3891            self.assertIsInstance(r, timedelta)
3892            self.assertEqual(r, d)
3893            d = timedelta(-30)
3894            r = query(q, (d,)).getresult()[0][0]
3895            self.assertIsInstance(r, timedelta)
3896            self.assertEqual(r, d)
3897            d = timedelta(hours=3, minutes=31, seconds=42, microseconds=5678)
3898            q = "select $1::interval"
3899            r = query(q, (d,)).getresult()[0][0]
3900            self.assertIsInstance(r, timedelta)
3901            self.assertEqual(r, d)
3902
3903    def testDateAndTimeArrays(self):
3904        dt = (date(2016, 3, 14), time(15, 9, 26))
3905        q = "select ARRAY[$1::date], ARRAY[$2::time]"
3906        r = self.db.query(q, dt).getresult()[0]
3907        self.assertIsInstance(r[0], list)
3908        self.assertEqual(r[0][0], dt[0])
3909        self.assertIsInstance(r[1], list)
3910        self.assertEqual(r[1][0], dt[1])
3911
3912    def testHstore(self):
3913        try:
3914            self.db.query("select 'k=>v'::hstore")
3915        except pg.DatabaseError:
3916            try:
3917                self.db.query("create extension hstore")
3918            except pg.DatabaseError:
3919                self.skipTest("hstore extension not enabled")
3920        d = {'k': 'v', 'foo': 'bar', 'baz': 'whatever',
3921            '1a': 'anything at all', '2=b': 'value = 2', '3>c': 'value > 3',
3922            '4"c': 'value " 4', "5'c": "value ' 5", 'hello, world': '"hi!"',
3923            'None': None, 'NULL': 'NULL', 'empty': ''}
3924        q = "select $1::hstore"
3925        r = self.db.query(q, (pg.Hstore(d),)).getresult()[0][0]
3926        self.assertIsInstance(r, dict)
3927        self.assertEqual(r, d)
3928
3929    def testUuid(self):
3930        d = UUID('{12345678-1234-5678-1234-567812345678}')
3931        q = 'select $1::uuid'
3932        r = self.db.query(q, (d,)).getresult()[0][0]
3933        self.assertIsInstance(r, UUID)
3934        self.assertEqual(r, d)
3935
3936    def testDbTypesInfo(self):
3937        dbtypes = self.db.dbtypes
3938        self.assertIsInstance(dbtypes, dict)
3939        self.assertNotIn('numeric', dbtypes)
3940        typ = dbtypes['numeric']
3941        self.assertIn('numeric', dbtypes)
3942        self.assertEqual(typ, 'numeric' if self.regtypes else 'num')
3943        self.assertEqual(typ.oid, 1700)
3944        self.assertEqual(typ.pgtype, 'numeric')
3945        self.assertEqual(typ.regtype, 'numeric')
3946        self.assertEqual(typ.simple, 'num')
3947        self.assertEqual(typ.typtype, 'b')
3948        self.assertEqual(typ.category, 'N')
3949        self.assertEqual(typ.delim, ',')
3950        self.assertEqual(typ.relid, 0)
3951        self.assertIs(dbtypes[1700], typ)
3952        self.assertNotIn('pg_type', dbtypes)
3953        typ = dbtypes['pg_type']
3954        self.assertIn('pg_type', dbtypes)
3955        self.assertEqual(typ, 'pg_type' if self.regtypes else 'record')
3956        self.assertIsInstance(typ.oid, int)
3957        self.assertEqual(typ.pgtype, 'pg_type')
3958        self.assertEqual(typ.regtype, 'pg_type')
3959        self.assertEqual(typ.simple, 'record')
3960        self.assertEqual(typ.typtype, 'c')
3961        self.assertEqual(typ.category, 'C')
3962        self.assertEqual(typ.delim, ',')
3963        self.assertNotEqual(typ.relid, 0)
3964        attnames = typ.attnames
3965        self.assertIsInstance(attnames, dict)
3966        self.assertIs(attnames, dbtypes.get_attnames('pg_type'))
3967        self.assertEqual(list(attnames)[0], 'typname')
3968        typname = attnames['typname']
3969        self.assertEqual(typname, 'name' if self.regtypes else 'text')
3970        self.assertEqual(typname.typtype, 'b')  # base
3971        self.assertEqual(typname.category, 'S')  # string
3972        self.assertEqual(list(attnames)[3], 'typlen')
3973        typlen = attnames['typlen']
3974        self.assertEqual(typlen, 'smallint' if self.regtypes else 'int')
3975        self.assertEqual(typlen.typtype, 'b')  # base
3976        self.assertEqual(typlen.category, 'N')  # numeric
3977
3978    def testDbTypesTypecast(self):
3979        dbtypes = self.db.dbtypes
3980        self.assertIsInstance(dbtypes, dict)
3981        self.assertNotIn('int4', dbtypes)
3982        self.assertIs(dbtypes.get_typecast('int4'), int)
3983        dbtypes.set_typecast('int4', float)
3984        self.assertIs(dbtypes.get_typecast('int4'), float)
3985        dbtypes.reset_typecast('int4')
3986        self.assertIs(dbtypes.get_typecast('int4'), int)
3987        dbtypes.set_typecast('int4', float)
3988        self.assertIs(dbtypes.get_typecast('int4'), float)
3989        dbtypes.reset_typecast()
3990        self.assertIs(dbtypes.get_typecast('int4'), int)
3991        self.assertNotIn('circle', dbtypes)
3992        self.assertIsNone(dbtypes.get_typecast('circle'))
3993        squared_circle = lambda v: 'Squared Circle: %s' % v
3994        dbtypes.set_typecast('circle', squared_circle)
3995        self.assertIs(dbtypes.get_typecast('circle'), squared_circle)
3996        r = self.db.query("select '0,0,1'::circle").getresult()[0][0]
3997        self.assertIn('circle', dbtypes)
3998        self.assertEqual(r, 'Squared Circle: <(0,0),1>')
3999        self.assertEqual(dbtypes.typecast('Impossible', 'circle'),
4000            'Squared Circle: Impossible')
4001        dbtypes.reset_typecast('circle')
4002        self.assertIsNone(dbtypes.get_typecast('circle'))
4003
4004    def testGetSetTypeCast(self):
4005        get_typecast = pg.get_typecast
4006        set_typecast = pg.set_typecast
4007        dbtypes = self.db.dbtypes
4008        self.assertIsInstance(dbtypes, dict)
4009        self.assertNotIn('int4', dbtypes)
4010        self.assertNotIn('real', dbtypes)
4011        self.assertNotIn('bool', dbtypes)
4012        self.assertIs(get_typecast('int4'), int)
4013        self.assertIs(get_typecast('float4'), float)
4014        self.assertIs(get_typecast('bool'), pg.cast_bool)
4015        cast_circle = get_typecast('circle')
4016        self.addCleanup(set_typecast, 'circle', cast_circle)
4017        squared_circle = lambda v: 'Squared Circle: %s' % v
4018        self.assertNotIn('circle', dbtypes)
4019        set_typecast('circle', squared_circle)
4020        self.assertNotIn('circle', dbtypes)
4021        self.assertIs(get_typecast('circle'), squared_circle)
4022        r = self.db.query("select '0,0,1'::circle").getresult()[0][0]
4023        self.assertIn('circle', dbtypes)
4024        self.assertEqual(r, 'Squared Circle: <(0,0),1>')
4025        set_typecast('circle', cast_circle)
4026        self.assertIs(get_typecast('circle'), cast_circle)
4027
4028    def testNotificationHandler(self):
4029        # the notification handler itself is tested separately
4030        f = self.db.notification_handler
4031        callback = lambda arg_dict: None
4032        handler = f('test', callback)
4033        self.assertIsInstance(handler, pg.NotificationHandler)
4034        self.assertIs(handler.db, self.db)
4035        self.assertEqual(handler.event, 'test')
4036        self.assertEqual(handler.stop_event, 'stop_test')
4037        self.assertIs(handler.callback, callback)
4038        self.assertIsInstance(handler.arg_dict, dict)
4039        self.assertEqual(handler.arg_dict, {})
4040        self.assertIsNone(handler.timeout)
4041        self.assertFalse(handler.listening)
4042        handler.close()
4043        self.assertIsNone(handler.db)
4044        self.db.reopen()
4045        self.assertIsNone(handler.db)
4046        handler = f('test2', callback, timeout=2)
4047        self.assertIsInstance(handler, pg.NotificationHandler)
4048        self.assertIs(handler.db, self.db)
4049        self.assertEqual(handler.event, 'test2')
4050        self.assertEqual(handler.stop_event, 'stop_test2')
4051        self.assertIs(handler.callback, callback)
4052        self.assertIsInstance(handler.arg_dict, dict)
4053        self.assertEqual(handler.arg_dict, {})
4054        self.assertEqual(handler.timeout, 2)
4055        self.assertFalse(handler.listening)
4056        handler.close()
4057        self.assertIsNone(handler.db)
4058        self.db.reopen()
4059        self.assertIsNone(handler.db)
4060        arg_dict = {'testing': 3}
4061        handler = f('test3', callback, arg_dict=arg_dict)
4062        self.assertIsInstance(handler, pg.NotificationHandler)
4063        self.assertIs(handler.db, self.db)
4064        self.assertEqual(handler.event, 'test3')
4065        self.assertEqual(handler.stop_event, 'stop_test3')
4066        self.assertIs(handler.callback, callback)
4067        self.assertIs(handler.arg_dict, arg_dict)
4068        self.assertEqual(arg_dict['testing'], 3)
4069        self.assertIsNone(handler.timeout)
4070        self.assertFalse(handler.listening)
4071        handler.close()
4072        self.assertIsNone(handler.db)
4073        self.db.reopen()
4074        self.assertIsNone(handler.db)
4075        handler = f('test4', callback, stop_event='stop4')
4076        self.assertIsInstance(handler, pg.NotificationHandler)
4077        self.assertIs(handler.db, self.db)
4078        self.assertEqual(handler.event, 'test4')
4079        self.assertEqual(handler.stop_event, 'stop4')
4080        self.assertIs(handler.callback, callback)
4081        self.assertIsInstance(handler.arg_dict, dict)
4082        self.assertEqual(handler.arg_dict, {})
4083        self.assertIsNone(handler.timeout)
4084        self.assertFalse(handler.listening)
4085        handler.close()
4086        self.assertIsNone(handler.db)
4087        self.db.reopen()
4088        self.assertIsNone(handler.db)
4089        arg_dict = {'testing': 5}
4090        handler = f('test5', callback, arg_dict, 1.5, 'stop5')
4091        self.assertIsInstance(handler, pg.NotificationHandler)
4092        self.assertIs(handler.db, self.db)
4093        self.assertEqual(handler.event, 'test5')
4094        self.assertEqual(handler.stop_event, 'stop5')
4095        self.assertIs(handler.callback, callback)
4096        self.assertIs(handler.arg_dict, arg_dict)
4097        self.assertEqual(arg_dict['testing'], 5)
4098        self.assertEqual(handler.timeout, 1.5)
4099        self.assertFalse(handler.listening)
4100        handler.close()
4101        self.assertIsNone(handler.db)
4102        self.db.reopen()
4103        self.assertIsNone(handler.db)
4104
4105
4106class TestDBClassNonStdOpts(TestDBClass):
4107    """Test the methods of the DB class with non-standard global options."""
4108
4109    @classmethod
4110    def setUpClass(cls):
4111        cls.saved_options = {}
4112        cls.set_option('decimal', float)
4113        not_bool = not pg.get_bool()
4114        cls.set_option('bool', not_bool)
4115        not_array = not pg.get_array()
4116        cls.set_option('array', not_array)
4117        not_bytea_escaped = not pg.get_bytea_escaped()
4118        cls.set_option('bytea_escaped', not_bytea_escaped)
4119        cls.set_option('namedresult', None)
4120        cls.set_option('jsondecode', None)
4121        db = DB()
4122        cls.regtypes = not db.use_regtypes()
4123        db.close()
4124        super(TestDBClassNonStdOpts, cls).setUpClass()
4125
4126    @classmethod
4127    def tearDownClass(cls):
4128        super(TestDBClassNonStdOpts, cls).tearDownClass()
4129        cls.reset_option('jsondecode')
4130        cls.reset_option('namedresult')
4131        cls.reset_option('bool')
4132        cls.reset_option('array')
4133        cls.reset_option('bytea_escaped')
4134        cls.reset_option('decimal')
4135
4136    @classmethod
4137    def set_option(cls, option, value):
4138        cls.saved_options[option] = getattr(pg, 'get_' + option)()
4139        return getattr(pg, 'set_' + option)(value)
4140
4141    @classmethod
4142    def reset_option(cls, option):
4143        return getattr(pg, 'set_' + option)(cls.saved_options[option])
4144
4145
4146class TestDBClassAdapter(unittest.TestCase):
4147    """Test the adapter object associatd with the DB class."""
4148
4149    def setUp(self):
4150        self.db = DB()
4151        self.adapter = self.db.adapter
4152
4153    def tearDown(self):
4154        try:
4155            self.db.close()
4156        except pg.InternalError:
4157            pass
4158
4159    def testGuessSimpleType(self):
4160        f = self.adapter.guess_simple_type
4161        self.assertEqual(f(pg.Bytea(b'test')), 'bytea')
4162        self.assertEqual(f('string'), 'text')
4163        self.assertEqual(f(b'string'), 'text')
4164        self.assertEqual(f(True), 'bool')
4165        self.assertEqual(f(3), 'int')
4166        self.assertEqual(f(2.75), 'float')
4167        self.assertEqual(f(Decimal('4.25')), 'num')
4168        self.assertEqual(f(date(2016, 1, 30)), 'date')
4169        self.assertEqual(f([1, 2, 3]), 'int[]')
4170        self.assertEqual(f([[[123]]]), 'int[]')
4171        self.assertEqual(f(['a', 'b', 'c']), 'text[]')
4172        self.assertEqual(f([[['abc']]]), 'text[]')
4173        self.assertEqual(f([False, True]), 'bool[]')
4174        self.assertEqual(f([[[False]]]), 'bool[]')
4175        r = f(('string', True, 3, 2.75, [1], [False]))
4176        self.assertEqual(r, 'record')
4177        self.assertEqual(list(r.attnames.values()),
4178            ['text', 'bool', 'int', 'float', 'int[]', 'bool[]'])
4179
4180    def testAdaptQueryTypedList(self):
4181        format_query = self.adapter.format_query
4182        self.assertRaises(TypeError, format_query,
4183            '%s,%s', (1, 2), ('int2',))
4184        self.assertRaises(TypeError, format_query,
4185            '%s,%s', (1,), ('int2', 'int2'))
4186        values = (3, 7.5, 'hello', True)
4187        types = ('int4', 'float4', 'text', 'bool')
4188        sql, params = format_query("select %s,%s,%s,%s", values, types)
4189        self.assertEqual(sql, 'select $1,$2,$3,$4')
4190        self.assertEqual(params, [3, 7.5, 'hello', 't'])
4191        types = ('bool', 'bool', 'bool', 'bool')
4192        sql, params = format_query("select %s,%s,%s,%s", values, types)
4193        self.assertEqual(sql, 'select $1,$2,$3,$4')
4194        self.assertEqual(params, ['t', 't', 'f', 't'])
4195        values = ('2016-01-30', 'current_date')
4196        types = ('date', 'date')
4197        sql, params = format_query("values(%s,%s)", values, types)
4198        self.assertEqual(sql, 'values($1,current_date)')
4199        self.assertEqual(params, ['2016-01-30'])
4200        values = ([1, 2, 3], ['a', 'b', 'c'])
4201        types = ('_int4', '_text')
4202        sql, params = format_query("%s::int4[],%s::text[]", values, types)
4203        self.assertEqual(sql, '$1::int4[],$2::text[]')
4204        self.assertEqual(params, ['{1,2,3}', '{a,b,c}'])
4205        types = ('_bool', '_bool')
4206        sql, params = format_query("%s::bool[],%s::bool[]", values, types)
4207        self.assertEqual(sql, '$1::bool[],$2::bool[]')
4208        self.assertEqual(params, ['{t,t,t}', '{f,f,f}'])
4209        values = [(3, 7.5, 'hello', True, [123], ['abc'])]
4210        t = self.adapter.simple_type
4211        typ = t('record')
4212        typ._get_attnames = lambda _self: pg.AttrDict([
4213            ('i', t('int')), ('f', t('float')),
4214            ('t', t('text')), ('b', t('bool')),
4215            ('i3', t('int[]')), ('t3', t('text[]'))])
4216        types = [typ]
4217        sql, params = format_query('select %s', values, types)
4218        self.assertEqual(sql, 'select $1')
4219        self.assertEqual(params, ['(3,7.5,hello,t,{123},{abc})'])
4220
4221    def testAdaptQueryTypedDict(self):
4222        format_query = self.adapter.format_query
4223        self.assertRaises(TypeError, format_query,
4224            '%s,%s', dict(i1=1, i2=2), dict(i1='int2'))
4225        values = dict(i=3, f=7.5, t='hello', b=True)
4226        types = dict(i='int4', f='float4',
4227            t='text', b='bool')
4228        sql, params = format_query(
4229            "select %(i)s,%(f)s,%(t)s,%(b)s", values, types)
4230        self.assertEqual(sql, 'select $3,$2,$4,$1')
4231        self.assertEqual(params, ['t', 7.5, 3, 'hello'])
4232        types = dict(i='bool', f='bool',
4233            t='bool', b='bool')
4234        sql, params = format_query(
4235            "select %(i)s,%(f)s,%(t)s,%(b)s", values, types)
4236        self.assertEqual(sql, 'select $3,$2,$4,$1')
4237        self.assertEqual(params, ['t', 't', 't', 'f'])
4238        values = dict(d1='2016-01-30', d2='current_date')
4239        types = dict(d1='date', d2='date')
4240        sql, params = format_query("values(%(d1)s,%(d2)s)", values, types)
4241        self.assertEqual(sql, 'values($1,current_date)')
4242        self.assertEqual(params, ['2016-01-30'])
4243        values = dict(i=[1, 2, 3], t=['a', 'b', 'c'])
4244        types = dict(i='_int4', t='_text')
4245        sql, params = format_query(
4246            "%(i)s::int4[],%(t)s::text[]", values, types)
4247        self.assertEqual(sql, '$1::int4[],$2::text[]')
4248        self.assertEqual(params, ['{1,2,3}', '{a,b,c}'])
4249        types = dict(i='_bool', t='_bool')
4250        sql, params = format_query(
4251            "%(i)s::bool[],%(t)s::bool[]", values, types)
4252        self.assertEqual(sql, '$1::bool[],$2::bool[]')
4253        self.assertEqual(params, ['{t,t,t}', '{f,f,f}'])
4254        values = dict(record=(3, 7.5, 'hello', True, [123], ['abc']))
4255        t = self.adapter.simple_type
4256        typ = t('record')
4257        typ._get_attnames = lambda _self: pg.AttrDict([
4258            ('i', t('int')), ('f', t('float')),
4259            ('t', t('text')), ('b', t('bool')),
4260            ('i3', t('int[]')), ('t3', t('text[]'))])
4261        types = dict(record=typ)
4262        sql, params = format_query('select %(record)s', values, types)
4263        self.assertEqual(sql, 'select $1')
4264        self.assertEqual(params, ['(3,7.5,hello,t,{123},{abc})'])
4265
4266    def testAdaptQueryUntypedList(self):
4267        format_query = self.adapter.format_query
4268        values = (3, 7.5, 'hello', True)
4269        sql, params = format_query("select %s,%s,%s,%s", values)
4270        self.assertEqual(sql, 'select $1,$2,$3,$4')
4271        self.assertEqual(params, [3, 7.5, 'hello', 't'])
4272        values = [date(2016, 1, 30), 'current_date']
4273        sql, params = format_query("values(%s,%s)", values)
4274        self.assertEqual(sql, 'values($1,$2)')
4275        self.assertEqual(params, values)
4276        values = ([1, 2, 3], ['a', 'b', 'c'], [True, False, True])
4277        sql, params = format_query("%s,%s,%s", values)
4278        self.assertEqual(sql, "$1,$2,$3")
4279        self.assertEqual(params, ['{1,2,3}', '{a,b,c}', '{t,f,t}'])
4280        values = ([[1, 2], [3, 4]], [['a', 'b'], ['c', 'd']],
4281            [[True, False], [False, True]])
4282        sql, params = format_query("%s,%s,%s", values)
4283        self.assertEqual(sql, "$1,$2,$3")
4284        self.assertEqual(params, [
4285            '{{1,2},{3,4}}', '{{a,b},{c,d}}', '{{t,f},{f,t}}'])
4286        values = [(3, 7.5, 'hello', True, [123], ['abc'])]
4287        sql, params = format_query('select %s', values)
4288        self.assertEqual(sql, 'select $1')
4289        self.assertEqual(params, ['(3,7.5,hello,t,{123},{abc})'])
4290
4291    def testAdaptQueryUntypedDict(self):
4292        format_query = self.adapter.format_query
4293        values = dict(i=3, f=7.5, t='hello', b=True)
4294        sql, params = format_query(
4295            "select %(i)s,%(f)s,%(t)s,%(b)s", values)
4296        self.assertEqual(sql, 'select $3,$2,$4,$1')
4297        self.assertEqual(params, ['t', 7.5, 3, 'hello'])
4298        values = dict(d1='2016-01-30', d2='current_date')
4299        sql, params = format_query("values(%(d1)s,%(d2)s)", values)
4300        self.assertEqual(sql, 'values($1,$2)')
4301        self.assertEqual(params, [values['d1'], values['d2']])
4302        values = dict(i=[1, 2, 3], t=['a', 'b', 'c'], b=[True, False, True])
4303        sql, params = format_query("%(i)s,%(t)s,%(b)s", values)
4304        self.assertEqual(sql, "$2,$3,$1")
4305        self.assertEqual(params, ['{t,f,t}', '{1,2,3}', '{a,b,c}'])
4306        values = dict(i=[[1, 2], [3, 4]], t=[['a', 'b'], ['c', 'd']],
4307            b=[[True, False], [False, True]])
4308        sql, params = format_query("%(i)s,%(t)s,%(b)s", values)
4309        self.assertEqual(sql, "$2,$3,$1")
4310        self.assertEqual(params, [
4311            '{{t,f},{f,t}}', '{{1,2},{3,4}}', '{{a,b},{c,d}}'])
4312        values = dict(record=(3, 7.5, 'hello', True, [123], ['abc']))
4313        sql, params = format_query('select %(record)s', values)
4314        self.assertEqual(sql, 'select $1')
4315        self.assertEqual(params, ['(3,7.5,hello,t,{123},{abc})'])
4316
4317    def testAdaptQueryInlineList(self):
4318        format_query = self.adapter.format_query
4319        values = (3, 7.5, 'hello', True)
4320        sql, params = format_query("select %s,%s,%s,%s", values, inline=True)
4321        self.assertEqual(sql, "select 3,7.5,'hello',true")
4322        self.assertEqual(params, [])
4323        values = [date(2016, 1, 30), 'current_date']
4324        sql, params = format_query("values(%s,%s)", values, inline=True)
4325        self.assertEqual(sql, "values('2016-01-30','current_date')")
4326        self.assertEqual(params, [])
4327        values = ([1, 2, 3], ['a', 'b', 'c'], [True, False, True])
4328        sql, params = format_query("%s,%s,%s", values, inline=True)
4329        self.assertEqual(sql,
4330            "ARRAY[1,2,3],ARRAY['a','b','c'],ARRAY[true,false,true]")
4331        self.assertEqual(params, [])
4332        values = ([[1, 2], [3, 4]], [['a', 'b'], ['c', 'd']],
4333            [[True, False], [False, True]])
4334        sql, params = format_query("%s,%s,%s", values, inline=True)
4335        self.assertEqual(sql, "ARRAY[[1,2],[3,4]],ARRAY[['a','b'],['c','d']],"
4336            "ARRAY[[true,false],[false,true]]")
4337        self.assertEqual(params, [])
4338        values = [(3, 7.5, 'hello', True, [123], ['abc'])]
4339        sql, params = format_query('select %s', values, inline=True)
4340        self.assertEqual(sql,
4341            "select (3,7.5,'hello',true,ARRAY[123],ARRAY['abc'])")
4342        self.assertEqual(params, [])
4343
4344    def testAdaptQueryInlineDict(self):
4345        format_query = self.adapter.format_query
4346        values = dict(i=3, f=7.5, t='hello', b=True)
4347        sql, params = format_query(
4348            "select %(i)s,%(f)s,%(t)s,%(b)s", values, inline=True)
4349        self.assertEqual(sql, "select 3,7.5,'hello',true")
4350        self.assertEqual(params, [])
4351        values = dict(d1='2016-01-30', d2='current_date')
4352        sql, params = format_query(
4353            "values(%(d1)s,%(d2)s)", values, inline=True)
4354        self.assertEqual(sql, "values('2016-01-30','current_date')")
4355        self.assertEqual(params, [])
4356        values = dict(i=[1, 2, 3], t=['a', 'b', 'c'], b=[True, False, True])
4357        sql, params = format_query("%(i)s,%(t)s,%(b)s", values, inline=True)
4358        self.assertEqual(sql,
4359            "ARRAY[1,2,3],ARRAY['a','b','c'],ARRAY[true,false,true]")
4360        self.assertEqual(params, [])
4361        values = dict(i=[[1, 2], [3, 4]], t=[['a', 'b'], ['c', 'd']],
4362            b=[[True, False], [False, True]])
4363        sql, params = format_query("%(i)s,%(t)s,%(b)s", values, inline=True)
4364        self.assertEqual(sql, "ARRAY[[1,2],[3,4]],ARRAY[['a','b'],['c','d']],"
4365            "ARRAY[[true,false],[false,true]]")
4366        self.assertEqual(params, [])
4367        values = dict(record=(3, 7.5, 'hello', True, [123], ['abc']))
4368        sql, params = format_query('select %(record)s', values, inline=True)
4369        self.assertEqual(sql,
4370            "select (3,7.5,'hello',true,ARRAY[123],ARRAY['abc'])")
4371        self.assertEqual(params, [])
4372
4373    def testAdaptQueryWithPgRepr(self):
4374        format_query = self.adapter.format_query
4375        self.assertRaises(TypeError, format_query,
4376            '%s', object(), inline=True)
4377        class TestObject:
4378            def __pg_repr__(self):
4379                return "'adapted'"
4380        sql, params = format_query('select %s', [TestObject()], inline=True)
4381        self.assertEqual(sql, "select 'adapted'")
4382        self.assertEqual(params, [])
4383        sql, params = format_query('select %s', [[TestObject()]], inline=True)
4384        self.assertEqual(sql, "select ARRAY['adapted']")
4385        self.assertEqual(params, [])
4386
4387
4388class TestSchemas(unittest.TestCase):
4389    """Test correct handling of schemas (namespaces)."""
4390
4391    cls_set_up = False
4392
4393    @classmethod
4394    def setUpClass(cls):
4395        db = DB()
4396        query = db.query
4397        for num_schema in range(5):
4398            if num_schema:
4399                schema = "s%d" % num_schema
4400                query("drop schema if exists %s cascade" % (schema,))
4401                try:
4402                    query("create schema %s" % (schema,))
4403                except pg.ProgrammingError:
4404                    raise RuntimeError("The test user cannot create schemas.\n"
4405                        "Grant create on database %s to the user"
4406                        " for running these tests." % dbname)
4407            else:
4408                schema = "public"
4409                query("drop table if exists %s.t" % (schema,))
4410                query("drop table if exists %s.t%d" % (schema, num_schema))
4411            query("create table %s.t with oids as select 1 as n, %d as d"
4412                  % (schema, num_schema))
4413            query("create table %s.t%d with oids as select 1 as n, %d as d"
4414                  % (schema, num_schema, num_schema))
4415        db.close()
4416        cls.cls_set_up = True
4417
4418    @classmethod
4419    def tearDownClass(cls):
4420        db = DB()
4421        query = db.query
4422        for num_schema in range(5):
4423            if num_schema:
4424                schema = "s%d" % num_schema
4425                query("drop schema %s cascade" % (schema,))
4426            else:
4427                schema = "public"
4428                query("drop table %s.t" % (schema,))
4429                query("drop table %s.t%d" % (schema, num_schema))
4430        db.close()
4431
4432    def setUp(self):
4433        self.assertTrue(self.cls_set_up)
4434        self.db = DB()
4435
4436    def tearDown(self):
4437        self.doCleanups()
4438        self.db.close()
4439
4440    def testGetTables(self):
4441        tables = self.db.get_tables()
4442        for num_schema in range(5):
4443            if num_schema:
4444                schema = "s" + str(num_schema)
4445            else:
4446                schema = "public"
4447            for t in (schema + ".t",
4448                    schema + ".t" + str(num_schema)):
4449                self.assertIn(t, tables)
4450
4451    def testGetAttnames(self):
4452        get_attnames = self.db.get_attnames
4453        query = self.db.query
4454        result = {'oid': 'int', 'd': 'int', 'n': 'int'}
4455        r = get_attnames("t")
4456        self.assertEqual(r, result)
4457        r = get_attnames("s4.t4")
4458        self.assertEqual(r, result)
4459        query("drop table if exists s3.t3m")
4460        self.addCleanup(query, "drop table s3.t3m")
4461        query("create table s3.t3m with oids as select 1 as m")
4462        result_m = {'oid': 'int', 'm': 'int'}
4463        r = get_attnames("s3.t3m")
4464        self.assertEqual(r, result_m)
4465        query("set search_path to s1,s3")
4466        r = get_attnames("t3")
4467        self.assertEqual(r, result)
4468        r = get_attnames("t3m")
4469        self.assertEqual(r, result_m)
4470
4471    def testGet(self):
4472        get = self.db.get
4473        query = self.db.query
4474        PrgError = pg.ProgrammingError
4475        self.assertEqual(get("t", 1, 'n')['d'], 0)
4476        self.assertEqual(get("t0", 1, 'n')['d'], 0)
4477        self.assertEqual(get("public.t", 1, 'n')['d'], 0)
4478        self.assertEqual(get("public.t0", 1, 'n')['d'], 0)
4479        self.assertRaises(PrgError, get, "public.t1", 1, 'n')
4480        self.assertEqual(get("s1.t1", 1, 'n')['d'], 1)
4481        self.assertEqual(get("s3.t", 1, 'n')['d'], 3)
4482        query("set search_path to s2,s4")
4483        self.assertRaises(PrgError, get, "t1", 1, 'n')
4484        self.assertEqual(get("t4", 1, 'n')['d'], 4)
4485        self.assertRaises(PrgError, get, "t3", 1, 'n')
4486        self.assertEqual(get("t", 1, 'n')['d'], 2)
4487        self.assertEqual(get("s3.t3", 1, 'n')['d'], 3)
4488        query("set search_path to s1,s3")
4489        self.assertRaises(PrgError, get, "t2", 1, 'n')
4490        self.assertEqual(get("t3", 1, 'n')['d'], 3)
4491        self.assertRaises(PrgError, get, "t4", 1, 'n')
4492        self.assertEqual(get("t", 1, 'n')['d'], 1)
4493        self.assertEqual(get("s4.t4", 1, 'n')['d'], 4)
4494
4495    def testMunging(self):
4496        get = self.db.get
4497        query = self.db.query
4498        r = get("t", 1, 'n')
4499        self.assertIn('oid(t)', r)
4500        query("set search_path to s2")
4501        r = get("t2", 1, 'n')
4502        self.assertIn('oid(t2)', r)
4503        query("set search_path to s3")
4504        r = get("t", 1, 'n')
4505        self.assertIn('oid(t)', r)
4506
4507
4508class TestDebug(unittest.TestCase):
4509    """Test the debug attribute of the DB class."""
4510
4511    def setUp(self):
4512        self.db = DB()
4513        self.query = self.db.query
4514        self.debug = self.db.debug
4515        self.output = StringIO()
4516        self.stdout, sys.stdout = sys.stdout, self.output
4517
4518    def tearDown(self):
4519        sys.stdout = self.stdout
4520        self.output.close()
4521        self.db.debug = debug
4522        self.db.close()
4523
4524    def get_output(self):
4525        return self.output.getvalue()
4526
4527    def send_queries(self):
4528        self.db.query("select 1")
4529        self.db.query("select 2")
4530
4531    def testDebugDefault(self):
4532        if debug:
4533            self.assertEqual(self.db.debug, debug)
4534        else:
4535            self.assertIsNone(self.db.debug)
4536
4537    def testDebugIsFalse(self):
4538        self.db.debug = False
4539        self.send_queries()
4540        self.assertEqual(self.get_output(), "")
4541
4542    def testDebugIsTrue(self):
4543        self.db.debug = True
4544        self.send_queries()
4545        self.assertEqual(self.get_output(), "select 1\nselect 2\n")
4546
4547    def testDebugIsString(self):
4548        self.db.debug = "Test with string: %s."
4549        self.send_queries()
4550        self.assertEqual(self.get_output(),
4551            "Test with string: select 1.\nTest with string: select 2.\n")
4552
4553    def testDebugIsFileLike(self):
4554        with tempfile.TemporaryFile('w+') as debug_file:
4555            self.db.debug = debug_file
4556            self.send_queries()
4557            debug_file.seek(0)
4558            output = debug_file.read()
4559            self.assertEqual(output, "select 1\nselect 2\n")
4560            self.assertEqual(self.get_output(), "")
4561
4562    def testDebugIsCallable(self):
4563        output = []
4564        self.db.debug = output.append
4565        self.db.query("select 1")
4566        self.db.query("select 2")
4567        self.assertEqual(output, ["select 1", "select 2"])
4568        self.assertEqual(self.get_output(), "")
4569
4570    def testDebugMultipleArgs(self):
4571        output = []
4572        self.db.debug = output.append
4573        args = ['Error', 42, {1: 'a', 2: 'b'}, [3, 5, 7]]
4574        self.db._do_debug(*args)
4575        self.assertEqual(output, ['\n'.join(str(arg) for arg in args)])
4576        self.assertEqual(self.get_output(), "")
4577
4578
4579class TestMemoryLeaks(unittest.TestCase):
4580    """Test that the DB class does not leak memory."""
4581
4582    def getLeaks(self, fut):
4583        ids = set()
4584        objs = []
4585        add_ids = ids.update
4586        gc.collect()
4587        objs[:] = gc.get_objects()
4588        add_ids(id(obj) for obj in objs)
4589        fut()
4590        gc.collect()
4591        objs[:] = gc.get_objects()
4592        objs[:] = [obj for obj in objs if id(obj) not in ids]
4593        if objs and sys.version_info[:3] in ((3, 5, 0), (3, 5, 1)):
4594            # workaround for Python issue 26811
4595            objs[:] = [obj for obj in objs if repr(obj) != '(<NULL>,)']
4596        self.assertEqual(len(objs), 0)
4597
4598    def testLeaksWithClose(self):
4599        def fut():
4600            db = DB()
4601            db.query("select $1::int as r", 42).dictresult()
4602            db.close()
4603        self.getLeaks(fut)
4604
4605    def testLeaksWithoutClose(self):
4606        def fut():
4607            db = DB()
4608            db.query("select $1::int as r", 42).dictresult()
4609        self.getLeaks(fut)
4610
4611
4612if __name__ == '__main__':
4613    unittest.main()
Note: See TracBrowser for help on using the repository browser.