source: trunk/tests/test_classic_dbwrapper.py @ 928

Last change on this file since 928 was 928, checked in by cito, 21 months ago

Adapt tests for Postgres 10

  • Property svn:executable set to *
  • Property svn:keywords set to Id
File size: 181.3 KB
Line 
1#! /usr/bin/python
2# -*- coding: utf-8 -*-
3
4"""Test the classic PyGreSQL interface.
5
6Sub-tests for the DB wrapper object.
7
8Contributed by Christoph Zwerschke.
9
10These tests need a database to test against.
11"""
12
13try:
14    import unittest2 as unittest  # for Python < 2.7
15except ImportError:
16    import unittest
17
18import os
19import sys
20import gc
21import json
22import tempfile
23
24import pg  # the module under test
25
26from decimal import Decimal
27from datetime import date, time, datetime, timedelta
28from uuid import UUID
29from time import strftime
30from operator import itemgetter
31
32# We need a database to test against.  If LOCAL_PyGreSQL.py exists we will
33# get our information from that.  Otherwise we use the defaults.
34# The current user must have create schema privilege on the database.
35dbname = 'unittest'
36dbhost = None
37dbport = 5432
38
39debug = False  # let DB wrapper print debugging output
40
41try:
42    from .LOCAL_PyGreSQL import *
43except (ImportError, ValueError):
44    try:
45        from LOCAL_PyGreSQL import *
46    except ImportError:
47        pass
48
49try:
50    long
51except NameError:  # Python >= 3.0
52    long = int
53
54try:
55    unicode
56except NameError:  # Python >= 3.0
57    unicode = str
58
59try:
60    from collections import OrderedDict
61except ImportError:  # Python 2.6 or 3.0
62    OrderedDict = dict
63
64if str is bytes:
65    from StringIO import StringIO
66else:
67    from io import StringIO
68
69windows = os.name == 'nt'
70
71# There is a known a bug in libpq under Windows which can cause
72# the interface to crash when calling PQhost():
73do_not_ask_for_host = windows
74do_not_ask_for_host_reason = 'libpq issue on Windows'
75
76
77def DB():
78    """Create a DB wrapper object connecting to the test database."""
79    db = pg.DB(dbname, dbhost, dbport)
80    if debug:
81        db.debug = debug
82    db.query("set client_min_messages=warning")
83    return db
84
85
86class TestAttrDict(unittest.TestCase):
87    """Test the simple ordered dictionary for attribute names."""
88
89    cls = pg.AttrDict
90    base = OrderedDict
91
92    def testInit(self):
93        a = self.cls()
94        self.assertIsInstance(a, self.base)
95        self.assertEqual(a, self.base())
96        items = [('id', 'int'), ('name', 'text')]
97        a = self.cls(items)
98        self.assertIsInstance(a, self.base)
99        self.assertEqual(a, self.base(items))
100        iteritems = iter(items)
101        a = self.cls(iteritems)
102        self.assertIsInstance(a, self.base)
103        self.assertEqual(a, self.base(items))
104
105    def testIter(self):
106        a = self.cls()
107        self.assertEqual(list(a), [])
108        keys = ['id', 'name', 'age']
109        items = [(key, None) for key in keys]
110        a = self.cls(items)
111        self.assertEqual(list(a), keys)
112
113    def testKeys(self):
114        a = self.cls()
115        self.assertEqual(list(a.keys()), [])
116        keys = ['id', 'name', 'age']
117        items = [(key, None) for key in keys]
118        a = self.cls(items)
119        self.assertEqual(list(a.keys()), keys)
120
121    def testValues(self):
122        a = self.cls()
123        self.assertEqual(list(a.values()), [])
124        items = [('id', 'int'), ('name', 'text')]
125        values = [item[1] for item in items]
126        a = self.cls(items)
127        self.assertEqual(list(a.values()), values)
128
129    def testItems(self):
130        a = self.cls()
131        self.assertEqual(list(a.items()), [])
132        items = [('id', 'int'), ('name', 'text')]
133        a = self.cls(items)
134        self.assertEqual(list(a.items()), items)
135
136    def testGet(self):
137        a = self.cls([('id', 1)])
138        try:
139            self.assertEqual(a['id'], 1)
140        except KeyError:
141            self.fail('AttrDict should be readable')
142
143    def testSet(self):
144        a = self.cls()
145        try:
146            a['id'] = 1
147        except TypeError:
148            pass
149        else:
150            self.fail('AttrDict should be read-only')
151
152    def testDel(self):
153        a = self.cls([('id', 1)])
154        try:
155            del a['id']
156        except TypeError:
157            pass
158        else:
159            self.fail('AttrDict should be read-only')
160
161    def testWriteMethods(self):
162        a = self.cls([('id', 1)])
163        self.assertEqual(a['id'], 1)
164        for method in 'clear', 'update', 'pop', 'setdefault', 'popitem':
165            method = getattr(a, method)
166            self.assertRaises(TypeError, method, a)
167
168
169class TestDBClassInit(unittest.TestCase):
170    """Test proper handling of errors when creating DB instances."""
171
172    def testBadParams(self):
173        self.assertRaises(TypeError, pg.DB, invalid=True)
174
175    def testDeleteDb(self):
176        db = DB()
177        del db.db
178        self.assertRaises(pg.InternalError, db.close)
179        del db
180
181
182class TestDBClassBasic(unittest.TestCase):
183    """Test existence of the DB class wrapped pg connection methods."""
184
185    def setUp(self):
186        self.db = DB()
187
188    def tearDown(self):
189        try:
190            self.db.close()
191        except pg.InternalError:
192            pass
193
194    def testAllDBAttributes(self):
195        attributes = [
196            'abort', 'adapter',
197            'begin',
198            'cancel', 'clear', 'close', 'commit',
199            'date_format', 'db', 'dbname', 'dbtypes',
200            'debug', 'decode_json', 'delete',
201            'encode_json', 'end', 'endcopy', 'error',
202            'escape_bytea', 'escape_identifier',
203            'escape_literal', 'escape_string',
204            'fileno',
205            'get', 'get_as_dict', 'get_as_list',
206            'get_attnames', 'get_cast_hook',
207            'get_databases', 'get_notice_receiver',
208            'get_parameter', 'get_relations', 'get_tables',
209            'getline', 'getlo', 'getnotify',
210            'has_table_privilege', 'host',
211            'insert', 'inserttable',
212            'locreate', 'loimport',
213            'notification_handler',
214            'options',
215            'parameter', 'pkey', 'port',
216            'protocol_version', 'putline',
217            'query', 'query_formatted',
218            'release', 'reopen', 'reset', 'rollback',
219            'savepoint', 'server_version',
220            'set_cast_hook', 'set_notice_receiver',
221            'set_parameter',
222            'source', 'start', 'status',
223            'transaction', 'truncate',
224            'unescape_bytea', 'update', 'upsert',
225            'use_regtypes', 'user',
226        ]
227        # __dir__ is not called in Python 2.6 for old-style classes
228        db_attributes = dir(self.db) if hasattr(
229            self.db.__class__, '__class__') else self.db.__dir__()
230        db_attributes = [a for a in db_attributes
231            if not a.startswith('_')]
232        self.assertEqual(attributes, db_attributes)
233
234    def testAttributeDb(self):
235        self.assertEqual(self.db.db.db, dbname)
236
237    def testAttributeDbname(self):
238        self.assertEqual(self.db.dbname, dbname)
239
240    def testAttributeError(self):
241        error = self.db.error
242        self.assertTrue(not error or 'krb5_' in error)
243        self.assertEqual(self.db.error, self.db.db.error)
244
245    @unittest.skipIf(do_not_ask_for_host, do_not_ask_for_host_reason)
246    def testAttributeHost(self):
247        if dbhost and not dbhost.startswith('/'):
248            host = dbhost
249        else:
250            host = 'localhost'
251        self.assertIsInstance(self.db.host, str)
252        self.assertEqual(self.db.host, host)
253        self.assertEqual(self.db.db.host, host)
254
255    def testAttributeOptions(self):
256        no_options = ''
257        options = self.db.options
258        self.assertEqual(options, no_options)
259        self.assertEqual(options, self.db.db.options)
260
261    def testAttributePort(self):
262        def_port = 5432
263        port = self.db.port
264        self.assertIsInstance(port, int)
265        self.assertEqual(port, dbport or def_port)
266        self.assertEqual(port, self.db.db.port)
267
268    def testAttributeProtocolVersion(self):
269        protocol_version = self.db.protocol_version
270        self.assertIsInstance(protocol_version, int)
271        self.assertTrue(2 <= protocol_version < 4)
272        self.assertEqual(protocol_version, self.db.db.protocol_version)
273
274    def testAttributeServerVersion(self):
275        server_version = self.db.server_version
276        self.assertIsInstance(server_version, int)
277        self.assertTrue(90000 <= server_version < 110000)
278        self.assertEqual(server_version, self.db.db.server_version)
279
280    def testAttributeStatus(self):
281        status_ok = 1
282        status = self.db.status
283        self.assertIsInstance(status, int)
284        self.assertEqual(status, status_ok)
285        self.assertEqual(status, self.db.db.status)
286
287    def testAttributeUser(self):
288        no_user = 'Deprecated facility'
289        user = self.db.user
290        self.assertTrue(user)
291        self.assertIsInstance(user, str)
292        self.assertNotEqual(user, no_user)
293        self.assertEqual(user, self.db.db.user)
294
295    def testMethodEscapeLiteral(self):
296        self.assertEqual(self.db.escape_literal(''), "''")
297
298    def testMethodEscapeIdentifier(self):
299        self.assertEqual(self.db.escape_identifier(''), '""')
300
301    def testMethodEscapeString(self):
302        self.assertEqual(self.db.escape_string(''), '')
303
304    def testMethodEscapeBytea(self):
305        self.assertEqual(self.db.escape_bytea('').replace(
306            '\\x', '').replace('\\', ''), '')
307
308    def testMethodUnescapeBytea(self):
309        self.assertEqual(self.db.unescape_bytea(''), b'')
310
311    def testMethodDecodeJson(self):
312        self.assertEqual(self.db.decode_json('{}'), {})
313
314    def testMethodEncodeJson(self):
315        self.assertEqual(self.db.encode_json({}), '{}')
316
317    def testMethodQuery(self):
318        query = self.db.query
319        query("select 1+1")
320        query("select 1+$1+$2", 2, 3)
321        query("select 1+$1+$2", (2, 3))
322        query("select 1+$1+$2", [2, 3])
323        query("select 1+$1", 1)
324
325    def testMethodQueryEmpty(self):
326        self.assertRaises(ValueError, self.db.query, '')
327
328    def testMethodQueryDataError(self):
329        try:
330            self.db.query("select 1/0")
331        except pg.DataError as error:
332            self.assertEqual(error.sqlstate, '22012')
333
334    def testMethodEndcopy(self):
335        try:
336            self.db.endcopy()
337        except IOError:
338            pass
339
340    def testMethodClose(self):
341        self.db.close()
342        try:
343            self.db.reset()
344        except pg.Error:
345            pass
346        else:
347            self.fail('Reset should give an error for a closed connection')
348        self.assertIsNone(self.db.db)
349        self.assertRaises(pg.InternalError, self.db.close)
350        self.assertRaises(pg.InternalError, self.db.query, 'select 1')
351        self.assertRaises(pg.InternalError, getattr, self.db, 'status')
352        self.assertRaises(pg.InternalError, getattr, self.db, 'error')
353        self.assertRaises(pg.InternalError, getattr, self.db, 'absent')
354
355    def testMethodReset(self):
356        con = self.db.db
357        self.db.reset()
358        self.assertIs(self.db.db, con)
359        self.db.query("select 1+1")
360        self.db.close()
361        self.assertRaises(pg.InternalError, self.db.reset)
362
363    def testMethodReopen(self):
364        con = self.db.db
365        self.db.reopen()
366        self.assertIsNot(self.db.db, con)
367        con = self.db.db
368        self.db.query("select 1+1")
369        self.db.close()
370        self.db.reopen()
371        self.assertIsNot(self.db.db, con)
372        self.db.query("select 1+1")
373        self.db.close()
374
375    def testExistingConnection(self):
376        db = pg.DB(self.db.db)
377        self.assertEqual(self.db.db, db.db)
378        self.assertTrue(db.db)
379        db.close()
380        self.assertTrue(db.db)
381        db.reopen()
382        self.assertTrue(db.db)
383        db.close()
384        self.assertTrue(db.db)
385        db = pg.DB(self.db)
386        self.assertEqual(self.db.db, db.db)
387        db = pg.DB(db=self.db.db)
388        self.assertEqual(self.db.db, db.db)
389
390        class DB2:
391            pass
392
393        db2 = DB2()
394        db2._cnx = self.db.db
395        db = pg.DB(db2)
396        self.assertEqual(self.db.db, db.db)
397
398
399class TestDBClass(unittest.TestCase):
400    """Test the methods of the DB class wrapped pg connection."""
401
402    maxDiff = 80 * 20
403
404    cls_set_up = False
405
406    regtypes = None
407
408    @classmethod
409    def setUpClass(cls):
410        db = DB()
411        db.query("drop table if exists test cascade")
412        db.query("create table test ("
413                 "i2 smallint, i4 integer, i8 bigint,"
414                 " d numeric, f4 real, f8 double precision, m money,"
415                 " v4 varchar(4), c4 char(4), t text)")
416        db.query("create or replace view test_view as"
417                 " select i4, v4 from test")
418        db.close()
419        cls.cls_set_up = True
420
421    @classmethod
422    def tearDownClass(cls):
423        db = DB()
424        db.query("drop table test cascade")
425        db.close()
426
427    def setUp(self):
428        self.assertTrue(self.cls_set_up)
429        self.db = DB()
430        if self.regtypes is None:
431            self.regtypes = self.db.use_regtypes()
432        else:
433            self.db.use_regtypes(self.regtypes)
434        query = self.db.query
435        query('set client_encoding=utf8')
436        query("set lc_monetary='C'")
437        query("set datestyle='ISO,YMD'")
438        query('set standard_conforming_strings=on')
439        try:
440            query('set bytea_output=hex')
441        except pg.ProgrammingError:
442            if self.db.server_version >= 90000:
443                raise  # ignore for older server versions
444
445    def tearDown(self):
446        self.doCleanups()
447        self.db.close()
448
449    def createTable(self, table, definition,
450                    temporary=True, oids=None, values=None):
451        query = self.db.query
452        if not '"' in table or '.' in table:
453            table = '"%s"' % table
454        if not temporary:
455            q = 'drop table if exists %s cascade' % table
456            query(q)
457            self.addCleanup(query, q)
458        temporary = 'temporary table' if temporary else 'table'
459        as_query = definition.startswith(('as ', 'AS '))
460        if not as_query and not definition.startswith('('):
461            definition = '(%s)' % definition
462        with_oids = 'with oids' if oids else 'without oids'
463        q = ['create', temporary, table]
464        if as_query:
465            q.extend([with_oids, definition])
466        else:
467            q.extend([definition, with_oids])
468        q = ' '.join(q)
469        query(q)
470        if values:
471            for params in values:
472                if not isinstance(params, (list, tuple)):
473                    params = [params]
474                values = ', '.join('$%d' % (n + 1) for n in range(len(params)))
475                q = "insert into %s values (%s)" % (table, values)
476                query(q, params)
477
478    def testClassName(self):
479        self.assertEqual(self.db.__class__.__name__, 'DB')
480
481    def testModuleName(self):
482        self.assertEqual(self.db.__module__, 'pg')
483        self.assertEqual(self.db.__class__.__module__, 'pg')
484
485    def testEscapeLiteral(self):
486        f = self.db.escape_literal
487        r = f(b"plain")
488        self.assertIsInstance(r, bytes)
489        self.assertEqual(r, b"'plain'")
490        r = f(u"plain")
491        self.assertIsInstance(r, unicode)
492        self.assertEqual(r, u"'plain'")
493        r = f(u"that's kÀse".encode('utf-8'))
494        self.assertIsInstance(r, bytes)
495        self.assertEqual(r, u"'that''s kÀse'".encode('utf-8'))
496        r = f(u"that's kÀse")
497        self.assertIsInstance(r, unicode)
498        self.assertEqual(r, u"'that''s kÀse'")
499        self.assertEqual(f(r"It's fine to have a \ inside."),
500                         r" E'It''s fine to have a \\ inside.'")
501        self.assertEqual(f('No "quotes" must be escaped.'),
502                         "'No \"quotes\" must be escaped.'")
503
504    def testEscapeIdentifier(self):
505        f = self.db.escape_identifier
506        r = f(b"plain")
507        self.assertIsInstance(r, bytes)
508        self.assertEqual(r, b'"plain"')
509        r = f(u"plain")
510        self.assertIsInstance(r, unicode)
511        self.assertEqual(r, u'"plain"')
512        r = f(u"that's kÀse".encode('utf-8'))
513        self.assertIsInstance(r, bytes)
514        self.assertEqual(r, u'"that\'s kÀse"'.encode('utf-8'))
515        r = f(u"that's kÀse")
516        self.assertIsInstance(r, unicode)
517        self.assertEqual(r, u'"that\'s kÀse"')
518        self.assertEqual(f(r"It's fine to have a \ inside."),
519                         '"It\'s fine to have a \\ inside."')
520        self.assertEqual(f('All "quotes" must be escaped.'),
521                         '"All ""quotes"" must be escaped."')
522
523    def testEscapeString(self):
524        f = self.db.escape_string
525        r = f(b"plain")
526        self.assertIsInstance(r, bytes)
527        self.assertEqual(r, b"plain")
528        r = f(u"plain")
529        self.assertIsInstance(r, unicode)
530        self.assertEqual(r, u"plain")
531        r = f(u"that's kÀse".encode('utf-8'))
532        self.assertIsInstance(r, bytes)
533        self.assertEqual(r, u"that''s kÀse".encode('utf-8'))
534        r = f(u"that's kÀse")
535        self.assertIsInstance(r, unicode)
536        self.assertEqual(r, u"that''s kÀse")
537        self.assertEqual(f(r"It's fine to have a \ inside."),
538                         r"It''s fine to have a \ inside.")
539
540    def testEscapeBytea(self):
541        f = self.db.escape_bytea
542        # note that escape_byte always returns hex output since Pg 9.0,
543        # regardless of the bytea_output setting
544        r = f(b'plain')
545        self.assertIsInstance(r, bytes)
546        self.assertEqual(r, b'\\x706c61696e')
547        r = f(u'plain')
548        self.assertIsInstance(r, unicode)
549        self.assertEqual(r, u'\\x706c61696e')
550        r = f(u"das is' kÀse".encode('utf-8'))
551        self.assertIsInstance(r, bytes)
552        self.assertEqual(r, b'\\x64617320697327206bc3a47365')
553        r = f(u"das is' kÀse")
554        self.assertIsInstance(r, unicode)
555        self.assertEqual(r, u'\\x64617320697327206bc3a47365')
556        self.assertEqual(f(b'O\x00ps\xff!'), b'\\x4f007073ff21')
557
558    def testUnescapeBytea(self):
559        f = self.db.unescape_bytea
560        r = f(b'plain')
561        self.assertIsInstance(r, bytes)
562        self.assertEqual(r, b'plain')
563        r = f(u'plain')
564        self.assertIsInstance(r, bytes)
565        self.assertEqual(r, b'plain')
566        r = f(b"das is' k\\303\\244se")
567        self.assertIsInstance(r, bytes)
568        self.assertEqual(r, u"das is' kÀse".encode('utf8'))
569        r = f(u"das is' k\\303\\244se")
570        self.assertIsInstance(r, bytes)
571        self.assertEqual(r, u"das is' kÀse".encode('utf8'))
572        self.assertEqual(f(r'O\\000ps\\377!'), b'O\\000ps\\377!')
573        self.assertEqual(f(r'\\x706c61696e'), b'\\x706c61696e')
574        self.assertEqual(f(r'\\x746861742773206be47365'),
575                         b'\\x746861742773206be47365')
576        self.assertEqual(f(r'\\x4f007073ff21'), b'\\x4f007073ff21')
577
578    def testDecodeJson(self):
579        f = self.db.decode_json
580        self.assertIsNone(f('null'))
581        data = {
582            "id": 1, "name": "Foo", "price": 1234.5,
583            "new": True, "note": None,
584            "tags": ["Bar", "Eek"],
585            "stock": {"warehouse": 300, "retail": 20}}
586        text = json.dumps(data)
587        r = f(text)
588        self.assertIsInstance(r, dict)
589        self.assertEqual(r, data)
590        self.assertIsInstance(r['id'], int)
591        self.assertIsInstance(r['name'], unicode)
592        self.assertIsInstance(r['price'], float)
593        self.assertIsInstance(r['new'], bool)
594        self.assertIsInstance(r['tags'], list)
595        self.assertIsInstance(r['stock'], dict)
596
597    def testEncodeJson(self):
598        f = self.db.encode_json
599        self.assertEqual(f(None), 'null')
600        data = {
601            "id": 1, "name": "Foo", "price": 1234.5,
602            "new": True, "note": None,
603            "tags": ["Bar", "Eek"],
604            "stock": {"warehouse": 300, "retail": 20}}
605        text = json.dumps(data)
606        r = f(data)
607        self.assertIsInstance(r, str)
608        self.assertEqual(r, text)
609
610    def testGetParameter(self):
611        f = self.db.get_parameter
612        self.assertRaises(TypeError, f)
613        self.assertRaises(TypeError, f, None)
614        self.assertRaises(TypeError, f, 42)
615        self.assertRaises(TypeError, f, '')
616        self.assertRaises(TypeError, f, [])
617        self.assertRaises(TypeError, f, [''])
618        self.assertRaises(pg.ProgrammingError, f, 'this_does_not_exist')
619        r = f('standard_conforming_strings')
620        self.assertEqual(r, 'on')
621        r = f('lc_monetary')
622        self.assertEqual(r, 'C')
623        r = f('datestyle')
624        self.assertEqual(r, 'ISO, YMD')
625        r = f('bytea_output')
626        self.assertEqual(r, 'hex')
627        r = f(['bytea_output', 'lc_monetary'])
628        self.assertIsInstance(r, list)
629        self.assertEqual(r, ['hex', 'C'])
630        r = f(('standard_conforming_strings', 'datestyle', 'bytea_output'))
631        self.assertEqual(r, ['on', 'ISO, YMD', 'hex'])
632        r = f(set(['bytea_output', 'lc_monetary']))
633        self.assertIsInstance(r, dict)
634        self.assertEqual(r, {'bytea_output': 'hex', 'lc_monetary': 'C'})
635        r = f(set(['Bytea_Output', ' LC_Monetary ']))
636        self.assertIsInstance(r, dict)
637        self.assertEqual(r, {'Bytea_Output': 'hex', ' LC_Monetary ': 'C'})
638        s = dict.fromkeys(('bytea_output', 'lc_monetary'))
639        r = f(s)
640        self.assertIs(r, s)
641        self.assertEqual(r, {'bytea_output': 'hex', 'lc_monetary': 'C'})
642        s = dict.fromkeys(('Bytea_Output', ' LC_Monetary '))
643        r = f(s)
644        self.assertIs(r, s)
645        self.assertEqual(r, {'Bytea_Output': 'hex', ' LC_Monetary ': 'C'})
646
647    def testGetParameterServerVersion(self):
648        r = self.db.get_parameter('server_version_num')
649        self.assertIsInstance(r, str)
650        s = self.db.server_version
651        self.assertIsInstance(s, int)
652        self.assertEqual(r, str(s))
653
654    def testGetParameterAll(self):
655        f = self.db.get_parameter
656        r = f('all')
657        self.assertIsInstance(r, dict)
658        self.assertEqual(r['standard_conforming_strings'], 'on')
659        self.assertEqual(r['lc_monetary'], 'C')
660        self.assertEqual(r['DateStyle'], 'ISO, YMD')
661        self.assertEqual(r['bytea_output'], 'hex')
662
663    def testSetParameter(self):
664        f = self.db.set_parameter
665        g = self.db.get_parameter
666        self.assertRaises(TypeError, f)
667        self.assertRaises(TypeError, f, None)
668        self.assertRaises(TypeError, f, 42)
669        self.assertRaises(TypeError, f, '')
670        self.assertRaises(TypeError, f, [])
671        self.assertRaises(TypeError, f, [''])
672        self.assertRaises(ValueError, f, 'all', 'invalid')
673        self.assertRaises(ValueError, f, {
674            'invalid1': 'value1', 'invalid2': 'value2'}, 'value')
675        self.assertRaises(pg.ProgrammingError, f, 'this_does_not_exist')
676        f('standard_conforming_strings', 'off')
677        self.assertEqual(g('standard_conforming_strings'), 'off')
678        f('datestyle', 'ISO, DMY')
679        self.assertEqual(g('datestyle'), 'ISO, DMY')
680        f(['standard_conforming_strings', 'datestyle'], ['on', 'ISO, DMY'])
681        self.assertEqual(g('standard_conforming_strings'), 'on')
682        self.assertEqual(g('datestyle'), 'ISO, DMY')
683        f(['default_with_oids', 'standard_conforming_strings'], 'off')
684        self.assertEqual(g('default_with_oids'), 'off')
685        self.assertEqual(g('standard_conforming_strings'), 'off')
686        f(('standard_conforming_strings', 'datestyle'), ('on', 'ISO, YMD'))
687        self.assertEqual(g('standard_conforming_strings'), 'on')
688        self.assertEqual(g('datestyle'), 'ISO, YMD')
689        f(('default_with_oids', 'standard_conforming_strings'), 'off')
690        self.assertEqual(g('default_with_oids'), 'off')
691        self.assertEqual(g('standard_conforming_strings'), 'off')
692        f(set(['default_with_oids', 'standard_conforming_strings']), 'on')
693        self.assertEqual(g('default_with_oids'), 'on')
694        self.assertEqual(g('standard_conforming_strings'), 'on')
695        self.assertRaises(ValueError, f, set(['default_with_oids',
696            'standard_conforming_strings']), ['off', 'on'])
697        f(set(['default_with_oids', 'standard_conforming_strings']),
698            ['off', 'off'])
699        self.assertEqual(g('default_with_oids'), 'off')
700        self.assertEqual(g('standard_conforming_strings'), 'off')
701        f({'standard_conforming_strings': 'on', 'datestyle': 'ISO, YMD'})
702        self.assertEqual(g('standard_conforming_strings'), 'on')
703        self.assertEqual(g('datestyle'), 'ISO, YMD')
704
705    def testResetParameter(self):
706        db = DB()
707        f = db.set_parameter
708        g = db.get_parameter
709        r = g('default_with_oids')
710        self.assertIn(r, ('on', 'off'))
711        dwi, not_dwi = r, 'off' if r == 'on' else 'on'
712        r = g('standard_conforming_strings')
713        self.assertIn(r, ('on', 'off'))
714        scs, not_scs = r, 'off' if r == 'on' else 'on'
715        f('default_with_oids', not_dwi)
716        f('standard_conforming_strings', not_scs)
717        self.assertEqual(g('default_with_oids'), not_dwi)
718        self.assertEqual(g('standard_conforming_strings'), not_scs)
719        f('default_with_oids')
720        f('standard_conforming_strings', None)
721        self.assertEqual(g('default_with_oids'), dwi)
722        self.assertEqual(g('standard_conforming_strings'), scs)
723        f('default_with_oids', not_dwi)
724        f('standard_conforming_strings', not_scs)
725        self.assertEqual(g('default_with_oids'), not_dwi)
726        self.assertEqual(g('standard_conforming_strings'), not_scs)
727        f(['default_with_oids', 'standard_conforming_strings'], None)
728        self.assertEqual(g('default_with_oids'), dwi)
729        self.assertEqual(g('standard_conforming_strings'), scs)
730        f('default_with_oids', not_dwi)
731        f('standard_conforming_strings', not_scs)
732        self.assertEqual(g('default_with_oids'), not_dwi)
733        self.assertEqual(g('standard_conforming_strings'), not_scs)
734        f(('default_with_oids', 'standard_conforming_strings'))
735        self.assertEqual(g('default_with_oids'), dwi)
736        self.assertEqual(g('standard_conforming_strings'), scs)
737        f('default_with_oids', not_dwi)
738        f('standard_conforming_strings', not_scs)
739        self.assertEqual(g('default_with_oids'), not_dwi)
740        self.assertEqual(g('standard_conforming_strings'), not_scs)
741        f(set(['default_with_oids', 'standard_conforming_strings']))
742        self.assertEqual(g('default_with_oids'), dwi)
743        self.assertEqual(g('standard_conforming_strings'), scs)
744        db.close()
745
746    def testResetParameterAll(self):
747        db = DB()
748        f = db.set_parameter
749        self.assertRaises(ValueError, f, 'all', 0)
750        self.assertRaises(ValueError, f, 'all', 'off')
751        g = db.get_parameter
752        r = g('default_with_oids')
753        self.assertIn(r, ('on', 'off'))
754        dwi, not_dwi = r, 'off' if r == 'on' else 'on'
755        r = g('standard_conforming_strings')
756        self.assertIn(r, ('on', 'off'))
757        scs, not_scs = r, 'off' if r == 'on' else 'on'
758        f('default_with_oids', not_dwi)
759        f('standard_conforming_strings', not_scs)
760        self.assertEqual(g('default_with_oids'), not_dwi)
761        self.assertEqual(g('standard_conforming_strings'), not_scs)
762        f('all')
763        self.assertEqual(g('default_with_oids'), dwi)
764        self.assertEqual(g('standard_conforming_strings'), scs)
765        db.close()
766
767    def testSetParameterLocal(self):
768        f = self.db.set_parameter
769        g = self.db.get_parameter
770        self.assertEqual(g('standard_conforming_strings'), 'on')
771        self.db.begin()
772        f('standard_conforming_strings', 'off', local=True)
773        self.assertEqual(g('standard_conforming_strings'), 'off')
774        self.db.end()
775        self.assertEqual(g('standard_conforming_strings'), 'on')
776
777    def testSetParameterSession(self):
778        f = self.db.set_parameter
779        g = self.db.get_parameter
780        self.assertEqual(g('standard_conforming_strings'), 'on')
781        self.db.begin()
782        f('standard_conforming_strings', 'off', local=False)
783        self.assertEqual(g('standard_conforming_strings'), 'off')
784        self.db.end()
785        self.assertEqual(g('standard_conforming_strings'), 'off')
786
787    def testReset(self):
788        db = DB()
789        default_datestyle = db.get_parameter('datestyle')
790        changed_datestyle = 'ISO, DMY'
791        if changed_datestyle == default_datestyle:
792            changed_datestyle == 'ISO, YMD'
793        self.db.set_parameter('datestyle', changed_datestyle)
794        r = self.db.get_parameter('datestyle')
795        self.assertEqual(r, changed_datestyle)
796        con = self.db.db
797        q = con.query("show datestyle")
798        self.db.reset()
799        r = q.getresult()[0][0]
800        self.assertEqual(r, changed_datestyle)
801        q = con.query("show datestyle")
802        r = q.getresult()[0][0]
803        self.assertEqual(r, default_datestyle)
804        r = self.db.get_parameter('datestyle')
805        self.assertEqual(r, default_datestyle)
806        db.close()
807
808    def testReopen(self):
809        db = DB()
810        default_datestyle = db.get_parameter('datestyle')
811        changed_datestyle = 'ISO, DMY'
812        if changed_datestyle == default_datestyle:
813            changed_datestyle == 'ISO, YMD'
814        self.db.set_parameter('datestyle', changed_datestyle)
815        r = self.db.get_parameter('datestyle')
816        self.assertEqual(r, changed_datestyle)
817        con = self.db.db
818        q = con.query("show datestyle")
819        self.db.reopen()
820        r = q.getresult()[0][0]
821        self.assertEqual(r, changed_datestyle)
822        self.assertRaises(TypeError, getattr, con, 'query')
823        r = self.db.get_parameter('datestyle')
824        self.assertEqual(r, default_datestyle)
825        db.close()
826
827    def testCreateTable(self):
828        table = 'test hello world'
829        values = [(2, "World!"), (1, "Hello")]
830        self.createTable(table, "n smallint, t varchar",
831                         temporary=True, oids=True, values=values)
832        r = self.db.query('select t from "%s" order by n' % table).getresult()
833        r = ', '.join(row[0] for row in r)
834        self.assertEqual(r, "Hello, World!")
835        r = self.db.query('select oid from "%s" limit 1' % table).getresult()
836        self.assertIsInstance(r[0][0], int)
837
838    def testQuery(self):
839        query = self.db.query
840        table = 'test_table'
841        self.createTable(table, "n integer", oids=True)
842        q = "insert into test_table values (1)"
843        r = query(q)
844        self.assertIsInstance(r, int)
845        q = "insert into test_table select 2"
846        r = query(q)
847        self.assertIsInstance(r, int)
848        oid = r
849        q = "select oid from test_table where n=2"
850        r = query(q).getresult()
851        self.assertEqual(len(r), 1)
852        r = r[0]
853        self.assertEqual(len(r), 1)
854        r = r[0]
855        self.assertEqual(r, oid)
856        q = "insert into test_table select 3 union select 4 union select 5"
857        r = query(q)
858        self.assertIsInstance(r, str)
859        self.assertEqual(r, '3')
860        q = "update test_table set n=4 where n<5"
861        r = query(q)
862        self.assertIsInstance(r, str)
863        self.assertEqual(r, '4')
864        q = "delete from test_table"
865        r = query(q)
866        self.assertIsInstance(r, str)
867        self.assertEqual(r, '5')
868
869    def testMultipleQueries(self):
870        self.assertEqual(self.db.query(
871            "create temporary table test_multi (n integer);"
872            "insert into test_multi values (4711);"
873            "select n from test_multi").getresult()[0][0], 4711)
874
875    def testQueryWithParams(self):
876        query = self.db.query
877        self.createTable('test_table', 'n1 integer, n2 integer', oids=True)
878        q = "insert into test_table values ($1, $2)"
879        r = query(q, (1, 2))
880        self.assertIsInstance(r, int)
881        r = query(q, [3, 4])
882        self.assertIsInstance(r, int)
883        r = query(q, [5, 6])
884        self.assertIsInstance(r, int)
885        q = "select * from test_table order by 1, 2"
886        self.assertEqual(query(q).getresult(),
887                         [(1, 2), (3, 4), (5, 6)])
888        q = "select * from test_table where n1=$1 and n2=$2"
889        self.assertEqual(query(q, 3, 4).getresult(), [(3, 4)])
890        q = "update test_table set n2=$2 where n1=$1"
891        r = query(q, 3, 7)
892        self.assertEqual(r, '1')
893        q = "select * from test_table order by 1, 2"
894        self.assertEqual(query(q).getresult(),
895                         [(1, 2), (3, 7), (5, 6)])
896        q = "delete from test_table where n2!=$1"
897        r = query(q, 4)
898        self.assertEqual(r, '3')
899
900    def testEmptyQuery(self):
901        self.assertRaises(ValueError, self.db.query, '')
902
903    def testQueryDataError(self):
904        try:
905            self.db.query("select 1/0")
906        except pg.DataError as error:
907            self.assertEqual(error.sqlstate, '22012')
908
909    def testQueryFormatted(self):
910        f = self.db.query_formatted
911        t = True if pg.get_bool() else 't'
912        # test with tuple
913        q = f("select %s::int, %s::real, %s::text, %s::bool",
914              (3, 2.5, 'hello', True))
915        r = q.getresult()[0]
916        self.assertEqual(r, (3, 2.5, 'hello', t))
917        # test with tuple, inline
918        q = f("select %s, %s, %s, %s", (3, 2.5, 'hello', True), inline=True)
919        r = q.getresult()[0]
920        if isinstance(r[1], Decimal):
921            # Python 2.6 cannot compare float and Decimal
922            r = list(r)
923            r[1] = float(r[1])
924            r = tuple(r)
925        self.assertEqual(r, (3, 2.5, 'hello', t))
926        # test with dict
927        q = f("select %(a)s::int, %(b)s::real, %(c)s::text, %(d)s::bool",
928              dict(a=3, b=2.5, c='hello', d=True))
929        r = q.getresult()[0]
930        self.assertEqual(r, (3, 2.5, 'hello', t))
931        # test with dict, inline
932        q = f("select %(a)s, %(b)s, %(c)s, %(d)s",
933              dict(a=3, b=2.5, c='hello', d=True), inline=True)
934        r = q.getresult()[0]
935        if isinstance(r[1], Decimal):
936            # Python 2.6 cannot compare float and Decimal
937            r = list(r)
938            r[1] = float(r[1])
939            r = tuple(r)
940        self.assertEqual(r, (3, 2.5, 'hello', t))
941        # test with dict and extra values
942        q = f("select %(a)s||%(b)s||%(c)s||%(d)s||'epsilon'",
943              dict(a='alpha', b='beta', c='gamma', d='delta', e='extra'))
944        r = q.getresult()[0][0]
945        self.assertEqual(r, 'alphabetagammadeltaepsilon')
946
947    def testQueryFormattedWithAny(self):
948        f = self.db.query_formatted
949        q = "select 2 = any(%s)"
950        r = f(q, [[1, 3]]).getresult()[0][0]
951        self.assertEqual(r, False if pg.get_bool() else 'f')
952        r = f(q, [[1, 2, 3]]).getresult()[0][0]
953        self.assertEqual(r, True if pg.get_bool() else 't')
954        r = f(q, [[]]).getresult()[0][0]
955        self.assertEqual(r, False if pg.get_bool() else 'f')
956        r = f(q, [[None]]).getresult()[0][0]
957        self.assertIsNone(r)
958
959    def testQueryFormattedWithoutParams(self):
960        f = self.db.query_formatted
961        q = "select 42"
962        r = f(q).getresult()[0][0]
963        self.assertEqual(r, 42)
964        r = f(q, None).getresult()[0][0]
965        self.assertEqual(r, 42)
966        r = f(q, []).getresult()[0][0]
967        self.assertEqual(r, 42)
968        r = f(q, {}).getresult()[0][0]
969        self.assertEqual(r, 42)
970
971    def testPkey(self):
972        query = self.db.query
973        pkey = self.db.pkey
974        self.assertRaises(KeyError, pkey, 'test')
975        for t in ('pkeytest', 'primary key test'):
976            self.createTable('%s0' % t, 'a smallint')
977            self.createTable('%s1' % t, 'b smallint primary key')
978            self.createTable('%s2' % t,
979                'c smallint, d smallint primary key')
980            self.createTable('%s3' % t,
981                'e smallint, f smallint, g smallint, h smallint, i smallint,'
982                ' primary key (f, h)')
983            self.createTable('%s4' % t,
984                'e smallint, f smallint, g smallint, h smallint, i smallint,'
985                ' primary key (h, f)')
986            self.createTable('%s5' % t,
987                'more_than_one_letter varchar primary key')
988            self.createTable('%s6' % t,
989                '"with space" date primary key')
990            self.createTable('%s7' % t,
991                'a_very_long_column_name varchar, "with space" date, "42" int,'
992                ' primary key (a_very_long_column_name, "with space", "42")')
993            self.assertRaises(KeyError, pkey, '%s0' % t)
994            self.assertEqual(pkey('%s1' % t), 'b')
995            self.assertEqual(pkey('%s1' % t, True), ('b',))
996            self.assertEqual(pkey('%s1' % t, composite=False), 'b')
997            self.assertEqual(pkey('%s1' % t, composite=True), ('b',))
998            self.assertEqual(pkey('%s2' % t), 'd')
999            self.assertEqual(pkey('%s2' % t, composite=True), ('d',))
1000            r = pkey('%s3' % t)
1001            self.assertIsInstance(r, tuple)
1002            self.assertEqual(r, ('f', 'h'))
1003            r = pkey('%s3' % t, composite=False)
1004            self.assertIsInstance(r, tuple)
1005            self.assertEqual(r, ('f', 'h'))
1006            r = pkey('%s4' % t)
1007            self.assertIsInstance(r, tuple)
1008            self.assertEqual(r, ('h', 'f'))
1009            self.assertEqual(pkey('%s5' % t), 'more_than_one_letter')
1010            self.assertEqual(pkey('%s6' % t), 'with space')
1011            r = pkey('%s7' % t)
1012            self.assertIsInstance(r, tuple)
1013            self.assertEqual(r, (
1014                'a_very_long_column_name', 'with space', '42'))
1015            # a newly added primary key will be detected
1016            query('alter table "%s0" add primary key (a)' % t)
1017            self.assertEqual(pkey('%s0' % t), 'a')
1018            # a changed primary key will not be detected,
1019            # indicating that the internal cache is operating
1020            query('alter table "%s1" rename column b to x' % t)
1021            self.assertEqual(pkey('%s1' % t), 'b')
1022            # we get the changed primary key when the cache is flushed
1023            self.assertEqual(pkey('%s1' % t, flush=True), 'x')
1024
1025    def testGetDatabases(self):
1026        databases = self.db.get_databases()
1027        self.assertIn('template0', databases)
1028        self.assertIn('template1', databases)
1029        self.assertNotIn('not existing database', databases)
1030        self.assertIn('postgres', databases)
1031        self.assertIn(dbname, databases)
1032
1033    def testGetTables(self):
1034        get_tables = self.db.get_tables
1035        tables = ('A very Special Name', 'A_MiXeD_quoted_NaMe',
1036                  'Hello, Test World!', 'Zoro', 'a1', 'a2', 'a321',
1037                  'averyveryveryveryveryveryveryreallyreallylongtablename',
1038                  'b0', 'b3', 'x', 'xXx', 'xx', 'y', 'z')
1039        for t in tables:
1040            self.db.query('drop table if exists "%s" cascade' % t)
1041        before_tables = get_tables()
1042        self.assertIsInstance(before_tables, list)
1043        for t in before_tables:
1044            t = t.split('.', 1)
1045            self.assertGreaterEqual(len(t), 2)
1046            if len(t) > 2:
1047                self.assertTrue(t[1].startswith('"'))
1048            t = t[0]
1049            self.assertNotEqual(t, 'information_schema')
1050            self.assertFalse(t.startswith('pg_'))
1051        for t in tables:
1052            self.createTable(t, 'as select 0', temporary=False)
1053        current_tables = get_tables()
1054        new_tables = [t for t in current_tables if t not in before_tables]
1055        expected_new_tables = ['public.%s' % (
1056            '"%s"' % t if ' ' in t or t != t.lower() else t) for t in tables]
1057        self.assertEqual(new_tables, expected_new_tables)
1058        self.doCleanups()
1059        after_tables = get_tables()
1060        self.assertEqual(after_tables, before_tables)
1061
1062    def testGetSystemTables(self):
1063        get_tables = self.db.get_tables
1064        result = get_tables()
1065        self.assertNotIn('pg_catalog.pg_class', result)
1066        self.assertNotIn('information_schema.tables', result)
1067        result = get_tables(system=False)
1068        self.assertNotIn('pg_catalog.pg_class', result)
1069        self.assertNotIn('information_schema.tables', result)
1070        result = get_tables(system=True)
1071        self.assertIn('pg_catalog.pg_class', result)
1072        self.assertNotIn('information_schema.tables', result)
1073
1074    def testGetRelations(self):
1075        get_relations = self.db.get_relations
1076        result = get_relations()
1077        self.assertIn('public.test', result)
1078        self.assertIn('public.test_view', result)
1079        result = get_relations('rv')
1080        self.assertIn('public.test', result)
1081        self.assertIn('public.test_view', result)
1082        result = get_relations('r')
1083        self.assertIn('public.test', result)
1084        self.assertNotIn('public.test_view', result)
1085        result = get_relations('v')
1086        self.assertNotIn('public.test', result)
1087        self.assertIn('public.test_view', result)
1088        result = get_relations('cisSt')
1089        self.assertNotIn('public.test', result)
1090        self.assertNotIn('public.test_view', result)
1091
1092    def testGetSystemRelations(self):
1093        get_relations = self.db.get_relations
1094        result = get_relations()
1095        self.assertNotIn('pg_catalog.pg_class', result)
1096        self.assertNotIn('information_schema.tables', result)
1097        result = get_relations(system=False)
1098        self.assertNotIn('pg_catalog.pg_class', result)
1099        self.assertNotIn('information_schema.tables', result)
1100        result = get_relations(system=True)
1101        self.assertIn('pg_catalog.pg_class', result)
1102        self.assertIn('information_schema.tables', result)
1103
1104    def testGetAttnames(self):
1105        get_attnames = self.db.get_attnames
1106        self.assertRaises(pg.ProgrammingError,
1107                          self.db.get_attnames, 'does_not_exist')
1108        self.assertRaises(pg.ProgrammingError,
1109                          self.db.get_attnames, 'has.too.many.dots')
1110        r = get_attnames('test')
1111        self.assertIsInstance(r, dict)
1112        if self.regtypes:
1113            self.assertEqual(r, dict(
1114                i2='smallint', i4='integer', i8='bigint', d='numeric',
1115                f4='real', f8='double precision', m='money',
1116                v4='character varying', c4='character', t='text'))
1117        else:
1118            self.assertEqual(r, dict(
1119                i2='int', i4='int', i8='int', d='num',
1120                f4='float', f8='float', m='money',
1121                v4='text', c4='text', t='text'))
1122        self.createTable('test_table',
1123                         'n int, alpha smallint, beta bool,'
1124                         ' gamma char(5), tau text, v varchar(3)')
1125        r = get_attnames('test_table')
1126        self.assertIsInstance(r, dict)
1127        if self.regtypes:
1128            self.assertEqual(r, dict(
1129                n='integer', alpha='smallint', beta='boolean',
1130                gamma='character', tau='text', v='character varying'))
1131        else:
1132            self.assertEqual(r, dict(
1133                n='int', alpha='int', beta='bool',
1134                gamma='text', tau='text', v='text'))
1135
1136    def testGetAttnamesWithQuotes(self):
1137        get_attnames = self.db.get_attnames
1138        table = 'test table for get_attnames()'
1139        self.createTable(table,
1140            '"Prime!" smallint, "much space" integer, "Questions?" text')
1141        r = get_attnames(table)
1142        self.assertIsInstance(r, dict)
1143        if self.regtypes:
1144            self.assertEqual(r, {
1145                'Prime!': 'smallint', 'much space': 'integer',
1146                'Questions?': 'text'})
1147        else:
1148            self.assertEqual(r, {
1149                'Prime!': 'int', 'much space': 'int', 'Questions?': 'text'})
1150        table = 'yet another test table for get_attnames()'
1151        self.createTable(table,
1152                         'a smallint, b integer, c bigint,'
1153                         ' e numeric, f real, f2 double precision, m money,'
1154                         ' x smallint, y smallint, z smallint,'
1155                         ' Normal_NaMe smallint, "Special Name" smallint,'
1156                         ' t text, u char(2), v varchar(2),'
1157                         ' primary key (y, u)', oids=True)
1158        r = get_attnames(table)
1159        self.assertIsInstance(r, dict)
1160        if self.regtypes:
1161            self.assertEqual(r, {
1162                'a': 'smallint', 'b': 'integer', 'c': 'bigint',
1163                'e': 'numeric', 'f': 'real', 'f2': 'double precision',
1164                'm': 'money', 'normal_name': 'smallint',
1165                'Special Name': 'smallint', 'u': 'character',
1166                't': 'text', 'v': 'character varying', 'y': 'smallint',
1167                'x': 'smallint', 'z': 'smallint', 'oid': 'oid'})
1168        else:
1169            self.assertEqual(r, {'a': 'int', 'b': 'int', 'c': 'int',
1170                 'e': 'num', 'f': 'float', 'f2': 'float', 'm': 'money',
1171                 'normal_name': 'int', 'Special Name': 'int',
1172                 'u': 'text', 't': 'text', 'v': 'text',
1173                 'y': 'int', 'x': 'int', 'z': 'int', 'oid': 'int'})
1174
1175    def testGetAttnamesWithRegtypes(self):
1176        get_attnames = self.db.get_attnames
1177        self.createTable('test_table', 'n int, alpha smallint, beta bool,'
1178            ' gamma char(5), tau text, v varchar(3)')
1179        use_regtypes = self.db.use_regtypes
1180        regtypes = use_regtypes()
1181        self.assertEqual(regtypes, self.regtypes)
1182        use_regtypes(True)
1183        try:
1184            r = get_attnames("test_table")
1185            self.assertIsInstance(r, dict)
1186        finally:
1187            use_regtypes(regtypes)
1188        self.assertEqual(r, dict(
1189            n='integer', alpha='smallint', beta='boolean',
1190            gamma='character', tau='text', v='character varying'))
1191
1192    def testGetAttnamesWithoutRegtypes(self):
1193        get_attnames = self.db.get_attnames
1194        self.createTable('test_table', 'n int, alpha smallint, beta bool,'
1195            ' gamma char(5), tau text, v varchar(3)')
1196        use_regtypes = self.db.use_regtypes
1197        regtypes = use_regtypes()
1198        self.assertEqual(regtypes, self.regtypes)
1199        use_regtypes(False)
1200        try:
1201            r = get_attnames("test_table")
1202            self.assertIsInstance(r, dict)
1203        finally:
1204            use_regtypes(regtypes)
1205        self.assertEqual(r, dict(
1206            n='int', alpha='int', beta='bool',
1207            gamma='text', tau='text', v='text'))
1208
1209    def testGetAttnamesIsCached(self):
1210        get_attnames = self.db.get_attnames
1211        int_type = 'integer' if self.regtypes else 'int'
1212        text_type = 'text'
1213        query = self.db.query
1214        self.createTable('test_table', 'col int')
1215        r = get_attnames("test_table")
1216        self.assertIsInstance(r, dict)
1217        self.assertEqual(r, dict(col=int_type))
1218        query("alter table test_table alter column col type text")
1219        query("alter table test_table add column col2 int")
1220        r = get_attnames("test_table")
1221        self.assertEqual(r, dict(col=int_type))
1222        r = get_attnames("test_table", flush=True)
1223        self.assertEqual(r, dict(col=text_type, col2=int_type))
1224        query("alter table test_table drop column col2")
1225        r = get_attnames("test_table")
1226        self.assertEqual(r, dict(col=text_type, col2=int_type))
1227        r = get_attnames("test_table", flush=True)
1228        self.assertEqual(r, dict(col=text_type))
1229        query("alter table test_table drop column col")
1230        r = get_attnames("test_table")
1231        self.assertEqual(r, dict(col=text_type))
1232        r = get_attnames("test_table", flush=True)
1233        self.assertEqual(r, dict())
1234
1235    def testGetAttnamesIsOrdered(self):
1236        get_attnames = self.db.get_attnames
1237        r = get_attnames('test', flush=True)
1238        self.assertIsInstance(r, OrderedDict)
1239        if self.regtypes:
1240            self.assertEqual(r, OrderedDict([
1241                ('i2', 'smallint'), ('i4', 'integer'), ('i8', 'bigint'),
1242                ('d', 'numeric'), ('f4', 'real'), ('f8', 'double precision'),
1243                ('m', 'money'), ('v4', 'character varying'),
1244                ('c4', 'character'), ('t', 'text')]))
1245        else:
1246            self.assertEqual(r, OrderedDict([
1247                ('i2', 'int'), ('i4', 'int'), ('i8', 'int'),
1248                ('d', 'num'), ('f4', 'float'), ('f8', 'float'), ('m', 'money'),
1249                ('v4', 'text'), ('c4', 'text'), ('t', 'text')]))
1250        if OrderedDict is not dict:
1251            r = ' '.join(list(r.keys()))
1252            self.assertEqual(r, 'i2 i4 i8 d f4 f8 m v4 c4 t')
1253        table = 'test table for get_attnames'
1254        self.createTable(table, 'n int, alpha smallint, v varchar(3),'
1255                ' gamma char(5), tau text, beta bool')
1256        r = get_attnames(table)
1257        self.assertIsInstance(r, OrderedDict)
1258        if self.regtypes:
1259            self.assertEqual(r, OrderedDict([
1260                ('n', 'integer'), ('alpha', 'smallint'),
1261                ('v', 'character varying'), ('gamma', 'character'),
1262                ('tau', 'text'), ('beta', 'boolean')]))
1263        else:
1264            self.assertEqual(r, OrderedDict([
1265                ('n', 'int'), ('alpha', 'int'), ('v', 'text'),
1266                ('gamma', 'text'), ('tau', 'text'), ('beta', 'bool')]))
1267        if OrderedDict is not dict:
1268            r = ' '.join(list(r.keys()))
1269            self.assertEqual(r, 'n alpha v gamma tau beta')
1270        else:
1271            self.skipTest('OrderedDict is not supported')
1272
1273    def testGetAttnamesIsAttrDict(self):
1274        AttrDict = pg.AttrDict
1275        get_attnames = self.db.get_attnames
1276        r = get_attnames('test', flush=True)
1277        self.assertIsInstance(r, AttrDict)
1278        if self.regtypes:
1279            self.assertEqual(r, AttrDict([
1280                ('i2', 'smallint'), ('i4', 'integer'), ('i8', 'bigint'),
1281                ('d', 'numeric'), ('f4', 'real'), ('f8', 'double precision'),
1282                ('m', 'money'), ('v4', 'character varying'),
1283                ('c4', 'character'), ('t', 'text')]))
1284        else:
1285            self.assertEqual(r, AttrDict([
1286                ('i2', 'int'), ('i4', 'int'), ('i8', 'int'),
1287                ('d', 'num'), ('f4', 'float'), ('f8', 'float'), ('m', 'money'),
1288                ('v4', 'text'), ('c4', 'text'), ('t', 'text')]))
1289        r = ' '.join(list(r.keys()))
1290        self.assertEqual(r, 'i2 i4 i8 d f4 f8 m v4 c4 t')
1291        table = 'test table for get_attnames'
1292        self.createTable(table, 'n int, alpha smallint, v varchar(3),'
1293                ' gamma char(5), tau text, beta bool')
1294        r = get_attnames(table)
1295        self.assertIsInstance(r, AttrDict)
1296        if self.regtypes:
1297            self.assertEqual(r, AttrDict([
1298                ('n', 'integer'), ('alpha', 'smallint'),
1299                ('v', 'character varying'), ('gamma', 'character'),
1300                ('tau', 'text'), ('beta', 'boolean')]))
1301        else:
1302            self.assertEqual(r, AttrDict([
1303                ('n', 'int'), ('alpha', 'int'), ('v', 'text'),
1304                ('gamma', 'text'), ('tau', 'text'), ('beta', 'bool')]))
1305        r = ' '.join(list(r.keys()))
1306        self.assertEqual(r, 'n alpha v gamma tau beta')
1307
1308    def testHasTablePrivilege(self):
1309        can = self.db.has_table_privilege
1310        self.assertEqual(can('test'), True)
1311        self.assertEqual(can('test', 'select'), True)
1312        self.assertEqual(can('test', 'SeLeCt'), True)
1313        self.assertEqual(can('test', 'SELECT'), True)
1314        self.assertEqual(can('test', 'insert'), True)
1315        self.assertEqual(can('test', 'update'), True)
1316        self.assertEqual(can('test', 'delete'), True)
1317        self.assertRaises(pg.DataError, can, 'test', 'foobar')
1318        self.assertRaises(pg.ProgrammingError, can, 'table_does_not_exist')
1319        r = self.db.query('select rolsuper FROM pg_roles'
1320            ' where rolname=current_user').getresult()[0][0]
1321        if not pg.get_bool():
1322            r = r == 't'
1323        if r:
1324            self.skipTest('must not be superuser')
1325        self.assertEqual(can('pg_views', 'select'), True)
1326        self.assertEqual(can('pg_views', 'delete'), False)
1327
1328    def testGet(self):
1329        get = self.db.get
1330        query = self.db.query
1331        table = 'get_test_table'
1332        self.assertRaises(TypeError, get)
1333        self.assertRaises(TypeError, get, table)
1334        self.createTable(table, 'n integer, t text',
1335                         values=enumerate('xyz', start=1))
1336        self.assertRaises(pg.ProgrammingError, get, table, 2)
1337        r = get(table, 2, 'n')
1338        self.assertIsInstance(r, dict)
1339        self.assertEqual(r, dict(n=2, t='y'))
1340        r = get(table, 1, 'n')
1341        self.assertEqual(r, dict(n=1, t='x'))
1342        r = get(table, (3,), ('n',))
1343        self.assertEqual(r, dict(n=3, t='z'))
1344        r = get(table, 'y', 't')
1345        self.assertEqual(r, dict(n=2, t='y'))
1346        self.assertRaises(pg.DatabaseError, get, table, 4)
1347        self.assertRaises(pg.DatabaseError, get, table, 4, 'n')
1348        self.assertRaises(pg.DatabaseError, get, table, 'y')
1349        self.assertRaises(pg.DatabaseError, get, table, 2, 't')
1350        s = dict(n=3)
1351        self.assertRaises(pg.ProgrammingError, get, table, s)
1352        r = get(table, s, 'n')
1353        self.assertIs(r, s)
1354        self.assertEqual(r, dict(n=3, t='z'))
1355        s.update(t='x')
1356        r = get(table, s, 't')
1357        self.assertIs(r, s)
1358        self.assertEqual(s, dict(n=1, t='x'))
1359        r = get(table, s, ('n', 't'))
1360        self.assertIs(r, s)
1361        self.assertEqual(r, dict(n=1, t='x'))
1362        query('alter table "%s" alter n set not null' % table)
1363        query('alter table "%s" add primary key (n)' % table)
1364        r = get(table, 2)
1365        self.assertIsInstance(r, dict)
1366        self.assertEqual(r, dict(n=2, t='y'))
1367        self.assertEqual(get(table, 1)['t'], 'x')
1368        self.assertEqual(get(table, 3)['t'], 'z')
1369        self.assertEqual(get(table + '*', 2)['t'], 'y')
1370        self.assertEqual(get(table + ' *', 2)['t'], 'y')
1371        self.assertRaises(KeyError, get, table, (2, 2))
1372        s = dict(n=3)
1373        r = get(table, s)
1374        self.assertIs(r, s)
1375        self.assertEqual(r, dict(n=3, t='z'))
1376        s.update(n=1)
1377        self.assertEqual(get(table, s)['t'], 'x')
1378        s.update(n=2)
1379        self.assertEqual(get(table, r)['t'], 'y')
1380        s.pop('n')
1381        self.assertRaises(KeyError, get, table, s)
1382
1383    def testGetWithOid(self):
1384        get = self.db.get
1385        query = self.db.query
1386        table = 'get_with_oid_test_table'
1387        self.createTable(table, 'n integer, t text', oids=True,
1388                         values=enumerate('xyz', start=1))
1389        self.assertRaises(pg.ProgrammingError, get, table, 2)
1390        self.assertRaises(KeyError, get, table, {}, 'oid')
1391        r = get(table, 2, 'n')
1392        qoid = 'oid(%s)' % table
1393        self.assertIn(qoid, r)
1394        oid = r[qoid]
1395        self.assertIsInstance(oid, int)
1396        result = {'t': 'y', 'n': 2, qoid: oid}
1397        self.assertEqual(r, result)
1398        r = get(table, oid, 'oid')
1399        self.assertEqual(r, result)
1400        r = get(table, dict(oid=oid))
1401        self.assertEqual(r, result)
1402        r = get(table, dict(oid=oid), 'oid')
1403        self.assertEqual(r, result)
1404        r = get(table, {qoid: oid})
1405        self.assertEqual(r, result)
1406        r = get(table, {qoid: oid}, 'oid')
1407        self.assertEqual(r, result)
1408        self.assertEqual(get(table + '*', 2, 'n'), r)
1409        self.assertEqual(get(table + ' *', 2, 'n'), r)
1410        self.assertEqual(get(table, oid, 'oid')['t'], 'y')
1411        self.assertEqual(get(table, 1, 'n')['t'], 'x')
1412        self.assertEqual(get(table, 3, 'n')['t'], 'z')
1413        self.assertEqual(get(table, 2, 'n')['t'], 'y')
1414        self.assertRaises(pg.DatabaseError, get, table, 4, 'n')
1415        r['n'] = 3
1416        self.assertEqual(get(table, r, 'n')['t'], 'z')
1417        self.assertEqual(get(table, 1, 'n')['t'], 'x')
1418        self.assertEqual(get(table, r, 'oid')['t'], 'z')
1419        query('alter table "%s" alter n set not null' % table)
1420        query('alter table "%s" add primary key (n)' % table)
1421        self.assertEqual(get(table, 3)['t'], 'z')
1422        self.assertEqual(get(table, 1)['t'], 'x')
1423        self.assertEqual(get(table, 2)['t'], 'y')
1424        r['n'] = 1
1425        self.assertEqual(get(table, r)['t'], 'x')
1426        r['n'] = 3
1427        self.assertEqual(get(table, r)['t'], 'z')
1428        r['n'] = 2
1429        self.assertEqual(get(table, r)['t'], 'y')
1430        r = get(table, oid, 'oid')
1431        self.assertEqual(r, result)
1432        r = get(table, dict(oid=oid))
1433        self.assertEqual(r, result)
1434        r = get(table, dict(oid=oid), 'oid')
1435        self.assertEqual(r, result)
1436        r = get(table, {qoid: oid})
1437        self.assertEqual(r, result)
1438        r = get(table, {qoid: oid}, 'oid')
1439        self.assertEqual(r, result)
1440        r = get(table, dict(oid=oid, n=1))
1441        self.assertEqual(r['n'], 1)
1442        self.assertNotEqual(r[qoid], oid)
1443        r = get(table, dict(oid=oid, t='z'), 't')
1444        self.assertEqual(r['n'], 3)
1445        self.assertNotEqual(r[qoid], oid)
1446
1447    def testGetWithCompositeKey(self):
1448        get = self.db.get
1449        query = self.db.query
1450        table = 'get_test_table_1'
1451        self.createTable(table, 'n integer primary key, t text',
1452                         values=enumerate('abc', start=1))
1453        self.assertEqual(get(table, 2)['t'], 'b')
1454        self.assertEqual(get(table, 1, 'n')['t'], 'a')
1455        self.assertEqual(get(table, 2, ('n',))['t'], 'b')
1456        self.assertEqual(get(table, 3, ['n'])['t'], 'c')
1457        self.assertEqual(get(table, (2,), ('n',))['t'], 'b')
1458        self.assertEqual(get(table, 'b', 't')['n'], 2)
1459        self.assertEqual(get(table, ('a',), ('t',))['n'], 1)
1460        self.assertEqual(get(table, ['c'], ['t'])['n'], 3)
1461        table = 'get_test_table_2'
1462        self.createTable(table,
1463                         'n integer, m integer, t text, primary key (n, m)',
1464                         values=[(n + 1, m + 1, chr(ord('a') + 2 * n + m))
1465                                 for n in range(3) for m in range(2)])
1466        self.assertRaises(KeyError, get, table, 2)
1467        self.assertEqual(get(table, (1, 1))['t'], 'a')
1468        self.assertEqual(get(table, (1, 2))['t'], 'b')
1469        self.assertEqual(get(table, (2, 1))['t'], 'c')
1470        self.assertEqual(get(table, (1, 2), ('n', 'm'))['t'], 'b')
1471        self.assertEqual(get(table, (1, 2), ('m', 'n'))['t'], 'c')
1472        self.assertEqual(get(table, (3, 1), ('n', 'm'))['t'], 'e')
1473        self.assertEqual(get(table, (1, 3), ('m', 'n'))['t'], 'e')
1474        self.assertEqual(get(table, dict(n=2, m=2))['t'], 'd')
1475        self.assertEqual(get(table, dict(n=1, m=2), ('n', 'm'))['t'], 'b')
1476        self.assertEqual(get(table, dict(n=2, m=1), ['n', 'm'])['t'], 'c')
1477        self.assertEqual(get(table, dict(n=3, m=2), ('m', 'n'))['t'], 'f')
1478
1479    def testGetWithQuotedNames(self):
1480        get = self.db.get
1481        query = self.db.query
1482        table = 'test table for get()'
1483        self.createTable(table, '"Prime!" smallint primary key,'
1484                                ' "much space" integer, "Questions?" text',
1485                         values=[(17, 1001, 'No!')])
1486        r = get(table, 17)
1487        self.assertIsInstance(r, dict)
1488        self.assertEqual(r['Prime!'], 17)
1489        self.assertEqual(r['much space'], 1001)
1490        self.assertEqual(r['Questions?'], 'No!')
1491
1492    def testGetFromView(self):
1493        self.db.query('delete from test where i4=14')
1494        self.db.query('insert into test (i4, v4) values('
1495                      "14, 'abc4')")
1496        r = self.db.get('test_view', 14, 'i4')
1497        self.assertIn('v4', r)
1498        self.assertEqual(r['v4'], 'abc4')
1499
1500    def testGetLittleBobbyTables(self):
1501        get = self.db.get
1502        query = self.db.query
1503        self.createTable('test_students',
1504                         'firstname varchar primary key, nickname varchar, grade char(2)',
1505                         values=[("D'Arcy", 'Darcey', 'A+'), ('Sheldon', 'Moonpie', 'A+'),
1506                                 ('Robert', 'Little Bobby Tables', 'D-')])
1507        r = get('test_students', 'Sheldon')
1508        self.assertEqual(r, dict(
1509            firstname="Sheldon", nickname='Moonpie', grade='A+'))
1510        r = get('test_students', 'Robert')
1511        self.assertEqual(r, dict(
1512            firstname="Robert", nickname='Little Bobby Tables', grade='D-'))
1513        r = get('test_students', "D'Arcy")
1514        self.assertEqual(r, dict(
1515            firstname="D'Arcy", nickname='Darcey', grade='A+'))
1516        try:
1517            get('test_students', "D' Arcy")
1518        except pg.DatabaseError as error:
1519            self.assertEqual(str(error),
1520                'No such record in test_students\nwhere "firstname" = $1\n'
1521                'with $1="D\' Arcy"')
1522        try:
1523            get('test_students', "Robert'); TRUNCATE TABLE test_students;--")
1524        except pg.DatabaseError as error:
1525            self.assertEqual(str(error),
1526                'No such record in test_students\nwhere "firstname" = $1\n'
1527                'with $1="Robert\'); TRUNCATE TABLE test_students;--"')
1528        q = "select * from test_students order by 1 limit 4"
1529        r = query(q).getresult()
1530        self.assertEqual(len(r), 3)
1531        self.assertEqual(r[1][2], 'D-')
1532
1533    def testInsert(self):
1534        insert = self.db.insert
1535        query = self.db.query
1536        bool_on = pg.get_bool()
1537        decimal = pg.get_decimal()
1538        table = 'insert_test_table'
1539        self.createTable(table,
1540            'i2 smallint, i4 integer, i8 bigint,'
1541            ' d numeric, f4 real, f8 double precision, m money,'
1542            ' v4 varchar(4), c4 char(4), t text,'
1543            ' b boolean, ts timestamp', oids=True)
1544        oid_table = 'oid(%s)' % table
1545        tests = [dict(i2=None, i4=None, i8=None),
1546             (dict(i2='', i4='', i8=''), dict(i2=None, i4=None, i8=None)),
1547             (dict(i2=0, i4=0, i8=0), dict(i2=0, i4=0, i8=0)),
1548             dict(i2=42, i4=123456, i8=9876543210),
1549             dict(i2=2 ** 15 - 1,
1550                  i4=int(2 ** 31 - 1), i8=long(2 ** 63 - 1)),
1551             dict(d=None), (dict(d=''), dict(d=None)),
1552             dict(d=Decimal(0)), (dict(d=0), dict(d=Decimal(0))),
1553             dict(f4=None, f8=None), dict(f4=0, f8=0),
1554             (dict(f4='', f8=''), dict(f4=None, f8=None)),
1555             (dict(d=1234.5, f4=1234.5, f8=1234.5),
1556              dict(d=Decimal('1234.5'))),
1557             dict(d=Decimal('123.456789'), f4=12.375, f8=123.4921875),
1558             dict(d=Decimal('123456789.9876543212345678987654321')),
1559             dict(m=None), (dict(m=''), dict(m=None)),
1560             dict(m=Decimal('-1234.56')),
1561             (dict(m=('-1234.56')), dict(m=Decimal('-1234.56'))),
1562             dict(m=Decimal('1234.56')), dict(m=Decimal('123456')),
1563             (dict(m='1234.56'), dict(m=Decimal('1234.56'))),
1564             (dict(m=1234.5), dict(m=Decimal('1234.5'))),
1565             (dict(m=-1234.5), dict(m=Decimal('-1234.5'))),
1566             (dict(m=123456), dict(m=Decimal('123456'))),
1567             (dict(m='1234567.89'), dict(m=Decimal('1234567.89'))),
1568             dict(b=None), (dict(b=''), dict(b=None)),
1569             dict(b='f'), dict(b='t'),
1570             (dict(b=0), dict(b='f')), (dict(b=1), dict(b='t')),
1571             (dict(b=False), dict(b='f')), (dict(b=True), dict(b='t')),
1572             (dict(b='0'), dict(b='f')), (dict(b='1'), dict(b='t')),
1573             (dict(b='n'), dict(b='f')), (dict(b='y'), dict(b='t')),
1574             (dict(b='no'), dict(b='f')), (dict(b='yes'), dict(b='t')),
1575             (dict(b='off'), dict(b='f')), (dict(b='on'), dict(b='t')),
1576             dict(v4=None, c4=None, t=None),
1577             (dict(v4='', c4='', t=''), dict(c4=' ' * 4)),
1578             dict(v4='1234', c4='1234', t='1234' * 10),
1579             dict(v4='abcd', c4='abcd', t='abcdefg'),
1580             (dict(v4='abc', c4='abc', t='abc'), dict(c4='abc ')),
1581             dict(ts=None), (dict(ts=''), dict(ts=None)),
1582             (dict(ts=0), dict(ts=None)), (dict(ts=False), dict(ts=None)),
1583             dict(ts='2012-12-21 00:00:00'),
1584             (dict(ts='2012-12-21'), dict(ts='2012-12-21 00:00:00')),
1585             dict(ts='2012-12-21 12:21:12'),
1586             dict(ts='2013-01-05 12:13:14'),
1587             dict(ts='current_timestamp')]
1588        for test in tests:
1589            if isinstance(test, dict):
1590                data = test
1591                change = {}
1592            else:
1593                data, change = test
1594            expect = data.copy()
1595            expect.update(change)
1596            if bool_on:
1597                b = expect.get('b')
1598                if b is not None:
1599                    expect['b'] = b == 't'
1600            if decimal is not Decimal:
1601                d = expect.get('d')
1602                if d is not None:
1603                    expect['d'] = decimal(d)
1604                m = expect.get('m')
1605                if m is not None:
1606                    expect['m'] = decimal(m)
1607            self.assertEqual(insert(table, data), data)
1608            self.assertIn(oid_table, data)
1609            oid = data[oid_table]
1610            self.assertIsInstance(oid, int)
1611            data = dict(item for item in data.items()
1612                        if item[0] in expect)
1613            ts = expect.get('ts')
1614            if ts:
1615                if ts == 'current_timestamp':
1616                    ts = data['ts']
1617                    self.assertIsInstance(ts, datetime)
1618                    self.assertEqual(ts.strftime('%Y-%m-%d'),
1619                        strftime('%Y-%m-%d'))
1620                else:
1621                    ts = datetime.strptime(ts, '%Y-%m-%d %H:%M:%S')
1622                expect['ts'] = ts
1623            self.assertEqual(data, expect)
1624            data = query(
1625                'select oid,* from "%s"' % table).dictresult()[0]
1626            self.assertEqual(data['oid'], oid)
1627            data = dict(item for item in data.items()
1628                        if item[0] in expect)
1629            self.assertEqual(data, expect)
1630            query('delete from "%s"' % table)
1631
1632    def testInsertWithOid(self):
1633        insert = self.db.insert
1634        query = self.db.query
1635        self.createTable('test_table', 'n int', oids=True)
1636        self.assertRaises(pg.ProgrammingError, insert, 'test_table', m=1)
1637        r = insert('test_table', n=1)
1638        self.assertIsInstance(r, dict)
1639        self.assertEqual(r['n'], 1)
1640        self.assertNotIn('oid', r)
1641        qoid = 'oid(test_table)'
1642        self.assertIn(qoid, r)
1643        oid = r[qoid]
1644        self.assertEqual(sorted(r.keys()), ['n', qoid])
1645        r = insert('test_table', n=2, oid=oid)
1646        self.assertIsInstance(r, dict)
1647        self.assertEqual(r['n'], 2)
1648        self.assertIn(qoid, r)
1649        self.assertNotEqual(r[qoid], oid)
1650        self.assertNotIn('oid', r)
1651        r = insert('test_table', None, n=3)
1652        self.assertIsInstance(r, dict)
1653        self.assertEqual(r['n'], 3)
1654        s = r
1655        r = insert('test_table', r)
1656        self.assertIs(r, s)
1657        self.assertEqual(r['n'], 3)
1658        r = insert('test_table *', r)
1659        self.assertIs(r, s)
1660        self.assertEqual(r['n'], 3)
1661        r = insert('test_table', r, n=4)
1662        self.assertIs(r, s)
1663        self.assertEqual(r['n'], 4)
1664        self.assertNotIn('oid', r)
1665        self.assertIn(qoid, r)
1666        oid = r[qoid]
1667        r = insert('test_table', r, n=5, oid=oid)
1668        self.assertIs(r, s)
1669        self.assertEqual(r['n'], 5)
1670        self.assertIn(qoid, r)
1671        self.assertNotEqual(r[qoid], oid)
1672        self.assertNotIn('oid', r)
1673        r['oid'] = oid = r[qoid]
1674        r = insert('test_table', r, n=6)
1675        self.assertIs(r, s)
1676        self.assertEqual(r['n'], 6)
1677        self.assertIn(qoid, r)
1678        self.assertNotEqual(r[qoid], oid)
1679        self.assertNotIn('oid', r)
1680        q = 'select n from test_table order by 1 limit 9'
1681        r = ' '.join(str(row[0]) for row in query(q).getresult())
1682        self.assertEqual(r, '1 2 3 3 3 4 5 6')
1683        query("truncate test_table")
1684        query("alter table test_table add unique (n)")
1685        r = insert('test_table', dict(n=7))
1686        self.assertIsInstance(r, dict)
1687        self.assertEqual(r['n'], 7)
1688        self.assertRaises(pg.IntegrityError, insert, 'test_table', r)
1689        r['n'] = 6
1690        self.assertRaises(pg.IntegrityError, insert, 'test_table', r, n=7)
1691        self.assertIsInstance(r, dict)
1692        self.assertEqual(r['n'], 7)
1693        r['n'] = 6
1694        r = insert('test_table', r)
1695        self.assertIsInstance(r, dict)
1696        self.assertEqual(r['n'], 6)
1697        r = ' '.join(str(row[0]) for row in query(q).getresult())
1698        self.assertEqual(r, '6 7')
1699
1700    def testInsertWithQuotedNames(self):
1701        insert = self.db.insert
1702        query = self.db.query
1703        table = 'test table for insert()'
1704        self.createTable(table, '"Prime!" smallint primary key,'
1705                                ' "much space" integer, "Questions?" text')
1706        r = {'Prime!': 11, 'much space': 2002, 'Questions?': 'What?'}
1707        r = insert(table, r)
1708        self.assertIsInstance(r, dict)
1709        self.assertEqual(r['Prime!'], 11)
1710        self.assertEqual(r['much space'], 2002)
1711        self.assertEqual(r['Questions?'], 'What?')
1712        r = query('select * from "%s" limit 2' % table).dictresult()
1713        self.assertEqual(len(r), 1)
1714        r = r[0]
1715        self.assertEqual(r['Prime!'], 11)
1716        self.assertEqual(r['much space'], 2002)
1717        self.assertEqual(r['Questions?'], 'What?')
1718
1719    def testInsertIntoView(self):
1720        insert = self.db.insert
1721        query = self.db.query
1722        query("truncate test")
1723        q = 'select * from test_view order by i4 limit 3'
1724        r = query(q).getresult()
1725        self.assertEqual(r, [])
1726        r = dict(i4=1234, v4='abcd')
1727        insert('test', r)
1728        self.assertIsNone(r['i2'])
1729        self.assertEqual(r['i4'], 1234)
1730        self.assertIsNone(r['i8'])
1731        self.assertEqual(r['v4'], 'abcd')
1732        self.assertIsNone(r['c4'])
1733        r = query(q).getresult()
1734        self.assertEqual(r, [(1234, 'abcd')])
1735        r = dict(i4=5678, v4='efgh')
1736        try:
1737            insert('test_view', r)
1738        except (pg.OperationalError, pg.NotSupportedError) as error:
1739            if self.db.server_version < 90300:
1740                # must setup rules in older PostgreSQL versions
1741                self.skipTest('database cannot insert into view')
1742            self.fail(str(error))
1743        self.assertNotIn('i2', r)
1744        self.assertEqual(r['i4'], 5678)
1745        self.assertNotIn('i8', r)
1746        self.assertEqual(r['v4'], 'efgh')
1747        self.assertNotIn('c4', r)
1748        r = query(q).getresult()
1749        self.assertEqual(r, [(1234, 'abcd'), (5678, 'efgh')])
1750
1751    def testUpdate(self):
1752        update = self.db.update
1753        query = self.db.query
1754        self.assertRaises(pg.ProgrammingError, update,
1755                          'test', i2=2, i4=4, i8=8)
1756        table = 'update_test_table'
1757        self.createTable(table, 'n integer, t text', oids=True,
1758                         values=enumerate('xyz', start=1))
1759        self.assertRaises(pg.ProgrammingError, self.db.get, table, 2)
1760        r = self.db.get(table, 2, 'n')
1761        r['t'] = 'u'
1762        s = update(table, r)
1763        self.assertEqual(s, r)
1764        q = 'select t from "%s" where n=2' % table
1765        r = query(q).getresult()[0][0]
1766        self.assertEqual(r, 'u')
1767
1768    def testUpdateWithOid(self):
1769        update = self.db.update
1770        get = self.db.get
1771        query = self.db.query
1772        self.createTable('test_table', 'n int', oids=True, values=[1])
1773        s = get('test_table', 1, 'n')
1774        self.assertIsInstance(s, dict)
1775        self.assertEqual(s['n'], 1)
1776        s['n'] = 2
1777        r = update('test_table', s)
1778        self.assertIs(r, s)
1779        self.assertEqual(r['n'], 2)
1780        qoid = 'oid(test_table)'
1781        self.assertIn(qoid, r)
1782        self.assertNotIn('oid', r)
1783        self.assertEqual(sorted(r.keys()), ['n', qoid])
1784        r['n'] = 3
1785        oid = r.pop(qoid)
1786        r = update('test_table', r, oid=oid)
1787        self.assertIs(r, s)
1788        self.assertEqual(r['n'], 3)
1789        r.pop(qoid)
1790        self.assertRaises(pg.ProgrammingError, update, 'test_table', r)
1791        s = get('test_table', 3, 'n')
1792        self.assertIsInstance(s, dict)
1793        self.assertEqual(s['n'], 3)
1794        s.pop('n')
1795        r = update('test_table', s)
1796        oid = r.pop(qoid)
1797        self.assertEqual(r, {})
1798        q = "select n from test_table limit 2"
1799        r = query(q).getresult()
1800        self.assertEqual(r, [(3,)])
1801        query("insert into test_table values (1)")
1802        self.assertRaises(pg.ProgrammingError,
1803                          update, 'test_table', dict(oid=oid, n=4))
1804        r = update('test_table', dict(n=4), oid=oid)
1805        self.assertEqual(r['n'], 4)
1806        r = update('test_table *', dict(n=5), oid=oid)
1807        self.assertEqual(r['n'], 5)
1808        query("alter table test_table add column m int")
1809        query("alter table test_table add primary key (n)")
1810        self.assertIn('m', self.db.get_attnames('test_table', flush=True))
1811        self.assertEqual('n', self.db.pkey('test_table', flush=True))
1812        s = dict(n=1, m=4)
1813        r = update('test_table', s)
1814        self.assertIs(r, s)
1815        self.assertEqual(r['n'], 1)
1816        self.assertEqual(r['m'], 4)
1817        s = dict(m=7)
1818        r = update('test_table', s, n=5)
1819        self.assertIs(r, s)
1820        self.assertEqual(r['n'], 5)
1821        self.assertEqual(r['m'], 7)
1822        q = "select n, m from test_table order by 1 limit 3"
1823        r = query(q).getresult()
1824        self.assertEqual(r, [(1, 4), (5, 7)])
1825        s = dict(m=9, oid=oid)
1826        self.assertRaises(KeyError, update, 'test_table', s)
1827        r = update('test_table', s, oid=oid)
1828        self.assertIs(r, s)
1829        self.assertEqual(r['n'], 5)
1830        self.assertEqual(r['m'], 9)
1831        s = dict(n=1, m=3, oid=oid)
1832        r = update('test_table', s)
1833        self.assertIs(r, s)
1834        self.assertEqual(r['n'], 1)
1835        self.assertEqual(r['m'], 3)
1836        r = query(q).getresult()
1837        self.assertEqual(r, [(1, 3), (5, 9)])
1838        s.update(n=4, m=7)
1839        r = update('test_table', s, oid=oid)
1840        self.assertIs(r, s)
1841        self.assertEqual(r['n'], 4)
1842        self.assertEqual(r['m'], 7)
1843        r = query(q).getresult()
1844        self.assertEqual(r, [(1, 3), (4, 7)])
1845
1846    def testUpdateWithoutOid(self):
1847        update = self.db.update
1848        query = self.db.query
1849        self.assertRaises(pg.ProgrammingError, update,
1850                          'test', i2=2, i4=4, i8=8)
1851        table = 'update_test_table'
1852        self.createTable(table, 'n integer primary key, t text', oids=False,
1853                         values=enumerate('xyz', start=1))
1854        r = self.db.get(table, 2)
1855        r['t'] = 'u'
1856        s = update(table, r)
1857        self.assertEqual(s, r)
1858        q = 'select t from "%s" where n=2' % table
1859        r = query(q).getresult()[0][0]
1860        self.assertEqual(r, 'u')
1861
1862    def testUpdateWithCompositeKey(self):
1863        update = self.db.update
1864        query = self.db.query
1865        table = 'update_test_table_1'
1866        self.createTable(table, 'n integer primary key, t text',
1867                         values=enumerate('abc', start=1))
1868        self.assertRaises(KeyError, update, table, dict(t='b'))
1869        s = dict(n=2, t='d')
1870        r = update(table, s)
1871        self.assertIs(r, s)
1872        self.assertEqual(r['n'], 2)
1873        self.assertEqual(r['t'], 'd')
1874        q = 'select t from "%s" where n=2' % table
1875        r = query(q).getresult()[0][0]
1876        self.assertEqual(r, 'd')
1877        s.update(dict(n=4, t='e'))
1878        r = update(table, s)
1879        self.assertEqual(r['n'], 4)
1880        self.assertEqual(r['t'], 'e')
1881        q = 'select t from "%s" where n=2' % table
1882        r = query(q).getresult()[0][0]
1883        self.assertEqual(r, 'd')
1884        q = 'select t from "%s" where n=4' % table
1885        r = query(q).getresult()
1886        self.assertEqual(len(r), 0)
1887        query('drop table "%s"' % table)
1888        table = 'update_test_table_2'
1889        self.createTable(table,
1890                         'n integer, m integer, t text, primary key (n, m)',
1891                         values=[(n + 1, m + 1, chr(ord('a') + 2 * n + m))
1892                                 for n in range(3) for m in range(2)])
1893        self.assertRaises(KeyError, update, table, dict(n=2, t='b'))
1894        self.assertEqual(update(table,
1895                                dict(n=2, m=2, t='x'))['t'], 'x')
1896        q = 'select t from "%s" where n=2 order by m' % table
1897        r = [r[0] for r in query(q).getresult()]
1898        self.assertEqual(r, ['c', 'x'])
1899
1900    def testUpdateWithQuotedNames(self):
1901        update = self.db.update
1902        query = self.db.query
1903        table = 'test table for update()'
1904        self.createTable(table, '"Prime!" smallint primary key,'
1905                                ' "much space" integer, "Questions?" text',
1906                         values=[(13, 3003, 'Why!')])
1907        r = {'Prime!': 13, 'much space': 7007, 'Questions?': 'When?'}
1908        r = update(table, r)
1909        self.assertIsInstance(r, dict)
1910        self.assertEqual(r['Prime!'], 13)
1911        self.assertEqual(r['much space'], 7007)
1912        self.assertEqual(r['Questions?'], 'When?')
1913        r = query('select * from "%s" limit 2' % table).dictresult()
1914        self.assertEqual(len(r), 1)
1915        r = r[0]
1916        self.assertEqual(r['Prime!'], 13)
1917        self.assertEqual(r['much space'], 7007)
1918        self.assertEqual(r['Questions?'], 'When?')
1919
1920    def testUpsert(self):
1921        upsert = self.db.upsert
1922        query = self.db.query
1923        self.assertRaises(pg.ProgrammingError, upsert,
1924                          'test', i2=2, i4=4, i8=8)
1925        table = 'upsert_test_table'
1926        self.createTable(table, 'n integer primary key, t text', oids=True)
1927        s = dict(n=1, t='x')
1928        try:
1929            r = upsert(table, s)
1930        except pg.ProgrammingError as error:
1931            if self.db.server_version < 90500:
1932                self.skipTest('database does not support upsert')
1933            self.fail(str(error))
1934        self.assertIs(r, s)
1935        self.assertEqual(r['n'], 1)
1936        self.assertEqual(r['t'], 'x')
1937        s.update(n=2, t='y')
1938        r = upsert(table, s, **dict.fromkeys(s))
1939        self.assertIs(r, s)
1940        self.assertEqual(r['n'], 2)
1941        self.assertEqual(r['t'], 'y')
1942        q = 'select n, t from "%s" order by n limit 3' % table
1943        r = query(q).getresult()
1944        self.assertEqual(r, [(1, 'x'), (2, 'y')])
1945        s.update(t='z')
1946        r = upsert(table, s)
1947        self.assertIs(r, s)
1948        self.assertEqual(r['n'], 2)
1949        self.assertEqual(r['t'], 'z')
1950        r = query(q).getresult()
1951        self.assertEqual(r, [(1, 'x'), (2, 'z')])
1952        s.update(t='n')
1953        r = upsert(table, s, t=False)
1954        self.assertIs(r, s)
1955        self.assertEqual(r['n'], 2)
1956        self.assertEqual(r['t'], 'z')
1957        r = query(q).getresult()
1958        self.assertEqual(r, [(1, 'x'), (2, 'z')])
1959        s.update(t='y')
1960        r = upsert(table, s, t=True)
1961        self.assertIs(r, s)
1962        self.assertEqual(r['n'], 2)
1963        self.assertEqual(r['t'], 'y')
1964        r = query(q).getresult()
1965        self.assertEqual(r, [(1, 'x'), (2, 'y')])
1966        s.update(t='n')
1967        r = upsert(table, s, t="included.t || '2'")
1968        self.assertIs(r, s)
1969        self.assertEqual(r['n'], 2)
1970        self.assertEqual(r['t'], 'y2')
1971        r = query(q).getresult()
1972        self.assertEqual(r, [(1, 'x'), (2, 'y2')])
1973        s.update(t='y')
1974        r = upsert(table, s, t="excluded.t || '3'")
1975        self.assertIs(r, s)
1976        self.assertEqual(r['n'], 2)
1977        self.assertEqual(r['t'], 'y3')
1978        r = query(q).getresult()
1979        self.assertEqual(r, [(1, 'x'), (2, 'y3')])
1980        s.update(n=1, t='2')
1981        r = upsert(table, s, t="included.t || excluded.t")
1982        self.assertIs(r, s)
1983        self.assertEqual(r['n'], 1)
1984        self.assertEqual(r['t'], 'x2')
1985        r = query(q).getresult()
1986        self.assertEqual(r, [(1, 'x2'), (2, 'y3')])
1987        # not existing columns and oid parameter should be ignored
1988        s = dict(m=3, u='z')
1989        r = upsert(table, s, oid='invalid')
1990        self.assertIs(r, s)
1991
1992    def testUpsertWithOid(self):
1993        upsert = self.db.upsert
1994        get = self.db.get
1995        query = self.db.query
1996        self.createTable('test_table', 'n int', oids=True, values=[1])
1997        self.assertRaises(pg.ProgrammingError,
1998                          upsert, 'test_table', dict(n=2))
1999        r = get('test_table', 1, 'n')
2000        self.assertIsInstance(r, dict)
2001        self.assertEqual(r['n'], 1)
2002        qoid = 'oid(test_table)'
2003        self.assertIn(qoid, r)
2004        self.assertNotIn('oid', r)
2005        oid = r[qoid]
2006        self.assertRaises(pg.ProgrammingError,
2007                          upsert, 'test_table', dict(n=2, oid=oid))
2008        query("alter table test_table add column m int")
2009        query("alter table test_table add primary key (n)")
2010        self.assertIn('m', self.db.get_attnames('test_table', flush=True))
2011        self.assertEqual('n', self.db.pkey('test_table', flush=True))
2012        s = dict(n=2)
2013        try:
2014            r = upsert('test_table', s)
2015        except pg.ProgrammingError as error:
2016            if self.db.server_version < 90500:
2017                self.skipTest('database does not support upsert')
2018            self.fail(str(error))
2019        self.assertIs(r, s)
2020        self.assertEqual(r['n'], 2)
2021        self.assertIsNone(r['m'])
2022        q = query("select n, m from test_table order by n limit 3")
2023        self.assertEqual(q.getresult(), [(1, None), (2, None)])
2024        r['oid'] = oid
2025        r = upsert('test_table', r)
2026        self.assertIs(r, s)
2027        self.assertEqual(r['n'], 2)
2028        self.assertIsNone(r['m'])
2029        self.assertIn(qoid, r)
2030        self.assertNotIn('oid', r)
2031        self.assertNotEqual(r[qoid], oid)
2032        r['m'] = 7
2033        r = upsert('test_table', r)
2034        self.assertIs(r, s)
2035        self.assertEqual(r['n'], 2)
2036        self.assertEqual(r['m'], 7)
2037        r.update(n=1, m=3)
2038        r = upsert('test_table', r)
2039        self.assertIs(r, s)
2040        self.assertEqual(r['n'], 1)
2041        self.assertEqual(r['m'], 3)
2042        q = query("select n, m from test_table order by n limit 3")
2043        self.assertEqual(q.getresult(), [(1, 3), (2, 7)])
2044        r = upsert('test_table', r, oid='invalid')
2045        self.assertIs(r, s)
2046        self.assertEqual(r['n'], 1)
2047        self.assertEqual(r['m'], 3)
2048        r['m'] = 5
2049        r = upsert('test_table', r, m=False)
2050        self.assertIs(r, s)
2051        self.assertEqual(r['n'], 1)
2052        self.assertEqual(r['m'], 3)
2053        r['m'] = 5
2054        r = upsert('test_table', r, m=True)
2055        self.assertIs(r, s)
2056        self.assertEqual(r['n'], 1)
2057        self.assertEqual(r['m'], 5)
2058        r.update(n=2, m=1)
2059        r = upsert('test_table', r, m='included.m')
2060        self.assertIs(r, s)
2061        self.assertEqual(r['n'], 2)
2062        self.assertEqual(r['m'], 7)
2063        r['m'] = 9
2064        r = upsert('test_table', r, m='excluded.m')
2065        self.assertIs(r, s)
2066        self.assertEqual(r['n'], 2)
2067        self.assertEqual(r['m'], 9)
2068        r['m'] = 8
2069        r = upsert('test_table *', r, m='included.m + 1')
2070        self.assertIs(r, s)
2071        self.assertEqual(r['n'], 2)
2072        self.assertEqual(r['m'], 10)
2073        q = query("select n, m from test_table order by n limit 3")
2074        self.assertEqual(q.getresult(), [(1, 5), (2, 10)])
2075
2076    def testUpsertWithCompositeKey(self):
2077        upsert = self.db.upsert
2078        query = self.db.query
2079        table = 'upsert_test_table_2'
2080        self.createTable(table,
2081                         'n integer, m integer, t text, primary key (n, m)')
2082        s = dict(n=1, m=2, t='x')
2083        try:
2084            r = upsert(table, s)
2085        except pg.ProgrammingError as error:
2086            if self.db.server_version < 90500:
2087                self.skipTest('database does not support upsert')
2088            self.fail(str(error))
2089        self.assertIs(r, s)
2090        self.assertEqual(r['n'], 1)
2091        self.assertEqual(r['m'], 2)
2092        self.assertEqual(r['t'], 'x')
2093        s.update(m=3, t='y')
2094        r = upsert(table, s, **dict.fromkeys(s))
2095        self.assertIs(r, s)
2096        self.assertEqual(r['n'], 1)
2097        self.assertEqual(r['m'], 3)
2098        self.assertEqual(r['t'], 'y')
2099        q = 'select n, m, t from "%s" order by n, m limit 3' % table
2100        r = query(q).getresult()
2101        self.assertEqual(r, [(1, 2, 'x'), (1, 3, 'y')])
2102        s.update(t='z')
2103        r = upsert(table, s)
2104        self.assertIs(r, s)
2105        self.assertEqual(r['n'], 1)
2106        self.assertEqual(r['m'], 3)
2107        self.assertEqual(r['t'], 'z')
2108        r = query(q).getresult()
2109        self.assertEqual(r, [(1, 2, 'x'), (1, 3, 'z')])
2110        s.update(t='n')
2111        r = upsert(table, s, t=False)
2112        self.assertIs(r, s)
2113        self.assertEqual(r['n'], 1)
2114        self.assertEqual(r['m'], 3)
2115        self.assertEqual(r['t'], 'z')
2116        r = query(q).getresult()
2117        self.assertEqual(r, [(1, 2, 'x'), (1, 3, 'z')])
2118        s.update(t='n')
2119        r = upsert(table, s, t=True)
2120        self.assertIs(r, s)
2121        self.assertEqual(r['n'], 1)
2122        self.assertEqual(r['m'], 3)
2123        self.assertEqual(r['t'], 'n')
2124        r = query(q).getresult()
2125        self.assertEqual(r, [(1, 2, 'x'), (1, 3, 'n')])
2126        s.update(n=2, t='y')
2127        r = upsert(table, s, t="'z'")
2128        self.assertIs(r, s)
2129        self.assertEqual(r['n'], 2)
2130        self.assertEqual(r['m'], 3)
2131        self.assertEqual(r['t'], 'y')
2132        r = query(q).getresult()
2133        self.assertEqual(r, [(1, 2, 'x'), (1, 3, 'n'), (2, 3, 'y')])
2134        s.update(n=1, t='m')
2135        r = upsert(table, s, t='included.t || excluded.t')
2136        self.assertIs(r, s)
2137        self.assertEqual(r['n'], 1)
2138        self.assertEqual(r['m'], 3)
2139        self.assertEqual(r['t'], 'nm')
2140        r = query(q).getresult()
2141        self.assertEqual(r, [(1, 2, 'x'), (1, 3, 'nm'), (2, 3, 'y')])
2142
2143    def testUpsertWithQuotedNames(self):
2144        upsert = self.db.upsert
2145        query = self.db.query
2146        table = 'test table for upsert()'
2147        self.createTable(table, '"Prime!" smallint primary key,'
2148                                ' "much space" integer, "Questions?" text')
2149        s = {'Prime!': 31, 'much space': 9009, 'Questions?': 'Yes.'}
2150        try:
2151            r = upsert(table, s)
2152        except pg.ProgrammingError as error:
2153            if self.db.server_version < 90500:
2154                self.skipTest('database does not support upsert')
2155            self.fail(str(error))
2156        self.assertIs(r, s)
2157        self.assertEqual(r['Prime!'], 31)
2158        self.assertEqual(r['much space'], 9009)
2159        self.assertEqual(r['Questions?'], 'Yes.')
2160        q = 'select * from "%s" limit 2' % table
2161        r = query(q).getresult()
2162        self.assertEqual(r, [(31, 9009, 'Yes.')])
2163        s.update({'Questions?': 'No.'})
2164        r = upsert(table, s)
2165        self.assertIs(r, s)
2166        self.assertEqual(r['Prime!'], 31)
2167        self.assertEqual(r['much space'], 9009)
2168        self.assertEqual(r['Questions?'], 'No.')
2169        r = query(q).getresult()
2170        self.assertEqual(r, [(31, 9009, 'No.')])
2171
2172    def testClear(self):
2173        clear = self.db.clear
2174        f = False if pg.get_bool() else 'f'
2175        r = clear('test')
2176        result = dict(
2177            i2=0, i4=0, i8=0, d=0, f4=0, f8=0, m=0, v4='', c4='', t='')
2178        self.assertEqual(r, result)
2179        table = 'clear_test_table'
2180        self.createTable(table,
2181            'n integer, f float, b boolean, d date, t text', oids=True)
2182        r = clear(table)
2183        result = dict(n=0, f=0, b=f, d='', t='')
2184        self.assertEqual(r, result)
2185        r['a'] = r['f'] = r['n'] = 1
2186        r['d'] = r['t'] = 'x'
2187        r['b'] = 't'
2188        r['oid'] = long(1)
2189        r = clear(table, r)
2190        result = dict(a=1, n=0, f=0, b=f, d='', t='', oid=long(1))
2191        self.assertEqual(r, result)
2192
2193    def testClearWithQuotedNames(self):
2194        clear = self.db.clear
2195        table = 'test table for clear()'
2196        self.createTable(table, '"Prime!" smallint primary key,'
2197            ' "much space" integer, "Questions?" text')
2198        r = clear(table)
2199        self.assertIsInstance(r, dict)
2200        self.assertEqual(r['Prime!'], 0)
2201        self.assertEqual(r['much space'], 0)
2202        self.assertEqual(r['Questions?'], '')
2203
2204    def testDelete(self):
2205        delete = self.db.delete
2206        query = self.db.query
2207        self.assertRaises(pg.ProgrammingError, delete,
2208                          'test', dict(i2=2, i4=4, i8=8))
2209        table = 'delete_test_table'
2210        self.createTable(table, 'n integer, t text', oids=True,
2211                         values=enumerate('xyz', start=1))
2212        self.assertRaises(pg.ProgrammingError, self.db.get, table, 2)
2213        r = self.db.get(table, 1, 'n')
2214        s = delete(table, r)
2215        self.assertEqual(s, 1)
2216        r = self.db.get(table, 3, 'n')
2217        s = delete(table, r)
2218        self.assertEqual(s, 1)
2219        s = delete(table, r)
2220        self.assertEqual(s, 0)
2221        r = query('select * from "%s"' % table).dictresult()
2222        self.assertEqual(len(r), 1)
2223        r = r[0]
2224        result = {'n': 2, 't': 'y'}
2225        self.assertEqual(r, result)
2226        r = self.db.get(table, 2, 'n')
2227        s = delete(table, r)
2228        self.assertEqual(s, 1)
2229        s = delete(table, r)
2230        self.assertEqual(s, 0)
2231        self.assertRaises(pg.DatabaseError, self.db.get, table, 2, 'n')
2232        # not existing columns and oid parameter should be ignored
2233        r.update(m=3, u='z', oid='invalid')
2234        s = delete(table, r)
2235        self.assertEqual(s, 0)
2236
2237    def testDeleteWithOid(self):
2238        delete = self.db.delete
2239        get = self.db.get
2240        query = self.db.query
2241        self.createTable('test_table', 'n int', oids=True, values=range(1, 7))
2242        r = dict(n=3)
2243        self.assertRaises(pg.ProgrammingError, delete, 'test_table', r)
2244        s = get('test_table', 1, 'n')
2245        qoid = 'oid(test_table)'
2246        self.assertIn(qoid, s)
2247        r = delete('test_table', s)
2248        self.assertEqual(r, 1)
2249        r = delete('test_table', s)
2250        self.assertEqual(r, 0)
2251        q = "select min(n),count(n) from test_table"
2252        self.assertEqual(query(q).getresult()[0], (2, 5))
2253        oid = get('test_table', 2, 'n')[qoid]
2254        s = dict(oid=oid, n=2)
2255        self.assertRaises(pg.ProgrammingError, delete, 'test_table', s)
2256        r = delete('test_table', None, oid=oid)
2257        self.assertEqual(r, 1)
2258        r = delete('test_table', None, oid=oid)
2259        self.assertEqual(r, 0)
2260        self.assertEqual(query(q).getresult()[0], (3, 4))
2261        s = dict(oid=oid, n=2)
2262        oid = get('test_table', 3, 'n')[qoid]
2263        self.assertRaises(pg.ProgrammingError, delete, 'test_table', s)
2264        r = delete('test_table', s, oid=oid)
2265        self.assertEqual(r, 1)
2266        r = delete('test_table', s, oid=oid)
2267        self.assertEqual(r, 0)
2268        self.assertEqual(query(q).getresult()[0], (4, 3))
2269        s = get('test_table', 4, 'n')
2270        r = delete('test_table *', s)
2271        self.assertEqual(r, 1)
2272        r = delete('test_table *', s)
2273        self.assertEqual(r, 0)
2274        self.assertEqual(query(q).getresult()[0], (5, 2))
2275        oid = get('test_table', 5, 'n')[qoid]
2276        s = {qoid: oid, 'm': 4}
2277        r = delete('test_table', s, m=6)
2278        self.assertEqual(r, 1)
2279        r = delete('test_table *', s)
2280        self.assertEqual(r, 0)
2281        self.assertEqual(query(q).getresult()[0], (6, 1))
2282        query("alter table test_table add column m int")
2283        query("alter table test_table add primary key (n)")
2284        self.assertIn('m', self.db.get_attnames('test_table', flush=True))
2285        self.assertEqual('n', self.db.pkey('test_table', flush=True))
2286        for i in range(5):
2287            query("insert into test_table values (%d, %d)" % (i + 1, i + 2))
2288        s = dict(m=2)
2289        self.assertRaises(KeyError, delete, 'test_table', s)
2290        s = dict(m=2, oid=oid)
2291        self.assertRaises(KeyError, delete, 'test_table', s)
2292        r = delete('test_table', dict(m=2), oid=oid)
2293        self.assertEqual(r, 0)
2294        oid = get('test_table', 1, 'n')[qoid]
2295        s = dict(oid=oid)
2296        self.assertRaises(KeyError, delete, 'test_table', s)
2297        r = delete('test_table', s, oid=oid)
2298        self.assertEqual(r, 1)
2299        r = delete('test_table', s, oid=oid)
2300        self.assertEqual(r, 0)
2301        self.assertEqual(query(q).getresult()[0], (2, 5))
2302        s = get('test_table', 2, 'n')
2303        del s['n']
2304        r = delete('test_table', s)
2305        self.assertEqual(r, 1)
2306        r = delete('test_table', s)
2307        self.assertEqual(r, 0)
2308        self.assertEqual(query(q).getresult()[0], (3, 4))
2309        r = delete('test_table', n=3)
2310        self.assertEqual(r, 1)
2311        r = delete('test_table', n=3)
2312        self.assertEqual(r, 0)
2313        self.assertEqual(query(q).getresult()[0], (4, 3))
2314        r = delete('test_table', None, n=4)
2315        self.assertEqual(r, 1)
2316        r = delete('test_table', None, n=4)
2317        self.assertEqual(r, 0)
2318        self.assertEqual(query(q).getresult()[0], (5, 2))
2319        s = dict(n=6)
2320        r = delete('test_table', s, n=5)
2321        self.assertEqual(r, 1)
2322        r = delete('test_table', s, n=5)
2323        self.assertEqual(r, 0)
2324        s = get('test_table', 6, 'n')
2325        self.assertEqual(s['n'], 6)
2326        s['n'] = 7
2327        r = delete('test_table', s)
2328        self.assertEqual(r, 1)
2329        self.assertEqual(query(q).getresult()[0], (None, 0))
2330
2331    def testDeleteWithCompositeKey(self):
2332        query = self.db.query
2333        table = 'delete_test_table_1'
2334        self.createTable(table, 'n integer primary key, t text',
2335                         values=enumerate('abc', start=1))
2336        self.assertRaises(KeyError, self.db.delete, table, dict(t='b'))
2337        self.assertEqual(self.db.delete(table, dict(n=2)), 1)
2338        r = query('select t from "%s" where n=2' % table).getresult()
2339        self.assertEqual(r, [])
2340        self.assertEqual(self.db.delete(table, dict(n=2)), 0)
2341        r = query('select t from "%s" where n=3' % table).getresult()[0][0]
2342        self.assertEqual(r, 'c')
2343        table = 'delete_test_table_2'
2344        self.createTable(table,
2345             'n integer, m integer, t text, primary key (n, m)',
2346             values=[(n + 1, m + 1, chr(ord('a') + 2 * n + m))
2347                     for n in range(3) for m in range(2)])
2348        self.assertRaises(KeyError, self.db.delete, table, dict(n=2, t='b'))
2349        self.assertEqual(self.db.delete(table, dict(n=2, m=2)), 1)
2350        r = [r[0] for r in query('select t from "%s" where n=2'
2351            ' order by m' % table).getresult()]
2352        self.assertEqual(r, ['c'])
2353        self.assertEqual(self.db.delete(table, dict(n=2, m=2)), 0)
2354        r = [r[0] for r in query('select t from "%s" where n=3'
2355             ' order by m' % table).getresult()]
2356        self.assertEqual(r, ['e', 'f'])
2357        self.assertEqual(self.db.delete(table, dict(n=3, m=1)), 1)
2358        r = [r[0] for r in query('select t from "%s" where n=3'
2359             ' order by m' % table).getresult()]
2360        self.assertEqual(r, ['f'])
2361
2362    def testDeleteWithQuotedNames(self):
2363        delete = self.db.delete
2364        query = self.db.query
2365        table = 'test table for delete()'
2366        self.createTable(table, '"Prime!" smallint primary key,'
2367            ' "much space" integer, "Questions?" text',
2368            values=[(19, 5005, 'Yes!')])
2369        r = {'Prime!': 17}
2370        r = delete(table, r)
2371        self.assertEqual(r, 0)
2372        r = query('select count(*) from "%s"' % table).getresult()
2373        self.assertEqual(r[0][0], 1)
2374        r = {'Prime!': 19}
2375        r = delete(table, r)
2376        self.assertEqual(r, 1)
2377        r = query('select count(*) from "%s"' % table).getresult()
2378        self.assertEqual(r[0][0], 0)
2379
2380    def testDeleteReferenced(self):
2381        delete = self.db.delete
2382        query = self.db.query
2383        self.createTable('test_parent',
2384            'n smallint primary key', values=range(3))
2385        self.createTable('test_child',
2386            'n smallint primary key references test_parent', values=range(3))
2387        q = ("select (select count(*) from test_parent),"
2388             " (select count(*) from test_child)")
2389        self.assertEqual(query(q).getresult()[0], (3, 3))
2390        self.assertRaises(pg.IntegrityError,
2391                          delete, 'test_parent', None, n=2)
2392        self.assertRaises(pg.IntegrityError,
2393                          delete, 'test_parent *', None, n=2)
2394        r = delete('test_child', None, n=2)
2395        self.assertEqual(r, 1)
2396        self.assertEqual(query(q).getresult()[0], (3, 2))
2397        r = delete('test_parent', None, n=2)
2398        self.assertEqual(r, 1)
2399        self.assertEqual(query(q).getresult()[0], (2, 2))
2400        self.assertRaises(pg.IntegrityError,
2401                          delete, 'test_parent', dict(n=0))
2402        self.assertRaises(pg.IntegrityError,
2403                          delete, 'test_parent *', dict(n=0))
2404        r = delete('test_child', dict(n=0))
2405        self.assertEqual(r, 1)
2406        self.assertEqual(query(q).getresult()[0], (2, 1))
2407        r = delete('test_child', dict(n=0))
2408        self.assertEqual(r, 0)
2409        r = delete('test_parent', dict(n=0))
2410        self.assertEqual(r, 1)
2411        self.assertEqual(query(q).getresult()[0], (1, 1))
2412        r = delete('test_parent', None, n=0)
2413        self.assertEqual(r, 0)
2414        q = "select n from test_parent natural join test_child limit 2"
2415        self.assertEqual(query(q).getresult(), [(1,)])
2416
2417    def testTempCrud(self):
2418        table = 'test_temp_table'
2419        self.createTable(table, "n int primary key, t varchar", temporary=True)
2420        self.db.insert(table, dict(n=1, t='one'))
2421        self.db.insert(table, dict(n=2, t='too'))
2422        self.db.insert(table, dict(n=3, t='three'))
2423        r = self.db.get(table, 2)
2424        self.assertEqual(r['t'], 'too')
2425        self.db.update(table, dict(n=2, t='two'))
2426        r = self.db.get(table, 2)
2427        self.assertEqual(r['t'], 'two')
2428        self.db.delete(table, r)
2429        r = self.db.query('select n, t from %s order by 1' % table).getresult()
2430        self.assertEqual(r, [(1, 'one'), (3, 'three')])
2431
2432    def testTruncate(self):
2433        truncate = self.db.truncate
2434        self.assertRaises(TypeError, truncate, None)
2435        self.assertRaises(TypeError, truncate, 42)
2436        self.assertRaises(TypeError, truncate, dict(test_table=None))
2437        query = self.db.query
2438        self.createTable('test_table', 'n smallint',
2439                         temporary=False, values=[1] * 3)
2440        q = "select count(*) from test_table"
2441        r = query(q).getresult()[0][0]
2442        self.assertEqual(r, 3)
2443        truncate('test_table')
2444        r = query(q).getresult()[0][0]
2445        self.assertEqual(r, 0)
2446        for i in range(3):
2447            query("insert into test_table values (1)")
2448        r = query(q).getresult()[0][0]
2449        self.assertEqual(r, 3)
2450        truncate('public.test_table')
2451        r = query(q).getresult()[0][0]
2452        self.assertEqual(r, 0)
2453        self.createTable('test_table_2', 'n smallint', temporary=True)
2454        for t in (list, tuple, set):
2455            for i in range(3):
2456                query("insert into test_table values (1)")
2457                query("insert into test_table_2 values (2)")
2458            q = ("select (select count(*) from test_table),"
2459                 " (select count(*) from test_table_2)")
2460            r = query(q).getresult()[0]
2461            self.assertEqual(r, (3, 3))
2462            truncate(t(['test_table', 'test_table_2']))
2463            r = query(q).getresult()[0]
2464            self.assertEqual(r, (0, 0))
2465
2466    def testTruncateRestart(self):
2467        truncate = self.db.truncate
2468        self.assertRaises(TypeError, truncate, 'test_table', restart='invalid')
2469        query = self.db.query
2470        self.createTable('test_table', 'n serial, t text')
2471        for n in range(3):
2472            query("insert into test_table (t) values ('test')")
2473        q = "select count(n), min(n), max(n) from test_table"
2474        r = query(q).getresult()[0]
2475        self.assertEqual(r, (3, 1, 3))
2476        truncate('test_table')
2477        r = query(q).getresult()[0]
2478        self.assertEqual(r, (0, None, None))
2479        for n in range(3):
2480            query("insert into test_table (t) values ('test')")
2481        r = query(q).getresult()[0]
2482        self.assertEqual(r, (3, 4, 6))
2483        truncate('test_table', restart=True)
2484        r = query(q).getresult()[0]
2485        self.assertEqual(r, (0, None, None))
2486        for n in range(3):
2487            query("insert into test_table (t) values ('test')")
2488        r = query(q).getresult()[0]
2489        self.assertEqual(r, (3, 1, 3))
2490
2491    def testTruncateCascade(self):
2492        truncate = self.db.truncate
2493        self.assertRaises(TypeError, truncate, 'test_table', cascade='invalid')
2494        query = self.db.query
2495        self.createTable('test_parent', 'n smallint primary key',
2496                         values=range(3))
2497        self.createTable('test_child',
2498                         'n smallint primary key references test_parent (n)',
2499                         values=range(3))
2500        q = ("select (select count(*) from test_parent),"
2501             " (select count(*) from test_child)")
2502        r = query(q).getresult()[0]
2503        self.assertEqual(r, (3, 3))
2504        self.assertRaises(pg.NotSupportedError, truncate, 'test_parent')
2505        truncate(['test_parent', 'test_child'])
2506        r = query(q).getresult()[0]
2507        self.assertEqual(r, (0, 0))
2508        for n in range(3):
2509            query("insert into test_parent (n) values (%d)" % n)
2510            query("insert into test_child (n) values (%d)" % n)
2511        r = query(q).getresult()[0]
2512        self.assertEqual(r, (3, 3))
2513        truncate('test_parent', cascade=True)
2514        r = query(q).getresult()[0]
2515        self.assertEqual(r, (0, 0))
2516        for n in range(3):
2517            query("insert into test_parent (n) values (%d)" % n)
2518            query("insert into test_child (n) values (%d)" % n)
2519        r = query(q).getresult()[0]
2520        self.assertEqual(r, (3, 3))
2521        truncate('test_child')
2522        r = query(q).getresult()[0]
2523        self.assertEqual(r, (3, 0))
2524        self.assertRaises(pg.NotSupportedError, truncate, 'test_parent')
2525        truncate('test_parent', cascade=True)
2526        r = query(q).getresult()[0]
2527        self.assertEqual(r, (0, 0))
2528
2529    def testTruncateOnly(self):
2530        truncate = self.db.truncate
2531        self.assertRaises(TypeError, truncate, 'test_table', only='invalid')
2532        query = self.db.query
2533        self.createTable('test_parent', 'n smallint')
2534        self.createTable('test_child', 'm smallint) inherits (test_parent')
2535        for n in range(3):
2536            query("insert into test_parent (n) values (1)")
2537            query("insert into test_child (n, m) values (2, 3)")
2538        q = ("select (select count(*) from test_parent),"
2539             " (select count(*) from test_child)")
2540        r = query(q).getresult()[0]
2541        self.assertEqual(r, (6, 3))
2542        truncate('test_parent')
2543        r = query(q).getresult()[0]
2544        self.assertEqual(r, (0, 0))
2545        for n in range(3):
2546            query("insert into test_parent (n) values (1)")
2547            query("insert into test_child (n, m) values (2, 3)")
2548        r = query(q).getresult()[0]
2549        self.assertEqual(r, (6, 3))
2550        truncate('test_parent*')
2551        r = query(q).getresult()[0]
2552        self.assertEqual(r, (0, 0))
2553        for n in range(3):
2554            query("insert into test_parent (n) values (1)")
2555            query("insert into test_child (n, m) values (2, 3)")
2556        r = query(q).getresult()[0]
2557        self.assertEqual(r, (6, 3))
2558        truncate('test_parent', only=True)
2559        r = query(q).getresult()[0]
2560        self.assertEqual(r, (3, 3))
2561        truncate('test_parent', only=False)
2562        r = query(q).getresult()[0]
2563        self.assertEqual(r, (0, 0))
2564        self.assertRaises(ValueError, truncate, 'test_parent*', only=True)
2565        truncate('test_parent*', only=False)
2566        self.createTable('test_parent_2', 'n smallint')
2567        self.createTable('test_child_2', 'm smallint) inherits (test_parent_2')
2568        for t in '', '_2':
2569            for n in range(3):
2570                query("insert into test_parent%s (n) values (1)" % t)
2571                query("insert into test_child%s (n, m) values (2, 3)" % t)
2572        q = ("select (select count(*) from test_parent),"
2573             " (select count(*) from test_child),"
2574             " (select count(*) from test_parent_2),"
2575             " (select count(*) from test_child_2)")
2576        r = query(q).getresult()[0]
2577        self.assertEqual(r, (6, 3, 6, 3))
2578        truncate(['test_parent', 'test_parent_2'], only=[False, True])
2579        r = query(q).getresult()[0]
2580        self.assertEqual(r, (0, 0, 3, 3))
2581        truncate(['test_parent', 'test_parent_2'], only=False)
2582        r = query(q).getresult()[0]
2583        self.assertEqual(r, (0, 0, 0, 0))
2584        self.assertRaises(ValueError, truncate,
2585            ['test_parent*', 'test_child'], only=[True, False])
2586        truncate(['test_parent*', 'test_child'], only=[False, True])
2587
2588    def testTruncateQuoted(self):
2589        truncate = self.db.truncate
2590        query = self.db.query
2591        table = "test table for truncate()"
2592        self.createTable(table, 'n smallint', temporary=False, values=[1] * 3)
2593        q = 'select count(*) from "%s"' % table
2594        r = query(q).getresult()[0][0]
2595        self.assertEqual(r, 3)
2596        truncate(table)
2597        r = query(q).getresult()[0][0]
2598        self.assertEqual(r, 0)
2599        for i in range(3):
2600            query('insert into "%s" values (1)' % table)
2601        r = query(q).getresult()[0][0]
2602        self.assertEqual(r, 3)
2603        truncate('public."%s"' % table)
2604        r = query(q).getresult()[0][0]
2605        self.assertEqual(r, 0)
2606
2607    def testGetAsList(self):
2608        get_as_list = self.db.get_as_list
2609        self.assertRaises(TypeError, get_as_list)
2610        self.assertRaises(TypeError, get_as_list, None)
2611        query = self.db.query
2612        table = 'test_aslist'
2613        r = query('select 1 as colname').namedresult()[0]
2614        self.assertIsInstance(r, tuple)
2615        named = hasattr(r, 'colname')
2616        names = [(1, 'Homer'), (2, 'Marge'),
2617                 (3, 'Bart'), (4, 'Lisa'), (5, 'Maggie')]
2618        self.createTable(table,
2619            'id smallint primary key, name varchar', values=names)
2620        r = get_as_list(table)
2621        self.assertIsInstance(r, list)
2622        self.assertEqual(r, names)
2623        for t, n in zip(r, names):
2624            self.assertIsInstance(t, tuple)
2625            self.assertEqual(t, n)
2626            if named:
2627                self.assertEqual(t.id, n[0])
2628                self.assertEqual(t.name, n[1])
2629                self.assertEqual(t._asdict(), dict(id=n[0], name=n[1]))
2630        r = get_as_list(table, what='name')
2631        self.assertIsInstance(r, list)
2632        expected = sorted((row[1],) for row in names)
2633        self.assertEqual(r, expected)
2634        r = get_as_list(table, what='name, id')
2635        self.assertIsInstance(r, list)
2636        expected = sorted(tuple(reversed(row)) for row in names)
2637        self.assertEqual(r, expected)
2638        r = get_as_list(table, what=['name', 'id'])
2639        self.assertIsInstance(r, list)
2640        self.assertEqual(r, expected)
2641        r = get_as_list(table, where="name like 'Ba%'")
2642        self.assertIsInstance(r, list)
2643        self.assertEqual(r, names[2:3])
2644        r = get_as_list(table, what='name', where="name like 'Ma%'")
2645        self.assertIsInstance(r, list)
2646        self.assertEqual(r, [('Maggie',), ('Marge',)])
2647        r = get_as_list(table, what='name',
2648            where=["name like 'Ma%'", "name like '%r%'"])
2649        self.assertIsInstance(r, list)
2650        self.assertEqual(r, [('Marge',)])
2651        r = get_as_list(table, what='name', order='id')
2652        self.assertIsInstance(r, list)
2653        expected = [(row[1],) for row in names]
2654        self.assertEqual(r, expected)
2655        r = get_as_list(table, what=['name'], order=['id'])
2656        self.assertIsInstance(r, list)
2657        self.assertEqual(r, expected)
2658        r = get_as_list(table, what=['id', 'name'], order=['id', 'name'])
2659        self.assertIsInstance(r, list)
2660        self.assertEqual(r, names)
2661        r = get_as_list(table, what='id * 2 as num', order='id desc')
2662        self.assertIsInstance(r, list)
2663        expected = [(n,) for n in range(10, 0, -2)]
2664        self.assertEqual(r, expected)
2665        r = get_as_list(table, limit=2)
2666        self.assertIsInstance(r, list)
2667        self.assertEqual(r, names[:2])
2668        r = get_as_list(table, offset=3)
2669        self.assertIsInstance(r, list)
2670        self.assertEqual(r, names[3:])
2671        r = get_as_list(table, limit=1, offset=2)
2672        self.assertIsInstance(r, list)
2673        self.assertEqual(r, names[2:3])
2674        r = get_as_list(table, scalar=True)
2675        self.assertIsInstance(r, list)
2676        self.assertEqual(r, list(range(1, 6)))
2677        r = get_as_list(table, what='name', scalar=True)
2678        self.assertIsInstance(r, list)
2679        expected = sorted(row[1] for row in names)
2680        self.assertEqual(r, expected)
2681        r = get_as_list(table, what='name', limit=1, scalar=True)
2682        self.assertIsInstance(r, list)
2683        self.assertEqual(r, expected[:1])
2684        query('alter table "%s" drop constraint "%s_pkey"' % (table, table))
2685        self.assertRaises(KeyError, self.db.pkey, table, flush=True)
2686        names.insert(1, (1, 'Snowball'))
2687        query('insert into "%s" values ($1, $2)' % table, (1, 'Snowball'))
2688        r = get_as_list(table)
2689        self.assertIsInstance(r, list)
2690        self.assertEqual(r, names)
2691        r = get_as_list(table, what='name', where='id=1', scalar=True)
2692        self.assertIsInstance(r, list)
2693        self.assertEqual(r, ['Homer', 'Snowball'])
2694        # test with unordered query
2695        r = get_as_list(table, order=False)
2696        self.assertIsInstance(r, list)
2697        self.assertEqual(set(r), set(names))
2698        # test with arbitrary from clause
2699        from_table = '(select lower(name) as n2 from "%s") as t2' % table
2700        r = get_as_list(from_table)
2701        self.assertIsInstance(r, list)
2702        r = set(row[0] for row in r)
2703        expected = set(row[1].lower() for row in names)
2704        self.assertEqual(r, expected)
2705        r = get_as_list(from_table, order='n2', scalar=True)
2706        self.assertIsInstance(r, list)
2707        self.assertEqual(r, sorted(expected))
2708        r = get_as_list(from_table, order='n2', limit=1)
2709        self.assertIsInstance(r, list)
2710        self.assertEqual(len(r), 1)
2711        t = r[0]
2712        self.assertIsInstance(t, tuple)
2713        if named:
2714            self.assertEqual(t.n2, 'bart')
2715            self.assertEqual(t._asdict(), dict(n2='bart'))
2716        else:
2717            self.assertEqual(t, ('bart',))
2718
2719    def testGetAsDict(self):
2720        get_as_dict = self.db.get_as_dict
2721        self.assertRaises(TypeError, get_as_dict)
2722        self.assertRaises(TypeError, get_as_dict, None)
2723        # the test table has no primary key
2724        self.assertRaises(pg.ProgrammingError, get_as_dict, 'test')
2725        query = self.db.query
2726        table = 'test_asdict'
2727        r = query('select 1 as colname').namedresult()[0]
2728        self.assertIsInstance(r, tuple)
2729        named = hasattr(r, 'colname')
2730        colors = [(1, '#7cb9e8', 'Aero'), (2, '#b5a642', 'Brass'),
2731                  (3, '#b2ffff', 'Celeste'), (4, '#c19a6b', 'Desert')]
2732        self.createTable(table,
2733            'id smallint primary key, rgb char(7), name varchar',
2734            values=colors)
2735        # keyname must be string, list or tuple
2736        self.assertRaises(KeyError, get_as_dict, table, 3)
2737        self.assertRaises(KeyError, get_as_dict, table, dict(id=None))
2738        # missing keyname in row
2739        self.assertRaises(KeyError, get_as_dict, table,
2740                          keyname='rgb', what='name')
2741        r = get_as_dict(table)
2742        self.assertIsInstance(r, OrderedDict)
2743        expected = OrderedDict((row[0], row[1:]) for row in colors)
2744        self.assertEqual(r, expected)
2745        for key in r:
2746            self.assertIsInstance(key, int)
2747            self.assertIn(key, expected)
2748            row = r[key]
2749            self.assertIsInstance(row, tuple)
2750            t = expected[key]
2751            self.assertEqual(row, t)
2752            if named:
2753                self.assertEqual(row.rgb, t[0])
2754                self.assertEqual(row.name, t[1])
2755                self.assertEqual(row._asdict(), dict(rgb=t[0], name=t[1]))
2756        if OrderedDict is not dict:  # Python > 2.6
2757            self.assertEqual(r.keys(), expected.keys())
2758        r = get_as_dict(table, keyname='rgb')
2759        self.assertIsInstance(r, OrderedDict)
2760        expected = OrderedDict((row[1], (row[0], row[2]))
2761                               for row in sorted(colors, key=itemgetter(1)))
2762        self.assertEqual(r, expected)
2763        for key in r:
2764            self.assertIsInstance(key, str)
2765            self.assertIn(key, expected)
2766            row = r[key]
2767            self.assertIsInstance(row, tuple)
2768            t = expected[key]
2769            self.assertEqual(row, t)
2770            if named:
2771                self.assertEqual(row.id, t[0])
2772                self.assertEqual(row.name, t[1])
2773                self.assertEqual(row._asdict(), dict(id=t[0], name=t[1]))
2774        if OrderedDict is not dict:  # Python > 2.6
2775            self.assertEqual(r.keys(), expected.keys())
2776        r = get_as_dict(table, keyname=['id', 'rgb'])
2777        self.assertIsInstance(r, OrderedDict)
2778        expected = OrderedDict((row[:2], row[2:]) for row in colors)
2779        self.assertEqual(r, expected)
2780        for key in r:
2781            self.assertIsInstance(key, tuple)
2782            self.assertIsInstance(key[0], int)
2783            self.assertIsInstance(key[1], str)
2784            if named:
2785                self.assertEqual(key, (key.id, key.rgb))
2786                self.assertEqual(key._fields, ('id', 'rgb'))
2787            row = r[key]
2788            self.assertIsInstance(row, tuple)
2789            self.assertIsInstance(row[0], str)
2790            t = expected[key]
2791            self.assertEqual(row, t)
2792            if named:
2793                self.assertEqual(row.name, t[0])
2794                self.assertEqual(row._asdict(), dict(name=t[0]))
2795        if OrderedDict is not dict:  # Python > 2.6
2796            self.assertEqual(r.keys(), expected.keys())
2797        r = get_as_dict(table, keyname=['id', 'rgb'], scalar=True)
2798        self.assertIsInstance(r, OrderedDict)
2799        expected = OrderedDict((row[:2], row[2]) for row in colors)
2800        self.assertEqual(r, expected)
2801        for key in r:
2802            self.assertIsInstance(key, tuple)
2803            row = r[key]
2804            self.assertIsInstance(row, str)
2805            t = expected[key]
2806            self.assertEqual(row, t)
2807        if OrderedDict is not dict:  # Python > 2.6
2808            self.assertEqual(r.keys(), expected.keys())
2809        r = get_as_dict(table, keyname='rgb', what=['rgb', 'name'], scalar=True)
2810        self.assertIsInstance(r, OrderedDict)
2811        expected = OrderedDict((row[1], row[2])
2812            for row in sorted(colors, key=itemgetter(1)))
2813        self.assertEqual(r, expected)
2814        for key in r:
2815            self.assertIsInstance(key, str)
2816            row = r[key]
2817            self.assertIsInstance(row, str)
2818            t = expected[key]
2819            self.assertEqual(row, t)
2820        if OrderedDict is not dict:  # Python > 2.6
2821            self.assertEqual(r.keys(), expected.keys())
2822        r = get_as_dict(table, what='id, name',
2823            where="rgb like '#b%'", scalar=True)
2824        self.assertIsInstance(r, OrderedDict)
2825        expected = OrderedDict((row[0], row[2]) for row in colors[1:3])
2826        self.assertEqual(r, expected)
2827        for key in r:
2828            self.assertIsInstance(key, int)
2829            row = r[key]
2830            self.assertIsInstance(row, str)
2831            t = expected[key]
2832            self.assertEqual(row, t)
2833        if OrderedDict is not dict:  # Python > 2.6
2834            self.assertEqual(r.keys(), expected.keys())
2835        expected = r
2836        r = get_as_dict(table, what=['name', 'id'],
2837            where=['id > 1', 'id < 4', "rgb like '#b%'",
2838                   "name not like 'A%'", "name not like '%t'"], scalar=True)
2839        self.assertEqual(r, expected)
2840        r = get_as_dict(table, what='name, id', limit=2, offset=1, scalar=True)
2841        self.assertEqual(r, expected)
2842        r = get_as_dict(table, keyname=('id',), what=('name', 'id'),
2843            where=('id > 1', 'id < 4'), order=('id',), scalar=True)
2844        self.assertEqual(r, expected)
2845        r = get_as_dict(table, limit=1)
2846        self.assertEqual(len(r), 1)
2847        self.assertEqual(r[1][1], 'Aero')
2848        r = get_as_dict(table, offset=3)
2849        self.assertEqual(len(r), 1)
2850        self.assertEqual(r[4][1], 'Desert')
2851        r = get_as_dict(table, order='id desc')
2852        expected = OrderedDict((row[0], row[1:]) for row in reversed(colors))
2853        self.assertEqual(r, expected)
2854        r = get_as_dict(table, where='id > 5')
2855        self.assertIsInstance(r, OrderedDict)
2856        self.assertEqual(len(r), 0)
2857        # test with unordered query
2858        expected = dict((row[0], row[1:]) for row in colors)
2859        r = get_as_dict(table, order=False)
2860        self.assertIsInstance(r, dict)
2861        self.assertEqual(r, expected)
2862        if dict is not OrderedDict:  # Python > 2.6
2863            self.assertNotIsInstance(self, OrderedDict)
2864        # test with arbitrary from clause
2865        from_table = '(select id, lower(name) as n2 from "%s") as t2' % table
2866        # primary key must be passed explicitly in this case
2867        self.assertRaises(pg.ProgrammingError, get_as_dict, from_table)
2868        r = get_as_dict(from_table, 'id')
2869        self.assertIsInstance(r, OrderedDict)
2870        expected = OrderedDict((row[0], (row[2].lower(),)) for row in colors)
2871        self.assertEqual(r, expected)
2872        # test without a primary key
2873        query('alter table "%s" drop constraint "%s_pkey"' % (table, table))
2874        self.assertRaises(KeyError, self.db.pkey, table, flush=True)
2875        self.assertRaises(pg.ProgrammingError, get_as_dict, table)
2876        r = get_as_dict(table, keyname='id')
2877        expected = OrderedDict((row[0], row[1:]) for row in colors)
2878        self.assertIsInstance(r, dict)
2879        self.assertEqual(r, expected)
2880        r = (1, '#007fff', 'Azure')
2881        query('insert into "%s" values ($1, $2, $3)' % table, r)
2882        # the last entry will win
2883        expected[1] = r[1:]
2884        r = get_as_dict(table, keyname='id')
2885        self.assertEqual(r, expected)
2886
2887    def testTransaction(self):
2888        query = self.db.query
2889        self.createTable('test_table', 'n integer', temporary=False)
2890        self.db.begin()
2891        query("insert into test_table values (1)")
2892        query("insert into test_table values (2)")
2893        self.db.commit()
2894        self.db.begin()
2895        query("insert into test_table values (3)")
2896        query("insert into test_table values (4)")
2897        self.db.rollback()
2898        self.db.begin()
2899        query("insert into test_table values (5)")
2900        self.db.savepoint('before6')
2901        query("insert into test_table values (6)")
2902        self.db.rollback('before6')
2903        query("insert into test_table values (7)")
2904        self.db.commit()
2905        self.db.begin()
2906        self.db.savepoint('before8')
2907        query("insert into test_table values (8)")
2908        self.db.release('before8')
2909        self.assertRaises(pg.InternalError, self.db.rollback, 'before8')
2910        self.db.commit()
2911        self.db.start()
2912        query("insert into test_table values (9)")
2913        self.db.end()
2914        r = [r[0] for r in query(
2915            "select * from test_table order by 1").getresult()]
2916        self.assertEqual(r, [1, 2, 5, 7, 9])
2917        self.db.begin(mode='read only')
2918        self.assertRaises(pg.InternalError,
2919                          query, "insert into test_table values (0)")
2920        self.db.rollback()
2921        self.db.start(mode='Read Only')
2922        self.assertRaises(pg.InternalError,
2923                          query, "insert into test_table values (0)")
2924        self.db.abort()
2925
2926    def testTransactionAliases(self):
2927        self.assertEqual(self.db.begin, self.db.start)
2928        self.assertEqual(self.db.commit, self.db.end)
2929        self.assertEqual(self.db.rollback, self.db.abort)
2930
2931    def testContextManager(self):
2932        query = self.db.query
2933        self.createTable('test_table', 'n integer check(n>0)')
2934        with self.db:
2935            query("insert into test_table values (1)")
2936            query("insert into test_table values (2)")
2937        try:
2938            with self.db:
2939                query("insert into test_table values (3)")
2940                query("insert into test_table values (4)")
2941                raise ValueError('test transaction should rollback')
2942        except ValueError as error:
2943            self.assertEqual(str(error), 'test transaction should rollback')
2944        with self.db:
2945            query("insert into test_table values (5)")
2946        try:
2947            with self.db:
2948                query("insert into test_table values (6)")
2949                query("insert into test_table values (-1)")
2950        except pg.IntegrityError as error:
2951            self.assertTrue('check' in str(error))
2952        with self.db:
2953            query("insert into test_table values (7)")
2954        r = [r[0] for r in query(
2955            "select * from test_table order by 1").getresult()]
2956        self.assertEqual(r, [1, 2, 5, 7])
2957
2958    def testBytea(self):
2959        query = self.db.query
2960        self.createTable('bytea_test', 'n smallint primary key, data bytea')
2961        s = b"It's all \\ kinds \x00 of\r nasty \xff stuff!\n"
2962        r = self.db.escape_bytea(s)
2963        query('insert into bytea_test values(3, $1)', (r,))
2964        r = query('select * from bytea_test where n=3').getresult()
2965        self.assertEqual(len(r), 1)
2966        r = r[0]
2967        self.assertEqual(len(r), 2)
2968        self.assertEqual(r[0], 3)
2969        r = r[1]
2970        if pg.get_bytea_escaped():
2971            self.assertNotEqual(r, s)
2972            r = pg.unescape_bytea(r)
2973        self.assertIsInstance(r, bytes)
2974        self.assertEqual(r, s)
2975
2976    def testInsertUpdateGetBytea(self):
2977        query = self.db.query
2978        unescape = pg.unescape_bytea if pg.get_bytea_escaped() else None
2979        self.createTable('bytea_test', 'n smallint primary key, data bytea')
2980        # insert null value
2981        r = self.db.insert('bytea_test', n=0, data=None)
2982        self.assertIsInstance(r, dict)
2983        self.assertIn('n', r)
2984        self.assertEqual(r['n'], 0)
2985        self.assertIn('data', r)
2986        self.assertIsNone(r['data'])
2987        s = b'None'
2988        r = self.db.update('bytea_test', n=0, data=s)
2989        self.assertIsInstance(r, dict)
2990        self.assertIn('n', r)
2991        self.assertEqual(r['n'], 0)
2992        self.assertIn('data', r)
2993        r = r['data']
2994        if unescape:
2995            self.assertNotEqual(r, s)
2996            r = unescape(r)
2997        self.assertIsInstance(r, bytes)
2998        self.assertEqual(r, s)
2999        r = self.db.update('bytea_test', n=0, data=None)
3000        self.assertIsNone(r['data'])
3001        # insert as bytes
3002        s = b"It's all \\ kinds \x00 of\r nasty \xff stuff!\n"
3003        r = self.db.insert('bytea_test', n=5, data=s)
3004        self.assertIsInstance(r, dict)
3005        self.assertIn('n', r)
3006        self.assertEqual(r['n'], 5)
3007        self.assertIn('data', r)
3008        r = r['data']
3009        if unescape:
3010            self.assertNotEqual(r, s)
3011            r = unescape(r)
3012        self.assertIsInstance(r, bytes)
3013        self.assertEqual(r, s)
3014        # update as bytes
3015        s += b"and now even more \x00 nasty \t stuff!\f"
3016        r = self.db.update('bytea_test', n=5, data=s)
3017        self.assertIsInstance(r, dict)
3018        self.assertIn('n', r)
3019        self.assertEqual(r['n'], 5)
3020        self.assertIn('data', r)
3021        r = r['data']
3022        if unescape:
3023            self.assertNotEqual(r, s)
3024            r = unescape(r)
3025        self.assertIsInstance(r, bytes)
3026        self.assertEqual(r, s)
3027        r = query('select * from bytea_test where n=5').getresult()
3028        self.assertEqual(len(r), 1)
3029        r = r[0]
3030        self.assertEqual(len(r), 2)
3031        self.assertEqual(r[0], 5)
3032        r = r[1]
3033        if unescape:
3034            self.assertNotEqual(r, s)
3035            r = unescape(r)
3036        self.assertIsInstance(r, bytes)
3037        self.assertEqual(r, s)
3038        r = self.db.get('bytea_test', dict(n=5))
3039        self.assertIsInstance(r, dict)
3040        self.assertIn('n', r)
3041        self.assertEqual(r['n'], 5)
3042        self.assertIn('data', r)
3043        r = r['data']
3044        if unescape:
3045            self.assertNotEqual(r, s)
3046            r = pg.unescape_bytea(r)
3047        self.assertIsInstance(r, bytes)
3048        self.assertEqual(r, s)
3049
3050    def testUpsertBytea(self):
3051        self.createTable('bytea_test', 'n smallint primary key, data bytea')
3052        s = b"It's all \\ kinds \x00 of\r nasty \xff stuff!\n"
3053        r = dict(n=7, data=s)
3054        try:
3055            r = self.db.upsert('bytea_test', r)
3056        except pg.ProgrammingError as error:
3057            if self.db.server_version < 90500:
3058                self.skipTest('database does not support upsert')
3059            self.fail(str(error))
3060        self.assertIsInstance(r, dict)
3061        self.assertIn('n', r)
3062        self.assertEqual(r['n'], 7)
3063        self.assertIn('data', r)
3064        if pg.get_bytea_escaped():
3065            self.assertNotEqual(r['data'], s)
3066            r['data'] = pg.unescape_bytea(r['data'])
3067        self.assertIsInstance(r['data'], bytes)
3068        self.assertEqual(r['data'], s)
3069        r['data'] = None
3070        r = self.db.upsert('bytea_test', r)
3071        self.assertIsInstance(r, dict)
3072        self.assertIn('n', r)
3073        self.assertEqual(r['n'], 7)
3074        self.assertIn('data', r)
3075        self.assertIsNone(r['data'])
3076
3077    def testInsertGetJson(self):
3078        try:
3079            self.createTable('json_test', 'n smallint primary key, data json')
3080        except pg.ProgrammingError as error:
3081            if self.db.server_version < 90200:
3082                self.skipTest('database does not support json')
3083            self.fail(str(error))
3084        jsondecode = pg.get_jsondecode()
3085        # insert null value
3086        r = self.db.insert('json_test', n=0, data=None)
3087        self.assertIsInstance(r, dict)
3088        self.assertIn('n', r)
3089        self.assertEqual(r['n'], 0)
3090        self.assertIn('data', r)
3091        self.assertIsNone(r['data'])
3092        r = self.db.get('json_test', 0)
3093        self.assertIsInstance(r, dict)
3094        self.assertIn('n', r)
3095        self.assertEqual(r['n'], 0)
3096        self.assertIn('data', r)
3097        self.assertIsNone(r['data'])
3098        # insert JSON object
3099        data = {
3100            "id": 1, "name": "Foo", "price": 1234.5,
3101            "new": True, "note": None,
3102            "tags": ["Bar", "Eek"],
3103            "stock": {"warehouse": 300, "retail": 20}}
3104        r = self.db.insert('json_test', n=1, data=data)
3105        self.assertIsInstance(r, dict)
3106        self.assertIn('n', r)
3107        self.assertEqual(r['n'], 1)
3108        self.assertIn('data', r)
3109        r = r['data']
3110        if jsondecode is None:
3111            self.assertIsInstance(r, str)
3112            r = json.loads(r)
3113        self.assertIsInstance(r, dict)
3114        self.assertEqual(r, data)
3115        self.assertIsInstance(r['id'], int)
3116        self.assertIsInstance(r['name'], unicode)
3117        self.assertIsInstance(r['price'], float)
3118        self.assertIsInstance(r['new'], bool)
3119        self.assertIsInstance(r['tags'], list)
3120        self.assertIsInstance(r['stock'], dict)
3121        r = self.db.get('json_test', 1)
3122        self.assertIsInstance(r, dict)
3123        self.assertIn('n', r)
3124        self.assertEqual(r['n'], 1)
3125        self.assertIn('data', r)
3126        r = r['data']
3127        if jsondecode is None:
3128            self.assertIsInstance(r, str)
3129            r = json.loads(r)
3130        self.assertIsInstance(r, dict)
3131        self.assertEqual(r, data)
3132        self.assertIsInstance(r['id'], int)
3133        self.assertIsInstance(r['name'], unicode)
3134        self.assertIsInstance(r['price'], float)
3135        self.assertIsInstance(r['new'], bool)
3136        self.assertIsInstance(r['tags'], list)
3137        self.assertIsInstance(r['stock'], dict)
3138        # insert JSON object as text
3139        self.db.insert('json_test', n=2, data=json.dumps(data))
3140        q = "select data from json_test where n in (1, 2) order by n"
3141        r = self.db.query(q).getresult()
3142        self.assertEqual(len(r), 2)
3143        self.assertIsInstance(r[0][0], str if jsondecode is None else dict)
3144        self.assertEqual(r[0][0], r[1][0])
3145
3146    def testInsertGetJsonb(self):
3147        try:
3148            self.createTable('jsonb_test',
3149                             'n smallint primary key, data jsonb')
3150        except pg.ProgrammingError as error:
3151            if self.db.server_version < 90400:
3152                self.skipTest('database does not support jsonb')
3153            self.fail(str(error))
3154        jsondecode = pg.get_jsondecode()
3155        # insert null value
3156        r = self.db.insert('jsonb_test', n=0, data=None)
3157        self.assertIsInstance(r, dict)
3158        self.assertIn('n', r)
3159        self.assertEqual(r['n'], 0)
3160        self.assertIn('data', r)
3161        self.assertIsNone(r['data'])
3162        r = self.db.get('jsonb_test', 0)
3163        self.assertIsInstance(r, dict)
3164        self.assertIn('n', r)
3165        self.assertEqual(r['n'], 0)
3166        self.assertIn('data', r)
3167        self.assertIsNone(r['data'])
3168        # insert JSON object
3169        data = {
3170            "id": 1, "name": "Foo", "price": 1234.5,
3171            "new": True, "note": None,
3172            "tags": ["Bar", "Eek"],
3173            "stock": {"warehouse": 300, "retail": 20}}
3174        r = self.db.insert('jsonb_test', n=1, data=data)
3175        self.assertIsInstance(r, dict)
3176        self.assertIn('n', r)
3177        self.assertEqual(r['n'], 1)
3178        self.assertIn('data', r)
3179        r = r['data']
3180        if jsondecode is None:
3181            self.assertIsInstance(r, str)
3182            r = json.loads(r)
3183        self.assertIsInstance(r, dict)
3184        self.assertEqual(r, data)
3185        self.assertIsInstance(r['id'], int)
3186        self.assertIsInstance(r['name'], unicode)
3187        self.assertIsInstance(r['price'], float)
3188        self.assertIsInstance(r['new'], bool)
3189        self.assertIsInstance(r['tags'], list)
3190        self.assertIsInstance(r['stock'], dict)
3191        r = self.db.get('jsonb_test', 1)
3192        self.assertIsInstance(r, dict)
3193        self.assertIn('n', r)
3194        self.assertEqual(r['n'], 1)
3195        self.assertIn('data', r)
3196        r = r['data']
3197        if jsondecode is None:
3198            self.assertIsInstance(r, str)
3199            r = json.loads(r)
3200        self.assertIsInstance(r, dict)
3201        self.assertEqual(r, data)
3202        self.assertIsInstance(r['id'], int)
3203        self.assertIsInstance(r['name'], unicode)
3204        self.assertIsInstance(r['price'], float)
3205        self.assertIsInstance(r['new'], bool)
3206        self.assertIsInstance(r['tags'], list)
3207        self.assertIsInstance(r['stock'], dict)
3208
3209    def testArray(self):
3210        returns_arrays = pg.get_array()
3211        self.createTable('arraytest',
3212            'id smallint, i2 smallint[], i4 integer[], i8 bigint[],'
3213            ' d numeric[], f4 real[], f8 double precision[], m money[],'
3214            ' b bool[], v4 varchar(4)[], c4 char(4)[], t text[]')
3215        r = self.db.get_attnames('arraytest')
3216        if self.regtypes:
3217            self.assertEqual(r, dict(
3218                id='smallint', i2='smallint[]', i4='integer[]', i8='bigint[]',
3219                d='numeric[]', f4='real[]', f8='double precision[]',
3220                m='money[]', b='boolean[]',
3221                v4='character varying[]', c4='character[]', t='text[]'))
3222        else:
3223            self.assertEqual(r, dict(
3224                id='int', i2='int[]', i4='int[]', i8='int[]',
3225                d='num[]', f4='float[]', f8='float[]', m='money[]',
3226                b='bool[]', v4='text[]', c4='text[]', t='text[]'))
3227        decimal = pg.get_decimal()
3228        if decimal is Decimal:
3229            long_decimal = decimal('123456789.123456789')
3230            odd_money = decimal('1234567891234567.89')
3231        else:
3232            long_decimal = decimal('12345671234.5')
3233            odd_money = decimal('1234567123.25')
3234        t, f = (True, False) if pg.get_bool() else ('t', 'f')
3235        data = dict(id=42, i2=[42, 1234, None, 0, -1],
3236            i4=[42, 123456789, None, 0, 1, -1],
3237            i8=[long(42), long(123456789123456789), None,
3238                long(0), long(1), long(-1)],
3239            d=[decimal(42), long_decimal, None,
3240               decimal(0), decimal(1), decimal(-1), -long_decimal],
3241            f4=[42.0, 1234.5, None, 0.0, 1.0, -1.0,
3242                float('inf'), float('-inf')],
3243            f8=[42.0, 12345671234.5, None, 0.0, 1.0, -1.0,
3244                float('inf'), float('-inf')],
3245            m=[decimal('42.00'), odd_money, None,
3246               decimal('0.00'), decimal('1.00'), decimal('-1.00'), -odd_money],
3247            b=[t, f, t, None, f, t, None, None, t],
3248            v4=['abc', '"Hi"', '', None], c4=['abc ', '"Hi"', '    ', None],
3249            t=['abc', 'Hello, World!', '"Hello, World!"', '', None])
3250        r = data.copy()
3251        self.db.insert('arraytest', r)
3252        if returns_arrays:
3253            self.assertEqual(r, data)
3254        else:
3255            self.assertEqual(r['i4'], '{42,123456789,NULL,0,1,-1}')
3256        self.db.insert('arraytest', r)
3257        r = self.db.get('arraytest', 42, 'id')
3258        if returns_arrays:
3259            self.assertEqual(r, data)
3260        else:
3261            self.assertEqual(r['i4'], '{42,123456789,NULL,0,1,-1}')
3262        r = self.db.query('select * from arraytest limit 1').dictresult()[0]
3263        if returns_arrays:
3264            self.assertEqual(r, data)
3265        else:
3266            self.assertEqual(r['i4'], '{42,123456789,NULL,0,1,-1}')
3267
3268    def testArrayLiteral(self):
3269        insert = self.db.insert
3270        returns_arrays = pg.get_array()
3271        self.createTable('arraytest', 'i int[], t text[]', oids=True)
3272        r = dict(i=[1, 2, 3], t=['a', 'b', 'c'])
3273        insert('arraytest', r)
3274        if returns_arrays:
3275            self.assertEqual(r['i'], [1, 2, 3])
3276            self.assertEqual(r['t'], ['a', 'b', 'c'])
3277        else:
3278            self.assertEqual(r['i'], '{1,2,3}')
3279            self.assertEqual(r['t'], '{a,b,c}')
3280        r = dict(i='{1,2,3}', t='{a,b,c}')
3281        self.db.insert('arraytest', r)
3282        if returns_arrays:
3283            self.assertEqual(r['i'], [1, 2, 3])
3284            self.assertEqual(r['t'], ['a', 'b', 'c'])
3285        else:
3286            self.assertEqual(r['i'], '{1,2,3}')
3287            self.assertEqual(r['t'], '{a,b,c}')
3288        L = pg.Literal
3289        r = dict(i=L("ARRAY[1, 2, 3]"), t=L("ARRAY['a', 'b', 'c']"))
3290        self.db.insert('arraytest', r)
3291        if returns_arrays:
3292            self.assertEqual(r['i'], [1, 2, 3])
3293            self.assertEqual(r['t'], ['a', 'b', 'c'])
3294        else:
3295            self.assertEqual(r['i'], '{1,2,3}')
3296            self.assertEqual(r['t'], '{a,b,c}')
3297        r = dict(i="1, 2, 3", t="'a', 'b', 'c'")
3298        self.assertRaises(pg.DataError, self.db.insert, 'arraytest', r)
3299
3300    def testArrayOfIds(self):
3301        array_on = pg.get_array()
3302        self.createTable('arraytest', 'c cid[], o oid[], x xid[]', oids=True)
3303        r = self.db.get_attnames('arraytest')
3304        if self.regtypes:
3305            self.assertEqual(r, dict(
3306                oid='oid', c='cid[]', o='oid[]', x='xid[]'))
3307        else:
3308            self.assertEqual(r, dict(
3309                oid='int', c='int[]', o='int[]', x='int[]'))
3310        data = dict(c=[11, 12, 13], o=[21, 22, 23], x=[31, 32, 33])
3311        r = data.copy()
3312        self.db.insert('arraytest', r)
3313        qoid = 'oid(arraytest)'
3314        oid = r.pop(qoid)
3315        if array_on:
3316            self.assertEqual(r, data)
3317        else:
3318            self.assertEqual(r['o'], '{21,22,23}')
3319        r = {qoid: oid}
3320        self.db.get('arraytest', r)
3321        self.assertEqual(oid, r.pop(qoid))
3322        if array_on:
3323            self.assertEqual(r, data)
3324        else:
3325            self.assertEqual(r['o'], '{21,22,23}')
3326
3327    def testArrayOfText(self):
3328        array_on = pg.get_array()
3329        self.createTable('arraytest', 'data text[]', oids=True)
3330        r = self.db.get_attnames('arraytest')
3331        self.assertEqual(r['data'], 'text[]')
3332        data = ['Hello, World!', '', None, '{a,b,c}', '"Hi!"',
3333                'null', 'NULL', 'Null', 'nulL',
3334                "It's all \\ kinds of\r nasty stuff!\n"]
3335        r = dict(data=data)
3336        self.db.insert('arraytest', r)
3337        if not array_on:
3338            r['data'] = pg.cast_array(r['data'])
3339        self.assertEqual(r['data'], data)
3340        self.assertIsInstance(r['data'][1], str)
3341        self.assertIsNone(r['data'][2])
3342        r['data'] = None
3343        self.db.get('arraytest', r)
3344        if not array_on:
3345            r['data'] = pg.cast_array(r['data'])
3346        self.assertEqual(r['data'], data)
3347        self.assertIsInstance(r['data'][1], str)
3348        self.assertIsNone(r['data'][2])
3349
3350    def testArrayOfBytea(self):
3351        array_on = pg.get_array()
3352        bytea_escaped = pg.get_bytea_escaped()
3353        self.createTable('arraytest', 'data bytea[]', oids=True)
3354        r = self.db.get_attnames('arraytest')
3355        self.assertEqual(r['data'], 'bytea[]')
3356        data = [b'Hello, World!', b'', None, b'{a,b,c}', b'"Hi!"',
3357                b"It's all \\ kinds \x00 of\r nasty \xff stuff!\n"]
3358        r = dict(data=data)
3359        self.db.insert('arraytest', r)
3360        if array_on:
3361            self.assertIsInstance(r['data'], list)
3362        if array_on and not bytea_escaped:
3363            self.assertEqual(r['data'], data)
3364            self.assertIsInstance(r['data'][1], bytes)
3365            self.assertIsNone(r['data'][2])
3366        else:
3367            self.assertNotEqual(r['data'], data)
3368        r['data'] = None
3369        self.db.get('arraytest', r)
3370        if array_on:
3371            self.assertIsInstance(r['data'], list)
3372        if array_on and not bytea_escaped:
3373            self.assertEqual(r['data'], data)
3374            self.assertIsInstance(r['data'][1], bytes)
3375            self.assertIsNone(r['data'][2])
3376        else:
3377            self.assertNotEqual(r['data'], data)
3378
3379    def testArrayOfJson(self):
3380        try:
3381            self.createTable('arraytest', 'data json[]', oids=True)
3382        except pg.ProgrammingError as error:
3383            if self.db.server_version < 90200:
3384                self.skipTest('database does not support json')
3385            self.fail(str(error))
3386        r = self.db.get_attnames('arraytest')
3387        self.assertEqual(r['data'], 'json[]')
3388        data = [dict(id=815, name='John Doe'), dict(id=816, name='Jane Roe')]
3389        array_on = pg.get_array()
3390        jsondecode = pg.get_jsondecode()
3391        r = dict(data=data)
3392        self.db.insert('arraytest', r)
3393        if not array_on:
3394            r['data'] = pg.cast_array(r['data'], jsondecode)
3395        if jsondecode is None:
3396            r['data'] = [json.loads(d) for d in r['data']]
3397        self.assertEqual(r['data'], data)
3398        r['data'] = None
3399        self.db.get('arraytest', r)
3400        if not array_on:
3401            r['data'] = pg.cast_array(r['data'], jsondecode)
3402        if jsondecode is None:
3403            r['data'] = [json.loads(d) for d in r['data']]
3404        self.assertEqual(r['data'], data)
3405        r = dict(data=[json.dumps(d) for d in data])
3406        self.db.insert('arraytest', r)
3407        if not array_on:
3408            r['data'] = pg.cast_array(r['data'], jsondecode)
3409        if jsondecode is None:
3410            r['data'] = [json.loads(d) for d in r['data']]
3411        self.assertEqual(r['data'], data)
3412        r['data'] = None
3413        self.db.get('arraytest', r)
3414        # insert empty json values
3415        r = dict(data=['', None])
3416        self.db.insert('arraytest', r)
3417        r = r['data']
3418        if array_on:
3419            self.assertIsInstance(r, list)
3420            self.assertEqual(len(r), 2)
3421            self.assertIsNone(r[0])
3422            self.assertIsNone(r[1])
3423        else:
3424            self.assertEqual(r, '{NULL,NULL}')
3425
3426    def testArrayOfJsonb(self):
3427        try:
3428            self.createTable('arraytest', 'data jsonb[]', oids=True)
3429        except pg.ProgrammingError as error:
3430            if self.db.server_version < 90400:
3431                self.skipTest('database does not support jsonb')
3432            self.fail(str(error))
3433        r = self.db.get_attnames('arraytest')
3434        self.assertEqual(r['data'], 'jsonb[]' if self.regtypes else 'json[]')
3435        data = [dict(id=815, name='John Doe'), dict(id=816, name='Jane Roe')]
3436        array_on = pg.get_array()
3437        jsondecode = pg.get_jsondecode()
3438        r = dict(data=data)
3439        self.db.insert('arraytest', r)
3440        if not array_on:
3441            r['data'] = pg.cast_array(r['data'], jsondecode)
3442        if jsondecode is None:
3443            r['data'] = [json.loads(d) for d in r['data']]
3444        self.assertEqual(r['data'], data)
3445        r['data'] = None
3446        self.db.get('arraytest', r)
3447        if not array_on:
3448            r['data'] = pg.cast_array(r['data'], jsondecode)
3449        if jsondecode is None:
3450            r['data'] = [json.loads(d) for d in r['data']]
3451        self.assertEqual(r['data'], data)
3452        r = dict(data=[json.dumps(d) for d in data])
3453        self.db.insert('arraytest', r)
3454        if not array_on:
3455            r['data'] = pg.cast_array(r['data'], jsondecode)
3456        if jsondecode is None:
3457            r['data'] = [json.loads(d) for d in r['data']]
3458        self.assertEqual(r['data'], data)
3459        r['data'] = None
3460        self.db.get('arraytest', r)
3461        # insert empty json values
3462        r = dict(data=['', None])
3463        self.db.insert('arraytest', r)
3464        r = r['data']
3465        if array_on:
3466            self.assertIsInstance(r, list)
3467            self.assertEqual(len(r), 2)
3468            self.assertIsNone(r[0])
3469            self.assertIsNone(r[1])
3470        else:
3471            self.assertEqual(r, '{NULL,NULL}')
3472
3473    def testDeepArray(self):
3474        array_on = pg.get_array()
3475        self.createTable('arraytest', 'data text[][][]', oids=True)
3476        r = self.db.get_attnames('arraytest')
3477        self.assertEqual(r['data'], 'text[]')
3478        data = [[['Hello, World!', '{a,b,c}', 'back\\slash']]]
3479        r = dict(data=data)
3480        self.db.insert('arraytest', r)
3481        if array_on:
3482            self.assertEqual(r['data'], data)
3483        else:
3484            self.assertTrue(r['data'].startswith('{{{"Hello,'))
3485        r['data'] = None
3486        self.db.get('arraytest', r)
3487        if array_on:
3488            self.assertEqual(r['data'], data)
3489        else:
3490            self.assertTrue(r['data'].startswith('{{{"Hello,'))
3491
3492    def testInsertUpdateGetRecord(self):
3493        query = self.db.query
3494        query('create type test_person_type as'
3495              ' (name varchar, age smallint, married bool,'
3496              ' weight real, salary money)')
3497        self.addCleanup(query, 'drop type test_person_type')
3498        self.createTable('test_person', 'person test_person_type',
3499                         temporary=False, oids=True)
3500        attnames = self.db.get_attnames('test_person')
3501        self.assertEqual(len(attnames), 2)
3502        self.assertIn('oid', attnames)
3503        self.assertIn('person', attnames)
3504        person_typ = attnames['person']
3505        if self.regtypes:
3506            self.assertEqual(person_typ, 'test_person_type')
3507        else:
3508            self.assertEqual(person_typ, 'record')
3509        if self.regtypes:
3510            self.assertEqual(person_typ.attnames,
3511                dict(name='character varying', age='smallint',
3512                    married='boolean', weight='real', salary='money'))
3513        else:
3514            self.assertEqual(person_typ.attnames,
3515                dict(name='text', age='int', married='bool',
3516                    weight='float', salary='money'))
3517        decimal = pg.get_decimal()
3518        if pg.get_bool():
3519            bool_class = bool
3520            t, f = True, False
3521        else:
3522            bool_class = str
3523            t, f = 't', 'f'
3524        person = ('John Doe', 61, t, 99.5, decimal('93456.75'))
3525        r = self.db.insert('test_person', None, person=person)
3526        p = r['person']
3527        self.assertIsInstance(p, tuple)
3528        self.assertEqual(p, person)
3529        self.assertEqual(p.name, 'John Doe')
3530        self.assertIsInstance(p.name, str)
3531        self.assertIsInstance(p.age, int)
3532        self.assertIsInstance(p.married, bool_class)
3533        self.assertIsInstance(p.weight, float)
3534        self.assertIsInstance(p.salary, decimal)
3535        person = ('Jane Roe', 59, f, 64.5, decimal('96543.25'))
3536        r['person'] = person
3537        self.db.update('test_person', r)
3538        p = r['person']
3539        self.assertIsInstance(p, tuple)
3540        self.assertEqual(p, person)
3541        self.assertEqual(p.name, 'Jane Roe')
3542        self.assertIsInstance(p.name, str)
3543        self.assertIsInstance(p.age, int)
3544        self.assertIsInstance(p.married, bool_class)
3545        self.assertIsInstance(p.weight, float)
3546        self.assertIsInstance(p.salary, decimal)
3547        r['person'] = None
3548        self.db.get('test_person', r)
3549        p = r['person']
3550        self.assertIsInstance(p, tuple)
3551        self.assertEqual(p, person)
3552        self.assertEqual(p.name, 'Jane Roe')
3553        self.assertIsInstance(p.name, str)
3554        self.assertIsInstance(p.age, int)
3555        self.assertIsInstance(p.married, bool_class)
3556        self.assertIsInstance(p.weight, float)
3557        self.assertIsInstance(p.salary, decimal)
3558        person = (None,) * 5
3559        r = self.db.insert('test_person', None, person=person)
3560        p = r['person']
3561        self.assertIsInstance(p, tuple)
3562        self.assertIsNone(p.name)
3563        self.assertIsNone(p.age)
3564        self.assertIsNone(p.married)
3565        self.assertIsNone(p.weight)
3566        self.assertIsNone(p.salary)
3567        r['person'] = None
3568        self.db.get('test_person', r)
3569        p = r['person']
3570        self.assertIsInstance(p, tuple)
3571        self.assertIsNone(p.name)
3572        self.assertIsNone(p.age)
3573        self.assertIsNone(p.married)
3574        self.assertIsNone(p.weight)
3575        self.assertIsNone(p.salary)
3576        r = self.db.insert('test_person', None, person=None)
3577        self.assertIsNone(r['person'])
3578        r['person'] = None
3579        self.db.get('test_person', r)
3580        self.assertIsNone(r['person'])
3581
3582    def testRecordInsertBytea(self):
3583        query = self.db.query
3584        query('create type test_person_type as'
3585              ' (name text, picture bytea)')
3586        self.addCleanup(query, 'drop type test_person_type')
3587        self.createTable('test_person', 'person test_person_type',
3588                         temporary=False, oids=True)
3589        person_typ = self.db.get_attnames('test_person')['person']
3590        self.assertEqual(person_typ.attnames,
3591                         dict(name='text', picture='bytea'))
3592        person = ('John Doe', b'O\x00ps\xff!')
3593        r = self.db.insert('test_person', None, person=person)
3594        p = r['person']
3595        self.assertIsInstance(p, tuple)
3596        self.assertEqual(p, person)
3597        self.assertEqual(p.name, 'John Doe')
3598        self.assertIsInstance(p.name, str)
3599        self.assertEqual(p.picture, person[1])
3600        self.assertIsInstance(p.picture, bytes)
3601
3602    def testRecordInsertJson(self):
3603        query = self.db.query
3604        try:
3605            query('create type test_person_type as'
3606                  ' (name text, data json)')
3607        except pg.ProgrammingError as error:
3608            if self.db.server_version < 90200:
3609                self.skipTest('database does not support json')
3610            self.fail(str(error))
3611        self.addCleanup(query, 'drop type test_person_type')
3612        self.createTable('test_person', 'person test_person_type',
3613                         temporary=False, oids=True)
3614        person_typ = self.db.get_attnames('test_person')['person']
3615        self.assertEqual(person_typ.attnames,
3616                         dict(name='text', data='json'))
3617        person = ('John Doe', dict(age=61, married=True, weight=99.5))
3618        r = self.db.insert('test_person', None, person=person)
3619        p = r['person']
3620        self.assertIsInstance(p, tuple)
3621        if pg.get_jsondecode() is None:
3622            p = p._replace(data=json.loads(p.data))
3623        self.assertEqual(p, person)
3624        self.assertEqual(p.name, 'John Doe')
3625        self.assertIsInstance(p.name, str)
3626        self.assertEqual(p.data, person[1])
3627        self.assertIsInstance(p.data, dict)
3628
3629    def testRecordLiteral(self):
3630        query = self.db.query
3631        query('create type test_person_type as'
3632              ' (name varchar, age smallint)')
3633        self.addCleanup(query, 'drop type test_person_type')
3634        self.createTable('test_person', 'person test_person_type',
3635                         temporary=False, oids=True)
3636        person_typ = self.db.get_attnames('test_person')['person']
3637        if self.regtypes:
3638            self.assertEqual(person_typ, 'test_person_type')
3639        else:
3640            self.assertEqual(person_typ, 'record')
3641        if self.regtypes:
3642            self.assertEqual(person_typ.attnames,
3643                             dict(name='character varying', age='smallint'))
3644        else:
3645            self.assertEqual(person_typ.attnames,
3646                             dict(name='text', age='int'))
3647        person = pg.Literal("('John Doe', 61)")
3648        r = self.db.insert('test_person', None, person=person)
3649        p = r['person']
3650        self.assertIsInstance(p, tuple)
3651        self.assertEqual(p.name, 'John Doe')
3652        self.assertIsInstance(p.name, str)
3653        self.assertEqual(p.age, 61)
3654        self.assertIsInstance(p.age, int)
3655
3656    def testDate(self):
3657        query = self.db.query
3658        for datestyle in ('ISO', 'Postgres, MDY', 'Postgres, DMY',
3659                'SQL, MDY', 'SQL, DMY', 'German'):
3660            self.db.set_parameter('datestyle', datestyle)
3661            d = date(2016, 3, 14)
3662            q = "select $1::date"
3663            r = query(q, (d,)).getresult()[0][0]
3664            self.assertIsInstance(r, date)
3665            self.assertEqual(r, d)
3666            q = "select '10000-08-01'::date, '0099-01-08 BC'::date"
3667            r = query(q).getresult()[0]
3668            self.assertIsInstance(r[0], date)
3669            self.assertIsInstance(r[1], date)
3670            self.assertEqual(r[0], date.max)
3671            self.assertEqual(r[1], date.min)
3672        q = "select 'infinity'::date, '-infinity'::date"
3673        r = query(q).getresult()[0]
3674        self.assertIsInstance(r[0], date)
3675        self.assertIsInstance(r[1], date)
3676        self.assertEqual(r[0], date.max)
3677        self.assertEqual(r[1], date.min)
3678
3679    def testTime(self):
3680        query = self.db.query
3681        d = time(15, 9, 26)
3682        q = "select $1::time"
3683        r = query(q, (d,)).getresult()[0][0]
3684        self.assertIsInstance(r, time)
3685        self.assertEqual(r, d)
3686        d = time(15, 9, 26, 535897)
3687        q = "select $1::time"
3688        r = query(q, (d,)).getresult()[0][0]
3689        self.assertIsInstance(r, time)
3690        self.assertEqual(r, d)
3691
3692    def testTimetz(self):
3693        query = self.db.query
3694        timezones = dict(CET=1, EET=2, EST=-5, UTC=0)
3695        for timezone in sorted(timezones):
3696            tz = '%+03d00' % timezones[timezone]
3697            try:
3698                tzinfo = datetime.strptime(tz, '%z').tzinfo
3699            except ValueError:  # Python < 3.2
3700                tzinfo = pg._get_timezone(tz)
3701            self.db.set_parameter('timezone', timezone)
3702            d = time(15, 9, 26, tzinfo=tzinfo)
3703            q = "select $1::timetz"
3704            r = query(q, (d,)).getresult()[0][0]
3705            self.assertIsInstance(r, time)
3706            self.assertEqual(r, d)
3707            d = time(15, 9, 26, 535897, tzinfo)
3708            q = "select $1::timetz"
3709            r = query(q, (d,)).getresult()[0][0]
3710            self.assertIsInstance(r, time)
3711            self.assertEqual(r, d)
3712
3713    def testTimestamp(self):
3714        query = self.db.query
3715        for datestyle in ('ISO', 'Postgres, MDY', 'Postgres, DMY',
3716                'SQL, MDY', 'SQL, DMY', 'German'):
3717            self.db.set_parameter('datestyle', datestyle)
3718            d = datetime(2016, 3, 14)
3719            q = "select $1::timestamp"
3720            r = query(q, (d,)).getresult()[0][0]
3721            self.assertIsInstance(r, datetime)
3722            self.assertEqual(r, d)
3723            d = datetime(2016, 3, 14, 15, 9, 26)
3724            q = "select $1::timestamp"
3725            r = query(q, (d,)).getresult()[0][0]
3726            self.assertIsInstance(r, datetime)
3727            self.assertEqual(r, d)
3728            d = datetime(2016, 3, 14, 15, 9, 26, 535897)
3729            q = "select $1::timestamp"
3730            r = query(q, (d,)).getresult()[0][0]
3731            self.assertIsInstance(r, datetime)
3732            self.assertEqual(r, d)
3733            q = ("select '10000-08-01 AD'::timestamp,"
3734                " '0099-01-08 BC'::timestamp")
3735            r = query(q).getresult()[0]
3736            self.assertIsInstance(r[0], datetime)
3737            self.assertIsInstance(r[1], datetime)
3738            self.assertEqual(r[0], datetime.max)
3739            self.assertEqual(r[1], datetime.min)
3740        q = "select 'infinity'::timestamp, '-infinity'::timestamp"
3741        r = query(q).getresult()[0]
3742        self.assertIsInstance(r[0], datetime)
3743        self.assertIsInstance(r[1], datetime)
3744        self.assertEqual(r[0], datetime.max)
3745        self.assertEqual(r[1], datetime.min)
3746
3747    def testTimestamptz(self):
3748        query = self.db.query
3749        timezones = dict(CET=1, EET=2, EST=-5, UTC=0)
3750        for timezone in sorted(timezones):
3751            tz = '%+03d00' % timezones[timezone]
3752            try:
3753                tzinfo = datetime.strptime(tz, '%z').tzinfo
3754            except ValueError:  # Python < 3.2
3755                tzinfo = pg._get_timezone(tz)
3756            self.db.set_parameter('timezone', timezone)
3757            for datestyle in ('ISO', 'Postgres, MDY', 'Postgres, DMY',
3758                    'SQL, MDY', 'SQL, DMY', 'German'):
3759                self.db.set_parameter('datestyle', datestyle)
3760                d = datetime(2016, 3, 14, tzinfo=tzinfo)
3761                q = "select $1::timestamptz"
3762                r = query(q, (d,)).getresult()[0][0]
3763                self.assertIsInstance(r, datetime)
3764                self.assertEqual(r, d)
3765                d = datetime(2016, 3, 14, 15, 9, 26, tzinfo=tzinfo)
3766                q = "select $1::timestamptz"
3767                r = query(q, (d,)).getresult()[0][0]
3768                self.assertIsInstance(r, datetime)
3769                self.assertEqual(r, d)
3770                d = datetime(2016, 3, 14, 15, 9, 26, 535897, tzinfo)
3771                q = "select $1::timestamptz"
3772                r = query(q, (d,)).getresult()[0][0]
3773                self.assertIsInstance(r, datetime)
3774                self.assertEqual(r, d)
3775                q = ("select '10000-08-01 AD'::timestamptz,"
3776                    " '0099-01-08 BC'::timestamptz")
3777                r = query(q).getresult()[0]
3778                self.assertIsInstance(r[0], datetime)
3779                self.assertIsInstance(r[1], datetime)
3780                self.assertEqual(r[0], datetime.max)
3781                self.assertEqual(r[1], datetime.min)
3782        q = "select 'infinity'::timestamptz, '-infinity'::timestamptz"
3783        r = query(q).getresult()[0]
3784        self.assertIsInstance(r[0], datetime)
3785        self.assertIsInstance(r[1], datetime)
3786        self.assertEqual(r[0], datetime.max)
3787        self.assertEqual(r[1], datetime.min)
3788
3789    def testInterval(self):
3790        query = self.db.query
3791        for intervalstyle in (
3792                'sql_standard', 'postgres', 'postgres_verbose', 'iso_8601'):
3793            self.db.set_parameter('intervalstyle', intervalstyle)
3794            d = timedelta(3)
3795            q = "select $1::interval"
3796            r = query(q, (d,)).getresult()[0][0]
3797            self.assertIsInstance(r, timedelta)
3798            self.assertEqual(r, d)
3799            d = timedelta(-30)
3800            r = query(q, (d,)).getresult()[0][0]
3801            self.assertIsInstance(r, timedelta)
3802            self.assertEqual(r, d)
3803            d = timedelta(hours=3, minutes=31, seconds=42, microseconds=5678)
3804            q = "select $1::interval"
3805            r = query(q, (d,)).getresult()[0][0]
3806            self.assertIsInstance(r, timedelta)
3807            self.assertEqual(r, d)
3808
3809    def testDateAndTimeArrays(self):
3810        dt = (date(2016, 3, 14), time(15, 9, 26))
3811        q = "select ARRAY[$1::date], ARRAY[$2::time]"
3812        r = self.db.query(q, dt).getresult()[0]
3813        self.assertIsInstance(r[0], list)
3814        self.assertEqual(r[0][0], dt[0])
3815        self.assertIsInstance(r[1], list)
3816        self.assertEqual(r[1][0], dt[1])
3817
3818    def testHstore(self):
3819        try:
3820            self.db.query("select 'k=>v'::hstore")
3821        except pg.DatabaseError:
3822            try:
3823                self.db.query("create extension hstore")
3824            except pg.DatabaseError:
3825                self.skipTest("hstore extension not enabled")
3826        d = {'k': 'v', 'foo': 'bar', 'baz': 'whatever',
3827            '1a': 'anything at all', '2=b': 'value = 2', '3>c': 'value > 3',
3828            '4"c': 'value " 4', "5'c": "value ' 5", 'hello, world': '"hi!"',
3829            'None': None, 'NULL': 'NULL', 'empty': ''}
3830        q = "select $1::hstore"
3831        r = self.db.query(q, (pg.Hstore(d),)).getresult()[0][0]
3832        self.assertIsInstance(r, dict)
3833        self.assertEqual(r, d)
3834
3835    def testUuid(self):
3836        d = UUID('{12345678-1234-5678-1234-567812345678}')
3837        q = 'select $1::uuid'
3838        r = self.db.query(q, (d,)).getresult()[0][0]
3839        self.assertIsInstance(r, UUID)
3840        self.assertEqual(r, d)
3841
3842    def testDbTypesInfo(self):
3843        dbtypes = self.db.dbtypes
3844        self.assertIsInstance(dbtypes, dict)
3845        self.assertNotIn('numeric', dbtypes)
3846        typ = dbtypes['numeric']
3847        self.assertIn('numeric', dbtypes)
3848        self.assertEqual(typ, 'numeric' if self.regtypes else 'num')
3849        self.assertEqual(typ.oid, 1700)
3850        self.assertEqual(typ.pgtype, 'numeric')
3851        self.assertEqual(typ.regtype, 'numeric')
3852        self.assertEqual(typ.simple, 'num')
3853        self.assertEqual(typ.typtype, 'b')
3854        self.assertEqual(typ.category, 'N')
3855        self.assertEqual(typ.delim, ',')
3856        self.assertEqual(typ.relid, 0)
3857        self.assertIs(dbtypes[1700], typ)
3858        self.assertNotIn('pg_type', dbtypes)
3859        typ = dbtypes['pg_type']
3860        self.assertIn('pg_type', dbtypes)
3861        self.assertEqual(typ, 'pg_type' if self.regtypes else 'record')
3862        self.assertIsInstance(typ.oid, int)
3863        self.assertEqual(typ.pgtype, 'pg_type')
3864        self.assertEqual(typ.regtype, 'pg_type')
3865        self.assertEqual(typ.simple, 'record')
3866        self.assertEqual(typ.typtype, 'c')
3867        self.assertEqual(typ.category, 'C')
3868        self.assertEqual(typ.delim, ',')
3869        self.assertNotEqual(typ.relid, 0)
3870        attnames = typ.attnames
3871        self.assertIsInstance(attnames, dict)
3872        self.assertIs(attnames, dbtypes.get_attnames('pg_type'))
3873        self.assertEqual(list(attnames)[0], 'typname')
3874        typname = attnames['typname']
3875        self.assertEqual(typname, 'name' if self.regtypes else 'text')
3876        self.assertEqual(typname.typtype, 'b')  # base
3877        self.assertEqual(typname.category, 'S')  # string
3878        self.assertEqual(list(attnames)[3], 'typlen')
3879        typlen = attnames['typlen']
3880        self.assertEqual(typlen, 'smallint' if self.regtypes else 'int')
3881        self.assertEqual(typlen.typtype, 'b')  # base
3882        self.assertEqual(typlen.category, 'N')  # numeric
3883
3884    def testDbTypesTypecast(self):
3885        dbtypes = self.db.dbtypes
3886        self.assertIsInstance(dbtypes, dict)
3887        self.assertNotIn('int4', dbtypes)
3888        self.assertIs(dbtypes.get_typecast('int4'), int)
3889        dbtypes.set_typecast('int4', float)
3890        self.assertIs(dbtypes.get_typecast('int4'), float)
3891        dbtypes.reset_typecast('int4')
3892        self.assertIs(dbtypes.get_typecast('int4'), int)
3893        dbtypes.set_typecast('int4', float)
3894        self.assertIs(dbtypes.get_typecast('int4'), float)
3895        dbtypes.reset_typecast()
3896        self.assertIs(dbtypes.get_typecast('int4'), int)
3897        self.assertNotIn('circle', dbtypes)
3898        self.assertIsNone(dbtypes.get_typecast('circle'))
3899        squared_circle = lambda v: 'Squared Circle: %s' % v
3900        dbtypes.set_typecast('circle', squared_circle)
3901        self.assertIs(dbtypes.get_typecast('circle'), squared_circle)
3902        r = self.db.query("select '0,0,1'::circle").getresult()[0][0]
3903        self.assertIn('circle', dbtypes)
3904        self.assertEqual(r, 'Squared Circle: <(0,0),1>')
3905        self.assertEqual(dbtypes.typecast('Impossible', 'circle'),
3906            'Squared Circle: Impossible')
3907        dbtypes.reset_typecast('circle')
3908        self.assertIsNone(dbtypes.get_typecast('circle'))
3909
3910    def testGetSetTypeCast(self):
3911        get_typecast = pg.get_typecast
3912        set_typecast = pg.set_typecast
3913        dbtypes = self.db.dbtypes
3914        self.assertIsInstance(dbtypes, dict)
3915        self.assertNotIn('int4', dbtypes)
3916        self.assertNotIn('real', dbtypes)
3917        self.assertNotIn('bool', dbtypes)
3918        self.assertIs(get_typecast('int4'), int)
3919        self.assertIs(get_typecast('float4'), float)
3920        self.assertIs(get_typecast('bool'), pg.cast_bool)
3921        cast_circle = get_typecast('circle')
3922        self.addCleanup(set_typecast, 'circle', cast_circle)
3923        squared_circle = lambda v: 'Squared Circle: %s' % v
3924        self.assertNotIn('circle', dbtypes)
3925        set_typecast('circle', squared_circle)
3926        self.assertNotIn('circle', dbtypes)
3927        self.assertIs(get_typecast('circle'), squared_circle)
3928        r = self.db.query("select '0,0,1'::circle").getresult()[0][0]
3929        self.assertIn('circle', dbtypes)
3930        self.assertEqual(r, 'Squared Circle: <(0,0),1>')
3931        set_typecast('circle', cast_circle)
3932        self.assertIs(get_typecast('circle'), cast_circle)
3933
3934    def testNotificationHandler(self):
3935        # the notification handler itself is tested separately
3936        f = self.db.notification_handler
3937        callback = lambda arg_dict: None
3938        handler = f('test', callback)
3939        self.assertIsInstance(handler, pg.NotificationHandler)
3940        self.assertIs(handler.db, self.db)
3941        self.assertEqual(handler.event, 'test')
3942        self.assertEqual(handler.stop_event, 'stop_test')
3943        self.assertIs(handler.callback, callback)
3944        self.assertIsInstance(handler.arg_dict, dict)
3945        self.assertEqual(handler.arg_dict, {})
3946        self.assertIsNone(handler.timeout)
3947        self.assertFalse(handler.listening)
3948        handler.close()
3949        self.assertIsNone(handler.db)
3950        self.db.reopen()
3951        self.assertIsNone(handler.db)
3952        handler = f('test2', callback, timeout=2)
3953        self.assertIsInstance(handler, pg.NotificationHandler)
3954        self.assertIs(handler.db, self.db)
3955        self.assertEqual(handler.event, 'test2')
3956        self.assertEqual(handler.stop_event, 'stop_test2')
3957        self.assertIs(handler.callback, callback)
3958        self.assertIsInstance(handler.arg_dict, dict)
3959        self.assertEqual(handler.arg_dict, {})
3960        self.assertEqual(handler.timeout, 2)
3961        self.assertFalse(handler.listening)
3962        handler.close()
3963        self.assertIsNone(handler.db)
3964        self.db.reopen()
3965        self.assertIsNone(handler.db)
3966        arg_dict = {'testing': 3}
3967        handler = f('test3', callback, arg_dict=arg_dict)
3968        self.assertIsInstance(handler, pg.NotificationHandler)
3969        self.assertIs(handler.db, self.db)
3970        self.assertEqual(handler.event, 'test3')
3971        self.assertEqual(handler.stop_event, 'stop_test3')
3972        self.assertIs(handler.callback, callback)
3973        self.assertIs(handler.arg_dict, arg_dict)
3974        self.assertEqual(arg_dict['testing'], 3)
3975        self.assertIsNone(handler.timeout)
3976        self.assertFalse(handler.listening)
3977        handler.close()
3978        self.assertIsNone(handler.db)
3979        self.db.reopen()
3980        self.assertIsNone(handler.db)
3981        handler = f('test4', callback, stop_event='stop4')
3982        self.assertIsInstance(handler, pg.NotificationHandler)
3983        self.assertIs(handler.db, self.db)
3984        self.assertEqual(handler.event, 'test4')
3985        self.assertEqual(handler.stop_event, 'stop4')
3986        self.assertIs(handler.callback, callback)
3987        self.assertIsInstance(handler.arg_dict, dict)
3988        self.assertEqual(handler.arg_dict, {})
3989        self.assertIsNone(handler.timeout)
3990        self.assertFalse(handler.listening)
3991        handler.close()
3992        self.assertIsNone(handler.db)
3993        self.db.reopen()
3994        self.assertIsNone(handler.db)
3995        arg_dict = {'testing': 5}
3996        handler = f('test5', callback, arg_dict, 1.5, 'stop5')
3997        self.assertIsInstance(handler, pg.NotificationHandler)
3998        self.assertIs(handler.db, self.db)
3999        self.assertEqual(handler.event, 'test5')
4000        self.assertEqual(handler.stop_event, 'stop5')
4001        self.assertIs(handler.callback, callback)
4002        self.assertIs(handler.arg_dict, arg_dict)
4003        self.assertEqual(arg_dict['testing'], 5)
4004        self.assertEqual(handler.timeout, 1.5)
4005        self.assertFalse(handler.listening)
4006        handler.close()
4007        self.assertIsNone(handler.db)
4008        self.db.reopen()
4009        self.assertIsNone(handler.db)
4010
4011
4012class TestDBClassNonStdOpts(TestDBClass):
4013    """Test the methods of the DB class with non-standard global options."""
4014
4015    @classmethod
4016    def setUpClass(cls):
4017        cls.saved_options = {}
4018        cls.set_option('decimal', float)
4019        not_bool = not pg.get_bool()
4020        cls.set_option('bool', not_bool)
4021        not_array = not pg.get_array()
4022        cls.set_option('array', not_array)
4023        not_bytea_escaped = not pg.get_bytea_escaped()
4024        cls.set_option('bytea_escaped', not_bytea_escaped)
4025        cls.set_option('namedresult', None)
4026        cls.set_option('jsondecode', None)
4027        db = DB()
4028        cls.regtypes = not db.use_regtypes()
4029        db.close()
4030        super(TestDBClassNonStdOpts, cls).setUpClass()
4031
4032    @classmethod
4033    def tearDownClass(cls):
4034        super(TestDBClassNonStdOpts, cls).tearDownClass()
4035        cls.reset_option('jsondecode')
4036        cls.reset_option('namedresult')
4037        cls.reset_option('bool')
4038        cls.reset_option('array')
4039        cls.reset_option('bytea_escaped')
4040        cls.reset_option('decimal')
4041
4042    @classmethod
4043    def set_option(cls, option, value):
4044        cls.saved_options[option] = getattr(pg, 'get_' + option)()
4045        return getattr(pg, 'set_' + option)(value)
4046
4047    @classmethod
4048    def reset_option(cls, option):
4049        return getattr(pg, 'set_' + option)(cls.saved_options[option])
4050
4051
4052class TestDBClassAdapter(unittest.TestCase):
4053    """Test the adapter object associatd with the DB class."""
4054
4055    def setUp(self):
4056        self.db = DB()
4057        self.adapter = self.db.adapter
4058
4059    def tearDown(self):
4060        try:
4061            self.db.close()
4062        except pg.InternalError:
4063            pass
4064
4065    def testGuessSimpleType(self):
4066        f = self.adapter.guess_simple_type
4067        self.assertEqual(f(pg.Bytea(b'test')), 'bytea')
4068        self.assertEqual(f('string'), 'text')
4069        self.assertEqual(f(b'string'), 'text')
4070        self.assertEqual(f(True), 'bool')
4071        self.assertEqual(f(3), 'int')
4072        self.assertEqual(f(2.75), 'float')
4073        self.assertEqual(f(Decimal('4.25')), 'num')
4074        self.assertEqual(f(date(2016, 1, 30)), 'date')
4075        self.assertEqual(f([1, 2, 3]), 'int[]')
4076        self.assertEqual(f([[[123]]]), 'int[]')
4077        self.assertEqual(f(['a', 'b', 'c']), 'text[]')
4078        self.assertEqual(f([[['abc']]]), 'text[]')
4079        self.assertEqual(f([False, True]), 'bool[]')
4080        self.assertEqual(f([[[False]]]), 'bool[]')
4081        r = f(('string', True, 3, 2.75, [1], [False]))
4082        self.assertEqual(r, 'record')
4083        self.assertEqual(list(r.attnames.values()),
4084            ['text', 'bool', 'int', 'float', 'int[]', 'bool[]'])
4085
4086    def testAdaptQueryTypedList(self):
4087        format_query = self.adapter.format_query
4088        self.assertRaises(TypeError, format_query,
4089            '%s,%s', (1, 2), ('int2',))
4090        self.assertRaises(TypeError, format_query,
4091            '%s,%s', (1,), ('int2', 'int2'))
4092        values = (3, 7.5, 'hello', True)
4093        types = ('int4', 'float4', 'text', 'bool')
4094        sql, params = format_query("select %s,%s,%s,%s", values, types)
4095        self.assertEqual(sql, 'select $1,$2,$3,$4')
4096        self.assertEqual(params, [3, 7.5, 'hello', 't'])
4097        types = ('bool', 'bool', 'bool', 'bool')
4098        sql, params = format_query("select %s,%s,%s,%s", values, types)
4099        self.assertEqual(sql, 'select $1,$2,$3,$4')
4100        self.assertEqual(params, ['t', 't', 'f', 't'])
4101        values = ('2016-01-30', 'current_date')
4102        types = ('date', 'date')
4103        sql, params = format_query("values(%s,%s)", values, types)
4104        self.assertEqual(sql, 'values($1,current_date)')
4105        self.assertEqual(params, ['2016-01-30'])
4106        values = ([1, 2, 3], ['a', 'b', 'c'])
4107        types = ('_int4', '_text')
4108        sql, params = format_query("%s::int4[],%s::text[]", values, types)
4109        self.assertEqual(sql, '$1::int4[],$2::text[]')
4110        self.assertEqual(params, ['{1,2,3}', '{a,b,c}'])
4111        types = ('_bool', '_bool')
4112        sql, params = format_query("%s::bool[],%s::bool[]", values, types)
4113        self.assertEqual(sql, '$1::bool[],$2::bool[]')
4114        self.assertEqual(params, ['{t,t,t}', '{f,f,f}'])
4115        values = [(3, 7.5, 'hello', True, [123], ['abc'])]
4116        t = self.adapter.simple_type
4117        typ = t('record')
4118        typ._get_attnames = lambda _self: pg.AttrDict([
4119            ('i', t('int')), ('f', t('float')),
4120            ('t', t('text')), ('b', t('bool')),
4121            ('i3', t('int[]')), ('t3', t('text[]'))])
4122        types = [typ]
4123        sql, params = format_query('select %s', values, types)
4124        self.assertEqual(sql, 'select $1')
4125        self.assertEqual(params, ['(3,7.5,hello,t,{123},{abc})'])
4126
4127    def testAdaptQueryTypedDict(self):
4128        format_query = self.adapter.format_query
4129        self.assertRaises(TypeError, format_query,
4130            '%s,%s', dict(i1=1, i2=2), dict(i1='int2'))
4131        values = dict(i=3, f=7.5, t='hello', b=True)
4132        types = dict(i='int4', f='float4',
4133            t='text', b='bool')
4134        sql, params = format_query(
4135            "select %(i)s,%(f)s,%(t)s,%(b)s", values, types)
4136        self.assertEqual(sql, 'select $3,$2,$4,$1')
4137        self.assertEqual(params, ['t', 7.5, 3, 'hello'])
4138        types = dict(i='bool', f='bool',
4139            t='bool', b='bool')
4140        sql, params = format_query(
4141            "select %(i)s,%(f)s,%(t)s,%(b)s", values, types)
4142        self.assertEqual(sql, 'select $3,$2,$4,$1')
4143        self.assertEqual(params, ['t', 't', 't', 'f'])
4144        values = dict(d1='2016-01-30', d2='current_date')
4145        types = dict(d1='date', d2='date')
4146        sql, params = format_query("values(%(d1)s,%(d2)s)", values, types)
4147        self.assertEqual(sql, 'values($1,current_date)')
4148        self.assertEqual(params, ['2016-01-30'])
4149        values = dict(i=[1, 2, 3], t=['a', 'b', 'c'])
4150        types = dict(i='_int4', t='_text')
4151        sql, params = format_query(
4152            "%(i)s::int4[],%(t)s::text[]", values, types)
4153        self.assertEqual(sql, '$1::int4[],$2::text[]')
4154        self.assertEqual(params, ['{1,2,3}', '{a,b,c}'])
4155        types = dict(i='_bool', t='_bool')
4156        sql, params = format_query(
4157            "%(i)s::bool[],%(t)s::bool[]", values, types)
4158        self.assertEqual(sql, '$1::bool[],$2::bool[]')
4159        self.assertEqual(params, ['{t,t,t}', '{f,f,f}'])
4160        values = dict(record=(3, 7.5, 'hello', True, [123], ['abc']))
4161        t = self.adapter.simple_type
4162        typ = t('record')
4163        typ._get_attnames = lambda _self: pg.AttrDict([
4164            ('i', t('int')), ('f', t('float')),
4165            ('t', t('text')), ('b', t('bool')),
4166            ('i3', t('int[]')), ('t3', t('text[]'))])
4167        types = dict(record=typ)
4168        sql, params = format_query('select %(record)s', values, types)
4169        self.assertEqual(sql, 'select $1')
4170        self.assertEqual(params, ['(3,7.5,hello,t,{123},{abc})'])
4171
4172    def testAdaptQueryUntypedList(self):
4173        format_query = self.adapter.format_query
4174        values = (3, 7.5, 'hello', True)
4175        sql, params = format_query("select %s,%s,%s,%s", values)
4176        self.assertEqual(sql, 'select $1,$2,$3,$4')
4177        self.assertEqual(params, [3, 7.5, 'hello', 't'])
4178        values = [date(2016, 1, 30), 'current_date']
4179        sql, params = format_query("values(%s,%s)", values)
4180        self.assertEqual(sql, 'values($1,$2)')
4181        self.assertEqual(params, values)
4182        values = ([1, 2, 3], ['a', 'b', 'c'], [True, False, True])
4183        sql, params = format_query("%s,%s,%s", values)
4184        self.assertEqual(sql, "$1,$2,$3")
4185        self.assertEqual(params, ['{1,2,3}', '{a,b,c}', '{t,f,t}'])
4186        values = ([[1, 2], [3, 4]], [['a', 'b'], ['c', 'd']],
4187            [[True, False], [False, True]])
4188        sql, params = format_query("%s,%s,%s", values)
4189        self.assertEqual(sql, "$1,$2,$3")
4190        self.assertEqual(params, [
4191            '{{1,2},{3,4}}', '{{a,b},{c,d}}', '{{t,f},{f,t}}'])
4192        values = [(3, 7.5, 'hello', True, [123], ['abc'])]
4193        sql, params = format_query('select %s', values)
4194        self.assertEqual(sql, 'select $1')
4195        self.assertEqual(params, ['(3,7.5,hello,t,{123},{abc})'])
4196
4197    def testAdaptQueryUntypedDict(self):
4198        format_query = self.adapter.format_query
4199        values = dict(i=3, f=7.5, t='hello', b=True)
4200        sql, params = format_query(
4201            "select %(i)s,%(f)s,%(t)s,%(b)s", values)
4202        self.assertEqual(sql, 'select $3,$2,$4,$1')
4203        self.assertEqual(params, ['t', 7.5, 3, 'hello'])
4204        values = dict(d1='2016-01-30', d2='current_date')
4205        sql, params = format_query("values(%(d1)s,%(d2)s)", values)
4206        self.assertEqual(sql, 'values($1,$2)')
4207        self.assertEqual(params, [values['d1'], values['d2']])
4208        values = dict(i=[1, 2, 3], t=['a', 'b', 'c'], b=[True, False, True])
4209        sql, params = format_query("%(i)s,%(t)s,%(b)s", values)
4210        self.assertEqual(sql, "$2,$3,$1")
4211        self.assertEqual(params, ['{t,f,t}', '{1,2,3}', '{a,b,c}'])
4212        values = dict(i=[[1, 2], [3, 4]], t=[['a', 'b'], ['c', 'd']],
4213            b=[[True, False], [False, True]])
4214        sql, params = format_query("%(i)s,%(t)s,%(b)s", values)
4215        self.assertEqual(sql, "$2,$3,$1")
4216        self.assertEqual(params, [
4217            '{{t,f},{f,t}}', '{{1,2},{3,4}}', '{{a,b},{c,d}}'])
4218        values = dict(record=(3, 7.5, 'hello', True, [123], ['abc']))
4219        sql, params = format_query('select %(record)s', values)
4220        self.assertEqual(sql, 'select $1')
4221        self.assertEqual(params, ['(3,7.5,hello,t,{123},{abc})'])
4222
4223    def testAdaptQueryInlineList(self):
4224        format_query = self.adapter.format_query
4225        values = (3, 7.5, 'hello', True)
4226        sql, params = format_query("select %s,%s,%s,%s", values, inline=True)
4227        self.assertEqual(sql, "select 3,7.5,'hello',true")
4228        self.assertEqual(params, [])
4229        values = [date(2016, 1, 30), 'current_date']
4230        sql, params = format_query("values(%s,%s)", values, inline=True)
4231        self.assertEqual(sql, "values('2016-01-30','current_date')")
4232        self.assertEqual(params, [])
4233        values = ([1, 2, 3], ['a', 'b', 'c'], [True, False, True])
4234        sql, params = format_query("%s,%s,%s", values, inline=True)
4235        self.assertEqual(sql,
4236            "ARRAY[1,2,3],ARRAY['a','b','c'],ARRAY[true,false,true]")
4237        self.assertEqual(params, [])
4238        values = ([[1, 2], [3, 4]], [['a', 'b'], ['c', 'd']],
4239            [[True, False], [False, True]])
4240        sql, params = format_query("%s,%s,%s", values, inline=True)
4241        self.assertEqual(sql, "ARRAY[[1,2],[3,4]],ARRAY[['a','b'],['c','d']],"
4242            "ARRAY[[true,false],[false,true]]")
4243        self.assertEqual(params, [])
4244        values = [(3, 7.5, 'hello', True, [123], ['abc'])]
4245        sql, params = format_query('select %s', values, inline=True)
4246        self.assertEqual(sql,
4247            "select (3,7.5,'hello',true,ARRAY[123],ARRAY['abc'])")
4248        self.assertEqual(params, [])
4249
4250    def testAdaptQueryInlineDict(self):
4251        format_query = self.adapter.format_query
4252        values = dict(i=3, f=7.5, t='hello', b=True)
4253        sql, params = format_query(
4254            "select %(i)s,%(f)s,%(t)s,%(b)s", values, inline=True)
4255        self.assertEqual(sql, "select 3,7.5,'hello',true")
4256        self.assertEqual(params, [])
4257        values = dict(d1='2016-01-30', d2='current_date')
4258        sql, params = format_query(
4259            "values(%(d1)s,%(d2)s)", values, inline=True)
4260        self.assertEqual(sql, "values('2016-01-30','current_date')")
4261        self.assertEqual(params, [])
4262        values = dict(i=[1, 2, 3], t=['a', 'b', 'c'], b=[True, False, True])
4263        sql, params = format_query("%(i)s,%(t)s,%(b)s", values, inline=True)
4264        self.assertEqual(sql,
4265            "ARRAY[1,2,3],ARRAY['a','b','c'],ARRAY[true,false,true]")
4266        self.assertEqual(params, [])
4267        values = dict(i=[[1, 2], [3, 4]], t=[['a', 'b'], ['c', 'd']],
4268            b=[[True, False], [False, True]])
4269        sql, params = format_query("%(i)s,%(t)s,%(b)s", values, inline=True)
4270        self.assertEqual(sql, "ARRAY[[1,2],[3,4]],ARRAY[['a','b'],['c','d']],"
4271            "ARRAY[[true,false],[false,true]]")
4272        self.assertEqual(params, [])
4273        values = dict(record=(3, 7.5, 'hello', True, [123], ['abc']))
4274        sql, params = format_query('select %(record)s', values, inline=True)
4275        self.assertEqual(sql,
4276            "select (3,7.5,'hello',true,ARRAY[123],ARRAY['abc'])")
4277        self.assertEqual(params, [])
4278
4279    def testAdaptQueryWithPgRepr(self):
4280        format_query = self.adapter.format_query
4281        self.assertRaises(TypeError, format_query,
4282            '%s', object(), inline=True)
4283        class TestObject:
4284            def __pg_repr__(self):
4285                return "'adapted'"
4286        sql, params = format_query('select %s', [TestObject()], inline=True)
4287        self.assertEqual(sql, "select 'adapted'")
4288        self.assertEqual(params, [])
4289        sql, params = format_query('select %s', [[TestObject()]], inline=True)
4290        self.assertEqual(sql, "select ARRAY['adapted']")
4291        self.assertEqual(params, [])
4292
4293
4294class TestSchemas(unittest.TestCase):
4295    """Test correct handling of schemas (namespaces)."""
4296
4297    cls_set_up = False
4298
4299    @classmethod
4300    def setUpClass(cls):
4301        db = DB()
4302        query = db.query
4303        for num_schema in range(5):
4304            if num_schema:
4305                schema = "s%d" % num_schema
4306                query("drop schema if exists %s cascade" % (schema,))
4307                try:
4308                    query("create schema %s" % (schema,))
4309                except pg.ProgrammingError:
4310                    raise RuntimeError("The test user cannot create schemas.\n"
4311                        "Grant create on database %s to the user"
4312                        " for running these tests." % dbname)
4313            else:
4314                schema = "public"
4315                query("drop table if exists %s.t" % (schema,))
4316                query("drop table if exists %s.t%d" % (schema, num_schema))
4317            query("create table %s.t with oids as select 1 as n, %d as d"
4318                  % (schema, num_schema))
4319            query("create table %s.t%d with oids as select 1 as n, %d as d"
4320                  % (schema, num_schema, num_schema))
4321        db.close()
4322        cls.cls_set_up = True
4323
4324    @classmethod
4325    def tearDownClass(cls):
4326        db = DB()
4327        query = db.query
4328        for num_schema in range(5):
4329            if num_schema:
4330                schema = "s%d" % num_schema
4331                query("drop schema %s cascade" % (schema,))
4332            else:
4333                schema = "public"
4334                query("drop table %s.t" % (schema,))
4335                query("drop table %s.t%d" % (schema, num_schema))
4336        db.close()
4337
4338    def setUp(self):
4339        self.assertTrue(self.cls_set_up)
4340        self.db = DB()
4341
4342    def tearDown(self):
4343        self.doCleanups()
4344        self.db.close()
4345
4346    def testGetTables(self):
4347        tables = self.db.get_tables()
4348        for num_schema in range(5):
4349            if num_schema:
4350                schema = "s" + str(num_schema)
4351            else:
4352                schema = "public"
4353            for t in (schema + ".t",
4354                    schema + ".t" + str(num_schema)):
4355                self.assertIn(t, tables)
4356
4357    def testGetAttnames(self):
4358        get_attnames = self.db.get_attnames
4359        query = self.db.query
4360        result = {'oid': 'int', 'd': 'int', 'n': 'int'}
4361        r = get_attnames("t")
4362        self.assertEqual(r, result)
4363        r = get_attnames("s4.t4")
4364        self.assertEqual(r, result)
4365        query("drop table if exists s3.t3m")
4366        self.addCleanup(query, "drop table s3.t3m")
4367        query("create table s3.t3m with oids as select 1 as m")
4368        result_m = {'oid': 'int', 'm': 'int'}
4369        r = get_attnames("s3.t3m")
4370        self.assertEqual(r, result_m)
4371        query("set search_path to s1,s3")
4372        r = get_attnames("t3")
4373        self.assertEqual(r, result)
4374        r = get_attnames("t3m")
4375        self.assertEqual(r, result_m)
4376
4377    def testGet(self):
4378        get = self.db.get
4379        query = self.db.query
4380        PrgError = pg.ProgrammingError
4381        self.assertEqual(get("t", 1, 'n')['d'], 0)
4382        self.assertEqual(get("t0", 1, 'n')['d'], 0)
4383        self.assertEqual(get("public.t", 1, 'n')['d'], 0)
4384        self.assertEqual(get("public.t0", 1, 'n')['d'], 0)
4385        self.assertRaises(PrgError, get, "public.t1", 1, 'n')
4386        self.assertEqual(get("s1.t1", 1, 'n')['d'], 1)
4387        self.assertEqual(get("s3.t", 1, 'n')['d'], 3)
4388        query("set search_path to s2,s4")
4389        self.assertRaises(PrgError, get, "t1", 1, 'n')
4390        self.assertEqual(get("t4", 1, 'n')['d'], 4)
4391        self.assertRaises(PrgError, get, "t3", 1, 'n')
4392        self.assertEqual(get("t", 1, 'n')['d'], 2)
4393        self.assertEqual(get("s3.t3", 1, 'n')['d'], 3)
4394        query("set search_path to s1,s3")
4395        self.assertRaises(PrgError, get, "t2", 1, 'n')
4396        self.assertEqual(get("t3", 1, 'n')['d'], 3)
4397        self.assertRaises(PrgError, get, "t4", 1, 'n')
4398        self.assertEqual(get("t", 1, 'n')['d'], 1)
4399        self.assertEqual(get("s4.t4", 1, 'n')['d'], 4)
4400
4401    def testMunging(self):
4402        get = self.db.get
4403        query = self.db.query
4404        r = get("t", 1, 'n')
4405        self.assertIn('oid(t)', r)
4406        query("set search_path to s2")
4407        r = get("t2", 1, 'n')
4408        self.assertIn('oid(t2)', r)
4409        query("set search_path to s3")
4410        r = get("t", 1, 'n')
4411        self.assertIn('oid(t)', r)
4412
4413
4414class TestDebug(unittest.TestCase):
4415    """Test the debug attribute of the DB class."""
4416
4417    def setUp(self):
4418        self.db = DB()
4419        self.query = self.db.query
4420        self.debug = self.db.debug
4421        self.output = StringIO()
4422        self.stdout, sys.stdout = sys.stdout, self.output
4423
4424    def tearDown(self):
4425        sys.stdout = self.stdout
4426        self.output.close()
4427        self.db.debug = debug
4428        self.db.close()
4429
4430    def get_output(self):
4431        return self.output.getvalue()
4432
4433    def send_queries(self):
4434        self.db.query("select 1")
4435        self.db.query("select 2")
4436
4437    def testDebugDefault(self):
4438        if debug:
4439            self.assertEqual(self.db.debug, debug)
4440        else:
4441            self.assertIsNone(self.db.debug)
4442
4443    def testDebugIsFalse(self):
4444        self.db.debug = False
4445        self.send_queries()
4446        self.assertEqual(self.get_output(), "")
4447
4448    def testDebugIsTrue(self):
4449        self.db.debug = True
4450        self.send_queries()
4451        self.assertEqual(self.get_output(), "select 1\nselect 2\n")
4452
4453    def testDebugIsString(self):
4454        self.db.debug = "Test with string: %s."
4455        self.send_queries()
4456        self.assertEqual(self.get_output(),
4457            "Test with string: select 1.\nTest with string: select 2.\n")
4458
4459    def testDebugIsFileLike(self):
4460        with tempfile.TemporaryFile('w+') as debug_file:
4461            self.db.debug = debug_file
4462            self.send_queries()
4463            debug_file.seek(0)
4464            output = debug_file.read()
4465            self.assertEqual(output, "select 1\nselect 2\n")
4466            self.assertEqual(self.get_output(), "")
4467
4468    def testDebugIsCallable(self):
4469        output = []
4470        self.db.debug = output.append
4471        self.db.query("select 1")
4472        self.db.query("select 2")
4473        self.assertEqual(output, ["select 1", "select 2"])
4474        self.assertEqual(self.get_output(), "")
4475
4476    def testDebugMultipleArgs(self):
4477        output = []
4478        self.db.debug = output.append
4479        args = ['Error', 42, {1: 'a', 2: 'b'}, [3, 5, 7]]
4480        self.db._do_debug(*args)
4481        self.assertEqual(output, ['\n'.join(str(arg) for arg in args)])
4482        self.assertEqual(self.get_output(), "")
4483
4484
4485class TestMemoryLeaks(unittest.TestCase):
4486    """Test that the DB class does not leak memory."""
4487
4488    def getLeaks(self, fut):
4489        ids = set()
4490        objs = []
4491        add_ids = ids.update
4492        gc.collect()
4493        objs[:] = gc.get_objects()
4494        add_ids(id(obj) for obj in objs)
4495        fut()
4496        gc.collect()
4497        objs[:] = gc.get_objects()
4498        objs[:] = [obj for obj in objs if id(obj) not in ids]
4499        if objs and sys.version_info[:3] in ((3, 5, 0), (3, 5, 1)):
4500            # workaround for Python issue 26811
4501            objs[:] = [obj for obj in objs if repr(obj) != '(<NULL>,)']
4502        self.assertEqual(len(objs), 0)
4503
4504    def testLeaksWithClose(self):
4505        def fut():
4506            db = DB()
4507            db.query("select $1::int as r", 42).dictresult()
4508            db.close()
4509        self.getLeaks(fut)
4510
4511    def testLeaksWithoutClose(self):
4512        def fut():
4513            db = DB()
4514            db.query("select $1::int as r", 42).dictresult()
4515        self.getLeaks(fut)
4516
4517
4518if __name__ == '__main__':
4519    unittest.main()
Note: See TracBrowser for help on using the repository browser.