source: trunk/tests/test_classic_dbwrapper.py @ 799

Last change on this file since 799 was 799, checked in by cito, 4 years ago

Improve adaptation and add query_formatted() method

Also added more tests and documentation.

  • Property svn:executable set to *
  • Property svn:keywords set to Id
File size: 163.9 KB
Line 
1#! /usr/bin/python
2# -*- coding: utf-8 -*-
3
4"""Test the classic PyGreSQL interface.
5
6Sub-tests for the DB wrapper object.
7
8Contributed by Christoph Zwerschke.
9
10These tests need a database to test against.
11"""
12
13try:
14    import unittest2 as unittest  # for Python < 2.7
15except ImportError:
16    import unittest
17
18import os
19import sys
20import tempfile
21import json
22
23import pg  # the module under test
24
25from decimal import Decimal
26from datetime import date
27from operator import itemgetter
28
29# We need a database to test against.  If LOCAL_PyGreSQL.py exists we will
30# get our information from that.  Otherwise we use the defaults.
31# The current user must have create schema privilege on the database.
32dbname = 'unittest'
33dbhost = None
34dbport = 5432
35
36debug = False  # let DB wrapper print debugging output
37
38try:
39    from .LOCAL_PyGreSQL import *
40except (ImportError, ValueError):
41    try:
42        from LOCAL_PyGreSQL import *
43    except ImportError:
44        pass
45
46try:
47    long
48except NameError:  # Python >= 3.0
49    long = int
50
51try:
52    unicode
53except NameError:  # Python >= 3.0
54    unicode = str
55
56try:
57    from collections import OrderedDict
58except ImportError:  # Python 2.6 or 3.0
59    OrderedDict = dict
60
61if str is bytes:
62    from StringIO import StringIO
63else:
64    from io import StringIO
65
66windows = os.name == 'nt'
67
68# There is a known a bug in libpq under Windows which can cause
69# the interface to crash when calling PQhost():
70do_not_ask_for_host = windows
71do_not_ask_for_host_reason = 'libpq issue on Windows'
72
73
74def DB():
75    """Create a DB wrapper object connecting to the test database."""
76    db = pg.DB(dbname, dbhost, dbport)
77    if debug:
78        db.debug = debug
79    db.query("set client_min_messages=warning")
80    return db
81
82
83class TestAttrDict(unittest.TestCase):
84    """Test the simple ordered dictionary for attribute names."""
85
86    cls = pg.AttrDict
87    base = OrderedDict
88
89    def testInit(self):
90        a = self.cls()
91        self.assertIsInstance(a, self.base)
92        self.assertEqual(a, self.base())
93        items = [('id', 'int'), ('name', 'text')]
94        a = self.cls(items)
95        self.assertIsInstance(a, self.base)
96        self.assertEqual(a, self.base(items))
97        iteritems = iter(items)
98        a = self.cls(iteritems)
99        self.assertIsInstance(a, self.base)
100        self.assertEqual(a, self.base(items))
101
102    def testIter(self):
103        a = self.cls()
104        self.assertEqual(list(a), [])
105        keys = ['id', 'name', 'age']
106        items = [(key, None) for key in keys]
107        a = self.cls(items)
108        self.assertEqual(list(a), keys)
109
110    def testKeys(self):
111        a = self.cls()
112        self.assertEqual(list(a.keys()), [])
113        keys = ['id', 'name', 'age']
114        items = [(key, None) for key in keys]
115        a = self.cls(items)
116        self.assertEqual(list(a.keys()), keys)
117
118    def testValues(self):
119        a = self.cls()
120        self.assertEqual(list(a.values()), [])
121        items = [('id', 'int'), ('name', 'text')]
122        values = [item[1] for item in items]
123        a = self.cls(items)
124        self.assertEqual(list(a.values()), values)
125
126    def testItems(self):
127        a = self.cls()
128        self.assertEqual(list(a.items()), [])
129        items = [('id', 'int'), ('name', 'text')]
130        a = self.cls(items)
131        self.assertEqual(list(a.items()), items)
132
133    def testGet(self):
134        a = self.cls([('id', 1)])
135        try:
136            self.assertEqual(a['id'], 1)
137        except KeyError:
138            self.fail('AttrDict should be readable')
139
140    def testSet(self):
141        a = self.cls()
142        try:
143            a['id'] = 1
144        except TypeError:
145            pass
146        else:
147            self.fail('AttrDict should be read-only')
148
149    def testDel(self):
150        a = self.cls([('id', 1)])
151        try:
152            del a['id']
153        except TypeError:
154            pass
155        else:
156            self.fail('AttrDict should be read-only')
157
158    def testWriteMethods(self):
159        a = self.cls([('id', 1)])
160        self.assertEqual(a['id'], 1)
161        for method in 'clear', 'update', 'pop', 'setdefault', 'popitem':
162            method = getattr(a, method)
163            self.assertRaises(TypeError, method, a)
164
165
166class TestDBClassBasic(unittest.TestCase):
167    """Test existence of the DB class wrapped pg connection methods."""
168
169    def setUp(self):
170        self.db = DB()
171
172    def tearDown(self):
173        try:
174            self.db.close()
175        except pg.InternalError:
176            pass
177
178    def testAllDBAttributes(self):
179        attributes = [
180            'abort', 'adapter',
181            'begin',
182            'cancel', 'clear', 'close', 'commit',
183            'db', 'dbname', 'dbtypes',
184            'debug', 'decode_json', 'delete',
185            'encode_json', 'end', 'endcopy', 'error',
186            'escape_bytea', 'escape_identifier',
187            'escape_literal', 'escape_string',
188            'fileno',
189            'get', 'get_as_dict', 'get_as_list',
190            'get_attnames', 'get_cast_hook',
191            'get_databases', 'get_notice_receiver',
192            'get_parameter', 'get_relations', 'get_tables',
193            'getline', 'getlo', 'getnotify',
194            'has_table_privilege', 'host',
195            'insert', 'inserttable',
196            'locreate', 'loimport',
197            'notification_handler',
198            'options',
199            'parameter', 'pkey', 'port',
200            'protocol_version', 'putline',
201            'query', 'query_formatted',
202            'release', 'reopen', 'reset', 'rollback',
203            'savepoint', 'server_version',
204            'set_cast_hook', 'set_notice_receiver',
205            'set_parameter',
206            'source', 'start', 'status',
207            'transaction', 'truncate',
208            'unescape_bytea', 'update', 'upsert',
209            'use_regtypes', 'user',
210        ]
211        db_attributes = [a for a in dir(self.db)
212            if not a.startswith('_')]
213        self.assertEqual(attributes, db_attributes)
214
215    def testAttributeDb(self):
216        self.assertEqual(self.db.db.db, dbname)
217
218    def testAttributeDbname(self):
219        self.assertEqual(self.db.dbname, dbname)
220
221    def testAttributeError(self):
222        error = self.db.error
223        self.assertTrue(not error or 'krb5_' in error)
224        self.assertEqual(self.db.error, self.db.db.error)
225
226    @unittest.skipIf(do_not_ask_for_host, do_not_ask_for_host_reason)
227    def testAttributeHost(self):
228        def_host = 'localhost'
229        host = self.db.host
230        self.assertIsInstance(host, str)
231        self.assertEqual(host, dbhost or def_host)
232        self.assertEqual(host, self.db.db.host)
233
234    def testAttributeOptions(self):
235        no_options = ''
236        options = self.db.options
237        self.assertEqual(options, no_options)
238        self.assertEqual(options, self.db.db.options)
239
240    def testAttributePort(self):
241        def_port = 5432
242        port = self.db.port
243        self.assertIsInstance(port, int)
244        self.assertEqual(port, dbport or def_port)
245        self.assertEqual(port, self.db.db.port)
246
247    def testAttributeProtocolVersion(self):
248        protocol_version = self.db.protocol_version
249        self.assertIsInstance(protocol_version, int)
250        self.assertTrue(2 <= protocol_version < 4)
251        self.assertEqual(protocol_version, self.db.db.protocol_version)
252
253    def testAttributeServerVersion(self):
254        server_version = self.db.server_version
255        self.assertIsInstance(server_version, int)
256        self.assertTrue(70400 <= server_version < 100000)
257        self.assertEqual(server_version, self.db.db.server_version)
258
259    def testAttributeStatus(self):
260        status_ok = 1
261        status = self.db.status
262        self.assertIsInstance(status, int)
263        self.assertEqual(status, status_ok)
264        self.assertEqual(status, self.db.db.status)
265
266    def testAttributeUser(self):
267        no_user = 'Deprecated facility'
268        user = self.db.user
269        self.assertTrue(user)
270        self.assertIsInstance(user, str)
271        self.assertNotEqual(user, no_user)
272        self.assertEqual(user, self.db.db.user)
273
274    def testMethodEscapeLiteral(self):
275        self.assertEqual(self.db.escape_literal(''), "''")
276
277    def testMethodEscapeIdentifier(self):
278        self.assertEqual(self.db.escape_identifier(''), '""')
279
280    def testMethodEscapeString(self):
281        self.assertEqual(self.db.escape_string(''), '')
282
283    def testMethodEscapeBytea(self):
284        self.assertEqual(self.db.escape_bytea('').replace(
285            '\\x', '').replace('\\', ''), '')
286
287    def testMethodUnescapeBytea(self):
288        self.assertEqual(self.db.unescape_bytea(''), b'')
289
290    def testMethodDecodeJson(self):
291        self.assertEqual(self.db.decode_json('{}'), {})
292
293    def testMethodEncodeJson(self):
294        self.assertEqual(self.db.encode_json({}), '{}')
295
296    def testMethodQuery(self):
297        query = self.db.query
298        query("select 1+1")
299        query("select 1+$1+$2", 2, 3)
300        query("select 1+$1+$2", (2, 3))
301        query("select 1+$1+$2", [2, 3])
302        query("select 1+$1", 1)
303
304    def testMethodQueryEmpty(self):
305        self.assertRaises(ValueError, self.db.query, '')
306
307    def testMethodQueryProgrammingError(self):
308        try:
309            self.db.query("select 1/0")
310        except pg.ProgrammingError as error:
311            self.assertEqual(error.sqlstate, '22012')
312
313    def testMethodEndcopy(self):
314        try:
315            self.db.endcopy()
316        except IOError:
317            pass
318
319    def testMethodClose(self):
320        self.db.close()
321        try:
322            self.db.reset()
323        except pg.Error:
324            pass
325        else:
326            self.fail('Reset should give an error for a closed connection')
327        self.assertIsNone(self.db.db)
328        self.assertRaises(pg.InternalError, self.db.close)
329        self.assertRaises(pg.InternalError, self.db.query, 'select 1')
330        self.assertRaises(pg.InternalError, getattr, self.db, 'status')
331        self.assertRaises(pg.InternalError, getattr, self.db, 'error')
332        self.assertRaises(pg.InternalError, getattr, self.db, 'absent')
333
334    def testMethodReset(self):
335        con = self.db.db
336        self.db.reset()
337        self.assertIs(self.db.db, con)
338        self.db.query("select 1+1")
339        self.db.close()
340        self.assertRaises(pg.InternalError, self.db.reset)
341
342    def testMethodReopen(self):
343        con = self.db.db
344        self.db.reopen()
345        self.assertIsNot(self.db.db, con)
346        con = self.db.db
347        self.db.query("select 1+1")
348        self.db.close()
349        self.db.reopen()
350        self.assertIsNot(self.db.db, con)
351        self.db.query("select 1+1")
352        self.db.close()
353
354    def testExistingConnection(self):
355        db = pg.DB(self.db.db)
356        self.assertEqual(self.db.db, db.db)
357        self.assertTrue(db.db)
358        db.close()
359        self.assertTrue(db.db)
360        db.reopen()
361        self.assertTrue(db.db)
362        db.close()
363        self.assertTrue(db.db)
364        db = pg.DB(self.db)
365        self.assertEqual(self.db.db, db.db)
366        db = pg.DB(db=self.db.db)
367        self.assertEqual(self.db.db, db.db)
368
369        class DB2:
370            pass
371
372        db2 = DB2()
373        db2._cnx = self.db.db
374        db = pg.DB(db2)
375        self.assertEqual(self.db.db, db.db)
376
377
378class TestDBClass(unittest.TestCase):
379    """Test the methods of the DB class wrapped pg connection."""
380
381    maxDiff = 80 * 20
382
383    cls_set_up = False
384
385    regtypes = None
386
387    @classmethod
388    def setUpClass(cls):
389        db = DB()
390        db.query("drop table if exists test cascade")
391        db.query("create table test ("
392                 "i2 smallint, i4 integer, i8 bigint,"
393                 " d numeric, f4 real, f8 double precision, m money,"
394                 " v4 varchar(4), c4 char(4), t text)")
395        db.query("create or replace view test_view as"
396                 " select i4, v4 from test")
397        db.close()
398        cls.cls_set_up = True
399
400    @classmethod
401    def tearDownClass(cls):
402        db = DB()
403        db.query("drop table test cascade")
404        db.close()
405
406    def setUp(self):
407        self.assertTrue(self.cls_set_up)
408        self.db = DB()
409        if self.regtypes is None:
410            self.regtypes = self.db.use_regtypes()
411        else:
412            self.db.use_regtypes(self.regtypes)
413        query = self.db.query
414        query('set client_encoding=utf8')
415        query('set standard_conforming_strings=on')
416        query("set lc_monetary='C'")
417        query("set datestyle='ISO,YMD'")
418        query('set bytea_output=hex')
419
420    def tearDown(self):
421        self.doCleanups()
422        self.db.close()
423
424    def createTable(self, table, definition,
425                    temporary=True, oids=None, values=None):
426        query = self.db.query
427        if not '"' in table or '.' in table:
428            table = '"%s"' % table
429        if not temporary:
430            q = 'drop table if exists %s cascade' % table
431            query(q)
432            self.addCleanup(query, q)
433        temporary = 'temporary table' if temporary else 'table'
434        as_query = definition.startswith(('as ', 'AS '))
435        if not as_query and not definition.startswith('('):
436            definition = '(%s)' % definition
437        with_oids = 'with oids' if oids else 'without oids'
438        q = ['create', temporary, table]
439        if as_query:
440            q.extend([with_oids, definition])
441        else:
442            q.extend([definition, with_oids])
443        q = ' '.join(q)
444        query(q)
445        if values:
446            for params in values:
447                if not isinstance(params, (list, tuple)):
448                    params = [params]
449                values = ', '.join('$%d' % (n + 1) for n in range(len(params)))
450                q = "insert into %s values (%s)" % (table, values)
451                query(q, params)
452
453    def testClassName(self):
454        self.assertEqual(self.db.__class__.__name__, 'DB')
455
456    def testModuleName(self):
457        self.assertEqual(self.db.__module__, 'pg')
458        self.assertEqual(self.db.__class__.__module__, 'pg')
459
460    def testEscapeLiteral(self):
461        f = self.db.escape_literal
462        r = f(b"plain")
463        self.assertIsInstance(r, bytes)
464        self.assertEqual(r, b"'plain'")
465        r = f(u"plain")
466        self.assertIsInstance(r, unicode)
467        self.assertEqual(r, u"'plain'")
468        r = f(u"that's kÀse".encode('utf-8'))
469        self.assertIsInstance(r, bytes)
470        self.assertEqual(r, u"'that''s kÀse'".encode('utf-8'))
471        r = f(u"that's kÀse")
472        self.assertIsInstance(r, unicode)
473        self.assertEqual(r, u"'that''s kÀse'")
474        self.assertEqual(f(r"It's fine to have a \ inside."),
475                         r" E'It''s fine to have a \\ inside.'")
476        self.assertEqual(f('No "quotes" must be escaped.'),
477                         "'No \"quotes\" must be escaped.'")
478
479    def testEscapeIdentifier(self):
480        f = self.db.escape_identifier
481        r = f(b"plain")
482        self.assertIsInstance(r, bytes)
483        self.assertEqual(r, b'"plain"')
484        r = f(u"plain")
485        self.assertIsInstance(r, unicode)
486        self.assertEqual(r, u'"plain"')
487        r = f(u"that's kÀse".encode('utf-8'))
488        self.assertIsInstance(r, bytes)
489        self.assertEqual(r, u'"that\'s kÀse"'.encode('utf-8'))
490        r = f(u"that's kÀse")
491        self.assertIsInstance(r, unicode)
492        self.assertEqual(r, u'"that\'s kÀse"')
493        self.assertEqual(f(r"It's fine to have a \ inside."),
494                         '"It\'s fine to have a \\ inside."')
495        self.assertEqual(f('All "quotes" must be escaped.'),
496                         '"All ""quotes"" must be escaped."')
497
498    def testEscapeString(self):
499        f = self.db.escape_string
500        r = f(b"plain")
501        self.assertIsInstance(r, bytes)
502        self.assertEqual(r, b"plain")
503        r = f(u"plain")
504        self.assertIsInstance(r, unicode)
505        self.assertEqual(r, u"plain")
506        r = f(u"that's kÀse".encode('utf-8'))
507        self.assertIsInstance(r, bytes)
508        self.assertEqual(r, u"that''s kÀse".encode('utf-8'))
509        r = f(u"that's kÀse")
510        self.assertIsInstance(r, unicode)
511        self.assertEqual(r, u"that''s kÀse")
512        self.assertEqual(f(r"It's fine to have a \ inside."),
513                         r"It''s fine to have a \ inside.")
514
515    def testEscapeBytea(self):
516        f = self.db.escape_bytea
517        # note that escape_byte always returns hex output since Pg 9.0,
518        # regardless of the bytea_output setting
519        r = f(b'plain')
520        self.assertIsInstance(r, bytes)
521        self.assertEqual(r, b'\\x706c61696e')
522        r = f(u'plain')
523        self.assertIsInstance(r, unicode)
524        self.assertEqual(r, u'\\x706c61696e')
525        r = f(u"das is' kÀse".encode('utf-8'))
526        self.assertIsInstance(r, bytes)
527        self.assertEqual(r, b'\\x64617320697327206bc3a47365')
528        r = f(u"das is' kÀse")
529        self.assertIsInstance(r, unicode)
530        self.assertEqual(r, u'\\x64617320697327206bc3a47365')
531        self.assertEqual(f(b'O\x00ps\xff!'), b'\\x4f007073ff21')
532
533    def testUnescapeBytea(self):
534        f = self.db.unescape_bytea
535        r = f(b'plain')
536        self.assertIsInstance(r, bytes)
537        self.assertEqual(r, b'plain')
538        r = f(u'plain')
539        self.assertIsInstance(r, bytes)
540        self.assertEqual(r, b'plain')
541        r = f(b"das is' k\\303\\244se")
542        self.assertIsInstance(r, bytes)
543        self.assertEqual(r, u"das is' kÀse".encode('utf8'))
544        r = f(u"das is' k\\303\\244se")
545        self.assertIsInstance(r, bytes)
546        self.assertEqual(r, u"das is' kÀse".encode('utf8'))
547        self.assertEqual(f(r'O\\000ps\\377!'), b'O\\000ps\\377!')
548        self.assertEqual(f(r'\\x706c61696e'), b'\\x706c61696e')
549        self.assertEqual(f(r'\\x746861742773206be47365'),
550                         b'\\x746861742773206be47365')
551        self.assertEqual(f(r'\\x4f007073ff21'), b'\\x4f007073ff21')
552
553    def testDecodeJson(self):
554        f = self.db.decode_json
555        self.assertIsNone(f('null'))
556        data = {
557            "id": 1, "name": "Foo", "price": 1234.5,
558            "new": True, "note": None,
559            "tags": ["Bar", "Eek"],
560            "stock": {"warehouse": 300, "retail": 20}}
561        text = json.dumps(data)
562        r = f(text)
563        self.assertIsInstance(r, dict)
564        self.assertEqual(r, data)
565        self.assertIsInstance(r['id'], int)
566        self.assertIsInstance(r['name'], unicode)
567        self.assertIsInstance(r['price'], float)
568        self.assertIsInstance(r['new'], bool)
569        self.assertIsInstance(r['tags'], list)
570        self.assertIsInstance(r['stock'], dict)
571
572    def testEncodeJson(self):
573        f = self.db.encode_json
574        self.assertEqual(f(None), 'null')
575        data = {
576            "id": 1, "name": "Foo", "price": 1234.5,
577            "new": True, "note": None,
578            "tags": ["Bar", "Eek"],
579            "stock": {"warehouse": 300, "retail": 20}}
580        text = json.dumps(data)
581        r = f(data)
582        self.assertIsInstance(r, str)
583        self.assertEqual(r, text)
584
585    def testGetParameter(self):
586        f = self.db.get_parameter
587        self.assertRaises(TypeError, f)
588        self.assertRaises(TypeError, f, None)
589        self.assertRaises(TypeError, f, 42)
590        self.assertRaises(TypeError, f, '')
591        self.assertRaises(TypeError, f, [])
592        self.assertRaises(TypeError, f, [''])
593        self.assertRaises(pg.ProgrammingError, f, 'this_does_not_exist')
594        r = f('standard_conforming_strings')
595        self.assertEqual(r, 'on')
596        r = f('lc_monetary')
597        self.assertEqual(r, 'C')
598        r = f('datestyle')
599        self.assertEqual(r, 'ISO, YMD')
600        r = f('bytea_output')
601        self.assertEqual(r, 'hex')
602        r = f(['bytea_output', 'lc_monetary'])
603        self.assertIsInstance(r, list)
604        self.assertEqual(r, ['hex', 'C'])
605        r = f(('standard_conforming_strings', 'datestyle', 'bytea_output'))
606        self.assertEqual(r, ['on', 'ISO, YMD', 'hex'])
607        r = f(set(['bytea_output', 'lc_monetary']))
608        self.assertIsInstance(r, dict)
609        self.assertEqual(r, {'bytea_output': 'hex', 'lc_monetary': 'C'})
610        r = f(set(['Bytea_Output', ' LC_Monetary ']))
611        self.assertIsInstance(r, dict)
612        self.assertEqual(r, {'Bytea_Output': 'hex', ' LC_Monetary ': 'C'})
613        s = dict.fromkeys(('bytea_output', 'lc_monetary'))
614        r = f(s)
615        self.assertIs(r, s)
616        self.assertEqual(r, {'bytea_output': 'hex', 'lc_monetary': 'C'})
617        s = dict.fromkeys(('Bytea_Output', ' LC_Monetary '))
618        r = f(s)
619        self.assertIs(r, s)
620        self.assertEqual(r, {'Bytea_Output': 'hex', ' LC_Monetary ': 'C'})
621
622    def testGetParameterServerVersion(self):
623        r = self.db.get_parameter('server_version_num')
624        self.assertIsInstance(r, str)
625        s = self.db.server_version
626        self.assertIsInstance(s, int)
627        self.assertEqual(r, str(s))
628
629    def testGetParameterAll(self):
630        f = self.db.get_parameter
631        r = f('all')
632        self.assertIsInstance(r, dict)
633        self.assertEqual(r['standard_conforming_strings'], 'on')
634        self.assertEqual(r['lc_monetary'], 'C')
635        self.assertEqual(r['DateStyle'], 'ISO, YMD')
636        self.assertEqual(r['bytea_output'], 'hex')
637
638    def testSetParameter(self):
639        f = self.db.set_parameter
640        g = self.db.get_parameter
641        self.assertRaises(TypeError, f)
642        self.assertRaises(TypeError, f, None)
643        self.assertRaises(TypeError, f, 42)
644        self.assertRaises(TypeError, f, '')
645        self.assertRaises(TypeError, f, [])
646        self.assertRaises(TypeError, f, [''])
647        self.assertRaises(ValueError, f, 'all', 'invalid')
648        self.assertRaises(ValueError, f, {
649            'invalid1': 'value1', 'invalid2': 'value2'}, 'value')
650        self.assertRaises(pg.ProgrammingError, f, 'this_does_not_exist')
651        f('standard_conforming_strings', 'off')
652        self.assertEqual(g('standard_conforming_strings'), 'off')
653        f('datestyle', 'ISO, DMY')
654        self.assertEqual(g('datestyle'), 'ISO, DMY')
655        f(['standard_conforming_strings', 'datestyle'], ['on', 'ISO, DMY'])
656        self.assertEqual(g('standard_conforming_strings'), 'on')
657        self.assertEqual(g('datestyle'), 'ISO, DMY')
658        f(['default_with_oids', 'standard_conforming_strings'], 'off')
659        self.assertEqual(g('default_with_oids'), 'off')
660        self.assertEqual(g('standard_conforming_strings'), 'off')
661        f(('standard_conforming_strings', 'datestyle'), ('on', 'ISO, YMD'))
662        self.assertEqual(g('standard_conforming_strings'), 'on')
663        self.assertEqual(g('datestyle'), 'ISO, YMD')
664        f(('default_with_oids', 'standard_conforming_strings'), 'off')
665        self.assertEqual(g('default_with_oids'), 'off')
666        self.assertEqual(g('standard_conforming_strings'), 'off')
667        f(set(['default_with_oids', 'standard_conforming_strings']), 'on')
668        self.assertEqual(g('default_with_oids'), 'on')
669        self.assertEqual(g('standard_conforming_strings'), 'on')
670        self.assertRaises(ValueError, f, set(['default_with_oids',
671            'standard_conforming_strings']), ['off', 'on'])
672        f(set(['default_with_oids', 'standard_conforming_strings']),
673            ['off', 'off'])
674        self.assertEqual(g('default_with_oids'), 'off')
675        self.assertEqual(g('standard_conforming_strings'), 'off')
676        f({'standard_conforming_strings': 'on', 'datestyle': 'ISO, YMD'})
677        self.assertEqual(g('standard_conforming_strings'), 'on')
678        self.assertEqual(g('datestyle'), 'ISO, YMD')
679
680    def testResetParameter(self):
681        db = DB()
682        f = db.set_parameter
683        g = db.get_parameter
684        r = g('default_with_oids')
685        self.assertIn(r, ('on', 'off'))
686        dwi, not_dwi = r, 'off' if r == 'on' else 'on'
687        r = g('standard_conforming_strings')
688        self.assertIn(r, ('on', 'off'))
689        scs, not_scs = r, 'off' if r == 'on' else 'on'
690        f('default_with_oids', not_dwi)
691        f('standard_conforming_strings', not_scs)
692        self.assertEqual(g('default_with_oids'), not_dwi)
693        self.assertEqual(g('standard_conforming_strings'), not_scs)
694        f('default_with_oids')
695        f('standard_conforming_strings', None)
696        self.assertEqual(g('default_with_oids'), dwi)
697        self.assertEqual(g('standard_conforming_strings'), scs)
698        f('default_with_oids', not_dwi)
699        f('standard_conforming_strings', not_scs)
700        self.assertEqual(g('default_with_oids'), not_dwi)
701        self.assertEqual(g('standard_conforming_strings'), not_scs)
702        f(['default_with_oids', 'standard_conforming_strings'], None)
703        self.assertEqual(g('default_with_oids'), dwi)
704        self.assertEqual(g('standard_conforming_strings'), scs)
705        f('default_with_oids', not_dwi)
706        f('standard_conforming_strings', not_scs)
707        self.assertEqual(g('default_with_oids'), not_dwi)
708        self.assertEqual(g('standard_conforming_strings'), not_scs)
709        f(('default_with_oids', 'standard_conforming_strings'))
710        self.assertEqual(g('default_with_oids'), dwi)
711        self.assertEqual(g('standard_conforming_strings'), scs)
712        f('default_with_oids', not_dwi)
713        f('standard_conforming_strings', not_scs)
714        self.assertEqual(g('default_with_oids'), not_dwi)
715        self.assertEqual(g('standard_conforming_strings'), not_scs)
716        f(set(['default_with_oids', 'standard_conforming_strings']))
717        self.assertEqual(g('default_with_oids'), dwi)
718        self.assertEqual(g('standard_conforming_strings'), scs)
719
720    def testResetParameterAll(self):
721        db = DB()
722        f = db.set_parameter
723        self.assertRaises(ValueError, f, 'all', 0)
724        self.assertRaises(ValueError, f, 'all', 'off')
725        g = db.get_parameter
726        r = g('default_with_oids')
727        self.assertIn(r, ('on', 'off'))
728        dwi, not_dwi = r, 'off' if r == 'on' else 'on'
729        r = g('standard_conforming_strings')
730        self.assertIn(r, ('on', 'off'))
731        scs, not_scs = r, 'off' if r == 'on' else 'on'
732        f('default_with_oids', not_dwi)
733        f('standard_conforming_strings', not_scs)
734        self.assertEqual(g('default_with_oids'), not_dwi)
735        self.assertEqual(g('standard_conforming_strings'), not_scs)
736        f('all')
737        self.assertEqual(g('default_with_oids'), dwi)
738        self.assertEqual(g('standard_conforming_strings'), scs)
739
740    def testSetParameterLocal(self):
741        f = self.db.set_parameter
742        g = self.db.get_parameter
743        self.assertEqual(g('standard_conforming_strings'), 'on')
744        self.db.begin()
745        f('standard_conforming_strings', 'off', local=True)
746        self.assertEqual(g('standard_conforming_strings'), 'off')
747        self.db.end()
748        self.assertEqual(g('standard_conforming_strings'), 'on')
749
750    def testSetParameterSession(self):
751        f = self.db.set_parameter
752        g = self.db.get_parameter
753        self.assertEqual(g('standard_conforming_strings'), 'on')
754        self.db.begin()
755        f('standard_conforming_strings', 'off', local=False)
756        self.assertEqual(g('standard_conforming_strings'), 'off')
757        self.db.end()
758        self.assertEqual(g('standard_conforming_strings'), 'off')
759
760    def testReset(self):
761        db = DB()
762        default_datestyle = db.get_parameter('datestyle')
763        changed_datestyle = 'ISO, DMY'
764        if changed_datestyle == default_datestyle:
765            changed_datestyle == 'ISO, YMD'
766        self.db.set_parameter('datestyle', changed_datestyle)
767        r = self.db.get_parameter('datestyle')
768        self.assertEqual(r, changed_datestyle)
769        con = self.db.db
770        q = con.query("show datestyle")
771        self.db.reset()
772        r = q.getresult()[0][0]
773        self.assertEqual(r, changed_datestyle)
774        q = con.query("show datestyle")
775        r = q.getresult()[0][0]
776        self.assertEqual(r, default_datestyle)
777        r = self.db.get_parameter('datestyle')
778        self.assertEqual(r, default_datestyle)
779
780    def testReopen(self):
781        db = DB()
782        default_datestyle = db.get_parameter('datestyle')
783        changed_datestyle = 'ISO, DMY'
784        if changed_datestyle == default_datestyle:
785            changed_datestyle == 'ISO, YMD'
786        self.db.set_parameter('datestyle', changed_datestyle)
787        r = self.db.get_parameter('datestyle')
788        self.assertEqual(r, changed_datestyle)
789        con = self.db.db
790        q = con.query("show datestyle")
791        self.db.reopen()
792        r = q.getresult()[0][0]
793        self.assertEqual(r, changed_datestyle)
794        self.assertRaises(TypeError, getattr, con, 'query')
795        r = self.db.get_parameter('datestyle')
796        self.assertEqual(r, default_datestyle)
797
798    def testCreateTable(self):
799        table = 'test hello world'
800        values = [(2, "World!"), (1, "Hello")]
801        self.createTable(table, "n smallint, t varchar",
802                         temporary=True, oids=True, values=values)
803        r = self.db.query('select t from "%s" order by n' % table).getresult()
804        r = ', '.join(row[0] for row in r)
805        self.assertEqual(r, "Hello, World!")
806        r = self.db.query('select oid from "%s" limit 1' % table).getresult()
807        self.assertIsInstance(r[0][0], int)
808
809    def testQuery(self):
810        query = self.db.query
811        table = 'test_table'
812        self.createTable(table, "n integer", oids=True)
813        q = "insert into test_table values (1)"
814        r = query(q)
815        self.assertIsInstance(r, int)
816        q = "insert into test_table select 2"
817        r = query(q)
818        self.assertIsInstance(r, int)
819        oid = r
820        q = "select oid from test_table where n=2"
821        r = query(q).getresult()
822        self.assertEqual(len(r), 1)
823        r = r[0]
824        self.assertEqual(len(r), 1)
825        r = r[0]
826        self.assertEqual(r, oid)
827        q = "insert into test_table select 3 union select 4 union select 5"
828        r = query(q)
829        self.assertIsInstance(r, str)
830        self.assertEqual(r, '3')
831        q = "update test_table set n=4 where n<5"
832        r = query(q)
833        self.assertIsInstance(r, str)
834        self.assertEqual(r, '4')
835        q = "delete from test_table"
836        r = query(q)
837        self.assertIsInstance(r, str)
838        self.assertEqual(r, '5')
839
840    def testMultipleQueries(self):
841        self.assertEqual(self.db.query(
842            "create temporary table test_multi (n integer);"
843            "insert into test_multi values (4711);"
844            "select n from test_multi").getresult()[0][0], 4711)
845
846    def testQueryWithParams(self):
847        query = self.db.query
848        self.createTable('test_table', 'n1 integer, n2 integer', oids=True)
849        q = "insert into test_table values ($1, $2)"
850        r = query(q, (1, 2))
851        self.assertIsInstance(r, int)
852        r = query(q, [3, 4])
853        self.assertIsInstance(r, int)
854        r = query(q, [5, 6])
855        self.assertIsInstance(r, int)
856        q = "select * from test_table order by 1, 2"
857        self.assertEqual(query(q).getresult(),
858                         [(1, 2), (3, 4), (5, 6)])
859        q = "select * from test_table where n1=$1 and n2=$2"
860        self.assertEqual(query(q, 3, 4).getresult(), [(3, 4)])
861        q = "update test_table set n2=$2 where n1=$1"
862        r = query(q, 3, 7)
863        self.assertEqual(r, '1')
864        q = "select * from test_table order by 1, 2"
865        self.assertEqual(query(q).getresult(),
866                         [(1, 2), (3, 7), (5, 6)])
867        q = "delete from test_table where n2!=$1"
868        r = query(q, 4)
869        self.assertEqual(r, '3')
870
871    def testEmptyQuery(self):
872        self.assertRaises(ValueError, self.db.query, '')
873
874    def testQueryProgrammingError(self):
875        try:
876            self.db.query("select 1/0")
877        except pg.ProgrammingError as error:
878            self.assertEqual(error.sqlstate, '22012')
879
880    def testQueryFormatted(self):
881        f = self.db.query_formatted
882        t = True if pg.get_bool() else 't'
883        q = f("select %s::int, %s::real, %s::text, %s::bool",
884              (3, 2.5, 'hello', True))
885        r = q.getresult()[0]
886        self.assertEqual(r, (3, 2.5, 'hello', t))
887        q = f("select %s, %s, %s, %s", (3, 2.5, 'hello', True), inline=True)
888        r = q.getresult()[0]
889        self.assertEqual(r, (3, 2.5, 'hello', t))
890
891    def testPkey(self):
892        query = self.db.query
893        pkey = self.db.pkey
894        self.assertRaises(KeyError, pkey, 'test')
895        for t in ('pkeytest', 'primary key test'):
896            self.createTable('%s0' % t, 'a smallint')
897            self.createTable('%s1' % t, 'b smallint primary key')
898            self.createTable('%s2' % t,
899                'c smallint, d smallint primary key')
900            self.createTable('%s3' % t,
901                'e smallint, f smallint, g smallint, h smallint, i smallint,'
902                ' primary key (f, h)')
903            self.createTable('%s4' % t,
904                'e smallint, f smallint, g smallint, h smallint, i smallint,'
905                ' primary key (h, f)')
906            self.createTable('%s5' % t,
907                'more_than_one_letter varchar primary key')
908            self.createTable('%s6' % t,
909                '"with space" date primary key')
910            self.createTable('%s7' % t,
911                'a_very_long_column_name varchar, "with space" date, "42" int,'
912                ' primary key (a_very_long_column_name, "with space", "42")')
913            self.assertRaises(KeyError, pkey, '%s0' % t)
914            self.assertEqual(pkey('%s1' % t), 'b')
915            self.assertEqual(pkey('%s1' % t, True), ('b',))
916            self.assertEqual(pkey('%s1' % t, composite=False), 'b')
917            self.assertEqual(pkey('%s1' % t, composite=True), ('b',))
918            self.assertEqual(pkey('%s2' % t), 'd')
919            self.assertEqual(pkey('%s2' % t, composite=True), ('d',))
920            r = pkey('%s3' % t)
921            self.assertIsInstance(r, tuple)
922            self.assertEqual(r, ('f', 'h'))
923            r = pkey('%s3' % t, composite=False)
924            self.assertIsInstance(r, tuple)
925            self.assertEqual(r, ('f', 'h'))
926            r = pkey('%s4' % t)
927            self.assertIsInstance(r, tuple)
928            self.assertEqual(r, ('h', 'f'))
929            self.assertEqual(pkey('%s5' % t), 'more_than_one_letter')
930            self.assertEqual(pkey('%s6' % t), 'with space')
931            r = pkey('%s7' % t)
932            self.assertIsInstance(r, tuple)
933            self.assertEqual(r, (
934                'a_very_long_column_name', 'with space', '42'))
935            # a newly added primary key will be detected
936            query('alter table "%s0" add primary key (a)' % t)
937            self.assertEqual(pkey('%s0' % t), 'a')
938            # a changed primary key will not be detected,
939            # indicating that the internal cache is operating
940            query('alter table "%s1" rename column b to x' % t)
941            self.assertEqual(pkey('%s1' % t), 'b')
942            # we get the changed primary key when the cache is flushed
943            self.assertEqual(pkey('%s1' % t, flush=True), 'x')
944
945    def testGetDatabases(self):
946        databases = self.db.get_databases()
947        self.assertIn('template0', databases)
948        self.assertIn('template1', databases)
949        self.assertNotIn('not existing database', databases)
950        self.assertIn('postgres', databases)
951        self.assertIn(dbname, databases)
952
953    def testGetTables(self):
954        get_tables = self.db.get_tables
955        tables = ('A very Special Name', 'A_MiXeD_quoted_NaMe',
956                  'Hello, Test World!', 'Zoro', 'a1', 'a2', 'a321',
957                  'averyveryveryveryveryveryveryreallyreallylongtablename',
958                  'b0', 'b3', 'x', 'xXx', 'xx', 'y', 'z')
959        for t in tables:
960            self.db.query('drop table if exists "%s" cascade' % t)
961        before_tables = get_tables()
962        self.assertIsInstance(before_tables, list)
963        for t in before_tables:
964            t = t.split('.', 1)
965            self.assertGreaterEqual(len(t), 2)
966            if len(t) > 2:
967                self.assertTrue(t[1].startswith('"'))
968            t = t[0]
969            self.assertNotEqual(t, 'information_schema')
970            self.assertFalse(t.startswith('pg_'))
971        for t in tables:
972            self.createTable(t, 'as select 0', temporary=False)
973        current_tables = get_tables()
974        new_tables = [t for t in current_tables if t not in before_tables]
975        expected_new_tables = ['public.%s' % (
976            '"%s"' % t if ' ' in t or t != t.lower() else t) for t in tables]
977        self.assertEqual(new_tables, expected_new_tables)
978        self.doCleanups()
979        after_tables = get_tables()
980        self.assertEqual(after_tables, before_tables)
981
982    def testGetRelations(self):
983        get_relations = self.db.get_relations
984        result = get_relations()
985        self.assertIn('public.test', result)
986        self.assertIn('public.test_view', result)
987        result = get_relations('rv')
988        self.assertIn('public.test', result)
989        self.assertIn('public.test_view', result)
990        result = get_relations('r')
991        self.assertIn('public.test', result)
992        self.assertNotIn('public.test_view', result)
993        result = get_relations('v')
994        self.assertNotIn('public.test', result)
995        self.assertIn('public.test_view', result)
996        result = get_relations('cisSt')
997        self.assertNotIn('public.test', result)
998        self.assertNotIn('public.test_view', result)
999
1000    def testGetAttnames(self):
1001        get_attnames = self.db.get_attnames
1002        self.assertRaises(pg.ProgrammingError,
1003                          self.db.get_attnames, 'does_not_exist')
1004        self.assertRaises(pg.ProgrammingError,
1005                          self.db.get_attnames, 'has.too.many.dots')
1006        r = get_attnames('test')
1007        self.assertIsInstance(r, dict)
1008        if self.regtypes:
1009            self.assertEqual(r, dict(
1010                i2='smallint', i4='integer', i8='bigint', d='numeric',
1011                f4='real', f8='double precision', m='money',
1012                v4='character varying', c4='character', t='text'))
1013        else:
1014            self.assertEqual(r, dict(
1015                i2='int', i4='int', i8='int', d='num',
1016                f4='float', f8='float', m='money',
1017                v4='text', c4='text', t='text'))
1018        self.createTable('test_table',
1019                         'n int, alpha smallint, beta bool,'
1020                         ' gamma char(5), tau text, v varchar(3)')
1021        r = get_attnames('test_table')
1022        self.assertIsInstance(r, dict)
1023        if self.regtypes:
1024            self.assertEqual(r, dict(
1025                n='integer', alpha='smallint', beta='boolean',
1026                gamma='character', tau='text', v='character varying'))
1027        else:
1028            self.assertEqual(r, dict(
1029                n='int', alpha='int', beta='bool',
1030                gamma='text', tau='text', v='text'))
1031
1032    def testGetAttnamesWithQuotes(self):
1033        get_attnames = self.db.get_attnames
1034        table = 'test table for get_attnames()'
1035        self.createTable(table,
1036            '"Prime!" smallint, "much space" integer, "Questions?" text')
1037        r = get_attnames(table)
1038        self.assertIsInstance(r, dict)
1039        if self.regtypes:
1040            self.assertEqual(r, {
1041                'Prime!': 'smallint', 'much space': 'integer',
1042                'Questions?': 'text'})
1043        else:
1044            self.assertEqual(r, {
1045                'Prime!': 'int', 'much space': 'int', 'Questions?': 'text'})
1046        table = 'yet another test table for get_attnames()'
1047        self.createTable(table,
1048                         'a smallint, b integer, c bigint,'
1049                         ' e numeric, f real, f2 double precision, m money,'
1050                         ' x smallint, y smallint, z smallint,'
1051                         ' Normal_NaMe smallint, "Special Name" smallint,'
1052                         ' t text, u char(2), v varchar(2),'
1053                         ' primary key (y, u)', oids=True)
1054        r = get_attnames(table)
1055        self.assertIsInstance(r, dict)
1056        if self.regtypes:
1057            self.assertEqual(r, {
1058                'a': 'smallint', 'b': 'integer', 'c': 'bigint',
1059                'e': 'numeric', 'f': 'real', 'f2': 'double precision',
1060                'm': 'money', 'normal_name': 'smallint',
1061                'Special Name': 'smallint', 'u': 'character',
1062                't': 'text', 'v': 'character varying', 'y': 'smallint',
1063                'x': 'smallint', 'z': 'smallint', 'oid': 'oid'})
1064        else:
1065            self.assertEqual(r, {'a': 'int', 'b': 'int', 'c': 'int',
1066                 'e': 'num', 'f': 'float', 'f2': 'float', 'm': 'money',
1067                 'normal_name': 'int', 'Special Name': 'int',
1068                 'u': 'text', 't': 'text', 'v': 'text',
1069                 'y': 'int', 'x': 'int', 'z': 'int', 'oid': 'int'})
1070
1071    def testGetAttnamesWithRegtypes(self):
1072        get_attnames = self.db.get_attnames
1073        self.createTable('test_table',
1074                         ' n int, alpha smallint, beta bool,'
1075                         ' gamma char(5), tau text, v varchar(3)')
1076        use_regtypes = self.db.use_regtypes
1077        regtypes = use_regtypes()
1078        self.assertEqual(regtypes, self.regtypes)
1079        use_regtypes(True)
1080        try:
1081            r = get_attnames("test_table")
1082            self.assertIsInstance(r, dict)
1083        finally:
1084            use_regtypes(regtypes)
1085        self.assertEqual(r, dict(
1086            n='integer', alpha='smallint', beta='boolean',
1087            gamma='character', tau='text', v='character varying'))
1088
1089    def testGetAttnamesWithoutRegtypes(self):
1090        get_attnames = self.db.get_attnames
1091        self.createTable('test_table',
1092                         ' n int, alpha smallint, beta bool,'
1093                         ' gamma char(5), tau text, v varchar(3)')
1094        use_regtypes = self.db.use_regtypes
1095        regtypes = use_regtypes()
1096        self.assertEqual(regtypes, self.regtypes)
1097        use_regtypes(False)
1098        try:
1099            r = get_attnames("test_table")
1100            self.assertIsInstance(r, dict)
1101        finally:
1102            use_regtypes(regtypes)
1103        self.assertEqual(r, dict(
1104            n='int', alpha='int', beta='bool',
1105            gamma='text', tau='text', v='text'))
1106
1107    def testGetAttnamesIsCached(self):
1108        get_attnames = self.db.get_attnames
1109        int_type = 'integer' if self.regtypes else 'int'
1110        text_type = 'text'
1111        query = self.db.query
1112        self.createTable('test_table', 'col int')
1113        r = get_attnames("test_table")
1114        self.assertIsInstance(r, dict)
1115        self.assertEqual(r, dict(col=int_type))
1116        query("alter table test_table alter column col type text")
1117        query("alter table test_table add column col2 int")
1118        r = get_attnames("test_table")
1119        self.assertEqual(r, dict(col=int_type))
1120        r = get_attnames("test_table", flush=True)
1121        self.assertEqual(r, dict(col=text_type, col2=int_type))
1122        query("alter table test_table drop column col2")
1123        r = get_attnames("test_table")
1124        self.assertEqual(r, dict(col=text_type, col2=int_type))
1125        r = get_attnames("test_table", flush=True)
1126        self.assertEqual(r, dict(col=text_type))
1127        query("alter table test_table drop column col")
1128        r = get_attnames("test_table")
1129        self.assertEqual(r, dict(col=text_type))
1130        r = get_attnames("test_table", flush=True)
1131        self.assertEqual(r, dict())
1132
1133    def testGetAttnamesIsOrdered(self):
1134        get_attnames = self.db.get_attnames
1135        r = get_attnames('test', flush=True)
1136        self.assertIsInstance(r, OrderedDict)
1137        if self.regtypes:
1138            self.assertEqual(r, OrderedDict([
1139                ('i2', 'smallint'), ('i4', 'integer'), ('i8', 'bigint'),
1140                ('d', 'numeric'), ('f4', 'real'), ('f8', 'double precision'),
1141                ('m', 'money'), ('v4', 'character varying'),
1142                ('c4', 'character'), ('t', 'text')]))
1143        else:
1144            self.assertEqual(r, OrderedDict([
1145                ('i2', 'int'), ('i4', 'int'), ('i8', 'int'),
1146                ('d', 'num'), ('f4', 'float'), ('f8', 'float'), ('m', 'money'),
1147                ('v4', 'text'), ('c4', 'text'), ('t', 'text')]))
1148        if OrderedDict is not dict:
1149            r = ' '.join(list(r.keys()))
1150            self.assertEqual(r, 'i2 i4 i8 d f4 f8 m v4 c4 t')
1151        table = 'test table for get_attnames'
1152        self.createTable(table,
1153                         ' n int, alpha smallint, v varchar(3),'
1154                         ' gamma char(5), tau text, beta bool')
1155        r = get_attnames(table)
1156        self.assertIsInstance(r, OrderedDict)
1157        if self.regtypes:
1158            self.assertEqual(r, OrderedDict([
1159                ('n', 'integer'), ('alpha', 'smallint'),
1160                ('v', 'character varying'), ('gamma', 'character'),
1161                ('tau', 'text'), ('beta', 'boolean')]))
1162        else:
1163            self.assertEqual(r, OrderedDict([
1164                ('n', 'int'), ('alpha', 'int'), ('v', 'text'),
1165                ('gamma', 'text'), ('tau', 'text'), ('beta', 'bool')]))
1166        if OrderedDict is not dict:
1167            r = ' '.join(list(r.keys()))
1168            self.assertEqual(r, 'n alpha v gamma tau beta')
1169        else:
1170            self.skipTest('OrderedDict is not supported')
1171
1172    def testGetAttnamesIsAttrDict(self):
1173        AttrDict = pg.AttrDict
1174        get_attnames = self.db.get_attnames
1175        r = get_attnames('test', flush=True)
1176        self.assertIsInstance(r, AttrDict)
1177        if self.regtypes:
1178            self.assertEqual(r, AttrDict([
1179                ('i2', 'smallint'), ('i4', 'integer'), ('i8', 'bigint'),
1180                ('d', 'numeric'), ('f4', 'real'), ('f8', 'double precision'),
1181                ('m', 'money'), ('v4', 'character varying'),
1182                ('c4', 'character'), ('t', 'text')]))
1183        else:
1184            self.assertEqual(r, AttrDict([
1185                ('i2', 'int'), ('i4', 'int'), ('i8', 'int'),
1186                ('d', 'num'), ('f4', 'float'), ('f8', 'float'), ('m', 'money'),
1187                ('v4', 'text'), ('c4', 'text'), ('t', 'text')]))
1188        r = ' '.join(list(r.keys()))
1189        self.assertEqual(r, 'i2 i4 i8 d f4 f8 m v4 c4 t')
1190        table = 'test table for get_attnames'
1191        self.createTable(table,
1192                         ' n int, alpha smallint, v varchar(3),'
1193                         ' gamma char(5), tau text, beta bool')
1194        r = get_attnames(table)
1195        self.assertIsInstance(r, AttrDict)
1196        if self.regtypes:
1197            self.assertEqual(r, AttrDict([
1198                ('n', 'integer'), ('alpha', 'smallint'),
1199                ('v', 'character varying'), ('gamma', 'character'),
1200                ('tau', 'text'), ('beta', 'boolean')]))
1201        else:
1202            self.assertEqual(r, AttrDict([
1203                ('n', 'int'), ('alpha', 'int'), ('v', 'text'),
1204                ('gamma', 'text'), ('tau', 'text'), ('beta', 'bool')]))
1205        r = ' '.join(list(r.keys()))
1206        self.assertEqual(r, 'n alpha v gamma tau beta')
1207
1208    def testHasTablePrivilege(self):
1209        can = self.db.has_table_privilege
1210        self.assertEqual(can('test'), True)
1211        self.assertEqual(can('test', 'select'), True)
1212        self.assertEqual(can('test', 'SeLeCt'), True)
1213        self.assertEqual(can('test', 'SELECT'), True)
1214        self.assertEqual(can('test', 'insert'), True)
1215        self.assertEqual(can('test', 'update'), True)
1216        self.assertEqual(can('test', 'delete'), True)
1217        self.assertEqual(can('pg_views', 'select'), True)
1218        self.assertEqual(can('pg_views', 'delete'), False)
1219        self.assertRaises(pg.ProgrammingError, can, 'test', 'foobar')
1220        self.assertRaises(pg.ProgrammingError, can, 'table_does_not_exist')
1221
1222    def testGet(self):
1223        get = self.db.get
1224        query = self.db.query
1225        table = 'get_test_table'
1226        self.assertRaises(TypeError, get)
1227        self.assertRaises(TypeError, get, table)
1228        self.createTable(table, 'n integer, t text',
1229                         values=enumerate('xyz', start=1))
1230        self.assertRaises(pg.ProgrammingError, get, table, 2)
1231        r = get(table, 2, 'n')
1232        self.assertIsInstance(r, dict)
1233        self.assertEqual(r, dict(n=2, t='y'))
1234        r = get(table, 1, 'n')
1235        self.assertEqual(r, dict(n=1, t='x'))
1236        r = get(table, (3,), ('n',))
1237        self.assertEqual(r, dict(n=3, t='z'))
1238        r = get(table, 'y', 't')
1239        self.assertEqual(r, dict(n=2, t='y'))
1240        self.assertRaises(pg.DatabaseError, get, table, 4)
1241        self.assertRaises(pg.DatabaseError, get, table, 4, 'n')
1242        self.assertRaises(pg.DatabaseError, get, table, 'y')
1243        self.assertRaises(pg.DatabaseError, get, table, 2, 't')
1244        s = dict(n=3)
1245        self.assertRaises(pg.ProgrammingError, get, table, s)
1246        r = get(table, s, 'n')
1247        self.assertIs(r, s)
1248        self.assertEqual(r, dict(n=3, t='z'))
1249        s.update(t='x')
1250        r = get(table, s, 't')
1251        self.assertIs(r, s)
1252        self.assertEqual(s, dict(n=1, t='x'))
1253        r = get(table, s, ('n', 't'))
1254        self.assertIs(r, s)
1255        self.assertEqual(r, dict(n=1, t='x'))
1256        query('alter table "%s" alter n set not null' % table)
1257        query('alter table "%s" add primary key (n)' % table)
1258        r = get(table, 2)
1259        self.assertIsInstance(r, dict)
1260        self.assertEqual(r, dict(n=2, t='y'))
1261        self.assertEqual(get(table, 1)['t'], 'x')
1262        self.assertEqual(get(table, 3)['t'], 'z')
1263        self.assertEqual(get(table + '*', 2)['t'], 'y')
1264        self.assertEqual(get(table + ' *', 2)['t'], 'y')
1265        self.assertRaises(KeyError, get, table, (2, 2))
1266        s = dict(n=3)
1267        r = get(table, s)
1268        self.assertIs(r, s)
1269        self.assertEqual(r, dict(n=3, t='z'))
1270        s.update(n=1)
1271        self.assertEqual(get(table, s)['t'], 'x')
1272        s.update(n=2)
1273        self.assertEqual(get(table, r)['t'], 'y')
1274        s.pop('n')
1275        self.assertRaises(KeyError, get, table, s)
1276
1277    def testGetWithOid(self):
1278        get = self.db.get
1279        query = self.db.query
1280        table = 'get_with_oid_test_table'
1281        self.createTable(table, 'n integer, t text', oids=True,
1282                         values=enumerate('xyz', start=1))
1283        self.assertRaises(pg.ProgrammingError, get, table, 2)
1284        self.assertRaises(KeyError, get, table, {}, 'oid')
1285        r = get(table, 2, 'n')
1286        qoid = 'oid(%s)' % table
1287        self.assertIn(qoid, r)
1288        oid = r[qoid]
1289        self.assertIsInstance(oid, int)
1290        result = {'t': 'y', 'n': 2, qoid: oid}
1291        self.assertEqual(r, result)
1292        r = get(table, oid, 'oid')
1293        self.assertEqual(r, result)
1294        r = get(table, dict(oid=oid))
1295        self.assertEqual(r, result)
1296        r = get(table, dict(oid=oid), 'oid')
1297        self.assertEqual(r, result)
1298        r = get(table, {qoid: oid})
1299        self.assertEqual(r, result)
1300        r = get(table, {qoid: oid}, 'oid')
1301        self.assertEqual(r, result)
1302        self.assertEqual(get(table + '*', 2, 'n'), r)
1303        self.assertEqual(get(table + ' *', 2, 'n'), r)
1304        self.assertEqual(get(table, oid, 'oid')['t'], 'y')
1305        self.assertEqual(get(table, 1, 'n')['t'], 'x')
1306        self.assertEqual(get(table, 3, 'n')['t'], 'z')
1307        self.assertEqual(get(table, 2, 'n')['t'], 'y')
1308        self.assertRaises(pg.DatabaseError, get, table, 4, 'n')
1309        r['n'] = 3
1310        self.assertEqual(get(table, r, 'n')['t'], 'z')
1311        self.assertEqual(get(table, 1, 'n')['t'], 'x')
1312        self.assertEqual(get(table, r, 'oid')['t'], 'z')
1313        query('alter table "%s" alter n set not null' % table)
1314        query('alter table "%s" add primary key (n)' % table)
1315        self.assertEqual(get(table, 3)['t'], 'z')
1316        self.assertEqual(get(table, 1)['t'], 'x')
1317        self.assertEqual(get(table, 2)['t'], 'y')
1318        r['n'] = 1
1319        self.assertEqual(get(table, r)['t'], 'x')
1320        r['n'] = 3
1321        self.assertEqual(get(table, r)['t'], 'z')
1322        r['n'] = 2
1323        self.assertEqual(get(table, r)['t'], 'y')
1324        r = get(table, oid, 'oid')
1325        self.assertEqual(r, result)
1326        r = get(table, dict(oid=oid))
1327        self.assertEqual(r, result)
1328        r = get(table, dict(oid=oid), 'oid')
1329        self.assertEqual(r, result)
1330        r = get(table, {qoid: oid})
1331        self.assertEqual(r, result)
1332        r = get(table, {qoid: oid}, 'oid')
1333        self.assertEqual(r, result)
1334        r = get(table, dict(oid=oid, n=1))
1335        self.assertEqual(r['n'], 1)
1336        self.assertNotEqual(r[qoid], oid)
1337        r = get(table, dict(oid=oid, t='z'), 't')
1338        self.assertEqual(r['n'], 3)
1339        self.assertNotEqual(r[qoid], oid)
1340
1341    def testGetWithCompositeKey(self):
1342        get = self.db.get
1343        query = self.db.query
1344        table = 'get_test_table_1'
1345        self.createTable(table, 'n integer primary key, t text',
1346                         values=enumerate('abc', start=1))
1347        self.assertEqual(get(table, 2)['t'], 'b')
1348        self.assertEqual(get(table, 1, 'n')['t'], 'a')
1349        self.assertEqual(get(table, 2, ('n',))['t'], 'b')
1350        self.assertEqual(get(table, 3, ['n'])['t'], 'c')
1351        self.assertEqual(get(table, (2,), ('n',))['t'], 'b')
1352        self.assertEqual(get(table, 'b', 't')['n'], 2)
1353        self.assertEqual(get(table, ('a',), ('t',))['n'], 1)
1354        self.assertEqual(get(table, ['c'], ['t'])['n'], 3)
1355        table = 'get_test_table_2'
1356        self.createTable(table,
1357                         'n integer, m integer, t text, primary key (n, m)',
1358                         values=[(n + 1, m + 1, chr(ord('a') + 2 * n + m))
1359                                 for n in range(3) for m in range(2)])
1360        self.assertRaises(KeyError, get, table, 2)
1361        self.assertEqual(get(table, (1, 1))['t'], 'a')
1362        self.assertEqual(get(table, (1, 2))['t'], 'b')
1363        self.assertEqual(get(table, (2, 1))['t'], 'c')
1364        self.assertEqual(get(table, (1, 2), ('n', 'm'))['t'], 'b')
1365        self.assertEqual(get(table, (1, 2), ('m', 'n'))['t'], 'c')
1366        self.assertEqual(get(table, (3, 1), ('n', 'm'))['t'], 'e')
1367        self.assertEqual(get(table, (1, 3), ('m', 'n'))['t'], 'e')
1368        self.assertEqual(get(table, dict(n=2, m=2))['t'], 'd')
1369        self.assertEqual(get(table, dict(n=1, m=2), ('n', 'm'))['t'], 'b')
1370        self.assertEqual(get(table, dict(n=2, m=1), ['n', 'm'])['t'], 'c')
1371        self.assertEqual(get(table, dict(n=3, m=2), ('m', 'n'))['t'], 'f')
1372
1373    def testGetWithQuotedNames(self):
1374        get = self.db.get
1375        query = self.db.query
1376        table = 'test table for get()'
1377        self.createTable(table, '"Prime!" smallint primary key,'
1378                                ' "much space" integer, "Questions?" text',
1379                         values=[(17, 1001, 'No!')])
1380        r = get(table, 17)
1381        self.assertIsInstance(r, dict)
1382        self.assertEqual(r['Prime!'], 17)
1383        self.assertEqual(r['much space'], 1001)
1384        self.assertEqual(r['Questions?'], 'No!')
1385
1386    def testGetFromView(self):
1387        self.db.query('delete from test where i4=14')
1388        self.db.query('insert into test (i4, v4) values('
1389                      "14, 'abc4')")
1390        r = self.db.get('test_view', 14, 'i4')
1391        self.assertIn('v4', r)
1392        self.assertEqual(r['v4'], 'abc4')
1393
1394    def testGetLittleBobbyTables(self):
1395        get = self.db.get
1396        query = self.db.query
1397        self.createTable('test_students',
1398                         'firstname varchar primary key, nickname varchar, grade char(2)',
1399                         values=[("D'Arcy", 'Darcey', 'A+'), ('Sheldon', 'Moonpie', 'A+'),
1400                                 ('Robert', 'Little Bobby Tables', 'D-')])
1401        r = get('test_students', 'Sheldon')
1402        self.assertEqual(r, dict(
1403            firstname="Sheldon", nickname='Moonpie', grade='A+'))
1404        r = get('test_students', 'Robert')
1405        self.assertEqual(r, dict(
1406            firstname="Robert", nickname='Little Bobby Tables', grade='D-'))
1407        r = get('test_students', "D'Arcy")
1408        self.assertEqual(r, dict(
1409            firstname="D'Arcy", nickname='Darcey', grade='A+'))
1410        try:
1411            get('test_students', "D' Arcy")
1412        except pg.DatabaseError as error:
1413            self.assertEqual(str(error),
1414                'No such record in test_students\nwhere "firstname" = $1\n'
1415                'with $1="D\' Arcy"')
1416        try:
1417            get('test_students', "Robert'); TRUNCATE TABLE test_students;--")
1418        except pg.DatabaseError as error:
1419            self.assertEqual(str(error),
1420                'No such record in test_students\nwhere "firstname" = $1\n'
1421                'with $1="Robert\'); TRUNCATE TABLE test_students;--"')
1422        q = "select * from test_students order by 1 limit 4"
1423        r = query(q).getresult()
1424        self.assertEqual(len(r), 3)
1425        self.assertEqual(r[1][2], 'D-')
1426
1427    def testInsert(self):
1428        insert = self.db.insert
1429        query = self.db.query
1430        bool_on = pg.get_bool()
1431        decimal = pg.get_decimal()
1432        table = 'insert_test_table'
1433        self.createTable(table,
1434            'i2 smallint, i4 integer, i8 bigint,'
1435            ' d numeric, f4 real, f8 double precision, m money,'
1436            ' v4 varchar(4), c4 char(4), t text,'
1437            ' b boolean, ts timestamp', oids=True)
1438        oid_table = 'oid(%s)' % table
1439        tests = [dict(i2=None, i4=None, i8=None),
1440             (dict(i2='', i4='', i8=''), dict(i2=None, i4=None, i8=None)),
1441             (dict(i2=0, i4=0, i8=0), dict(i2=0, i4=0, i8=0)),
1442             dict(i2=42, i4=123456, i8=9876543210),
1443             dict(i2=2 ** 15 - 1,
1444                  i4=int(2 ** 31 - 1), i8=long(2 ** 63 - 1)),
1445             dict(d=None), (dict(d=''), dict(d=None)),
1446             dict(d=Decimal(0)), (dict(d=0), dict(d=Decimal(0))),
1447             dict(f4=None, f8=None), dict(f4=0, f8=0),
1448             (dict(f4='', f8=''), dict(f4=None, f8=None)),
1449             (dict(d=1234.5, f4=1234.5, f8=1234.5),
1450              dict(d=Decimal('1234.5'))),
1451             dict(d=Decimal('123.456789'), f4=12.375, f8=123.4921875),
1452             dict(d=Decimal('123456789.9876543212345678987654321')),
1453             dict(m=None), (dict(m=''), dict(m=None)),
1454             dict(m=Decimal('-1234.56')),
1455             (dict(m=('-1234.56')), dict(m=Decimal('-1234.56'))),
1456             dict(m=Decimal('1234.56')), dict(m=Decimal('123456')),
1457             (dict(m='1234.56'), dict(m=Decimal('1234.56'))),
1458             (dict(m=1234.5), dict(m=Decimal('1234.5'))),
1459             (dict(m=-1234.5), dict(m=Decimal('-1234.5'))),
1460             (dict(m=123456), dict(m=Decimal('123456'))),
1461             (dict(m='1234567.89'), dict(m=Decimal('1234567.89'))),
1462             dict(b=None), (dict(b=''), dict(b=None)),
1463             dict(b='f'), dict(b='t'),
1464             (dict(b=0), dict(b='f')), (dict(b=1), dict(b='t')),
1465             (dict(b=False), dict(b='f')), (dict(b=True), dict(b='t')),
1466             (dict(b='0'), dict(b='f')), (dict(b='1'), dict(b='t')),
1467             (dict(b='n'), dict(b='f')), (dict(b='y'), dict(b='t')),
1468             (dict(b='no'), dict(b='f')), (dict(b='yes'), dict(b='t')),
1469             (dict(b='off'), dict(b='f')), (dict(b='on'), dict(b='t')),
1470             dict(v4=None, c4=None, t=None),
1471             (dict(v4='', c4='', t=''), dict(c4=' ' * 4)),
1472             dict(v4='1234', c4='1234', t='1234' * 10),
1473             dict(v4='abcd', c4='abcd', t='abcdefg'),
1474             (dict(v4='abc', c4='abc', t='abc'), dict(c4='abc ')),
1475             dict(ts=None), (dict(ts=''), dict(ts=None)),
1476             (dict(ts=0), dict(ts=None)), (dict(ts=False), dict(ts=None)),
1477             dict(ts='2012-12-21 00:00:00'),
1478             (dict(ts='2012-12-21'), dict(ts='2012-12-21 00:00:00')),
1479             dict(ts='2012-12-21 12:21:12'),
1480             dict(ts='2013-01-05 12:13:14'),
1481             dict(ts='current_timestamp')]
1482        for test in tests:
1483            if isinstance(test, dict):
1484                data = test
1485                change = {}
1486            else:
1487                data, change = test
1488            expect = data.copy()
1489            expect.update(change)
1490            if bool_on:
1491                b = expect.get('b')
1492                if b is not None:
1493                    expect['b'] = b == 't'
1494            if decimal is not Decimal:
1495                d = expect.get('d')
1496                if d is not None:
1497                    expect['d'] = decimal(d)
1498                m = expect.get('m')
1499                if m is not None:
1500                    expect['m'] = decimal(m)
1501            self.assertEqual(insert(table, data), data)
1502            self.assertIn(oid_table, data)
1503            oid = data[oid_table]
1504            self.assertIsInstance(oid, int)
1505            data = dict(item for item in data.items()
1506                        if item[0] in expect)
1507            ts = expect.get('ts')
1508            if ts == 'current_timestamp':
1509                ts = expect['ts'] = data['ts']
1510                if len(ts) > 19:
1511                    self.assertEqual(ts[19], '.')
1512                    ts = ts[:19]
1513                else:
1514                    self.assertEqual(len(ts), 19)
1515                self.assertTrue(ts[:4].isdigit())
1516                self.assertEqual(ts[4], '-')
1517                self.assertEqual(ts[10], ' ')
1518                self.assertTrue(ts[11:13].isdigit())
1519                self.assertEqual(ts[13], ':')
1520            self.assertEqual(data, expect)
1521            data = query(
1522                'select oid,* from "%s"' % table).dictresult()[0]
1523            self.assertEqual(data['oid'], oid)
1524            data = dict(item for item in data.items()
1525                        if item[0] in expect)
1526            self.assertEqual(data, expect)
1527            query('delete from "%s"' % table)
1528
1529    def testInsertWithOid(self):
1530        insert = self.db.insert
1531        query = self.db.query
1532        self.createTable('test_table', 'n int', oids=True)
1533        self.assertRaises(pg.ProgrammingError, insert, 'test_table', m=1)
1534        r = insert('test_table', n=1)
1535        self.assertIsInstance(r, dict)
1536        self.assertEqual(r['n'], 1)
1537        self.assertNotIn('oid', r)
1538        qoid = 'oid(test_table)'
1539        self.assertIn(qoid, r)
1540        oid = r[qoid]
1541        self.assertEqual(sorted(r.keys()), ['n', qoid])
1542        r = insert('test_table', n=2, oid=oid)
1543        self.assertIsInstance(r, dict)
1544        self.assertEqual(r['n'], 2)
1545        self.assertIn(qoid, r)
1546        self.assertNotEqual(r[qoid], oid)
1547        self.assertNotIn('oid', r)
1548        r = insert('test_table', None, n=3)
1549        self.assertIsInstance(r, dict)
1550        self.assertEqual(r['n'], 3)
1551        s = r
1552        r = insert('test_table', r)
1553        self.assertIs(r, s)
1554        self.assertEqual(r['n'], 3)
1555        r = insert('test_table *', r)
1556        self.assertIs(r, s)
1557        self.assertEqual(r['n'], 3)
1558        r = insert('test_table', r, n=4)
1559        self.assertIs(r, s)
1560        self.assertEqual(r['n'], 4)
1561        self.assertNotIn('oid', r)
1562        self.assertIn(qoid, r)
1563        oid = r[qoid]
1564        r = insert('test_table', r, n=5, oid=oid)
1565        self.assertIs(r, s)
1566        self.assertEqual(r['n'], 5)
1567        self.assertIn(qoid, r)
1568        self.assertNotEqual(r[qoid], oid)
1569        self.assertNotIn('oid', r)
1570        r['oid'] = oid = r[qoid]
1571        r = insert('test_table', r, n=6)
1572        self.assertIs(r, s)
1573        self.assertEqual(r['n'], 6)
1574        self.assertIn(qoid, r)
1575        self.assertNotEqual(r[qoid], oid)
1576        self.assertNotIn('oid', r)
1577        q = 'select n from test_table order by 1 limit 9'
1578        r = ' '.join(str(row[0]) for row in query(q).getresult())
1579        self.assertEqual(r, '1 2 3 3 3 4 5 6')
1580        query("truncate test_table")
1581        query("alter table test_table add unique (n)")
1582        r = insert('test_table', dict(n=7))
1583        self.assertIsInstance(r, dict)
1584        self.assertEqual(r['n'], 7)
1585        self.assertRaises(pg.ProgrammingError, insert, 'test_table', r)
1586        r['n'] = 6
1587        self.assertRaises(pg.ProgrammingError, insert, 'test_table', r, n=7)
1588        self.assertIsInstance(r, dict)
1589        self.assertEqual(r['n'], 7)
1590        r['n'] = 6
1591        r = insert('test_table', r)
1592        self.assertIsInstance(r, dict)
1593        self.assertEqual(r['n'], 6)
1594        r = ' '.join(str(row[0]) for row in query(q).getresult())
1595        self.assertEqual(r, '6 7')
1596
1597    def testInsertWithQuotedNames(self):
1598        insert = self.db.insert
1599        query = self.db.query
1600        table = 'test table for insert()'
1601        self.createTable(table, '"Prime!" smallint primary key,'
1602                                ' "much space" integer, "Questions?" text')
1603        r = {'Prime!': 11, 'much space': 2002, 'Questions?': 'What?'}
1604        r = insert(table, r)
1605        self.assertIsInstance(r, dict)
1606        self.assertEqual(r['Prime!'], 11)
1607        self.assertEqual(r['much space'], 2002)
1608        self.assertEqual(r['Questions?'], 'What?')
1609        r = query('select * from "%s" limit 2' % table).dictresult()
1610        self.assertEqual(len(r), 1)
1611        r = r[0]
1612        self.assertEqual(r['Prime!'], 11)
1613        self.assertEqual(r['much space'], 2002)
1614        self.assertEqual(r['Questions?'], 'What?')
1615
1616    def testInsertIntoView(self):
1617        insert = self.db.insert
1618        query = self.db.query
1619        query("truncate test")
1620        q = 'select * from test_view order by i4 limit 3'
1621        r = query(q).getresult()
1622        self.assertEqual(r, [])
1623        r = dict(i4=1234, v4='abcd')
1624        insert('test', r)
1625        self.assertIsNone(r['i2'])
1626        self.assertEqual(r['i4'], 1234)
1627        self.assertIsNone(r['i8'])
1628        self.assertEqual(r['v4'], 'abcd')
1629        self.assertIsNone(r['c4'])
1630        r = query(q).getresult()
1631        self.assertEqual(r, [(1234, 'abcd')])
1632        r = dict(i4=5678, v4='efgh')
1633        try:
1634            insert('test_view', r)
1635        except pg.ProgrammingError as error:
1636            if self.db.server_version < 90300:
1637                # must setup rules in older PostgreSQL versions
1638                self.skipTest('database cannot insert into view')
1639            self.fail(str(error))
1640        self.assertNotIn('i2', r)
1641        self.assertEqual(r['i4'], 5678)
1642        self.assertNotIn('i8', r)
1643        self.assertEqual(r['v4'], 'efgh')
1644        self.assertNotIn('c4', r)
1645        r = query(q).getresult()
1646        self.assertEqual(r, [(1234, 'abcd'), (5678, 'efgh')])
1647
1648    def testUpdate(self):
1649        update = self.db.update
1650        query = self.db.query
1651        self.assertRaises(pg.ProgrammingError, update,
1652                          'test', i2=2, i4=4, i8=8)
1653        table = 'update_test_table'
1654        self.createTable(table, 'n integer, t text', oids=True,
1655                         values=enumerate('xyz', start=1))
1656        self.assertRaises(pg.ProgrammingError, self.db.get, table, 2)
1657        r = self.db.get(table, 2, 'n')
1658        r['t'] = 'u'
1659        s = update(table, r)
1660        self.assertEqual(s, r)
1661        q = 'select t from "%s" where n=2' % table
1662        r = query(q).getresult()[0][0]
1663        self.assertEqual(r, 'u')
1664
1665    def testUpdateWithOid(self):
1666        update = self.db.update
1667        get = self.db.get
1668        query = self.db.query
1669        self.createTable('test_table', 'n int', oids=True, values=[1])
1670        s = get('test_table', 1, 'n')
1671        self.assertIsInstance(s, dict)
1672        self.assertEqual(s['n'], 1)
1673        s['n'] = 2
1674        r = update('test_table', s)
1675        self.assertIs(r, s)
1676        self.assertEqual(r['n'], 2)
1677        qoid = 'oid(test_table)'
1678        self.assertIn(qoid, r)
1679        self.assertNotIn('oid', r)
1680        self.assertEqual(sorted(r.keys()), ['n', qoid])
1681        r['n'] = 3
1682        oid = r.pop(qoid)
1683        r = update('test_table', r, oid=oid)
1684        self.assertIs(r, s)
1685        self.assertEqual(r['n'], 3)
1686        r.pop(qoid)
1687        self.assertRaises(pg.ProgrammingError, update, 'test_table', r)
1688        s = get('test_table', 3, 'n')
1689        self.assertIsInstance(s, dict)
1690        self.assertEqual(s['n'], 3)
1691        s.pop('n')
1692        r = update('test_table', s)
1693        oid = r.pop(qoid)
1694        self.assertEqual(r, {})
1695        q = "select n from test_table limit 2"
1696        r = query(q).getresult()
1697        self.assertEqual(r, [(3,)])
1698        query("insert into test_table values (1)")
1699        self.assertRaises(pg.ProgrammingError,
1700                          update, 'test_table', dict(oid=oid, n=4))
1701        r = update('test_table', dict(n=4), oid=oid)
1702        self.assertEqual(r['n'], 4)
1703        r = update('test_table *', dict(n=5), oid=oid)
1704        self.assertEqual(r['n'], 5)
1705        query("alter table test_table add column m int")
1706        query("alter table test_table add primary key (n)")
1707        self.assertIn('m', self.db.get_attnames('test_table', flush=True))
1708        self.assertEqual('n', self.db.pkey('test_table', flush=True))
1709        s = dict(n=1, m=4)
1710        r = update('test_table', s)
1711        self.assertIs(r, s)
1712        self.assertEqual(r['n'], 1)
1713        self.assertEqual(r['m'], 4)
1714        s = dict(m=7)
1715        r = update('test_table', s, n=5)
1716        self.assertIs(r, s)
1717        self.assertEqual(r['n'], 5)
1718        self.assertEqual(r['m'], 7)
1719        q = "select n, m from test_table order by 1 limit 3"
1720        r = query(q).getresult()
1721        self.assertEqual(r, [(1, 4), (5, 7)])
1722        s = dict(m=9, oid=oid)
1723        self.assertRaises(KeyError, update, 'test_table', s)
1724        r = update('test_table', s, oid=oid)
1725        self.assertIs(r, s)
1726        self.assertEqual(r['n'], 5)
1727        self.assertEqual(r['m'], 9)
1728        s = dict(n=1, m=3, oid=oid)
1729        r = update('test_table', s)
1730        self.assertIs(r, s)
1731        self.assertEqual(r['n'], 1)
1732        self.assertEqual(r['m'], 3)
1733        r = query(q).getresult()
1734        self.assertEqual(r, [(1, 3), (5, 9)])
1735
1736    def testUpdateWithoutOid(self):
1737        update = self.db.update
1738        query = self.db.query
1739        self.assertRaises(pg.ProgrammingError, update,
1740                          'test', i2=2, i4=4, i8=8)
1741        table = 'update_test_table'
1742        self.createTable(table, 'n integer primary key, t text', oids=False,
1743                         values=enumerate('xyz', start=1))
1744        r = self.db.get(table, 2)
1745        r['t'] = 'u'
1746        s = update(table, r)
1747        self.assertEqual(s, r)
1748        q = 'select t from "%s" where n=2' % table
1749        r = query(q).getresult()[0][0]
1750        self.assertEqual(r, 'u')
1751
1752    def testUpdateWithCompositeKey(self):
1753        update = self.db.update
1754        query = self.db.query
1755        table = 'update_test_table_1'
1756        self.createTable(table, 'n integer primary key, t text',
1757                         values=enumerate('abc', start=1))
1758        self.assertRaises(KeyError, update, table, dict(t='b'))
1759        s = dict(n=2, t='d')
1760        r = update(table, s)
1761        self.assertIs(r, s)
1762        self.assertEqual(r['n'], 2)
1763        self.assertEqual(r['t'], 'd')
1764        q = 'select t from "%s" where n=2' % table
1765        r = query(q).getresult()[0][0]
1766        self.assertEqual(r, 'd')
1767        s.update(dict(n=4, t='e'))
1768        r = update(table, s)
1769        self.assertEqual(r['n'], 4)
1770        self.assertEqual(r['t'], 'e')
1771        q = 'select t from "%s" where n=2' % table
1772        r = query(q).getresult()[0][0]
1773        self.assertEqual(r, 'd')
1774        q = 'select t from "%s" where n=4' % table
1775        r = query(q).getresult()
1776        self.assertEqual(len(r), 0)
1777        query('drop table "%s"' % table)
1778        table = 'update_test_table_2'
1779        self.createTable(table,
1780                         'n integer, m integer, t text, primary key (n, m)',
1781                         values=[(n + 1, m + 1, chr(ord('a') + 2 * n + m))
1782                                 for n in range(3) for m in range(2)])
1783        self.assertRaises(KeyError, update, table, dict(n=2, t='b'))
1784        self.assertEqual(update(table,
1785                                dict(n=2, m=2, t='x'))['t'], 'x')
1786        q = 'select t from "%s" where n=2 order by m' % table
1787        r = [r[0] for r in query(q).getresult()]
1788        self.assertEqual(r, ['c', 'x'])
1789
1790    def testUpdateWithQuotedNames(self):
1791        update = self.db.update
1792        query = self.db.query
1793        table = 'test table for update()'
1794        self.createTable(table, '"Prime!" smallint primary key,'
1795                                ' "much space" integer, "Questions?" text',
1796                         values=[(13, 3003, 'Why!')])
1797        r = {'Prime!': 13, 'much space': 7007, 'Questions?': 'When?'}
1798        r = update(table, r)
1799        self.assertIsInstance(r, dict)
1800        self.assertEqual(r['Prime!'], 13)
1801        self.assertEqual(r['much space'], 7007)
1802        self.assertEqual(r['Questions?'], 'When?')
1803        r = query('select * from "%s" limit 2' % table).dictresult()
1804        self.assertEqual(len(r), 1)
1805        r = r[0]
1806        self.assertEqual(r['Prime!'], 13)
1807        self.assertEqual(r['much space'], 7007)
1808        self.assertEqual(r['Questions?'], 'When?')
1809
1810    def testUpsert(self):
1811        upsert = self.db.upsert
1812        query = self.db.query
1813        self.assertRaises(pg.ProgrammingError, upsert,
1814                          'test', i2=2, i4=4, i8=8)
1815        table = 'upsert_test_table'
1816        self.createTable(table, 'n integer primary key, t text', oids=True)
1817        s = dict(n=1, t='x')
1818        try:
1819            r = upsert(table, s)
1820        except pg.ProgrammingError as error:
1821            if self.db.server_version < 90500:
1822                self.skipTest('database does not support upsert')
1823            self.fail(str(error))
1824        self.assertIs(r, s)
1825        self.assertEqual(r['n'], 1)
1826        self.assertEqual(r['t'], 'x')
1827        s.update(n=2, t='y')
1828        r = upsert(table, s, **dict.fromkeys(s))
1829        self.assertIs(r, s)
1830        self.assertEqual(r['n'], 2)
1831        self.assertEqual(r['t'], 'y')
1832        q = 'select n, t from "%s" order by n limit 3' % table
1833        r = query(q).getresult()
1834        self.assertEqual(r, [(1, 'x'), (2, 'y')])
1835        s.update(t='z')
1836        r = upsert(table, s)
1837        self.assertIs(r, s)
1838        self.assertEqual(r['n'], 2)
1839        self.assertEqual(r['t'], 'z')
1840        r = query(q).getresult()
1841        self.assertEqual(r, [(1, 'x'), (2, 'z')])
1842        s.update(t='n')
1843        r = upsert(table, s, t=False)
1844        self.assertIs(r, s)
1845        self.assertEqual(r['n'], 2)
1846        self.assertEqual(r['t'], 'z')
1847        r = query(q).getresult()
1848        self.assertEqual(r, [(1, 'x'), (2, 'z')])
1849        s.update(t='y')
1850        r = upsert(table, s, t=True)
1851        self.assertIs(r, s)
1852        self.assertEqual(r['n'], 2)
1853        self.assertEqual(r['t'], 'y')
1854        r = query(q).getresult()
1855        self.assertEqual(r, [(1, 'x'), (2, 'y')])
1856        s.update(t='n')
1857        r = upsert(table, s, t="included.t || '2'")
1858        self.assertIs(r, s)
1859        self.assertEqual(r['n'], 2)
1860        self.assertEqual(r['t'], 'y2')
1861        r = query(q).getresult()
1862        self.assertEqual(r, [(1, 'x'), (2, 'y2')])
1863        s.update(t='y')
1864        r = upsert(table, s, t="excluded.t || '3'")
1865        self.assertIs(r, s)
1866        self.assertEqual(r['n'], 2)
1867        self.assertEqual(r['t'], 'y3')
1868        r = query(q).getresult()
1869        self.assertEqual(r, [(1, 'x'), (2, 'y3')])
1870        s.update(n=1, t='2')
1871        r = upsert(table, s, t="included.t || excluded.t")
1872        self.assertIs(r, s)
1873        self.assertEqual(r['n'], 1)
1874        self.assertEqual(r['t'], 'x2')
1875        r = query(q).getresult()
1876        self.assertEqual(r, [(1, 'x2'), (2, 'y3')])
1877        # not existing columns and oid parameter should be ignored
1878        s = dict(m=3, u='z')
1879        r = upsert(table, s, oid='invalid')
1880        self.assertIs(r, s)
1881
1882    def testUpsertWithOid(self):
1883        upsert = self.db.upsert
1884        get = self.db.get
1885        query = self.db.query
1886        self.createTable('test_table', 'n int', oids=True, values=[1])
1887        self.assertRaises(pg.ProgrammingError,
1888                          upsert, 'test_table', dict(n=2))
1889        r = get('test_table', 1, 'n')
1890        self.assertIsInstance(r, dict)
1891        self.assertEqual(r['n'], 1)
1892        qoid = 'oid(test_table)'
1893        self.assertIn(qoid, r)
1894        self.assertNotIn('oid', r)
1895        oid = r[qoid]
1896        self.assertRaises(pg.ProgrammingError,
1897                          upsert, 'test_table', dict(n=2, oid=oid))
1898        query("alter table test_table add column m int")
1899        query("alter table test_table add primary key (n)")
1900        self.assertIn('m', self.db.get_attnames('test_table', flush=True))
1901        self.assertEqual('n', self.db.pkey('test_table', flush=True))
1902        s = dict(n=2)
1903        try:
1904            r = upsert('test_table', s)
1905        except pg.ProgrammingError as error:
1906            if self.db.server_version < 90500:
1907                self.skipTest('database does not support upsert')
1908            self.fail(str(error))
1909        self.assertIs(r, s)
1910        self.assertEqual(r['n'], 2)
1911        self.assertIsNone(r['m'])
1912        q = query("select n, m from test_table order by n limit 3")
1913        self.assertEqual(q.getresult(), [(1, None), (2, None)])
1914        r['oid'] = oid
1915        r = upsert('test_table', r)
1916        self.assertIs(r, s)
1917        self.assertEqual(r['n'], 2)
1918        self.assertIsNone(r['m'])
1919        self.assertIn(qoid, r)
1920        self.assertNotIn('oid', r)
1921        self.assertNotEqual(r[qoid], oid)
1922        r['m'] = 7
1923        r = upsert('test_table', r)
1924        self.assertIs(r, s)
1925        self.assertEqual(r['n'], 2)
1926        self.assertEqual(r['m'], 7)
1927        r.update(n=1, m=3)
1928        r = upsert('test_table', r)
1929        self.assertIs(r, s)
1930        self.assertEqual(r['n'], 1)
1931        self.assertEqual(r['m'], 3)
1932        q = query("select n, m from test_table order by n limit 3")
1933        self.assertEqual(q.getresult(), [(1, 3), (2, 7)])
1934        r = upsert('test_table', r, oid='invalid')
1935        self.assertIs(r, s)
1936        self.assertEqual(r['n'], 1)
1937        self.assertEqual(r['m'], 3)
1938        r['m'] = 5
1939        r = upsert('test_table', r, m=False)
1940        self.assertIs(r, s)
1941        self.assertEqual(r['n'], 1)
1942        self.assertEqual(r['m'], 3)
1943        r['m'] = 5
1944        r = upsert('test_table', r, m=True)
1945        self.assertIs(r, s)
1946        self.assertEqual(r['n'], 1)
1947        self.assertEqual(r['m'], 5)
1948        r.update(n=2, m=1)
1949        r = upsert('test_table', r, m='included.m')
1950        self.assertIs(r, s)
1951        self.assertEqual(r['n'], 2)
1952        self.assertEqual(r['m'], 7)
1953        r['m'] = 9
1954        r = upsert('test_table', r, m='excluded.m')
1955        self.assertIs(r, s)
1956        self.assertEqual(r['n'], 2)
1957        self.assertEqual(r['m'], 9)
1958        r['m'] = 8
1959        r = upsert('test_table *', r, m='included.m + 1')
1960        self.assertIs(r, s)
1961        self.assertEqual(r['n'], 2)
1962        self.assertEqual(r['m'], 10)
1963        q = query("select n, m from test_table order by n limit 3")
1964        self.assertEqual(q.getresult(), [(1, 5), (2, 10)])
1965
1966    def testUpsertWithCompositeKey(self):
1967        upsert = self.db.upsert
1968        query = self.db.query
1969        table = 'upsert_test_table_2'
1970        self.createTable(table,
1971                         'n integer, m integer, t text, primary key (n, m)')
1972        s = dict(n=1, m=2, t='x')
1973        try:
1974            r = upsert(table, s)
1975        except pg.ProgrammingError as error:
1976            if self.db.server_version < 90500:
1977                self.skipTest('database does not support upsert')
1978            self.fail(str(error))
1979        self.assertIs(r, s)
1980        self.assertEqual(r['n'], 1)
1981        self.assertEqual(r['m'], 2)
1982        self.assertEqual(r['t'], 'x')
1983        s.update(m=3, t='y')
1984        r = upsert(table, s, **dict.fromkeys(s))
1985        self.assertIs(r, s)
1986        self.assertEqual(r['n'], 1)
1987        self.assertEqual(r['m'], 3)
1988        self.assertEqual(r['t'], 'y')
1989        q = 'select n, m, t from "%s" order by n, m limit 3' % table
1990        r = query(q).getresult()
1991        self.assertEqual(r, [(1, 2, 'x'), (1, 3, 'y')])
1992        s.update(t='z')
1993        r = upsert(table, s)
1994        self.assertIs(r, s)
1995        self.assertEqual(r['n'], 1)
1996        self.assertEqual(r['m'], 3)
1997        self.assertEqual(r['t'], 'z')
1998        r = query(q).getresult()
1999        self.assertEqual(r, [(1, 2, 'x'), (1, 3, 'z')])
2000        s.update(t='n')
2001        r = upsert(table, s, t=False)
2002        self.assertIs(r, s)
2003        self.assertEqual(r['n'], 1)
2004        self.assertEqual(r['m'], 3)
2005        self.assertEqual(r['t'], 'z')
2006        r = query(q).getresult()
2007        self.assertEqual(r, [(1, 2, 'x'), (1, 3, 'z')])
2008        s.update(t='n')
2009        r = upsert(table, s, t=True)
2010        self.assertIs(r, s)
2011        self.assertEqual(r['n'], 1)
2012        self.assertEqual(r['m'], 3)
2013        self.assertEqual(r['t'], 'n')
2014        r = query(q).getresult()
2015        self.assertEqual(r, [(1, 2, 'x'), (1, 3, 'n')])
2016        s.update(n=2, t='y')
2017        r = upsert(table, s, t="'z'")
2018        self.assertIs(r, s)
2019        self.assertEqual(r['n'], 2)
2020        self.assertEqual(r['m'], 3)
2021        self.assertEqual(r['t'], 'y')
2022        r = query(q).getresult()
2023        self.assertEqual(r, [(1, 2, 'x'), (1, 3, 'n'), (2, 3, 'y')])
2024        s.update(n=1, t='m')
2025        r = upsert(table, s, t='included.t || excluded.t')
2026        self.assertIs(r, s)
2027        self.assertEqual(r['n'], 1)
2028        self.assertEqual(r['m'], 3)
2029        self.assertEqual(r['t'], 'nm')
2030        r = query(q).getresult()
2031        self.assertEqual(r, [(1, 2, 'x'), (1, 3, 'nm'), (2, 3, 'y')])
2032
2033    def testUpsertWithQuotedNames(self):
2034        upsert = self.db.upsert
2035        query = self.db.query
2036        table = 'test table for upsert()'
2037        self.createTable(table, '"Prime!" smallint primary key,'
2038                                ' "much space" integer, "Questions?" text')
2039        s = {'Prime!': 31, 'much space': 9009, 'Questions?': 'Yes.'}
2040        try:
2041            r = upsert(table, s)
2042        except pg.ProgrammingError as error:
2043            if self.db.server_version < 90500:
2044                self.skipTest('database does not support upsert')
2045            self.fail(str(error))
2046        self.assertIs(r, s)
2047        self.assertEqual(r['Prime!'], 31)
2048        self.assertEqual(r['much space'], 9009)
2049        self.assertEqual(r['Questions?'], 'Yes.')
2050        q = 'select * from "%s" limit 2' % table
2051        r = query(q).getresult()
2052        self.assertEqual(r, [(31, 9009, 'Yes.')])
2053        s.update({'Questions?': 'No.'})
2054        r = upsert(table, s)
2055        self.assertIs(r, s)
2056        self.assertEqual(r['Prime!'], 31)
2057        self.assertEqual(r['much space'], 9009)
2058        self.assertEqual(r['Questions?'], 'No.')
2059        r = query(q).getresult()
2060        self.assertEqual(r, [(31, 9009, 'No.')])
2061
2062    def testClear(self):
2063        clear = self.db.clear
2064        f = False if pg.get_bool() else 'f'
2065        r = clear('test')
2066        result = dict(
2067            i2=0, i4=0, i8=0, d=0, f4=0, f8=0, m=0, v4='', c4='', t='')
2068        self.assertEqual(r, result)
2069        table = 'clear_test_table'
2070        self.createTable(table,
2071            'n integer, f float, b boolean, d date, t text', oids=True)
2072        r = clear(table)
2073        result = dict(n=0, f=0, b=f, d='', t='')
2074        self.assertEqual(r, result)
2075        r['a'] = r['f'] = r['n'] = 1
2076        r['d'] = r['t'] = 'x'
2077        r['b'] = 't'
2078        r['oid'] = long(1)
2079        r = clear(table, r)
2080        result = dict(a=1, n=0, f=0, b=f, d='', t='', oid=long(1))
2081        self.assertEqual(r, result)
2082
2083    def testClearWithQuotedNames(self):
2084        clear = self.db.clear
2085        table = 'test table for clear()'
2086        self.createTable(table, '"Prime!" smallint primary key,'
2087            ' "much space" integer, "Questions?" text')
2088        r = clear(table)
2089        self.assertIsInstance(r, dict)
2090        self.assertEqual(r['Prime!'], 0)
2091        self.assertEqual(r['much space'], 0)
2092        self.assertEqual(r['Questions?'], '')
2093
2094    def testDelete(self):
2095        delete = self.db.delete
2096        query = self.db.query
2097        self.assertRaises(pg.ProgrammingError, delete,
2098                          'test', dict(i2=2, i4=4, i8=8))
2099        table = 'delete_test_table'
2100        self.createTable(table, 'n integer, t text', oids=True,
2101                         values=enumerate('xyz', start=1))
2102        self.assertRaises(pg.ProgrammingError, self.db.get, table, 2)
2103        r = self.db.get(table, 1, 'n')
2104        s = delete(table, r)
2105        self.assertEqual(s, 1)
2106        r = self.db.get(table, 3, 'n')
2107        s = delete(table, r)
2108        self.assertEqual(s, 1)
2109        s = delete(table, r)
2110        self.assertEqual(s, 0)
2111        r = query('select * from "%s"' % table).dictresult()
2112        self.assertEqual(len(r), 1)
2113        r = r[0]
2114        result = {'n': 2, 't': 'y'}
2115        self.assertEqual(r, result)
2116        r = self.db.get(table, 2, 'n')
2117        s = delete(table, r)
2118        self.assertEqual(s, 1)
2119        s = delete(table, r)
2120        self.assertEqual(s, 0)
2121        self.assertRaises(pg.DatabaseError, self.db.get, table, 2, 'n')
2122        # not existing columns and oid parameter should be ignored
2123        r.update(m=3, u='z', oid='invalid')
2124        s = delete(table, r)
2125        self.assertEqual(s, 0)
2126
2127    def testDeleteWithOid(self):
2128        delete = self.db.delete
2129        get = self.db.get
2130        query = self.db.query
2131        self.createTable('test_table', 'n int', oids=True, values=range(1, 7))
2132        r = dict(n=3)
2133        self.assertRaises(pg.ProgrammingError, delete, 'test_table', r)
2134        s = get('test_table', 1, 'n')
2135        qoid = 'oid(test_table)'
2136        self.assertIn(qoid, s)
2137        r = delete('test_table', s)
2138        self.assertEqual(r, 1)
2139        r = delete('test_table', s)
2140        self.assertEqual(r, 0)
2141        q = "select min(n),count(n) from test_table"
2142        self.assertEqual(query(q).getresult()[0], (2, 5))
2143        oid = get('test_table', 2, 'n')[qoid]
2144        s = dict(oid=oid, n=2)
2145        self.assertRaises(pg.ProgrammingError, delete, 'test_table', s)
2146        r = delete('test_table', None, oid=oid)
2147        self.assertEqual(r, 1)
2148        r = delete('test_table', None, oid=oid)
2149        self.assertEqual(r, 0)
2150        self.assertEqual(query(q).getresult()[0], (3, 4))
2151        s = dict(oid=oid, n=2)
2152        oid = get('test_table', 3, 'n')[qoid]
2153        self.assertRaises(pg.ProgrammingError, delete, 'test_table', s)
2154        r = delete('test_table', s, oid=oid)
2155        self.assertEqual(r, 1)
2156        r = delete('test_table', s, oid=oid)
2157        self.assertEqual(r, 0)
2158        self.assertEqual(query(q).getresult()[0], (4, 3))
2159        s = get('test_table', 4, 'n')
2160        r = delete('test_table *', s)
2161        self.assertEqual(r, 1)
2162        r = delete('test_table *', s)
2163        self.assertEqual(r, 0)
2164        self.assertEqual(query(q).getresult()[0], (5, 2))
2165        oid = get('test_table', 5, 'n')[qoid]
2166        s = {qoid: oid, 'm': 4}
2167        r = delete('test_table', s, m=6)
2168        self.assertEqual(r, 1)
2169        r = delete('test_table *', s)
2170        self.assertEqual(r, 0)
2171        self.assertEqual(query(q).getresult()[0], (6, 1))
2172        query("alter table test_table add column m int")
2173        query("alter table test_table add primary key (n)")
2174        self.assertIn('m', self.db.get_attnames('test_table', flush=True))
2175        self.assertEqual('n', self.db.pkey('test_table', flush=True))
2176        for i in range(5):
2177            query("insert into test_table values (%d, %d)" % (i + 1, i + 2))
2178        s = dict(m=2)
2179        self.assertRaises(KeyError, delete, 'test_table', s)
2180        s = dict(m=2, oid=oid)
2181        self.assertRaises(KeyError, delete, 'test_table', s)
2182        r = delete('test_table', dict(m=2), oid=oid)
2183        self.assertEqual(r, 0)
2184        oid = get('test_table', 1, 'n')[qoid]
2185        s = dict(oid=oid)
2186        self.assertRaises(KeyError, delete, 'test_table', s)
2187        r = delete('test_table', s, oid=oid)
2188        self.assertEqual(r, 1)
2189        r = delete('test_table', s, oid=oid)
2190        self.assertEqual(r, 0)
2191        self.assertEqual(query(q).getresult()[0], (2, 5))
2192        s = get('test_table', 2, 'n')
2193        del s['n']
2194        r = delete('test_table', s)
2195        self.assertEqual(r, 1)
2196        r = delete('test_table', s)
2197        self.assertEqual(r, 0)
2198        self.assertEqual(query(q).getresult()[0], (3, 4))
2199        r = delete('test_table', n=3)
2200        self.assertEqual(r, 1)
2201        r = delete('test_table', n=3)
2202        self.assertEqual(r, 0)
2203        self.assertEqual(query(q).getresult()[0], (4, 3))
2204        r = delete('test_table', None, n=4)
2205        self.assertEqual(r, 1)
2206        r = delete('test_table', None, n=4)
2207        self.assertEqual(r, 0)
2208        self.assertEqual(query(q).getresult()[0], (5, 2))
2209        s = dict(n=6)
2210        r = delete('test_table', s, n=5)
2211        self.assertEqual(r, 1)
2212        r = delete('test_table', s, n=5)
2213        self.assertEqual(r, 0)
2214        self.assertEqual(query(q).getresult()[0], (6, 1))
2215
2216    def testDeleteWithCompositeKey(self):
2217        query = self.db.query
2218        table = 'delete_test_table_1'
2219        self.createTable(table, 'n integer primary key, t text',
2220                         values=enumerate('abc', start=1))
2221        self.assertRaises(KeyError, self.db.delete, table, dict(t='b'))
2222        self.assertEqual(self.db.delete(table, dict(n=2)), 1)
2223        r = query('select t from "%s" where n=2' % table).getresult()
2224        self.assertEqual(r, [])
2225        self.assertEqual(self.db.delete(table, dict(n=2)), 0)
2226        r = query('select t from "%s" where n=3' % table).getresult()[0][0]
2227        self.assertEqual(r, 'c')
2228        table = 'delete_test_table_2'
2229        self.createTable(table,
2230             'n integer, m integer, t text, primary key (n, m)',
2231             values=[(n + 1, m + 1, chr(ord('a') + 2 * n + m))
2232                     for n in range(3) for m in range(2)])
2233        self.assertRaises(KeyError, self.db.delete, table, dict(n=2, t='b'))
2234        self.assertEqual(self.db.delete(table, dict(n=2, m=2)), 1)
2235        r = [r[0] for r in query('select t from "%s" where n=2'
2236            ' order by m' % table).getresult()]
2237        self.assertEqual(r, ['c'])
2238        self.assertEqual(self.db.delete(table, dict(n=2, m=2)), 0)
2239        r = [r[0] for r in query('select t from "%s" where n=3'
2240             ' order by m' % table).getresult()]
2241        self.assertEqual(r, ['e', 'f'])
2242        self.assertEqual(self.db.delete(table, dict(n=3, m=1)), 1)
2243        r = [r[0] for r in query('select t from "%s" where n=3'
2244             ' order by m' % table).getresult()]
2245        self.assertEqual(r, ['f'])
2246
2247    def testDeleteWithQuotedNames(self):
2248        delete = self.db.delete
2249        query = self.db.query
2250        table = 'test table for delete()'
2251        self.createTable(table, '"Prime!" smallint primary key,'
2252            ' "much space" integer, "Questions?" text',
2253            values=[(19, 5005, 'Yes!')])
2254        r = {'Prime!': 17}
2255        r = delete(table, r)
2256        self.assertEqual(r, 0)
2257        r = query('select count(*) from "%s"' % table).getresult()
2258        self.assertEqual(r[0][0], 1)
2259        r = {'Prime!': 19}
2260        r = delete(table, r)
2261        self.assertEqual(r, 1)
2262        r = query('select count(*) from "%s"' % table).getresult()
2263        self.assertEqual(r[0][0], 0)
2264
2265    def testDeleteReferenced(self):
2266        delete = self.db.delete
2267        query = self.db.query
2268        self.createTable('test_parent',
2269            'n smallint primary key', values=range(3))
2270        self.createTable('test_child',
2271            'n smallint primary key references test_parent', values=range(3))
2272        q = ("select (select count(*) from test_parent),"
2273             " (select count(*) from test_child)")
2274        self.assertEqual(query(q).getresult()[0], (3, 3))
2275        self.assertRaises(pg.ProgrammingError,
2276                          delete, 'test_parent', None, n=2)
2277        self.assertRaises(pg.ProgrammingError,
2278                          delete, 'test_parent *', None, n=2)
2279        r = delete('test_child', None, n=2)
2280        self.assertEqual(r, 1)
2281        self.assertEqual(query(q).getresult()[0], (3, 2))
2282        r = delete('test_parent', None, n=2)
2283        self.assertEqual(r, 1)
2284        self.assertEqual(query(q).getresult()[0], (2, 2))
2285        self.assertRaises(pg.ProgrammingError,
2286                          delete, 'test_parent', dict(n=0))
2287        self.assertRaises(pg.ProgrammingError,
2288                          delete, 'test_parent *', dict(n=0))
2289        r = delete('test_child', dict(n=0))
2290        self.assertEqual(r, 1)
2291        self.assertEqual(query(q).getresult()[0], (2, 1))
2292        r = delete('test_child', dict(n=0))
2293        self.assertEqual(r, 0)
2294        r = delete('test_parent', dict(n=0))
2295        self.assertEqual(r, 1)
2296        self.assertEqual(query(q).getresult()[0], (1, 1))
2297        r = delete('test_parent', None, n=0)
2298        self.assertEqual(r, 0)
2299        q = "select n from test_parent natural join test_child limit 2"
2300        self.assertEqual(query(q).getresult(), [(1,)])
2301
2302    def testTruncate(self):
2303        truncate = self.db.truncate
2304        self.assertRaises(TypeError, truncate, None)
2305        self.assertRaises(TypeError, truncate, 42)
2306        self.assertRaises(TypeError, truncate, dict(test_table=None))
2307        query = self.db.query
2308        self.createTable('test_table', 'n smallint',
2309                         temporary=False, values=[1] * 3)
2310        q = "select count(*) from test_table"
2311        r = query(q).getresult()[0][0]
2312        self.assertEqual(r, 3)
2313        truncate('test_table')
2314        r = query(q).getresult()[0][0]
2315        self.assertEqual(r, 0)
2316        for i in range(3):
2317            query("insert into test_table values (1)")
2318        r = query(q).getresult()[0][0]
2319        self.assertEqual(r, 3)
2320        truncate('public.test_table')
2321        r = query(q).getresult()[0][0]
2322        self.assertEqual(r, 0)
2323        self.createTable('test_table_2', 'n smallint', temporary=True)
2324        for t in (list, tuple, set):
2325            for i in range(3):
2326                query("insert into test_table values (1)")
2327                query("insert into test_table_2 values (2)")
2328            q = ("select (select count(*) from test_table),"
2329                 " (select count(*) from test_table_2)")
2330            r = query(q).getresult()[0]
2331            self.assertEqual(r, (3, 3))
2332            truncate(t(['test_table', 'test_table_2']))
2333            r = query(q).getresult()[0]
2334            self.assertEqual(r, (0, 0))
2335
2336    def testTruncateRestart(self):
2337        truncate = self.db.truncate
2338        self.assertRaises(TypeError, truncate, 'test_table', restart='invalid')
2339        query = self.db.query
2340        self.createTable('test_table', 'n serial, t text')
2341        for n in range(3):
2342            query("insert into test_table (t) values ('test')")
2343        q = "select count(n), min(n), max(n) from test_table"
2344        r = query(q).getresult()[0]
2345        self.assertEqual(r, (3, 1, 3))
2346        truncate('test_table')
2347        r = query(q).getresult()[0]
2348        self.assertEqual(r, (0, None, None))
2349        for n in range(3):
2350            query("insert into test_table (t) values ('test')")
2351        r = query(q).getresult()[0]
2352        self.assertEqual(r, (3, 4, 6))
2353        truncate('test_table', restart=True)
2354        r = query(q).getresult()[0]
2355        self.assertEqual(r, (0, None, None))
2356        for n in range(3):
2357            query("insert into test_table (t) values ('test')")
2358        r = query(q).getresult()[0]
2359        self.assertEqual(r, (3, 1, 3))
2360
2361    def testTruncateCascade(self):
2362        truncate = self.db.truncate
2363        self.assertRaises(TypeError, truncate, 'test_table', cascade='invalid')
2364        query = self.db.query
2365        self.createTable('test_parent', 'n smallint primary key',
2366                         values=range(3))
2367        self.createTable('test_child',
2368                         'n smallint primary key references test_parent (n)',
2369                         values=range(3))
2370        q = ("select (select count(*) from test_parent),"
2371             " (select count(*) from test_child)")
2372        r = query(q).getresult()[0]
2373        self.assertEqual(r, (3, 3))
2374        self.assertRaises(pg.ProgrammingError, truncate, 'test_parent')
2375        truncate(['test_parent', 'test_child'])
2376        r = query(q).getresult()[0]
2377        self.assertEqual(r, (0, 0))
2378        for n in range(3):
2379            query("insert into test_parent (n) values (%d)" % n)
2380            query("insert into test_child (n) values (%d)" % n)
2381        r = query(q).getresult()[0]
2382        self.assertEqual(r, (3, 3))
2383        truncate('test_parent', cascade=True)
2384        r = query(q).getresult()[0]
2385        self.assertEqual(r, (0, 0))
2386        for n in range(3):
2387            query("insert into test_parent (n) values (%d)" % n)
2388            query("insert into test_child (n) values (%d)" % n)
2389        r = query(q).getresult()[0]
2390        self.assertEqual(r, (3, 3))
2391        truncate('test_child')
2392        r = query(q).getresult()[0]
2393        self.assertEqual(r, (3, 0))
2394        self.assertRaises(pg.ProgrammingError, truncate, 'test_parent')
2395        truncate('test_parent', cascade=True)
2396        r = query(q).getresult()[0]
2397        self.assertEqual(r, (0, 0))
2398
2399    def testTruncateOnly(self):
2400        truncate = self.db.truncate
2401        self.assertRaises(TypeError, truncate, 'test_table', only='invalid')
2402        query = self.db.query
2403        self.createTable('test_parent', 'n smallint')
2404        self.createTable('test_child', 'm smallint) inherits (test_parent')
2405        for n in range(3):
2406            query("insert into test_parent (n) values (1)")
2407            query("insert into test_child (n, m) values (2, 3)")
2408        q = ("select (select count(*) from test_parent),"
2409             " (select count(*) from test_child)")
2410        r = query(q).getresult()[0]
2411        self.assertEqual(r, (6, 3))
2412        truncate('test_parent')
2413        r = query(q).getresult()[0]
2414        self.assertEqual(r, (0, 0))
2415        for n in range(3):
2416            query("insert into test_parent (n) values (1)")
2417            query("insert into test_child (n, m) values (2, 3)")
2418        r = query(q).getresult()[0]
2419        self.assertEqual(r, (6, 3))
2420        truncate('test_parent*')
2421        r = query(q).getresult()[0]
2422        self.assertEqual(r, (0, 0))
2423        for n in range(3):
2424            query("insert into test_parent (n) values (1)")
2425            query("insert into test_child (n, m) values (2, 3)")
2426        r = query(q).getresult()[0]
2427        self.assertEqual(r, (6, 3))
2428        truncate('test_parent', only=True)
2429        r = query(q).getresult()[0]
2430        self.assertEqual(r, (3, 3))
2431        truncate('test_parent', only=False)
2432        r = query(q).getresult()[0]
2433        self.assertEqual(r, (0, 0))
2434        self.assertRaises(ValueError, truncate, 'test_parent*', only=True)
2435        truncate('test_parent*', only=False)
2436        self.createTable('test_parent_2', 'n smallint')
2437        self.createTable('test_child_2', 'm smallint) inherits (test_parent_2')
2438        for t in '', '_2':
2439            for n in range(3):
2440                query("insert into test_parent%s (n) values (1)" % t)
2441                query("insert into test_child%s (n, m) values (2, 3)" % t)
2442        q = ("select (select count(*) from test_parent),"
2443             " (select count(*) from test_child),"
2444             " (select count(*) from test_parent_2),"
2445             " (select count(*) from test_child_2)")
2446        r = query(q).getresult()[0]
2447        self.assertEqual(r, (6, 3, 6, 3))
2448        truncate(['test_parent', 'test_parent_2'], only=[False, True])
2449        r = query(q).getresult()[0]
2450        self.assertEqual(r, (0, 0, 3, 3))
2451        truncate(['test_parent', 'test_parent_2'], only=False)
2452        r = query(q).getresult()[0]
2453        self.assertEqual(r, (0, 0, 0, 0))
2454        self.assertRaises(ValueError, truncate,
2455            ['test_parent*', 'test_child'], only=[True, False])
2456        truncate(['test_parent*', 'test_child'], only=[False, True])
2457
2458    def testTruncateQuoted(self):
2459        truncate = self.db.truncate
2460        query = self.db.query
2461        table = "test table for truncate()"
2462        self.createTable(table, 'n smallint', temporary=False, values=[1] * 3)
2463        q = 'select count(*) from "%s"' % table
2464        r = query(q).getresult()[0][0]
2465        self.assertEqual(r, 3)
2466        truncate(table)
2467        r = query(q).getresult()[0][0]
2468        self.assertEqual(r, 0)
2469        for i in range(3):
2470            query('insert into "%s" values (1)' % table)
2471        r = query(q).getresult()[0][0]
2472        self.assertEqual(r, 3)
2473        truncate('public."%s"' % table)
2474        r = query(q).getresult()[0][0]
2475        self.assertEqual(r, 0)
2476
2477    def testGetAsList(self):
2478        get_as_list = self.db.get_as_list
2479        self.assertRaises(TypeError, get_as_list)
2480        self.assertRaises(TypeError, get_as_list, None)
2481        query = self.db.query
2482        table = 'test_aslist'
2483        r = query('select 1 as colname').namedresult()[0]
2484        self.assertIsInstance(r, tuple)
2485        named = hasattr(r, 'colname')
2486        names = [(1, 'Homer'), (2, 'Marge'),
2487                 (3, 'Bart'), (4, 'Lisa'), (5, 'Maggie')]
2488        self.createTable(table,
2489            'id smallint primary key, name varchar', values=names)
2490        r = get_as_list(table)
2491        self.assertIsInstance(r, list)
2492        self.assertEqual(r, names)
2493        for t, n in zip(r, names):
2494            self.assertIsInstance(t, tuple)
2495            self.assertEqual(t, n)
2496            if named:
2497                self.assertEqual(t.id, n[0])
2498                self.assertEqual(t.name, n[1])
2499                self.assertEqual(t._asdict(), dict(id=n[0], name=n[1]))
2500        r = get_as_list(table, what='name')
2501        self.assertIsInstance(r, list)
2502        expected = sorted((row[1],) for row in names)
2503        self.assertEqual(r, expected)
2504        r = get_as_list(table, what='name, id')
2505        self.assertIsInstance(r, list)
2506        expected = sorted(tuple(reversed(row)) for row in names)
2507        self.assertEqual(r, expected)
2508        r = get_as_list(table, what=['name', 'id'])
2509        self.assertIsInstance(r, list)
2510        self.assertEqual(r, expected)
2511        r = get_as_list(table, where="name like 'Ba%'")
2512        self.assertIsInstance(r, list)
2513        self.assertEqual(r, names[2:3])
2514        r = get_as_list(table, what='name', where="name like 'Ma%'")
2515        self.assertIsInstance(r, list)
2516        self.assertEqual(r, [('Maggie',), ('Marge',)])
2517        r = get_as_list(table, what='name',
2518            where=["name like 'Ma%'", "name like '%r%'"])
2519        self.assertIsInstance(r, list)
2520        self.assertEqual(r, [('Marge',)])
2521        r = get_as_list(table, what='name', order='id')
2522        self.assertIsInstance(r, list)
2523        expected = [(row[1],) for row in names]
2524        self.assertEqual(r, expected)
2525        r = get_as_list(table, what=['name'], order=['id'])
2526        self.assertIsInstance(r, list)
2527        self.assertEqual(r, expected)
2528        r = get_as_list(table, what=['id', 'name'], order=['id', 'name'])
2529        self.assertIsInstance(r, list)
2530        self.assertEqual(r, names)
2531        r = get_as_list(table, what='id * 2 as num', order='id desc')
2532        self.assertIsInstance(r, list)
2533        expected = [(n,) for n in range(10, 0, -2)]
2534        self.assertEqual(r, expected)
2535        r = get_as_list(table, limit=2)
2536        self.assertIsInstance(r, list)
2537        self.assertEqual(r, names[:2])
2538        r = get_as_list(table, offset=3)
2539        self.assertIsInstance(r, list)
2540        self.assertEqual(r, names[3:])
2541        r = get_as_list(table, limit=1, offset=2)
2542        self.assertIsInstance(r, list)
2543        self.assertEqual(r, names[2:3])
2544        r = get_as_list(table, scalar=True)
2545        self.assertIsInstance(r, list)
2546        self.assertEqual(r, list(range(1, 6)))
2547        r = get_as_list(table, what='name', scalar=True)
2548        self.assertIsInstance(r, list)
2549        expected = sorted(row[1] for row in names)
2550        self.assertEqual(r, expected)
2551        r = get_as_list(table, what='name', limit=1, scalar=True)
2552        self.assertIsInstance(r, list)
2553        self.assertEqual(r, expected[:1])
2554        query('alter table "%s" drop constraint "%s_pkey"' % (table, table))
2555        self.assertRaises(KeyError, self.db.pkey, table, flush=True)
2556        names.insert(1, (1, 'Snowball'))
2557        query('insert into "%s" values ($1, $2)' % table, (1, 'Snowball'))
2558        r = get_as_list(table)
2559        self.assertIsInstance(r, list)
2560        self.assertEqual(r, names)
2561        r = get_as_list(table, what='name', where='id=1', scalar=True)
2562        self.assertIsInstance(r, list)
2563        self.assertEqual(r, ['Homer', 'Snowball'])
2564        # test with unordered query
2565        r = get_as_list(table, order=False)
2566        self.assertIsInstance(r, list)
2567        self.assertEqual(set(r), set(names))
2568        # test with arbitrary from clause
2569        from_table = '(select lower(name) as n2 from "%s") as t2' % table
2570        r = get_as_list(from_table)
2571        self.assertIsInstance(r, list)
2572        r = set(row[0] for row in r)
2573        expected = set(row[1].lower() for row in names)
2574        self.assertEqual(r, expected)
2575        r = get_as_list(from_table, order='n2', scalar=True)
2576        self.assertIsInstance(r, list)
2577        self.assertEqual(r, sorted(expected))
2578        r = get_as_list(from_table, order='n2', limit=1)
2579        self.assertIsInstance(r, list)
2580        self.assertEqual(len(r), 1)
2581        t = r[0]
2582        self.assertIsInstance(t, tuple)
2583        if named:
2584            self.assertEqual(t.n2, 'bart')
2585            self.assertEqual(t._asdict(), dict(n2='bart'))
2586        else:
2587            self.assertEqual(t, ('bart',))
2588
2589    def testGetAsDict(self):
2590        get_as_dict = self.db.get_as_dict
2591        self.assertRaises(TypeError, get_as_dict)
2592        self.assertRaises(TypeError, get_as_dict, None)
2593        # the test table has no primary key
2594        self.assertRaises(pg.ProgrammingError, get_as_dict, 'test')
2595        query = self.db.query
2596        table = 'test_asdict'
2597        r = query('select 1 as colname').namedresult()[0]
2598        self.assertIsInstance(r, tuple)
2599        named = hasattr(r, 'colname')
2600        colors = [(1, '#7cb9e8', 'Aero'), (2, '#b5a642', 'Brass'),
2601                  (3, '#b2ffff', 'Celeste'), (4, '#c19a6b', 'Desert')]
2602        self.createTable(table,
2603            'id smallint primary key, rgb char(7), name varchar',
2604            values=colors)
2605        # keyname must be string, list or tuple
2606        self.assertRaises(KeyError, get_as_dict, table, 3)
2607        self.assertRaises(KeyError, get_as_dict, table, dict(id=None))
2608        # missing keyname in row
2609        self.assertRaises(KeyError, get_as_dict, table,
2610                          keyname='rgb', what='name')
2611        r = get_as_dict(table)
2612        self.assertIsInstance(r, OrderedDict)
2613        expected = OrderedDict((row[0], row[1:]) for row in colors)
2614        self.assertEqual(r, expected)
2615        for key in r:
2616            self.assertIsInstance(key, int)
2617            self.assertIn(key, expected)
2618            row = r[key]
2619            self.assertIsInstance(row, tuple)
2620            t = expected[key]
2621            self.assertEqual(row, t)
2622            if named:
2623                self.assertEqual(row.rgb, t[0])
2624                self.assertEqual(row.name, t[1])
2625                self.assertEqual(row._asdict(), dict(rgb=t[0], name=t[1]))
2626        if OrderedDict is not dict:  # Python > 2.6
2627            self.assertEqual(r.keys(), expected.keys())
2628        r = get_as_dict(table, keyname='rgb')
2629        self.assertIsInstance(r, OrderedDict)
2630        expected = OrderedDict((row[1], (row[0], row[2]))
2631                               for row in sorted(colors, key=itemgetter(1)))
2632        self.assertEqual(r, expected)
2633        for key in r:
2634            self.assertIsInstance(key, str)
2635            self.assertIn(key, expected)
2636            row = r[key]
2637            self.assertIsInstance(row, tuple)
2638            t = expected[key]
2639            self.assertEqual(row, t)
2640            if named:
2641                self.assertEqual(row.id, t[0])
2642                self.assertEqual(row.name, t[1])
2643                self.assertEqual(row._asdict(), dict(id=t[0], name=t[1]))
2644        if OrderedDict is not dict:  # Python > 2.6
2645            self.assertEqual(r.keys(), expected.keys())
2646        r = get_as_dict(table, keyname=['id', 'rgb'])
2647        self.assertIsInstance(r, OrderedDict)
2648        expected = OrderedDict((row[:2], row[2:]) for row in colors)
2649        self.assertEqual(r, expected)
2650        for key in r:
2651            self.assertIsInstance(key, tuple)
2652            self.assertIsInstance(key[0], int)
2653            self.assertIsInstance(key[1], str)
2654            if named:
2655                self.assertEqual(key, (key.id, key.rgb))
2656                self.assertEqual(key._fields, ('id', 'rgb'))
2657            row = r[key]
2658            self.assertIsInstance(row, tuple)
2659            self.assertIsInstance(row[0], str)
2660            t = expected[key]
2661            self.assertEqual(row, t)
2662            if named:
2663                self.assertEqual(row.name, t[0])
2664                self.assertEqual(row._asdict(), dict(name=t[0]))
2665        if OrderedDict is not dict:  # Python > 2.6
2666            self.assertEqual(r.keys(), expected.keys())
2667        r = get_as_dict(table, keyname=['id', 'rgb'], scalar=True)
2668        self.assertIsInstance(r, OrderedDict)
2669        expected = OrderedDict((row[:2], row[2]) for row in colors)
2670        self.assertEqual(r, expected)
2671        for key in r:
2672            self.assertIsInstance(key, tuple)
2673            row = r[key]
2674            self.assertIsInstance(row, str)
2675            t = expected[key]
2676            self.assertEqual(row, t)
2677        if OrderedDict is not dict:  # Python > 2.6
2678            self.assertEqual(r.keys(), expected.keys())
2679        r = get_as_dict(table, keyname='rgb', what=['rgb', 'name'], scalar=True)
2680        self.assertIsInstance(r, OrderedDict)
2681        expected = OrderedDict((row[1], row[2])
2682            for row in sorted(colors, key=itemgetter(1)))
2683        self.assertEqual(r, expected)
2684        for key in r:
2685            self.assertIsInstance(key, str)
2686            row = r[key]
2687            self.assertIsInstance(row, str)
2688            t = expected[key]
2689            self.assertEqual(row, t)
2690        if OrderedDict is not dict:  # Python > 2.6
2691            self.assertEqual(r.keys(), expected.keys())
2692        r = get_as_dict(table, what='id, name',
2693            where="rgb like '#b%'", scalar=True)
2694        self.assertIsInstance(r, OrderedDict)
2695        expected = OrderedDict((row[0], row[2]) for row in colors[1:3])
2696        self.assertEqual(r, expected)
2697        for key in r:
2698            self.assertIsInstance(key, int)
2699            row = r[key]
2700            self.assertIsInstance(row, str)
2701            t = expected[key]
2702            self.assertEqual(row, t)
2703        if OrderedDict is not dict:  # Python > 2.6
2704            self.assertEqual(r.keys(), expected.keys())
2705        expected = r
2706        r = get_as_dict(table, what=['name', 'id'],
2707            where=['id > 1', 'id < 4', "rgb like '#b%'",
2708                   "name not like 'A%'", "name not like '%t'"], scalar=True)
2709        self.assertEqual(r, expected)
2710        r = get_as_dict(table, what='name, id', limit=2, offset=1, scalar=True)
2711        self.assertEqual(r, expected)
2712        r = get_as_dict(table, keyname=('id',), what=('name', 'id'),
2713            where=('id > 1', 'id < 4'), order=('id',), scalar=True)
2714        self.assertEqual(r, expected)
2715        r = get_as_dict(table, limit=1)
2716        self.assertEqual(len(r), 1)
2717        self.assertEqual(r[1][1], 'Aero')
2718        r = get_as_dict(table, offset=3)
2719        self.assertEqual(len(r), 1)
2720        self.assertEqual(r[4][1], 'Desert')
2721        r = get_as_dict(table, order='id desc')
2722        expected = OrderedDict((row[0], row[1:]) for row in reversed(colors))
2723        self.assertEqual(r, expected)
2724        r = get_as_dict(table, where='id > 5')
2725        self.assertIsInstance(r, OrderedDict)
2726        self.assertEqual(len(r), 0)
2727        # test with unordered query
2728        expected = dict((row[0], row[1:]) for row in colors)
2729        r = get_as_dict(table, order=False)
2730        self.assertIsInstance(r, dict)
2731        self.assertEqual(r, expected)
2732        if dict is not OrderedDict:  # Python > 2.6
2733            self.assertNotIsInstance(self, OrderedDict)
2734        # test with arbitrary from clause
2735        from_table = '(select id, lower(name) as n2 from "%s") as t2' % table
2736        # primary key must be passed explicitly in this case
2737        self.assertRaises(pg.ProgrammingError, get_as_dict, from_table)
2738        r = get_as_dict(from_table, 'id')
2739        self.assertIsInstance(r, OrderedDict)
2740        expected = OrderedDict((row[0], (row[2].lower(),)) for row in colors)
2741        self.assertEqual(r, expected)
2742        # test without a primary key
2743        query('alter table "%s" drop constraint "%s_pkey"' % (table, table))
2744        self.assertRaises(KeyError, self.db.pkey, table, flush=True)
2745        self.assertRaises(pg.ProgrammingError, get_as_dict, table)
2746        r = get_as_dict(table, keyname='id')
2747        expected = OrderedDict((row[0], row[1:]) for row in colors)
2748        self.assertIsInstance(r, dict)
2749        self.assertEqual(r, expected)
2750        r = (1, '#007fff', 'Azure')
2751        query('insert into "%s" values ($1, $2, $3)' % table, r)
2752        # the last entry will win
2753        expected[1] = r[1:]
2754        r = get_as_dict(table, keyname='id')
2755        self.assertEqual(r, expected)
2756
2757    def testTransaction(self):
2758        query = self.db.query
2759        self.createTable('test_table', 'n integer', temporary=False)
2760        self.db.begin()
2761        query("insert into test_table values (1)")
2762        query("insert into test_table values (2)")
2763        self.db.commit()
2764        self.db.begin()
2765        query("insert into test_table values (3)")
2766        query("insert into test_table values (4)")
2767        self.db.rollback()
2768        self.db.begin()
2769        query("insert into test_table values (5)")
2770        self.db.savepoint('before6')
2771        query("insert into test_table values (6)")
2772        self.db.rollback('before6')
2773        query("insert into test_table values (7)")
2774        self.db.commit()
2775        self.db.begin()
2776        self.db.savepoint('before8')
2777        query("insert into test_table values (8)")
2778        self.db.release('before8')
2779        self.assertRaises(pg.ProgrammingError, self.db.rollback, 'before8')
2780        self.db.commit()
2781        self.db.start()
2782        query("insert into test_table values (9)")
2783        self.db.end()
2784        r = [r[0] for r in query(
2785            "select * from test_table order by 1").getresult()]
2786        self.assertEqual(r, [1, 2, 5, 7, 9])
2787        self.db.begin(mode='read only')
2788        self.assertRaises(pg.ProgrammingError,
2789                          query, "insert into test_table values (0)")
2790        self.db.rollback()
2791        self.db.start(mode='Read Only')
2792        self.assertRaises(pg.ProgrammingError,
2793                          query, "insert into test_table values (0)")
2794        self.db.abort()
2795
2796    def testTransactionAliases(self):
2797        self.assertEqual(self.db.begin, self.db.start)
2798        self.assertEqual(self.db.commit, self.db.end)
2799        self.assertEqual(self.db.rollback, self.db.abort)
2800
2801    def testContextManager(self):
2802        query = self.db.query
2803        self.createTable('test_table', 'n integer check(n>0)')
2804        with self.db:
2805            query("insert into test_table values (1)")
2806            query("insert into test_table values (2)")
2807        try:
2808            with self.db:
2809                query("insert into test_table values (3)")
2810                query("insert into test_table values (4)")
2811                raise ValueError('test transaction should rollback')
2812        except ValueError as error:
2813            self.assertEqual(str(error), 'test transaction should rollback')
2814        with self.db:
2815            query("insert into test_table values (5)")
2816        try:
2817            with self.db:
2818                query("insert into test_table values (6)")
2819                query("insert into test_table values (-1)")
2820        except pg.ProgrammingError as error:
2821            self.assertTrue('check' in str(error))
2822        with self.db:
2823            query("insert into test_table values (7)")
2824        r = [r[0] for r in query(
2825            "select * from test_table order by 1").getresult()]
2826        self.assertEqual(r, [1, 2, 5, 7])
2827
2828    def testBytea(self):
2829        query = self.db.query
2830        self.createTable('bytea_test', 'n smallint primary key, data bytea')
2831        s = b"It's all \\ kinds \x00 of\r nasty \xff stuff!\n"
2832        r = self.db.escape_bytea(s)
2833        query('insert into bytea_test values(3, $1)', (r,))
2834        r = query('select * from bytea_test where n=3').getresult()
2835        self.assertEqual(len(r), 1)
2836        r = r[0]
2837        self.assertEqual(len(r), 2)
2838        self.assertEqual(r[0], 3)
2839        r = r[1]
2840        self.assertIsInstance(r, bytes)
2841        self.assertEqual(r, s)
2842
2843    def testInsertUpdateGetBytea(self):
2844        query = self.db.query
2845        self.createTable('bytea_test', 'n smallint primary key, data bytea')
2846        # insert null value
2847        r = self.db.insert('bytea_test', n=0, data=None)
2848        self.assertIsInstance(r, dict)
2849        self.assertIn('n', r)
2850        self.assertEqual(r['n'], 0)
2851        self.assertIn('data', r)
2852        self.assertIsNone(r['data'])
2853        s = b'None'
2854        r = self.db.update('bytea_test', n=0, data=s)
2855        self.assertIsInstance(r, dict)
2856        self.assertIn('n', r)
2857        self.assertEqual(r['n'], 0)
2858        self.assertIn('data', r)
2859        r = r['data']
2860        self.assertIsInstance(r, bytes)
2861        self.assertEqual(r, s)
2862        r = self.db.update('bytea_test', n=0, data=None)
2863        self.assertIsNone(r['data'])
2864        # insert as bytes
2865        s = b"It's all \\ kinds \x00 of\r nasty \xff stuff!\n"
2866        r = self.db.insert('bytea_test', n=5, data=s)
2867        self.assertIsInstance(r, dict)
2868        self.assertIn('n', r)
2869        self.assertEqual(r['n'], 5)
2870        self.assertIn('data', r)
2871        r = r['data']
2872        self.assertIsInstance(r, bytes)
2873        self.assertEqual(r, s)
2874        # update as bytes
2875        s += b"and now even more \x00 nasty \t stuff!\f"
2876        r = self.db.update('bytea_test', n=5, data=s)
2877        self.assertIsInstance(r, dict)
2878        self.assertIn('n', r)
2879        self.assertEqual(r['n'], 5)
2880        self.assertIn('data', r)
2881        r = r['data']
2882        self.assertIsInstance(r, bytes)
2883        self.assertEqual(r, s)
2884        r = query('select * from bytea_test where n=5').getresult()
2885        self.assertEqual(len(r), 1)
2886        r = r[0]
2887        self.assertEqual(len(r), 2)
2888        self.assertEqual(r[0], 5)
2889        r = r[1]
2890        self.assertIsInstance(r, bytes)
2891        self.assertEqual(r, s)
2892        r = self.db.get('bytea_test', dict(n=5))
2893        self.assertIsInstance(r, dict)
2894        self.assertIn('n', r)
2895        self.assertEqual(r['n'], 5)
2896        self.assertIn('data', r)
2897        r = r['data']
2898        self.assertIsInstance(r, bytes)
2899        self.assertEqual(r, s)
2900
2901    def testUpsertBytea(self):
2902        self.createTable('bytea_test', 'n smallint primary key, data bytea')
2903        s = b"It's all \\ kinds \x00 of\r nasty \xff stuff!\n"
2904        r = dict(n=7, data=s)
2905        try:
2906            r = self.db.upsert('bytea_test', r)
2907        except pg.ProgrammingError as error:
2908            if self.db.server_version < 90500:
2909                self.skipTest('database does not support upsert')
2910            self.fail(str(error))
2911        self.assertIsInstance(r, dict)
2912        self.assertIn('n', r)
2913        self.assertEqual(r['n'], 7)
2914        self.assertIn('data', r)
2915        self.assertIsInstance(r['data'], bytes)
2916        self.assertEqual(r['data'], s)
2917        r['data'] = None
2918        r = self.db.upsert('bytea_test', r)
2919        self.assertIsInstance(r, dict)
2920        self.assertIn('n', r)
2921        self.assertEqual(r['n'], 7)
2922        self.assertIn('data', r)
2923        self.assertIsNone(r['data'], bytes)
2924
2925    def testInsertGetJson(self):
2926        try:
2927            self.createTable('json_test', 'n smallint primary key, data json')
2928        except pg.ProgrammingError as error:
2929            if self.db.server_version < 90200:
2930                self.skipTest('database does not support json')
2931            self.fail(str(error))
2932        jsondecode = pg.get_jsondecode()
2933        # insert null value
2934        r = self.db.insert('json_test', n=0, data=None)
2935        self.assertIsInstance(r, dict)
2936        self.assertIn('n', r)
2937        self.assertEqual(r['n'], 0)
2938        self.assertIn('data', r)
2939        self.assertIsNone(r['data'])
2940        r = self.db.get('json_test', 0)
2941        self.assertIsInstance(r, dict)
2942        self.assertIn('n', r)
2943        self.assertEqual(r['n'], 0)
2944        self.assertIn('data', r)
2945        self.assertIsNone(r['data'])
2946        # insert JSON object
2947        data = {
2948            "id": 1, "name": "Foo", "price": 1234.5,
2949            "new": True, "note": None,
2950            "tags": ["Bar", "Eek"],
2951            "stock": {"warehouse": 300, "retail": 20}}
2952        r = self.db.insert('json_test', n=1, data=data)
2953        self.assertIsInstance(r, dict)
2954        self.assertIn('n', r)
2955        self.assertEqual(r['n'], 1)
2956        self.assertIn('data', r)
2957        r = r['data']
2958        if jsondecode is None:
2959            self.assertIsInstance(r, str)
2960            r = json.loads(r)
2961        self.assertIsInstance(r, dict)
2962        self.assertEqual(r, data)
2963        self.assertIsInstance(r['id'], int)
2964        self.assertIsInstance(r['name'], unicode)
2965        self.assertIsInstance(r['price'], float)
2966        self.assertIsInstance(r['new'], bool)
2967        self.assertIsInstance(r['tags'], list)
2968        self.assertIsInstance(r['stock'], dict)
2969        r = self.db.get('json_test', 1)
2970        self.assertIsInstance(r, dict)
2971        self.assertIn('n', r)
2972        self.assertEqual(r['n'], 1)
2973        self.assertIn('data', r)
2974        r = r['data']
2975        if jsondecode is None:
2976            self.assertIsInstance(r, str)
2977            r = json.loads(r)
2978        self.assertIsInstance(r, dict)
2979        self.assertEqual(r, data)
2980        self.assertIsInstance(r['id'], int)
2981        self.assertIsInstance(r['name'], unicode)
2982        self.assertIsInstance(r['price'], float)
2983        self.assertIsInstance(r['new'], bool)
2984        self.assertIsInstance(r['tags'], list)
2985        self.assertIsInstance(r['stock'], dict)
2986        # insert JSON object as text
2987        self.db.insert('json_test', n=2, data=json.dumps(data))
2988        q = "select data from json_test where n in (1, 2) order by n"
2989        r = self.db.query(q).getresult()
2990        self.assertEqual(len(r), 2)
2991        self.assertIsInstance(r[0][0], str if jsondecode is None else dict)
2992        self.assertEqual(r[0][0], r[1][0])
2993
2994    def testInsertGetJsonb(self):
2995        try:
2996            self.createTable('jsonb_test',
2997                             'n smallint primary key, data jsonb')
2998        except pg.ProgrammingError as error:
2999            if self.db.server_version < 90400:
3000                self.skipTest('database does not support jsonb')
3001            self.fail(str(error))
3002        jsondecode = pg.get_jsondecode()
3003        # insert null value
3004        r = self.db.insert('jsonb_test', n=0, data=None)
3005        self.assertIsInstance(r, dict)
3006        self.assertIn('n', r)
3007        self.assertEqual(r['n'], 0)
3008        self.assertIn('data', r)
3009        self.assertIsNone(r['data'])
3010        r = self.db.get('jsonb_test', 0)
3011        self.assertIsInstance(r, dict)
3012        self.assertIn('n', r)
3013        self.assertEqual(r['n'], 0)
3014        self.assertIn('data', r)
3015        self.assertIsNone(r['data'])
3016        # insert JSON object
3017        data = {
3018            "id": 1, "name": "Foo", "price": 1234.5,
3019            "new": True, "note": None,
3020            "tags": ["Bar", "Eek"],
3021            "stock": {"warehouse": 300, "retail": 20}}
3022        r = self.db.insert('jsonb_test', n=1, data=data)
3023        self.assertIsInstance(r, dict)
3024        self.assertIn('n', r)
3025        self.assertEqual(r['n'], 1)
3026        self.assertIn('data', r)
3027        r = r['data']
3028        if jsondecode is None:
3029            self.assertIsInstance(r, str)
3030            r = json.loads(r)
3031        self.assertIsInstance(r, dict)
3032        self.assertEqual(r, data)
3033        self.assertIsInstance(r['id'], int)
3034        self.assertIsInstance(r['name'], unicode)
3035        self.assertIsInstance(r['price'], float)
3036        self.assertIsInstance(r['new'], bool)
3037        self.assertIsInstance(r['tags'], list)
3038        self.assertIsInstance(r['stock'], dict)
3039        r = self.db.get('jsonb_test', 1)
3040        self.assertIsInstance(r, dict)
3041        self.assertIn('n', r)
3042        self.assertEqual(r['n'], 1)
3043        self.assertIn('data', r)
3044        r = r['data']
3045        if jsondecode is None:
3046            self.assertIsInstance(r, str)
3047            r = json.loads(r)
3048        self.assertIsInstance(r, dict)
3049        self.assertEqual(r, data)
3050        self.assertIsInstance(r['id'], int)
3051        self.assertIsInstance(r['name'], unicode)
3052        self.assertIsInstance(r['price'], float)
3053        self.assertIsInstance(r['new'], bool)
3054        self.assertIsInstance(r['tags'], list)
3055        self.assertIsInstance(r['stock'], dict)
3056
3057    def testArray(self):
3058        self.createTable('arraytest',
3059            'id smallint, i2 smallint[], i4 integer[], i8 bigint[],'
3060            ' d numeric[], f4 real[], f8 double precision[], m money[],'
3061            ' b bool[], v4 varchar(4)[], c4 char(4)[], t text[]')
3062        r = self.db.get_attnames('arraytest')
3063        if self.regtypes:
3064            self.assertEqual(r, dict(
3065                id='smallint', i2='smallint[]', i4='integer[]', i8='bigint[]',
3066                d='numeric[]', f4='real[]', f8='double precision[]',
3067                m='money[]', b='boolean[]',
3068                v4='character varying[]', c4='character[]', t='text[]'))
3069        else:
3070            self.assertEqual(r, dict(
3071                id='int', i2='int[]', i4='int[]', i8='int[]',
3072                d='num[]', f4='float[]', f8='float[]', m='money[]',
3073                b='bool[]', v4='text[]', c4='text[]', t='text[]'))
3074        decimal = pg.get_decimal()
3075        if decimal is Decimal:
3076            long_decimal = decimal('123456789.123456789')
3077            odd_money = decimal('1234567891234567.89')
3078        else:
3079            long_decimal = decimal('12345671234.5')
3080            odd_money = decimal('1234567123.25')
3081        t, f = (True, False) if pg.get_bool() else ('t', 'f')
3082        data = dict(id=42, i2=[42, 1234, None, 0, -1],
3083            i4=[42, 123456789, None, 0, 1, -1],
3084            i8=[long(42), long(123456789123456789), None,
3085                long(0), long(1), long(-1)],
3086            d=[decimal(42), long_decimal, None,
3087               decimal(0), decimal(1), decimal(-1), -long_decimal],
3088            f4=[42.0, 1234.5, None, 0.0, 1.0, -1.0,
3089                float('inf'), float('-inf')],
3090            f8=[42.0, 12345671234.5, None, 0.0, 1.0, -1.0,
3091                float('inf'), float('-inf')],
3092            m=[decimal('42.00'), odd_money, None,
3093               decimal('0.00'), decimal('1.00'), decimal('-1.00'), -odd_money],
3094            b=[t, f, t, None, f, t, None, None, t],
3095            v4=['abc', '"Hi"', '', None], c4=['abc ', '"Hi"', '    ', None],
3096            t=['abc', 'Hello, World!', '"Hello, World!"', '', None])
3097        r = data.copy()
3098        self.db.insert('arraytest', r)
3099        self.assertEqual(r, data)
3100        self.db.insert('arraytest', r)
3101        r = self.db.get('arraytest', 42, 'id')
3102        self.assertEqual(r, data)
3103        r = self.db.query('select * from arraytest limit 1').dictresult()[0]
3104        self.assertEqual(r, data)
3105
3106    def testArrayLiteral(self):
3107        insert = self.db.insert
3108        self.createTable('arraytest', 'i int[], t text[]', oids=True)
3109        r = dict(i=[1, 2, 3], t=['a', 'b', 'c'])
3110        insert('arraytest', r)
3111        self.assertEqual(r['i'], [1, 2, 3])
3112        self.assertEqual(r['t'], ['a', 'b', 'c'])
3113        r = dict(i='{1,2,3}', t='{a,b,c}')
3114        self.db.insert('arraytest', r)
3115        self.assertEqual(r['i'], [1, 2, 3])
3116        self.assertEqual(r['t'], ['a', 'b', 'c'])
3117        L = pg.Literal
3118        r = dict(i=L("ARRAY[1, 2, 3]"), t=L("ARRAY['a', 'b', 'c']"))
3119        self.db.insert('arraytest', r)
3120        self.assertEqual(r['i'], [1, 2, 3])
3121        self.assertEqual(r['t'], ['a', 'b', 'c'])
3122        r = dict(i="1, 2, 3", t="'a', 'b', 'c'")
3123        self.assertRaises(pg.ProgrammingError, self.db.insert, 'arraytest', r)
3124
3125    def testArrayOfIds(self):
3126        self.createTable('arraytest', 'c cid[], o oid[], x xid[]', oids=True)
3127        r = self.db.get_attnames('arraytest')
3128        if self.regtypes:
3129            self.assertEqual(r, dict(
3130                oid='oid', c='cid[]', o='oid[]', x='xid[]'))
3131        else:
3132            self.assertEqual(r, dict(
3133                oid='int', c='int[]', o='int[]', x='int[]'))
3134        data = dict(c=[11, 12, 13], o=[21, 22, 23], x=[31, 32, 33])
3135        r = data.copy()
3136        self.db.insert('arraytest', r)
3137        qoid = 'oid(arraytest)'
3138        oid = r.pop(qoid)
3139        self.assertEqual(r, data)
3140        r = {qoid: oid}
3141        self.db.get('arraytest', r)
3142        self.assertEqual(oid, r.pop(qoid))
3143        self.assertEqual(r, data)
3144
3145    def testArrayOfText(self):
3146        self.createTable('arraytest', 'data text[]', oids=True)
3147        r = self.db.get_attnames('arraytest')
3148        self.assertEqual(r['data'], 'text[]')
3149        data = ['Hello, World!', '', None, '{a,b,c}', '"Hi!"',
3150                'null', 'NULL', 'Null', 'nulL',
3151                "It's all \\ kinds of\r nasty stuff!\n"]
3152        r = dict(data=data)
3153        self.db.insert('arraytest', r)
3154        self.assertEqual(r['data'], data)
3155        self.assertIsInstance(r['data'][1], str)
3156        self.assertIsNone(r['data'][2])
3157        r['data'] = None
3158        self.db.get('arraytest', r)
3159        self.assertEqual(r['data'], data)
3160        self.assertIsInstance(r['data'][1], str)
3161        self.assertIsNone(r['data'][2])
3162
3163    def testArrayOfBytea(self):
3164        self.createTable('arraytest', 'data bytea[]', oids=True)
3165        r = self.db.get_attnames('arraytest')
3166        self.assertEqual(r['data'], 'bytea[]')
3167        data = [b'Hello, World!', b'', None, b'{a,b,c}', b'"Hi!"',
3168                b"It's all \\ kinds \x00 of\r nasty \xff stuff!\n"]
3169        r = dict(data=data)
3170        self.db.insert('arraytest', r)
3171        self.assertEqual(r['data'], data)
3172        self.assertIsInstance(r['data'][1], bytes)
3173        self.assertIsNone(r['data'][2])
3174        r['data'] = None
3175        self.db.get('arraytest', r)
3176        self.assertEqual(r['data'], data)
3177        self.assertIsInstance(r['data'][1], bytes)
3178        self.assertIsNone(r['data'][2])
3179
3180    def testArrayOfJson(self):
3181        try:
3182            self.createTable('arraytest', 'data json[]', oids=True)
3183        except pg.ProgrammingError as error:
3184            if self.db.server_version < 90200:
3185                self.skipTest('database does not support json')
3186            self.fail(str(error))
3187        r = self.db.get_attnames('arraytest')
3188        self.assertEqual(r['data'], 'json[]')
3189        data = [dict(id=815, name='John Doe'), dict(id=816, name='Jane Roe')]
3190        jsondecode = pg.get_jsondecode()
3191        r = dict(data=data)
3192        self.db.insert('arraytest', r)
3193        if jsondecode is None:
3194            r['data'] = [json.loads(d) for d in r['data']]
3195        self.assertEqual(r['data'], data)
3196        r['data'] = None
3197        self.db.get('arraytest', r)
3198        if jsondecode is None:
3199            r['data'] = [json.loads(d) for d in r['data']]
3200        self.assertEqual(r['data'], data)
3201        r = dict(data=[json.dumps(d) for d in data])
3202        self.db.insert('arraytest', r)
3203        if jsondecode is None:
3204            r['data'] = [json.loads(d) for d in r['data']]
3205        self.assertEqual(r['data'], data)
3206        r['data'] = None
3207        self.db.get('arraytest', r)
3208        # insert empty json values
3209        r = dict(data=['', None])
3210        self.db.insert('arraytest', r)
3211        r = r['data']
3212        self.assertIsInstance(r, list)
3213        self.assertEqual(len(r), 2)
3214        self.assertIsNone(r[0])
3215        self.assertIsNone(r[1])
3216
3217    def testArrayOfJsonb(self):
3218        try:
3219            self.createTable('arraytest', 'data jsonb[]', oids=True)
3220        except pg.ProgrammingError as error:
3221            if self.db.server_version < 90400:
3222                self.skipTest('database does not support jsonb')
3223            self.fail(str(error))
3224        r = self.db.get_attnames('arraytest')
3225        self.assertEqual(r['data'], 'jsonb[]' if self.regtypes else 'json[]')
3226        data = [dict(id=815, name='John Doe'), dict(id=816, name='Jane Roe')]
3227        jsondecode = pg.get_jsondecode()
3228        r = dict(data=data)
3229        self.db.insert('arraytest', r)
3230        if jsondecode is None:
3231            r['data'] = [json.loads(d) for d in r['data']]
3232        self.assertEqual(r['data'], data)
3233        r['data'] = None
3234        self.db.get('arraytest', r)
3235        if jsondecode is None:
3236            r['data'] = [json.loads(d) for d in r['data']]
3237        self.assertEqual(r['data'], data)
3238        r = dict(data=[json.dumps(d) for d in data])
3239        self.db.insert('arraytest', r)
3240        if jsondecode is None:
3241            r['data'] = [json.loads(d) for d in r['data']]
3242        self.assertEqual(r['data'], data)
3243        r['data'] = None
3244        self.db.get('arraytest', r)
3245        # insert empty json values
3246        r = dict(data=['', None])
3247        self.db.insert('arraytest', r)
3248        r = r['data']
3249        self.assertIsInstance(r, list)
3250        self.assertEqual(len(r), 2)
3251        self.assertIsNone(r[0])
3252        self.assertIsNone(r[1])
3253
3254    def testDeepArray(self):
3255        self.createTable('arraytest', 'data text[][][]', oids=True)
3256        r = self.db.get_attnames('arraytest')
3257        self.assertEqual(r['data'], 'text[]')
3258        data = [[['Hello, World!', '{a,b,c}', 'back\\slash']]]
3259        r = dict(data=data)
3260        self.db.insert('arraytest', r)
3261        self.assertEqual(r['data'], data)
3262        r['data'] = None
3263        self.db.get('arraytest', r)
3264        self.assertEqual(r['data'], data)
3265
3266    def testInsertUpdateGetRecord(self):
3267        query = self.db.query
3268        query('create type test_person_type as'
3269              ' (name varchar, age smallint, married bool,'
3270              ' weight real, salary money)')
3271        self.addCleanup(query, 'drop type test_person_type')
3272        self.createTable('test_person', 'person test_person_type',
3273                         temporary=False, oids=True)
3274        attnames = self.db.get_attnames('test_person')
3275        self.assertEqual(len(attnames), 2)
3276        self.assertIn('oid', attnames)
3277        self.assertIn('person', attnames)
3278        person_typ = attnames['person']
3279        if self.regtypes:
3280            self.assertEqual(person_typ, 'test_person_type')
3281        else:
3282            self.assertEqual(person_typ, 'record')
3283        if self.regtypes:
3284            self.assertEqual(person_typ.attnames,
3285                dict(name='character varying', age='smallint',
3286                    married='boolean', weight='real', salary='money'))
3287        else:
3288            self.assertEqual(person_typ.attnames,
3289                dict(name='text', age='int', married='bool',
3290                    weight='float', salary='money'))
3291        decimal = pg.get_decimal()
3292        if pg.get_bool():
3293            bool_class = bool
3294            t, f = True, False
3295        else:
3296            bool_class = str
3297            t, f = 't', 'f'
3298        person = ('John Doe', 61, t, 99.5, decimal('93456.75'))
3299        r = self.db.insert('test_person', None, person=person)
3300        p = r['person']
3301        self.assertIsInstance(p, tuple)
3302        self.assertEqual(p, person)
3303        self.assertEqual(p.name, 'John Doe')
3304        self.assertIsInstance(p.name, str)
3305        self.assertIsInstance(p.age, int)
3306        self.assertIsInstance(p.married, bool_class)
3307        self.assertIsInstance(p.weight, float)
3308        self.assertIsInstance(p.salary, decimal)
3309        person = ('Jane Roe', 59, f, 64.5, decimal('96543.25'))
3310        r['person'] = person
3311        self.db.update('test_person', r)
3312        p = r['person']
3313        self.assertIsInstance(p, tuple)
3314        self.assertEqual(p, person)
3315        self.assertEqual(p.name, 'Jane Roe')
3316        self.assertIsInstance(p.name, str)
3317        self.assertIsInstance(p.age, int)
3318        self.assertIsInstance(p.married, bool_class)
3319        self.assertIsInstance(p.weight, float)
3320        self.assertIsInstance(p.salary, decimal)
3321        r['person'] = None
3322        self.db.get('test_person', r)
3323        p = r['person']
3324        self.assertIsInstance(p, tuple)
3325        self.assertEqual(p, person)
3326        self.assertEqual(p.name, 'Jane Roe')
3327        self.assertIsInstance(p.name, str)
3328        self.assertIsInstance(p.age, int)
3329        self.assertIsInstance(p.married, bool_class)
3330        self.assertIsInstance(p.weight, float)
3331        self.assertIsInstance(p.salary, decimal)
3332        person = (None,) * 5
3333        r = self.db.insert('test_person', None, person=person)
3334        p = r['person']
3335        self.assertIsInstance(p, tuple)
3336        self.assertIsNone(p.name)
3337        self.assertIsNone(p.age)
3338        self.assertIsNone(p.married)
3339        self.assertIsNone(p.weight)
3340        self.assertIsNone(p.salary)
3341        r['person'] = None
3342        self.db.get('test_person', r)
3343        p = r['person']
3344        self.assertIsInstance(p, tuple)
3345        self.assertIsNone(p.name)
3346        self.assertIsNone(p.age)
3347        self.assertIsNone(p.married)
3348        self.assertIsNone(p.weight)
3349        self.assertIsNone(p.salary)
3350        r = self.db.insert('test_person', None, person=None)
3351        self.assertIsNone(r['person'])
3352        r['person'] = None
3353        self.db.get('test_person', r)
3354        self.assertIsNone(r['person'])
3355
3356    def testRecordInsertBytea(self):
3357        query = self.db.query
3358        query('create type test_person_type as'
3359              ' (name text, picture bytea)')
3360        self.addCleanup(query, 'drop type test_person_type')
3361        self.createTable('test_person', 'person test_person_type',
3362                         temporary=False, oids=True)
3363        person_typ = self.db.get_attnames('test_person')['person']
3364        self.assertEqual(person_typ.attnames,
3365                         dict(name='text', picture='bytea'))
3366        person = ('John Doe', b'O\x00ps\xff!')
3367        r = self.db.insert('test_person', None, person=person)
3368        p = r['person']
3369        self.assertIsInstance(p, tuple)
3370        self.assertEqual(p, person)
3371        self.assertEqual(p.name, 'John Doe')
3372        self.assertIsInstance(p.name, str)
3373        self.assertEqual(p.picture, person[1])
3374        self.assertIsInstance(p.picture, bytes)
3375
3376    def testRecordInsertJson(self):
3377        query = self.db.query
3378        try:
3379            query('create type test_person_type as'
3380                  ' (name text, data json)')
3381        except pg.ProgrammingError as error:
3382            if self.db.server_version < 90200:
3383                self.skipTest('database does not support json')
3384            self.fail(str(error))
3385        self.addCleanup(query, 'drop type test_person_type')
3386        self.createTable('test_person', 'person test_person_type',
3387                         temporary=False, oids=True)
3388        person_typ = self.db.get_attnames('test_person')['person']
3389        self.assertEqual(person_typ.attnames,
3390                         dict(name='text', data='json'))
3391        person = ('John Doe', dict(age=61, married=True, weight=99.5))
3392        r = self.db.insert('test_person', None, person=person)
3393        p = r['person']
3394        self.assertIsInstance(p, tuple)
3395        if pg.get_jsondecode() is None:
3396            p = p._replace(data=json.loads(p.data))
3397        self.assertEqual(p, person)
3398        self.assertEqual(p.name, 'John Doe')
3399        self.assertIsInstance(p.name, str)
3400        self.assertEqual(p.data, person[1])
3401        self.assertIsInstance(p.data, dict)
3402
3403    def testRecordLiteral(self):
3404        query = self.db.query
3405        query('create type test_person_type as'
3406              ' (name varchar, age smallint)')
3407        self.addCleanup(query, 'drop type test_person_type')
3408        self.createTable('test_person', 'person test_person_type',
3409                         temporary=False, oids=True)
3410        person_typ = self.db.get_attnames('test_person')['person']
3411        if self.regtypes:
3412            self.assertEqual(person_typ, 'test_person_type')
3413        else:
3414            self.assertEqual(person_typ, 'record')
3415        if self.regtypes:
3416            self.assertEqual(person_typ.attnames,
3417                             dict(name='character varying', age='smallint'))
3418        else:
3419            self.assertEqual(person_typ.attnames,
3420                             dict(name='text', age='int'))
3421        person = pg.Literal("('John Doe', 61)")
3422        r = self.db.insert('test_person', None, person=person)
3423        p = r['person']
3424        self.assertIsInstance(p, tuple)
3425        self.assertEqual(p.name, 'John Doe')
3426        self.assertIsInstance(p.name, str)
3427        self.assertEqual(p.age, 61)
3428        self.assertIsInstance(p.age, int)
3429
3430    def testDbTypesInfo(self):
3431        dbtypes = self.db.dbtypes
3432        self.assertIsInstance(dbtypes, dict)
3433        self.assertNotIn('numeric', dbtypes)
3434        typ = dbtypes['numeric']
3435        self.assertIn('numeric', dbtypes)
3436        self.assertEqual(typ, 'numeric' if self.regtypes else 'num')
3437        self.assertEqual(typ.oid, 1700)
3438        self.assertEqual(typ.pgtype, 'numeric')
3439        self.assertEqual(typ.regtype, 'numeric')
3440        self.assertEqual(typ.simple, 'num')
3441        self.assertEqual(typ.typtype, 'b')
3442        self.assertEqual(typ.category, 'N')
3443        self.assertEqual(typ.delim, ',')
3444        self.assertEqual(typ.relid, 0)
3445        self.assertIs(dbtypes[1700], typ)
3446        self.assertNotIn('pg_type', dbtypes)
3447        typ = dbtypes['pg_type']
3448        self.assertIn('pg_type', dbtypes)
3449        self.assertEqual(typ, 'pg_type' if self.regtypes else 'record')
3450        self.assertIsInstance(typ.oid, int)
3451        self.assertEqual(typ.pgtype, 'pg_type')
3452        self.assertEqual(typ.regtype, 'pg_type')
3453        self.assertEqual(typ.simple, 'record')
3454        self.assertEqual(typ.typtype, 'c')
3455        self.assertEqual(typ.category, 'C')
3456        self.assertEqual(typ.delim, ',')
3457        self.assertNotEqual(typ.relid, 0)
3458        attnames = typ.attnames
3459        self.assertIsInstance(attnames, dict)
3460        self.assertIs(attnames, dbtypes.get_attnames('pg_type'))
3461        self.assertEqual(list(attnames)[0], 'typname')
3462        typname = attnames['typname']
3463        self.assertEqual(typname, 'name' if self.regtypes else 'text')
3464        self.assertEqual(typname.typtype, 'b')  # base
3465        self.assertEqual(typname.category, 'S')  # string
3466        self.assertEqual(list(attnames)[3], 'typlen')
3467        typlen = attnames['typlen']
3468        self.assertEqual(typlen, 'smallint' if self.regtypes else 'int')
3469        self.assertEqual(typlen.typtype, 'b')  # base
3470        self.assertEqual(typlen.category, 'N')  # numeric
3471
3472    def testDbTypesTypecast(self):
3473        dbtypes = self.db.dbtypes
3474        self.assertIsInstance(dbtypes, dict)
3475        self.assertNotIn('int4', dbtypes)
3476        self.assertIs(dbtypes.get_typecast('int4'), int)
3477        dbtypes.set_typecast('int4', float)
3478        self.assertIs(dbtypes.get_typecast('int4'), float)
3479        dbtypes.reset_typecast('int4')
3480        self.assertIs(dbtypes.get_typecast('int4'), int)
3481        dbtypes.set_typecast('int4', float)
3482        self.assertIs(dbtypes.get_typecast('int4'), float)
3483        dbtypes.reset_typecast()
3484        self.assertIs(dbtypes.get_typecast('int4'), int)
3485        self.assertNotIn('circle', dbtypes)
3486        self.assertIsNone(dbtypes.get_typecast('circle'))
3487        squared_circle = lambda v: 'Squared Circle: %s' % v
3488        dbtypes.set_typecast('circle', squared_circle)
3489        self.assertIs(dbtypes.get_typecast('circle'), squared_circle)
3490        r = self.db.query("select '0,0,1'::circle").getresult()[0][0]
3491        self.assertIn('circle', dbtypes)
3492        self.assertEqual(r, 'Squared Circle: <(0,0),1>')
3493        self.assertEqual(dbtypes.typecast('Impossible', 'circle'),
3494            'Squared Circle: Impossible')
3495        dbtypes.reset_typecast('circle')
3496        self.assertIsNone(dbtypes.get_typecast('circle'))
3497
3498    def testGetSetTypeCast(self):
3499        get_typecast = pg.get_typecast
3500        set_typecast = pg.set_typecast
3501        dbtypes = self.db.dbtypes
3502        self.assertIsInstance(dbtypes, dict)
3503        self.assertNotIn('int4', dbtypes)
3504        self.assertNotIn('real', dbtypes)
3505        self.assertNotIn('bool', dbtypes)
3506        self.assertIs(get_typecast('int4'), int)
3507        self.assertIs(get_typecast('float4'), float)
3508        self.assertIs(get_typecast('bool'), pg.cast_bool)
3509        cast_circle = get_typecast('circle')
3510        self.addCleanup(set_typecast, 'circle', cast_circle)
3511        squared_circle = lambda v: 'Squared Circle: %s' % v
3512        self.assertNotIn('circle', dbtypes)
3513        set_typecast('circle', squared_circle)
3514        self.assertNotIn('circle', dbtypes)
3515        self.assertIs(get_typecast('circle'), squared_circle)
3516        r = self.db.query("select '0,0,1'::circle").getresult()[0][0]
3517        self.assertIn('circle', dbtypes)
3518        self.assertEqual(r, 'Squared Circle: <(0,0),1>')
3519        set_typecast('circle', cast_circle)
3520        self.assertIs(get_typecast('circle'), cast_circle)
3521
3522    def testNotificationHandler(self):
3523        # the notification handler itself is tested separately
3524        f = self.db.notification_handler
3525        callback = lambda arg_dict: None
3526        handler = f('test', callback)
3527        self.assertIsInstance(handler, pg.NotificationHandler)
3528        self.assertIs(handler.db, self.db)
3529        self.assertEqual(handler.event, 'test')
3530        self.assertEqual(handler.stop_event, 'stop_test')
3531        self.assertIs(handler.callback, callback)
3532        self.assertIsInstance(handler.arg_dict, dict)
3533        self.assertEqual(handler.arg_dict, {})
3534        self.assertIsNone(handler.timeout)
3535        self.assertFalse(handler.listening)
3536        handler.close()
3537        self.assertIsNone(handler.db)
3538        self.db.reopen()
3539        self.assertIsNone(handler.db)
3540        handler = f('test2', callback, timeout=2)
3541        self.assertIsInstance(handler, pg.NotificationHandler)
3542        self.assertIs(handler.db, self.db)
3543        self.assertEqual(handler.event, 'test2')
3544        self.assertEqual(handler.stop_event, 'stop_test2')
3545        self.assertIs(handler.callback, callback)
3546        self.assertIsInstance(handler.arg_dict, dict)
3547        self.assertEqual(handler.arg_dict, {})
3548        self.assertEqual(handler.timeout, 2)
3549        self.assertFalse(handler.listening)
3550        handler.close()
3551        self.assertIsNone(handler.db)
3552        self.db.reopen()
3553        self.assertIsNone(handler.db)
3554        arg_dict = {'testing': 3}
3555        handler = f('test3', callback, arg_dict=arg_dict)
3556        self.assertIsInstance(handler, pg.NotificationHandler)
3557        self.assertIs(handler.db, self.db)
3558        self.assertEqual(handler.event, 'test3')
3559        self.assertEqual(handler.stop_event, 'stop_test3')
3560        self.assertIs(handler.callback, callback)
3561        self.assertIs(handler.arg_dict, arg_dict)
3562        self.assertEqual(arg_dict['testing'], 3)
3563        self.assertIsNone(handler.timeout)
3564        self.assertFalse(handler.listening)
3565        handler.close()
3566        self.assertIsNone(handler.db)
3567        self.db.reopen()
3568        self.assertIsNone(handler.db)
3569        handler = f('test4', callback, stop_event='stop4')
3570        self.assertIsInstance(handler, pg.NotificationHandler)
3571        self.assertIs(handler.db, self.db)
3572        self.assertEqual(handler.event, 'test4')
3573        self.assertEqual(handler.stop_event, 'stop4')
3574        self.assertIs(handler.callback, callback)
3575        self.assertIsInstance(handler.arg_dict, dict)
3576        self.assertEqual(handler.arg_dict, {})
3577        self.assertIsNone(handler.timeout)
3578        self.assertFalse(handler.listening)
3579        handler.close()
3580        self.assertIsNone(handler.db)
3581        self.db.reopen()
3582        self.assertIsNone(handler.db)
3583        arg_dict = {'testing': 5}
3584        handler = f('test5', callback, arg_dict, 1.5, 'stop5')
3585        self.assertIsInstance(handler, pg.NotificationHandler)
3586        self.assertIs(handler.db, self.db)
3587        self.assertEqual(handler.event, 'test5')
3588        self.assertEqual(handler.stop_event, 'stop5')
3589        self.assertIs(handler.callback, callback)
3590        self.assertIs(handler.arg_dict, arg_dict)
3591        self.assertEqual(arg_dict['testing'], 5)
3592        self.assertEqual(handler.timeout, 1.5)
3593        self.assertFalse(handler.listening)
3594        handler.close()
3595        self.assertIsNone(handler.db)
3596        self.db.reopen()
3597        self.assertIsNone(handler.db)
3598
3599
3600class TestDBClassNonStdOpts(TestDBClass):
3601    """Test the methods of the DB class with non-standard global options."""
3602
3603    @classmethod
3604    def setUpClass(cls):
3605        cls.saved_options = {}
3606        cls.set_option('decimal', float)
3607        not_bool = not pg.get_bool()
3608        cls.set_option('bool', not_bool)
3609        cls.set_option('namedresult', None)
3610        cls.set_option('jsondecode', None)
3611        cls.regtypes = not DB().use_regtypes()
3612        super(TestDBClassNonStdOpts, cls).setUpClass()
3613
3614    @classmethod
3615    def tearDownClass(cls):
3616        super(TestDBClassNonStdOpts, cls).tearDownClass()
3617        cls.reset_option('jsondecode')
3618        cls.reset_option('namedresult')
3619        cls.reset_option('bool')
3620        cls.reset_option('decimal')
3621
3622    @classmethod
3623    def set_option(cls, option, value):
3624        cls.saved_options[option] = getattr(pg, 'get_' + option)()
3625        return getattr(pg, 'set_' + option)(value)
3626
3627    @classmethod
3628    def reset_option(cls, option):
3629        return getattr(pg, 'set_' + option)(cls.saved_options[option])
3630
3631
3632class TestDBClassAdapter(unittest.TestCase):
3633    """Test the adapter object associatd with the DB class."""
3634
3635    def setUp(self):
3636        self.db = DB()
3637        self.adapter = self.db.adapter
3638
3639    def tearDown(self):
3640        try:
3641            self.db.close()
3642        except pg.InternalError:
3643            pass
3644
3645    def testGuessSimpleType(self):
3646        f = self.adapter.guess_simple_type
3647        self.assertEqual(f(pg.Bytea(b'test')), 'bytea')
3648        self.assertEqual(f('string'), 'text')
3649        self.assertEqual(f(b'string'), 'text')
3650        self.assertEqual(f(True), 'bool')
3651        self.assertEqual(f(3), 'int')
3652        self.assertEqual(f(2.75), 'float')
3653        self.assertEqual(f(Decimal('4.25')), 'num')
3654        self.assertEqual(f(date(2016, 1, 30)), 'date')
3655        self.assertEqual(f([1, 2, 3]), 'int[]')
3656        self.assertEqual(f([[[123]]]), 'int[]')
3657        self.assertEqual(f(['a', 'b', 'c']), 'text[]')
3658        self.assertEqual(f([[['abc']]]), 'text[]')
3659        self.assertEqual(f([False, True]), 'bool[]')
3660        self.assertEqual(f([[[False]]]), 'bool[]')
3661        r = f(('string', True, 3, 2.75, [1], [False]))
3662        self.assertEqual(r, 'record')
3663        self.assertEqual(list(r.attnames.values()),
3664            ['text', 'bool', 'int', 'float', 'int[]', 'bool[]'])
3665
3666    def testAdaptQueryTypedList(self):
3667        format_query = self.adapter.format_query
3668        self.assertRaises(TypeError, format_query,
3669            '%s,%s', (1, 2), ('int2',))
3670        self.assertRaises(TypeError, format_query,
3671            '%s,%s', (1,), ('int2', 'int2'))
3672        values = (3, 7.5, 'hello', True)
3673        types = ('int4', 'float4', 'text', 'bool')
3674        sql, params = format_query("select %s,%s,%s,%s", values, types)
3675        self.assertEqual(sql, 'select $1,$2,$3,$4')
3676        self.assertEqual(params, [3, 7.5, 'hello', 't'])
3677        types = ('bool', 'bool', 'bool', 'bool')
3678        sql, params = format_query("select %s,%s,%s,%s", values, types)
3679        self.assertEqual(sql, 'select $1,$2,$3,$4')
3680        self.assertEqual(params, ['t', 't', 'f', 't'])
3681        values = ('2016-01-30', 'current_date')
3682        types = ('date', 'date')
3683        sql, params = format_query("values(%s,%s)", values, types)
3684        self.assertEqual(sql, 'values($1,current_date)')
3685        self.assertEqual(params, ['2016-01-30'])
3686        values = ([1, 2, 3], ['a', 'b', 'c'])
3687        types = ('_int4', '_text')
3688        sql, params = format_query("%s::int4[],%s::text[]", values, types)
3689        self.assertEqual(sql, '$1::int4[],$2::text[]')
3690        self.assertEqual(params, ['{1,2,3}', '{a,b,c}'])
3691        types = ('_bool', '_bool')
3692        sql, params = format_query("%s::bool[],%s::bool[]", values, types)
3693        self.assertEqual(sql, '$1::bool[],$2::bool[]')
3694        self.assertEqual(params, ['{t,t,t}', '{f,f,f}'])
3695        values = [(3, 7.5, 'hello', True, [123], ['abc'])]
3696        t = self.adapter.simple_type
3697        typ = t('record')
3698        typ._get_attnames = lambda _self: pg.AttrDict([
3699            ('i', t('int')), ('f', t('float')),
3700            ('t', t('text')), ('b', t('bool')),
3701            ('i3', t('int[]')), ('t3', t('text[]'))])
3702        types = [typ]
3703        sql, params = format_query('select %s', values, types)
3704        self.assertEqual(sql, 'select $1')
3705        self.assertEqual(params, ['(3,7.5,hello,t,{123},{abc})'])
3706
3707    def testAdaptQueryTypedDict(self):
3708        format_query = self.adapter.format_query
3709        self.assertRaises(TypeError, format_query,
3710            '%s,%s', dict(i1=1, i2=2), dict(i1='int2'))
3711        values = dict(i=3, f=7.5, t='hello', b=True)
3712        types = dict(i='int4', f='float4',
3713            t='text', b='bool')
3714        sql, params = format_query(
3715            "select %(i)s,%(f)s,%(t)s,%(b)s", values, types)
3716        self.assertEqual(sql, 'select $3,$2,$4,$1')
3717        self.assertEqual(params, ['t', 7.5, 3, 'hello'])
3718        types = dict(i='bool', f='bool',
3719            t='bool', b='bool')
3720        sql, params = format_query(
3721            "select %(i)s,%(f)s,%(t)s,%(b)s", values, types)
3722        self.assertEqual(sql, 'select $3,$2,$4,$1')
3723        self.assertEqual(params, ['t', 't', 't', 'f'])
3724        values = dict(d1='2016-01-30', d2='current_date')
3725        types = dict(d1='date', d2='date')
3726        sql, params = format_query("values(%(d1)s,%(d2)s)", values, types)
3727        self.assertEqual(sql, 'values($1,current_date)')
3728        self.assertEqual(params, ['2016-01-30'])
3729        values = dict(i=[1, 2, 3], t=['a', 'b', 'c'])
3730        types = dict(i='_int4', t='_text')
3731        sql, params = format_query(
3732            "%(i)s::int4[],%(t)s::text[]", values, types)
3733        self.assertEqual(sql, '$1::int4[],$2::text[]')
3734        self.assertEqual(params, ['{1,2,3}', '{a,b,c}'])
3735        types = dict(i='_bool', t='_bool')
3736        sql, params = format_query(
3737            "%(i)s::bool[],%(t)s::bool[]", values, types)
3738        self.assertEqual(sql, '$1::bool[],$2::bool[]')
3739        self.assertEqual(params, ['{t,t,t}', '{f,f,f}'])
3740        values = dict(record=(3, 7.5, 'hello', True, [123], ['abc']))
3741        t = self.adapter.simple_type
3742        typ = t('record')
3743        typ._get_attnames = lambda _self: pg.AttrDict([
3744            ('i', t('int')), ('f', t('float')),
3745            ('t', t('text')), ('b', t('bool')),
3746            ('i3', t('int[]')), ('t3', t('text[]'))])
3747        types = dict(record=typ)
3748        sql, params = format_query('select %(record)s', values, types)
3749        self.assertEqual(sql, 'select $1')
3750        self.assertEqual(params, ['(3,7.5,hello,t,{123},{abc})'])
3751
3752    def testAdaptQueryUntypedList(self):
3753        format_query = self.adapter.format_query
3754        values = (3, 7.5, 'hello', True)
3755        sql, params = format_query("select %s,%s,%s,%s", values)
3756        self.assertEqual(sql, 'select $1,$2,$3,$4')
3757        self.assertEqual(params, [3, 7.5, 'hello', 't'])
3758        values = [date(2016, 1, 30), 'current_date']
3759        sql, params = format_query("values(%s,%s)", values)
3760        self.assertEqual(sql, 'values($1,$2)')
3761        self.assertEqual(params, values)
3762        values = ([1, 2, 3], ['a', 'b', 'c'], [True, False, True])
3763        sql, params = format_query("%s,%s,%s", values)
3764        self.assertEqual(sql, "$1,$2,$3")
3765        self.assertEqual(params, ['{1,2,3}', '{a,b,c}', '{t,f,t}'])
3766        values = ([[1, 2], [3, 4]], [['a', 'b'], ['c', 'd']],
3767            [[True, False], [False, True]])
3768        sql, params = format_query("%s,%s,%s", values)
3769        self.assertEqual(sql, "$1,$2,$3")
3770        self.assertEqual(params, [
3771            '{{1,2},{3,4}}', '{{a,b},{c,d}}', '{{t,f},{f,t}}'])
3772        values = [(3, 7.5, 'hello', True, [123], ['abc'])]
3773        sql, params = format_query('select %s', values)
3774        self.assertEqual(sql, 'select $1')
3775        self.assertEqual(params, ['(3,7.5,hello,t,{123},{abc})'])
3776
3777    def testAdaptQueryUntypedDict(self):
3778        format_query = self.adapter.format_query
3779        values = dict(i=3, f=7.5, t='hello', b=True)
3780        sql, params = format_query(
3781            "select %(i)s,%(f)s,%(t)s,%(b)s", values)
3782        self.assertEqual(sql, 'select $3,$2,$4,$1')
3783        self.assertEqual(params, ['t', 7.5, 3, 'hello'])
3784        values = dict(d1='2016-01-30', d2='current_date')
3785        sql, params = format_query("values(%(d1)s,%(d2)s)", values)
3786        self.assertEqual(sql, 'values($1,$2)')
3787        self.assertEqual(params, [values['d1'], values['d2']])
3788        values = dict(i=[1, 2, 3], t=['a', 'b', 'c'], b=[True, False, True])
3789        sql, params = format_query("%(i)s,%(t)s,%(b)s", values)
3790        self.assertEqual(sql, "$2,$3,$1")
3791        self.assertEqual(params, ['{t,f,t}', '{1,2,3}', '{a,b,c}'])
3792        values = dict(i=[[1, 2], [3, 4]], t=[['a', 'b'], ['c', 'd']],
3793            b=[[True, False], [False, True]])
3794        sql, params = format_query("%(i)s,%(t)s,%(b)s", values)
3795        self.assertEqual(sql, "$2,$3,$1")
3796        self.assertEqual(params, [
3797            '{{t,f},{f,t}}', '{{1,2},{3,4}}', '{{a,b},{c,d}}'])
3798        values = dict(record=(3, 7.5, 'hello', True, [123], ['abc']))
3799        sql, params = format_query('select %(record)s', values)
3800        self.assertEqual(sql, 'select $1')
3801        self.assertEqual(params, ['(3,7.5,hello,t,{123},{abc})'])
3802
3803    def testAdaptQueryInlineList(self):
3804        format_query = self.adapter.format_query
3805        values = (3, 7.5, 'hello', True)
3806        sql, params = format_query("select %s,%s,%s,%s", values, inline=True)
3807        self.assertEqual(sql, "select 3,7.5,'hello',true")
3808        self.assertEqual(params, [])
3809        values = [date(2016, 1, 30), 'current_date']
3810        sql, params = format_query("values(%s,%s)", values, inline=True)
3811        self.assertEqual(sql, "values('2016-01-30','current_date')")
3812        self.assertEqual(params, [])
3813        values = ([1, 2, 3], ['a', 'b', 'c'], [True, False, True])
3814        sql, params = format_query("%s,%s,%s", values, inline=True)
3815        self.assertEqual(sql,
3816            "ARRAY[1,2,3],ARRAY['a','b','c'],ARRAY[true,false,true]")
3817        self.assertEqual(params, [])
3818        values = ([[1, 2], [3, 4]], [['a', 'b'], ['c', 'd']],
3819            [[True, False], [False, True]])
3820        sql, params = format_query("%s,%s,%s", values, inline=True)
3821        self.assertEqual(sql, "ARRAY[[1,2],[3,4]],ARRAY[['a','b'],['c','d']],"
3822            "ARRAY[[true,false],[false,true]]")
3823        self.assertEqual(params, [])
3824        values = [(3, 7.5, 'hello', True, [123], ['abc'])]
3825        sql, params = format_query('select %s', values, inline=True)
3826        self.assertEqual(sql,
3827            "select (3,7.5,'hello',true,ARRAY[123],ARRAY['abc'])")
3828        self.assertEqual(params, [])
3829
3830    def testAdaptQueryInlineDict(self):
3831        format_query = self.adapter.format_query
3832        values = dict(i=3, f=7.5, t='hello', b=True)
3833        sql, params = format_query(
3834            "select %(i)s,%(f)s,%(t)s,%(b)s", values, inline=True)
3835        self.assertEqual(sql, "select 3,7.5,'hello',true")
3836        self.assertEqual(params, [])
3837        values = dict(d1='2016-01-30', d2='current_date')
3838        sql, params = format_query(
3839            "values(%(d1)s,%(d2)s)", values, inline=True)
3840        self.assertEqual(sql, "values('2016-01-30','current_date')")
3841        self.assertEqual(params, [])
3842        values = dict(i=[1, 2, 3], t=['a', 'b', 'c'], b=[True, False, True])
3843        sql, params = format_query("%(i)s,%(t)s,%(b)s", values, inline=True)
3844        self.assertEqual(sql,
3845            "ARRAY[1,2,3],ARRAY['a','b','c'],ARRAY[true,false,true]")
3846        self.assertEqual(params, [])
3847        values = dict(i=[[1, 2], [3, 4]], t=[['a', 'b'], ['c', 'd']],
3848            b=[[True, False], [False, True]])
3849        sql, params = format_query("%(i)s,%(t)s,%(b)s", values, inline=True)
3850        self.assertEqual(sql, "ARRAY[[1,2],[3,4]],ARRAY[['a','b'],['c','d']],"
3851            "ARRAY[[true,false],[false,true]]")
3852        self.assertEqual(params, [])
3853        values = dict(record=(3, 7.5, 'hello', True, [123], ['abc']))
3854        sql, params = format_query('select %(record)s', values, inline=True)
3855        self.assertEqual(sql,
3856            "select (3,7.5,'hello',true,ARRAY[123],ARRAY['abc'])")
3857        self.assertEqual(params, [])
3858
3859    def testAdaptQueryWithPgRepr(self):
3860        format_query = self.adapter.format_query
3861        self.assertRaises(TypeError, format_query,
3862            '%s', object(), inline=True)
3863        class TestObject:
3864            def __pg_repr__(self):
3865                return "'adapted'"
3866        sql, params = format_query('select %s', [TestObject()], inline=True)
3867        self.assertEqual(sql, "select 'adapted'")
3868        self.assertEqual(params, [])
3869        sql, params = format_query('select %s', [[TestObject()]], inline=True)
3870        self.assertEqual(sql, "select ARRAY['adapted']")
3871        self.assertEqual(params, [])
3872
3873
3874class TestSchemas(unittest.TestCase):
3875    """Test correct handling of schemas (namespaces)."""
3876
3877    cls_set_up = False
3878
3879    @classmethod
3880    def setUpClass(cls):
3881        db = DB()
3882        query = db.query
3883        for num_schema in range(5):
3884            if num_schema:
3885                schema = "s%d" % num_schema
3886                query("drop schema if exists %s cascade" % (schema,))
3887                try:
3888                    query("create schema %s" % (schema,))
3889                except pg.ProgrammingError:
3890                    raise RuntimeError("The test user cannot create schemas.\n"
3891                        "Grant create on database %s to the user"
3892                        " for running these tests." % dbname)
3893            else:
3894                schema = "public"
3895                query("drop table if exists %s.t" % (schema,))
3896                query("drop table if exists %s.t%d" % (schema, num_schema))
3897            query("create table %s.t with oids as select 1 as n, %d as d"
3898                  % (schema, num_schema))
3899            query("create table %s.t%d with oids as select 1 as n, %d as d"
3900                  % (schema, num_schema, num_schema))
3901        db.close()
3902        cls.cls_set_up = True
3903
3904    @classmethod
3905    def tearDownClass(cls):
3906        db = DB()
3907        query = db.query
3908        for num_schema in range(5):
3909            if num_schema:
3910                schema = "s%d" % num_schema
3911                query("drop schema %s cascade" % (schema,))
3912            else:
3913                schema = "public"
3914                query("drop table %s.t" % (schema,))
3915                query("drop table %s.t%d" % (schema, num_schema))
3916        db.close()
3917
3918    def setUp(self):
3919        self.assertTrue(self.cls_set_up)
3920        self.db = DB()
3921
3922    def tearDown(self):
3923        self.doCleanups()
3924        self.db.close()
3925
3926    def testGetTables(self):
3927        tables = self.db.get_tables()
3928        for num_schema in range(5):
3929            if num_schema:
3930                schema = "s" + str(num_schema)
3931            else:
3932                schema = "public"
3933            for t in (schema + ".t",
3934                    schema + ".t" + str(num_schema)):
3935                self.assertIn(t, tables)
3936
3937    def testGetAttnames(self):
3938        get_attnames = self.db.get_attnames
3939        query = self.db.query
3940        result = {'oid': 'int', 'd': 'int', 'n': 'int'}
3941        r = get_attnames("t")
3942        self.assertEqual(r, result)
3943        r = get_attnames("s4.t4")
3944        self.assertEqual(r, result)
3945        query("drop table if exists s3.t3m")
3946        self.addCleanup(query, "drop table s3.t3m")
3947        query("create table s3.t3m with oids as select 1 as m")
3948        result_m = {'oid': 'int', 'm': 'int'}
3949        r = get_attnames("s3.t3m")
3950        self.assertEqual(r, result_m)
3951        query("set search_path to s1,s3")
3952        r = get_attnames("t3")
3953        self.assertEqual(r, result)
3954        r = get_attnames("t3m")
3955        self.assertEqual(r, result_m)
3956
3957    def testGet(self):
3958        get = self.db.get
3959        query = self.db.query
3960        PrgError = pg.ProgrammingError
3961        self.assertEqual(get("t", 1, 'n')['d'], 0)
3962        self.assertEqual(get("t0", 1, 'n')['d'], 0)
3963        self.assertEqual(get("public.t", 1, 'n')['d'], 0)
3964        self.assertEqual(get("public.t0", 1, 'n')['d'], 0)
3965        self.assertRaises(PrgError, get, "public.t1", 1, 'n')
3966        self.assertEqual(get("s1.t1", 1, 'n')['d'], 1)
3967        self.assertEqual(get("s3.t", 1, 'n')['d'], 3)
3968        query("set search_path to s2,s4")
3969        self.assertRaises(PrgError, get, "t1", 1, 'n')
3970        self.assertEqual(get("t4", 1, 'n')['d'], 4)
3971        self.assertRaises(PrgError, get, "t3", 1, 'n')
3972        self.assertEqual(get("t", 1, 'n')['d'], 2)
3973        self.assertEqual(get("s3.t3", 1, 'n')['d'], 3)
3974        query("set search_path to s1,s3")
3975        self.assertRaises(PrgError, get, "t2", 1, 'n')
3976        self.assertEqual(get("t3", 1, 'n')['d'], 3)
3977        self.assertRaises(PrgError, get, "t4", 1, 'n')
3978        self.assertEqual(get("t", 1, 'n')['d'], 1)
3979        self.assertEqual(get("s4.t4", 1, 'n')['d'], 4)
3980
3981    def testMunging(self):
3982        get = self.db.get
3983        query = self.db.query
3984        r = get("t", 1, 'n')
3985        self.assertIn('oid(t)', r)
3986        query("set search_path to s2")
3987        r = get("t2", 1, 'n')
3988        self.assertIn('oid(t2)', r)
3989        query("set search_path to s3")
3990        r = get("t", 1, 'n')
3991        self.assertIn('oid(t)', r)
3992
3993
3994class TestDebug(unittest.TestCase):
3995    """Test the debug attribute of the DB class."""
3996
3997    def setUp(self):
3998        self.db = DB()
3999        self.query = self.db.query
4000        self.debug = self.db.debug
4001        self.output = StringIO()
4002        self.stdout, sys.stdout = sys.stdout, self.output
4003
4004    def tearDown(self):
4005        sys.stdout = self.stdout
4006        self.output.close()
4007        self.db.debug = debug
4008        self.db.close()
4009
4010    def get_output(self):
4011        return self.output.getvalue()
4012
4013    def send_queries(self):
4014        self.db.query("select 1")
4015        self.db.query("select 2")
4016
4017    def testDebugDefault(self):
4018        if debug:
4019            self.assertEqual(self.db.debug, debug)
4020        else:
4021            self.assertIsNone(self.db.debug)
4022
4023    def testDebugIsFalse(self):
4024        self.db.debug = False
4025        self.send_queries()
4026        self.assertEqual(self.get_output(), "")
4027
4028    def testDebugIsTrue(self):
4029        self.db.debug = True
4030        self.send_queries()
4031        self.assertEqual(self.get_output(), "select 1\nselect 2\n")
4032
4033    def testDebugIsString(self):
4034        self.db.debug = "Test with string: %s."
4035        self.send_queries()
4036        self.assertEqual(self.get_output(),
4037            "Test with string: select 1.\nTest with string: select 2.\n")
4038
4039    def testDebugIsFileLike(self):
4040        with tempfile.TemporaryFile('w+') as debug_file:
4041            self.db.debug = debug_file
4042            self.send_queries()
4043            debug_file.seek(0)
4044            output = debug_file.read()
4045            self.assertEqual(output, "select 1\nselect 2\n")
4046            self.assertEqual(self.get_output(), "")
4047
4048    def testDebugIsCallable(self):
4049        output = []
4050        self.db.debug = output.append
4051        self.db.query("select 1")
4052        self.db.query("select 2")
4053        self.assertEqual(output, ["select 1", "select 2"])
4054        self.assertEqual(self.get_output(), "")
4055
4056    def testDebugMultipleArgs(self):
4057        output = []
4058        self.db.debug = output.append
4059        args = ['Error', 42, {1: 'a', 2: 'b'}, [3, 5, 7]]
4060        self.db._do_debug(*args)
4061        self.assertEqual(output, ['\n'.join(str(arg) for arg in args)])
4062        self.assertEqual(self.get_output(), "")
4063
4064
4065if __name__ == '__main__':
4066    unittest.main()
Note: See TracBrowser for help on using the repository browser.