source: trunk/tests/test_classic_dbwrapper.py

Last change on this file was 997, checked in by cito, 3 months ago

Some more IDE hints

  • Property svn:executable set to *
  • Property svn:keywords set to Id
File size: 186.2 KB
Line 
1#!/usr/bin/python
2# -*- coding: utf-8 -*-
3
4"""Test the classic PyGreSQL interface.
5
6Sub-tests for the DB wrapper object.
7
8Contributed by Christoph Zwerschke.
9
10These tests need a database to test against.
11"""
12
13try:
14    import unittest2 as unittest  # for Python < 2.7
15except ImportError:
16    import unittest
17
18import os
19import sys
20import gc
21import json
22import tempfile
23
24import pg  # the module under test
25
26from decimal import Decimal
27from datetime import date, time, datetime, timedelta
28from uuid import UUID
29from time import strftime
30from operator import itemgetter
31
32# We need a database to test against.  If LOCAL_PyGreSQL.py exists we will
33# get our information from that.  Otherwise we use the defaults.
34# The current user must have create schema privilege on the database.
35dbname = 'unittest'
36dbhost = None
37dbport = 5432
38
39debug = False  # let DB wrapper print debugging output
40
41try:
42    from .LOCAL_PyGreSQL import *
43except (ImportError, ValueError):
44    try:
45        from LOCAL_PyGreSQL import *
46    except ImportError:
47        pass
48
49try:  # noinspection PyUnresolvedReferences
50    long
51except NameError:  # Python >= 3.0
52    long = int
53
54try:  # noinspection PyUnresolvedReferences
55    unicode
56except NameError:  # Python >= 3.0
57    unicode = str
58
59try:
60    from collections import OrderedDict
61except ImportError:  # Python 2.6 or 3.0
62    OrderedDict = dict
63
64if str is bytes:  # noinspection PyUnresolvedReferences
65    from StringIO import StringIO
66else:
67    from io import StringIO
68
69windows = os.name == 'nt'
70
71# There is a known a bug in libpq under Windows which can cause
72# the interface to crash when calling PQhost():
73do_not_ask_for_host = windows
74do_not_ask_for_host_reason = 'libpq issue on Windows'
75
76
77def DB():
78    """Create a DB wrapper object connecting to the test database."""
79    db = pg.DB(dbname, dbhost, dbport)
80    if debug:
81        db.debug = debug
82    db.query("set client_min_messages=warning")
83    return db
84
85
86class TestAttrDict(unittest.TestCase):
87    """Test the simple ordered dictionary for attribute names."""
88
89    cls = pg.AttrDict
90    base = OrderedDict
91
92    def testInit(self):
93        a = self.cls()
94        self.assertIsInstance(a, self.base)
95        self.assertEqual(a, self.base())
96        items = [('id', 'int'), ('name', 'text')]
97        a = self.cls(items)
98        self.assertIsInstance(a, self.base)
99        self.assertEqual(a, self.base(items))
100        iteritems = iter(items)
101        a = self.cls(iteritems)
102        self.assertIsInstance(a, self.base)
103        self.assertEqual(a, self.base(items))
104
105    def testIter(self):
106        a = self.cls()
107        self.assertEqual(list(a), [])
108        keys = ['id', 'name', 'age']
109        items = [(key, None) for key in keys]
110        a = self.cls(items)
111        self.assertEqual(list(a), keys)
112
113    def testKeys(self):
114        a = self.cls()
115        self.assertEqual(list(a.keys()), [])
116        keys = ['id', 'name', 'age']
117        items = [(key, None) for key in keys]
118        a = self.cls(items)
119        self.assertEqual(list(a.keys()), keys)
120
121    def testValues(self):
122        a = self.cls()
123        self.assertEqual(list(a.values()), [])
124        items = [('id', 'int'), ('name', 'text')]
125        values = [item[1] for item in items]
126        a = self.cls(items)
127        self.assertEqual(list(a.values()), values)
128
129    def testItems(self):
130        a = self.cls()
131        self.assertEqual(list(a.items()), [])
132        items = [('id', 'int'), ('name', 'text')]
133        a = self.cls(items)
134        self.assertEqual(list(a.items()), items)
135
136    def testGet(self):
137        a = self.cls([('id', 1)])
138        try:
139            self.assertEqual(a['id'], 1)
140        except KeyError:
141            self.fail('AttrDict should be readable')
142
143    def testSet(self):
144        a = self.cls()
145        try:
146            a['id'] = 1
147        except TypeError:
148            pass
149        else:
150            self.fail('AttrDict should be read-only')
151
152    def testDel(self):
153        a = self.cls([('id', 1)])
154        try:
155            del a['id']
156        except TypeError:
157            pass
158        else:
159            self.fail('AttrDict should be read-only')
160
161    def testWriteMethods(self):
162        a = self.cls([('id', 1)])
163        self.assertEqual(a['id'], 1)
164        for method in 'clear', 'update', 'pop', 'setdefault', 'popitem':
165            method = getattr(a, method)
166            self.assertRaises(TypeError, method, a)
167
168
169class TestDBClassInit(unittest.TestCase):
170    """Test proper handling of errors when creating DB instances."""
171
172    def testBadParams(self):
173        self.assertRaises(TypeError, pg.DB, invalid=True)
174
175    def testDeleteDb(self):
176        db = DB()
177        del db.db
178        self.assertRaises(pg.InternalError, db.close)
179        del db
180
181
182class TestDBClassBasic(unittest.TestCase):
183    """Test existence of the DB class wrapped pg connection methods."""
184
185    def setUp(self):
186        self.db = DB()
187
188    def tearDown(self):
189        try:
190            self.db.close()
191        except pg.InternalError:
192            pass
193
194    def testAllDBAttributes(self):
195        attributes = [
196            'abort', 'adapter',
197            'backend_pid', 'begin',
198            'cancel', 'clear', 'close', 'commit',
199            'date_format', 'db', 'dbname', 'dbtypes',
200            'debug', 'decode_json', 'delete',
201            'delete_prepared', 'describe_prepared',
202            'encode_json', 'end', 'endcopy', 'error',
203            'escape_bytea', 'escape_identifier',
204            'escape_literal', 'escape_string',
205            'fileno',
206            'get', 'get_as_dict', 'get_as_list',
207            'get_attnames', 'get_cast_hook',
208            'get_databases', 'get_notice_receiver',
209            'get_parameter', 'get_relations', 'get_tables',
210            'getline', 'getlo', 'getnotify',
211            'has_table_privilege', 'host',
212            'insert', 'inserttable',
213            'locreate', 'loimport',
214            'notification_handler',
215            'options',
216            'parameter', 'pkey', 'port',
217            'prepare', 'protocol_version', 'putline',
218            'query', 'query_formatted', 'query_prepared',
219            'release', 'reopen', 'reset', 'rollback',
220            'savepoint', 'server_version',
221            'set_cast_hook', 'set_notice_receiver',
222            'set_parameter', 'socket', 'source',
223            'ssl_attributes', 'ssl_in_use',
224            'start', 'status',
225            'transaction', 'truncate',
226            'unescape_bytea', 'update', 'upsert',
227            'use_regtypes', 'user',
228        ]
229        # __dir__ is not called in Python 2.6 for old-style classes
230        db_attributes = dir(self.db) if hasattr(
231            self.db.__class__, '__class__') else self.db.__dir__()
232        db_attributes = [a for a in db_attributes
233            if not a.startswith('_')]
234        self.assertEqual(attributes, db_attributes)
235
236    def testAttributeDb(self):
237        self.assertEqual(self.db.db.db, dbname)
238
239    def testAttributeDbname(self):
240        self.assertEqual(self.db.dbname, dbname)
241
242    def testAttributeError(self):
243        error = self.db.error
244        self.assertTrue(not error or 'krb5_' in error)
245        self.assertEqual(self.db.error, self.db.db.error)
246
247    @unittest.skipIf(do_not_ask_for_host, do_not_ask_for_host_reason)
248    def testAttributeHost(self):
249        if dbhost and not dbhost.startswith('/'):
250            host = dbhost
251        else:
252            host = 'localhost'
253        self.assertIsInstance(self.db.host, str)
254        self.assertEqual(self.db.host, host)
255        self.assertEqual(self.db.db.host, host)
256
257    def testAttributeOptions(self):
258        no_options = ''
259        options = self.db.options
260        self.assertEqual(options, no_options)
261        self.assertEqual(options, self.db.db.options)
262
263    def testAttributePort(self):
264        def_port = 5432
265        port = self.db.port
266        self.assertIsInstance(port, int)
267        self.assertEqual(port, dbport or def_port)
268        self.assertEqual(port, self.db.db.port)
269
270    def testAttributeProtocolVersion(self):
271        protocol_version = self.db.protocol_version
272        self.assertIsInstance(protocol_version, int)
273        self.assertTrue(2 <= protocol_version < 4)
274        self.assertEqual(protocol_version, self.db.db.protocol_version)
275
276    def testAttributeServerVersion(self):
277        server_version = self.db.server_version
278        self.assertIsInstance(server_version, int)
279        self.assertTrue(90000 <= server_version < 120000)
280        self.assertEqual(server_version, self.db.db.server_version)
281
282    def testAttributeSocket(self):
283        socket = self.db.socket
284        self.assertIsInstance(socket, int)
285        self.assertGreaterEqual(socket, 0)
286
287    def testAttributeBackendPid(self):
288        backend_pid = self.db.backend_pid
289        self.assertIsInstance(backend_pid, int)
290        self.assertGreaterEqual(backend_pid, 1)
291
292    def testAttributeSslInUse(self):
293        ssl_in_use = self.db.ssl_in_use
294        self.assertIsInstance(ssl_in_use, bool)
295        self.assertFalse(ssl_in_use)
296
297    def testAttributeSslAttributes(self):
298        ssl_attributes = self.db.ssl_attributes
299        self.assertIsInstance(ssl_attributes, dict)
300        self.assertEqual(ssl_attributes, {
301            'cipher': None, 'compression': None, 'key_bits': None,
302            'library': None, 'protocol': None})
303
304    def testAttributeStatus(self):
305        status_ok = 1
306        status = self.db.status
307        self.assertIsInstance(status, int)
308        self.assertEqual(status, status_ok)
309        self.assertEqual(status, self.db.db.status)
310
311    def testAttributeUser(self):
312        no_user = 'Deprecated facility'
313        user = self.db.user
314        self.assertTrue(user)
315        self.assertIsInstance(user, str)
316        self.assertNotEqual(user, no_user)
317        self.assertEqual(user, self.db.db.user)
318
319    def testMethodEscapeLiteral(self):
320        self.assertEqual(self.db.escape_literal(''), "''")
321
322    def testMethodEscapeIdentifier(self):
323        self.assertEqual(self.db.escape_identifier(''), '""')
324
325    def testMethodEscapeString(self):
326        self.assertEqual(self.db.escape_string(''), '')
327
328    def testMethodEscapeBytea(self):
329        self.assertEqual(self.db.escape_bytea('').replace(
330            '\\x', '').replace('\\', ''), '')
331
332    def testMethodUnescapeBytea(self):
333        self.assertEqual(self.db.unescape_bytea(''), b'')
334
335    def testMethodDecodeJson(self):
336        self.assertEqual(self.db.decode_json('{}'), {})
337
338    def testMethodEncodeJson(self):
339        self.assertEqual(self.db.encode_json({}), '{}')
340
341    def testMethodQuery(self):
342        query = self.db.query
343        query("select 1+1")
344        query("select 1+$1+$2", 2, 3)
345        query("select 1+$1+$2", (2, 3))
346        query("select 1+$1+$2", [2, 3])
347        query("select 1+$1", 1)
348
349    def testMethodQueryEmpty(self):
350        self.assertRaises(ValueError, self.db.query, '')
351
352    def testMethodQueryDataError(self):
353        try:
354            self.db.query("select 1/0")
355        except pg.DataError as error:
356            self.assertEqual(error.sqlstate, '22012')
357
358    def testMethodEndcopy(self):
359        try:
360            self.db.endcopy()
361        except IOError:
362            pass
363
364    def testMethodClose(self):
365        self.db.close()
366        try:
367            self.db.reset()
368        except pg.Error:
369            pass
370        else:
371            self.fail('Reset should give an error for a closed connection')
372        self.assertIsNone(self.db.db)
373        self.assertRaises(pg.InternalError, self.db.close)
374        self.assertRaises(pg.InternalError, self.db.query, 'select 1')
375        self.assertRaises(pg.InternalError, getattr, self.db, 'status')
376        self.assertRaises(pg.InternalError, getattr, self.db, 'error')
377        self.assertRaises(pg.InternalError, getattr, self.db, 'absent')
378
379    def testMethodReset(self):
380        con = self.db.db
381        self.db.reset()
382        self.assertIs(self.db.db, con)
383        self.db.query("select 1+1")
384        self.db.close()
385        self.assertRaises(pg.InternalError, self.db.reset)
386
387    def testMethodReopen(self):
388        con = self.db.db
389        self.db.reopen()
390        self.assertIsNot(self.db.db, con)
391        con = self.db.db
392        self.db.query("select 1+1")
393        self.db.close()
394        self.db.reopen()
395        self.assertIsNot(self.db.db, con)
396        self.db.query("select 1+1")
397        self.db.close()
398
399    def testExistingConnection(self):
400        db = pg.DB(self.db.db)
401        self.assertIsNotNone(db.db)
402        self.assertEqual(self.db.db, db.db)
403        db.close()
404        self.assertIsNone(db.db)
405        self.assertIsNotNone(self.db.db)
406        db.reopen()
407        self.assertIsNotNone(db.db)
408        self.assertEqual(self.db.db, db.db)
409        db.close()
410        self.assertIsNone(db.db)
411        db = pg.DB(self.db)
412        self.assertEqual(self.db.db, db.db)
413        db = pg.DB(db=self.db.db)
414        self.assertEqual(self.db.db, db.db)
415
416    def testExistingDbApi2Connection(self):
417
418        class DBApi2Con:
419
420            def __init__(self, cnx):
421                self._cnx = cnx
422
423            def close(self):
424                self._cnx.close()
425
426        db2 = DBApi2Con(self.db.db)
427        db = pg.DB(db2)
428        self.assertEqual(self.db.db, db.db)
429        db.close()
430        self.assertIsNone(db.db)
431        db.reopen()
432        self.assertIsNotNone(db.db)
433        self.assertEqual(self.db.db, db.db)
434        db.close()
435        self.assertIsNone(db.db)
436        db2.close()
437
438
439class TestDBClass(unittest.TestCase):
440    """Test the methods of the DB class wrapped pg connection."""
441
442    maxDiff = 80 * 20
443
444    cls_set_up = False
445
446    regtypes = None
447
448    @classmethod
449    def setUpClass(cls):
450        db = DB()
451        db.query("drop table if exists test cascade")
452        db.query("create table test ("
453                 "i2 smallint, i4 integer, i8 bigint,"
454                 " d numeric, f4 real, f8 double precision, m money,"
455                 " v4 varchar(4), c4 char(4), t text)")
456        db.query("create or replace view test_view as"
457                 " select i4, v4 from test")
458        db.close()
459        cls.cls_set_up = True
460
461    @classmethod
462    def tearDownClass(cls):
463        db = DB()
464        db.query("drop table test cascade")
465        db.close()
466
467    def setUp(self):
468        self.assertTrue(self.cls_set_up)
469        self.db = DB()
470        if self.regtypes is None:
471            self.regtypes = self.db.use_regtypes()
472        else:
473            self.db.use_regtypes(self.regtypes)
474        query = self.db.query
475        query('set client_encoding=utf8')
476        query("set lc_monetary='C'")
477        query("set datestyle='ISO,YMD'")
478        query('set standard_conforming_strings=on')
479        try:
480            query('set bytea_output=hex')
481        except pg.ProgrammingError:
482            if self.db.server_version >= 90000:
483                raise  # ignore for older server versions
484
485    def tearDown(self):
486        self.doCleanups()
487        self.db.close()
488
489    def createTable(self, table, definition,
490                    temporary=True, oids=None, values=None):
491        query = self.db.query
492        if '"' not in table or '.' in table:
493            table = '"%s"' % table
494        if not temporary:
495            q = 'drop table if exists %s cascade' % table
496            query(q)
497            self.addCleanup(query, q)
498        temporary = 'temporary table' if temporary else 'table'
499        as_query = definition.startswith(('as ', 'AS '))
500        if not as_query and not definition.startswith('('):
501            definition = '(%s)' % definition
502        with_oids = 'with oids' if oids else 'without oids'
503        q = ['create', temporary, table]
504        if as_query:
505            q.extend([with_oids, definition])
506        else:
507            q.extend([definition, with_oids])
508        q = ' '.join(q)
509        query(q)
510        if values:
511            for params in values:
512                if not isinstance(params, (list, tuple)):
513                    params = [params]
514                values = ', '.join('$%d' % (n + 1) for n in range(len(params)))
515                q = "insert into %s values (%s)" % (table, values)
516                query(q, params)
517
518    def testClassName(self):
519        self.assertEqual(self.db.__class__.__name__, 'DB')
520
521    def testModuleName(self):
522        self.assertEqual(self.db.__module__, 'pg')
523        self.assertEqual(self.db.__class__.__module__, 'pg')
524
525    def testEscapeLiteral(self):
526        f = self.db.escape_literal
527        r = f(b"plain")
528        self.assertIsInstance(r, bytes)
529        self.assertEqual(r, b"'plain'")
530        r = f(u"plain")
531        self.assertIsInstance(r, unicode)
532        self.assertEqual(r, u"'plain'")
533        r = f(u"that's kÀse".encode('utf-8'))
534        self.assertIsInstance(r, bytes)
535        self.assertEqual(r, u"'that''s kÀse'".encode('utf-8'))
536        r = f(u"that's kÀse")
537        self.assertIsInstance(r, unicode)
538        self.assertEqual(r, u"'that''s kÀse'")
539        self.assertEqual(f(r"It's fine to have a \ inside."),
540                         r" E'It''s fine to have a \\ inside.'")
541        self.assertEqual(f('No "quotes" must be escaped.'),
542                         "'No \"quotes\" must be escaped.'")
543
544    def testEscapeIdentifier(self):
545        f = self.db.escape_identifier
546        r = f(b"plain")
547        self.assertIsInstance(r, bytes)
548        self.assertEqual(r, b'"plain"')
549        r = f(u"plain")
550        self.assertIsInstance(r, unicode)
551        self.assertEqual(r, u'"plain"')
552        r = f(u"that's kÀse".encode('utf-8'))
553        self.assertIsInstance(r, bytes)
554        self.assertEqual(r, u'"that\'s kÀse"'.encode('utf-8'))
555        r = f(u"that's kÀse")
556        self.assertIsInstance(r, unicode)
557        self.assertEqual(r, u'"that\'s kÀse"')
558        self.assertEqual(f(r"It's fine to have a \ inside."),
559                         '"It\'s fine to have a \\ inside."')
560        self.assertEqual(f('All "quotes" must be escaped.'),
561                         '"All ""quotes"" must be escaped."')
562
563    def testEscapeString(self):
564        f = self.db.escape_string
565        r = f(b"plain")
566        self.assertIsInstance(r, bytes)
567        self.assertEqual(r, b"plain")
568        r = f(u"plain")
569        self.assertIsInstance(r, unicode)
570        self.assertEqual(r, u"plain")
571        r = f(u"that's kÀse".encode('utf-8'))
572        self.assertIsInstance(r, bytes)
573        self.assertEqual(r, u"that''s kÀse".encode('utf-8'))
574        r = f(u"that's kÀse")
575        self.assertIsInstance(r, unicode)
576        self.assertEqual(r, u"that''s kÀse")
577        self.assertEqual(f(r"It's fine to have a \ inside."),
578                         r"It''s fine to have a \ inside.")
579
580    def testEscapeBytea(self):
581        f = self.db.escape_bytea
582        # note that escape_byte always returns hex output since Pg 9.0,
583        # regardless of the bytea_output setting
584        r = f(b'plain')
585        self.assertIsInstance(r, bytes)
586        self.assertEqual(r, b'\\x706c61696e')
587        r = f(u'plain')
588        self.assertIsInstance(r, unicode)
589        self.assertEqual(r, u'\\x706c61696e')
590        r = f(u"das is' kÀse".encode('utf-8'))
591        self.assertIsInstance(r, bytes)
592        self.assertEqual(r, b'\\x64617320697327206bc3a47365')
593        r = f(u"das is' kÀse")
594        self.assertIsInstance(r, unicode)
595        self.assertEqual(r, u'\\x64617320697327206bc3a47365')
596        self.assertEqual(f(b'O\x00ps\xff!'), b'\\x4f007073ff21')
597
598    def testUnescapeBytea(self):
599        f = self.db.unescape_bytea
600        r = f(b'plain')
601        self.assertIsInstance(r, bytes)
602        self.assertEqual(r, b'plain')
603        r = f(u'plain')
604        self.assertIsInstance(r, bytes)
605        self.assertEqual(r, b'plain')
606        r = f(b"das is' k\\303\\244se")
607        self.assertIsInstance(r, bytes)
608        self.assertEqual(r, u"das is' kÀse".encode('utf8'))
609        r = f(u"das is' k\\303\\244se")
610        self.assertIsInstance(r, bytes)
611        self.assertEqual(r, u"das is' kÀse".encode('utf8'))
612        self.assertEqual(f(r'O\\000ps\\377!'), b'O\\000ps\\377!')
613        self.assertEqual(f(r'\\x706c61696e'), b'\\x706c61696e')
614        self.assertEqual(f(r'\\x746861742773206be47365'),
615                         b'\\x746861742773206be47365')
616        self.assertEqual(f(r'\\x4f007073ff21'), b'\\x4f007073ff21')
617
618    def testDecodeJson(self):
619        f = self.db.decode_json
620        self.assertIsNone(f('null'))
621        data = {
622            "id": 1, "name": "Foo", "price": 1234.5,
623            "new": True, "note": None,
624            "tags": ["Bar", "Eek"],
625            "stock": {"warehouse": 300, "retail": 20}}
626        text = json.dumps(data)
627        r = f(text)
628        self.assertIsInstance(r, dict)
629        self.assertEqual(r, data)
630        self.assertIsInstance(r['id'], int)
631        self.assertIsInstance(r['name'], unicode)
632        self.assertIsInstance(r['price'], float)
633        self.assertIsInstance(r['new'], bool)
634        self.assertIsInstance(r['tags'], list)
635        self.assertIsInstance(r['stock'], dict)
636
637    def testEncodeJson(self):
638        f = self.db.encode_json
639        self.assertEqual(f(None), 'null')
640        data = {
641            "id": 1, "name": "Foo", "price": 1234.5,
642            "new": True, "note": None,
643            "tags": ["Bar", "Eek"],
644            "stock": {"warehouse": 300, "retail": 20}}
645        text = json.dumps(data)
646        r = f(data)
647        self.assertIsInstance(r, str)
648        self.assertEqual(r, text)
649
650    def testGetParameter(self):
651        f = self.db.get_parameter
652        self.assertRaises(TypeError, f)
653        self.assertRaises(TypeError, f, None)
654        self.assertRaises(TypeError, f, 42)
655        self.assertRaises(TypeError, f, '')
656        self.assertRaises(TypeError, f, [])
657        self.assertRaises(TypeError, f, [''])
658        self.assertRaises(pg.ProgrammingError, f, 'this_does_not_exist')
659        r = f('standard_conforming_strings')
660        self.assertEqual(r, 'on')
661        r = f('lc_monetary')
662        self.assertEqual(r, 'C')
663        r = f('datestyle')
664        self.assertEqual(r, 'ISO, YMD')
665        r = f('bytea_output')
666        self.assertEqual(r, 'hex')
667        r = f(['bytea_output', 'lc_monetary'])
668        self.assertIsInstance(r, list)
669        self.assertEqual(r, ['hex', 'C'])
670        r = f(('standard_conforming_strings', 'datestyle', 'bytea_output'))
671        self.assertEqual(r, ['on', 'ISO, YMD', 'hex'])
672        r = f(set(['bytea_output', 'lc_monetary']))
673        self.assertIsInstance(r, dict)
674        self.assertEqual(r, {'bytea_output': 'hex', 'lc_monetary': 'C'})
675        r = f(set(['Bytea_Output', ' LC_Monetary ']))
676        self.assertIsInstance(r, dict)
677        self.assertEqual(r, {'Bytea_Output': 'hex', ' LC_Monetary ': 'C'})
678        s = dict.fromkeys(('bytea_output', 'lc_monetary'))
679        r = f(s)
680        self.assertIs(r, s)
681        self.assertEqual(r, {'bytea_output': 'hex', 'lc_monetary': 'C'})
682        s = dict.fromkeys(('Bytea_Output', ' LC_Monetary '))
683        r = f(s)
684        self.assertIs(r, s)
685        self.assertEqual(r, {'Bytea_Output': 'hex', ' LC_Monetary ': 'C'})
686
687    def testGetParameterServerVersion(self):
688        r = self.db.get_parameter('server_version_num')
689        self.assertIsInstance(r, str)
690        s = self.db.server_version
691        self.assertIsInstance(s, int)
692        self.assertEqual(r, str(s))
693
694    def testGetParameterAll(self):
695        f = self.db.get_parameter
696        r = f('all')
697        self.assertIsInstance(r, dict)
698        self.assertEqual(r['standard_conforming_strings'], 'on')
699        self.assertEqual(r['lc_monetary'], 'C')
700        self.assertEqual(r['DateStyle'], 'ISO, YMD')
701        self.assertEqual(r['bytea_output'], 'hex')
702
703    def testSetParameter(self):
704        f = self.db.set_parameter
705        g = self.db.get_parameter
706        self.assertRaises(TypeError, f)
707        self.assertRaises(TypeError, f, None)
708        self.assertRaises(TypeError, f, 42)
709        self.assertRaises(TypeError, f, '')
710        self.assertRaises(TypeError, f, [])
711        self.assertRaises(TypeError, f, [''])
712        self.assertRaises(ValueError, f, 'all', 'invalid')
713        self.assertRaises(ValueError, f, {
714            'invalid1': 'value1', 'invalid2': 'value2'}, 'value')
715        self.assertRaises(pg.ProgrammingError, f, 'this_does_not_exist')
716        f('standard_conforming_strings', 'off')
717        self.assertEqual(g('standard_conforming_strings'), 'off')
718        f('datestyle', 'ISO, DMY')
719        self.assertEqual(g('datestyle'), 'ISO, DMY')
720        f(['standard_conforming_strings', 'datestyle'], ['on', 'ISO, DMY'])
721        self.assertEqual(g('standard_conforming_strings'), 'on')
722        self.assertEqual(g('datestyle'), 'ISO, DMY')
723        f(['default_with_oids', 'standard_conforming_strings'], 'off')
724        self.assertEqual(g('default_with_oids'), 'off')
725        self.assertEqual(g('standard_conforming_strings'), 'off')
726        f(('standard_conforming_strings', 'datestyle'), ('on', 'ISO, YMD'))
727        self.assertEqual(g('standard_conforming_strings'), 'on')
728        self.assertEqual(g('datestyle'), 'ISO, YMD')
729        f(('default_with_oids', 'standard_conforming_strings'), 'off')
730        self.assertEqual(g('default_with_oids'), 'off')
731        self.assertEqual(g('standard_conforming_strings'), 'off')
732        f(set(['default_with_oids', 'standard_conforming_strings']), 'on')
733        self.assertEqual(g('default_with_oids'), 'on')
734        self.assertEqual(g('standard_conforming_strings'), 'on')
735        self.assertRaises(ValueError, f, set(['default_with_oids',
736            'standard_conforming_strings']), ['off', 'on'])
737        f(set(['default_with_oids', 'standard_conforming_strings']),
738            ['off', 'off'])
739        self.assertEqual(g('default_with_oids'), 'off')
740        self.assertEqual(g('standard_conforming_strings'), 'off')
741        f({'standard_conforming_strings': 'on', 'datestyle': 'ISO, YMD'})
742        self.assertEqual(g('standard_conforming_strings'), 'on')
743        self.assertEqual(g('datestyle'), 'ISO, YMD')
744
745    def testResetParameter(self):
746        db = DB()
747        f = db.set_parameter
748        g = db.get_parameter
749        r = g('default_with_oids')
750        self.assertIn(r, ('on', 'off'))
751        dwi, not_dwi = r, 'off' if r == 'on' else 'on'
752        r = g('standard_conforming_strings')
753        self.assertIn(r, ('on', 'off'))
754        scs, not_scs = r, 'off' if r == 'on' else 'on'
755        f('default_with_oids', not_dwi)
756        f('standard_conforming_strings', not_scs)
757        self.assertEqual(g('default_with_oids'), not_dwi)
758        self.assertEqual(g('standard_conforming_strings'), not_scs)
759        f('default_with_oids')
760        f('standard_conforming_strings', None)
761        self.assertEqual(g('default_with_oids'), dwi)
762        self.assertEqual(g('standard_conforming_strings'), scs)
763        f('default_with_oids', not_dwi)
764        f('standard_conforming_strings', not_scs)
765        self.assertEqual(g('default_with_oids'), not_dwi)
766        self.assertEqual(g('standard_conforming_strings'), not_scs)
767        f(['default_with_oids', 'standard_conforming_strings'], None)
768        self.assertEqual(g('default_with_oids'), dwi)
769        self.assertEqual(g('standard_conforming_strings'), scs)
770        f('default_with_oids', not_dwi)
771        f('standard_conforming_strings', not_scs)
772        self.assertEqual(g('default_with_oids'), not_dwi)
773        self.assertEqual(g('standard_conforming_strings'), not_scs)
774        f(('default_with_oids', 'standard_conforming_strings'))
775        self.assertEqual(g('default_with_oids'), dwi)
776        self.assertEqual(g('standard_conforming_strings'), scs)
777        f('default_with_oids', not_dwi)
778        f('standard_conforming_strings', not_scs)
779        self.assertEqual(g('default_with_oids'), not_dwi)
780        self.assertEqual(g('standard_conforming_strings'), not_scs)
781        f(set(['default_with_oids', 'standard_conforming_strings']))
782        self.assertEqual(g('default_with_oids'), dwi)
783        self.assertEqual(g('standard_conforming_strings'), scs)
784        db.close()
785
786    def testResetParameterAll(self):
787        db = DB()
788        f = db.set_parameter
789        self.assertRaises(ValueError, f, 'all', 0)
790        self.assertRaises(ValueError, f, 'all', 'off')
791        g = db.get_parameter
792        r = g('default_with_oids')
793        self.assertIn(r, ('on', 'off'))
794        dwi, not_dwi = r, 'off' if r == 'on' else 'on'
795        r = g('standard_conforming_strings')
796        self.assertIn(r, ('on', 'off'))
797        scs, not_scs = r, 'off' if r == 'on' else 'on'
798        f('default_with_oids', not_dwi)
799        f('standard_conforming_strings', not_scs)
800        self.assertEqual(g('default_with_oids'), not_dwi)
801        self.assertEqual(g('standard_conforming_strings'), not_scs)
802        f('all')
803        self.assertEqual(g('default_with_oids'), dwi)
804        self.assertEqual(g('standard_conforming_strings'), scs)
805        db.close()
806
807    def testSetParameterLocal(self):
808        f = self.db.set_parameter
809        g = self.db.get_parameter
810        self.assertEqual(g('standard_conforming_strings'), 'on')
811        self.db.begin()
812        f('standard_conforming_strings', 'off', local=True)
813        self.assertEqual(g('standard_conforming_strings'), 'off')
814        self.db.end()
815        self.assertEqual(g('standard_conforming_strings'), 'on')
816
817    def testSetParameterSession(self):
818        f = self.db.set_parameter
819        g = self.db.get_parameter
820        self.assertEqual(g('standard_conforming_strings'), 'on')
821        self.db.begin()
822        f('standard_conforming_strings', 'off', local=False)
823        self.assertEqual(g('standard_conforming_strings'), 'off')
824        self.db.end()
825        self.assertEqual(g('standard_conforming_strings'), 'off')
826
827    def testReset(self):
828        db = DB()
829        default_datestyle = db.get_parameter('datestyle')
830        changed_datestyle = 'ISO, DMY'
831        if changed_datestyle == default_datestyle:
832            changed_datestyle == 'ISO, YMD'
833        self.db.set_parameter('datestyle', changed_datestyle)
834        r = self.db.get_parameter('datestyle')
835        self.assertEqual(r, changed_datestyle)
836        con = self.db.db
837        q = con.query("show datestyle")
838        self.db.reset()
839        r = q.getresult()[0][0]
840        self.assertEqual(r, changed_datestyle)
841        q = con.query("show datestyle")
842        r = q.getresult()[0][0]
843        self.assertEqual(r, default_datestyle)
844        r = self.db.get_parameter('datestyle')
845        self.assertEqual(r, default_datestyle)
846        db.close()
847
848    def testReopen(self):
849        db = DB()
850        default_datestyle = db.get_parameter('datestyle')
851        changed_datestyle = 'ISO, DMY'
852        if changed_datestyle == default_datestyle:
853            changed_datestyle == 'ISO, YMD'
854        self.db.set_parameter('datestyle', changed_datestyle)
855        r = self.db.get_parameter('datestyle')
856        self.assertEqual(r, changed_datestyle)
857        con = self.db.db
858        q = con.query("show datestyle")
859        self.db.reopen()
860        r = q.getresult()[0][0]
861        self.assertEqual(r, changed_datestyle)
862        self.assertRaises(TypeError, getattr, con, 'query')
863        r = self.db.get_parameter('datestyle')
864        self.assertEqual(r, default_datestyle)
865        db.close()
866
867    def testCreateTable(self):
868        table = 'test hello world'
869        values = [(2, "World!"), (1, "Hello")]
870        self.createTable(table, "n smallint, t varchar",
871                         temporary=True, oids=True, values=values)
872        r = self.db.query('select t from "%s" order by n' % table).getresult()
873        r = ', '.join(row[0] for row in r)
874        self.assertEqual(r, "Hello, World!")
875        r = self.db.query('select oid from "%s" limit 1' % table).getresult()
876        self.assertIsInstance(r[0][0], int)
877
878    def testQuery(self):
879        query = self.db.query
880        table = 'test_table'
881        self.createTable(table, "n integer", oids=True)
882        q = "insert into test_table values (1)"
883        r = query(q)
884        self.assertIsInstance(r, int)
885        q = "insert into test_table select 2"
886        r = query(q)
887        self.assertIsInstance(r, int)
888        oid = r
889        q = "select oid from test_table where n=2"
890        r = query(q).getresult()
891        self.assertEqual(len(r), 1)
892        r = r[0]
893        self.assertEqual(len(r), 1)
894        r = r[0]
895        self.assertEqual(r, oid)
896        q = "insert into test_table select 3 union select 4 union select 5"
897        r = query(q)
898        self.assertIsInstance(r, str)
899        self.assertEqual(r, '3')
900        q = "update test_table set n=4 where n<5"
901        r = query(q)
902        self.assertIsInstance(r, str)
903        self.assertEqual(r, '4')
904        q = "delete from test_table"
905        r = query(q)
906        self.assertIsInstance(r, str)
907        self.assertEqual(r, '5')
908
909    def testMultipleQueries(self):
910        self.assertEqual(self.db.query(
911            "create temporary table test_multi (n integer);"
912            "insert into test_multi values (4711);"
913            "select n from test_multi").getresult()[0][0], 4711)
914
915    def testQueryWithParams(self):
916        query = self.db.query
917        self.createTable('test_table', 'n1 integer, n2 integer', oids=True)
918        q = "insert into test_table values ($1, $2)"
919        r = query(q, (1, 2))
920        self.assertIsInstance(r, int)
921        r = query(q, [3, 4])
922        self.assertIsInstance(r, int)
923        r = query(q, [5, 6])
924        self.assertIsInstance(r, int)
925        q = "select * from test_table order by 1, 2"
926        self.assertEqual(query(q).getresult(),
927                         [(1, 2), (3, 4), (5, 6)])
928        q = "select * from test_table where n1=$1 and n2=$2"
929        self.assertEqual(query(q, 3, 4).getresult(), [(3, 4)])
930        q = "update test_table set n2=$2 where n1=$1"
931        r = query(q, 3, 7)
932        self.assertEqual(r, '1')
933        q = "select * from test_table order by 1, 2"
934        self.assertEqual(query(q).getresult(),
935                         [(1, 2), (3, 7), (5, 6)])
936        q = "delete from test_table where n2!=$1"
937        r = query(q, 4)
938        self.assertEqual(r, '3')
939
940    def testEmptyQuery(self):
941        self.assertRaises(ValueError, self.db.query, '')
942
943    def testQueryDataError(self):
944        try:
945            self.db.query("select 1/0")
946        except pg.DataError as error:
947            self.assertEqual(error.sqlstate, '22012')
948
949    def testQueryFormatted(self):
950        f = self.db.query_formatted
951        t = True if pg.get_bool() else 't'
952        # test with tuple
953        q = f("select %s::int, %s::real, %s::text, %s::bool",
954              (3, 2.5, 'hello', True))
955        r = q.getresult()[0]
956        self.assertEqual(r, (3, 2.5, 'hello', t))
957        # test with tuple, inline
958        q = f("select %s, %s, %s, %s", (3, 2.5, 'hello', True), inline=True)
959        r = q.getresult()[0]
960        if isinstance(r[1], Decimal):
961            # Python 2.6 cannot compare float and Decimal
962            r = list(r)
963            r[1] = float(r[1])
964            r = tuple(r)
965        self.assertEqual(r, (3, 2.5, 'hello', t))
966        # test with dict
967        q = f("select %(a)s::int, %(b)s::real, %(c)s::text, %(d)s::bool",
968              dict(a=3, b=2.5, c='hello', d=True))
969        r = q.getresult()[0]
970        self.assertEqual(r, (3, 2.5, 'hello', t))
971        # test with dict, inline
972        q = f("select %(a)s, %(b)s, %(c)s, %(d)s",
973              dict(a=3, b=2.5, c='hello', d=True), inline=True)
974        r = q.getresult()[0]
975        if isinstance(r[1], Decimal):
976            # Python 2.6 cannot compare float and Decimal
977            r = list(r)
978            r[1] = float(r[1])
979            r = tuple(r)
980        self.assertEqual(r, (3, 2.5, 'hello', t))
981        # test with dict and extra values
982        q = f("select %(a)s||%(b)s||%(c)s||%(d)s||'epsilon'",
983              dict(a='alpha', b='beta', c='gamma', d='delta', e='extra'))
984        r = q.getresult()[0][0]
985        self.assertEqual(r, 'alphabetagammadeltaepsilon')
986
987    def testQueryFormattedWithAny(self):
988        f = self.db.query_formatted
989        q = "select 2 = any(%s)"
990        r = f(q, [[1, 3]]).getresult()[0][0]
991        self.assertEqual(r, False if pg.get_bool() else 'f')
992        r = f(q, [[1, 2, 3]]).getresult()[0][0]
993        self.assertEqual(r, True if pg.get_bool() else 't')
994        r = f(q, [[]]).getresult()[0][0]
995        self.assertEqual(r, False if pg.get_bool() else 'f')
996        r = f(q, [[None]]).getresult()[0][0]
997        self.assertIsNone(r)
998
999    def testQueryFormattedWithoutParams(self):
1000        f = self.db.query_formatted
1001        q = "select 42"
1002        r = f(q).getresult()[0][0]
1003        self.assertEqual(r, 42)
1004        r = f(q, None).getresult()[0][0]
1005        self.assertEqual(r, 42)
1006        r = f(q, []).getresult()[0][0]
1007        self.assertEqual(r, 42)
1008        r = f(q, {}).getresult()[0][0]
1009        self.assertEqual(r, 42)
1010
1011    def testPrepare(self):
1012        p = self.db.prepare
1013        self.assertIsNone(p('my query', "select 'hello'"))
1014        self.assertIsNone(p('my other query', "select 'world'"))
1015        self.assertRaises(pg.ProgrammingError,
1016            p, 'my query', "select 'hello, too'")
1017
1018    def testPrepareUnnamed(self):
1019        p = self.db.prepare
1020        self.assertIsNone(p('', "select null"))
1021        self.assertIsNone(p(None, "select null"))
1022
1023    def testQueryPreparedWithoutParams(self):
1024        f = self.db.query_prepared
1025        self.assertRaises(pg.OperationalError, f, 'q')
1026        p = self.db.prepare
1027        p('q1', "select 17")
1028        p('q2', "select 42")
1029        r = f('q1').getresult()[0][0]
1030        self.assertEqual(r, 17)
1031        r = f('q2').getresult()[0][0]
1032        self.assertEqual(r, 42)
1033
1034    def testQueryPreparedWithParams(self):
1035        p = self.db.prepare
1036        p('sum', "select 1 + $1 + $2 + $3")
1037        p('cat', "select initcap($1) || ', ' || $2 || '!'")
1038        f = self.db.query_prepared
1039        r = f('sum', 2, 3, 5).getresult()[0][0]
1040        self.assertEqual(r, 11)
1041        r = f('cat', 'hello', 'world').getresult()[0][0]
1042        self.assertEqual(r, 'Hello, world!')
1043
1044    def testQueryPreparedUnnamedWithOutParams(self):
1045        f = self.db.query_prepared
1046        self.assertRaises(pg.OperationalError, f, None)
1047        self.assertRaises(pg.OperationalError, f, '')
1048        p = self.db.prepare
1049        # make sure all types are known so that we will not
1050        # generate other anonymous queries in the background
1051        p('', "select 'empty'::varchar")
1052        r = f(None).getresult()[0][0]
1053        self.assertEqual(r, 'empty')
1054        r = f('').getresult()[0][0]
1055        self.assertEqual(r, 'empty')
1056        p(None, "select 'none'::varchar")
1057        r = f(None).getresult()[0][0]
1058        self.assertEqual(r, 'none')
1059        r = f('').getresult()[0][0]
1060        self.assertEqual(r, 'none')
1061
1062    def testQueryPreparedUnnamedWithParams(self):
1063        p = self.db.prepare
1064        p('', "select 1 + $1 + $2")
1065        f = self.db.query_prepared
1066        r = f('', 2, 3).getresult()[0][0]
1067        self.assertEqual(r, 6)
1068        r = f(None, 2, 3).getresult()[0][0]
1069        self.assertEqual(r, 6)
1070        p(None, "select 2 + $1 + $2")
1071        f = self.db.query_prepared
1072        r = f('', 3, 4).getresult()[0][0]
1073        self.assertEqual(r, 9)
1074        r = f(None, 3, 4).getresult()[0][0]
1075        self.assertEqual(r, 9)
1076
1077    def testDescribePrepared(self):
1078        self.db.prepare('count', "select 1 as first, 2 as second")
1079        f = self.db.describe_prepared
1080        r = f('count').listfields()
1081        self.assertEqual(r, ('first', 'second'))
1082
1083    def testDescribePreparedUnnamed(self):
1084        self.db.prepare('', "select null as anon")
1085        f = self.db.describe_prepared
1086        r = f().listfields()
1087        self.assertEqual(r, ('anon',))
1088        r = f(None).listfields()
1089        self.assertEqual(r, ('anon',))
1090        r = f('').listfields()
1091        self.assertEqual(r, ('anon',))
1092
1093    def testDeletePrepared(self):
1094        f = self.db.delete_prepared
1095        f()
1096        e = pg.OperationalError
1097        self.assertRaises(e, f, 'myquery')
1098        p = self.db.prepare
1099        p('q1', "select 1")
1100        p('q2', "select 2")
1101        f('q1')
1102        f('q2')
1103        self.assertRaises(e, f, 'q1')
1104        self.assertRaises(e, f, 'q2')
1105        p('q1', "select 1")
1106        p('q2', "select 2")
1107        f()
1108        self.assertRaises(e, f, 'q1')
1109        self.assertRaises(e, f, 'q2')
1110
1111    def testPkey(self):
1112        query = self.db.query
1113        pkey = self.db.pkey
1114        self.assertRaises(KeyError, pkey, 'test')
1115        for t in ('pkeytest', 'primary key test'):
1116            self.createTable('%s0' % t, 'a smallint')
1117            self.createTable('%s1' % t, 'b smallint primary key')
1118            self.createTable('%s2' % t,
1119                'c smallint, d smallint primary key')
1120            self.createTable('%s3' % t,
1121                'e smallint, f smallint, g smallint, h smallint, i smallint,'
1122                ' primary key (f, h)')
1123            self.createTable('%s4' % t,
1124                'e smallint, f smallint, g smallint, h smallint, i smallint,'
1125                ' primary key (h, f)')
1126            self.createTable('%s5' % t,
1127                'more_than_one_letter varchar primary key')
1128            self.createTable('%s6' % t,
1129                '"with space" date primary key')
1130            self.createTable('%s7' % t,
1131                'a_very_long_column_name varchar, "with space" date, "42" int,'
1132                ' primary key (a_very_long_column_name, "with space", "42")')
1133            self.assertRaises(KeyError, pkey, '%s0' % t)
1134            self.assertEqual(pkey('%s1' % t), 'b')
1135            self.assertEqual(pkey('%s1' % t, True), ('b',))
1136            self.assertEqual(pkey('%s1' % t, composite=False), 'b')
1137            self.assertEqual(pkey('%s1' % t, composite=True), ('b',))
1138            self.assertEqual(pkey('%s2' % t), 'd')
1139            self.assertEqual(pkey('%s2' % t, composite=True), ('d',))
1140            r = pkey('%s3' % t)
1141            self.assertIsInstance(r, tuple)
1142            self.assertEqual(r, ('f', 'h'))
1143            r = pkey('%s3' % t, composite=False)
1144            self.assertIsInstance(r, tuple)
1145            self.assertEqual(r, ('f', 'h'))
1146            r = pkey('%s4' % t)
1147            self.assertIsInstance(r, tuple)
1148            self.assertEqual(r, ('h', 'f'))
1149            self.assertEqual(pkey('%s5' % t), 'more_than_one_letter')
1150            self.assertEqual(pkey('%s6' % t), 'with space')
1151            r = pkey('%s7' % t)
1152            self.assertIsInstance(r, tuple)
1153            self.assertEqual(r, (
1154                'a_very_long_column_name', 'with space', '42'))
1155            # a newly added primary key will be detected
1156            query('alter table "%s0" add primary key (a)' % t)
1157            self.assertEqual(pkey('%s0' % t), 'a')
1158            # a changed primary key will not be detected,
1159            # indicating that the internal cache is operating
1160            query('alter table "%s1" rename column b to x' % t)
1161            self.assertEqual(pkey('%s1' % t), 'b')
1162            # we get the changed primary key when the cache is flushed
1163            self.assertEqual(pkey('%s1' % t, flush=True), 'x')
1164
1165    def testGetDatabases(self):
1166        databases = self.db.get_databases()
1167        self.assertIn('template0', databases)
1168        self.assertIn('template1', databases)
1169        self.assertNotIn('not existing database', databases)
1170        self.assertIn('postgres', databases)
1171        self.assertIn(dbname, databases)
1172
1173    def testGetTables(self):
1174        get_tables = self.db.get_tables
1175        tables = ('A very Special Name', 'A_MiXeD_quoted_NaMe',
1176                  'Hello, Test World!', 'Zoro', 'a1', 'a2', 'a321',
1177                  'averyveryveryveryveryveryveryreallyreallylongtablename',
1178                  'b0', 'b3', 'x', 'xXx', 'xx', 'y', 'z')
1179        for t in tables:
1180            self.db.query('drop table if exists "%s" cascade' % t)
1181        before_tables = get_tables()
1182        self.assertIsInstance(before_tables, list)
1183        for t in before_tables:
1184            t = t.split('.', 1)
1185            self.assertGreaterEqual(len(t), 2)
1186            if len(t) > 2:
1187                self.assertTrue(t[1].startswith('"'))
1188            t = t[0]
1189            self.assertNotEqual(t, 'information_schema')
1190            self.assertFalse(t.startswith('pg_'))
1191        for t in tables:
1192            self.createTable(t, 'as select 0', temporary=False)
1193        current_tables = get_tables()
1194        new_tables = [t for t in current_tables if t not in before_tables]
1195        expected_new_tables = ['public.%s' % (
1196            '"%s"' % t if ' ' in t or t != t.lower() else t) for t in tables]
1197        self.assertEqual(new_tables, expected_new_tables)
1198        self.doCleanups()
1199        after_tables = get_tables()
1200        self.assertEqual(after_tables, before_tables)
1201
1202    def testGetSystemTables(self):
1203        get_tables = self.db.get_tables
1204        result = get_tables()
1205        self.assertNotIn('pg_catalog.pg_class', result)
1206        self.assertNotIn('information_schema.tables', result)
1207        result = get_tables(system=False)
1208        self.assertNotIn('pg_catalog.pg_class', result)
1209        self.assertNotIn('information_schema.tables', result)
1210        result = get_tables(system=True)
1211        self.assertIn('pg_catalog.pg_class', result)
1212        self.assertNotIn('information_schema.tables', result)
1213
1214    def testGetRelations(self):
1215        get_relations = self.db.get_relations
1216        result = get_relations()
1217        self.assertIn('public.test', result)
1218        self.assertIn('public.test_view', result)
1219        result = get_relations('rv')
1220        self.assertIn('public.test', result)
1221        self.assertIn('public.test_view', result)
1222        result = get_relations('r')
1223        self.assertIn('public.test', result)
1224        self.assertNotIn('public.test_view', result)
1225        result = get_relations('v')
1226        self.assertNotIn('public.test', result)
1227        self.assertIn('public.test_view', result)
1228        result = get_relations('cisSt')
1229        self.assertNotIn('public.test', result)
1230        self.assertNotIn('public.test_view', result)
1231
1232    def testGetSystemRelations(self):
1233        get_relations = self.db.get_relations
1234        result = get_relations()
1235        self.assertNotIn('pg_catalog.pg_class', result)
1236        self.assertNotIn('information_schema.tables', result)
1237        result = get_relations(system=False)
1238        self.assertNotIn('pg_catalog.pg_class', result)
1239        self.assertNotIn('information_schema.tables', result)
1240        result = get_relations(system=True)
1241        self.assertIn('pg_catalog.pg_class', result)
1242        self.assertIn('information_schema.tables', result)
1243
1244    def testGetAttnames(self):
1245        get_attnames = self.db.get_attnames
1246        self.assertRaises(pg.ProgrammingError,
1247                          self.db.get_attnames, 'does_not_exist')
1248        self.assertRaises(pg.ProgrammingError,
1249                          self.db.get_attnames, 'has.too.many.dots')
1250        r = get_attnames('test')
1251        self.assertIsInstance(r, dict)
1252        if self.regtypes:
1253            self.assertEqual(r, dict(
1254                i2='smallint', i4='integer', i8='bigint', d='numeric',
1255                f4='real', f8='double precision', m='money',
1256                v4='character varying', c4='character', t='text'))
1257        else:
1258            self.assertEqual(r, dict(
1259                i2='int', i4='int', i8='int', d='num',
1260                f4='float', f8='float', m='money',
1261                v4='text', c4='text', t='text'))
1262        self.createTable('test_table',
1263                         'n int, alpha smallint, beta bool,'
1264                         ' gamma char(5), tau text, v varchar(3)')
1265        r = get_attnames('test_table')
1266        self.assertIsInstance(r, dict)
1267        if self.regtypes:
1268            self.assertEqual(r, dict(
1269                n='integer', alpha='smallint', beta='boolean',
1270                gamma='character', tau='text', v='character varying'))
1271        else:
1272            self.assertEqual(r, dict(
1273                n='int', alpha='int', beta='bool',
1274                gamma='text', tau='text', v='text'))
1275
1276    def testGetAttnamesWithQuotes(self):
1277        get_attnames = self.db.get_attnames
1278        table = 'test table for get_attnames()'
1279        self.createTable(table,
1280            '"Prime!" smallint, "much space" integer, "Questions?" text')
1281        r = get_attnames(table)
1282        self.assertIsInstance(r, dict)
1283        if self.regtypes:
1284            self.assertEqual(r, {
1285                'Prime!': 'smallint', 'much space': 'integer',
1286                'Questions?': 'text'})
1287        else:
1288            self.assertEqual(r, {
1289                'Prime!': 'int', 'much space': 'int', 'Questions?': 'text'})
1290        table = 'yet another test table for get_attnames()'
1291        self.createTable(table,
1292                         'a smallint, b integer, c bigint,'
1293                         ' e numeric, f real, f2 double precision, m money,'
1294                         ' x smallint, y smallint, z smallint,'
1295                         ' Normal_NaMe smallint, "Special Name" smallint,'
1296                         ' t text, u char(2), v varchar(2),'
1297                         ' primary key (y, u)', oids=True)
1298        r = get_attnames(table)
1299        self.assertIsInstance(r, dict)
1300        if self.regtypes:
1301            self.assertEqual(r, {
1302                'a': 'smallint', 'b': 'integer', 'c': 'bigint',
1303                'e': 'numeric', 'f': 'real', 'f2': 'double precision',
1304                'm': 'money', 'normal_name': 'smallint',
1305                'Special Name': 'smallint', 'u': 'character',
1306                't': 'text', 'v': 'character varying', 'y': 'smallint',
1307                'x': 'smallint', 'z': 'smallint', 'oid': 'oid'})
1308        else:
1309            self.assertEqual(r, {'a': 'int', 'b': 'int', 'c': 'int',
1310                 'e': 'num', 'f': 'float', 'f2': 'float', 'm': 'money',
1311                 'normal_name': 'int', 'Special Name': 'int',
1312                 'u': 'text', 't': 'text', 'v': 'text',
1313                 'y': 'int', 'x': 'int', 'z': 'int', 'oid': 'int'})
1314
1315    def testGetAttnamesWithRegtypes(self):
1316        get_attnames = self.db.get_attnames
1317        self.createTable('test_table', 'n int, alpha smallint, beta bool,'
1318            ' gamma char(5), tau text, v varchar(3)')
1319        use_regtypes = self.db.use_regtypes
1320        regtypes = use_regtypes()
1321        self.assertEqual(regtypes, self.regtypes)
1322        use_regtypes(True)
1323        try:
1324            r = get_attnames("test_table")
1325            self.assertIsInstance(r, dict)
1326        finally:
1327            use_regtypes(regtypes)
1328        self.assertEqual(r, dict(
1329            n='integer', alpha='smallint', beta='boolean',
1330            gamma='character', tau='text', v='character varying'))
1331
1332    def testGetAttnamesWithoutRegtypes(self):
1333        get_attnames = self.db.get_attnames
1334        self.createTable('test_table', 'n int, alpha smallint, beta bool,'
1335            ' gamma char(5), tau text, v varchar(3)')
1336        use_regtypes = self.db.use_regtypes
1337        regtypes = use_regtypes()
1338        self.assertEqual(regtypes, self.regtypes)
1339        use_regtypes(False)
1340        try:
1341            r = get_attnames("test_table")
1342            self.assertIsInstance(r, dict)
1343        finally:
1344            use_regtypes(regtypes)
1345        self.assertEqual(r, dict(
1346            n='int', alpha='int', beta='bool',
1347            gamma='text', tau='text', v='text'))
1348
1349    def testGetAttnamesIsCached(self):
1350        get_attnames = self.db.get_attnames
1351        int_type = 'integer' if self.regtypes else 'int'
1352        text_type = 'text'
1353        query = self.db.query
1354        self.createTable('test_table', 'col int')
1355        r = get_attnames("test_table")
1356        self.assertIsInstance(r, dict)
1357        self.assertEqual(r, dict(col=int_type))
1358        query("alter table test_table alter column col type text")
1359        query("alter table test_table add column col2 int")
1360        r = get_attnames("test_table")
1361        self.assertEqual(r, dict(col=int_type))
1362        r = get_attnames("test_table", flush=True)
1363        self.assertEqual(r, dict(col=text_type, col2=int_type))
1364        query("alter table test_table drop column col2")
1365        r = get_attnames("test_table")
1366        self.assertEqual(r, dict(col=text_type, col2=int_type))
1367        r = get_attnames("test_table", flush=True)
1368        self.assertEqual(r, dict(col=text_type))
1369        query("alter table test_table drop column col")
1370        r = get_attnames("test_table")
1371        self.assertEqual(r, dict(col=text_type))
1372        r = get_attnames("test_table", flush=True)
1373        self.assertEqual(r, dict())
1374
1375    def testGetAttnamesIsOrdered(self):
1376        get_attnames = self.db.get_attnames
1377        r = get_attnames('test', flush=True)
1378        self.assertIsInstance(r, OrderedDict)
1379        if self.regtypes:
1380            self.assertEqual(r, OrderedDict([
1381                ('i2', 'smallint'), ('i4', 'integer'), ('i8', 'bigint'),
1382                ('d', 'numeric'), ('f4', 'real'), ('f8', 'double precision'),
1383                ('m', 'money'), ('v4', 'character varying'),
1384                ('c4', 'character'), ('t', 'text')]))
1385        else:
1386            self.assertEqual(r, OrderedDict([
1387                ('i2', 'int'), ('i4', 'int'), ('i8', 'int'),
1388                ('d', 'num'), ('f4', 'float'), ('f8', 'float'), ('m', 'money'),
1389                ('v4', 'text'), ('c4', 'text'), ('t', 'text')]))
1390        if OrderedDict is not dict:
1391            r = ' '.join(list(r.keys()))
1392            self.assertEqual(r, 'i2 i4 i8 d f4 f8 m v4 c4 t')
1393        table = 'test table for get_attnames'
1394        self.createTable(table, 'n int, alpha smallint, v varchar(3),'
1395                ' gamma char(5), tau text, beta bool')
1396        r = get_attnames(table)
1397        self.assertIsInstance(r, OrderedDict)
1398        if self.regtypes:
1399            self.assertEqual(r, OrderedDict([
1400                ('n', 'integer'), ('alpha', 'smallint'),
1401                ('v', 'character varying'), ('gamma', 'character'),
1402                ('tau', 'text'), ('beta', 'boolean')]))
1403        else:
1404            self.assertEqual(r, OrderedDict([
1405                ('n', 'int'), ('alpha', 'int'), ('v', 'text'),
1406                ('gamma', 'text'), ('tau', 'text'), ('beta', 'bool')]))
1407        if OrderedDict is not dict:
1408            r = ' '.join(list(r.keys()))
1409            self.assertEqual(r, 'n alpha v gamma tau beta')
1410        else:
1411            self.skipTest('OrderedDict is not supported')
1412
1413    def testGetAttnamesIsAttrDict(self):
1414        AttrDict = pg.AttrDict
1415        get_attnames = self.db.get_attnames
1416        r = get_attnames('test', flush=True)
1417        self.assertIsInstance(r, AttrDict)
1418        if self.regtypes:
1419            self.assertEqual(r, AttrDict([
1420                ('i2', 'smallint'), ('i4', 'integer'), ('i8', 'bigint'),
1421                ('d', 'numeric'), ('f4', 'real'), ('f8', 'double precision'),
1422                ('m', 'money'), ('v4', 'character varying'),
1423                ('c4', 'character'), ('t', 'text')]))
1424        else:
1425            self.assertEqual(r, AttrDict([
1426                ('i2', 'int'), ('i4', 'int'), ('i8', 'int'),
1427                ('d', 'num'), ('f4', 'float'), ('f8', 'float'), ('m', 'money'),
1428                ('v4', 'text'), ('c4', 'text'), ('t', 'text')]))
1429        r = ' '.join(list(r.keys()))
1430        self.assertEqual(r, 'i2 i4 i8 d f4 f8 m v4 c4 t')
1431        table = 'test table for get_attnames'
1432        self.createTable(table, 'n int, alpha smallint, v varchar(3),'
1433                ' gamma char(5), tau text, beta bool')
1434        r = get_attnames(table)
1435        self.assertIsInstance(r, AttrDict)
1436        if self.regtypes:
1437            self.assertEqual(r, AttrDict([
1438                ('n', 'integer'), ('alpha', 'smallint'),
1439                ('v', 'character varying'), ('gamma', 'character'),
1440                ('tau', 'text'), ('beta', 'boolean')]))
1441        else:
1442            self.assertEqual(r, AttrDict([
1443                ('n', 'int'), ('alpha', 'int'), ('v', 'text'),
1444                ('gamma', 'text'), ('tau', 'text'), ('beta', 'bool')]))
1445        r = ' '.join(list(r.keys()))
1446        self.assertEqual(r, 'n alpha v gamma tau beta')
1447
1448    def testHasTablePrivilege(self):
1449        can = self.db.has_table_privilege
1450        self.assertEqual(can('test'), True)
1451        self.assertEqual(can('test', 'select'), True)
1452        self.assertEqual(can('test', 'SeLeCt'), True)
1453        self.assertEqual(can('test', 'SELECT'), True)
1454        self.assertEqual(can('test', 'insert'), True)
1455        self.assertEqual(can('test', 'update'), True)
1456        self.assertEqual(can('test', 'delete'), True)
1457        self.assertRaises(pg.DataError, can, 'test', 'foobar')
1458        self.assertRaises(pg.ProgrammingError, can, 'table_does_not_exist')
1459        r = self.db.query('select rolsuper FROM pg_roles'
1460            ' where rolname=current_user').getresult()[0][0]
1461        if not pg.get_bool():
1462            r = r == 't'
1463        if r:
1464            self.skipTest('must not be superuser')
1465        self.assertEqual(can('pg_views', 'select'), True)
1466        self.assertEqual(can('pg_views', 'delete'), False)
1467
1468    def testGet(self):
1469        get = self.db.get
1470        query = self.db.query
1471        table = 'get_test_table'
1472        self.assertRaises(TypeError, get)
1473        self.assertRaises(TypeError, get, table)
1474        self.createTable(table, 'n integer, t text',
1475                         values=enumerate('xyz', start=1))
1476        self.assertRaises(pg.ProgrammingError, get, table, 2)
1477        r = get(table, 2, 'n')
1478        self.assertIsInstance(r, dict)
1479        self.assertEqual(r, dict(n=2, t='y'))
1480        r = get(table, 1, 'n')
1481        self.assertEqual(r, dict(n=1, t='x'))
1482        r = get(table, (3,), ('n',))
1483        self.assertEqual(r, dict(n=3, t='z'))
1484        r = get(table, 'y', 't')
1485        self.assertEqual(r, dict(n=2, t='y'))
1486        self.assertRaises(pg.DatabaseError, get, table, 4)
1487        self.assertRaises(pg.DatabaseError, get, table, 4, 'n')
1488        self.assertRaises(pg.DatabaseError, get, table, 'y')
1489        self.assertRaises(pg.DatabaseError, get, table, 2, 't')
1490        s = dict(n=3)
1491        self.assertRaises(pg.ProgrammingError, get, table, s)
1492        r = get(table, s, 'n')
1493        self.assertIs(r, s)
1494        self.assertEqual(r, dict(n=3, t='z'))
1495        s.update(t='x')
1496        r = get(table, s, 't')
1497        self.assertIs(r, s)
1498        self.assertEqual(s, dict(n=1, t='x'))
1499        r = get(table, s, ('n', 't'))
1500        self.assertIs(r, s)
1501        self.assertEqual(r, dict(n=1, t='x'))
1502        query('alter table "%s" alter n set not null' % table)
1503        query('alter table "%s" add primary key (n)' % table)
1504        r = get(table, 2)
1505        self.assertIsInstance(r, dict)
1506        self.assertEqual(r, dict(n=2, t='y'))
1507        self.assertEqual(get(table, 1)['t'], 'x')
1508        self.assertEqual(get(table, 3)['t'], 'z')
1509        self.assertEqual(get(table + '*', 2)['t'], 'y')
1510        self.assertEqual(get(table + ' *', 2)['t'], 'y')
1511        self.assertRaises(KeyError, get, table, (2, 2))
1512        s = dict(n=3)
1513        r = get(table, s)
1514        self.assertIs(r, s)
1515        self.assertEqual(r, dict(n=3, t='z'))
1516        s.update(n=1)
1517        self.assertEqual(get(table, s)['t'], 'x')
1518        s.update(n=2)
1519        self.assertEqual(get(table, r)['t'], 'y')
1520        s.pop('n')
1521        self.assertRaises(KeyError, get, table, s)
1522
1523    def testGetWithOid(self):
1524        get = self.db.get
1525        query = self.db.query
1526        table = 'get_with_oid_test_table'
1527        self.createTable(table, 'n integer, t text', oids=True,
1528                         values=enumerate('xyz', start=1))
1529        self.assertRaises(pg.ProgrammingError, get, table, 2)
1530        self.assertRaises(KeyError, get, table, {}, 'oid')
1531        r = get(table, 2, 'n')
1532        qoid = 'oid(%s)' % table
1533        self.assertIn(qoid, r)
1534        oid = r[qoid]
1535        self.assertIsInstance(oid, int)
1536        result = {'t': 'y', 'n': 2, qoid: oid}
1537        self.assertEqual(r, result)
1538        r = get(table, oid, 'oid')
1539        self.assertEqual(r, result)
1540        r = get(table, dict(oid=oid))
1541        self.assertEqual(r, result)
1542        r = get(table, dict(oid=oid), 'oid')
1543        self.assertEqual(r, result)
1544        r = get(table, {qoid: oid})
1545        self.assertEqual(r, result)
1546        r = get(table, {qoid: oid}, 'oid')
1547        self.assertEqual(r, result)
1548        self.assertEqual(get(table + '*', 2, 'n'), r)
1549        self.assertEqual(get(table + ' *', 2, 'n'), r)
1550        self.assertEqual(get(table, oid, 'oid')['t'], 'y')
1551        self.assertEqual(get(table, 1, 'n')['t'], 'x')
1552        self.assertEqual(get(table, 3, 'n')['t'], 'z')
1553        self.assertEqual(get(table, 2, 'n')['t'], 'y')
1554        self.assertRaises(pg.DatabaseError, get, table, 4, 'n')
1555        r['n'] = 3
1556        self.assertEqual(get(table, r, 'n')['t'], 'z')
1557        self.assertEqual(get(table, 1, 'n')['t'], 'x')
1558        self.assertEqual(get(table, r, 'oid')['t'], 'z')
1559        query('alter table "%s" alter n set not null' % table)
1560        query('alter table "%s" add primary key (n)' % table)
1561        self.assertEqual(get(table, 3)['t'], 'z')
1562        self.assertEqual(get(table, 1)['t'], 'x')
1563        self.assertEqual(get(table, 2)['t'], 'y')
1564        r['n'] = 1
1565        self.assertEqual(get(table, r)['t'], 'x')
1566        r['n'] = 3
1567        self.assertEqual(get(table, r)['t'], 'z')
1568        r['n'] = 2
1569        self.assertEqual(get(table, r)['t'], 'y')
1570        r = get(table, oid, 'oid')
1571        self.assertEqual(r, result)
1572        r = get(table, dict(oid=oid))
1573        self.assertEqual(r, result)
1574        r = get(table, dict(oid=oid), 'oid')
1575        self.assertEqual(r, result)
1576        r = get(table, {qoid: oid})
1577        self.assertEqual(r, result)
1578        r = get(table, {qoid: oid}, 'oid')
1579        self.assertEqual(r, result)
1580        r = get(table, dict(oid=oid, n=1))
1581        self.assertEqual(r['n'], 1)
1582        self.assertNotEqual(r[qoid], oid)
1583        r = get(table, dict(oid=oid, t='z'), 't')
1584        self.assertEqual(r['n'], 3)
1585        self.assertNotEqual(r[qoid], oid)
1586
1587    def testGetWithCompositeKey(self):
1588        get = self.db.get
1589        query = self.db.query
1590        table = 'get_test_table_1'
1591        self.createTable(table, 'n integer primary key, t text',
1592                         values=enumerate('abc', start=1))
1593        self.assertEqual(get(table, 2)['t'], 'b')
1594        self.assertEqual(get(table, 1, 'n')['t'], 'a')
1595        self.assertEqual(get(table, 2, ('n',))['t'], 'b')
1596        self.assertEqual(get(table, 3, ['n'])['t'], 'c')
1597        self.assertEqual(get(table, (2,), ('n',))['t'], 'b')
1598        self.assertEqual(get(table, 'b', 't')['n'], 2)
1599        self.assertEqual(get(table, ('a',), ('t',))['n'], 1)
1600        self.assertEqual(get(table, ['c'], ['t'])['n'], 3)
1601        table = 'get_test_table_2'
1602        self.createTable(table,
1603                         'n integer, m integer, t text, primary key (n, m)',
1604                         values=[(n + 1, m + 1, chr(ord('a') + 2 * n + m))
1605                                 for n in range(3) for m in range(2)])
1606        self.assertRaises(KeyError, get, table, 2)
1607        self.assertEqual(get(table, (1, 1))['t'], 'a')
1608        self.assertEqual(get(table, (1, 2))['t'], 'b')
1609        self.assertEqual(get(table, (2, 1))['t'], 'c')
1610        self.assertEqual(get(table, (1, 2), ('n', 'm'))['t'], 'b')
1611        self.assertEqual(get(table, (1, 2), ('m', 'n'))['t'], 'c')
1612        self.assertEqual(get(table, (3, 1), ('n', 'm'))['t'], 'e')
1613        self.assertEqual(get(table, (1, 3), ('m', 'n'))['t'], 'e')
1614        self.assertEqual(get(table, dict(n=2, m=2))['t'], 'd')
1615        self.assertEqual(get(table, dict(n=1, m=2), ('n', 'm'))['t'], 'b')
1616        self.assertEqual(get(table, dict(n=2, m=1), ['n', 'm'])['t'], 'c')
1617        self.assertEqual(get(table, dict(n=3, m=2), ('m', 'n'))['t'], 'f')
1618
1619    def testGetWithQuotedNames(self):
1620        get = self.db.get
1621        query = self.db.query
1622        table = 'test table for get()'
1623        self.createTable(table, '"Prime!" smallint primary key,'
1624                                ' "much space" integer, "Questions?" text',
1625                         values=[(17, 1001, 'No!')])
1626        r = get(table, 17)
1627        self.assertIsInstance(r, dict)
1628        self.assertEqual(r['Prime!'], 17)
1629        self.assertEqual(r['much space'], 1001)
1630        self.assertEqual(r['Questions?'], 'No!')
1631
1632    def testGetFromView(self):
1633        self.db.query('delete from test where i4=14')
1634        self.db.query('insert into test (i4, v4) values('
1635                      "14, 'abc4')")
1636        r = self.db.get('test_view', 14, 'i4')
1637        self.assertIn('v4', r)
1638        self.assertEqual(r['v4'], 'abc4')
1639
1640    def testGetLittleBobbyTables(self):
1641        get = self.db.get
1642        query = self.db.query
1643        self.createTable('test_students',
1644                         'firstname varchar primary key, nickname varchar, grade char(2)',
1645                         values=[("D'Arcy", 'Darcey', 'A+'), ('Sheldon', 'Moonpie', 'A+'),
1646                                 ('Robert', 'Little Bobby Tables', 'D-')])
1647        r = get('test_students', 'Sheldon')
1648        self.assertEqual(r, dict(
1649            firstname="Sheldon", nickname='Moonpie', grade='A+'))
1650        r = get('test_students', 'Robert')
1651        self.assertEqual(r, dict(
1652            firstname="Robert", nickname='Little Bobby Tables', grade='D-'))
1653        r = get('test_students', "D'Arcy")
1654        self.assertEqual(r, dict(
1655            firstname="D'Arcy", nickname='Darcey', grade='A+'))
1656        try:
1657            get('test_students', "D' Arcy")
1658        except pg.DatabaseError as error:
1659            self.assertEqual(str(error),
1660                'No such record in test_students\nwhere "firstname" = $1\n'
1661                'with $1="D\' Arcy"')
1662        try:
1663            get('test_students', "Robert'); TRUNCATE TABLE test_students;--")
1664        except pg.DatabaseError as error:
1665            self.assertEqual(str(error),
1666                'No such record in test_students\nwhere "firstname" = $1\n'
1667                'with $1="Robert\'); TRUNCATE TABLE test_students;--"')
1668        q = "select * from test_students order by 1 limit 4"
1669        r = query(q).getresult()
1670        self.assertEqual(len(r), 3)
1671        self.assertEqual(r[1][2], 'D-')
1672
1673    def testInsert(self):
1674        insert = self.db.insert
1675        query = self.db.query
1676        bool_on = pg.get_bool()
1677        decimal = pg.get_decimal()
1678        table = 'insert_test_table'
1679        self.createTable(table,
1680            'i2 smallint, i4 integer, i8 bigint,'
1681            ' d numeric, f4 real, f8 double precision, m money,'
1682            ' v4 varchar(4), c4 char(4), t text,'
1683            ' b boolean, ts timestamp', oids=True)
1684        oid_table = 'oid(%s)' % table
1685        tests = [dict(i2=None, i4=None, i8=None),
1686             (dict(i2='', i4='', i8=''), dict(i2=None, i4=None, i8=None)),
1687             (dict(i2=0, i4=0, i8=0), dict(i2=0, i4=0, i8=0)),
1688             dict(i2=42, i4=123456, i8=9876543210),
1689             dict(i2=2 ** 15 - 1,
1690                  i4=int(2 ** 31 - 1), i8=long(2 ** 63 - 1)),
1691             dict(d=None), (dict(d=''), dict(d=None)),
1692             dict(d=Decimal(0)), (dict(d=0), dict(d=Decimal(0))),
1693             dict(f4=None, f8=None), dict(f4=0, f8=0),
1694             (dict(f4='', f8=''), dict(f4=None, f8=None)),
1695             (dict(d=1234.5, f4=1234.5, f8=1234.5),
1696              dict(d=Decimal('1234.5'))),
1697             dict(d=Decimal('123.456789'), f4=12.375, f8=123.4921875),
1698             dict(d=Decimal('123456789.9876543212345678987654321')),
1699             dict(m=None), (dict(m=''), dict(m=None)),
1700             dict(m=Decimal('-1234.56')),
1701             (dict(m=('-1234.56')), dict(m=Decimal('-1234.56'))),
1702             dict(m=Decimal('1234.56')), dict(m=Decimal('123456')),
1703             (dict(m='1234.56'), dict(m=Decimal('1234.56'))),
1704             (dict(m=1234.5), dict(m=Decimal('1234.5'))),
1705             (dict(m=-1234.5), dict(m=Decimal('-1234.5'))),
1706             (dict(m=123456), dict(m=Decimal('123456'))),
1707             (dict(m='1234567.89'), dict(m=Decimal('1234567.89'))),
1708             dict(b=None), (dict(b=''), dict(b=None)),
1709             dict(b='f'), dict(b='t'),
1710             (dict(b=0), dict(b='f')), (dict(b=1), dict(b='t')),
1711             (dict(b=False), dict(b='f')), (dict(b=True), dict(b='t')),
1712             (dict(b='0'), dict(b='f')), (dict(b='1'), dict(b='t')),
1713             (dict(b='n'), dict(b='f')), (dict(b='y'), dict(b='t')),
1714             (dict(b='no'), dict(b='f')), (dict(b='yes'), dict(b='t')),
1715             (dict(b='off'), dict(b='f')), (dict(b='on'), dict(b='t')),
1716             dict(v4=None, c4=None, t=None),
1717             (dict(v4='', c4='', t=''), dict(c4=' ' * 4)),
1718             dict(v4='1234', c4='1234', t='1234' * 10),
1719             dict(v4='abcd', c4='abcd', t='abcdefg'),
1720             (dict(v4='abc', c4='abc', t='abc'), dict(c4='abc ')),
1721             dict(ts=None), (dict(ts=''), dict(ts=None)),
1722             (dict(ts=0), dict(ts=None)), (dict(ts=False), dict(ts=None)),
1723             dict(ts='2012-12-21 00:00:00'),
1724             (dict(ts='2012-12-21'), dict(ts='2012-12-21 00:00:00')),
1725             dict(ts='2012-12-21 12:21:12'),
1726             dict(ts='2013-01-05 12:13:14'),
1727             dict(ts='current_timestamp')]
1728        for test in tests:
1729            if isinstance(test, dict):
1730                data = test
1731                change = {}
1732            else:
1733                data, change = test
1734            expect = data.copy()
1735            expect.update(change)
1736            if bool_on:
1737                b = expect.get('b')
1738                if b is not None:
1739                    expect['b'] = b == 't'
1740            if decimal is not Decimal:
1741                d = expect.get('d')
1742                if d is not None:
1743                    expect['d'] = decimal(d)
1744                m = expect.get('m')
1745                if m is not None:
1746                    expect['m'] = decimal(m)
1747            self.assertEqual(insert(table, data), data)
1748            self.assertIn(oid_table, data)
1749            oid = data[oid_table]
1750            self.assertIsInstance(oid, int)
1751            data = dict(item for item in data.items()
1752                        if item[0] in expect)
1753            ts = expect.get('ts')
1754            if ts:
1755                if ts == 'current_timestamp':
1756                    ts = data['ts']
1757                    self.assertIsInstance(ts, datetime)
1758                    self.assertEqual(ts.strftime('%Y-%m-%d'),
1759                        strftime('%Y-%m-%d'))
1760                else:
1761                    ts = datetime.strptime(ts, '%Y-%m-%d %H:%M:%S')
1762                expect['ts'] = ts
1763            self.assertEqual(data, expect)
1764            data = query(
1765                'select oid,* from "%s"' % table).dictresult()[0]
1766            self.assertEqual(data['oid'], oid)
1767            data = dict(item for item in data.items()
1768                        if item[0] in expect)
1769            self.assertEqual(data, expect)
1770            query('delete from "%s"' % table)
1771
1772    def testInsertWithOid(self):
1773        insert = self.db.insert
1774        query = self.db.query
1775        self.createTable('test_table', 'n int', oids=True)
1776        self.assertRaises(pg.ProgrammingError, insert, 'test_table', m=1)
1777        r = insert('test_table', n=1)
1778        self.assertIsInstance(r, dict)
1779        self.assertEqual(r['n'], 1)
1780        self.assertNotIn('oid', r)
1781        qoid = 'oid(test_table)'
1782        self.assertIn(qoid, r)
1783        oid = r[qoid]
1784        self.assertEqual(sorted(r.keys()), ['n', qoid])
1785        r = insert('test_table', n=2, oid=oid)
1786        self.assertIsInstance(r, dict)
1787        self.assertEqual(r['n'], 2)
1788        self.assertIn(qoid, r)
1789        self.assertNotEqual(r[qoid], oid)
1790        self.assertNotIn('oid', r)
1791        r = insert('test_table', None, n=3)
1792        self.assertIsInstance(r, dict)
1793        self.assertEqual(r['n'], 3)
1794        s = r
1795        r = insert('test_table', r)
1796        self.assertIs(r, s)
1797        self.assertEqual(r['n'], 3)
1798        r = insert('test_table *', r)
1799        self.assertIs(r, s)
1800        self.assertEqual(r['n'], 3)
1801        r = insert('test_table', r, n=4)
1802        self.assertIs(r, s)
1803        self.assertEqual(r['n'], 4)
1804        self.assertNotIn('oid', r)
1805        self.assertIn(qoid, r)
1806        oid = r[qoid]
1807        r = insert('test_table', r, n=5, oid=oid)
1808        self.assertIs(r, s)
1809        self.assertEqual(r['n'], 5)
1810        self.assertIn(qoid, r)
1811        self.assertNotEqual(r[qoid], oid)
1812        self.assertNotIn('oid', r)
1813        r['oid'] = oid = r[qoid]
1814        r = insert('test_table', r, n=6)
1815        self.assertIs(r, s)
1816        self.assertEqual(r['n'], 6)
1817        self.assertIn(qoid, r)
1818        self.assertNotEqual(r[qoid], oid)
1819        self.assertNotIn('oid', r)
1820        q = 'select n from test_table order by 1 limit 9'
1821        r = ' '.join(str(row[0]) for row in query(q).getresult())
1822        self.assertEqual(r, '1 2 3 3 3 4 5 6')
1823        query("truncate test_table")
1824        query("alter table test_table add unique (n)")
1825        r = insert('test_table', dict(n=7))
1826        self.assertIsInstance(r, dict)
1827        self.assertEqual(r['n'], 7)
1828        self.assertRaises(pg.IntegrityError, insert, 'test_table', r)
1829        r['n'] = 6
1830        self.assertRaises(pg.IntegrityError, insert, 'test_table', r, n=7)
1831        self.assertIsInstance(r, dict)
1832        self.assertEqual(r['n'], 7)
1833        r['n'] = 6
1834        r = insert('test_table', r)
1835        self.assertIsInstance(r, dict)
1836        self.assertEqual(r['n'], 6)
1837        r = ' '.join(str(row[0]) for row in query(q).getresult())
1838        self.assertEqual(r, '6 7')
1839
1840    def testInsertWithQuotedNames(self):
1841        insert = self.db.insert
1842        query = self.db.query
1843        table = 'test table for insert()'
1844        self.createTable(table, '"Prime!" smallint primary key,'
1845                                ' "much space" integer, "Questions?" text')
1846        r = {'Prime!': 11, 'much space': 2002, 'Questions?': 'What?'}
1847        r = insert(table, r)
1848        self.assertIsInstance(r, dict)
1849        self.assertEqual(r['Prime!'], 11)
1850        self.assertEqual(r['much space'], 2002)
1851        self.assertEqual(r['Questions?'], 'What?')
1852        r = query('select * from "%s" limit 2' % table).dictresult()
1853        self.assertEqual(len(r), 1)
1854        r = r[0]
1855        self.assertEqual(r['Prime!'], 11)
1856        self.assertEqual(r['much space'], 2002)
1857        self.assertEqual(r['Questions?'], 'What?')
1858
1859    def testInsertIntoView(self):
1860        insert = self.db.insert
1861        query = self.db.query
1862        query("truncate test")
1863        q = 'select * from test_view order by i4 limit 3'
1864        r = query(q).getresult()
1865        self.assertEqual(r, [])
1866        r = dict(i4=1234, v4='abcd')
1867        insert('test', r)
1868        self.assertIsNone(r['i2'])
1869        self.assertEqual(r['i4'], 1234)
1870        self.assertIsNone(r['i8'])
1871        self.assertEqual(r['v4'], 'abcd')
1872        self.assertIsNone(r['c4'])
1873        r = query(q).getresult()
1874        self.assertEqual(r, [(1234, 'abcd')])
1875        r = dict(i4=5678, v4='efgh')
1876        try:
1877            insert('test_view', r)
1878        except (pg.OperationalError, pg.NotSupportedError) as error:
1879            if self.db.server_version < 90300:
1880                # must setup rules in older PostgreSQL versions
1881                self.skipTest('database cannot insert into view')
1882            self.fail(str(error))
1883        self.assertNotIn('i2', r)
1884        self.assertEqual(r['i4'], 5678)
1885        self.assertNotIn('i8', r)
1886        self.assertEqual(r['v4'], 'efgh')
1887        self.assertNotIn('c4', r)
1888        r = query(q).getresult()
1889        self.assertEqual(r, [(1234, 'abcd'), (5678, 'efgh')])
1890
1891    def testUpdate(self):
1892        update = self.db.update
1893        query = self.db.query
1894        self.assertRaises(pg.ProgrammingError, update,
1895                          'test', i2=2, i4=4, i8=8)
1896        table = 'update_test_table'
1897        self.createTable(table, 'n integer, t text', oids=True,
1898                         values=enumerate('xyz', start=1))
1899        self.assertRaises(pg.ProgrammingError, self.db.get, table, 2)
1900        r = self.db.get(table, 2, 'n')
1901        r['t'] = 'u'
1902        s = update(table, r)
1903        self.assertEqual(s, r)
1904        q = 'select t from "%s" where n=2' % table
1905        r = query(q).getresult()[0][0]
1906        self.assertEqual(r, 'u')
1907
1908    def testUpdateWithOid(self):
1909        update = self.db.update
1910        get = self.db.get
1911        query = self.db.query
1912        self.createTable('test_table', 'n int', oids=True, values=[1])
1913        s = get('test_table', 1, 'n')
1914        self.assertIsInstance(s, dict)
1915        self.assertEqual(s['n'], 1)
1916        s['n'] = 2
1917        r = update('test_table', s)
1918        self.assertIs(r, s)
1919        self.assertEqual(r['n'], 2)
1920        qoid = 'oid(test_table)'
1921        self.assertIn(qoid, r)
1922        self.assertNotIn('oid', r)
1923        self.assertEqual(sorted(r.keys()), ['n', qoid])
1924        r['n'] = 3
1925        oid = r.pop(qoid)
1926        r = update('test_table', r, oid=oid)
1927        self.assertIs(r, s)
1928        self.assertEqual(r['n'], 3)
1929        r.pop(qoid)
1930        self.assertRaises(pg.ProgrammingError, update, 'test_table', r)
1931        s = get('test_table', 3, 'n')
1932        self.assertIsInstance(s, dict)
1933        self.assertEqual(s['n'], 3)
1934        s.pop('n')
1935        r = update('test_table', s)
1936        oid = r.pop(qoid)
1937        self.assertEqual(r, {})
1938        q = "select n from test_table limit 2"
1939        r = query(q).getresult()
1940        self.assertEqual(r, [(3,)])
1941        query("insert into test_table values (1)")
1942        self.assertRaises(pg.ProgrammingError,
1943                          update, 'test_table', dict(oid=oid, n=4))
1944        r = update('test_table', dict(n=4), oid=oid)
1945        self.assertEqual(r['n'], 4)
1946        r = update('test_table *', dict(n=5), oid=oid)
1947        self.assertEqual(r['n'], 5)
1948        query("alter table test_table add column m int")
1949        query("alter table test_table add primary key (n)")
1950        self.assertIn('m', self.db.get_attnames('test_table', flush=True))
1951        self.assertEqual('n', self.db.pkey('test_table', flush=True))
1952        s = dict(n=1, m=4)
1953        r = update('test_table', s)
1954        self.assertIs(r, s)
1955        self.assertEqual(r['n'], 1)
1956        self.assertEqual(r['m'], 4)
1957        s = dict(m=7)
1958        r = update('test_table', s, n=5)
1959        self.assertIs(r, s)
1960        self.assertEqual(r['n'], 5)
1961        self.assertEqual(r['m'], 7)
1962        q = "select n, m from test_table order by 1 limit 3"
1963        r = query(q).getresult()
1964        self.assertEqual(r, [(1, 4), (5, 7)])
1965        s = dict(m=9, oid=oid)
1966        self.assertRaises(KeyError, update, 'test_table', s)
1967        r = update('test_table', s, oid=oid)
1968        self.assertIs(r, s)
1969        self.assertEqual(r['n'], 5)
1970        self.assertEqual(r['m'], 9)
1971        s = dict(n=1, m=3, oid=oid)
1972        r = update('test_table', s)
1973        self.assertIs(r, s)
1974        self.assertEqual(r['n'], 1)
1975        self.assertEqual(r['m'], 3)
1976        r = query(q).getresult()
1977        self.assertEqual(r, [(1, 3), (5, 9)])
1978        s.update(n=4, m=7)
1979        r = update('test_table', s, oid=oid)
1980        self.assertIs(r, s)
1981        self.assertEqual(r['n'], 4)
1982        self.assertEqual(r['m'], 7)
1983        r = query(q).getresult()
1984        self.assertEqual(r, [(1, 3), (4, 7)])
1985
1986    def testUpdateWithoutOid(self):
1987        update = self.db.update
1988        query = self.db.query
1989        self.assertRaises(pg.ProgrammingError, update,
1990                          'test', i2=2, i4=4, i8=8)
1991        table = 'update_test_table'
1992        self.createTable(table, 'n integer primary key, t text', oids=False,
1993                         values=enumerate('xyz', start=1))
1994        r = self.db.get(table, 2)
1995        r['t'] = 'u'
1996        s = update(table, r)
1997        self.assertEqual(s, r)
1998        q = 'select t from "%s" where n=2' % table
1999        r = query(q).getresult()[0][0]
2000        self.assertEqual(r, 'u')
2001
2002    def testUpdateWithCompositeKey(self):
2003        update = self.db.update
2004        query = self.db.query
2005        table = 'update_test_table_1'
2006        self.createTable(table, 'n integer primary key, t text',
2007                         values=enumerate('abc', start=1))
2008        self.assertRaises(KeyError, update, table, dict(t='b'))
2009        s = dict(n=2, t='d')
2010        r = update(table, s)
2011        self.assertIs(r, s)
2012        self.assertEqual(r['n'], 2)
2013        self.assertEqual(r['t'], 'd')
2014        q = 'select t from "%s" where n=2' % table
2015        r = query(q).getresult()[0][0]
2016        self.assertEqual(r, 'd')
2017        s.update(dict(n=4, t='e'))
2018        r = update(table, s)
2019        self.assertEqual(r['n'], 4)
2020        self.assertEqual(r['t'], 'e')
2021        q = 'select t from "%s" where n=2' % table
2022        r = query(q).getresult()[0][0]
2023        self.assertEqual(r, 'd')
2024        q = 'select t from "%s" where n=4' % table
2025        r = query(q).getresult()
2026        self.assertEqual(len(r), 0)
2027        query('drop table "%s"' % table)
2028        table = 'update_test_table_2'
2029        self.createTable(table,
2030                         'n integer, m integer, t text, primary key (n, m)',
2031                         values=[(n + 1, m + 1, chr(ord('a') + 2 * n + m))
2032                                 for n in range(3) for m in range(2)])
2033        self.assertRaises(KeyError, update, table, dict(n=2, t='b'))
2034        self.assertEqual(update(table,
2035                                dict(n=2, m=2, t='x'))['t'], 'x')
2036        q = 'select t from "%s" where n=2 order by m' % table
2037        r = [r[0] for r in query(q).getresult()]
2038        self.assertEqual(r, ['c', 'x'])
2039
2040    def testUpdateWithQuotedNames(self):
2041        update = self.db.update
2042        query = self.db.query
2043        table = 'test table for update()'
2044        self.createTable(table, '"Prime!" smallint primary key,'
2045                                ' "much space" integer, "Questions?" text',
2046                         values=[(13, 3003, 'Why!')])
2047        r = {'Prime!': 13, 'much space': 7007, 'Questions?': 'When?'}
2048        r = update(table, r)
2049        self.assertIsInstance(r, dict)
2050        self.assertEqual(r['Prime!'], 13)
2051        self.assertEqual(r['much space'], 7007)
2052        self.assertEqual(r['Questions?'], 'When?')
2053        r = query('select * from "%s" limit 2' % table).dictresult()
2054        self.assertEqual(len(r), 1)
2055        r = r[0]
2056        self.assertEqual(r['Prime!'], 13)
2057        self.assertEqual(r['much space'], 7007)
2058        self.assertEqual(r['Questions?'], 'When?')
2059
2060    def testUpsert(self):
2061        upsert = self.db.upsert
2062        query = self.db.query
2063        self.assertRaises(pg.ProgrammingError, upsert,
2064                          'test', i2=2, i4=4, i8=8)
2065        table = 'upsert_test_table'
2066        self.createTable(table, 'n integer primary key, t text', oids=True)
2067        s = dict(n=1, t='x')
2068        try:
2069            r = upsert(table, s)
2070        except pg.ProgrammingError as error:
2071            if self.db.server_version < 90500:
2072                self.skipTest('database does not support upsert')
2073            self.fail(str(error))
2074        self.assertIs(r, s)
2075        self.assertEqual(r['n'], 1)
2076        self.assertEqual(r['t'], 'x')
2077        s.update(n=2, t='y')
2078        r = upsert(table, s, **dict.fromkeys(s))
2079        self.assertIs(r, s)
2080        self.assertEqual(r['n'], 2)
2081        self.assertEqual(r['t'], 'y')
2082        q = 'select n, t from "%s" order by n limit 3' % table
2083        r = query(q).getresult()
2084        self.assertEqual(r, [(1, 'x'), (2, 'y')])
2085        s.update(t='z')
2086        r = upsert(table, s)
2087        self.assertIs(r, s)
2088        self.assertEqual(r['n'], 2)
2089        self.assertEqual(r['t'], 'z')
2090        r = query(q).getresult()
2091        self.assertEqual(r, [(1, 'x'), (2, 'z')])
2092        s.update(t='n')
2093        r = upsert(table, s, t=False)
2094        self.assertIs(r, s)
2095        self.assertEqual(r['n'], 2)
2096        self.assertEqual(r['t'], 'z')
2097        r = query(q).getresult()
2098        self.assertEqual(r, [(1, 'x'), (2, 'z')])
2099        s.update(t='y')
2100        r = upsert(table, s, t=True)
2101        self.assertIs(r, s)
2102        self.assertEqual(r['n'], 2)
2103        self.assertEqual(r['t'], 'y')
2104        r = query(q).getresult()
2105        self.assertEqual(r, [(1, 'x'), (2, 'y')])
2106        s.update(t='n')
2107        r = upsert(table, s, t="included.t || '2'")
2108        self.assertIs(r, s)
2109        self.assertEqual(r['n'], 2)
2110        self.assertEqual(r['t'], 'y2')
2111        r = query(q).getresult()
2112        self.assertEqual(r, [(1, 'x'), (2, 'y2')])
2113        s.update(t='y')
2114        r = upsert(table, s, t="excluded.t || '3'")
2115        self.assertIs(r, s)
2116        self.assertEqual(r['n'], 2)
2117        self.assertEqual(r['t'], 'y3')
2118        r = query(q).getresult()
2119        self.assertEqual(r, [(1, 'x'), (2, 'y3')])
2120        s.update(n=1, t='2')
2121        r = upsert(table, s, t="included.t || excluded.t")
2122        self.assertIs(r, s)
2123        self.assertEqual(r['n'], 1)
2124        self.assertEqual(r['t'], 'x2')
2125        r = query(q).getresult()
2126        self.assertEqual(r, [(1, 'x2'), (2, 'y3')])
2127        # not existing columns and oid parameter should be ignored
2128        s = dict(m=3, u='z')
2129        r = upsert(table, s, oid='invalid')
2130        self.assertIs(r, s)
2131
2132    def testUpsertWithOid(self):
2133        upsert = self.db.upsert
2134        get = self.db.get
2135        query = self.db.query
2136        self.createTable('test_table', 'n int', oids=True, values=[1])
2137        self.assertRaises(pg.ProgrammingError,
2138                          upsert, 'test_table', dict(n=2))
2139        r = get('test_table', 1, 'n')
2140        self.assertIsInstance(r, dict)
2141        self.assertEqual(r['n'], 1)
2142        qoid = 'oid(test_table)'
2143        self.assertIn(qoid, r)
2144        self.assertNotIn('oid', r)
2145        oid = r[qoid]
2146        self.assertRaises(pg.ProgrammingError,
2147                          upsert, 'test_table', dict(n=2, oid=oid))
2148        query("alter table test_table add column m int")
2149        query("alter table test_table add primary key (n)")
2150        self.assertIn('m', self.db.get_attnames('test_table', flush=True))
2151        self.assertEqual('n', self.db.pkey('test_table', flush=True))
2152        s = dict(n=2)
2153        try:
2154            r = upsert('test_table', s)
2155        except pg.ProgrammingError as error:
2156            if self.db.server_version < 90500:
2157                self.skipTest('database does not support upsert')
2158            self.fail(str(error))
2159        self.assertIs(r, s)
2160        self.assertEqual(r['n'], 2)
2161        self.assertIsNone(r['m'])
2162        q = query("select n, m from test_table order by n limit 3")
2163        self.assertEqual(q.getresult(), [(1, None), (2, None)])
2164        r['oid'] = oid
2165        r = upsert('test_table', r)
2166        self.assertIs(r, s)
2167        self.assertEqual(r['n'], 2)
2168        self.assertIsNone(r['m'])
2169        self.assertIn(qoid, r)
2170        self.assertNotIn('oid', r)
2171        self.assertNotEqual(r[qoid], oid)
2172        r['m'] = 7
2173        r = upsert('test_table', r)
2174        self.assertIs(r, s)
2175        self.assertEqual(r['n'], 2)
2176        self.assertEqual(r['m'], 7)
2177        r.update(n=1, m=3)
2178        r = upsert('test_table', r)
2179        self.assertIs(r, s)
2180        self.assertEqual(r['n'], 1)
2181        self.assertEqual(r['m'], 3)
2182        q = query("select n, m from test_table order by n limit 3")
2183        self.assertEqual(q.getresult(), [(1, 3), (2, 7)])
2184        r = upsert('test_table', r, oid='invalid')
2185        self.assertIs(r, s)
2186        self.assertEqual(r['n'], 1)
2187        self.assertEqual(r['m'], 3)
2188        r['m'] = 5
2189        r = upsert('test_table', r, m=False)
2190        self.assertIs(r, s)
2191        self.assertEqual(r['n'], 1)
2192        self.assertEqual(r['m'], 3)
2193        r['m'] = 5
2194        r = upsert('test_table', r, m=True)
2195        self.assertIs(r, s)
2196        self.assertEqual(r['n'], 1)
2197        self.assertEqual(r['m'], 5)
2198        r.update(n=2, m=1)
2199        r = upsert('test_table', r, m='included.m')
2200        self.assertIs(r, s)
2201        self.assertEqual(r['n'], 2)
2202        self.assertEqual(r['m'], 7)
2203        r['m'] = 9
2204        r = upsert('test_table', r, m='excluded.m')
2205        self.assertIs(r, s)
2206        self.assertEqual(r['n'], 2)
2207        self.assertEqual(r['m'], 9)
2208        r['m'] = 8
2209        r = upsert('test_table *', r, m='included.m + 1')
2210        self.assertIs(r, s)
2211        self.assertEqual(r['n'], 2)
2212        self.assertEqual(r['m'], 10)
2213        q = query("select n, m from test_table order by n limit 3")
2214        self.assertEqual(q.getresult(), [(1, 5), (2, 10)])
2215
2216    def testUpsertWithCompositeKey(self):
2217        upsert = self.db.upsert
2218        query = self.db.query
2219        table = 'upsert_test_table_2'
2220        self.createTable(table,
2221                         'n integer, m integer, t text, primary key (n, m)')
2222        s = dict(n=1, m=2, t='x')
2223        try:
2224            r = upsert(table, s)
2225        except pg.ProgrammingError as error:
2226            if self.db.server_version < 90500:
2227                self.skipTest('database does not support upsert')
2228            self.fail(str(error))
2229        self.assertIs(r, s)
2230        self.assertEqual(r['n'], 1)
2231        self.assertEqual(r['m'], 2)
2232        self.assertEqual(r['t'], 'x')
2233        s.update(m=3, t='y')
2234        r = upsert(table, s, **dict.fromkeys(s))
2235        self.assertIs(r, s)
2236        self.assertEqual(r['n'], 1)
2237        self.assertEqual(r['m'], 3)
2238        self.assertEqual(r['t'], 'y')
2239        q = 'select n, m, t from "%s" order by n, m limit 3' % table
2240        r = query(q).getresult()
2241        self.assertEqual(r, [(1, 2, 'x'), (1, 3, 'y')])
2242        s.update(t='z')
2243        r = upsert(table, s)
2244        self.assertIs(r, s)
2245        self.assertEqual(r['n'], 1)
2246        self.assertEqual(r['m'], 3)
2247        self.assertEqual(r['t'], 'z')
2248        r = query(q).getresult()
2249        self.assertEqual(r, [(1, 2, 'x'), (1, 3, 'z')])
2250        s.update(t='n')
2251        r = upsert(table, s, t=False)
2252        self.assertIs(r, s)
2253        self.assertEqual(r['n'], 1)
2254        self.assertEqual(r['m'], 3)
2255        self.assertEqual(r['t'], 'z')
2256        r = query(q).getresult()
2257        self.assertEqual(r, [(1, 2, 'x'), (1, 3, 'z')])
2258        s.update(t='n')
2259        r = upsert(table, s, t=True)
2260        self.assertIs(r, s)
2261        self.assertEqual(r['n'], 1)
2262        self.assertEqual(r['m'], 3)
2263        self.assertEqual(r['t'], 'n')
2264        r = query(q).getresult()
2265        self.assertEqual(r, [(1, 2, 'x'), (1, 3, 'n')])
2266        s.update(n=2, t='y')
2267        r = upsert(table, s, t="'z'")
2268        self.assertIs(r, s)
2269        self.assertEqual(r['n'], 2)
2270        self.assertEqual(r['m'], 3)
2271        self.assertEqual(r['t'], 'y')
2272        r = query(q).getresult()
2273        self.assertEqual(r, [(1, 2, 'x'), (1, 3, 'n'), (2, 3, 'y')])
2274        s.update(n=1, t='m')
2275        r = upsert(table, s, t='included.t || excluded.t')
2276        self.assertIs(r, s)
2277        self.assertEqual(r['n'], 1)
2278        self.assertEqual(r['m'], 3)
2279        self.assertEqual(r['t'], 'nm')
2280        r = query(q).getresult()
2281        self.assertEqual(r, [(1, 2, 'x'), (1, 3, 'nm'), (2, 3, 'y')])
2282
2283    def testUpsertWithQuotedNames(self):
2284        upsert = self.db.upsert
2285        query = self.db.query
2286        table = 'test table for upsert()'
2287        self.createTable(table, '"Prime!" smallint primary key,'
2288                                ' "much space" integer, "Questions?" text')
2289        s = {'Prime!': 31, 'much space': 9009, 'Questions?': 'Yes.'}
2290        try:
2291            r = upsert(table, s)
2292        except pg.ProgrammingError as error:
2293            if self.db.server_version < 90500:
2294                self.skipTest('database does not support upsert')
2295            self.fail(str(error))
2296        self.assertIs(r, s)
2297        self.assertEqual(r['Prime!'], 31)
2298        self.assertEqual(r['much space'], 9009)
2299        self.assertEqual(r['Questions?'], 'Yes.')
2300        q = 'select * from "%s" limit 2' % table
2301        r = query(q).getresult()
2302        self.assertEqual(r, [(31, 9009, 'Yes.')])
2303        s.update({'Questions?': 'No.'})
2304        r = upsert(table, s)
2305        self.assertIs(r, s)
2306        self.assertEqual(r['Prime!'], 31)
2307        self.assertEqual(r['much space'], 9009)
2308        self.assertEqual(r['Questions?'], 'No.')
2309        r = query(q).getresult()
2310        self.assertEqual(r, [(31, 9009, 'No.')])
2311
2312    def testClear(self):
2313        clear = self.db.clear
2314        f = False if pg.get_bool() else 'f'
2315        r = clear('test')
2316        result = dict(
2317            i2=0, i4=0, i8=0, d=0, f4=0, f8=0, m=0, v4='', c4='', t='')
2318        self.assertEqual(r, result)
2319        table = 'clear_test_table'
2320        self.createTable(table,
2321            'n integer, f float, b boolean, d date, t text', oids=True)
2322        r = clear(table)
2323        result = dict(n=0, f=0, b=f, d='', t='')
2324        self.assertEqual(r, result)
2325        r['a'] = r['f'] = r['n'] = 1
2326        r['d'] = r['t'] = 'x'
2327        r['b'] = 't'
2328        r['oid'] = long(1)
2329        r = clear(table, r)
2330        result = dict(a=1, n=0, f=0, b=f, d='', t='', oid=long(1))
2331        self.assertEqual(r, result)
2332
2333    def testClearWithQuotedNames(self):
2334        clear = self.db.clear
2335        table = 'test table for clear()'
2336        self.createTable(table, '"Prime!" smallint primary key,'
2337            ' "much space" integer, "Questions?" text')
2338        r = clear(table)
2339        self.assertIsInstance(r, dict)
2340        self.assertEqual(r['Prime!'], 0)
2341        self.assertEqual(r['much space'], 0)
2342        self.assertEqual(r['Questions?'], '')
2343
2344    def testDelete(self):
2345        delete = self.db.delete
2346        query = self.db.query
2347        self.assertRaises(pg.ProgrammingError, delete,
2348                          'test', dict(i2=2, i4=4, i8=8))
2349        table = 'delete_test_table'
2350        self.createTable(table, 'n integer, t text', oids=True,
2351                         values=enumerate('xyz', start=1))
2352        self.assertRaises(pg.ProgrammingError, self.db.get, table, 2)
2353        r = self.db.get(table, 1, 'n')
2354        s = delete(table, r)
2355        self.assertEqual(s, 1)
2356        r = self.db.get(table, 3, 'n')
2357        s = delete(table, r)
2358        self.assertEqual(s, 1)
2359        s = delete(table, r)
2360        self.assertEqual(s, 0)
2361        r = query('select * from "%s"' % table).dictresult()
2362        self.assertEqual(len(r), 1)
2363        r = r[0]
2364        result = {'n': 2, 't': 'y'}
2365        self.assertEqual(r, result)
2366        r = self.db.get(table, 2, 'n')
2367        s = delete(table, r)
2368        self.assertEqual(s, 1)
2369        s = delete(table, r)
2370        self.assertEqual(s, 0)
2371        self.assertRaises(pg.DatabaseError, self.db.get, table, 2, 'n')
2372        # not existing columns and oid parameter should be ignored
2373        r.update(m=3, u='z', oid='invalid')
2374        s = delete(table, r)
2375        self.assertEqual(s, 0)
2376
2377    def testDeleteWithOid(self):
2378        delete = self.db.delete
2379        get = self.db.get
2380        query = self.db.query
2381        self.createTable('test_table', 'n int', oids=True, values=range(1, 7))
2382        r = dict(n=3)
2383        self.assertRaises(pg.ProgrammingError, delete, 'test_table', r)
2384        s = get('test_table', 1, 'n')
2385        qoid = 'oid(test_table)'
2386        self.assertIn(qoid, s)
2387        r = delete('test_table', s)
2388        self.assertEqual(r, 1)
2389        r = delete('test_table', s)
2390        self.assertEqual(r, 0)
2391        q = "select min(n),count(n) from test_table"
2392        self.assertEqual(query(q).getresult()[0], (2, 5))
2393        oid = get('test_table', 2, 'n')[qoid]
2394        s = dict(oid=oid, n=2)
2395        self.assertRaises(pg.ProgrammingError, delete, 'test_table', s)
2396        r = delete('test_table', None, oid=oid)
2397        self.assertEqual(r, 1)
2398        r = delete('test_table', None, oid=oid)
2399        self.assertEqual(r, 0)
2400        self.assertEqual(query(q).getresult()[0], (3, 4))
2401        s = dict(oid=oid, n=2)
2402        oid = get('test_table', 3, 'n')[qoid]
2403        self.assertRaises(pg.ProgrammingError, delete, 'test_table', s)
2404        r = delete('test_table', s, oid=oid)
2405        self.assertEqual(r, 1)
2406        r = delete('test_table', s, oid=oid)
2407        self.assertEqual(r, 0)
2408        self.assertEqual(query(q).getresult()[0], (4, 3))
2409        s = get('test_table', 4, 'n')
2410        r = delete('test_table *', s)
2411        self.assertEqual(r, 1)
2412        r = delete('test_table *', s)
2413        self.assertEqual(r, 0)
2414        self.assertEqual(query(q).getresult()[0], (5, 2))
2415        oid = get('test_table', 5, 'n')[qoid]
2416        s = {qoid: oid, 'm': 4}
2417        r = delete('test_table', s, m=6)
2418        self.assertEqual(r, 1)
2419        r = delete('test_table *', s)
2420        self.assertEqual(r, 0)
2421        self.assertEqual(query(q).getresult()[0], (6, 1))
2422        query("alter table test_table add column m int")
2423        query("alter table test_table add primary key (n)")
2424        self.assertIn('m', self.db.get_attnames('test_table', flush=True))
2425        self.assertEqual('n', self.db.pkey('test_table', flush=True))
2426        for i in range(5):
2427            query("insert into test_table values (%d, %d)" % (i + 1, i + 2))
2428        s = dict(m=2)
2429        self.assertRaises(KeyError, delete, 'test_table', s)
2430        s = dict(m=2, oid=oid)
2431        self.assertRaises(KeyError, delete, 'test_table', s)
2432        r = delete('test_table', dict(m=2), oid=oid)
2433        self.assertEqual(r, 0)
2434        oid = get('test_table', 1, 'n')[qoid]
2435        s = dict(oid=oid)
2436        self.assertRaises(KeyError, delete, 'test_table', s)
2437        r = delete('test_table', s, oid=oid)
2438        self.assertEqual(r, 1)
2439        r = delete('test_table', s, oid=oid)
2440        self.assertEqual(r, 0)
2441        self.assertEqual(query(q).getresult()[0], (2, 5))
2442        s = get('test_table', 2, 'n')
2443        del s['n']
2444        r = delete('test_table', s)
2445        self.assertEqual(r, 1)
2446        r = delete('test_table', s)
2447        self.assertEqual(r, 0)
2448        self.assertEqual(query(q).getresult()[0], (3, 4))
2449        r = delete('test_table', n=3)
2450        self.assertEqual(r, 1)
2451        r = delete('test_table', n=3)
2452        self.assertEqual(r, 0)
2453        self.assertEqual(query(q).getresult()[0], (4, 3))
2454        r = delete('test_table', None, n=4)
2455        self.assertEqual(r, 1)
2456        r = delete('test_table', None, n=4)
2457        self.assertEqual(r, 0)
2458        self.assertEqual(query(q).getresult()[0], (5, 2))
2459        s = dict(n=6)
2460        r = delete('test_table', s, n=5)
2461        self.assertEqual(r, 1)
2462        r = delete('test_table', s, n=5)
2463        self.assertEqual(r, 0)
2464        s = get('test_table', 6, 'n')
2465        self.assertEqual(s['n'], 6)
2466        s['n'] = 7
2467        r = delete('test_table', s)
2468        self.assertEqual(r, 1)
2469        self.assertEqual(query(q).getresult()[0], (None, 0))
2470
2471    def testDeleteWithCompositeKey(self):
2472        query = self.db.query
2473        table = 'delete_test_table_1'
2474        self.createTable(table, 'n integer primary key, t text',
2475                         values=enumerate('abc', start=1))
2476        self.assertRaises(KeyError, self.db.delete, table, dict(t='b'))
2477        self.assertEqual(self.db.delete(table, dict(n=2)), 1)
2478        r = query('select t from "%s" where n=2' % table).getresult()
2479        self.assertEqual(r, [])
2480        self.assertEqual(self.db.delete(table, dict(n=2)), 0)
2481        r = query('select t from "%s" where n=3' % table).getresult()[0][0]
2482        self.assertEqual(r, 'c')
2483        table = 'delete_test_table_2'
2484        self.createTable(table,
2485             'n integer, m integer, t text, primary key (n, m)',
2486             values=[(n + 1, m + 1, chr(ord('a') + 2 * n + m))
2487                     for n in range(3) for m in range(2)])
2488        self.assertRaises(KeyError, self.db.delete, table, dict(n=2, t='b'))
2489        self.assertEqual(self.db.delete(table, dict(n=2, m=2)), 1)
2490        r = [r[0] for r in query('select t from "%s" where n=2'
2491            ' order by m' % table).getresult()]
2492        self.assertEqual(r, ['c'])
2493        self.assertEqual(self.db.delete(table, dict(n=2, m=2)), 0)
2494        r = [r[0] for r in query('select t from "%s" where n=3'
2495             ' order by m' % table).getresult()]
2496        self.assertEqual(r, ['e', 'f'])
2497        self.assertEqual(self.db.delete(table, dict(n=3, m=1)), 1)
2498        r = [r[0] for r in query('select t from "%s" where n=3'
2499             ' order by m' % table).getresult()]
2500        self.assertEqual(r, ['f'])
2501
2502    def testDeleteWithQuotedNames(self):
2503        delete = self.db.delete
2504        query = self.db.query
2505        table = 'test table for delete()'
2506        self.createTable(table, '"Prime!" smallint primary key,'
2507            ' "much space" integer, "Questions?" text',
2508            values=[(19, 5005, 'Yes!')])
2509        r = {'Prime!': 17}
2510        r = delete(table, r)
2511        self.assertEqual(r, 0)
2512        r = query('select count(*) from "%s"' % table).getresult()
2513        self.assertEqual(r[0][0], 1)
2514        r = {'Prime!': 19}
2515        r = delete(table, r)
2516        self.assertEqual(r, 1)
2517        r = query('select count(*) from "%s"' % table).getresult()
2518        self.assertEqual(r[0][0], 0)
2519
2520    def testDeleteReferenced(self):
2521        delete = self.db.delete
2522        query = self.db.query
2523        self.createTable('test_parent',
2524            'n smallint primary key', values=range(3))
2525        self.createTable('test_child',
2526            'n smallint primary key references test_parent', values=range(3))
2527        q = ("select (select count(*) from test_parent),"
2528             " (select count(*) from test_child)")
2529        self.assertEqual(query(q).getresult()[0], (3, 3))
2530        self.assertRaises(pg.IntegrityError,
2531                          delete, 'test_parent', None, n=2)
2532        self.assertRaises(pg.IntegrityError,
2533                          delete, 'test_parent *', None, n=2)
2534        r = delete('test_child', None, n=2)
2535        self.assertEqual(r, 1)
2536        self.assertEqual(query(q).getresult()[0], (3, 2))
2537        r = delete('test_parent', None, n=2)
2538        self.assertEqual(r, 1)
2539        self.assertEqual(query(q).getresult()[0], (2, 2))
2540        self.assertRaises(pg.IntegrityError,
2541                          delete, 'test_parent', dict(n=0))
2542        self.assertRaises(pg.IntegrityError,
2543                          delete, 'test_parent *', dict(n=0))
2544        r = delete('test_child', dict(n=0))
2545        self.assertEqual(r, 1)
2546        self.assertEqual(query(q).getresult()[0], (2, 1))
2547        r = delete('test_child', dict(n=0))
2548        self.assertEqual(r, 0)
2549        r = delete('test_parent', dict(n=0))
2550        self.assertEqual(r, 1)
2551        self.assertEqual(query(q).getresult()[0], (1, 1))
2552        r = delete('test_parent', None, n=0)
2553        self.assertEqual(r, 0)
2554        q = "select n from test_parent natural join test_child limit 2"
2555        self.assertEqual(query(q).getresult(), [(1,)])
2556
2557    def testTempCrud(self):
2558        table = 'test_temp_table'
2559        self.createTable(table, "n int primary key, t varchar", temporary=True)
2560        self.db.insert(table, dict(n=1, t='one'))
2561        self.db.insert(table, dict(n=2, t='too'))
2562        self.db.insert(table, dict(n=3, t='three'))
2563        r = self.db.get(table, 2)
2564        self.assertEqual(r['t'], 'too')
2565        self.db.update(table, dict(n=2, t='two'))
2566        r = self.db.get(table, 2)
2567        self.assertEqual(r['t'], 'two')
2568        self.db.delete(table, r)
2569        r = self.db.query('select n, t from %s order by 1' % table).getresult()
2570        self.assertEqual(r, [(1, 'one'), (3, 'three')])
2571
2572    def testTruncate(self):
2573        truncate = self.db.truncate
2574        self.assertRaises(TypeError, truncate, None)
2575        self.assertRaises(TypeError, truncate, 42)
2576        self.assertRaises(TypeError, truncate, dict(test_table=None))
2577        query = self.db.query
2578        self.createTable('test_table', 'n smallint',
2579                         temporary=False, values=[1] * 3)
2580        q = "select count(*) from test_table"
2581        r = query(q).getresult()[0][0]
2582        self.assertEqual(r, 3)
2583        truncate('test_table')
2584        r = query(q).getresult()[0][0]
2585        self.assertEqual(r, 0)
2586        for i in range(3):
2587            query("insert into test_table values (1)")
2588        r = query(q).getresult()[0][0]
2589        self.assertEqual(r, 3)
2590        truncate('public.test_table')
2591        r = query(q).getresult()[0][0]
2592        self.assertEqual(r, 0)
2593        self.createTable('test_table_2', 'n smallint', temporary=True)
2594        for t in (list, tuple, set):
2595            for i in range(3):
2596                query("insert into test_table values (1)")
2597                query("insert into test_table_2 values (2)")
2598            q = ("select (select count(*) from test_table),"
2599                 " (select count(*) from test_table_2)")
2600            r = query(q).getresult()[0]
2601            self.assertEqual(r, (3, 3))
2602            truncate(t(['test_table', 'test_table_2']))
2603            r = query(q).getresult()[0]
2604            self.assertEqual(r, (0, 0))
2605
2606    def testTruncateRestart(self):
2607        truncate = self.db.truncate
2608        self.assertRaises(TypeError, truncate, 'test_table', restart='invalid')
2609        query = self.db.query
2610        self.createTable('test_table', 'n serial, t text')
2611        for n in range(3):
2612            query("insert into test_table (t) values ('test')")
2613        q = "select count(n), min(n), max(n) from test_table"
2614        r = query(q).getresult()[0]
2615        self.assertEqual(r, (3, 1, 3))
2616        truncate('test_table')
2617        r = query(q).getresult()[0]
2618        self.assertEqual(r, (0, None, None))
2619        for n in range(3):
2620            query("insert into test_table (t) values ('test')")
2621        r = query(q).getresult()[0]
2622        self.assertEqual(r, (3, 4, 6))
2623        truncate('test_table', restart=True)
2624        r = query(q).getresult()[0]
2625        self.assertEqual(r, (0, None, None))
2626        for n in range(3):
2627            query("insert into test_table (t) values ('test')")
2628        r = query(q).getresult()[0]
2629        self.assertEqual(r, (3, 1, 3))
2630
2631    def testTruncateCascade(self):
2632        truncate = self.db.truncate
2633        self.assertRaises(TypeError, truncate, 'test_table', cascade='invalid')
2634        query = self.db.query
2635        self.createTable('test_parent', 'n smallint primary key',
2636                         values=range(3))
2637        self.createTable('test_child',
2638                         'n smallint primary key references test_parent (n)',
2639                         values=range(3))
2640        q = ("select (select count(*) from test_parent),"
2641             " (select count(*) from test_child)")
2642        r = query(q).getresult()[0]
2643        self.assertEqual(r, (3, 3))
2644        self.assertRaises(pg.NotSupportedError, truncate, 'test_parent')
2645        truncate(['test_parent', 'test_child'])
2646        r = query(q).getresult()[0]
2647        self.assertEqual(r, (0, 0))
2648        for n in range(3):
2649            query("insert into test_parent (n) values (%d)" % n)
2650            query("insert into test_child (n) values (%d)" % n)
2651        r = query(q).getresult()[0]
2652        self.assertEqual(r, (3, 3))
2653        truncate('test_parent', cascade=True)
2654        r = query(q).getresult()[0]
2655        self.assertEqual(r, (0, 0))
2656        for n in range(3):
2657            query("insert into test_parent (n) values (%d)" % n)
2658            query("insert into test_child (n) values (%d)" % n)
2659        r = query(q).getresult()[0]
2660        self.assertEqual(r, (3, 3))
2661        truncate('test_child')
2662        r = query(q).getresult()[0]
2663        self.assertEqual(r, (3, 0))
2664        self.assertRaises(pg.NotSupportedError, truncate, 'test_parent')
2665        truncate('test_parent', cascade=True)
2666        r = query(q).getresult()[0]
2667        self.assertEqual(r, (0, 0))
2668
2669    def testTruncateOnly(self):
2670        truncate = self.db.truncate
2671        self.assertRaises(TypeError, truncate, 'test_table', only='invalid')
2672        query = self.db.query
2673        self.createTable('test_parent', 'n smallint')
2674        self.createTable('test_child', 'm smallint) inherits (test_parent')
2675        for n in range(3):
2676            query("insert into test_parent (n) values (1)")
2677            query("insert into test_child (n, m) values (2, 3)")
2678        q = ("select (select count(*) from test_parent),"
2679             " (select count(*) from test_child)")
2680        r = query(q).getresult()[0]
2681        self.assertEqual(r, (6, 3))
2682        truncate('test_parent')
2683        r = query(q).getresult()[0]
2684        self.assertEqual(r, (0, 0))
2685        for n in range(3):
2686            query("insert into test_parent (n) values (1)")
2687            query("insert into test_child (n, m) values (2, 3)")
2688        r = query(q).getresult()[0]
2689        self.assertEqual(r, (6, 3))
2690        truncate('test_parent*')
2691        r = query(q).getresult()[0]
2692        self.assertEqual(r, (0, 0))
2693        for n in range(3):
2694            query("insert into test_parent (n) values (1)")
2695            query("insert into test_child (n, m) values (2, 3)")
2696        r = query(q).getresult()[0]
2697        self.assertEqual(r, (6, 3))
2698        truncate('test_parent', only=True)
2699        r = query(q).getresult()[0]
2700        self.assertEqual(r, (3, 3))
2701        truncate('test_parent', only=False)
2702        r = query(q).getresult()[0]
2703        self.assertEqual(r, (0, 0))
2704        self.assertRaises(ValueError, truncate, 'test_parent*', only=True)
2705        truncate('test_parent*', only=False)
2706        self.createTable('test_parent_2', 'n smallint')
2707        self.createTable('test_child_2', 'm smallint) inherits (test_parent_2')
2708        for t in '', '_2':
2709            for n in range(3):
2710                query("insert into test_parent%s (n) values (1)" % t)
2711                query("insert into test_child%s (n, m) values (2, 3)" % t)
2712        q = ("select (select count(*) from test_parent),"
2713             " (select count(*) from test_child),"
2714             " (select count(*) from test_parent_2),"
2715             " (select count(*) from test_child_2)")
2716        r = query(q).getresult()[0]
2717        self.assertEqual(r, (6, 3, 6, 3))
2718        truncate(['test_parent', 'test_parent_2'], only=[False, True])
2719        r = query(q).getresult()[0]
2720        self.assertEqual(r, (0, 0, 3, 3))
2721        truncate(['test_parent', 'test_parent_2'], only=False)
2722        r = query(q).getresult()[0]
2723        self.assertEqual(r, (0, 0, 0, 0))
2724        self.assertRaises(ValueError, truncate,
2725            ['test_parent*', 'test_child'], only=[True, False])
2726        truncate(['test_parent*', 'test_child'], only=[False, True])
2727
2728    def testTruncateQuoted(self):
2729        truncate = self.db.truncate
2730        query = self.db.query
2731        table = "test table for truncate()"
2732        self.createTable(table, 'n smallint', temporary=False, values=[1] * 3)
2733        q = 'select count(*) from "%s"' % table
2734        r = query(q).getresult()[0][0]
2735        self.assertEqual(r, 3)
2736        truncate(table)
2737        r = query(q).getresult()[0][0]
2738        self.assertEqual(r, 0)
2739        for i in range(3):
2740            query('insert into "%s" values (1)' % table)
2741        r = query(q).getresult()[0][0]
2742        self.assertEqual(r, 3)
2743        truncate('public."%s"' % table)
2744        r = query(q).getresult()[0][0]
2745        self.assertEqual(r, 0)
2746
2747    def testGetAsList(self):
2748        get_as_list = self.db.get_as_list
2749        self.assertRaises(TypeError, get_as_list)
2750        self.assertRaises(TypeError, get_as_list, None)
2751        query = self.db.query
2752        table = 'test_aslist'
2753        r = query('select 1 as colname').namedresult()[0]
2754        self.assertIsInstance(r, tuple)
2755        named = hasattr(r, 'colname')
2756        names = [(1, 'Homer'), (2, 'Marge'),
2757                 (3, 'Bart'), (4, 'Lisa'), (5, 'Maggie')]
2758        self.createTable(table,
2759            'id smallint primary key, name varchar', values=names)
2760        r = get_as_list(table)
2761        self.assertIsInstance(r, list)
2762        self.assertEqual(r, names)
2763        for t, n in zip(r, names):
2764            self.assertIsInstance(t, tuple)
2765            self.assertEqual(t, n)
2766            if named:
2767                self.assertEqual(t.id, n[0])
2768                self.assertEqual(t.name, n[1])
2769                self.assertEqual(t._asdict(), dict(id=n[0], name=n[1]))
2770        r = get_as_list(table, what='name')
2771        self.assertIsInstance(r, list)
2772        expected = sorted((row[1],) for row in names)
2773        self.assertEqual(r, expected)
2774        r = get_as_list(table, what='name, id')
2775        self.assertIsInstance(r, list)
2776        expected = sorted(tuple(reversed(row)) for row in names)
2777        self.assertEqual(r, expected)
2778        r = get_as_list(table, what=['name', 'id'])
2779        self.assertIsInstance(r, list)
2780        self.assertEqual(r, expected)
2781        r = get_as_list(table, where="name like 'Ba%'")
2782        self.assertIsInstance(r, list)
2783        self.assertEqual(r, names[2:3])
2784        r = get_as_list(table, what='name', where="name like 'Ma%'")
2785        self.assertIsInstance(r, list)
2786        self.assertEqual(r, [('Maggie',), ('Marge',)])
2787        r = get_as_list(table, what='name',
2788            where=["name like 'Ma%'", "name like '%r%'"])
2789        self.assertIsInstance(r, list)
2790        self.assertEqual(r, [('Marge',)])
2791        r = get_as_list(table, what='name', order='id')
2792        self.assertIsInstance(r, list)
2793        expected = [(row[1],) for row in names]
2794        self.assertEqual(r, expected)
2795        r = get_as_list(table, what=['name'], order=['id'])
2796        self.assertIsInstance(r, list)
2797        self.assertEqual(r, expected)
2798        r = get_as_list(table, what=['id', 'name'], order=['id', 'name'])
2799        self.assertIsInstance(r, list)
2800        self.assertEqual(r, names)
2801        r = get_as_list(table, what='id * 2 as num', order='id desc')
2802        self.assertIsInstance(r, list)
2803        expected = [(n,) for n in range(10, 0, -2)]
2804        self.assertEqual(r, expected)
2805        r = get_as_list(table, limit=2)
2806        self.assertIsInstance(r, list)
2807        self.assertEqual(r, names[:2])
2808        r = get_as_list(table, offset=3)
2809        self.assertIsInstance(r, list)
2810        self.assertEqual(r, names[3:])
2811        r = get_as_list(table, limit=1, offset=2)
2812        self.assertIsInstance(r, list)
2813        self.assertEqual(r, names[2:3])
2814        r = get_as_list(table, scalar=True)
2815        self.assertIsInstance(r, list)
2816        self.assertEqual(r, list(range(1, 6)))
2817        r = get_as_list(table, what='name', scalar=True)
2818        self.assertIsInstance(r, list)
2819        expected = sorted(row[1] for row in names)
2820        self.assertEqual(r, expected)
2821        r = get_as_list(table, what='name', limit=1, scalar=True)
2822        self.assertIsInstance(r, list)
2823        self.assertEqual(r, expected[:1])
2824        query('alter table "%s" drop constraint "%s_pkey"' % (table, table))
2825        self.assertRaises(KeyError, self.db.pkey, table, flush=True)
2826        names.insert(1, (1, 'Snowball'))
2827        query('insert into "%s" values ($1, $2)' % table, (1, 'Snowball'))
2828        r = get_as_list(table)
2829        self.assertIsInstance(r, list)
2830        self.assertEqual(r, names)
2831        r = get_as_list(table, what='name', where='id=1', scalar=True)
2832        self.assertIsInstance(r, list)
2833        self.assertEqual(r, ['Homer', 'Snowball'])
2834        # test with unordered query
2835        r = get_as_list(table, order=False)
2836        self.assertIsInstance(r, list)
2837        self.assertEqual(set(r), set(names))
2838        # test with arbitrary from clause
2839        from_table = '(select lower(name) as n2 from "%s") as t2' % table
2840        r = get_as_list(from_table)
2841        self.assertIsInstance(r, list)
2842        r = set(row[0] for row in r)
2843        expected = set(row[1].lower() for row in names)
2844        self.assertEqual(r, expected)
2845        r = get_as_list(from_table, order='n2', scalar=True)
2846        self.assertIsInstance(r, list)
2847        self.assertEqual(r, sorted(expected))
2848        r = get_as_list(from_table, order='n2', limit=1)
2849        self.assertIsInstance(r, list)
2850        self.assertEqual(len(r), 1)
2851        t = r[0]
2852        self.assertIsInstance(t, tuple)
2853        if named:
2854            self.assertEqual(t.n2, 'bart')
2855            self.assertEqual(t._asdict(), dict(n2='bart'))
2856        else:
2857            self.assertEqual(t, ('bart',))
2858
2859    def testGetAsDict(self):
2860        get_as_dict = self.db.get_as_dict
2861        self.assertRaises(TypeError, get_as_dict)
2862        self.assertRaises(TypeError, get_as_dict, None)
2863        # the test table has no primary key
2864        self.assertRaises(pg.ProgrammingError, get_as_dict, 'test')
2865        query = self.db.query
2866        table = 'test_asdict'
2867        r = query('select 1 as colname').namedresult()[0]
2868        self.assertIsInstance(r, tuple)
2869        named = hasattr(r, 'colname')
2870        colors = [(1, '#7cb9e8', 'Aero'), (2, '#b5a642', 'Brass'),
2871                  (3, '#b2ffff', 'Celeste'), (4, '#c19a6b', 'Desert')]
2872        self.createTable(table,
2873            'id smallint primary key, rgb char(7), name varchar',
2874            values=colors)
2875        # keyname must be string, list or tuple
2876        self.assertRaises(KeyError, get_as_dict, table, 3)
2877        self.assertRaises(KeyError, get_as_dict, table, dict(id=None))
2878        # missing keyname in row
2879        self.assertRaises(KeyError, get_as_dict, table,
2880                          keyname='rgb', what='name')
2881        r = get_as_dict(table)
2882        self.assertIsInstance(r, OrderedDict)
2883        expected = OrderedDict((row[0], row[1:]) for row in colors)
2884        self.assertEqual(r, expected)
2885        for key in r:
2886            self.assertIsInstance(key, int)
2887            self.assertIn(key, expected)
2888            row = r[key]
2889            self.assertIsInstance(row, tuple)
2890            t = expected[key]
2891            self.assertEqual(row, t)
2892            if named:
2893                self.assertEqual(row.rgb, t[0])
2894                self.assertEqual(row.name, t[1])
2895                self.assertEqual(row._asdict(), dict(rgb=t[0], name=t[1]))
2896        if OrderedDict is not dict:  # Python > 2.6
2897            self.assertEqual(r.keys(), expected.keys())
2898        r = get_as_dict(table, keyname='rgb')
2899        self.assertIsInstance(r, OrderedDict)
2900        expected = OrderedDict((row[1], (row[0], row[2]))
2901                               for row in sorted(colors, key=itemgetter(1)))
2902        self.assertEqual(r, expected)
2903        for key in r:
2904            self.assertIsInstance(key, str)
2905            self.assertIn(key, expected)
2906            row = r[key]
2907            self.assertIsInstance(row, tuple)
2908            t = expected[key]
2909            self.assertEqual(row, t)
2910            if named:
2911                self.assertEqual(row.id, t[0])
2912                self.assertEqual(row.name, t[1])
2913                self.assertEqual(row._asdict(), dict(id=t[0], name=t[1]))
2914        if OrderedDict is not dict:  # Python > 2.6
2915            self.assertEqual(r.keys(), expected.keys())
2916        r = get_as_dict(table, keyname=['id', 'rgb'])
2917        self.assertIsInstance(r, OrderedDict)
2918        expected = OrderedDict((row[:2], row[2:]) for row in colors)
2919        self.assertEqual(r, expected)
2920        for key in r:
2921            self.assertIsInstance(key, tuple)
2922            self.assertIsInstance(key[0], int)
2923            self.assertIsInstance(key[1], str)
2924            if named:
2925                self.assertEqual(key, (key.id, key.rgb))
2926                self.assertEqual(key._fields, ('id', 'rgb'))
2927            row = r[key]
2928            self.assertIsInstance(row, tuple)
2929            self.assertIsInstance(row[0], str)
2930            t = expected[key]
2931            self.assertEqual(row, t)
2932            if named:
2933                self.assertEqual(row.name, t[0])
2934                self.assertEqual(row._asdict(), dict(name=t[0]))
2935        if OrderedDict is not dict:  # Python > 2.6
2936            self.assertEqual(r.keys(), expected.keys())
2937        r = get_as_dict(table, keyname=['id', 'rgb'], scalar=True)
2938        self.assertIsInstance(r, OrderedDict)
2939        expected = OrderedDict((row[:2], row[2]) for row in colors)
2940        self.assertEqual(r, expected)
2941        for key in r:
2942            self.assertIsInstance(key, tuple)
2943            row = r[key]
2944            self.assertIsInstance(row, str)
2945            t = expected[key]
2946            self.assertEqual(row, t)
2947        if OrderedDict is not dict:  # Python > 2.6
2948            self.assertEqual(r.keys(), expected.keys())
2949        r = get_as_dict(table, keyname='rgb', what=['rgb', 'name'], scalar=True)
2950        self.assertIsInstance(r, OrderedDict)
2951        expected = OrderedDict((row[1], row[2])
2952            for row in sorted(colors, key=itemgetter(1)))
2953        self.assertEqual(r, expected)
2954        for key in r:
2955            self.assertIsInstance(key, str)
2956            row = r[key]
2957            self.assertIsInstance(row, str)
2958            t = expected[key]
2959            self.assertEqual(row, t)
2960        if OrderedDict is not dict:  # Python > 2.6
2961            self.assertEqual(r.keys(), expected.keys())
2962        r = get_as_dict(table, what='id, name',
2963            where="rgb like '#b%'", scalar=True)
2964        self.assertIsInstance(r, OrderedDict)
2965        expected = OrderedDict((row[0], row[2]) for row in colors[1:3])
2966        self.assertEqual(r, expected)
2967        for key in r:
2968            self.assertIsInstance(key, int)
2969            row = r[key]
2970            self.assertIsInstance(row, str)
2971            t = expected[key]
2972            self.assertEqual(row, t)
2973        if OrderedDict is not dict:  # Python > 2.6
2974            self.assertEqual(r.keys(), expected.keys())
2975        expected = r
2976        r = get_as_dict(table, what=['name', 'id'],
2977            where=['id > 1', 'id < 4', "rgb like '#b%'",
2978                   "name not like 'A%'", "name not like '%t'"], scalar=True)
2979        self.assertEqual(r, expected)
2980        r = get_as_dict(table, what='name, id', limit=2, offset=1, scalar=True)
2981        self.assertEqual(r, expected)
2982        r = get_as_dict(table, keyname=('id',), what=('name', 'id'),
2983            where=('id > 1', 'id < 4'), order=('id',), scalar=True)
2984        self.assertEqual(r, expected)
2985        r = get_as_dict(table, limit=1)
2986        self.assertEqual(len(r), 1)
2987        self.assertEqual(r[1][1], 'Aero')
2988        r = get_as_dict(table, offset=3)
2989        self.assertEqual(len(r), 1)
2990        self.assertEqual(r[4][1], 'Desert')
2991        r = get_as_dict(table, order='id desc')
2992        expected = OrderedDict((row[0], row[1:]) for row in reversed(colors))
2993        self.assertEqual(r, expected)
2994        r = get_as_dict(table, where='id > 5')
2995        self.assertIsInstance(r, OrderedDict)
2996        self.assertEqual(len(r), 0)
2997        # test with unordered query
2998        expected = dict((row[0], row[1:]) for row in colors)
2999        r = get_as_dict(table, order=False)
3000        self.assertIsInstance(r, dict)
3001        self.assertEqual(r, expected)
3002        if dict is not OrderedDict:  # Python > 2.6
3003            self.assertNotIsInstance(self, OrderedDict)
3004        # test with arbitrary from clause
3005        from_table = '(select id, lower(name) as n2 from "%s") as t2' % table
3006        # primary key must be passed explicitly in this case
3007        self.assertRaises(pg.ProgrammingError, get_as_dict, from_table)
3008        r = get_as_dict(from_table, 'id')
3009        self.assertIsInstance(r, OrderedDict)
3010        expected = OrderedDict((row[0], (row[2].lower(),)) for row in colors)
3011        self.assertEqual(r, expected)
3012        # test without a primary key
3013        query('alter table "%s" drop constraint "%s_pkey"' % (table, table))
3014        self.assertRaises(KeyError, self.db.pkey, table, flush=True)
3015        self.assertRaises(pg.ProgrammingError, get_as_dict, table)
3016        r = get_as_dict(table, keyname='id')
3017        expected = OrderedDict((row[0], row[1:]) for row in colors)
3018        self.assertIsInstance(r, dict)
3019        self.assertEqual(r, expected)
3020        r = (1, '#007fff', 'Azure')
3021        query('insert into "%s" values ($1, $2, $3)' % table, r)
3022        # the last entry will win
3023        expected[1] = r[1:]
3024        r = get_as_dict(table, keyname='id')
3025        self.assertEqual(r, expected)
3026
3027    def testTransaction(self):
3028        query = self.db.query
3029        self.createTable('test_table', 'n integer', temporary=False)
3030        self.db.begin()
3031        query("insert into test_table values (1)")
3032        query("insert into test_table values (2)")
3033        self.db.commit()
3034        self.db.begin()
3035        query("insert into test_table values (3)")
3036        query("insert into test_table values (4)")
3037        self.db.rollback()
3038        self.db.begin()
3039        query("insert into test_table values (5)")
3040        self.db.savepoint('before6')
3041        query("insert into test_table values (6)")
3042        self.db.rollback('before6')
3043        query("insert into test_table values (7)")
3044        self.db.commit()
3045        self.db.begin()
3046        self.db.savepoint('before8')
3047        query("insert into test_table values (8)")
3048        self.db.release('before8')
3049        self.assertRaises(pg.InternalError, self.db.rollback, 'before8')
3050        self.db.commit()
3051        self.db.start()
3052        query("insert into test_table values (9)")
3053        self.db.end()
3054        r = [r[0] for r in query(
3055            "select * from test_table order by 1").getresult()]
3056        self.assertEqual(r, [1, 2, 5, 7, 9])
3057        self.db.begin(mode='read only')
3058        self.assertRaises(pg.InternalError,
3059                          query, "insert into test_table values (0)")
3060        self.db.rollback()
3061        self.db.start(mode='Read Only')
3062        self.assertRaises(pg.InternalError,
3063                          query, "insert into test_table values (0)")
3064        self.db.abort()
3065
3066    def testTransactionAliases(self):
3067        self.assertEqual(self.db.begin, self.db.start)
3068        self.assertEqual(self.db.commit, self.db.end)
3069        self.assertEqual(self.db.rollback, self.db.abort)
3070
3071    def testContextManager(self):
3072        query = self.db.query
3073        self.createTable('test_table', 'n integer check(n>0)')
3074        with self.db:
3075            query("insert into test_table values (1)")
3076            query("insert into test_table values (2)")
3077        try:
3078            with self.db:
3079                query("insert into test_table values (3)")
3080                query("insert into test_table values (4)")
3081                raise ValueError('test transaction should rollback')
3082        except ValueError as error:
3083            self.assertEqual(str(error), 'test transaction should rollback')
3084        with self.db:
3085            query("insert into test_table values (5)")
3086        try:
3087            with self.db:
3088                query("insert into test_table values (6)")
3089                query("insert into test_table values (-1)")
3090        except pg.IntegrityError as error:
3091            self.assertTrue('check' in str(error))
3092        with self.db:
3093            query("insert into test_table values (7)")
3094        r = [r[0] for r in query(
3095            "select * from test_table order by 1").getresult()]
3096        self.assertEqual(r, [1, 2, 5, 7])
3097
3098    def testBytea(self):
3099        query = self.db.query
3100        self.createTable('bytea_test', 'n smallint primary key, data bytea')
3101        s = b"It's all \\ kinds \x00 of\r nasty \xff stuff!\n"
3102        r = self.db.escape_bytea(s)
3103        query('insert into bytea_test values(3, $1)', (r,))
3104        r = query('select * from bytea_test where n=3').getresult()
3105        self.assertEqual(len(r), 1)
3106        r = r[0]
3107        self.assertEqual(len(r), 2)
3108        self.assertEqual(r[0], 3)
3109        r = r[1]
3110        if pg.get_bytea_escaped():
3111            self.assertNotEqual(r, s)
3112            r = pg.unescape_bytea(r)
3113        self.assertIsInstance(r, bytes)
3114        self.assertEqual(r, s)
3115
3116    def testInsertUpdateGetBytea(self):
3117        query = self.db.query
3118        unescape = pg.unescape_bytea if pg.get_bytea_escaped() else None
3119        self.createTable('bytea_test', 'n smallint primary key, data bytea')
3120        # insert null value
3121        r = self.db.insert('bytea_test', n=0, data=None)
3122        self.assertIsInstance(r, dict)
3123        self.assertIn('n', r)
3124        self.assertEqual(r['n'], 0)
3125        self.assertIn('data', r)
3126        self.assertIsNone(r['data'])
3127        s = b'None'
3128        r = self.db.update('bytea_test', n=0, data=s)
3129        self.assertIsInstance(r, dict)
3130        self.assertIn('n', r)
3131        self.assertEqual(r['n'], 0)
3132        self.assertIn('data', r)
3133        r = r['data']
3134        if unescape:
3135            self.assertNotEqual(r, s)
3136            r = unescape(r)
3137        self.assertIsInstance(r, bytes)
3138        self.assertEqual(r, s)
3139        r = self.db.update('bytea_test', n=0, data=None)
3140        self.assertIsNone(r['data'])
3141        # insert as bytes
3142        s = b"It's all \\ kinds \x00 of\r nasty \xff stuff!\n"
3143        r = self.db.insert('bytea_test', n=5, data=s)
3144        self.assertIsInstance(r, dict)
3145        self.assertIn('n', r)
3146        self.assertEqual(r['n'], 5)
3147        self.assertIn('data', r)
3148        r = r['data']
3149        if unescape:
3150            self.assertNotEqual(r, s)
3151            r = unescape(r)
3152        self.assertIsInstance(r, bytes)
3153        self.assertEqual(r, s)
3154        # update as bytes
3155        s += b"and now even more \x00 nasty \t stuff!\f"
3156        r = self.db.update('bytea_test', n=5, data=s)
3157        self.assertIsInstance(r, dict)
3158        self.assertIn('n', r)
3159        self.assertEqual(r['n'], 5)
3160        self.assertIn('data', r)
3161        r = r['data']
3162        if unescape:
3163            self.assertNotEqual(r, s)
3164            r = unescape(r)
3165        self.assertIsInstance(r, bytes)
3166        self.assertEqual(r, s)
3167        r = query('select * from bytea_test where n=5').getresult()
3168        self.assertEqual(len(r), 1)
3169        r = r[0]
3170        self.assertEqual(len(r), 2)
3171        self.assertEqual(r[0], 5)
3172        r = r[1]
3173        if unescape:
3174            self.assertNotEqual(r, s)
3175            r = unescape(r)
3176        self.assertIsInstance(r, bytes)
3177        self.assertEqual(r, s)
3178        r = self.db.get('bytea_test', dict(n=5))
3179        self.assertIsInstance(r, dict)
3180        self.assertIn('n', r)
3181        self.assertEqual(r['n'], 5)
3182        self.assertIn('data', r)
3183        r = r['data']
3184        if unescape:
3185            self.assertNotEqual(r, s)
3186            r = pg.unescape_bytea(r)
3187        self.assertIsInstance(r, bytes)
3188        self.assertEqual(r, s)
3189
3190    def testUpsertBytea(self):
3191        self.createTable('bytea_test', 'n smallint primary key, data bytea')
3192        s = b"It's all \\ kinds \x00 of\r nasty \xff stuff!\n"
3193        r = dict(n=7, data=s)
3194        try:
3195            r = self.db.upsert('bytea_test', r)
3196        except pg.ProgrammingError as error:
3197            if self.db.server_version < 90500:
3198                self.skipTest('database does not support upsert')
3199            self.fail(str(error))
3200        self.assertIsInstance(r, dict)
3201        self.assertIn('n', r)
3202        self.assertEqual(r['n'], 7)
3203        self.assertIn('data', r)
3204        if pg.get_bytea_escaped():
3205            self.assertNotEqual(r['data'], s)
3206            r['data'] = pg.unescape_bytea(r['data'])
3207        self.assertIsInstance(r['data'], bytes)
3208        self.assertEqual(r['data'], s)
3209        r['data'] = None
3210        r = self.db.upsert('bytea_test', r)
3211        self.assertIsInstance(r, dict)
3212        self.assertIn('n', r)
3213        self.assertEqual(r['n'], 7)
3214        self.assertIn('data', r)
3215        self.assertIsNone(r['data'])
3216
3217    def testInsertGetJson(self):
3218        try:
3219            self.createTable('json_test', 'n smallint primary key, data json')
3220        except pg.ProgrammingError as error:
3221            if self.db.server_version < 90200:
3222                self.skipTest('database does not support json')
3223            self.fail(str(error))
3224        jsondecode = pg.get_jsondecode()
3225        # insert null value
3226        r = self.db.insert('json_test', n=0, data=None)
3227        self.assertIsInstance(r, dict)
3228        self.assertIn('n', r)
3229        self.assertEqual(r['n'], 0)
3230        self.assertIn('data', r)
3231        self.assertIsNone(r['data'])
3232        r = self.db.get('json_test', 0)
3233        self.assertIsInstance(r, dict)
3234        self.assertIn('n', r)
3235        self.assertEqual(r['n'], 0)
3236        self.assertIn('data', r)
3237        self.assertIsNone(r['data'])
3238        # insert JSON object
3239        data = {
3240            "id": 1, "name": "Foo", "price": 1234.5,
3241            "new": True, "note": None,
3242            "tags": ["Bar", "Eek"],
3243            "stock": {"warehouse": 300, "retail": 20}}
3244        r = self.db.insert('json_test', n=1, data=data)
3245        self.assertIsInstance(r, dict)
3246        self.assertIn('n', r)
3247        self.assertEqual(r['n'], 1)
3248        self.assertIn('data', r)
3249        r = r['data']
3250        if jsondecode is None:
3251            self.assertIsInstance(r, str)
3252            r = json.loads(r)
3253        self.assertIsInstance(r, dict)
3254        self.assertEqual(r, data)
3255        self.assertIsInstance(r['id'], int)
3256        self.assertIsInstance(r['name'], unicode)
3257        self.assertIsInstance(r['price'], float)
3258        self.assertIsInstance(r['new'], bool)
3259        self.assertIsInstance(r['tags'], list)
3260        self.assertIsInstance(r['stock'], dict)
3261        r = self.db.get('json_test', 1)
3262        self.assertIsInstance(r, dict)
3263        self.assertIn('n', r)
3264        self.assertEqual(r['n'], 1)
3265        self.assertIn('data', r)
3266        r = r['data']
3267        if jsondecode is None:
3268            self.assertIsInstance(r, str)
3269            r = json.loads(r)
3270        self.assertIsInstance(r, dict)
3271        self.assertEqual(r, data)
3272        self.assertIsInstance(r['id'], int)
3273        self.assertIsInstance(r['name'], unicode)
3274        self.assertIsInstance(r['price'], float)
3275        self.assertIsInstance(r['new'], bool)
3276        self.assertIsInstance(r['tags'], list)
3277        self.assertIsInstance(r['stock'], dict)
3278        # insert JSON object as text
3279        self.db.insert('json_test', n=2, data=json.dumps(data))
3280        q = "select data from json_test where n in (1, 2) order by n"
3281        r = self.db.query(q).getresult()
3282        self.assertEqual(len(r), 2)
3283        self.assertIsInstance(r[0][0], str if jsondecode is None else dict)
3284        self.assertEqual(r[0][0], r[1][0])
3285
3286    def testInsertGetJsonb(self):
3287        try:
3288            self.createTable('jsonb_test',
3289                             'n smallint primary key, data jsonb')
3290        except pg.ProgrammingError as error:
3291            if self.db.server_version < 90400:
3292                self.skipTest('database does not support jsonb')
3293            self.fail(str(error))
3294        jsondecode = pg.get_jsondecode()
3295        # insert null value
3296        r = self.db.insert('jsonb_test', n=0, data=None)
3297        self.assertIsInstance(r, dict)
3298        self.assertIn('n', r)
3299        self.assertEqual(r['n'], 0)
3300        self.assertIn('data', r)
3301        self.assertIsNone(r['data'])
3302        r = self.db.get('jsonb_test', 0)
3303        self.assertIsInstance(r, dict)
3304        self.assertIn('n', r)
3305        self.assertEqual(r['n'], 0)
3306        self.assertIn('data', r)
3307        self.assertIsNone(r['data'])
3308        # insert JSON object
3309        data = {
3310            "id": 1, "name": "Foo", "price": 1234.5,
3311            "new": True, "note": None,
3312            "tags": ["Bar", "Eek"],
3313            "stock": {"warehouse": 300, "retail": 20}}
3314        r = self.db.insert('jsonb_test', n=1, data=data)
3315        self.assertIsInstance(r, dict)
3316        self.assertIn('n', r)
3317        self.assertEqual(r['n'], 1)
3318        self.assertIn('data', r)
3319        r = r['data']
3320        if jsondecode is None:
3321            self.assertIsInstance(r, str)
3322            r = json.loads(r)
3323        self.assertIsInstance(r, dict)
3324        self.assertEqual(r, data)
3325        self.assertIsInstance(r['id'], int)
3326        self.assertIsInstance(r['name'], unicode)
3327        self.assertIsInstance(r['price'], float)
3328        self.assertIsInstance(r['new'], bool)
3329        self.assertIsInstance(r['tags'], list)
3330        self.assertIsInstance(r['stock'], dict)
3331        r = self.db.get('jsonb_test', 1)
3332        self.assertIsInstance(r, dict)
3333        self.assertIn('n', r)
3334        self.assertEqual(r['n'], 1)
3335        self.assertIn('data', r)
3336        r = r['data']
3337        if jsondecode is None:
3338            self.assertIsInstance(r, str)
3339            r = json.loads(r)
3340        self.assertIsInstance(r, dict)
3341        self.assertEqual(r, data)
3342        self.assertIsInstance(r['id'], int)
3343        self.assertIsInstance(r['name'], unicode)
3344        self.assertIsInstance(r['price'], float)
3345        self.assertIsInstance(r['new'], bool)
3346        self.assertIsInstance(r['tags'], list)
3347        self.assertIsInstance(r['stock'], dict)
3348
3349    def testArray(self):
3350        returns_arrays = pg.get_array()
3351        self.createTable('arraytest',
3352            'id smallint, i2 smallint[], i4 integer[], i8 bigint[],'
3353            ' d numeric[], f4 real[], f8 double precision[], m money[],'
3354            ' b bool[], v4 varchar(4)[], c4 char(4)[], t text[]')
3355        r = self.db.get_attnames('arraytest')
3356        if self.regtypes:
3357            self.assertEqual(r, dict(
3358                id='smallint', i2='smallint[]', i4='integer[]', i8='bigint[]',
3359                d='numeric[]', f4='real[]', f8='double precision[]',
3360                m='money[]', b='boolean[]',
3361                v4='character varying[]', c4='character[]', t='text[]'))
3362        else:
3363            self.assertEqual(r, dict(
3364                id='int', i2='int[]', i4='int[]', i8='int[]',
3365                d='num[]', f4='float[]', f8='float[]', m='money[]',
3366                b='bool[]', v4='text[]', c4='text[]', t='text[]'))
3367        decimal = pg.get_decimal()
3368        if decimal is Decimal:
3369            long_decimal = decimal('123456789.123456789')
3370            odd_money = decimal('1234567891234567.89')
3371        else:
3372            long_decimal = decimal('12345671234.5')
3373            odd_money = decimal('1234567123.25')
3374        t, f = (True, False) if pg.get_bool() else ('t', 'f')
3375        data = dict(id=42, i2=[42, 1234, None, 0, -1],
3376            i4=[42, 123456789, None, 0, 1, -1],
3377            i8=[long(42), long(123456789123456789), None,
3378                long(0), long(1), long(-1)],
3379            d=[decimal(42), long_decimal, None,
3380               decimal(0), decimal(1), decimal(-1), -long_decimal],
3381            f4=[42.0, 1234.5, None, 0.0, 1.0, -1.0,
3382                float('inf'), float('-inf')],
3383            f8=[42.0, 12345671234.5, None, 0.0, 1.0, -1.0,
3384                float('inf'), float('-inf')],
3385            m=[decimal('42.00'), odd_money, None,
3386               decimal('0.00'), decimal('1.00'), decimal('-1.00'), -odd_money],
3387            b=[t, f, t, None, f, t, None, None, t],
3388            v4=['abc', '"Hi"', '', None], c4=['abc ', '"Hi"', '    ', None],
3389            t=['abc', 'Hello, World!', '"Hello, World!"', '', None])
3390        r = data.copy()
3391        self.db.insert('arraytest', r)
3392        if returns_arrays:
3393            self.assertEqual(r, data)
3394        else:
3395            self.assertEqual(r['i4'], '{42,123456789,NULL,0,1,-1}')
3396        self.db.insert('arraytest', r)
3397        r = self.db.get('arraytest', 42, 'id')
3398        if returns_arrays:
3399            self.assertEqual(r, data)
3400        else:
3401            self.assertEqual(r['i4'], '{42,123456789,NULL,0,1,-1}')
3402        r = self.db.query('select * from arraytest limit 1').dictresult()[0]
3403        if returns_arrays:
3404            self.assertEqual(r, data)
3405        else:
3406            self.assertEqual(r['i4'], '{42,123456789,NULL,0,1,-1}')
3407
3408    def testArrayLiteral(self):
3409        insert = self.db.insert
3410        returns_arrays = pg.get_array()
3411        self.createTable('arraytest', 'i int[], t text[]', oids=True)
3412        r = dict(i=[1, 2, 3], t=['a', 'b', 'c'])
3413        insert('arraytest', r)
3414        if returns_arrays:
3415            self.assertEqual(r['i'], [1, 2, 3])
3416            self.assertEqual(r['t'], ['a', 'b', 'c'])
3417        else:
3418            self.assertEqual(r['i'], '{1,2,3}')
3419            self.assertEqual(r['t'], '{a,b,c}')
3420        r = dict(i='{1,2,3}', t='{a,b,c}')
3421        self.db.insert('arraytest', r)
3422        if returns_arrays:
3423            self.assertEqual(r['i'], [1, 2, 3])
3424            self.assertEqual(r['t'], ['a', 'b', 'c'])
3425        else:
3426            self.assertEqual(r['i'], '{1,2,3}')
3427            self.assertEqual(r['t'], '{a,b,c}')
3428        L = pg.Literal
3429        r = dict(i=L("ARRAY[1, 2, 3]"), t=L("ARRAY['a', 'b', 'c']"))
3430        self.db.insert('arraytest', r)
3431        if returns_arrays:
3432            self.assertEqual(r['i'], [1, 2, 3])
3433            self.assertEqual(r['t'], ['a', 'b', 'c'])
3434        else:
3435            self.assertEqual(r['i'], '{1,2,3}')
3436            self.assertEqual(r['t'], '{a,b,c}')
3437        r = dict(i="1, 2, 3", t="'a', 'b', 'c'")
3438        self.assertRaises(pg.DataError, self.db.insert, 'arraytest', r)
3439
3440    def testArrayOfIds(self):
3441        array_on = pg.get_array()
3442        self.createTable('arraytest', 'c cid[], o oid[], x xid[]', oids=True)
3443        r = self.db.get_attnames('arraytest')
3444        if self.regtypes:
3445            self.assertEqual(r, dict(
3446                oid='oid', c='cid[]', o='oid[]', x='xid[]'))
3447        else:
3448            self.assertEqual(r, dict(
3449                oid='int', c='int[]', o='int[]', x='int[]'))
3450        data = dict(c=[11, 12, 13], o=[21, 22, 23], x=[31, 32, 33])
3451        r = data.copy()
3452        self.db.insert('arraytest', r)
3453        qoid = 'oid(arraytest)'
3454        oid = r.pop(qoid)
3455        if array_on:
3456            self.assertEqual(r, data)
3457        else:
3458            self.assertEqual(r['o'], '{21,22,23}')
3459        r = {qoid: oid}
3460        self.db.get('arraytest', r)
3461        self.assertEqual(oid, r.pop(qoid))
3462        if array_on:
3463            self.assertEqual(r, data)
3464        else:
3465            self.assertEqual(r['o'], '{21,22,23}')
3466
3467    def testArrayOfText(self):
3468        array_on = pg.get_array()
3469        self.createTable('arraytest', 'data text[]', oids=True)
3470        r = self.db.get_attnames('arraytest')
3471        self.assertEqual(r['data'], 'text[]')
3472        data = ['Hello, World!', '', None, '{a,b,c}', '"Hi!"',
3473                'null', 'NULL', 'Null', 'nulL',
3474                "It's all \\ kinds of\r nasty stuff!\n"]
3475        r = dict(data=data)
3476        self.db.insert('arraytest', r)
3477        if not array_on:
3478            r['data'] = pg.cast_array(r['data'])
3479        self.assertEqual(r['data'], data)
3480        self.assertIsInstance(r['data'][1], str)
3481        self.assertIsNone(r['data'][2])
3482        r['data'] = None
3483        self.db.get('arraytest', r)
3484        if not array_on:
3485            r['data'] = pg.cast_array(r['data'])
3486        self.assertEqual(r['data'], data)
3487        self.assertIsInstance(r['data'][1], str)
3488        self.assertIsNone(r['data'][2])
3489
3490    def testArrayOfBytea(self):
3491        array_on = pg.get_array()
3492        bytea_escaped = pg.get_bytea_escaped()
3493        self.createTable('arraytest', 'data bytea[]', oids=True)
3494        r = self.db.get_attnames('arraytest')
3495        self.assertEqual(r['data'], 'bytea[]')
3496        data = [b'Hello, World!', b'', None, b'{a,b,c}', b'"Hi!"',
3497                b"It's all \\ kinds \x00 of\r nasty \xff stuff!\n"]
3498        r = dict(data=data)
3499        self.db.insert('arraytest', r)
3500        if array_on:
3501            self.assertIsInstance(r['data'], list)
3502        if array_on and not bytea_escaped:
3503            self.assertEqual(r['data'], data)
3504            self.assertIsInstance(r['data'][1], bytes)
3505            self.assertIsNone(r['data'][2])
3506        else:
3507            self.assertNotEqual(r['data'], data)
3508        r['data'] = None
3509        self.db.get('arraytest', r)
3510        if array_on:
3511            self.assertIsInstance(r['data'], list)
3512        if array_on and not bytea_escaped:
3513            self.assertEqual(r['data'], data)
3514            self.assertIsInstance(r['data'][1], bytes)
3515            self.assertIsNone(r['data'][2])
3516        else:
3517            self.assertNotEqual(r['data'], data)
3518
3519    def testArrayOfJson(self):
3520        try:
3521            self.createTable('arraytest', 'data json[]', oids=True)
3522        except pg.ProgrammingError as error:
3523            if self.db.server_version < 90200:
3524                self.skipTest('database does not support json')
3525            self.fail(str(error))
3526        r = self.db.get_attnames('arraytest')
3527        self.assertEqual(r['data'], 'json[]')
3528        data = [dict(id=815, name='John Doe'), dict(id=816, name='Jane Roe')]
3529        array_on = pg.get_array()
3530        jsondecode = pg.get_jsondecode()
3531        r = dict(data=data)
3532        self.db.insert('arraytest', r)
3533        if not array_on:
3534            r['data'] = pg.cast_array(r['data'], jsondecode)
3535        if jsondecode is None:
3536            r['data'] = [json.loads(d) for d in r['data']]
3537        self.assertEqual(r['data'], data)
3538        r['data'] = None
3539        self.db.get('arraytest', r)
3540        if not array_on:
3541            r['data'] = pg.cast_array(r['data'], jsondecode)
3542        if jsondecode is None:
3543            r['data'] = [json.loads(d) for d in r['data']]
3544        self.assertEqual(r['data'], data)
3545        r = dict(data=[json.dumps(d) for d in data])
3546        self.db.insert('arraytest', r)
3547        if not array_on:
3548            r['data'] = pg.cast_array(r['data'], jsondecode)
3549        if jsondecode is None:
3550            r['data'] = [json.loads(d) for d in r['data']]
3551        self.assertEqual(r['data'], data)
3552        r['data'] = None
3553        self.db.get('arraytest', r)
3554        # insert empty json values
3555        r = dict(data=['', None])
3556        self.db.insert('arraytest', r)
3557        r = r['data']
3558        if array_on:
3559            self.assertIsInstance(r, list)
3560            self.assertEqual(len(r), 2)
3561            self.assertIsNone(r[0])
3562            self.assertIsNone(r[1])
3563        else:
3564            self.assertEqual(r, '{NULL,NULL}')
3565
3566    def testArrayOfJsonb(self):
3567        try:
3568            self.createTable('arraytest', 'data jsonb[]', oids=True)
3569        except pg.ProgrammingError as error:
3570            if self.db.server_version < 90400:
3571                self.skipTest('database does not support jsonb')
3572            self.fail(str(error))
3573        r = self.db.get_attnames('arraytest')
3574        self.assertEqual(r['data'], 'jsonb[]' if self.regtypes else 'json[]')
3575        data = [dict(id=815, name='John Doe'), dict(id=816, name='Jane Roe')]
3576        array_on = pg.get_array()
3577        jsondecode = pg.get_jsondecode()
3578        r = dict(data=data)
3579        self.db.insert('arraytest', r)
3580        if not array_on:
3581            r['data'] = pg.cast_array(r['data'], jsondecode)
3582        if jsondecode is None:
3583            r['data'] = [json.loads(d) for d in r['data']]
3584        self.assertEqual(r['data'], data)
3585        r['data'] = None
3586        self.db.get('arraytest', r)
3587        if not array_on:
3588            r['data'] = pg.cast_array(r['data'], jsondecode)
3589        if jsondecode is None:
3590            r['data'] = [json.loads(d) for d in r['data']]
3591        self.assertEqual(r['data'], data)
3592        r = dict(data=[json.dumps(d) for d in data])
3593        self.db.insert('arraytest', r)
3594        if not array_on:
3595            r['data'] = pg.cast_array(r['data'], jsondecode)
3596        if jsondecode is None:
3597            r['data'] = [json.loads(d) for d in r['data']]
3598        self.assertEqual(r['data'], data)
3599        r['data'] = None
3600        self.db.get('arraytest', r)
3601        # insert empty json values
3602        r = dict(data=['', None])
3603        self.db.insert('arraytest', r)
3604        r = r['data']
3605        if array_on:
3606            self.assertIsInstance(r, list)
3607            self.assertEqual(len(r), 2)
3608            self.assertIsNone(r[0])
3609            self.assertIsNone(r[1])
3610        else:
3611            self.assertEqual(r, '{NULL,NULL}')
3612
3613    def testDeepArray(self):
3614        array_on = pg.get_array()
3615        self.createTable('arraytest', 'data text[][][]', oids=True)
3616        r = self.db.get_attnames('arraytest')
3617        self.assertEqual(r['data'], 'text[]')
3618        data = [[['Hello, World!', '{a,b,c}', 'back\\slash']]]
3619        r = dict(data=data)
3620        self.db.insert('arraytest', r)
3621        if array_on:
3622            self.assertEqual(r['data'], data)
3623        else:
3624            self.assertTrue(r['data'].startswith('{{{"Hello,'))
3625        r['data'] = None
3626        self.db.get('arraytest', r)
3627        if array_on:
3628            self.assertEqual(r['data'], data)
3629        else:
3630            self.assertTrue(r['data'].startswith('{{{"Hello,'))
3631
3632    def testInsertUpdateGetRecord(self):
3633        query = self.db.query
3634        query('create type test_person_type as'
3635              ' (name varchar, age smallint, married bool,'
3636              ' weight real, salary money)')
3637        self.addCleanup(query, 'drop type test_person_type')
3638        self.createTable('test_person', 'person test_person_type',
3639                         temporary=False, oids=True)
3640        attnames = self.db.get_attnames('test_person')
3641        self.assertEqual(len(attnames), 2)
3642        self.assertIn('oid', attnames)
3643        self.assertIn('person', attnames)
3644        person_typ = attnames['person']
3645        if self.regtypes:
3646            self.assertEqual(person_typ, 'test_person_type')
3647        else:
3648            self.assertEqual(person_typ, 'record')
3649        if self.regtypes:
3650            self.assertEqual(person_typ.attnames,
3651                dict(name='character varying', age='smallint',
3652                    married='boolean', weight='real', salary='money'))
3653        else:
3654            self.assertEqual(person_typ.attnames,
3655                dict(name='text', age='int', married='bool',
3656                    weight='float', salary='money'))
3657        decimal = pg.get_decimal()
3658        if pg.get_bool():
3659            bool_class = bool
3660            t, f = True, False
3661        else:
3662            bool_class = str
3663            t, f = 't', 'f'
3664        person = ('John Doe', 61, t, 99.5, decimal('93456.75'))
3665        r = self.db.insert('test_person', None, person=person)
3666        p = r['person']
3667        self.assertIsInstance(p, tuple)
3668        self.assertEqual(p, person)
3669        self.assertEqual(p.name, 'John Doe')
3670        self.assertIsInstance(p.name, str)
3671        self.assertIsInstance(p.age, int)
3672        self.assertIsInstance(p.married, bool_class)
3673        self.assertIsInstance(p.weight, float)
3674        self.assertIsInstance(p.salary, decimal)
3675        person = ('Jane Roe', 59, f, 64.5, decimal('96543.25'))
3676        r['person'] = person
3677        self.db.update('test_person', r)
3678        p = r['person']
3679        self.assertIsInstance(p, tuple)
3680        self.assertEqual(p, person)
3681        self.assertEqual(p.name, 'Jane Roe')
3682        self.assertIsInstance(p.name, str)
3683        self.assertIsInstance(p.age, int)
3684        self.assertIsInstance(p.married, bool_class)
3685        self.assertIsInstance(p.weight, float)
3686        self.assertIsInstance(p.salary, decimal)
3687        r['person'] = None
3688        self.db.get('test_person', r)
3689        p = r['person']
3690        self.assertIsInstance(p, tuple)
3691        self.assertEqual(p, person)
3692        self.assertEqual(p.name, 'Jane Roe')
3693        self.assertIsInstance(p.name, str)
3694        self.assertIsInstance(p.age, int)
3695        self.assertIsInstance(p.married, bool_class)
3696        self.assertIsInstance(p.weight, float)
3697        self.assertIsInstance(p.salary, decimal)
3698        person = (None,) * 5
3699        r = self.db.insert('test_person', None, person=person)
3700        p = r['person']
3701        self.assertIsInstance(p, tuple)
3702        self.assertIsNone(p.name)
3703        self.assertIsNone(p.age)
3704        self.assertIsNone(p.married)
3705        self.assertIsNone(p.weight)
3706        self.assertIsNone(p.salary)
3707        r['person'] = None
3708        self.db.get('test_person', r)
3709        p = r['person']
3710        self.assertIsInstance(p, tuple)
3711        self.assertIsNone(p.name)
3712        self.assertIsNone(p.age)
3713        self.assertIsNone(p.married)
3714        self.assertIsNone(p.weight)
3715        self.assertIsNone(p.salary)
3716        r = self.db.insert('test_person', None, person=None)
3717        self.assertIsNone(r['person'])
3718        r['person'] = None
3719        self.db.get('test_person', r)
3720        self.assertIsNone(r['person'])
3721
3722    def testRecordInsertBytea(self):
3723        query = self.db.query
3724        query('create type test_person_type as'
3725              ' (name text, picture bytea)')
3726        self.addCleanup(query, 'drop type test_person_type')
3727        self.createTable('test_person', 'person test_person_type',
3728                         temporary=False, oids=True)
3729        person_typ = self.db.get_attnames('test_person')['person']
3730        self.assertEqual(person_typ.attnames,
3731                         dict(name='text', picture='bytea'))
3732        person = ('John Doe', b'O\x00ps\xff!')
3733        r = self.db.insert('test_person', None, person=person)
3734        p = r['person']
3735        self.assertIsInstance(p, tuple)
3736        self.assertEqual(p, person)
3737        self.assertEqual(p.name, 'John Doe')
3738        self.assertIsInstance(p.name, str)
3739        self.assertEqual(p.picture, person[1])
3740        self.assertIsInstance(p.picture, bytes)
3741
3742    def testRecordInsertJson(self):
3743        query = self.db.query
3744        try:
3745            query('create type test_person_type as'
3746                  ' (name text, data json)')
3747        except pg.ProgrammingError as error:
3748            if self.db.server_version < 90200:
3749                self.skipTest('database does not support json')
3750            self.fail(str(error))
3751        self.addCleanup(query, 'drop type test_person_type')
3752        self.createTable('test_person', 'person test_person_type',
3753                         temporary=False, oids=True)
3754        person_typ = self.db.get_attnames('test_person')['person']
3755        self.assertEqual(person_typ.attnames,
3756                         dict(name='text', data='json'))
3757        person = ('John Doe', dict(age=61, married=True, weight=99.5))
3758        r = self.db.insert('test_person', None, person=person)
3759        p = r['person']
3760        self.assertIsInstance(p, tuple)
3761        if pg.get_jsondecode() is None:
3762            p = p._replace(data=json.loads(p.data))
3763        self.assertEqual(p, person)
3764        self.assertEqual(p.name, 'John Doe')
3765        self.assertIsInstance(p.name, str)
3766        self.assertEqual(p.data, person[1])
3767        self.assertIsInstance(p.data, dict)
3768
3769    def testRecordLiteral(self):
3770        query = self.db.query
3771        query('create type test_person_type as'
3772              ' (name varchar, age smallint)')
3773        self.addCleanup(query, 'drop type test_person_type')
3774        self.createTable('test_person', 'person test_person_type',
3775                         temporary=False, oids=True)
3776        person_typ = self.db.get_attnames('test_person')['person']
3777        if self.regtypes:
3778            self.assertEqual(person_typ, 'test_person_type')
3779        else:
3780            self.assertEqual(person_typ, 'record')
3781        if self.regtypes:
3782            self.assertEqual(person_typ.attnames,
3783                             dict(name='character varying', age='smallint'))
3784        else:
3785            self.assertEqual(person_typ.attnames,
3786                             dict(name='text', age='int'))
3787        person = pg.Literal("('John Doe', 61)")
3788        r = self.db.insert('test_person', None, person=person)
3789        p = r['person']
3790        self.assertIsInstance(p, tuple)
3791        self.assertEqual(p.name, 'John Doe')
3792        self.assertIsInstance(p.name, str)
3793        self.assertEqual(p.age, 61)
3794        self.assertIsInstance(p.age, int)
3795
3796    def testDate(self):
3797        query = self.db.query
3798        for datestyle in ('ISO', 'Postgres, MDY', 'Postgres, DMY',
3799                'SQL, MDY', 'SQL, DMY', 'German'):
3800            self.db.set_parameter('datestyle', datestyle)
3801            d = date(2016, 3, 14)
3802            q = "select $1::date"
3803            r = query(q, (d,)).getresult()[0][0]
3804            self.assertIsInstance(r, date)
3805            self.assertEqual(r, d)
3806            q = "select '10000-08-01'::date, '0099-01-08 BC'::date"
3807            r = query(q).getresult()[0]
3808            self.assertIsInstance(r[0], date)
3809            self.assertIsInstance(r[1], date)
3810            self.assertEqual(r[0], date.max)
3811            self.assertEqual(r[1], date.min)
3812        q = "select 'infinity'::date, '-infinity'::date"
3813        r = query(q).getresult()[0]
3814        self.assertIsInstance(r[0], date)
3815        self.assertIsInstance(r[1], date)
3816        self.assertEqual(r[0], date.max)
3817        self.assertEqual(r[1], date.min)
3818
3819    def testTime(self):
3820        query = self.db.query
3821        d = time(15, 9, 26)
3822        q = "select $1::time"
3823        r = query(q, (d,)).getresult()[0][0]
3824        self.assertIsInstance(r, time)
3825        self.assertEqual(r, d)
3826        d = time(15, 9, 26, 535897)
3827        q = "select $1::time"
3828        r = query(q, (d,)).getresult()[0][0]
3829        self.assertIsInstance(r, time)
3830        self.assertEqual(r, d)
3831
3832    def testTimetz(self):
3833        query = self.db.query
3834        timezones = dict(CET=1, EET=2, EST=-5, UTC=0)
3835        for timezone in sorted(timezones):
3836            tz = '%+03d00' % timezones[timezone]
3837            try:
3838                tzinfo = datetime.strptime(tz, '%z').tzinfo
3839            except ValueError:  # Python < 3.2
3840                tzinfo = pg._get_timezone(tz)
3841            self.db.set_parameter('timezone', timezone)
3842            d = time(15, 9, 26, tzinfo=tzinfo)
3843            q = "select $1::timetz"
3844            r = query(q, (d,)).getresult()[0][0]
3845            self.assertIsInstance(r, time)
3846            self.assertEqual(r, d)
3847            d = time(15, 9, 26, 535897, tzinfo)
3848            q = "select $1::timetz"
3849            r = query(q, (d,)).getresult()[0][0]
3850            self.assertIsInstance(r, time)
3851            self.assertEqual(r, d)
3852
3853    def testTimestamp(self):
3854        query = self.db.query
3855        for datestyle in ('ISO', 'Postgres, MDY', 'Postgres, DMY',
3856                'SQL, MDY', 'SQL, DMY', 'German'):
3857            self.db.set_parameter('datestyle', datestyle)
3858            d = datetime(2016, 3, 14)
3859            q = "select $1::timestamp"
3860            r = query(q, (d,)).getresult()[0][0]
3861            self.assertIsInstance(r, datetime)
3862            self.assertEqual(r, d)
3863            d = datetime(2016, 3, 14, 15, 9, 26)
3864            q = "select $1::timestamp"
3865            r = query(q, (d,)).getresult()[0][0]
3866            self.assertIsInstance(r, datetime)
3867            self.assertEqual(r, d)
3868            d = datetime(2016, 3, 14, 15, 9, 26, 535897)
3869            q = "select $1::timestamp"
3870            r = query(q, (d,)).getresult()[0][0]
3871            self.assertIsInstance(r, datetime)
3872            self.assertEqual(r, d)
3873            q = ("select '10000-08-01 AD'::timestamp,"
3874                " '0099-01-08 BC'::timestamp")
3875            r = query(q).getresult()[0]
3876            self.assertIsInstance(r[0], datetime)
3877            self.assertIsInstance(r[1], datetime)
3878            self.assertEqual(r[0], datetime.max)
3879            self.assertEqual(r[1], datetime.min)
3880        q = "select 'infinity'::timestamp, '-infinity'::timestamp"
3881        r = query(q).getresult()[0]
3882        self.assertIsInstance(r[0], datetime)
3883        self.assertIsInstance(r[1], datetime)
3884        self.assertEqual(r[0], datetime.max)
3885        self.assertEqual(r[1], datetime.min)
3886
3887    def testTimestamptz(self):
3888        query = self.db.query
3889        timezones = dict(CET=1, EET=2, EST=-5, UTC=0)
3890        for timezone in sorted(timezones):
3891            tz = '%+03d00' % timezones[timezone]
3892            try:
3893                tzinfo = datetime.strptime(tz, '%z').tzinfo
3894            except ValueError:  # Python < 3.2
3895                tzinfo = pg._get_timezone(tz)
3896            self.db.set_parameter('timezone', timezone)
3897            for datestyle in ('ISO', 'Postgres, MDY', 'Postgres, DMY',
3898                    'SQL, MDY', 'SQL, DMY', 'German'):
3899                self.db.set_parameter('datestyle', datestyle)
3900                d = datetime(2016, 3, 14, tzinfo=tzinfo)
3901                q = "select $1::timestamptz"
3902                r = query(q, (d,)).getresult()[0][0]
3903                self.assertIsInstance(r, datetime)
3904                self.assertEqual(r, d)
3905                d = datetime(2016, 3, 14, 15, 9, 26, tzinfo=tzinfo)
3906                q = "select $1::timestamptz"
3907                r = query(q, (d,)).getresult()[0][0]
3908                self.assertIsInstance(r, datetime)
3909                self.assertEqual(r, d)
3910                d = datetime(2016, 3, 14, 15, 9, 26, 535897, tzinfo)
3911                q = "select $1::timestamptz"
3912                r = query(q, (d,)).getresult()[0][0]
3913                self.assertIsInstance(r, datetime)
3914                self.assertEqual(r, d)
3915                q = ("select '10000-08-01 AD'::timestamptz,"
3916                    " '0099-01-08 BC'::timestamptz")
3917                r = query(q).getresult()[0]
3918                self.assertIsInstance(r[0], datetime)
3919                self.assertIsInstance(r[1], datetime)
3920                self.assertEqual(r[0], datetime.max)
3921                self.assertEqual(r[1], datetime.min)
3922        q = "select 'infinity'::timestamptz, '-infinity'::timestamptz"
3923        r = query(q).getresult()[0]
3924        self.assertIsInstance(r[0], datetime)
3925        self.assertIsInstance(r[1], datetime)
3926        self.assertEqual(r[0], datetime.max)
3927        self.assertEqual(r[1], datetime.min)
3928
3929    def testInterval(self):
3930        query = self.db.query
3931        for intervalstyle in (
3932                'sql_standard', 'postgres', 'postgres_verbose', 'iso_8601'):
3933            self.db.set_parameter('intervalstyle', intervalstyle)
3934            d = timedelta(3)
3935            q = "select $1::interval"
3936            r = query(q, (d,)).getresult()[0][0]
3937            self.assertIsInstance(r, timedelta)
3938            self.assertEqual(r, d)
3939            d = timedelta(-30)
3940            r = query(q, (d,)).getresult()[0][0]
3941            self.assertIsInstance(r, timedelta)
3942            self.assertEqual(r, d)
3943            d = timedelta(hours=3, minutes=31, seconds=42, microseconds=5678)
3944            q = "select $1::interval"
3945            r = query(q, (d,)).getresult()[0][0]
3946            self.assertIsInstance(r, timedelta)
3947            self.assertEqual(r, d)
3948
3949    def testDateAndTimeArrays(self):
3950        dt = (date(2016, 3, 14), time(15, 9, 26))
3951        q = "select ARRAY[$1::date], ARRAY[$2::time]"
3952        r = self.db.query(q, dt).getresult()[0]
3953        self.assertIsInstance(r[0], list)
3954        self.assertEqual(r[0][0], dt[0])
3955        self.assertIsInstance(r[1], list)
3956        self.assertEqual(r[1][0], dt[1])
3957
3958    def testHstore(self):
3959        try:
3960            self.db.query("select 'k=>v'::hstore")
3961        except pg.DatabaseError:
3962            try:
3963                self.db.query("create extension hstore")
3964            except pg.DatabaseError:
3965                self.skipTest("hstore extension not enabled")
3966        d = {'k': 'v', 'foo': 'bar', 'baz': 'whatever',
3967            '1a': 'anything at all', '2=b': 'value = 2', '3>c': 'value > 3',
3968            '4"c': 'value " 4', "5'c": "value ' 5", 'hello, world': '"hi!"',
3969            'None': None, 'NULL': 'NULL', 'empty': ''}
3970        q = "select $1::hstore"
3971        r = self.db.query(q, (pg.Hstore(d),)).getresult()[0][0]
3972        self.assertIsInstance(r, dict)
3973        self.assertEqual(r, d)
3974
3975    def testUuid(self):
3976        d = UUID('{12345678-1234-5678-1234-567812345678}')
3977        q = 'select $1::uuid'
3978        r = self.db.query(q, (d,)).getresult()[0][0]
3979        self.assertIsInstance(r, UUID)
3980        self.assertEqual(r, d)
3981
3982    def testDbTypesInfo(self):
3983        dbtypes = self.db.dbtypes
3984        self.assertIsInstance(dbtypes, dict)
3985        self.assertNotIn('numeric', dbtypes)
3986        typ = dbtypes['numeric']
3987        self.assertIn('numeric', dbtypes)
3988        self.assertEqual(typ, 'numeric' if self.regtypes else 'num')
3989        self.assertEqual(typ.oid, 1700)
3990        self.assertEqual(typ.pgtype, 'numeric')
3991        self.assertEqual(typ.regtype, 'numeric')
3992        self.assertEqual(typ.simple, 'num')
3993        self.assertEqual(typ.typtype, 'b')
3994        self.assertEqual(typ.category, 'N')
3995        self.assertEqual(typ.delim, ',')
3996        self.assertEqual(typ.relid, 0)
3997        self.assertIs(dbtypes[1700], typ)
3998        self.assertNotIn('pg_type', dbtypes)
3999        typ = dbtypes['pg_type']
4000        self.assertIn('pg_type', dbtypes)
4001        self.assertEqual(typ, 'pg_type' if self.regtypes else 'record')
4002        self.assertIsInstance(typ.oid, int)
4003        self.assertEqual(typ.pgtype, 'pg_type')
4004        self.assertEqual(typ.regtype, 'pg_type')
4005        self.assertEqual(typ.simple, 'record')
4006        self.assertEqual(typ.typtype, 'c')
4007        self.assertEqual(typ.category, 'C')
4008        self.assertEqual(typ.delim, ',')
4009        self.assertNotEqual(typ.relid, 0)
4010        attnames = typ.attnames
4011        self.assertIsInstance(attnames, dict)
4012        self.assertIs(attnames, dbtypes.get_attnames('pg_type'))
4013        self.assertEqual(list(attnames)[0], 'typname')
4014        typname = attnames['typname']
4015        self.assertEqual(typname, 'name' if self.regtypes else 'text')
4016        self.assertEqual(typname.typtype, 'b')  # base
4017        self.assertEqual(typname.category, 'S')  # string
4018        self.assertEqual(list(attnames)[3], 'typlen')
4019        typlen = attnames['typlen']
4020        self.assertEqual(typlen, 'smallint' if self.regtypes else 'int')
4021        self.assertEqual(typlen.typtype, 'b')  # base
4022        self.assertEqual(typlen.category, 'N')  # numeric
4023
4024    def testDbTypesTypecast(self):
4025        dbtypes = self.db.dbtypes
4026        self.assertIsInstance(dbtypes, dict)
4027        self.assertNotIn('int4', dbtypes)
4028        self.assertIs(dbtypes.get_typecast('int4'), int)
4029        dbtypes.set_typecast('int4', float)
4030        self.assertIs(dbtypes.get_typecast('int4'), float)
4031        dbtypes.reset_typecast('int4')
4032        self.assertIs(dbtypes.get_typecast('int4'), int)
4033        dbtypes.set_typecast('int4', float)
4034        self.assertIs(dbtypes.get_typecast('int4'), float)
4035        dbtypes.reset_typecast()
4036        self.assertIs(dbtypes.get_typecast('int4'), int)
4037        self.assertNotIn('circle', dbtypes)
4038        self.assertIsNone(dbtypes.get_typecast('circle'))
4039        squared_circle = lambda v: 'Squared Circle: %s' % v
4040        dbtypes.set_typecast('circle', squared_circle)
4041        self.assertIs(dbtypes.get_typecast('circle'), squared_circle)
4042        r = self.db.query("select '0,0,1'::circle").getresult()[0][0]
4043        self.assertIn('circle', dbtypes)
4044        self.assertEqual(r, 'Squared Circle: <(0,0),1>')
4045        self.assertEqual(dbtypes.typecast('Impossible', 'circle'),
4046            'Squared Circle: Impossible')
4047        dbtypes.reset_typecast('circle')
4048        self.assertIsNone(dbtypes.get_typecast('circle'))
4049
4050    def testGetSetTypeCast(self):
4051        get_typecast = pg.get_typecast
4052        set_typecast = pg.set_typecast
4053        dbtypes = self.db.dbtypes
4054        self.assertIsInstance(dbtypes, dict)
4055        self.assertNotIn('int4', dbtypes)
4056        self.assertNotIn('real', dbtypes)
4057        self.assertNotIn('bool', dbtypes)
4058        self.assertIs(get_typecast('int4'), int)
4059        self.assertIs(get_typecast('float4'), float)
4060        self.assertIs(get_typecast('bool'), pg.cast_bool)
4061        cast_circle = get_typecast('circle')
4062        self.addCleanup(set_typecast, 'circle', cast_circle)
4063        squared_circle = lambda v: 'Squared Circle: %s' % v
4064        self.assertNotIn('circle', dbtypes)
4065        set_typecast('circle', squared_circle)
4066        self.assertNotIn('circle', dbtypes)
4067        self.assertIs(get_typecast('circle'), squared_circle)
4068        r = self.db.query("select '0,0,1'::circle").getresult()[0][0]
4069        self.assertIn('circle', dbtypes)
4070        self.assertEqual(r, 'Squared Circle: <(0,0),1>')
4071        set_typecast('circle', cast_circle)
4072        self.assertIs(get_typecast('circle'), cast_circle)
4073
4074    def testNotificationHandler(self):
4075        # the notification handler itself is tested separately
4076        f = self.db.notification_handler
4077        callback = lambda arg_dict: None
4078        handler = f('test', callback)
4079        self.assertIsInstance(handler, pg.NotificationHandler)
4080        self.assertIs(handler.db, self.db)
4081        self.assertEqual(handler.event, 'test')
4082        self.assertEqual(handler.stop_event, 'stop_test')
4083        self.assertIs(handler.callback, callback)
4084        self.assertIsInstance(handler.arg_dict, dict)
4085        self.assertEqual(handler.arg_dict, {})
4086        self.assertIsNone(handler.timeout)
4087        self.assertFalse(handler.listening)
4088        handler.close()
4089        self.assertIsNone(handler.db)
4090        self.db.reopen()
4091        self.assertIsNone(handler.db)
4092        handler = f('test2', callback, timeout=2)
4093        self.assertIsInstance(handler, pg.NotificationHandler)
4094        self.assertIs(handler.db, self.db)
4095        self.assertEqual(handler.event, 'test2')
4096        self.assertEqual(handler.stop_event, 'stop_test2')
4097        self.assertIs(handler.callback, callback)
4098        self.assertIsInstance(handler.arg_dict, dict)
4099        self.assertEqual(handler.arg_dict, {})
4100        self.assertEqual(handler.timeout, 2)
4101        self.assertFalse(handler.listening)
4102        handler.close()
4103        self.assertIsNone(handler.db)
4104        self.db.reopen()
4105        self.assertIsNone(handler.db)
4106        arg_dict = {'testing': 3}
4107        handler = f('test3', callback, arg_dict=arg_dict)
4108        self.assertIsInstance(handler, pg.NotificationHandler)
4109        self.assertIs(handler.db, self.db)
4110        self.assertEqual(handler.event, 'test3')
4111        self.assertEqual(handler.stop_event, 'stop_test3')
4112        self.assertIs(handler.callback, callback)
4113        self.assertIs(handler.arg_dict, arg_dict)
4114        self.assertEqual(arg_dict['testing'], 3)
4115        self.assertIsNone(handler.timeout)
4116        self.assertFalse(handler.listening)
4117        handler.close()
4118        self.assertIsNone(handler.db)
4119        self.db.reopen()
4120        self.assertIsNone(handler.db)
4121        handler = f('test4', callback, stop_event='stop4')
4122        self.assertIsInstance(handler, pg.NotificationHandler)
4123        self.assertIs(handler.db, self.db)
4124        self.assertEqual(handler.event, 'test4')
4125        self.assertEqual(handler.stop_event, 'stop4')
4126        self.assertIs(handler.callback, callback)
4127        self.assertIsInstance(handler.arg_dict, dict)
4128        self.assertEqual(handler.arg_dict, {})
4129        self.assertIsNone(handler.timeout)
4130        self.assertFalse(handler.listening)
4131        handler.close()
4132        self.assertIsNone(handler.db)
4133        self.db.reopen()
4134        self.assertIsNone(handler.db)
4135        arg_dict = {'testing': 5}
4136        handler = f('test5', callback, arg_dict, 1.5, 'stop5')
4137        self.assertIsInstance(handler, pg.NotificationHandler)
4138        self.assertIs(handler.db, self.db)
4139        self.assertEqual(handler.event, 'test5')
4140        self.assertEqual(handler.stop_event, 'stop5')
4141        self.assertIs(handler.callback, callback)
4142        self.assertIs(handler.arg_dict, arg_dict)
4143        self.assertEqual(arg_dict['testing'], 5)
4144        self.assertEqual(handler.timeout, 1.5)
4145        self.assertFalse(handler.listening)
4146        handler.close()
4147        self.assertIsNone(handler.db)
4148        self.db.reopen()
4149        self.assertIsNone(handler.db)
4150
4151
4152class TestDBClassNonStdOpts(TestDBClass):
4153    """Test the methods of the DB class with non-standard global options."""
4154
4155    @classmethod
4156    def setUpClass(cls):
4157        cls.saved_options = {}
4158        cls.set_option('decimal', float)
4159        not_bool = not pg.get_bool()
4160        cls.set_option('bool', not_bool)
4161        not_array = not pg.get_array()
4162        cls.set_option('array', not_array)
4163        not_bytea_escaped = not pg.get_bytea_escaped()
4164        cls.set_option('bytea_escaped', not_bytea_escaped)
4165        cls.set_option('jsondecode', None)
4166        db = DB()
4167        cls.regtypes = not db.use_regtypes()
4168        db.close()
4169        super(TestDBClassNonStdOpts, cls).setUpClass()
4170
4171    @classmethod
4172    def tearDownClass(cls):
4173        super(TestDBClassNonStdOpts, cls).tearDownClass()
4174        cls.reset_option('jsondecode')
4175        cls.reset_option('bool')
4176        cls.reset_option('array')
4177        cls.reset_option('bytea_escaped')
4178        cls.reset_option('decimal')
4179
4180    @classmethod
4181    def set_option(cls, option, value):
4182        cls.saved_options[option] = getattr(pg, 'get_' + option)()
4183        return getattr(pg, 'set_' + option)(value)
4184
4185    @classmethod
4186    def reset_option(cls, option):
4187        return getattr(pg, 'set_' + option)(cls.saved_options[option])
4188
4189
4190class TestDBClassAdapter(unittest.TestCase):
4191    """Test the adapter object associated with the DB class."""
4192
4193    def setUp(self):
4194        self.db = DB()
4195        self.adapter = self.db.adapter
4196
4197    def tearDown(self):
4198        try:
4199            self.db.close()
4200        except pg.InternalError:
4201            pass
4202
4203    def testGuessSimpleType(self):
4204        f = self.adapter.guess_simple_type
4205        self.assertEqual(f(pg.Bytea(b'test')), 'bytea')
4206        self.assertEqual(f('string'), 'text')
4207        self.assertEqual(f(b'string'), 'text')
4208        self.assertEqual(f(True), 'bool')
4209        self.assertEqual(f(3), 'int')
4210        self.assertEqual(f(2.75), 'float')
4211        self.assertEqual(f(Decimal('4.25')), 'num')
4212        self.assertEqual(f(date(2016, 1, 30)), 'date')
4213        self.assertEqual(f([1, 2, 3]), 'int[]')
4214        self.assertEqual(f([[[123]]]), 'int[]')
4215        self.assertEqual(f(['a', 'b', 'c']), 'text[]')
4216        self.assertEqual(f([[['abc']]]), 'text[]')
4217        self.assertEqual(f([False, True]), 'bool[]')
4218        self.assertEqual(f([[[False]]]), 'bool[]')
4219        r = f(('string', True, 3, 2.75, [1], [False]))
4220        self.assertEqual(r, 'record')
4221        self.assertEqual(list(r.attnames.values()),
4222            ['text', 'bool', 'int', 'float', 'int[]', 'bool[]'])
4223
4224    def testAdaptQueryTypedList(self):
4225        format_query = self.adapter.format_query
4226        self.assertRaises(TypeError, format_query,
4227            '%s,%s', (1, 2), ('int2',))
4228        self.assertRaises(TypeError, format_query,
4229            '%s,%s', (1,), ('int2', 'int2'))
4230        values = (3, 7.5, 'hello', True)
4231        types = ('int4', 'float4', 'text', 'bool')
4232        sql, params = format_query("select %s,%s,%s,%s", values, types)
4233        self.assertEqual(sql, 'select $1,$2,$3,$4')
4234        self.assertEqual(params, [3, 7.5, 'hello', 't'])
4235        types = ('bool', 'bool', 'bool', 'bool')
4236        sql, params = format_query("select %s,%s,%s,%s", values, types)
4237        self.assertEqual(sql, 'select $1,$2,$3,$4')
4238        self.assertEqual(params, ['t', 't', 'f', 't'])
4239        values = ('2016-01-30', 'current_date')
4240        types = ('date', 'date')
4241        sql, params = format_query("values(%s,%s)", values, types)
4242        self.assertEqual(sql, 'values($1,current_date)')
4243        self.assertEqual(params, ['2016-01-30'])
4244        values = ([1, 2, 3], ['a', 'b', 'c'])
4245        types = ('_int4', '_text')
4246        sql, params = format_query("%s::int4[],%s::text[]", values, types)
4247        self.assertEqual(sql, '$1::int4[],$2::text[]')
4248        self.assertEqual(params, ['{1,2,3}', '{a,b,c}'])
4249        types = ('_bool', '_bool')
4250        sql, params = format_query("%s::bool[],%s::bool[]", values, types)
4251        self.assertEqual(sql, '$1::bool[],$2::bool[]')
4252        self.assertEqual(params, ['{t,t,t}', '{f,f,f}'])
4253        values = [(3, 7.5, 'hello', True, [123], ['abc'])]
4254        t = self.adapter.simple_type
4255        typ = t('record')
4256        typ._get_attnames = lambda _self: pg.AttrDict([
4257            ('i', t('int')), ('f', t('float')),
4258            ('t', t('text')), ('b', t('bool')),
4259            ('i3', t('int[]')), ('t3', t('text[]'))])
4260        types = [typ]
4261        sql, params = format_query('select %s', values, types)
4262        self.assertEqual(sql, 'select $1')
4263        self.assertEqual(params, ['(3,7.5,hello,t,{123},{abc})'])
4264
4265    def testAdaptQueryTypedDict(self):
4266        format_query = self.adapter.format_query
4267        self.assertRaises(TypeError, format_query,
4268            '%s,%s', dict(i1=1, i2=2), dict(i1='int2'))
4269        values = dict(i=3, f=7.5, t='hello', b=True)
4270        types = dict(i='int4', f='float4',
4271            t='text', b='bool')
4272        sql, params = format_query(
4273            "select %(i)s,%(f)s,%(t)s,%(b)s", values, types)
4274        self.assertEqual(sql, 'select $3,$2,$4,$1')
4275        self.assertEqual(params, ['t', 7.5, 3, 'hello'])
4276        types = dict(i='bool', f='bool',
4277            t='bool', b='bool')
4278        sql, params = format_query(
4279            "select %(i)s,%(f)s,%(t)s,%(b)s", values, types)
4280        self.assertEqual(sql, 'select $3,$2,$4,$1')
4281        self.assertEqual(params, ['t', 't', 't', 'f'])
4282        values = dict(d1='2016-01-30', d2='current_date')
4283        types = dict(d1='date', d2='date')
4284        sql, params = format_query("values(%(d1)s,%(d2)s)", values, types)
4285        self.assertEqual(sql, 'values($1,current_date)')
4286        self.assertEqual(params, ['2016-01-30'])
4287        values = dict(i=[1, 2, 3], t=['a', 'b', 'c'])
4288        types = dict(i='_int4', t='_text')
4289        sql, params = format_query(
4290            "%(i)s::int4[],%(t)s::text[]", values, types)
4291        self.assertEqual(sql, '$1::int4[],$2::text[]')
4292        self.assertEqual(params, ['{1,2,3}', '{a,b,c}'])
4293        types = dict(i='_bool', t='_bool')
4294        sql, params = format_query(
4295            "%(i)s::bool[],%(t)s::bool[]", values, types)
4296        self.assertEqual(sql, '$1::bool[],$2::bool[]')
4297        self.assertEqual(params, ['{t,t,t}', '{f,f,f}'])
4298        values = dict(record=(3, 7.5, 'hello', True, [123], ['abc']))
4299        t = self.adapter.simple_type
4300        typ = t('record')
4301        typ._get_attnames = lambda _self: pg.AttrDict([
4302            ('i', t('int')), ('f', t('float')),
4303            ('t', t('text')), ('b', t('bool')),
4304            ('i3', t('int[]')), ('t3', t('text[]'))])
4305        types = dict(record=typ)
4306        sql, params = format_query('select %(record)s', values, types)
4307        self.assertEqual(sql, 'select $1')
4308        self.assertEqual(params, ['(3,7.5,hello,t,{123},{abc})'])
4309
4310    def testAdaptQueryUntypedList(self):
4311        format_query = self.adapter.format_query
4312        values = (3, 7.5, 'hello', True)
4313        sql, params = format_query("select %s,%s,%s,%s", values)
4314        self.assertEqual(sql, 'select $1,$2,$3,$4')
4315        self.assertEqual(params, [3, 7.5, 'hello', 't'])
4316        values = [date(2016, 1, 30), 'current_date']
4317        sql, params = format_query("values(%s,%s)", values)
4318        self.assertEqual(sql, 'values($1,$2)')
4319        self.assertEqual(params, values)
4320        values = ([1, 2, 3], ['a', 'b', 'c'], [True, False, True])
4321        sql, params = format_query("%s,%s,%s", values)
4322        self.assertEqual(sql, "$1,$2,$3")
4323        self.assertEqual(params, ['{1,2,3}', '{a,b,c}', '{t,f,t}'])
4324        values = ([[1, 2], [3, 4]], [['a', 'b'], ['c', 'd']],
4325            [[True, False], [False, True]])
4326        sql, params = format_query("%s,%s,%s", values)
4327        self.assertEqual(sql, "$1,$2,$3")
4328        self.assertEqual(params, [
4329            '{{1,2},{3,4}}', '{{a,b},{c,d}}', '{{t,f},{f,t}}'])
4330        values = [(3, 7.5, 'hello', True, [123], ['abc'])]
4331        sql, params = format_query('select %s', values)
4332        self.assertEqual(sql, 'select $1')
4333        self.assertEqual(params, ['(3,7.5,hello,t,{123},{abc})'])
4334
4335    def testAdaptQueryUntypedDict(self):
4336        format_query = self.adapter.format_query
4337        values = dict(i=3, f=7.5, t='hello', b=True)
4338        sql, params = format_query(
4339            "select %(i)s,%(f)s,%(t)s,%(b)s", values)
4340        self.assertEqual(sql, 'select $3,$2,$4,$1')
4341        self.assertEqual(params, ['t', 7.5, 3, 'hello'])
4342        values = dict(d1='2016-01-30', d2='current_date')
4343        sql, params = format_query("values(%(d1)s,%(d2)s)", values)
4344        self.assertEqual(sql, 'values($1,$2)')
4345        self.assertEqual(params, [values['d1'], values['d2']])
4346        values = dict(i=[1, 2, 3], t=['a', 'b', 'c'], b=[True, False, True])
4347        sql, params = format_query("%(i)s,%(t)s,%(b)s", values)
4348        self.assertEqual(sql, "$2,$3,$1")
4349        self.assertEqual(params, ['{t,f,t}', '{1,2,3}', '{a,b,c}'])
4350        values = dict(i=[[1, 2], [3, 4]], t=[['a', 'b'], ['c', 'd']],
4351            b=[[True, False], [False, True]])
4352        sql, params = format_query("%(i)s,%(t)s,%(b)s", values)
4353        self.assertEqual(sql, "$2,$3,$1")
4354        self.assertEqual(params, [
4355            '{{t,f},{f,t}}', '{{1,2},{3,4}}', '{{a,b},{c,d}}'])
4356        values = dict(record=(3, 7.5, 'hello', True, [123], ['abc']))
4357        sql, params = format_query('select %(record)s', values)
4358        self.assertEqual(sql, 'select $1')
4359        self.assertEqual(params, ['(3,7.5,hello,t,{123},{abc})'])
4360
4361    def testAdaptQueryInlineList(self):
4362        format_query = self.adapter.format_query
4363        values = (3, 7.5, 'hello', True)
4364        sql, params = format_query("select %s,%s,%s,%s", values, inline=True)
4365        self.assertEqual(sql, "select 3,7.5,'hello',true")
4366        self.assertEqual(params, [])
4367        values = [date(2016, 1, 30), 'current_date']
4368        sql, params = format_query("values(%s,%s)", values, inline=True)
4369        self.assertEqual(sql, "values('2016-01-30','current_date')")
4370        self.assertEqual(params, [])
4371        values = ([1, 2, 3], ['a', 'b', 'c'], [True, False, True])
4372        sql, params = format_query("%s,%s,%s", values, inline=True)
4373        self.assertEqual(sql,
4374            "ARRAY[1,2,3],ARRAY['a','b','c'],ARRAY[true,false,true]")
4375        self.assertEqual(params, [])
4376        values = ([[1, 2], [3, 4]], [['a', 'b'], ['c', 'd']],
4377            [[True, False], [False, True]])
4378        sql, params = format_query("%s,%s,%s", values, inline=True)
4379        self.assertEqual(sql, "ARRAY[[1,2],[3,4]],ARRAY[['a','b'],['c','d']],"
4380            "ARRAY[[true,false],[false,true]]")
4381        self.assertEqual(params, [])
4382        values = [(3, 7.5, 'hello', True, [123], ['abc'])]
4383        sql, params = format_query('select %s', values, inline=True)
4384        self.assertEqual(sql,
4385            "select (3,7.5,'hello',true,ARRAY[123],ARRAY['abc'])")
4386        self.assertEqual(params, [])
4387
4388    def testAdaptQueryInlineDict(self):
4389        format_query = self.adapter.format_query
4390        values = dict(i=3, f=7.5, t='hello', b=True)
4391        sql, params = format_query(
4392            "select %(i)s,%(f)s,%(t)s,%(b)s", values, inline=True)
4393        self.assertEqual(sql, "select 3,7.5,'hello',true")
4394        self.assertEqual(params, [])
4395        values = dict(d1='2016-01-30', d2='current_date')
4396        sql, params = format_query(
4397            "values(%(d1)s,%(d2)s)", values, inline=True)
4398        self.assertEqual(sql, "values('2016-01-30','current_date')")
4399        self.assertEqual(params, [])
4400        values = dict(i=[1, 2, 3], t=['a', 'b', 'c'], b=[True, False, True])
4401        sql, params = format_query("%(i)s,%(t)s,%(b)s", values, inline=True)
4402        self.assertEqual(sql,
4403            "ARRAY[1,2,3],ARRAY['a','b','c'],ARRAY[true,false,true]")
4404        self.assertEqual(params, [])
4405        values = dict(i=[[1, 2], [3, 4]], t=[['a', 'b'], ['c', 'd']],
4406            b=[[True, False], [False, True]])
4407        sql, params = format_query("%(i)s,%(t)s,%(b)s", values, inline=True)
4408        self.assertEqual(sql, "ARRAY[[1,2],[3,4]],ARRAY[['a','b'],['c','d']],"
4409            "ARRAY[[true,false],[false,true]]")
4410        self.assertEqual(params, [])
4411        values = dict(record=(3, 7.5, 'hello', True, [123], ['abc']))
4412        sql, params = format_query('select %(record)s', values, inline=True)
4413        self.assertEqual(sql,
4414            "select (3,7.5,'hello',true,ARRAY[123],ARRAY['abc'])")
4415        self.assertEqual(params, [])
4416
4417    def testAdaptQueryWithPgRepr(self):
4418        format_query = self.adapter.format_query
4419        self.assertRaises(TypeError, format_query,
4420            '%s', object(), inline=True)
4421        class TestObject:
4422            def __pg_repr__(self):
4423                return "'adapted'"
4424        sql, params = format_query('select %s', [TestObject()], inline=True)
4425        self.assertEqual(sql, "select 'adapted'")
4426        self.assertEqual(params, [])
4427        sql, params = format_query('select %s', [[TestObject()]], inline=True)
4428        self.assertEqual(sql, "select ARRAY['adapted']")
4429        self.assertEqual(params, [])
4430
4431
4432class TestSchemas(unittest.TestCase):
4433    """Test correct handling of schemas (namespaces)."""
4434
4435    cls_set_up = False
4436
4437    @classmethod
4438    def setUpClass(cls):
4439        db = DB()
4440        query = db.query
4441        for num_schema in range(5):
4442            if num_schema:
4443                schema = "s%d" % num_schema
4444                query("drop schema if exists %s cascade" % (schema,))
4445                try:
4446                    query("create schema %s" % (schema,))
4447                except pg.ProgrammingError:
4448                    raise RuntimeError("The test user cannot create schemas.\n"
4449                        "Grant create on database %s to the user"
4450                        " for running these tests." % dbname)
4451            else:
4452                schema = "public"
4453                query("drop table if exists %s.t" % (schema,))
4454                query("drop table if exists %s.t%d" % (schema, num_schema))
4455            query("create table %s.t with oids as select 1 as n, %d as d"
4456                  % (schema, num_schema))
4457            query("create table %s.t%d with oids as select 1 as n, %d as d"
4458                  % (schema, num_schema, num_schema))
4459        db.close()
4460        cls.cls_set_up = True
4461
4462    @classmethod
4463    def tearDownClass(cls):
4464        db = DB()
4465        query = db.query
4466        for num_schema in range(5):
4467            if num_schema:
4468                schema = "s%d" % num_schema
4469                query("drop schema %s cascade" % (schema,))
4470            else:
4471                schema = "public"
4472                query("drop table %s.t" % (schema,))
4473                query("drop table %s.t%d" % (schema, num_schema))
4474        db.close()
4475
4476    def setUp(self):
4477        self.assertTrue(self.cls_set_up)
4478        self.db = DB()
4479
4480    def tearDown(self):
4481        self.doCleanups()
4482        self.db.close()
4483
4484    def testGetTables(self):
4485        tables = self.db.get_tables()
4486        for num_schema in range(5):
4487            if num_schema:
4488                schema = "s" + str(num_schema)
4489            else:
4490                schema = "public"
4491            for t in (schema + ".t",
4492                    schema + ".t" + str(num_schema)):
4493                self.assertIn(t, tables)
4494
4495    def testGetAttnames(self):
4496        get_attnames = self.db.get_attnames
4497        query = self.db.query
4498        result = {'oid': 'int', 'd': 'int', 'n': 'int'}
4499        r = get_attnames("t")
4500        self.assertEqual(r, result)
4501        r = get_attnames("s4.t4")
4502        self.assertEqual(r, result)
4503        query("drop table if exists s3.t3m")
4504        self.addCleanup(query, "drop table s3.t3m")
4505        query("create table s3.t3m with oids as select 1 as m")
4506        result_m = {'oid': 'int', 'm': 'int'}
4507        r = get_attnames("s3.t3m")
4508        self.assertEqual(r, result_m)
4509        query("set search_path to s1,s3")
4510        r = get_attnames("t3")
4511        self.assertEqual(r, result)
4512        r = get_attnames("t3m")
4513        self.assertEqual(r, result_m)
4514
4515    def testGet(self):
4516        get = self.db.get
4517        query = self.db.query
4518        PrgError = pg.ProgrammingError
4519        self.assertEqual(get("t", 1, 'n')['d'], 0)
4520        self.assertEqual(get("t0", 1, 'n')['d'], 0)
4521        self.assertEqual(get("public.t", 1, 'n')['d'], 0)
4522        self.assertEqual(get("public.t0", 1, 'n')['d'], 0)
4523        self.assertRaises(PrgError, get, "public.t1", 1, 'n')
4524        self.assertEqual(get("s1.t1", 1, 'n')['d'], 1)
4525        self.assertEqual(get("s3.t", 1, 'n')['d'], 3)
4526        query("set search_path to s2,s4")
4527        self.assertRaises(PrgError, get, "t1", 1, 'n')
4528        self.assertEqual(get("t4", 1, 'n')['d'], 4)
4529        self.assertRaises(PrgError, get, "t3", 1, 'n')
4530        self.assertEqual(get("t", 1, 'n')['d'], 2)
4531        self.assertEqual(get("s3.t3", 1, 'n')['d'], 3)
4532        query("set search_path to s1,s3")
4533        self.assertRaises(PrgError, get, "t2", 1, 'n')
4534        self.assertEqual(get("t3", 1, 'n')['d'], 3)
4535        self.assertRaises(PrgError, get, "t4", 1, 'n')
4536        self.assertEqual(get("t", 1, 'n')['d'], 1)
4537        self.assertEqual(get("s4.t4", 1, 'n')['d'], 4)
4538
4539    def testMunging(self):
4540        get = self.db.get
4541        query = self.db.query
4542        r = get("t", 1, 'n')
4543        self.assertIn('oid(t)', r)
4544        query("set search_path to s2")
4545        r = get("t2", 1, 'n')
4546        self.assertIn('oid(t2)', r)
4547        query("set search_path to s3")
4548        r = get("t", 1, 'n')
4549        self.assertIn('oid(t)', r)
4550
4551
4552class TestDebug(unittest.TestCase):
4553    """Test the debug attribute of the DB class."""
4554
4555    def setUp(self):
4556        self.db = DB()
4557        self.query = self.db.query
4558        self.debug = self.db.debug
4559        self.output = StringIO()
4560        self.stdout, sys.stdout = sys.stdout, self.output
4561
4562    def tearDown(self):
4563        sys.stdout = self.stdout
4564        self.output.close()
4565        self.db.debug = debug
4566        self.db.close()
4567
4568    def get_output(self):
4569        return self.output.getvalue()
4570
4571    def send_queries(self):
4572        self.db.query("select 1")
4573        self.db.query("select 2")
4574
4575    def testDebugDefault(self):
4576        if debug:
4577            self.assertEqual(self.db.debug, debug)
4578        else:
4579            self.assertIsNone(self.db.debug)
4580
4581    def testDebugIsFalse(self):
4582        self.db.debug = False
4583        self.send_queries()
4584        self.assertEqual(self.get_output(), "")
4585
4586    def testDebugIsTrue(self):
4587        self.db.debug = True
4588        self.send_queries()
4589        self.assertEqual(self.get_output(), "select 1\nselect 2\n")
4590
4591    def testDebugIsString(self):
4592        self.db.debug = "Test with string: %s."
4593        self.send_queries()
4594        self.assertEqual(self.get_output(),
4595            "Test with string: select 1.\nTest with string: select 2.\n")
4596
4597    def testDebugIsFileLike(self):
4598        with tempfile.TemporaryFile('w+') as debug_file:
4599            self.db.debug = debug_file
4600            self.send_queries()
4601            debug_file.seek(0)
4602            output = debug_file.read()
4603            self.assertEqual(output, "select 1\nselect 2\n")
4604            self.assertEqual(self.get_output(), "")
4605
4606    def testDebugIsCallable(self):
4607        output = []
4608        self.db.debug = output.append
4609        self.db.query("select 1")
4610        self.db.query("select 2")
4611        self.assertEqual(output, ["select 1", "select 2"])
4612        self.assertEqual(self.get_output(), "")
4613
4614    def testDebugMultipleArgs(self):
4615        output = []
4616        self.db.debug = output.append
4617        args = ['Error', 42, {1: 'a', 2: 'b'}, [3, 5, 7]]
4618        self.db._do_debug(*args)
4619        self.assertEqual(output, ['\n'.join(str(arg) for arg in args)])
4620        self.assertEqual(self.get_output(), "")
4621
4622
4623class TestMemoryLeaks(unittest.TestCase):
4624    """Test that the DB class does not leak memory."""
4625
4626    def getLeaks(self, fut):
4627        ids = set()
4628        objs = []
4629        add_ids = ids.update
4630        gc.collect()
4631        objs[:] = gc.get_objects()
4632        add_ids(id(obj) for obj in objs)
4633        fut()
4634        gc.collect()
4635        objs[:] = gc.get_objects()
4636        objs[:] = [obj for obj in objs if id(obj) not in ids]
4637        if objs and sys.version_info[:3] in ((3, 5, 0), (3, 5, 1)):
4638            # workaround for Python issue 26811
4639            objs[:] = [obj for obj in objs if repr(obj) != '(<NULL>,)']
4640        self.assertEqual(len(objs), 0)
4641
4642    def testLeaksWithClose(self):
4643        def fut():
4644            db = DB()
4645            db.query("select $1::int as r", 42).dictresult()
4646            db.close()
4647        self.getLeaks(fut)
4648
4649    def testLeaksWithoutClose(self):
4650        def fut():
4651            db = DB()
4652            db.query("select $1::int as r", 42).dictresult()
4653        self.getLeaks(fut)
4654
4655
4656if __name__ == '__main__':
4657    unittest.main()
Note: See TracBrowser for help on using the repository browser.