source: trunk/tests/test_classic_dbwrapper.py @ 804

Last change on this file since 804 was 804, checked in by cito, 4 years ago

Make the automatic conversion to arrays configurable

The automatic conversion of arrays to lists can now be
disabled with the set_array() method.

  • Property svn:executable set to *
  • Property svn:keywords set to Id
File size: 167.6 KB
Line 
1#! /usr/bin/python
2# -*- coding: utf-8 -*-
3
4"""Test the classic PyGreSQL interface.
5
6Sub-tests for the DB wrapper object.
7
8Contributed by Christoph Zwerschke.
9
10These tests need a database to test against.
11"""
12
13try:
14    import unittest2 as unittest  # for Python < 2.7
15except ImportError:
16    import unittest
17
18import os
19import sys
20import tempfile
21import json
22
23import pg  # the module under test
24
25from decimal import Decimal
26from datetime import date
27from operator import itemgetter
28
29# We need a database to test against.  If LOCAL_PyGreSQL.py exists we will
30# get our information from that.  Otherwise we use the defaults.
31# The current user must have create schema privilege on the database.
32dbname = 'unittest'
33dbhost = None
34dbport = 5432
35
36debug = False  # let DB wrapper print debugging output
37
38try:
39    from .LOCAL_PyGreSQL import *
40except (ImportError, ValueError):
41    try:
42        from LOCAL_PyGreSQL import *
43    except ImportError:
44        pass
45
46try:
47    long
48except NameError:  # Python >= 3.0
49    long = int
50
51try:
52    unicode
53except NameError:  # Python >= 3.0
54    unicode = str
55
56try:
57    from collections import OrderedDict
58except ImportError:  # Python 2.6 or 3.0
59    OrderedDict = dict
60
61if str is bytes:
62    from StringIO import StringIO
63else:
64    from io import StringIO
65
66windows = os.name == 'nt'
67
68# There is a known a bug in libpq under Windows which can cause
69# the interface to crash when calling PQhost():
70do_not_ask_for_host = windows
71do_not_ask_for_host_reason = 'libpq issue on Windows'
72
73
74def DB():
75    """Create a DB wrapper object connecting to the test database."""
76    db = pg.DB(dbname, dbhost, dbport)
77    if debug:
78        db.debug = debug
79    db.query("set client_min_messages=warning")
80    return db
81
82
83class TestAttrDict(unittest.TestCase):
84    """Test the simple ordered dictionary for attribute names."""
85
86    cls = pg.AttrDict
87    base = OrderedDict
88
89    def testInit(self):
90        a = self.cls()
91        self.assertIsInstance(a, self.base)
92        self.assertEqual(a, self.base())
93        items = [('id', 'int'), ('name', 'text')]
94        a = self.cls(items)
95        self.assertIsInstance(a, self.base)
96        self.assertEqual(a, self.base(items))
97        iteritems = iter(items)
98        a = self.cls(iteritems)
99        self.assertIsInstance(a, self.base)
100        self.assertEqual(a, self.base(items))
101
102    def testIter(self):
103        a = self.cls()
104        self.assertEqual(list(a), [])
105        keys = ['id', 'name', 'age']
106        items = [(key, None) for key in keys]
107        a = self.cls(items)
108        self.assertEqual(list(a), keys)
109
110    def testKeys(self):
111        a = self.cls()
112        self.assertEqual(list(a.keys()), [])
113        keys = ['id', 'name', 'age']
114        items = [(key, None) for key in keys]
115        a = self.cls(items)
116        self.assertEqual(list(a.keys()), keys)
117
118    def testValues(self):
119        a = self.cls()
120        self.assertEqual(list(a.values()), [])
121        items = [('id', 'int'), ('name', 'text')]
122        values = [item[1] for item in items]
123        a = self.cls(items)
124        self.assertEqual(list(a.values()), values)
125
126    def testItems(self):
127        a = self.cls()
128        self.assertEqual(list(a.items()), [])
129        items = [('id', 'int'), ('name', 'text')]
130        a = self.cls(items)
131        self.assertEqual(list(a.items()), items)
132
133    def testGet(self):
134        a = self.cls([('id', 1)])
135        try:
136            self.assertEqual(a['id'], 1)
137        except KeyError:
138            self.fail('AttrDict should be readable')
139
140    def testSet(self):
141        a = self.cls()
142        try:
143            a['id'] = 1
144        except TypeError:
145            pass
146        else:
147            self.fail('AttrDict should be read-only')
148
149    def testDel(self):
150        a = self.cls([('id', 1)])
151        try:
152            del a['id']
153        except TypeError:
154            pass
155        else:
156            self.fail('AttrDict should be read-only')
157
158    def testWriteMethods(self):
159        a = self.cls([('id', 1)])
160        self.assertEqual(a['id'], 1)
161        for method in 'clear', 'update', 'pop', 'setdefault', 'popitem':
162            method = getattr(a, method)
163            self.assertRaises(TypeError, method, a)
164
165
166class TestDBClassBasic(unittest.TestCase):
167    """Test existence of the DB class wrapped pg connection methods."""
168
169    def setUp(self):
170        self.db = DB()
171
172    def tearDown(self):
173        try:
174            self.db.close()
175        except pg.InternalError:
176            pass
177
178    def testAllDBAttributes(self):
179        attributes = [
180            'abort', 'adapter',
181            'begin',
182            'cancel', 'clear', 'close', 'commit',
183            'db', 'dbname', 'dbtypes',
184            'debug', 'decode_json', 'delete',
185            'encode_json', 'end', 'endcopy', 'error',
186            'escape_bytea', 'escape_identifier',
187            'escape_literal', 'escape_string',
188            'fileno',
189            'get', 'get_as_dict', 'get_as_list',
190            'get_attnames', 'get_cast_hook',
191            'get_databases', 'get_notice_receiver',
192            'get_parameter', 'get_relations', 'get_tables',
193            'getline', 'getlo', 'getnotify',
194            'has_table_privilege', 'host',
195            'insert', 'inserttable',
196            'locreate', 'loimport',
197            'notification_handler',
198            'options',
199            'parameter', 'pkey', 'port',
200            'protocol_version', 'putline',
201            'query', 'query_formatted',
202            'release', 'reopen', 'reset', 'rollback',
203            'savepoint', 'server_version',
204            'set_cast_hook', 'set_notice_receiver',
205            'set_parameter',
206            'source', 'start', 'status',
207            'transaction', 'truncate',
208            'unescape_bytea', 'update', 'upsert',
209            'use_regtypes', 'user',
210        ]
211        db_attributes = [a for a in dir(self.db)
212            if not a.startswith('_')]
213        self.assertEqual(attributes, db_attributes)
214
215    def testAttributeDb(self):
216        self.assertEqual(self.db.db.db, dbname)
217
218    def testAttributeDbname(self):
219        self.assertEqual(self.db.dbname, dbname)
220
221    def testAttributeError(self):
222        error = self.db.error
223        self.assertTrue(not error or 'krb5_' in error)
224        self.assertEqual(self.db.error, self.db.db.error)
225
226    @unittest.skipIf(do_not_ask_for_host, do_not_ask_for_host_reason)
227    def testAttributeHost(self):
228        def_host = 'localhost'
229        host = self.db.host
230        self.assertIsInstance(host, str)
231        self.assertEqual(host, dbhost or def_host)
232        self.assertEqual(host, self.db.db.host)
233
234    def testAttributeOptions(self):
235        no_options = ''
236        options = self.db.options
237        self.assertEqual(options, no_options)
238        self.assertEqual(options, self.db.db.options)
239
240    def testAttributePort(self):
241        def_port = 5432
242        port = self.db.port
243        self.assertIsInstance(port, int)
244        self.assertEqual(port, dbport or def_port)
245        self.assertEqual(port, self.db.db.port)
246
247    def testAttributeProtocolVersion(self):
248        protocol_version = self.db.protocol_version
249        self.assertIsInstance(protocol_version, int)
250        self.assertTrue(2 <= protocol_version < 4)
251        self.assertEqual(protocol_version, self.db.db.protocol_version)
252
253    def testAttributeServerVersion(self):
254        server_version = self.db.server_version
255        self.assertIsInstance(server_version, int)
256        self.assertTrue(70400 <= server_version < 100000)
257        self.assertEqual(server_version, self.db.db.server_version)
258
259    def testAttributeStatus(self):
260        status_ok = 1
261        status = self.db.status
262        self.assertIsInstance(status, int)
263        self.assertEqual(status, status_ok)
264        self.assertEqual(status, self.db.db.status)
265
266    def testAttributeUser(self):
267        no_user = 'Deprecated facility'
268        user = self.db.user
269        self.assertTrue(user)
270        self.assertIsInstance(user, str)
271        self.assertNotEqual(user, no_user)
272        self.assertEqual(user, self.db.db.user)
273
274    def testMethodEscapeLiteral(self):
275        self.assertEqual(self.db.escape_literal(''), "''")
276
277    def testMethodEscapeIdentifier(self):
278        self.assertEqual(self.db.escape_identifier(''), '""')
279
280    def testMethodEscapeString(self):
281        self.assertEqual(self.db.escape_string(''), '')
282
283    def testMethodEscapeBytea(self):
284        self.assertEqual(self.db.escape_bytea('').replace(
285            '\\x', '').replace('\\', ''), '')
286
287    def testMethodUnescapeBytea(self):
288        self.assertEqual(self.db.unescape_bytea(''), b'')
289
290    def testMethodDecodeJson(self):
291        self.assertEqual(self.db.decode_json('{}'), {})
292
293    def testMethodEncodeJson(self):
294        self.assertEqual(self.db.encode_json({}), '{}')
295
296    def testMethodQuery(self):
297        query = self.db.query
298        query("select 1+1")
299        query("select 1+$1+$2", 2, 3)
300        query("select 1+$1+$2", (2, 3))
301        query("select 1+$1+$2", [2, 3])
302        query("select 1+$1", 1)
303
304    def testMethodQueryEmpty(self):
305        self.assertRaises(ValueError, self.db.query, '')
306
307    def testMethodQueryProgrammingError(self):
308        try:
309            self.db.query("select 1/0")
310        except pg.ProgrammingError as error:
311            self.assertEqual(error.sqlstate, '22012')
312
313    def testMethodEndcopy(self):
314        try:
315            self.db.endcopy()
316        except IOError:
317            pass
318
319    def testMethodClose(self):
320        self.db.close()
321        try:
322            self.db.reset()
323        except pg.Error:
324            pass
325        else:
326            self.fail('Reset should give an error for a closed connection')
327        self.assertIsNone(self.db.db)
328        self.assertRaises(pg.InternalError, self.db.close)
329        self.assertRaises(pg.InternalError, self.db.query, 'select 1')
330        self.assertRaises(pg.InternalError, getattr, self.db, 'status')
331        self.assertRaises(pg.InternalError, getattr, self.db, 'error')
332        self.assertRaises(pg.InternalError, getattr, self.db, 'absent')
333
334    def testMethodReset(self):
335        con = self.db.db
336        self.db.reset()
337        self.assertIs(self.db.db, con)
338        self.db.query("select 1+1")
339        self.db.close()
340        self.assertRaises(pg.InternalError, self.db.reset)
341
342    def testMethodReopen(self):
343        con = self.db.db
344        self.db.reopen()
345        self.assertIsNot(self.db.db, con)
346        con = self.db.db
347        self.db.query("select 1+1")
348        self.db.close()
349        self.db.reopen()
350        self.assertIsNot(self.db.db, con)
351        self.db.query("select 1+1")
352        self.db.close()
353
354    def testExistingConnection(self):
355        db = pg.DB(self.db.db)
356        self.assertEqual(self.db.db, db.db)
357        self.assertTrue(db.db)
358        db.close()
359        self.assertTrue(db.db)
360        db.reopen()
361        self.assertTrue(db.db)
362        db.close()
363        self.assertTrue(db.db)
364        db = pg.DB(self.db)
365        self.assertEqual(self.db.db, db.db)
366        db = pg.DB(db=self.db.db)
367        self.assertEqual(self.db.db, db.db)
368
369        class DB2:
370            pass
371
372        db2 = DB2()
373        db2._cnx = self.db.db
374        db = pg.DB(db2)
375        self.assertEqual(self.db.db, db.db)
376
377
378class TestDBClass(unittest.TestCase):
379    """Test the methods of the DB class wrapped pg connection."""
380
381    maxDiff = 80 * 20
382
383    cls_set_up = False
384
385    regtypes = None
386
387    @classmethod
388    def setUpClass(cls):
389        db = DB()
390        db.query("drop table if exists test cascade")
391        db.query("create table test ("
392                 "i2 smallint, i4 integer, i8 bigint,"
393                 " d numeric, f4 real, f8 double precision, m money,"
394                 " v4 varchar(4), c4 char(4), t text)")
395        db.query("create or replace view test_view as"
396                 " select i4, v4 from test")
397        db.close()
398        cls.cls_set_up = True
399
400    @classmethod
401    def tearDownClass(cls):
402        db = DB()
403        db.query("drop table test cascade")
404        db.close()
405
406    def setUp(self):
407        self.assertTrue(self.cls_set_up)
408        self.db = DB()
409        if self.regtypes is None:
410            self.regtypes = self.db.use_regtypes()
411        else:
412            self.db.use_regtypes(self.regtypes)
413        query = self.db.query
414        query('set client_encoding=utf8')
415        query('set standard_conforming_strings=on')
416        query("set lc_monetary='C'")
417        query("set datestyle='ISO,YMD'")
418        query('set bytea_output=hex')
419
420    def tearDown(self):
421        self.doCleanups()
422        self.db.close()
423
424    def createTable(self, table, definition,
425                    temporary=True, oids=None, values=None):
426        query = self.db.query
427        if not '"' in table or '.' in table:
428            table = '"%s"' % table
429        if not temporary:
430            q = 'drop table if exists %s cascade' % table
431            query(q)
432            self.addCleanup(query, q)
433        temporary = 'temporary table' if temporary else 'table'
434        as_query = definition.startswith(('as ', 'AS '))
435        if not as_query and not definition.startswith('('):
436            definition = '(%s)' % definition
437        with_oids = 'with oids' if oids else 'without oids'
438        q = ['create', temporary, table]
439        if as_query:
440            q.extend([with_oids, definition])
441        else:
442            q.extend([definition, with_oids])
443        q = ' '.join(q)
444        query(q)
445        if values:
446            for params in values:
447                if not isinstance(params, (list, tuple)):
448                    params = [params]
449                values = ', '.join('$%d' % (n + 1) for n in range(len(params)))
450                q = "insert into %s values (%s)" % (table, values)
451                query(q, params)
452
453    def testClassName(self):
454        self.assertEqual(self.db.__class__.__name__, 'DB')
455
456    def testModuleName(self):
457        self.assertEqual(self.db.__module__, 'pg')
458        self.assertEqual(self.db.__class__.__module__, 'pg')
459
460    def testEscapeLiteral(self):
461        f = self.db.escape_literal
462        r = f(b"plain")
463        self.assertIsInstance(r, bytes)
464        self.assertEqual(r, b"'plain'")
465        r = f(u"plain")
466        self.assertIsInstance(r, unicode)
467        self.assertEqual(r, u"'plain'")
468        r = f(u"that's kÀse".encode('utf-8'))
469        self.assertIsInstance(r, bytes)
470        self.assertEqual(r, u"'that''s kÀse'".encode('utf-8'))
471        r = f(u"that's kÀse")
472        self.assertIsInstance(r, unicode)
473        self.assertEqual(r, u"'that''s kÀse'")
474        self.assertEqual(f(r"It's fine to have a \ inside."),
475                         r" E'It''s fine to have a \\ inside.'")
476        self.assertEqual(f('No "quotes" must be escaped.'),
477                         "'No \"quotes\" must be escaped.'")
478
479    def testEscapeIdentifier(self):
480        f = self.db.escape_identifier
481        r = f(b"plain")
482        self.assertIsInstance(r, bytes)
483        self.assertEqual(r, b'"plain"')
484        r = f(u"plain")
485        self.assertIsInstance(r, unicode)
486        self.assertEqual(r, u'"plain"')
487        r = f(u"that's kÀse".encode('utf-8'))
488        self.assertIsInstance(r, bytes)
489        self.assertEqual(r, u'"that\'s kÀse"'.encode('utf-8'))
490        r = f(u"that's kÀse")
491        self.assertIsInstance(r, unicode)
492        self.assertEqual(r, u'"that\'s kÀse"')
493        self.assertEqual(f(r"It's fine to have a \ inside."),
494                         '"It\'s fine to have a \\ inside."')
495        self.assertEqual(f('All "quotes" must be escaped.'),
496                         '"All ""quotes"" must be escaped."')
497
498    def testEscapeString(self):
499        f = self.db.escape_string
500        r = f(b"plain")
501        self.assertIsInstance(r, bytes)
502        self.assertEqual(r, b"plain")
503        r = f(u"plain")
504        self.assertIsInstance(r, unicode)
505        self.assertEqual(r, u"plain")
506        r = f(u"that's kÀse".encode('utf-8'))
507        self.assertIsInstance(r, bytes)
508        self.assertEqual(r, u"that''s kÀse".encode('utf-8'))
509        r = f(u"that's kÀse")
510        self.assertIsInstance(r, unicode)
511        self.assertEqual(r, u"that''s kÀse")
512        self.assertEqual(f(r"It's fine to have a \ inside."),
513                         r"It''s fine to have a \ inside.")
514
515    def testEscapeBytea(self):
516        f = self.db.escape_bytea
517        # note that escape_byte always returns hex output since Pg 9.0,
518        # regardless of the bytea_output setting
519        r = f(b'plain')
520        self.assertIsInstance(r, bytes)
521        self.assertEqual(r, b'\\x706c61696e')
522        r = f(u'plain')
523        self.assertIsInstance(r, unicode)
524        self.assertEqual(r, u'\\x706c61696e')
525        r = f(u"das is' kÀse".encode('utf-8'))
526        self.assertIsInstance(r, bytes)
527        self.assertEqual(r, b'\\x64617320697327206bc3a47365')
528        r = f(u"das is' kÀse")
529        self.assertIsInstance(r, unicode)
530        self.assertEqual(r, u'\\x64617320697327206bc3a47365')
531        self.assertEqual(f(b'O\x00ps\xff!'), b'\\x4f007073ff21')
532
533    def testUnescapeBytea(self):
534        f = self.db.unescape_bytea
535        r = f(b'plain')
536        self.assertIsInstance(r, bytes)
537        self.assertEqual(r, b'plain')
538        r = f(u'plain')
539        self.assertIsInstance(r, bytes)
540        self.assertEqual(r, b'plain')
541        r = f(b"das is' k\\303\\244se")
542        self.assertIsInstance(r, bytes)
543        self.assertEqual(r, u"das is' kÀse".encode('utf8'))
544        r = f(u"das is' k\\303\\244se")
545        self.assertIsInstance(r, bytes)
546        self.assertEqual(r, u"das is' kÀse".encode('utf8'))
547        self.assertEqual(f(r'O\\000ps\\377!'), b'O\\000ps\\377!')
548        self.assertEqual(f(r'\\x706c61696e'), b'\\x706c61696e')
549        self.assertEqual(f(r'\\x746861742773206be47365'),
550                         b'\\x746861742773206be47365')
551        self.assertEqual(f(r'\\x4f007073ff21'), b'\\x4f007073ff21')
552
553    def testDecodeJson(self):
554        f = self.db.decode_json
555        self.assertIsNone(f('null'))
556        data = {
557            "id": 1, "name": "Foo", "price": 1234.5,
558            "new": True, "note": None,
559            "tags": ["Bar", "Eek"],
560            "stock": {"warehouse": 300, "retail": 20}}
561        text = json.dumps(data)
562        r = f(text)
563        self.assertIsInstance(r, dict)
564        self.assertEqual(r, data)
565        self.assertIsInstance(r['id'], int)
566        self.assertIsInstance(r['name'], unicode)
567        self.assertIsInstance(r['price'], float)
568        self.assertIsInstance(r['new'], bool)
569        self.assertIsInstance(r['tags'], list)
570        self.assertIsInstance(r['stock'], dict)
571
572    def testEncodeJson(self):
573        f = self.db.encode_json
574        self.assertEqual(f(None), 'null')
575        data = {
576            "id": 1, "name": "Foo", "price": 1234.5,
577            "new": True, "note": None,
578            "tags": ["Bar", "Eek"],
579            "stock": {"warehouse": 300, "retail": 20}}
580        text = json.dumps(data)
581        r = f(data)
582        self.assertIsInstance(r, str)
583        self.assertEqual(r, text)
584
585    def testGetParameter(self):
586        f = self.db.get_parameter
587        self.assertRaises(TypeError, f)
588        self.assertRaises(TypeError, f, None)
589        self.assertRaises(TypeError, f, 42)
590        self.assertRaises(TypeError, f, '')
591        self.assertRaises(TypeError, f, [])
592        self.assertRaises(TypeError, f, [''])
593        self.assertRaises(pg.ProgrammingError, f, 'this_does_not_exist')
594        r = f('standard_conforming_strings')
595        self.assertEqual(r, 'on')
596        r = f('lc_monetary')
597        self.assertEqual(r, 'C')
598        r = f('datestyle')
599        self.assertEqual(r, 'ISO, YMD')
600        r = f('bytea_output')
601        self.assertEqual(r, 'hex')
602        r = f(['bytea_output', 'lc_monetary'])
603        self.assertIsInstance(r, list)
604        self.assertEqual(r, ['hex', 'C'])
605        r = f(('standard_conforming_strings', 'datestyle', 'bytea_output'))
606        self.assertEqual(r, ['on', 'ISO, YMD', 'hex'])
607        r = f(set(['bytea_output', 'lc_monetary']))
608        self.assertIsInstance(r, dict)
609        self.assertEqual(r, {'bytea_output': 'hex', 'lc_monetary': 'C'})
610        r = f(set(['Bytea_Output', ' LC_Monetary ']))
611        self.assertIsInstance(r, dict)
612        self.assertEqual(r, {'Bytea_Output': 'hex', ' LC_Monetary ': 'C'})
613        s = dict.fromkeys(('bytea_output', 'lc_monetary'))
614        r = f(s)
615        self.assertIs(r, s)
616        self.assertEqual(r, {'bytea_output': 'hex', 'lc_monetary': 'C'})
617        s = dict.fromkeys(('Bytea_Output', ' LC_Monetary '))
618        r = f(s)
619        self.assertIs(r, s)
620        self.assertEqual(r, {'Bytea_Output': 'hex', ' LC_Monetary ': 'C'})
621
622    def testGetParameterServerVersion(self):
623        r = self.db.get_parameter('server_version_num')
624        self.assertIsInstance(r, str)
625        s = self.db.server_version
626        self.assertIsInstance(s, int)
627        self.assertEqual(r, str(s))
628
629    def testGetParameterAll(self):
630        f = self.db.get_parameter
631        r = f('all')
632        self.assertIsInstance(r, dict)
633        self.assertEqual(r['standard_conforming_strings'], 'on')
634        self.assertEqual(r['lc_monetary'], 'C')
635        self.assertEqual(r['DateStyle'], 'ISO, YMD')
636        self.assertEqual(r['bytea_output'], 'hex')
637
638    def testSetParameter(self):
639        f = self.db.set_parameter
640        g = self.db.get_parameter
641        self.assertRaises(TypeError, f)
642        self.assertRaises(TypeError, f, None)
643        self.assertRaises(TypeError, f, 42)
644        self.assertRaises(TypeError, f, '')
645        self.assertRaises(TypeError, f, [])
646        self.assertRaises(TypeError, f, [''])
647        self.assertRaises(ValueError, f, 'all', 'invalid')
648        self.assertRaises(ValueError, f, {
649            'invalid1': 'value1', 'invalid2': 'value2'}, 'value')
650        self.assertRaises(pg.ProgrammingError, f, 'this_does_not_exist')
651        f('standard_conforming_strings', 'off')
652        self.assertEqual(g('standard_conforming_strings'), 'off')
653        f('datestyle', 'ISO, DMY')
654        self.assertEqual(g('datestyle'), 'ISO, DMY')
655        f(['standard_conforming_strings', 'datestyle'], ['on', 'ISO, DMY'])
656        self.assertEqual(g('standard_conforming_strings'), 'on')
657        self.assertEqual(g('datestyle'), 'ISO, DMY')
658        f(['default_with_oids', 'standard_conforming_strings'], 'off')
659        self.assertEqual(g('default_with_oids'), 'off')
660        self.assertEqual(g('standard_conforming_strings'), 'off')
661        f(('standard_conforming_strings', 'datestyle'), ('on', 'ISO, YMD'))
662        self.assertEqual(g('standard_conforming_strings'), 'on')
663        self.assertEqual(g('datestyle'), 'ISO, YMD')
664        f(('default_with_oids', 'standard_conforming_strings'), 'off')
665        self.assertEqual(g('default_with_oids'), 'off')
666        self.assertEqual(g('standard_conforming_strings'), 'off')
667        f(set(['default_with_oids', 'standard_conforming_strings']), 'on')
668        self.assertEqual(g('default_with_oids'), 'on')
669        self.assertEqual(g('standard_conforming_strings'), 'on')
670        self.assertRaises(ValueError, f, set(['default_with_oids',
671            'standard_conforming_strings']), ['off', 'on'])
672        f(set(['default_with_oids', 'standard_conforming_strings']),
673            ['off', 'off'])
674        self.assertEqual(g('default_with_oids'), 'off')
675        self.assertEqual(g('standard_conforming_strings'), 'off')
676        f({'standard_conforming_strings': 'on', 'datestyle': 'ISO, YMD'})
677        self.assertEqual(g('standard_conforming_strings'), 'on')
678        self.assertEqual(g('datestyle'), 'ISO, YMD')
679
680    def testResetParameter(self):
681        db = DB()
682        f = db.set_parameter
683        g = db.get_parameter
684        r = g('default_with_oids')
685        self.assertIn(r, ('on', 'off'))
686        dwi, not_dwi = r, 'off' if r == 'on' else 'on'
687        r = g('standard_conforming_strings')
688        self.assertIn(r, ('on', 'off'))
689        scs, not_scs = r, 'off' if r == 'on' else 'on'
690        f('default_with_oids', not_dwi)
691        f('standard_conforming_strings', not_scs)
692        self.assertEqual(g('default_with_oids'), not_dwi)
693        self.assertEqual(g('standard_conforming_strings'), not_scs)
694        f('default_with_oids')
695        f('standard_conforming_strings', None)
696        self.assertEqual(g('default_with_oids'), dwi)
697        self.assertEqual(g('standard_conforming_strings'), scs)
698        f('default_with_oids', not_dwi)
699        f('standard_conforming_strings', not_scs)
700        self.assertEqual(g('default_with_oids'), not_dwi)
701        self.assertEqual(g('standard_conforming_strings'), not_scs)
702        f(['default_with_oids', 'standard_conforming_strings'], None)
703        self.assertEqual(g('default_with_oids'), dwi)
704        self.assertEqual(g('standard_conforming_strings'), scs)
705        f('default_with_oids', not_dwi)
706        f('standard_conforming_strings', not_scs)
707        self.assertEqual(g('default_with_oids'), not_dwi)
708        self.assertEqual(g('standard_conforming_strings'), not_scs)
709        f(('default_with_oids', 'standard_conforming_strings'))
710        self.assertEqual(g('default_with_oids'), dwi)
711        self.assertEqual(g('standard_conforming_strings'), scs)
712        f('default_with_oids', not_dwi)
713        f('standard_conforming_strings', not_scs)
714        self.assertEqual(g('default_with_oids'), not_dwi)
715        self.assertEqual(g('standard_conforming_strings'), not_scs)
716        f(set(['default_with_oids', 'standard_conforming_strings']))
717        self.assertEqual(g('default_with_oids'), dwi)
718        self.assertEqual(g('standard_conforming_strings'), scs)
719
720    def testResetParameterAll(self):
721        db = DB()
722        f = db.set_parameter
723        self.assertRaises(ValueError, f, 'all', 0)
724        self.assertRaises(ValueError, f, 'all', 'off')
725        g = db.get_parameter
726        r = g('default_with_oids')
727        self.assertIn(r, ('on', 'off'))
728        dwi, not_dwi = r, 'off' if r == 'on' else 'on'
729        r = g('standard_conforming_strings')
730        self.assertIn(r, ('on', 'off'))
731        scs, not_scs = r, 'off' if r == 'on' else 'on'
732        f('default_with_oids', not_dwi)
733        f('standard_conforming_strings', not_scs)
734        self.assertEqual(g('default_with_oids'), not_dwi)
735        self.assertEqual(g('standard_conforming_strings'), not_scs)
736        f('all')
737        self.assertEqual(g('default_with_oids'), dwi)
738        self.assertEqual(g('standard_conforming_strings'), scs)
739
740    def testSetParameterLocal(self):
741        f = self.db.set_parameter
742        g = self.db.get_parameter
743        self.assertEqual(g('standard_conforming_strings'), 'on')
744        self.db.begin()
745        f('standard_conforming_strings', 'off', local=True)
746        self.assertEqual(g('standard_conforming_strings'), 'off')
747        self.db.end()
748        self.assertEqual(g('standard_conforming_strings'), 'on')
749
750    def testSetParameterSession(self):
751        f = self.db.set_parameter
752        g = self.db.get_parameter
753        self.assertEqual(g('standard_conforming_strings'), 'on')
754        self.db.begin()
755        f('standard_conforming_strings', 'off', local=False)
756        self.assertEqual(g('standard_conforming_strings'), 'off')
757        self.db.end()
758        self.assertEqual(g('standard_conforming_strings'), 'off')
759
760    def testReset(self):
761        db = DB()
762        default_datestyle = db.get_parameter('datestyle')
763        changed_datestyle = 'ISO, DMY'
764        if changed_datestyle == default_datestyle:
765            changed_datestyle == 'ISO, YMD'
766        self.db.set_parameter('datestyle', changed_datestyle)
767        r = self.db.get_parameter('datestyle')
768        self.assertEqual(r, changed_datestyle)
769        con = self.db.db
770        q = con.query("show datestyle")
771        self.db.reset()
772        r = q.getresult()[0][0]
773        self.assertEqual(r, changed_datestyle)
774        q = con.query("show datestyle")
775        r = q.getresult()[0][0]
776        self.assertEqual(r, default_datestyle)
777        r = self.db.get_parameter('datestyle')
778        self.assertEqual(r, default_datestyle)
779
780    def testReopen(self):
781        db = DB()
782        default_datestyle = db.get_parameter('datestyle')
783        changed_datestyle = 'ISO, DMY'
784        if changed_datestyle == default_datestyle:
785            changed_datestyle == 'ISO, YMD'
786        self.db.set_parameter('datestyle', changed_datestyle)
787        r = self.db.get_parameter('datestyle')
788        self.assertEqual(r, changed_datestyle)
789        con = self.db.db
790        q = con.query("show datestyle")
791        self.db.reopen()
792        r = q.getresult()[0][0]
793        self.assertEqual(r, changed_datestyle)
794        self.assertRaises(TypeError, getattr, con, 'query')
795        r = self.db.get_parameter('datestyle')
796        self.assertEqual(r, default_datestyle)
797
798    def testCreateTable(self):
799        table = 'test hello world'
800        values = [(2, "World!"), (1, "Hello")]
801        self.createTable(table, "n smallint, t varchar",
802                         temporary=True, oids=True, values=values)
803        r = self.db.query('select t from "%s" order by n' % table).getresult()
804        r = ', '.join(row[0] for row in r)
805        self.assertEqual(r, "Hello, World!")
806        r = self.db.query('select oid from "%s" limit 1' % table).getresult()
807        self.assertIsInstance(r[0][0], int)
808
809    def testQuery(self):
810        query = self.db.query
811        table = 'test_table'
812        self.createTable(table, "n integer", oids=True)
813        q = "insert into test_table values (1)"
814        r = query(q)
815        self.assertIsInstance(r, int)
816        q = "insert into test_table select 2"
817        r = query(q)
818        self.assertIsInstance(r, int)
819        oid = r
820        q = "select oid from test_table where n=2"
821        r = query(q).getresult()
822        self.assertEqual(len(r), 1)
823        r = r[0]
824        self.assertEqual(len(r), 1)
825        r = r[0]
826        self.assertEqual(r, oid)
827        q = "insert into test_table select 3 union select 4 union select 5"
828        r = query(q)
829        self.assertIsInstance(r, str)
830        self.assertEqual(r, '3')
831        q = "update test_table set n=4 where n<5"
832        r = query(q)
833        self.assertIsInstance(r, str)
834        self.assertEqual(r, '4')
835        q = "delete from test_table"
836        r = query(q)
837        self.assertIsInstance(r, str)
838        self.assertEqual(r, '5')
839
840    def testMultipleQueries(self):
841        self.assertEqual(self.db.query(
842            "create temporary table test_multi (n integer);"
843            "insert into test_multi values (4711);"
844            "select n from test_multi").getresult()[0][0], 4711)
845
846    def testQueryWithParams(self):
847        query = self.db.query
848        self.createTable('test_table', 'n1 integer, n2 integer', oids=True)
849        q = "insert into test_table values ($1, $2)"
850        r = query(q, (1, 2))
851        self.assertIsInstance(r, int)
852        r = query(q, [3, 4])
853        self.assertIsInstance(r, int)
854        r = query(q, [5, 6])
855        self.assertIsInstance(r, int)
856        q = "select * from test_table order by 1, 2"
857        self.assertEqual(query(q).getresult(),
858                         [(1, 2), (3, 4), (5, 6)])
859        q = "select * from test_table where n1=$1 and n2=$2"
860        self.assertEqual(query(q, 3, 4).getresult(), [(3, 4)])
861        q = "update test_table set n2=$2 where n1=$1"
862        r = query(q, 3, 7)
863        self.assertEqual(r, '1')
864        q = "select * from test_table order by 1, 2"
865        self.assertEqual(query(q).getresult(),
866                         [(1, 2), (3, 7), (5, 6)])
867        q = "delete from test_table where n2!=$1"
868        r = query(q, 4)
869        self.assertEqual(r, '3')
870
871    def testEmptyQuery(self):
872        self.assertRaises(ValueError, self.db.query, '')
873
874    def testQueryProgrammingError(self):
875        try:
876            self.db.query("select 1/0")
877        except pg.ProgrammingError as error:
878            self.assertEqual(error.sqlstate, '22012')
879
880    def testQueryFormatted(self):
881        f = self.db.query_formatted
882        t = True if pg.get_bool() else 't'
883        q = f("select %s::int, %s::real, %s::text, %s::bool",
884              (3, 2.5, 'hello', True))
885        r = q.getresult()[0]
886        self.assertEqual(r, (3, 2.5, 'hello', t))
887        q = f("select %s, %s, %s, %s", (3, 2.5, 'hello', True), inline=True)
888        r = q.getresult()[0]
889        self.assertEqual(r, (3, 2.5, 'hello', t))
890
891    def testPkey(self):
892        query = self.db.query
893        pkey = self.db.pkey
894        self.assertRaises(KeyError, pkey, 'test')
895        for t in ('pkeytest', 'primary key test'):
896            self.createTable('%s0' % t, 'a smallint')
897            self.createTable('%s1' % t, 'b smallint primary key')
898            self.createTable('%s2' % t,
899                'c smallint, d smallint primary key')
900            self.createTable('%s3' % t,
901                'e smallint, f smallint, g smallint, h smallint, i smallint,'
902                ' primary key (f, h)')
903            self.createTable('%s4' % t,
904                'e smallint, f smallint, g smallint, h smallint, i smallint,'
905                ' primary key (h, f)')
906            self.createTable('%s5' % t,
907                'more_than_one_letter varchar primary key')
908            self.createTable('%s6' % t,
909                '"with space" date primary key')
910            self.createTable('%s7' % t,
911                'a_very_long_column_name varchar, "with space" date, "42" int,'
912                ' primary key (a_very_long_column_name, "with space", "42")')
913            self.assertRaises(KeyError, pkey, '%s0' % t)
914            self.assertEqual(pkey('%s1' % t), 'b')
915            self.assertEqual(pkey('%s1' % t, True), ('b',))
916            self.assertEqual(pkey('%s1' % t, composite=False), 'b')
917            self.assertEqual(pkey('%s1' % t, composite=True), ('b',))
918            self.assertEqual(pkey('%s2' % t), 'd')
919            self.assertEqual(pkey('%s2' % t, composite=True), ('d',))
920            r = pkey('%s3' % t)
921            self.assertIsInstance(r, tuple)
922            self.assertEqual(r, ('f', 'h'))
923            r = pkey('%s3' % t, composite=False)
924            self.assertIsInstance(r, tuple)
925            self.assertEqual(r, ('f', 'h'))
926            r = pkey('%s4' % t)
927            self.assertIsInstance(r, tuple)
928            self.assertEqual(r, ('h', 'f'))
929            self.assertEqual(pkey('%s5' % t), 'more_than_one_letter')
930            self.assertEqual(pkey('%s6' % t), 'with space')
931            r = pkey('%s7' % t)
932            self.assertIsInstance(r, tuple)
933            self.assertEqual(r, (
934                'a_very_long_column_name', 'with space', '42'))
935            # a newly added primary key will be detected
936            query('alter table "%s0" add primary key (a)' % t)
937            self.assertEqual(pkey('%s0' % t), 'a')
938            # a changed primary key will not be detected,
939            # indicating that the internal cache is operating
940            query('alter table "%s1" rename column b to x' % t)
941            self.assertEqual(pkey('%s1' % t), 'b')
942            # we get the changed primary key when the cache is flushed
943            self.assertEqual(pkey('%s1' % t, flush=True), 'x')
944
945    def testGetDatabases(self):
946        databases = self.db.get_databases()
947        self.assertIn('template0', databases)
948        self.assertIn('template1', databases)
949        self.assertNotIn('not existing database', databases)
950        self.assertIn('postgres', databases)
951        self.assertIn(dbname, databases)
952
953    def testGetTables(self):
954        get_tables = self.db.get_tables
955        tables = ('A very Special Name', 'A_MiXeD_quoted_NaMe',
956                  'Hello, Test World!', 'Zoro', 'a1', 'a2', 'a321',
957                  'averyveryveryveryveryveryveryreallyreallylongtablename',
958                  'b0', 'b3', 'x', 'xXx', 'xx', 'y', 'z')
959        for t in tables:
960            self.db.query('drop table if exists "%s" cascade' % t)
961        before_tables = get_tables()
962        self.assertIsInstance(before_tables, list)
963        for t in before_tables:
964            t = t.split('.', 1)
965            self.assertGreaterEqual(len(t), 2)
966            if len(t) > 2:
967                self.assertTrue(t[1].startswith('"'))
968            t = t[0]
969            self.assertNotEqual(t, 'information_schema')
970            self.assertFalse(t.startswith('pg_'))
971        for t in tables:
972            self.createTable(t, 'as select 0', temporary=False)
973        current_tables = get_tables()
974        new_tables = [t for t in current_tables if t not in before_tables]
975        expected_new_tables = ['public.%s' % (
976            '"%s"' % t if ' ' in t or t != t.lower() else t) for t in tables]
977        self.assertEqual(new_tables, expected_new_tables)
978        self.doCleanups()
979        after_tables = get_tables()
980        self.assertEqual(after_tables, before_tables)
981
982    def testGetRelations(self):
983        get_relations = self.db.get_relations
984        result = get_relations()
985        self.assertIn('public.test', result)
986        self.assertIn('public.test_view', result)
987        result = get_relations('rv')
988        self.assertIn('public.test', result)
989        self.assertIn('public.test_view', result)
990        result = get_relations('r')
991        self.assertIn('public.test', result)
992        self.assertNotIn('public.test_view', result)
993        result = get_relations('v')
994        self.assertNotIn('public.test', result)
995        self.assertIn('public.test_view', result)
996        result = get_relations('cisSt')
997        self.assertNotIn('public.test', result)
998        self.assertNotIn('public.test_view', result)
999
1000    def testGetAttnames(self):
1001        get_attnames = self.db.get_attnames
1002        self.assertRaises(pg.ProgrammingError,
1003                          self.db.get_attnames, 'does_not_exist')
1004        self.assertRaises(pg.ProgrammingError,
1005                          self.db.get_attnames, 'has.too.many.dots')
1006        r = get_attnames('test')
1007        self.assertIsInstance(r, dict)
1008        if self.regtypes:
1009            self.assertEqual(r, dict(
1010                i2='smallint', i4='integer', i8='bigint', d='numeric',
1011                f4='real', f8='double precision', m='money',
1012                v4='character varying', c4='character', t='text'))
1013        else:
1014            self.assertEqual(r, dict(
1015                i2='int', i4='int', i8='int', d='num',
1016                f4='float', f8='float', m='money',
1017                v4='text', c4='text', t='text'))
1018        self.createTable('test_table',
1019                         'n int, alpha smallint, beta bool,'
1020                         ' gamma char(5), tau text, v varchar(3)')
1021        r = get_attnames('test_table')
1022        self.assertIsInstance(r, dict)
1023        if self.regtypes:
1024            self.assertEqual(r, dict(
1025                n='integer', alpha='smallint', beta='boolean',
1026                gamma='character', tau='text', v='character varying'))
1027        else:
1028            self.assertEqual(r, dict(
1029                n='int', alpha='int', beta='bool',
1030                gamma='text', tau='text', v='text'))
1031
1032    def testGetAttnamesWithQuotes(self):
1033        get_attnames = self.db.get_attnames
1034        table = 'test table for get_attnames()'
1035        self.createTable(table,
1036            '"Prime!" smallint, "much space" integer, "Questions?" text')
1037        r = get_attnames(table)
1038        self.assertIsInstance(r, dict)
1039        if self.regtypes:
1040            self.assertEqual(r, {
1041                'Prime!': 'smallint', 'much space': 'integer',
1042                'Questions?': 'text'})
1043        else:
1044            self.assertEqual(r, {
1045                'Prime!': 'int', 'much space': 'int', 'Questions?': 'text'})
1046        table = 'yet another test table for get_attnames()'
1047        self.createTable(table,
1048                         'a smallint, b integer, c bigint,'
1049                         ' e numeric, f real, f2 double precision, m money,'
1050                         ' x smallint, y smallint, z smallint,'
1051                         ' Normal_NaMe smallint, "Special Name" smallint,'
1052                         ' t text, u char(2), v varchar(2),'
1053                         ' primary key (y, u)', oids=True)
1054        r = get_attnames(table)
1055        self.assertIsInstance(r, dict)
1056        if self.regtypes:
1057            self.assertEqual(r, {
1058                'a': 'smallint', 'b': 'integer', 'c': 'bigint',
1059                'e': 'numeric', 'f': 'real', 'f2': 'double precision',
1060                'm': 'money', 'normal_name': 'smallint',
1061                'Special Name': 'smallint', 'u': 'character',
1062                't': 'text', 'v': 'character varying', 'y': 'smallint',
1063                'x': 'smallint', 'z': 'smallint', 'oid': 'oid'})
1064        else:
1065            self.assertEqual(r, {'a': 'int', 'b': 'int', 'c': 'int',
1066                 'e': 'num', 'f': 'float', 'f2': 'float', 'm': 'money',
1067                 'normal_name': 'int', 'Special Name': 'int',
1068                 'u': 'text', 't': 'text', 'v': 'text',
1069                 'y': 'int', 'x': 'int', 'z': 'int', 'oid': 'int'})
1070
1071    def testGetAttnamesWithRegtypes(self):
1072        get_attnames = self.db.get_attnames
1073        self.createTable('test_table',
1074                         ' n int, alpha smallint, beta bool,'
1075                         ' gamma char(5), tau text, v varchar(3)')
1076        use_regtypes = self.db.use_regtypes
1077        regtypes = use_regtypes()
1078        self.assertEqual(regtypes, self.regtypes)
1079        use_regtypes(True)
1080        try:
1081            r = get_attnames("test_table")
1082            self.assertIsInstance(r, dict)
1083        finally:
1084            use_regtypes(regtypes)
1085        self.assertEqual(r, dict(
1086            n='integer', alpha='smallint', beta='boolean',
1087            gamma='character', tau='text', v='character varying'))
1088
1089    def testGetAttnamesWithoutRegtypes(self):
1090        get_attnames = self.db.get_attnames
1091        self.createTable('test_table',
1092                         ' n int, alpha smallint, beta bool,'
1093                         ' gamma char(5), tau text, v varchar(3)')
1094        use_regtypes = self.db.use_regtypes
1095        regtypes = use_regtypes()
1096        self.assertEqual(regtypes, self.regtypes)
1097        use_regtypes(False)
1098        try:
1099            r = get_attnames("test_table")
1100            self.assertIsInstance(r, dict)
1101        finally:
1102            use_regtypes(regtypes)
1103        self.assertEqual(r, dict(
1104            n='int', alpha='int', beta='bool',
1105            gamma='text', tau='text', v='text'))
1106
1107    def testGetAttnamesIsCached(self):
1108        get_attnames = self.db.get_attnames
1109        int_type = 'integer' if self.regtypes else 'int'
1110        text_type = 'text'
1111        query = self.db.query
1112        self.createTable('test_table', 'col int')
1113        r = get_attnames("test_table")
1114        self.assertIsInstance(r, dict)
1115        self.assertEqual(r, dict(col=int_type))
1116        query("alter table test_table alter column col type text")
1117        query("alter table test_table add column col2 int")
1118        r = get_attnames("test_table")
1119        self.assertEqual(r, dict(col=int_type))
1120        r = get_attnames("test_table", flush=True)
1121        self.assertEqual(r, dict(col=text_type, col2=int_type))
1122        query("alter table test_table drop column col2")
1123        r = get_attnames("test_table")
1124        self.assertEqual(r, dict(col=text_type, col2=int_type))
1125        r = get_attnames("test_table", flush=True)
1126        self.assertEqual(r, dict(col=text_type))
1127        query("alter table test_table drop column col")
1128        r = get_attnames("test_table")
1129        self.assertEqual(r, dict(col=text_type))
1130        r = get_attnames("test_table", flush=True)
1131        self.assertEqual(r, dict())
1132
1133    def testGetAttnamesIsOrdered(self):
1134        get_attnames = self.db.get_attnames
1135        r = get_attnames('test', flush=True)
1136        self.assertIsInstance(r, OrderedDict)
1137        if self.regtypes:
1138            self.assertEqual(r, OrderedDict([
1139                ('i2', 'smallint'), ('i4', 'integer'), ('i8', 'bigint'),
1140                ('d', 'numeric'), ('f4', 'real'), ('f8', 'double precision'),
1141                ('m', 'money'), ('v4', 'character varying'),
1142                ('c4', 'character'), ('t', 'text')]))
1143        else:
1144            self.assertEqual(r, OrderedDict([
1145                ('i2', 'int'), ('i4', 'int'), ('i8', 'int'),
1146                ('d', 'num'), ('f4', 'float'), ('f8', 'float'), ('m', 'money'),
1147                ('v4', 'text'), ('c4', 'text'), ('t', 'text')]))
1148        if OrderedDict is not dict:
1149            r = ' '.join(list(r.keys()))
1150            self.assertEqual(r, 'i2 i4 i8 d f4 f8 m v4 c4 t')
1151        table = 'test table for get_attnames'
1152        self.createTable(table,
1153                         ' n int, alpha smallint, v varchar(3),'
1154                         ' gamma char(5), tau text, beta bool')
1155        r = get_attnames(table)
1156        self.assertIsInstance(r, OrderedDict)
1157        if self.regtypes:
1158            self.assertEqual(r, OrderedDict([
1159                ('n', 'integer'), ('alpha', 'smallint'),
1160                ('v', 'character varying'), ('gamma', 'character'),
1161                ('tau', 'text'), ('beta', 'boolean')]))
1162        else:
1163            self.assertEqual(r, OrderedDict([
1164                ('n', 'int'), ('alpha', 'int'), ('v', 'text'),
1165                ('gamma', 'text'), ('tau', 'text'), ('beta', 'bool')]))
1166        if OrderedDict is not dict:
1167            r = ' '.join(list(r.keys()))
1168            self.assertEqual(r, 'n alpha v gamma tau beta')
1169        else:
1170            self.skipTest('OrderedDict is not supported')
1171
1172    def testGetAttnamesIsAttrDict(self):
1173        AttrDict = pg.AttrDict
1174        get_attnames = self.db.get_attnames
1175        r = get_attnames('test', flush=True)
1176        self.assertIsInstance(r, AttrDict)
1177        if self.regtypes:
1178            self.assertEqual(r, AttrDict([
1179                ('i2', 'smallint'), ('i4', 'integer'), ('i8', 'bigint'),
1180                ('d', 'numeric'), ('f4', 'real'), ('f8', 'double precision'),
1181                ('m', 'money'), ('v4', 'character varying'),
1182                ('c4', 'character'), ('t', 'text')]))
1183        else:
1184            self.assertEqual(r, AttrDict([
1185                ('i2', 'int'), ('i4', 'int'), ('i8', 'int'),
1186                ('d', 'num'), ('f4', 'float'), ('f8', 'float'), ('m', 'money'),
1187                ('v4', 'text'), ('c4', 'text'), ('t', 'text')]))
1188        r = ' '.join(list(r.keys()))
1189        self.assertEqual(r, 'i2 i4 i8 d f4 f8 m v4 c4 t')
1190        table = 'test table for get_attnames'
1191        self.createTable(table,
1192                         ' n int, alpha smallint, v varchar(3),'
1193                         ' gamma char(5), tau text, beta bool')
1194        r = get_attnames(table)
1195        self.assertIsInstance(r, AttrDict)
1196        if self.regtypes:
1197            self.assertEqual(r, AttrDict([
1198                ('n', 'integer'), ('alpha', 'smallint'),
1199                ('v', 'character varying'), ('gamma', 'character'),
1200                ('tau', 'text'), ('beta', 'boolean')]))
1201        else:
1202            self.assertEqual(r, AttrDict([
1203                ('n', 'int'), ('alpha', 'int'), ('v', 'text'),
1204                ('gamma', 'text'), ('tau', 'text'), ('beta', 'bool')]))
1205        r = ' '.join(list(r.keys()))
1206        self.assertEqual(r, 'n alpha v gamma tau beta')
1207
1208    def testHasTablePrivilege(self):
1209        can = self.db.has_table_privilege
1210        self.assertEqual(can('test'), True)
1211        self.assertEqual(can('test', 'select'), True)
1212        self.assertEqual(can('test', 'SeLeCt'), True)
1213        self.assertEqual(can('test', 'SELECT'), True)
1214        self.assertEqual(can('test', 'insert'), True)
1215        self.assertEqual(can('test', 'update'), True)
1216        self.assertEqual(can('test', 'delete'), True)
1217        self.assertEqual(can('pg_views', 'select'), True)
1218        self.assertEqual(can('pg_views', 'delete'), False)
1219        self.assertRaises(pg.ProgrammingError, can, 'test', 'foobar')
1220        self.assertRaises(pg.ProgrammingError, can, 'table_does_not_exist')
1221
1222    def testGet(self):
1223        get = self.db.get
1224        query = self.db.query
1225        table = 'get_test_table'
1226        self.assertRaises(TypeError, get)
1227        self.assertRaises(TypeError, get, table)
1228        self.createTable(table, 'n integer, t text',
1229                         values=enumerate('xyz', start=1))
1230        self.assertRaises(pg.ProgrammingError, get, table, 2)
1231        r = get(table, 2, 'n')
1232        self.assertIsInstance(r, dict)
1233        self.assertEqual(r, dict(n=2, t='y'))
1234        r = get(table, 1, 'n')
1235        self.assertEqual(r, dict(n=1, t='x'))
1236        r = get(table, (3,), ('n',))
1237        self.assertEqual(r, dict(n=3, t='z'))
1238        r = get(table, 'y', 't')
1239        self.assertEqual(r, dict(n=2, t='y'))
1240        self.assertRaises(pg.DatabaseError, get, table, 4)
1241        self.assertRaises(pg.DatabaseError, get, table, 4, 'n')
1242        self.assertRaises(pg.DatabaseError, get, table, 'y')
1243        self.assertRaises(pg.DatabaseError, get, table, 2, 't')
1244        s = dict(n=3)
1245        self.assertRaises(pg.ProgrammingError, get, table, s)
1246        r = get(table, s, 'n')
1247        self.assertIs(r, s)
1248        self.assertEqual(r, dict(n=3, t='z'))
1249        s.update(t='x')
1250        r = get(table, s, 't')
1251        self.assertIs(r, s)
1252        self.assertEqual(s, dict(n=1, t='x'))
1253        r = get(table, s, ('n', 't'))
1254        self.assertIs(r, s)
1255        self.assertEqual(r, dict(n=1, t='x'))
1256        query('alter table "%s" alter n set not null' % table)
1257        query('alter table "%s" add primary key (n)' % table)
1258        r = get(table, 2)
1259        self.assertIsInstance(r, dict)
1260        self.assertEqual(r, dict(n=2, t='y'))
1261        self.assertEqual(get(table, 1)['t'], 'x')
1262        self.assertEqual(get(table, 3)['t'], 'z')
1263        self.assertEqual(get(table + '*', 2)['t'], 'y')
1264        self.assertEqual(get(table + ' *', 2)['t'], 'y')
1265        self.assertRaises(KeyError, get, table, (2, 2))
1266        s = dict(n=3)
1267        r = get(table, s)
1268        self.assertIs(r, s)
1269        self.assertEqual(r, dict(n=3, t='z'))
1270        s.update(n=1)
1271        self.assertEqual(get(table, s)['t'], 'x')
1272        s.update(n=2)
1273        self.assertEqual(get(table, r)['t'], 'y')
1274        s.pop('n')
1275        self.assertRaises(KeyError, get, table, s)
1276
1277    def testGetWithOid(self):
1278        get = self.db.get
1279        query = self.db.query
1280        table = 'get_with_oid_test_table'
1281        self.createTable(table, 'n integer, t text', oids=True,
1282                         values=enumerate('xyz', start=1))
1283        self.assertRaises(pg.ProgrammingError, get, table, 2)
1284        self.assertRaises(KeyError, get, table, {}, 'oid')
1285        r = get(table, 2, 'n')
1286        qoid = 'oid(%s)' % table
1287        self.assertIn(qoid, r)
1288        oid = r[qoid]
1289        self.assertIsInstance(oid, int)
1290        result = {'t': 'y', 'n': 2, qoid: oid}
1291        self.assertEqual(r, result)
1292        r = get(table, oid, 'oid')
1293        self.assertEqual(r, result)
1294        r = get(table, dict(oid=oid))
1295        self.assertEqual(r, result)
1296        r = get(table, dict(oid=oid), 'oid')
1297        self.assertEqual(r, result)
1298        r = get(table, {qoid: oid})
1299        self.assertEqual(r, result)
1300        r = get(table, {qoid: oid}, 'oid')
1301        self.assertEqual(r, result)
1302        self.assertEqual(get(table + '*', 2, 'n'), r)
1303        self.assertEqual(get(table + ' *', 2, 'n'), r)
1304        self.assertEqual(get(table, oid, 'oid')['t'], 'y')
1305        self.assertEqual(get(table, 1, 'n')['t'], 'x')
1306        self.assertEqual(get(table, 3, 'n')['t'], 'z')
1307        self.assertEqual(get(table, 2, 'n')['t'], 'y')
1308        self.assertRaises(pg.DatabaseError, get, table, 4, 'n')
1309        r['n'] = 3
1310        self.assertEqual(get(table, r, 'n')['t'], 'z')
1311        self.assertEqual(get(table, 1, 'n')['t'], 'x')
1312        self.assertEqual(get(table, r, 'oid')['t'], 'z')
1313        query('alter table "%s" alter n set not null' % table)
1314        query('alter table "%s" add primary key (n)' % table)
1315        self.assertEqual(get(table, 3)['t'], 'z')
1316        self.assertEqual(get(table, 1)['t'], 'x')
1317        self.assertEqual(get(table, 2)['t'], 'y')
1318        r['n'] = 1
1319        self.assertEqual(get(table, r)['t'], 'x')
1320        r['n'] = 3
1321        self.assertEqual(get(table, r)['t'], 'z')
1322        r['n'] = 2
1323        self.assertEqual(get(table, r)['t'], 'y')
1324        r = get(table, oid, 'oid')
1325        self.assertEqual(r, result)
1326        r = get(table, dict(oid=oid))
1327        self.assertEqual(r, result)
1328        r = get(table, dict(oid=oid), 'oid')
1329        self.assertEqual(r, result)
1330        r = get(table, {qoid: oid})
1331        self.assertEqual(r, result)
1332        r = get(table, {qoid: oid}, 'oid')
1333        self.assertEqual(r, result)
1334        r = get(table, dict(oid=oid, n=1))
1335        self.assertEqual(r['n'], 1)
1336        self.assertNotEqual(r[qoid], oid)
1337        r = get(table, dict(oid=oid, t='z'), 't')
1338        self.assertEqual(r['n'], 3)
1339        self.assertNotEqual(r[qoid], oid)
1340
1341    def testGetWithCompositeKey(self):
1342        get = self.db.get
1343        query = self.db.query
1344        table = 'get_test_table_1'
1345        self.createTable(table, 'n integer primary key, t text',
1346                         values=enumerate('abc', start=1))
1347        self.assertEqual(get(table, 2)['t'], 'b')
1348        self.assertEqual(get(table, 1, 'n')['t'], 'a')
1349        self.assertEqual(get(table, 2, ('n',))['t'], 'b')
1350        self.assertEqual(get(table, 3, ['n'])['t'], 'c')
1351        self.assertEqual(get(table, (2,), ('n',))['t'], 'b')
1352        self.assertEqual(get(table, 'b', 't')['n'], 2)
1353        self.assertEqual(get(table, ('a',), ('t',))['n'], 1)
1354        self.assertEqual(get(table, ['c'], ['t'])['n'], 3)
1355        table = 'get_test_table_2'
1356        self.createTable(table,
1357                         'n integer, m integer, t text, primary key (n, m)',
1358                         values=[(n + 1, m + 1, chr(ord('a') + 2 * n + m))
1359                                 for n in range(3) for m in range(2)])
1360        self.assertRaises(KeyError, get, table, 2)
1361        self.assertEqual(get(table, (1, 1))['t'], 'a')
1362        self.assertEqual(get(table, (1, 2))['t'], 'b')
1363        self.assertEqual(get(table, (2, 1))['t'], 'c')
1364        self.assertEqual(get(table, (1, 2), ('n', 'm'))['t'], 'b')
1365        self.assertEqual(get(table, (1, 2), ('m', 'n'))['t'], 'c')
1366        self.assertEqual(get(table, (3, 1), ('n', 'm'))['t'], 'e')
1367        self.assertEqual(get(table, (1, 3), ('m', 'n'))['t'], 'e')
1368        self.assertEqual(get(table, dict(n=2, m=2))['t'], 'd')
1369        self.assertEqual(get(table, dict(n=1, m=2), ('n', 'm'))['t'], 'b')
1370        self.assertEqual(get(table, dict(n=2, m=1), ['n', 'm'])['t'], 'c')
1371        self.assertEqual(get(table, dict(n=3, m=2), ('m', 'n'))['t'], 'f')
1372
1373    def testGetWithQuotedNames(self):
1374        get = self.db.get
1375        query = self.db.query
1376        table = 'test table for get()'
1377        self.createTable(table, '"Prime!" smallint primary key,'
1378                                ' "much space" integer, "Questions?" text',
1379                         values=[(17, 1001, 'No!')])
1380        r = get(table, 17)
1381        self.assertIsInstance(r, dict)
1382        self.assertEqual(r['Prime!'], 17)
1383        self.assertEqual(r['much space'], 1001)
1384        self.assertEqual(r['Questions?'], 'No!')
1385
1386    def testGetFromView(self):
1387        self.db.query('delete from test where i4=14')
1388        self.db.query('insert into test (i4, v4) values('
1389                      "14, 'abc4')")
1390        r = self.db.get('test_view', 14, 'i4')
1391        self.assertIn('v4', r)
1392        self.assertEqual(r['v4'], 'abc4')
1393
1394    def testGetLittleBobbyTables(self):
1395        get = self.db.get
1396        query = self.db.query
1397        self.createTable('test_students',
1398                         'firstname varchar primary key, nickname varchar, grade char(2)',
1399                         values=[("D'Arcy", 'Darcey', 'A+'), ('Sheldon', 'Moonpie', 'A+'),
1400                                 ('Robert', 'Little Bobby Tables', 'D-')])
1401        r = get('test_students', 'Sheldon')
1402        self.assertEqual(r, dict(
1403            firstname="Sheldon", nickname='Moonpie', grade='A+'))
1404        r = get('test_students', 'Robert')
1405        self.assertEqual(r, dict(
1406            firstname="Robert", nickname='Little Bobby Tables', grade='D-'))
1407        r = get('test_students', "D'Arcy")
1408        self.assertEqual(r, dict(
1409            firstname="D'Arcy", nickname='Darcey', grade='A+'))
1410        try:
1411            get('test_students', "D' Arcy")
1412        except pg.DatabaseError as error:
1413            self.assertEqual(str(error),
1414                'No such record in test_students\nwhere "firstname" = $1\n'
1415                'with $1="D\' Arcy"')
1416        try:
1417            get('test_students', "Robert'); TRUNCATE TABLE test_students;--")
1418        except pg.DatabaseError as error:
1419            self.assertEqual(str(error),
1420                'No such record in test_students\nwhere "firstname" = $1\n'
1421                'with $1="Robert\'); TRUNCATE TABLE test_students;--"')
1422        q = "select * from test_students order by 1 limit 4"
1423        r = query(q).getresult()
1424        self.assertEqual(len(r), 3)
1425        self.assertEqual(r[1][2], 'D-')
1426
1427    def testInsert(self):
1428        insert = self.db.insert
1429        query = self.db.query
1430        bool_on = pg.get_bool()
1431        decimal = pg.get_decimal()
1432        table = 'insert_test_table'
1433        self.createTable(table,
1434            'i2 smallint, i4 integer, i8 bigint,'
1435            ' d numeric, f4 real, f8 double precision, m money,'
1436            ' v4 varchar(4), c4 char(4), t text,'
1437            ' b boolean, ts timestamp', oids=True)
1438        oid_table = 'oid(%s)' % table
1439        tests = [dict(i2=None, i4=None, i8=None),
1440             (dict(i2='', i4='', i8=''), dict(i2=None, i4=None, i8=None)),
1441             (dict(i2=0, i4=0, i8=0), dict(i2=0, i4=0, i8=0)),
1442             dict(i2=42, i4=123456, i8=9876543210),
1443             dict(i2=2 ** 15 - 1,
1444                  i4=int(2 ** 31 - 1), i8=long(2 ** 63 - 1)),
1445             dict(d=None), (dict(d=''), dict(d=None)),
1446             dict(d=Decimal(0)), (dict(d=0), dict(d=Decimal(0))),
1447             dict(f4=None, f8=None), dict(f4=0, f8=0),
1448             (dict(f4='', f8=''), dict(f4=None, f8=None)),
1449             (dict(d=1234.5, f4=1234.5, f8=1234.5),
1450              dict(d=Decimal('1234.5'))),
1451             dict(d=Decimal('123.456789'), f4=12.375, f8=123.4921875),
1452             dict(d=Decimal('123456789.9876543212345678987654321')),
1453             dict(m=None), (dict(m=''), dict(m=None)),
1454             dict(m=Decimal('-1234.56')),
1455             (dict(m=('-1234.56')), dict(m=Decimal('-1234.56'))),
1456             dict(m=Decimal('1234.56')), dict(m=Decimal('123456')),
1457             (dict(m='1234.56'), dict(m=Decimal('1234.56'))),
1458             (dict(m=1234.5), dict(m=Decimal('1234.5'))),
1459             (dict(m=-1234.5), dict(m=Decimal('-1234.5'))),
1460             (dict(m=123456), dict(m=Decimal('123456'))),
1461             (dict(m='1234567.89'), dict(m=Decimal('1234567.89'))),
1462             dict(b=None), (dict(b=''), dict(b=None)),
1463             dict(b='f'), dict(b='t'),
1464             (dict(b=0), dict(b='f')), (dict(b=1), dict(b='t')),
1465             (dict(b=False), dict(b='f')), (dict(b=True), dict(b='t')),
1466             (dict(b='0'), dict(b='f')), (dict(b='1'), dict(b='t')),
1467             (dict(b='n'), dict(b='f')), (dict(b='y'), dict(b='t')),
1468             (dict(b='no'), dict(b='f')), (dict(b='yes'), dict(b='t')),
1469             (dict(b='off'), dict(b='f')), (dict(b='on'), dict(b='t')),
1470             dict(v4=None, c4=None, t=None),
1471             (dict(v4='', c4='', t=''), dict(c4=' ' * 4)),
1472             dict(v4='1234', c4='1234', t='1234' * 10),
1473             dict(v4='abcd', c4='abcd', t='abcdefg'),
1474             (dict(v4='abc', c4='abc', t='abc'), dict(c4='abc ')),
1475             dict(ts=None), (dict(ts=''), dict(ts=None)),
1476             (dict(ts=0), dict(ts=None)), (dict(ts=False), dict(ts=None)),
1477             dict(ts='2012-12-21 00:00:00'),
1478             (dict(ts='2012-12-21'), dict(ts='2012-12-21 00:00:00')),
1479             dict(ts='2012-12-21 12:21:12'),
1480             dict(ts='2013-01-05 12:13:14'),
1481             dict(ts='current_timestamp')]
1482        for test in tests:
1483            if isinstance(test, dict):
1484                data = test
1485                change = {}
1486            else:
1487                data, change = test
1488            expect = data.copy()
1489            expect.update(change)
1490            if bool_on:
1491                b = expect.get('b')
1492                if b is not None:
1493                    expect['b'] = b == 't'
1494            if decimal is not Decimal:
1495                d = expect.get('d')
1496                if d is not None:
1497                    expect['d'] = decimal(d)
1498                m = expect.get('m')
1499                if m is not None:
1500                    expect['m'] = decimal(m)
1501            self.assertEqual(insert(table, data), data)
1502            self.assertIn(oid_table, data)
1503            oid = data[oid_table]
1504            self.assertIsInstance(oid, int)
1505            data = dict(item for item in data.items()
1506                        if item[0] in expect)
1507            ts = expect.get('ts')
1508            if ts == 'current_timestamp':
1509                ts = expect['ts'] = data['ts']
1510                if len(ts) > 19:
1511                    self.assertEqual(ts[19], '.')
1512                    ts = ts[:19]
1513                else:
1514                    self.assertEqual(len(ts), 19)
1515                self.assertTrue(ts[:4].isdigit())
1516                self.assertEqual(ts[4], '-')
1517                self.assertEqual(ts[10], ' ')
1518                self.assertTrue(ts[11:13].isdigit())
1519                self.assertEqual(ts[13], ':')
1520            self.assertEqual(data, expect)
1521            data = query(
1522                'select oid,* from "%s"' % table).dictresult()[0]
1523            self.assertEqual(data['oid'], oid)
1524            data = dict(item for item in data.items()
1525                        if item[0] in expect)
1526            self.assertEqual(data, expect)
1527            query('delete from "%s"' % table)
1528
1529    def testInsertWithOid(self):
1530        insert = self.db.insert
1531        query = self.db.query
1532        self.createTable('test_table', 'n int', oids=True)
1533        self.assertRaises(pg.ProgrammingError, insert, 'test_table', m=1)
1534        r = insert('test_table', n=1)
1535        self.assertIsInstance(r, dict)
1536        self.assertEqual(r['n'], 1)
1537        self.assertNotIn('oid', r)
1538        qoid = 'oid(test_table)'
1539        self.assertIn(qoid, r)
1540        oid = r[qoid]
1541        self.assertEqual(sorted(r.keys()), ['n', qoid])
1542        r = insert('test_table', n=2, oid=oid)
1543        self.assertIsInstance(r, dict)
1544        self.assertEqual(r['n'], 2)
1545        self.assertIn(qoid, r)
1546        self.assertNotEqual(r[qoid], oid)
1547        self.assertNotIn('oid', r)
1548        r = insert('test_table', None, n=3)
1549        self.assertIsInstance(r, dict)
1550        self.assertEqual(r['n'], 3)
1551        s = r
1552        r = insert('test_table', r)
1553        self.assertIs(r, s)
1554        self.assertEqual(r['n'], 3)
1555        r = insert('test_table *', r)
1556        self.assertIs(r, s)
1557        self.assertEqual(r['n'], 3)
1558        r = insert('test_table', r, n=4)
1559        self.assertIs(r, s)
1560        self.assertEqual(r['n'], 4)
1561        self.assertNotIn('oid', r)
1562        self.assertIn(qoid, r)
1563        oid = r[qoid]
1564        r = insert('test_table', r, n=5, oid=oid)
1565        self.assertIs(r, s)
1566        self.assertEqual(r['n'], 5)
1567        self.assertIn(qoid, r)
1568        self.assertNotEqual(r[qoid], oid)
1569        self.assertNotIn('oid', r)
1570        r['oid'] = oid = r[qoid]
1571        r = insert('test_table', r, n=6)
1572        self.assertIs(r, s)
1573        self.assertEqual(r['n'], 6)
1574        self.assertIn(qoid, r)
1575        self.assertNotEqual(r[qoid], oid)
1576        self.assertNotIn('oid', r)
1577        q = 'select n from test_table order by 1 limit 9'
1578        r = ' '.join(str(row[0]) for row in query(q).getresult())
1579        self.assertEqual(r, '1 2 3 3 3 4 5 6')
1580        query("truncate test_table")
1581        query("alter table test_table add unique (n)")
1582        r = insert('test_table', dict(n=7))
1583        self.assertIsInstance(r, dict)
1584        self.assertEqual(r['n'], 7)
1585        self.assertRaises(pg.ProgrammingError, insert, 'test_table', r)
1586        r['n'] = 6
1587        self.assertRaises(pg.ProgrammingError, insert, 'test_table', r, n=7)
1588        self.assertIsInstance(r, dict)
1589        self.assertEqual(r['n'], 7)
1590        r['n'] = 6
1591        r = insert('test_table', r)
1592        self.assertIsInstance(r, dict)
1593        self.assertEqual(r['n'], 6)
1594        r = ' '.join(str(row[0]) for row in query(q).getresult())
1595        self.assertEqual(r, '6 7')
1596
1597    def testInsertWithQuotedNames(self):
1598        insert = self.db.insert
1599        query = self.db.query
1600        table = 'test table for insert()'
1601        self.createTable(table, '"Prime!" smallint primary key,'
1602                                ' "much space" integer, "Questions?" text')
1603        r = {'Prime!': 11, 'much space': 2002, 'Questions?': 'What?'}
1604        r = insert(table, r)
1605        self.assertIsInstance(r, dict)
1606        self.assertEqual(r['Prime!'], 11)
1607        self.assertEqual(r['much space'], 2002)
1608        self.assertEqual(r['Questions?'], 'What?')
1609        r = query('select * from "%s" limit 2' % table).dictresult()
1610        self.assertEqual(len(r), 1)
1611        r = r[0]
1612        self.assertEqual(r['Prime!'], 11)
1613        self.assertEqual(r['much space'], 2002)
1614        self.assertEqual(r['Questions?'], 'What?')
1615
1616    def testInsertIntoView(self):
1617        insert = self.db.insert
1618        query = self.db.query
1619        query("truncate test")
1620        q = 'select * from test_view order by i4 limit 3'
1621        r = query(q).getresult()
1622        self.assertEqual(r, [])
1623        r = dict(i4=1234, v4='abcd')
1624        insert('test', r)
1625        self.assertIsNone(r['i2'])
1626        self.assertEqual(r['i4'], 1234)
1627        self.assertIsNone(r['i8'])
1628        self.assertEqual(r['v4'], 'abcd')
1629        self.assertIsNone(r['c4'])
1630        r = query(q).getresult()
1631        self.assertEqual(r, [(1234, 'abcd')])
1632        r = dict(i4=5678, v4='efgh')
1633        try:
1634            insert('test_view', r)
1635        except pg.ProgrammingError as error:
1636            if self.db.server_version < 90300:
1637                # must setup rules in older PostgreSQL versions
1638                self.skipTest('database cannot insert into view')
1639            self.fail(str(error))
1640        self.assertNotIn('i2', r)
1641        self.assertEqual(r['i4'], 5678)
1642        self.assertNotIn('i8', r)
1643        self.assertEqual(r['v4'], 'efgh')
1644        self.assertNotIn('c4', r)
1645        r = query(q).getresult()
1646        self.assertEqual(r, [(1234, 'abcd'), (5678, 'efgh')])
1647
1648    def testUpdate(self):
1649        update = self.db.update
1650        query = self.db.query
1651        self.assertRaises(pg.ProgrammingError, update,
1652                          'test', i2=2, i4=4, i8=8)
1653        table = 'update_test_table'
1654        self.createTable(table, 'n integer, t text', oids=True,
1655                         values=enumerate('xyz', start=1))
1656        self.assertRaises(pg.ProgrammingError, self.db.get, table, 2)
1657        r = self.db.get(table, 2, 'n')
1658        r['t'] = 'u'
1659        s = update(table, r)
1660        self.assertEqual(s, r)
1661        q = 'select t from "%s" where n=2' % table
1662        r = query(q).getresult()[0][0]
1663        self.assertEqual(r, 'u')
1664
1665    def testUpdateWithOid(self):
1666        update = self.db.update
1667        get = self.db.get
1668        query = self.db.query
1669        self.createTable('test_table', 'n int', oids=True, values=[1])
1670        s = get('test_table', 1, 'n')
1671        self.assertIsInstance(s, dict)
1672        self.assertEqual(s['n'], 1)
1673        s['n'] = 2
1674        r = update('test_table', s)
1675        self.assertIs(r, s)
1676        self.assertEqual(r['n'], 2)
1677        qoid = 'oid(test_table)'
1678        self.assertIn(qoid, r)
1679        self.assertNotIn('oid', r)
1680        self.assertEqual(sorted(r.keys()), ['n', qoid])
1681        r['n'] = 3
1682        oid = r.pop(qoid)
1683        r = update('test_table', r, oid=oid)
1684        self.assertIs(r, s)
1685        self.assertEqual(r['n'], 3)
1686        r.pop(qoid)
1687        self.assertRaises(pg.ProgrammingError, update, 'test_table', r)
1688        s = get('test_table', 3, 'n')
1689        self.assertIsInstance(s, dict)
1690        self.assertEqual(s['n'], 3)
1691        s.pop('n')
1692        r = update('test_table', s)
1693        oid = r.pop(qoid)
1694        self.assertEqual(r, {})
1695        q = "select n from test_table limit 2"
1696        r = query(q).getresult()
1697        self.assertEqual(r, [(3,)])
1698        query("insert into test_table values (1)")
1699        self.assertRaises(pg.ProgrammingError,
1700                          update, 'test_table', dict(oid=oid, n=4))
1701        r = update('test_table', dict(n=4), oid=oid)
1702        self.assertEqual(r['n'], 4)
1703        r = update('test_table *', dict(n=5), oid=oid)
1704        self.assertEqual(r['n'], 5)
1705        query("alter table test_table add column m int")
1706        query("alter table test_table add primary key (n)")
1707        self.assertIn('m', self.db.get_attnames('test_table', flush=True))
1708        self.assertEqual('n', self.db.pkey('test_table', flush=True))
1709        s = dict(n=1, m=4)
1710        r = update('test_table', s)
1711        self.assertIs(r, s)
1712        self.assertEqual(r['n'], 1)
1713        self.assertEqual(r['m'], 4)
1714        s = dict(m=7)
1715        r = update('test_table', s, n=5)
1716        self.assertIs(r, s)
1717        self.assertEqual(r['n'], 5)
1718        self.assertEqual(r['m'], 7)
1719        q = "select n, m from test_table order by 1 limit 3"
1720        r = query(q).getresult()
1721        self.assertEqual(r, [(1, 4), (5, 7)])
1722        s = dict(m=9, oid=oid)
1723        self.assertRaises(KeyError, update, 'test_table', s)
1724        r = update('test_table', s, oid=oid)
1725        self.assertIs(r, s)
1726        self.assertEqual(r['n'], 5)
1727        self.assertEqual(r['m'], 9)
1728        s = dict(n=1, m=3, oid=oid)
1729        r = update('test_table', s)
1730        self.assertIs(r, s)
1731        self.assertEqual(r['n'], 1)
1732        self.assertEqual(r['m'], 3)
1733        r = query(q).getresult()
1734        self.assertEqual(r, [(1, 3), (5, 9)])
1735
1736    def testUpdateWithoutOid(self):
1737        update = self.db.update
1738        query = self.db.query
1739        self.assertRaises(pg.ProgrammingError, update,
1740                          'test', i2=2, i4=4, i8=8)
1741        table = 'update_test_table'
1742        self.createTable(table, 'n integer primary key, t text', oids=False,
1743                         values=enumerate('xyz', start=1))
1744        r = self.db.get(table, 2)
1745        r['t'] = 'u'
1746        s = update(table, r)
1747        self.assertEqual(s, r)
1748        q = 'select t from "%s" where n=2' % table
1749        r = query(q).getresult()[0][0]
1750        self.assertEqual(r, 'u')
1751
1752    def testUpdateWithCompositeKey(self):
1753        update = self.db.update
1754        query = self.db.query
1755        table = 'update_test_table_1'
1756        self.createTable(table, 'n integer primary key, t text',
1757                         values=enumerate('abc', start=1))
1758        self.assertRaises(KeyError, update, table, dict(t='b'))
1759        s = dict(n=2, t='d')
1760        r = update(table, s)
1761        self.assertIs(r, s)
1762        self.assertEqual(r['n'], 2)
1763        self.assertEqual(r['t'], 'd')
1764        q = 'select t from "%s" where n=2' % table
1765        r = query(q).getresult()[0][0]
1766        self.assertEqual(r, 'd')
1767        s.update(dict(n=4, t='e'))
1768        r = update(table, s)
1769        self.assertEqual(r['n'], 4)
1770        self.assertEqual(r['t'], 'e')
1771        q = 'select t from "%s" where n=2' % table
1772        r = query(q).getresult()[0][0]
1773        self.assertEqual(r, 'd')
1774        q = 'select t from "%s" where n=4' % table
1775        r = query(q).getresult()
1776        self.assertEqual(len(r), 0)
1777        query('drop table "%s"' % table)
1778        table = 'update_test_table_2'
1779        self.createTable(table,
1780                         'n integer, m integer, t text, primary key (n, m)',
1781                         values=[(n + 1, m + 1, chr(ord('a') + 2 * n + m))
1782                                 for n in range(3) for m in range(2)])
1783        self.assertRaises(KeyError, update, table, dict(n=2, t='b'))
1784        self.assertEqual(update(table,
1785                                dict(n=2, m=2, t='x'))['t'], 'x')
1786        q = 'select t from "%s" where n=2 order by m' % table
1787        r = [r[0] for r in query(q).getresult()]
1788        self.assertEqual(r, ['c', 'x'])
1789
1790    def testUpdateWithQuotedNames(self):
1791        update = self.db.update
1792        query = self.db.query
1793        table = 'test table for update()'
1794        self.createTable(table, '"Prime!" smallint primary key,'
1795                                ' "much space" integer, "Questions?" text',
1796                         values=[(13, 3003, 'Why!')])
1797        r = {'Prime!': 13, 'much space': 7007, 'Questions?': 'When?'}
1798        r = update(table, r)
1799        self.assertIsInstance(r, dict)
1800        self.assertEqual(r['Prime!'], 13)
1801        self.assertEqual(r['much space'], 7007)
1802        self.assertEqual(r['Questions?'], 'When?')
1803        r = query('select * from "%s" limit 2' % table).dictresult()
1804        self.assertEqual(len(r), 1)
1805        r = r[0]
1806        self.assertEqual(r['Prime!'], 13)
1807        self.assertEqual(r['much space'], 7007)
1808        self.assertEqual(r['Questions?'], 'When?')
1809
1810    def testUpsert(self):
1811        upsert = self.db.upsert
1812        query = self.db.query
1813        self.assertRaises(pg.ProgrammingError, upsert,
1814                          'test', i2=2, i4=4, i8=8)
1815        table = 'upsert_test_table'
1816        self.createTable(table, 'n integer primary key, t text', oids=True)
1817        s = dict(n=1, t='x')
1818        try:
1819            r = upsert(table, s)
1820        except pg.ProgrammingError as error:
1821            if self.db.server_version < 90500:
1822                self.skipTest('database does not support upsert')
1823            self.fail(str(error))
1824        self.assertIs(r, s)
1825        self.assertEqual(r['n'], 1)
1826        self.assertEqual(r['t'], 'x')
1827        s.update(n=2, t='y')
1828        r = upsert(table, s, **dict.fromkeys(s))
1829        self.assertIs(r, s)
1830        self.assertEqual(r['n'], 2)
1831        self.assertEqual(r['t'], 'y')
1832        q = 'select n, t from "%s" order by n limit 3' % table
1833        r = query(q).getresult()
1834        self.assertEqual(r, [(1, 'x'), (2, 'y')])
1835        s.update(t='z')
1836        r = upsert(table, s)
1837        self.assertIs(r, s)
1838        self.assertEqual(r['n'], 2)
1839        self.assertEqual(r['t'], 'z')
1840        r = query(q).getresult()
1841        self.assertEqual(r, [(1, 'x'), (2, 'z')])
1842        s.update(t='n')
1843        r = upsert(table, s, t=False)
1844        self.assertIs(r, s)
1845        self.assertEqual(r['n'], 2)
1846        self.assertEqual(r['t'], 'z')
1847        r = query(q).getresult()
1848        self.assertEqual(r, [(1, 'x'), (2, 'z')])
1849        s.update(t='y')
1850        r = upsert(table, s, t=True)
1851        self.assertIs(r, s)
1852        self.assertEqual(r['n'], 2)
1853        self.assertEqual(r['t'], 'y')
1854        r = query(q).getresult()
1855        self.assertEqual(r, [(1, 'x'), (2, 'y')])
1856        s.update(t='n')
1857        r = upsert(table, s, t="included.t || '2'")
1858        self.assertIs(r, s)
1859        self.assertEqual(r['n'], 2)
1860        self.assertEqual(r['t'], 'y2')
1861        r = query(q).getresult()
1862        self.assertEqual(r, [(1, 'x'), (2, 'y2')])
1863        s.update(t='y')
1864        r = upsert(table, s, t="excluded.t || '3'")
1865        self.assertIs(r, s)
1866        self.assertEqual(r['n'], 2)
1867        self.assertEqual(r['t'], 'y3')
1868        r = query(q).getresult()
1869        self.assertEqual(r, [(1, 'x'), (2, 'y3')])
1870        s.update(n=1, t='2')
1871        r = upsert(table, s, t="included.t || excluded.t")
1872        self.assertIs(r, s)
1873        self.assertEqual(r['n'], 1)
1874        self.assertEqual(r['t'], 'x2')
1875        r = query(q).getresult()
1876        self.assertEqual(r, [(1, 'x2'), (2, 'y3')])
1877        # not existing columns and oid parameter should be ignored
1878        s = dict(m=3, u='z')
1879        r = upsert(table, s, oid='invalid')
1880        self.assertIs(r, s)
1881
1882    def testUpsertWithOid(self):
1883        upsert = self.db.upsert
1884        get = self.db.get
1885        query = self.db.query
1886        self.createTable('test_table', 'n int', oids=True, values=[1])
1887        self.assertRaises(pg.ProgrammingError,
1888                          upsert, 'test_table', dict(n=2))
1889        r = get('test_table', 1, 'n')
1890        self.assertIsInstance(r, dict)
1891        self.assertEqual(r['n'], 1)
1892        qoid = 'oid(test_table)'
1893        self.assertIn(qoid, r)
1894        self.assertNotIn('oid', r)
1895        oid = r[qoid]
1896        self.assertRaises(pg.ProgrammingError,
1897                          upsert, 'test_table', dict(n=2, oid=oid))
1898        query("alter table test_table add column m int")
1899        query("alter table test_table add primary key (n)")
1900        self.assertIn('m', self.db.get_attnames('test_table', flush=True))
1901        self.assertEqual('n', self.db.pkey('test_table', flush=True))
1902        s = dict(n=2)
1903        try:
1904            r = upsert('test_table', s)
1905        except pg.ProgrammingError as error:
1906            if self.db.server_version < 90500:
1907                self.skipTest('database does not support upsert')
1908            self.fail(str(error))
1909        self.assertIs(r, s)
1910        self.assertEqual(r['n'], 2)
1911        self.assertIsNone(r['m'])
1912        q = query("select n, m from test_table order by n limit 3")
1913        self.assertEqual(q.getresult(), [(1, None), (2, None)])
1914        r['oid'] = oid
1915        r = upsert('test_table', r)
1916        self.assertIs(r, s)
1917        self.assertEqual(r['n'], 2)
1918        self.assertIsNone(r['m'])
1919        self.assertIn(qoid, r)
1920        self.assertNotIn('oid', r)
1921        self.assertNotEqual(r[qoid], oid)
1922        r['m'] = 7
1923        r = upsert('test_table', r)
1924        self.assertIs(r, s)
1925        self.assertEqual(r['n'], 2)
1926        self.assertEqual(r['m'], 7)
1927        r.update(n=1, m=3)
1928        r = upsert('test_table', r)
1929        self.assertIs(r, s)
1930        self.assertEqual(r['n'], 1)
1931        self.assertEqual(r['m'], 3)
1932        q = query("select n, m from test_table order by n limit 3")
1933        self.assertEqual(q.getresult(), [(1, 3), (2, 7)])
1934        r = upsert('test_table', r, oid='invalid')
1935        self.assertIs(r, s)
1936        self.assertEqual(r['n'], 1)
1937        self.assertEqual(r['m'], 3)
1938        r['m'] = 5
1939        r = upsert('test_table', r, m=False)
1940        self.assertIs(r, s)
1941        self.assertEqual(r['n'], 1)
1942        self.assertEqual(r['m'], 3)
1943        r['m'] = 5
1944        r = upsert('test_table', r, m=True)
1945        self.assertIs(r, s)
1946        self.assertEqual(r['n'], 1)
1947        self.assertEqual(r['m'], 5)
1948        r.update(n=2, m=1)
1949        r = upsert('test_table', r, m='included.m')
1950        self.assertIs(r, s)
1951        self.assertEqual(r['n'], 2)
1952        self.assertEqual(r['m'], 7)
1953        r['m'] = 9
1954        r = upsert('test_table', r, m='excluded.m')
1955        self.assertIs(r, s)
1956        self.assertEqual(r['n'], 2)
1957        self.assertEqual(r['m'], 9)
1958        r['m'] = 8
1959        r = upsert('test_table *', r, m='included.m + 1')
1960        self.assertIs(r, s)
1961        self.assertEqual(r['n'], 2)
1962        self.assertEqual(r['m'], 10)
1963        q = query("select n, m from test_table order by n limit 3")
1964        self.assertEqual(q.getresult(), [(1, 5), (2, 10)])
1965
1966    def testUpsertWithCompositeKey(self):
1967        upsert = self.db.upsert
1968        query = self.db.query
1969        table = 'upsert_test_table_2'
1970        self.createTable(table,
1971                         'n integer, m integer, t text, primary key (n, m)')
1972        s = dict(n=1, m=2, t='x')
1973        try:
1974            r = upsert(table, s)
1975        except pg.ProgrammingError as error:
1976            if self.db.server_version < 90500:
1977                self.skipTest('database does not support upsert')
1978            self.fail(str(error))
1979        self.assertIs(r, s)
1980        self.assertEqual(r['n'], 1)
1981        self.assertEqual(r['m'], 2)
1982        self.assertEqual(r['t'], 'x')
1983        s.update(m=3, t='y')
1984        r = upsert(table, s, **dict.fromkeys(s))
1985        self.assertIs(r, s)
1986        self.assertEqual(r['n'], 1)
1987        self.assertEqual(r['m'], 3)
1988        self.assertEqual(r['t'], 'y')
1989        q = 'select n, m, t from "%s" order by n, m limit 3' % table
1990        r = query(q).getresult()
1991        self.assertEqual(r, [(1, 2, 'x'), (1, 3, 'y')])
1992        s.update(t='z')
1993        r = upsert(table, s)
1994        self.assertIs(r, s)
1995        self.assertEqual(r['n'], 1)
1996        self.assertEqual(r['m'], 3)
1997        self.assertEqual(r['t'], 'z')
1998        r = query(q).getresult()
1999        self.assertEqual(r, [(1, 2, 'x'), (1, 3, 'z')])
2000        s.update(t='n')
2001        r = upsert(table, s, t=False)
2002        self.assertIs(r, s)
2003        self.assertEqual(r['n'], 1)
2004        self.assertEqual(r['m'], 3)
2005        self.assertEqual(r['t'], 'z')
2006        r = query(q).getresult()
2007        self.assertEqual(r, [(1, 2, 'x'), (1, 3, 'z')])
2008        s.update(t='n')
2009        r = upsert(table, s, t=True)
2010        self.assertIs(r, s)
2011        self.assertEqual(r['n'], 1)
2012        self.assertEqual(r['m'], 3)
2013        self.assertEqual(r['t'], 'n')
2014        r = query(q).getresult()
2015        self.assertEqual(r, [(1, 2, 'x'), (1, 3, 'n')])
2016        s.update(n=2, t='y')
2017        r = upsert(table, s, t="'z'")
2018        self.assertIs(r, s)
2019        self.assertEqual(r['n'], 2)
2020        self.assertEqual(r['m'], 3)
2021        self.assertEqual(r['t'], 'y')
2022        r = query(q).getresult()
2023        self.assertEqual(r, [(1, 2, 'x'), (1, 3, 'n'), (2, 3, 'y')])
2024        s.update(n=1, t='m')
2025        r = upsert(table, s, t='included.t || excluded.t')
2026        self.assertIs(r, s)
2027        self.assertEqual(r['n'], 1)
2028        self.assertEqual(r['m'], 3)
2029        self.assertEqual(r['t'], 'nm')
2030        r = query(q).getresult()
2031        self.assertEqual(r, [(1, 2, 'x'), (1, 3, 'nm'), (2, 3, 'y')])
2032
2033    def testUpsertWithQuotedNames(self):
2034        upsert = self.db.upsert
2035        query = self.db.query
2036        table = 'test table for upsert()'
2037        self.createTable(table, '"Prime!" smallint primary key,'
2038                                ' "much space" integer, "Questions?" text')
2039        s = {'Prime!': 31, 'much space': 9009, 'Questions?': 'Yes.'}
2040        try:
2041            r = upsert(table, s)
2042        except pg.ProgrammingError as error:
2043            if self.db.server_version < 90500:
2044                self.skipTest('database does not support upsert')
2045            self.fail(str(error))
2046        self.assertIs(r, s)
2047        self.assertEqual(r['Prime!'], 31)
2048        self.assertEqual(r['much space'], 9009)
2049        self.assertEqual(r['Questions?'], 'Yes.')
2050        q = 'select * from "%s" limit 2' % table
2051        r = query(q).getresult()
2052        self.assertEqual(r, [(31, 9009, 'Yes.')])
2053        s.update({'Questions?': 'No.'})
2054        r = upsert(table, s)
2055        self.assertIs(r, s)
2056        self.assertEqual(r['Prime!'], 31)
2057        self.assertEqual(r['much space'], 9009)
2058        self.assertEqual(r['Questions?'], 'No.')
2059        r = query(q).getresult()
2060        self.assertEqual(r, [(31, 9009, 'No.')])
2061
2062    def testClear(self):
2063        clear = self.db.clear
2064        f = False if pg.get_bool() else 'f'
2065        r = clear('test')
2066        result = dict(
2067            i2=0, i4=0, i8=0, d=0, f4=0, f8=0, m=0, v4='', c4='', t='')
2068        self.assertEqual(r, result)
2069        table = 'clear_test_table'
2070        self.createTable(table,
2071            'n integer, f float, b boolean, d date, t text', oids=True)
2072        r = clear(table)
2073        result = dict(n=0, f=0, b=f, d='', t='')
2074        self.assertEqual(r, result)
2075        r['a'] = r['f'] = r['n'] = 1
2076        r['d'] = r['t'] = 'x'
2077        r['b'] = 't'
2078        r['oid'] = long(1)
2079        r = clear(table, r)
2080        result = dict(a=1, n=0, f=0, b=f, d='', t='', oid=long(1))
2081        self.assertEqual(r, result)
2082
2083    def testClearWithQuotedNames(self):
2084        clear = self.db.clear
2085        table = 'test table for clear()'
2086        self.createTable(table, '"Prime!" smallint primary key,'
2087            ' "much space" integer, "Questions?" text')
2088        r = clear(table)
2089        self.assertIsInstance(r, dict)
2090        self.assertEqual(r['Prime!'], 0)
2091        self.assertEqual(r['much space'], 0)
2092        self.assertEqual(r['Questions?'], '')
2093
2094    def testDelete(self):
2095        delete = self.db.delete
2096        query = self.db.query
2097        self.assertRaises(pg.ProgrammingError, delete,
2098                          'test', dict(i2=2, i4=4, i8=8))
2099        table = 'delete_test_table'
2100        self.createTable(table, 'n integer, t text', oids=True,
2101                         values=enumerate('xyz', start=1))
2102        self.assertRaises(pg.ProgrammingError, self.db.get, table, 2)
2103        r = self.db.get(table, 1, 'n')
2104        s = delete(table, r)
2105        self.assertEqual(s, 1)
2106        r = self.db.get(table, 3, 'n')
2107        s = delete(table, r)
2108        self.assertEqual(s, 1)
2109        s = delete(table, r)
2110        self.assertEqual(s, 0)
2111        r = query('select * from "%s"' % table).dictresult()
2112        self.assertEqual(len(r), 1)
2113        r = r[0]
2114        result = {'n': 2, 't': 'y'}
2115        self.assertEqual(r, result)
2116        r = self.db.get(table, 2, 'n')
2117        s = delete(table, r)
2118        self.assertEqual(s, 1)
2119        s = delete(table, r)
2120        self.assertEqual(s, 0)
2121        self.assertRaises(pg.DatabaseError, self.db.get, table, 2, 'n')
2122        # not existing columns and oid parameter should be ignored
2123        r.update(m=3, u='z', oid='invalid')
2124        s = delete(table, r)
2125        self.assertEqual(s, 0)
2126
2127    def testDeleteWithOid(self):
2128        delete = self.db.delete
2129        get = self.db.get
2130        query = self.db.query
2131        self.createTable('test_table', 'n int', oids=True, values=range(1, 7))
2132        r = dict(n=3)
2133        self.assertRaises(pg.ProgrammingError, delete, 'test_table', r)
2134        s = get('test_table', 1, 'n')
2135        qoid = 'oid(test_table)'
2136        self.assertIn(qoid, s)
2137        r = delete('test_table', s)
2138        self.assertEqual(r, 1)
2139        r = delete('test_table', s)
2140        self.assertEqual(r, 0)
2141        q = "select min(n),count(n) from test_table"
2142        self.assertEqual(query(q).getresult()[0], (2, 5))
2143        oid = get('test_table', 2, 'n')[qoid]
2144        s = dict(oid=oid, n=2)
2145        self.assertRaises(pg.ProgrammingError, delete, 'test_table', s)
2146        r = delete('test_table', None, oid=oid)
2147        self.assertEqual(r, 1)
2148        r = delete('test_table', None, oid=oid)
2149        self.assertEqual(r, 0)
2150        self.assertEqual(query(q).getresult()[0], (3, 4))
2151        s = dict(oid=oid, n=2)
2152        oid = get('test_table', 3, 'n')[qoid]
2153        self.assertRaises(pg.ProgrammingError, delete, 'test_table', s)
2154        r = delete('test_table', s, oid=oid)
2155        self.assertEqual(r, 1)
2156        r = delete('test_table', s, oid=oid)
2157        self.assertEqual(r, 0)
2158        self.assertEqual(query(q).getresult()[0], (4, 3))
2159        s = get('test_table', 4, 'n')
2160        r = delete('test_table *', s)
2161        self.assertEqual(r, 1)
2162        r = delete('test_table *', s)
2163        self.assertEqual(r, 0)
2164        self.assertEqual(query(q).getresult()[0], (5, 2))
2165        oid = get('test_table', 5, 'n')[qoid]
2166        s = {qoid: oid, 'm': 4}
2167        r = delete('test_table', s, m=6)
2168        self.assertEqual(r, 1)
2169        r = delete('test_table *', s)
2170        self.assertEqual(r, 0)
2171        self.assertEqual(query(q).getresult()[0], (6, 1))
2172        query("alter table test_table add column m int")
2173        query("alter table test_table add primary key (n)")
2174        self.assertIn('m', self.db.get_attnames('test_table', flush=True))
2175        self.assertEqual('n', self.db.pkey('test_table', flush=True))
2176        for i in range(5):
2177            query("insert into test_table values (%d, %d)" % (i + 1, i + 2))
2178        s = dict(m=2)
2179        self.assertRaises(KeyError, delete, 'test_table', s)
2180        s = dict(m=2, oid=oid)
2181        self.assertRaises(KeyError, delete, 'test_table', s)
2182        r = delete('test_table', dict(m=2), oid=oid)
2183        self.assertEqual(r, 0)
2184        oid = get('test_table', 1, 'n')[qoid]
2185        s = dict(oid=oid)
2186        self.assertRaises(KeyError, delete, 'test_table', s)
2187        r = delete('test_table', s, oid=oid)
2188        self.assertEqual(r, 1)
2189        r = delete('test_table', s, oid=oid)
2190        self.assertEqual(r, 0)
2191        self.assertEqual(query(q).getresult()[0], (2, 5))
2192        s = get('test_table', 2, 'n')
2193        del s['n']
2194        r = delete('test_table', s)
2195        self.assertEqual(r, 1)
2196        r = delete('test_table', s)
2197        self.assertEqual(r, 0)
2198        self.assertEqual(query(q).getresult()[0], (3, 4))
2199        r = delete('test_table', n=3)
2200        self.assertEqual(r, 1)
2201        r = delete('test_table', n=3)
2202        self.assertEqual(r, 0)
2203        self.assertEqual(query(q).getresult()[0], (4, 3))
2204        r = delete('test_table', None, n=4)
2205        self.assertEqual(r, 1)
2206        r = delete('test_table', None, n=4)
2207        self.assertEqual(r, 0)
2208        self.assertEqual(query(q).getresult()[0], (5, 2))
2209        s = dict(n=6)
2210        r = delete('test_table', s, n=5)
2211        self.assertEqual(r, 1)
2212        r = delete('test_table', s, n=5)
2213        self.assertEqual(r, 0)
2214        self.assertEqual(query(q).getresult()[0], (6, 1))
2215
2216    def testDeleteWithCompositeKey(self):
2217        query = self.db.query
2218        table = 'delete_test_table_1'
2219        self.createTable(table, 'n integer primary key, t text',
2220                         values=enumerate('abc', start=1))
2221        self.assertRaises(KeyError, self.db.delete, table, dict(t='b'))
2222        self.assertEqual(self.db.delete(table, dict(n=2)), 1)
2223        r = query('select t from "%s" where n=2' % table).getresult()
2224        self.assertEqual(r, [])
2225        self.assertEqual(self.db.delete(table, dict(n=2)), 0)
2226        r = query('select t from "%s" where n=3' % table).getresult()[0][0]
2227        self.assertEqual(r, 'c')
2228        table = 'delete_test_table_2'
2229        self.createTable(table,
2230             'n integer, m integer, t text, primary key (n, m)',
2231             values=[(n + 1, m + 1, chr(ord('a') + 2 * n + m))
2232                     for n in range(3) for m in range(2)])
2233        self.assertRaises(KeyError, self.db.delete, table, dict(n=2, t='b'))
2234        self.assertEqual(self.db.delete(table, dict(n=2, m=2)), 1)
2235        r = [r[0] for r in query('select t from "%s" where n=2'
2236            ' order by m' % table).getresult()]
2237        self.assertEqual(r, ['c'])
2238        self.assertEqual(self.db.delete(table, dict(n=2, m=2)), 0)
2239        r = [r[0] for r in query('select t from "%s" where n=3'
2240             ' order by m' % table).getresult()]
2241        self.assertEqual(r, ['e', 'f'])
2242        self.assertEqual(self.db.delete(table, dict(n=3, m=1)), 1)
2243        r = [r[0] for r in query('select t from "%s" where n=3'
2244             ' order by m' % table).getresult()]
2245        self.assertEqual(r, ['f'])
2246
2247    def testDeleteWithQuotedNames(self):
2248        delete = self.db.delete
2249        query = self.db.query
2250        table = 'test table for delete()'
2251        self.createTable(table, '"Prime!" smallint primary key,'
2252            ' "much space" integer, "Questions?" text',
2253            values=[(19, 5005, 'Yes!')])
2254        r = {'Prime!': 17}
2255        r = delete(table, r)
2256        self.assertEqual(r, 0)
2257        r = query('select count(*) from "%s"' % table).getresult()
2258        self.assertEqual(r[0][0], 1)
2259        r = {'Prime!': 19}
2260        r = delete(table, r)
2261        self.assertEqual(r, 1)
2262        r = query('select count(*) from "%s"' % table).getresult()
2263        self.assertEqual(r[0][0], 0)
2264
2265    def testDeleteReferenced(self):
2266        delete = self.db.delete
2267        query = self.db.query
2268        self.createTable('test_parent',
2269            'n smallint primary key', values=range(3))
2270        self.createTable('test_child',
2271            'n smallint primary key references test_parent', values=range(3))
2272        q = ("select (select count(*) from test_parent),"
2273             " (select count(*) from test_child)")
2274        self.assertEqual(query(q).getresult()[0], (3, 3))
2275        self.assertRaises(pg.ProgrammingError,
2276                          delete, 'test_parent', None, n=2)
2277        self.assertRaises(pg.ProgrammingError,
2278                          delete, 'test_parent *', None, n=2)
2279        r = delete('test_child', None, n=2)
2280        self.assertEqual(r, 1)
2281        self.assertEqual(query(q).getresult()[0], (3, 2))
2282        r = delete('test_parent', None, n=2)
2283        self.assertEqual(r, 1)
2284        self.assertEqual(query(q).getresult()[0], (2, 2))
2285        self.assertRaises(pg.ProgrammingError,
2286                          delete, 'test_parent', dict(n=0))
2287        self.assertRaises(pg.ProgrammingError,
2288                          delete, 'test_parent *', dict(n=0))
2289        r = delete('test_child', dict(n=0))
2290        self.assertEqual(r, 1)
2291        self.assertEqual(query(q).getresult()[0], (2, 1))
2292        r = delete('test_child', dict(n=0))
2293        self.assertEqual(r, 0)
2294        r = delete('test_parent', dict(n=0))
2295        self.assertEqual(r, 1)
2296        self.assertEqual(query(q).getresult()[0], (1, 1))
2297        r = delete('test_parent', None, n=0)
2298        self.assertEqual(r, 0)
2299        q = "select n from test_parent natural join test_child limit 2"
2300        self.assertEqual(query(q).getresult(), [(1,)])
2301
2302    def testTruncate(self):
2303        truncate = self.db.truncate
2304        self.assertRaises(TypeError, truncate, None)
2305        self.assertRaises(TypeError, truncate, 42)
2306        self.assertRaises(TypeError, truncate, dict(test_table=None))
2307        query = self.db.query
2308        self.createTable('test_table', 'n smallint',
2309                         temporary=False, values=[1] * 3)
2310        q = "select count(*) from test_table"
2311        r = query(q).getresult()[0][0]
2312        self.assertEqual(r, 3)
2313        truncate('test_table')
2314        r = query(q).getresult()[0][0]
2315        self.assertEqual(r, 0)
2316        for i in range(3):
2317            query("insert into test_table values (1)")
2318        r = query(q).getresult()[0][0]
2319        self.assertEqual(r, 3)
2320        truncate('public.test_table')
2321        r = query(q).getresult()[0][0]
2322        self.assertEqual(r, 0)
2323        self.createTable('test_table_2', 'n smallint', temporary=True)
2324        for t in (list, tuple, set):
2325            for i in range(3):
2326                query("insert into test_table values (1)")
2327                query("insert into test_table_2 values (2)")
2328            q = ("select (select count(*) from test_table),"
2329                 " (select count(*) from test_table_2)")
2330            r = query(q).getresult()[0]
2331            self.assertEqual(r, (3, 3))
2332            truncate(t(['test_table', 'test_table_2']))
2333            r = query(q).getresult()[0]
2334            self.assertEqual(r, (0, 0))
2335
2336    def testTruncateRestart(self):
2337        truncate = self.db.truncate
2338        self.assertRaises(TypeError, truncate, 'test_table', restart='invalid')
2339        query = self.db.query
2340        self.createTable('test_table', 'n serial, t text')
2341        for n in range(3):
2342            query("insert into test_table (t) values ('test')")
2343        q = "select count(n), min(n), max(n) from test_table"
2344        r = query(q).getresult()[0]
2345        self.assertEqual(r, (3, 1, 3))
2346        truncate('test_table')
2347        r = query(q).getresult()[0]
2348        self.assertEqual(r, (0, None, None))
2349        for n in range(3):
2350            query("insert into test_table (t) values ('test')")
2351        r = query(q).getresult()[0]
2352        self.assertEqual(r, (3, 4, 6))
2353        truncate('test_table', restart=True)
2354        r = query(q).getresult()[0]
2355        self.assertEqual(r, (0, None, None))
2356        for n in range(3):
2357            query("insert into test_table (t) values ('test')")
2358        r = query(q).getresult()[0]
2359        self.assertEqual(r, (3, 1, 3))
2360
2361    def testTruncateCascade(self):
2362        truncate = self.db.truncate
2363        self.assertRaises(TypeError, truncate, 'test_table', cascade='invalid')
2364        query = self.db.query
2365        self.createTable('test_parent', 'n smallint primary key',
2366                         values=range(3))
2367        self.createTable('test_child',
2368                         'n smallint primary key references test_parent (n)',
2369                         values=range(3))
2370        q = ("select (select count(*) from test_parent),"
2371             " (select count(*) from test_child)")
2372        r = query(q).getresult()[0]
2373        self.assertEqual(r, (3, 3))
2374        self.assertRaises(pg.ProgrammingError, truncate, 'test_parent')
2375        truncate(['test_parent', 'test_child'])
2376        r = query(q).getresult()[0]
2377        self.assertEqual(r, (0, 0))
2378        for n in range(3):
2379            query("insert into test_parent (n) values (%d)" % n)
2380            query("insert into test_child (n) values (%d)" % n)
2381        r = query(q).getresult()[0]
2382        self.assertEqual(r, (3, 3))
2383        truncate('test_parent', cascade=True)
2384        r = query(q).getresult()[0]
2385        self.assertEqual(r, (0, 0))
2386        for n in range(3):
2387            query("insert into test_parent (n) values (%d)" % n)
2388            query("insert into test_child (n) values (%d)" % n)
2389        r = query(q).getresult()[0]
2390        self.assertEqual(r, (3, 3))
2391        truncate('test_child')
2392        r = query(q).getresult()[0]
2393        self.assertEqual(r, (3, 0))
2394        self.assertRaises(pg.ProgrammingError, truncate, 'test_parent')
2395        truncate('test_parent', cascade=True)
2396        r = query(q).getresult()[0]
2397        self.assertEqual(r, (0, 0))
2398
2399    def testTruncateOnly(self):
2400        truncate = self.db.truncate
2401        self.assertRaises(TypeError, truncate, 'test_table', only='invalid')
2402        query = self.db.query
2403        self.createTable('test_parent', 'n smallint')
2404        self.createTable('test_child', 'm smallint) inherits (test_parent')
2405        for n in range(3):
2406            query("insert into test_parent (n) values (1)")
2407            query("insert into test_child (n, m) values (2, 3)")
2408        q = ("select (select count(*) from test_parent),"
2409             " (select count(*) from test_child)")
2410        r = query(q).getresult()[0]
2411        self.assertEqual(r, (6, 3))
2412        truncate('test_parent')
2413        r = query(q).getresult()[0]
2414        self.assertEqual(r, (0, 0))
2415        for n in range(3):
2416            query("insert into test_parent (n) values (1)")
2417            query("insert into test_child (n, m) values (2, 3)")
2418        r = query(q).getresult()[0]
2419        self.assertEqual(r, (6, 3))
2420        truncate('test_parent*')
2421        r = query(q).getresult()[0]
2422        self.assertEqual(r, (0, 0))
2423        for n in range(3):
2424            query("insert into test_parent (n) values (1)")
2425            query("insert into test_child (n, m) values (2, 3)")
2426        r = query(q).getresult()[0]
2427        self.assertEqual(r, (6, 3))
2428        truncate('test_parent', only=True)
2429        r = query(q).getresult()[0]
2430        self.assertEqual(r, (3, 3))
2431        truncate('test_parent', only=False)
2432        r = query(q).getresult()[0]
2433        self.assertEqual(r, (0, 0))
2434        self.assertRaises(ValueError, truncate, 'test_parent*', only=True)
2435        truncate('test_parent*', only=False)
2436        self.createTable('test_parent_2', 'n smallint')
2437        self.createTable('test_child_2', 'm smallint) inherits (test_parent_2')
2438        for t in '', '_2':
2439            for n in range(3):
2440                query("insert into test_parent%s (n) values (1)" % t)
2441                query("insert into test_child%s (n, m) values (2, 3)" % t)
2442        q = ("select (select count(*) from test_parent),"
2443             " (select count(*) from test_child),"
2444             " (select count(*) from test_parent_2),"
2445             " (select count(*) from test_child_2)")
2446        r = query(q).getresult()[0]
2447        self.assertEqual(r, (6, 3, 6, 3))
2448        truncate(['test_parent', 'test_parent_2'], only=[False, True])
2449        r = query(q).getresult()[0]
2450        self.assertEqual(r, (0, 0, 3, 3))
2451        truncate(['test_parent', 'test_parent_2'], only=False)
2452        r = query(q).getresult()[0]
2453        self.assertEqual(r, (0, 0, 0, 0))
2454        self.assertRaises(ValueError, truncate,
2455            ['test_parent*', 'test_child'], only=[True, False])
2456        truncate(['test_parent*', 'test_child'], only=[False, True])
2457
2458    def testTruncateQuoted(self):
2459        truncate = self.db.truncate
2460        query = self.db.query
2461        table = "test table for truncate()"
2462        self.createTable(table, 'n smallint', temporary=False, values=[1] * 3)
2463        q = 'select count(*) from "%s"' % table
2464        r = query(q).getresult()[0][0]
2465        self.assertEqual(r, 3)
2466        truncate(table)
2467        r = query(q).getresult()[0][0]
2468        self.assertEqual(r, 0)
2469        for i in range(3):
2470            query('insert into "%s" values (1)' % table)
2471        r = query(q).getresult()[0][0]
2472        self.assertEqual(r, 3)
2473        truncate('public."%s"' % table)
2474        r = query(q).getresult()[0][0]
2475        self.assertEqual(r, 0)
2476
2477    def testGetAsList(self):
2478        get_as_list = self.db.get_as_list
2479        self.assertRaises(TypeError, get_as_list)
2480        self.assertRaises(TypeError, get_as_list, None)
2481        query = self.db.query
2482        table = 'test_aslist'
2483        r = query('select 1 as colname').namedresult()[0]
2484        self.assertIsInstance(r, tuple)
2485        named = hasattr(r, 'colname')
2486        names = [(1, 'Homer'), (2, 'Marge'),
2487                 (3, 'Bart'), (4, 'Lisa'), (5, 'Maggie')]
2488        self.createTable(table,
2489            'id smallint primary key, name varchar', values=names)
2490        r = get_as_list(table)
2491        self.assertIsInstance(r, list)
2492        self.assertEqual(r, names)
2493        for t, n in zip(r, names):
2494            self.assertIsInstance(t, tuple)
2495            self.assertEqual(t, n)
2496            if named:
2497                self.assertEqual(t.id, n[0])
2498                self.assertEqual(t.name, n[1])
2499                self.assertEqual(t._asdict(), dict(id=n[0], name=n[1]))
2500        r = get_as_list(table, what='name')
2501        self.assertIsInstance(r, list)
2502        expected = sorted((row[1],) for row in names)
2503        self.assertEqual(r, expected)
2504        r = get_as_list(table, what='name, id')
2505        self.assertIsInstance(r, list)
2506        expected = sorted(tuple(reversed(row)) for row in names)
2507        self.assertEqual(r, expected)
2508        r = get_as_list(table, what=['name', 'id'])
2509        self.assertIsInstance(r, list)
2510        self.assertEqual(r, expected)
2511        r = get_as_list(table, where="name like 'Ba%'")
2512        self.assertIsInstance(r, list)
2513        self.assertEqual(r, names[2:3])
2514        r = get_as_list(table, what='name', where="name like 'Ma%'")
2515        self.assertIsInstance(r, list)
2516        self.assertEqual(r, [('Maggie',), ('Marge',)])
2517        r = get_as_list(table, what='name',
2518            where=["name like 'Ma%'", "name like '%r%'"])
2519        self.assertIsInstance(r, list)
2520        self.assertEqual(r, [('Marge',)])
2521        r = get_as_list(table, what='name', order='id')
2522        self.assertIsInstance(r, list)
2523        expected = [(row[1],) for row in names]
2524        self.assertEqual(r, expected)
2525        r = get_as_list(table, what=['name'], order=['id'])
2526        self.assertIsInstance(r, list)
2527        self.assertEqual(r, expected)
2528        r = get_as_list(table, what=['id', 'name'], order=['id', 'name'])
2529        self.assertIsInstance(r, list)
2530        self.assertEqual(r, names)
2531        r = get_as_list(table, what='id * 2 as num', order='id desc')
2532        self.assertIsInstance(r, list)
2533        expected = [(n,) for n in range(10, 0, -2)]
2534        self.assertEqual(r, expected)
2535        r = get_as_list(table, limit=2)
2536        self.assertIsInstance(r, list)
2537        self.assertEqual(r, names[:2])
2538        r = get_as_list(table, offset=3)
2539        self.assertIsInstance(r, list)
2540        self.assertEqual(r, names[3:])
2541        r = get_as_list(table, limit=1, offset=2)
2542        self.assertIsInstance(r, list)
2543        self.assertEqual(r, names[2:3])
2544        r = get_as_list(table, scalar=True)
2545        self.assertIsInstance(r, list)
2546        self.assertEqual(r, list(range(1, 6)))
2547        r = get_as_list(table, what='name', scalar=True)
2548        self.assertIsInstance(r, list)
2549        expected = sorted(row[1] for row in names)
2550        self.assertEqual(r, expected)
2551        r = get_as_list(table, what='name', limit=1, scalar=True)
2552        self.assertIsInstance(r, list)
2553        self.assertEqual(r, expected[:1])
2554        query('alter table "%s" drop constraint "%s_pkey"' % (table, table))
2555        self.assertRaises(KeyError, self.db.pkey, table, flush=True)
2556        names.insert(1, (1, 'Snowball'))
2557        query('insert into "%s" values ($1, $2)' % table, (1, 'Snowball'))
2558        r = get_as_list(table)
2559        self.assertIsInstance(r, list)
2560        self.assertEqual(r, names)
2561        r = get_as_list(table, what='name', where='id=1', scalar=True)
2562        self.assertIsInstance(r, list)
2563        self.assertEqual(r, ['Homer', 'Snowball'])
2564        # test with unordered query
2565        r = get_as_list(table, order=False)
2566        self.assertIsInstance(r, list)
2567        self.assertEqual(set(r), set(names))
2568        # test with arbitrary from clause
2569        from_table = '(select lower(name) as n2 from "%s") as t2' % table
2570        r = get_as_list(from_table)
2571        self.assertIsInstance(r, list)
2572        r = set(row[0] for row in r)
2573        expected = set(row[1].lower() for row in names)
2574        self.assertEqual(r, expected)
2575        r = get_as_list(from_table, order='n2', scalar=True)
2576        self.assertIsInstance(r, list)
2577        self.assertEqual(r, sorted(expected))
2578        r = get_as_list(from_table, order='n2', limit=1)
2579        self.assertIsInstance(r, list)
2580        self.assertEqual(len(r), 1)
2581        t = r[0]
2582        self.assertIsInstance(t, tuple)
2583        if named:
2584            self.assertEqual(t.n2, 'bart')
2585            self.assertEqual(t._asdict(), dict(n2='bart'))
2586        else:
2587            self.assertEqual(t, ('bart',))
2588
2589    def testGetAsDict(self):
2590        get_as_dict = self.db.get_as_dict
2591        self.assertRaises(TypeError, get_as_dict)
2592        self.assertRaises(TypeError, get_as_dict, None)
2593        # the test table has no primary key
2594        self.assertRaises(pg.ProgrammingError, get_as_dict, 'test')
2595        query = self.db.query
2596        table = 'test_asdict'
2597        r = query('select 1 as colname').namedresult()[0]
2598        self.assertIsInstance(r, tuple)
2599        named = hasattr(r, 'colname')
2600        colors = [(1, '#7cb9e8', 'Aero'), (2, '#b5a642', 'Brass'),
2601                  (3, '#b2ffff', 'Celeste'), (4, '#c19a6b', 'Desert')]
2602        self.createTable(table,
2603            'id smallint primary key, rgb char(7), name varchar',
2604            values=colors)
2605        # keyname must be string, list or tuple
2606        self.assertRaises(KeyError, get_as_dict, table, 3)
2607        self.assertRaises(KeyError, get_as_dict, table, dict(id=None))
2608        # missing keyname in row
2609        self.assertRaises(KeyError, get_as_dict, table,
2610                          keyname='rgb', what='name')
2611        r = get_as_dict(table)
2612        self.assertIsInstance(r, OrderedDict)
2613        expected = OrderedDict((row[0], row[1:]) for row in colors)
2614        self.assertEqual(r, expected)
2615        for key in r:
2616            self.assertIsInstance(key, int)
2617            self.assertIn(key, expected)
2618            row = r[key]
2619            self.assertIsInstance(row, tuple)
2620            t = expected[key]
2621            self.assertEqual(row, t)
2622            if named:
2623                self.assertEqual(row.rgb, t[0])
2624                self.assertEqual(row.name, t[1])
2625                self.assertEqual(row._asdict(), dict(rgb=t[0], name=t[1]))
2626        if OrderedDict is not dict:  # Python > 2.6
2627            self.assertEqual(r.keys(), expected.keys())
2628        r = get_as_dict(table, keyname='rgb')
2629        self.assertIsInstance(r, OrderedDict)
2630        expected = OrderedDict((row[1], (row[0], row[2]))
2631                               for row in sorted(colors, key=itemgetter(1)))
2632        self.assertEqual(r, expected)
2633        for key in r:
2634            self.assertIsInstance(key, str)
2635            self.assertIn(key, expected)
2636            row = r[key]
2637            self.assertIsInstance(row, tuple)
2638            t = expected[key]
2639            self.assertEqual(row, t)
2640            if named:
2641                self.assertEqual(row.id, t[0])
2642                self.assertEqual(row.name, t[1])
2643                self.assertEqual(row._asdict(), dict(id=t[0], name=t[1]))
2644        if OrderedDict is not dict:  # Python > 2.6
2645            self.assertEqual(r.keys(), expected.keys())
2646        r = get_as_dict(table, keyname=['id', 'rgb'])
2647        self.assertIsInstance(r, OrderedDict)
2648        expected = OrderedDict((row[:2], row[2:]) for row in colors)
2649        self.assertEqual(r, expected)
2650        for key in r:
2651            self.assertIsInstance(key, tuple)
2652            self.assertIsInstance(key[0], int)
2653            self.assertIsInstance(key[1], str)
2654            if named:
2655                self.assertEqual(key, (key.id, key.rgb))
2656                self.assertEqual(key._fields, ('id', 'rgb'))
2657            row = r[key]
2658            self.assertIsInstance(row, tuple)
2659            self.assertIsInstance(row[0], str)
2660            t = expected[key]
2661            self.assertEqual(row, t)
2662            if named:
2663                self.assertEqual(row.name, t[0])
2664                self.assertEqual(row._asdict(), dict(name=t[0]))
2665        if OrderedDict is not dict:  # Python > 2.6
2666            self.assertEqual(r.keys(), expected.keys())
2667        r = get_as_dict(table, keyname=['id', 'rgb'], scalar=True)
2668        self.assertIsInstance(r, OrderedDict)
2669        expected = OrderedDict((row[:2], row[2]) for row in colors)
2670        self.assertEqual(r, expected)
2671        for key in r:
2672            self.assertIsInstance(key, tuple)
2673            row = r[key]
2674            self.assertIsInstance(row, str)
2675            t = expected[key]
2676            self.assertEqual(row, t)
2677        if OrderedDict is not dict:  # Python > 2.6
2678            self.assertEqual(r.keys(), expected.keys())
2679        r = get_as_dict(table, keyname='rgb', what=['rgb', 'name'], scalar=True)
2680        self.assertIsInstance(r, OrderedDict)
2681        expected = OrderedDict((row[1], row[2])
2682            for row in sorted(colors, key=itemgetter(1)))
2683        self.assertEqual(r, expected)
2684        for key in r:
2685            self.assertIsInstance(key, str)
2686            row = r[key]
2687            self.assertIsInstance(row, str)
2688            t = expected[key]
2689            self.assertEqual(row, t)
2690        if OrderedDict is not dict:  # Python > 2.6
2691            self.assertEqual(r.keys(), expected.keys())
2692        r = get_as_dict(table, what='id, name',
2693            where="rgb like '#b%'", scalar=True)
2694        self.assertIsInstance(r, OrderedDict)
2695        expected = OrderedDict((row[0], row[2]) for row in colors[1:3])
2696        self.assertEqual(r, expected)
2697        for key in r:
2698            self.assertIsInstance(key, int)
2699            row = r[key]
2700            self.assertIsInstance(row, str)
2701            t = expected[key]
2702            self.assertEqual(row, t)
2703        if OrderedDict is not dict:  # Python > 2.6
2704            self.assertEqual(r.keys(), expected.keys())
2705        expected = r
2706        r = get_as_dict(table, what=['name', 'id'],
2707            where=['id > 1', 'id < 4', "rgb like '#b%'",
2708                   "name not like 'A%'", "name not like '%t'"], scalar=True)
2709        self.assertEqual(r, expected)
2710        r = get_as_dict(table, what='name, id', limit=2, offset=1, scalar=True)
2711        self.assertEqual(r, expected)
2712        r = get_as_dict(table, keyname=('id',), what=('name', 'id'),
2713            where=('id > 1', 'id < 4'), order=('id',), scalar=True)
2714        self.assertEqual(r, expected)
2715        r = get_as_dict(table, limit=1)
2716        self.assertEqual(len(r), 1)
2717        self.assertEqual(r[1][1], 'Aero')
2718        r = get_as_dict(table, offset=3)
2719        self.assertEqual(len(r), 1)
2720        self.assertEqual(r[4][1], 'Desert')
2721        r = get_as_dict(table, order='id desc')
2722        expected = OrderedDict((row[0], row[1:]) for row in reversed(colors))
2723        self.assertEqual(r, expected)
2724        r = get_as_dict(table, where='id > 5')
2725        self.assertIsInstance(r, OrderedDict)
2726        self.assertEqual(len(r), 0)
2727        # test with unordered query
2728        expected = dict((row[0], row[1:]) for row in colors)
2729        r = get_as_dict(table, order=False)
2730        self.assertIsInstance(r, dict)
2731        self.assertEqual(r, expected)
2732        if dict is not OrderedDict:  # Python > 2.6
2733            self.assertNotIsInstance(self, OrderedDict)
2734        # test with arbitrary from clause
2735        from_table = '(select id, lower(name) as n2 from "%s") as t2' % table
2736        # primary key must be passed explicitly in this case
2737        self.assertRaises(pg.ProgrammingError, get_as_dict, from_table)
2738        r = get_as_dict(from_table, 'id')
2739        self.assertIsInstance(r, OrderedDict)
2740        expected = OrderedDict((row[0], (row[2].lower(),)) for row in colors)
2741        self.assertEqual(r, expected)
2742        # test without a primary key
2743        query('alter table "%s" drop constraint "%s_pkey"' % (table, table))
2744        self.assertRaises(KeyError, self.db.pkey, table, flush=True)
2745        self.assertRaises(pg.ProgrammingError, get_as_dict, table)
2746        r = get_as_dict(table, keyname='id')
2747        expected = OrderedDict((row[0], row[1:]) for row in colors)
2748        self.assertIsInstance(r, dict)
2749        self.assertEqual(r, expected)
2750        r = (1, '#007fff', 'Azure')
2751        query('insert into "%s" values ($1, $2, $3)' % table, r)
2752        # the last entry will win
2753        expected[1] = r[1:]
2754        r = get_as_dict(table, keyname='id')
2755        self.assertEqual(r, expected)
2756
2757    def testTransaction(self):
2758        query = self.db.query
2759        self.createTable('test_table', 'n integer', temporary=False)
2760        self.db.begin()
2761        query("insert into test_table values (1)")
2762        query("insert into test_table values (2)")
2763        self.db.commit()
2764        self.db.begin()
2765        query("insert into test_table values (3)")
2766        query("insert into test_table values (4)")
2767        self.db.rollback()
2768        self.db.begin()
2769        query("insert into test_table values (5)")
2770        self.db.savepoint('before6')
2771        query("insert into test_table values (6)")
2772        self.db.rollback('before6')
2773        query("insert into test_table values (7)")
2774        self.db.commit()
2775        self.db.begin()
2776        self.db.savepoint('before8')
2777        query("insert into test_table values (8)")
2778        self.db.release('before8')
2779        self.assertRaises(pg.ProgrammingError, self.db.rollback, 'before8')
2780        self.db.commit()
2781        self.db.start()
2782        query("insert into test_table values (9)")
2783        self.db.end()
2784        r = [r[0] for r in query(
2785            "select * from test_table order by 1").getresult()]
2786        self.assertEqual(r, [1, 2, 5, 7, 9])
2787        self.db.begin(mode='read only')
2788        self.assertRaises(pg.ProgrammingError,
2789                          query, "insert into test_table values (0)")
2790        self.db.rollback()
2791        self.db.start(mode='Read Only')
2792        self.assertRaises(pg.ProgrammingError,
2793                          query, "insert into test_table values (0)")
2794        self.db.abort()
2795
2796    def testTransactionAliases(self):
2797        self.assertEqual(self.db.begin, self.db.start)
2798        self.assertEqual(self.db.commit, self.db.end)
2799        self.assertEqual(self.db.rollback, self.db.abort)
2800
2801    def testContextManager(self):
2802        query = self.db.query
2803        self.createTable('test_table', 'n integer check(n>0)')
2804        with self.db:
2805            query("insert into test_table values (1)")
2806            query("insert into test_table values (2)")
2807        try:
2808            with self.db:
2809                query("insert into test_table values (3)")
2810                query("insert into test_table values (4)")
2811                raise ValueError('test transaction should rollback')
2812        except ValueError as error:
2813            self.assertEqual(str(error), 'test transaction should rollback')
2814        with self.db:
2815            query("insert into test_table values (5)")
2816        try:
2817            with self.db:
2818                query("insert into test_table values (6)")
2819                query("insert into test_table values (-1)")
2820        except pg.ProgrammingError as error:
2821            self.assertTrue('check' in str(error))
2822        with self.db:
2823            query("insert into test_table values (7)")
2824        r = [r[0] for r in query(
2825            "select * from test_table order by 1").getresult()]
2826        self.assertEqual(r, [1, 2, 5, 7])
2827
2828    def testBytea(self):
2829        query = self.db.query
2830        self.createTable('bytea_test', 'n smallint primary key, data bytea')
2831        s = b"It's all \\ kinds \x00 of\r nasty \xff stuff!\n"
2832        r = self.db.escape_bytea(s)
2833        query('insert into bytea_test values(3, $1)', (r,))
2834        r = query('select * from bytea_test where n=3').getresult()
2835        self.assertEqual(len(r), 1)
2836        r = r[0]
2837        self.assertEqual(len(r), 2)
2838        self.assertEqual(r[0], 3)
2839        r = r[1]
2840        if pg.get_bytea_escaped():
2841            self.assertNotEqual(r, s)
2842            r = pg.unescape_bytea(r)
2843        self.assertIsInstance(r, bytes)
2844        self.assertEqual(r, s)
2845
2846    def testInsertUpdateGetBytea(self):
2847        query = self.db.query
2848        unescape = pg.unescape_bytea if pg.get_bytea_escaped() else None
2849        self.createTable('bytea_test', 'n smallint primary key, data bytea')
2850        # insert null value
2851        r = self.db.insert('bytea_test', n=0, data=None)
2852        self.assertIsInstance(r, dict)
2853        self.assertIn('n', r)
2854        self.assertEqual(r['n'], 0)
2855        self.assertIn('data', r)
2856        self.assertIsNone(r['data'])
2857        s = b'None'
2858        r = self.db.update('bytea_test', n=0, data=s)
2859        self.assertIsInstance(r, dict)
2860        self.assertIn('n', r)
2861        self.assertEqual(r['n'], 0)
2862        self.assertIn('data', r)
2863        r = r['data']
2864        if unescape:
2865            self.assertNotEqual(r, s)
2866            r = unescape(r)
2867        self.assertIsInstance(r, bytes)
2868        self.assertEqual(r, s)
2869        r = self.db.update('bytea_test', n=0, data=None)
2870        self.assertIsNone(r['data'])
2871        # insert as bytes
2872        s = b"It's all \\ kinds \x00 of\r nasty \xff stuff!\n"
2873        r = self.db.insert('bytea_test', n=5, data=s)
2874        self.assertIsInstance(r, dict)
2875        self.assertIn('n', r)
2876        self.assertEqual(r['n'], 5)
2877        self.assertIn('data', r)
2878        r = r['data']
2879        if unescape:
2880            self.assertNotEqual(r, s)
2881            r = unescape(r)
2882        self.assertIsInstance(r, bytes)
2883        self.assertEqual(r, s)
2884        # update as bytes
2885        s += b"and now even more \x00 nasty \t stuff!\f"
2886        r = self.db.update('bytea_test', n=5, data=s)
2887        self.assertIsInstance(r, dict)
2888        self.assertIn('n', r)
2889        self.assertEqual(r['n'], 5)
2890        self.assertIn('data', r)
2891        r = r['data']
2892        if unescape:
2893            self.assertNotEqual(r, s)
2894            r = unescape(r)
2895        self.assertIsInstance(r, bytes)
2896        self.assertEqual(r, s)
2897        r = query('select * from bytea_test where n=5').getresult()
2898        self.assertEqual(len(r), 1)
2899        r = r[0]
2900        self.assertEqual(len(r), 2)
2901        self.assertEqual(r[0], 5)
2902        r = r[1]
2903        if unescape:
2904            self.assertNotEqual(r, s)
2905            r = unescape(r)
2906        self.assertIsInstance(r, bytes)
2907        self.assertEqual(r, s)
2908        r = self.db.get('bytea_test', dict(n=5))
2909        self.assertIsInstance(r, dict)
2910        self.assertIn('n', r)
2911        self.assertEqual(r['n'], 5)
2912        self.assertIn('data', r)
2913        r = r['data']
2914        if unescape:
2915            self.assertNotEqual(r, s)
2916            r = pg.unescape_bytea(r)
2917        self.assertIsInstance(r, bytes)
2918        self.assertEqual(r, s)
2919
2920    def testUpsertBytea(self):
2921        self.createTable('bytea_test', 'n smallint primary key, data bytea')
2922        s = b"It's all \\ kinds \x00 of\r nasty \xff stuff!\n"
2923        r = dict(n=7, data=s)
2924        try:
2925            r = self.db.upsert('bytea_test', r)
2926        except pg.ProgrammingError as error:
2927            if self.db.server_version < 90500:
2928                self.skipTest('database does not support upsert')
2929            self.fail(str(error))
2930        self.assertIsInstance(r, dict)
2931        self.assertIn('n', r)
2932        self.assertEqual(r['n'], 7)
2933        self.assertIn('data', r)
2934        if pg.get_bytea_escaped():
2935            self.assertNotEqual(r['data'], s)
2936            r['data'] = pg.unescape_bytea(r['data'])
2937        self.assertIsInstance(r['data'], bytes)
2938        self.assertEqual(r['data'], s)
2939        r['data'] = None
2940        r = self.db.upsert('bytea_test', r)
2941        self.assertIsInstance(r, dict)
2942        self.assertIn('n', r)
2943        self.assertEqual(r['n'], 7)
2944        self.assertIn('data', r)
2945        self.assertIsNone(r['data'])
2946
2947    def testInsertGetJson(self):
2948        try:
2949            self.createTable('json_test', 'n smallint primary key, data json')
2950        except pg.ProgrammingError as error:
2951            if self.db.server_version < 90200:
2952                self.skipTest('database does not support json')
2953            self.fail(str(error))
2954        jsondecode = pg.get_jsondecode()
2955        # insert null value
2956        r = self.db.insert('json_test', n=0, data=None)
2957        self.assertIsInstance(r, dict)
2958        self.assertIn('n', r)
2959        self.assertEqual(r['n'], 0)
2960        self.assertIn('data', r)
2961        self.assertIsNone(r['data'])
2962        r = self.db.get('json_test', 0)
2963        self.assertIsInstance(r, dict)
2964        self.assertIn('n', r)
2965        self.assertEqual(r['n'], 0)
2966        self.assertIn('data', r)
2967        self.assertIsNone(r['data'])
2968        # insert JSON object
2969        data = {
2970            "id": 1, "name": "Foo", "price": 1234.5,
2971            "new": True, "note": None,
2972            "tags": ["Bar", "Eek"],
2973            "stock": {"warehouse": 300, "retail": 20}}
2974        r = self.db.insert('json_test', n=1, data=data)
2975        self.assertIsInstance(r, dict)
2976        self.assertIn('n', r)
2977        self.assertEqual(r['n'], 1)
2978        self.assertIn('data', r)
2979        r = r['data']
2980        if jsondecode is None:
2981            self.assertIsInstance(r, str)
2982            r = json.loads(r)
2983        self.assertIsInstance(r, dict)
2984        self.assertEqual(r, data)
2985        self.assertIsInstance(r['id'], int)
2986        self.assertIsInstance(r['name'], unicode)
2987        self.assertIsInstance(r['price'], float)
2988        self.assertIsInstance(r['new'], bool)
2989        self.assertIsInstance(r['tags'], list)
2990        self.assertIsInstance(r['stock'], dict)
2991        r = self.db.get('json_test', 1)
2992        self.assertIsInstance(r, dict)
2993        self.assertIn('n', r)
2994        self.assertEqual(r['n'], 1)
2995        self.assertIn('data', r)
2996        r = r['data']
2997        if jsondecode is None:
2998            self.assertIsInstance(r, str)
2999            r = json.loads(r)
3000        self.assertIsInstance(r, dict)
3001        self.assertEqual(r, data)
3002        self.assertIsInstance(r['id'], int)
3003        self.assertIsInstance(r['name'], unicode)
3004        self.assertIsInstance(r['price'], float)
3005        self.assertIsInstance(r['new'], bool)
3006        self.assertIsInstance(r['tags'], list)
3007        self.assertIsInstance(r['stock'], dict)
3008        # insert JSON object as text
3009        self.db.insert('json_test', n=2, data=json.dumps(data))
3010        q = "select data from json_test where n in (1, 2) order by n"
3011        r = self.db.query(q).getresult()
3012        self.assertEqual(len(r), 2)
3013        self.assertIsInstance(r[0][0], str if jsondecode is None else dict)
3014        self.assertEqual(r[0][0], r[1][0])
3015
3016    def testInsertGetJsonb(self):
3017        try:
3018            self.createTable('jsonb_test',
3019                             'n smallint primary key, data jsonb')
3020        except pg.ProgrammingError as error:
3021            if self.db.server_version < 90400:
3022                self.skipTest('database does not support jsonb')
3023            self.fail(str(error))
3024        jsondecode = pg.get_jsondecode()
3025        # insert null value
3026        r = self.db.insert('jsonb_test', n=0, data=None)
3027        self.assertIsInstance(r, dict)
3028        self.assertIn('n', r)
3029        self.assertEqual(r['n'], 0)
3030        self.assertIn('data', r)
3031        self.assertIsNone(r['data'])
3032        r = self.db.get('jsonb_test', 0)
3033        self.assertIsInstance(r, dict)
3034        self.assertIn('n', r)
3035        self.assertEqual(r['n'], 0)
3036        self.assertIn('data', r)
3037        self.assertIsNone(r['data'])
3038        # insert JSON object
3039        data = {
3040            "id": 1, "name": "Foo", "price": 1234.5,
3041            "new": True, "note": None,
3042            "tags": ["Bar", "Eek"],
3043            "stock": {"warehouse": 300, "retail": 20}}
3044        r = self.db.insert('jsonb_test', n=1, data=data)
3045        self.assertIsInstance(r, dict)
3046        self.assertIn('n', r)
3047        self.assertEqual(r['n'], 1)
3048        self.assertIn('data', r)
3049        r = r['data']
3050        if jsondecode is None:
3051            self.assertIsInstance(r, str)
3052            r = json.loads(r)
3053        self.assertIsInstance(r, dict)
3054        self.assertEqual(r, data)
3055        self.assertIsInstance(r['id'], int)
3056        self.assertIsInstance(r['name'], unicode)
3057        self.assertIsInstance(r['price'], float)
3058        self.assertIsInstance(r['new'], bool)
3059        self.assertIsInstance(r['tags'], list)
3060        self.assertIsInstance(r['stock'], dict)
3061        r = self.db.get('jsonb_test', 1)
3062        self.assertIsInstance(r, dict)
3063        self.assertIn('n', r)
3064        self.assertEqual(r['n'], 1)
3065        self.assertIn('data', r)
3066        r = r['data']
3067        if jsondecode is None:
3068            self.assertIsInstance(r, str)
3069            r = json.loads(r)
3070        self.assertIsInstance(r, dict)
3071        self.assertEqual(r, data)
3072        self.assertIsInstance(r['id'], int)
3073        self.assertIsInstance(r['name'], unicode)
3074        self.assertIsInstance(r['price'], float)
3075        self.assertIsInstance(r['new'], bool)
3076        self.assertIsInstance(r['tags'], list)
3077        self.assertIsInstance(r['stock'], dict)
3078
3079    def testArray(self):
3080        returns_arrays = pg.get_array()
3081        self.createTable('arraytest',
3082            'id smallint, i2 smallint[], i4 integer[], i8 bigint[],'
3083            ' d numeric[], f4 real[], f8 double precision[], m money[],'
3084            ' b bool[], v4 varchar(4)[], c4 char(4)[], t text[]')
3085        r = self.db.get_attnames('arraytest')
3086        if self.regtypes:
3087            self.assertEqual(r, dict(
3088                id='smallint', i2='smallint[]', i4='integer[]', i8='bigint[]',
3089                d='numeric[]', f4='real[]', f8='double precision[]',
3090                m='money[]', b='boolean[]',
3091                v4='character varying[]', c4='character[]', t='text[]'))
3092        else:
3093            self.assertEqual(r, dict(
3094                id='int', i2='int[]', i4='int[]', i8='int[]',
3095                d='num[]', f4='float[]', f8='float[]', m='money[]',
3096                b='bool[]', v4='text[]', c4='text[]', t='text[]'))
3097        decimal = pg.get_decimal()
3098        if decimal is Decimal:
3099            long_decimal = decimal('123456789.123456789')
3100            odd_money = decimal('1234567891234567.89')
3101        else:
3102            long_decimal = decimal('12345671234.5')
3103            odd_money = decimal('1234567123.25')
3104        t, f = (True, False) if pg.get_bool() else ('t', 'f')
3105        data = dict(id=42, i2=[42, 1234, None, 0, -1],
3106            i4=[42, 123456789, None, 0, 1, -1],
3107            i8=[long(42), long(123456789123456789), None,
3108                long(0), long(1), long(-1)],
3109            d=[decimal(42), long_decimal, None,
3110               decimal(0), decimal(1), decimal(-1), -long_decimal],
3111            f4=[42.0, 1234.5, None, 0.0, 1.0, -1.0,
3112                float('inf'), float('-inf')],
3113            f8=[42.0, 12345671234.5, None, 0.0, 1.0, -1.0,
3114                float('inf'), float('-inf')],
3115            m=[decimal('42.00'), odd_money, None,
3116               decimal('0.00'), decimal('1.00'), decimal('-1.00'), -odd_money],
3117            b=[t, f, t, None, f, t, None, None, t],
3118            v4=['abc', '"Hi"', '', None], c4=['abc ', '"Hi"', '    ', None],
3119            t=['abc', 'Hello, World!', '"Hello, World!"', '', None])
3120        r = data.copy()
3121        self.db.insert('arraytest', r)
3122        if returns_arrays:
3123            self.assertEqual(r, data)
3124        else:
3125            self.assertEqual(r['i4'], '{42,123456789,NULL,0,1,-1}')
3126        self.db.insert('arraytest', r)
3127        r = self.db.get('arraytest', 42, 'id')
3128        if returns_arrays:
3129            self.assertEqual(r, data)
3130        else:
3131            self.assertEqual(r['i4'], '{42,123456789,NULL,0,1,-1}')
3132        r = self.db.query('select * from arraytest limit 1').dictresult()[0]
3133        if returns_arrays:
3134            self.assertEqual(r, data)
3135        else:
3136            self.assertEqual(r['i4'], '{42,123456789,NULL,0,1,-1}')
3137
3138    def testArrayLiteral(self):
3139        insert = self.db.insert
3140        returns_arrays = pg.get_array()
3141        self.createTable('arraytest', 'i int[], t text[]', oids=True)
3142        r = dict(i=[1, 2, 3], t=['a', 'b', 'c'])
3143        insert('arraytest', r)
3144        if returns_arrays:
3145            self.assertEqual(r['i'], [1, 2, 3])
3146            self.assertEqual(r['t'], ['a', 'b', 'c'])
3147        else:
3148            self.assertEqual(r['i'], '{1,2,3}')
3149            self.assertEqual(r['t'], '{a,b,c}')
3150        r = dict(i='{1,2,3}', t='{a,b,c}')
3151        self.db.insert('arraytest', r)
3152        if returns_arrays:
3153            self.assertEqual(r['i'], [1, 2, 3])
3154            self.assertEqual(r['t'], ['a', 'b', 'c'])
3155        else:
3156            self.assertEqual(r['i'], '{1,2,3}')
3157            self.assertEqual(r['t'], '{a,b,c}')
3158        L = pg.Literal
3159        r = dict(i=L("ARRAY[1, 2, 3]"), t=L("ARRAY['a', 'b', 'c']"))
3160        self.db.insert('arraytest', r)
3161        if returns_arrays:
3162            self.assertEqual(r['i'], [1, 2, 3])
3163            self.assertEqual(r['t'], ['a', 'b', 'c'])
3164        else:
3165            self.assertEqual(r['i'], '{1,2,3}')
3166            self.assertEqual(r['t'], '{a,b,c}')
3167        r = dict(i="1, 2, 3", t="'a', 'b', 'c'")
3168        self.assertRaises(pg.ProgrammingError, self.db.insert, 'arraytest', r)
3169
3170    def testArrayOfIds(self):
3171        array_on = pg.get_array()
3172        self.createTable('arraytest', 'c cid[], o oid[], x xid[]', oids=True)
3173        r = self.db.get_attnames('arraytest')
3174        if self.regtypes:
3175            self.assertEqual(r, dict(
3176                oid='oid', c='cid[]', o='oid[]', x='xid[]'))
3177        else:
3178            self.assertEqual(r, dict(
3179                oid='int', c='int[]', o='int[]', x='int[]'))
3180        data = dict(c=[11, 12, 13], o=[21, 22, 23], x=[31, 32, 33])
3181        r = data.copy()
3182        self.db.insert('arraytest', r)
3183        qoid = 'oid(arraytest)'
3184        oid = r.pop(qoid)
3185        if array_on:
3186            self.assertEqual(r, data)
3187        else:
3188            self.assertEqual(r['o'], '{21,22,23}')
3189        r = {qoid: oid}
3190        self.db.get('arraytest', r)
3191        self.assertEqual(oid, r.pop(qoid))
3192        if array_on:
3193            self.assertEqual(r, data)
3194        else:
3195            self.assertEqual(r['o'], '{21,22,23}')
3196
3197    def testArrayOfText(self):
3198        array_on = pg.get_array()
3199        self.createTable('arraytest', 'data text[]', oids=True)
3200        r = self.db.get_attnames('arraytest')
3201        self.assertEqual(r['data'], 'text[]')
3202        data = ['Hello, World!', '', None, '{a,b,c}', '"Hi!"',
3203                'null', 'NULL', 'Null', 'nulL',
3204                "It's all \\ kinds of\r nasty stuff!\n"]
3205        r = dict(data=data)
3206        self.db.insert('arraytest', r)
3207        if not array_on:
3208            r['data'] = pg.cast_array(r['data'])
3209        self.assertEqual(r['data'], data)
3210        self.assertIsInstance(r['data'][1], str)
3211        self.assertIsNone(r['data'][2])
3212        r['data'] = None
3213        self.db.get('arraytest', r)
3214        if not array_on:
3215            r['data'] = pg.cast_array(r['data'])
3216        self.assertEqual(r['data'], data)
3217        self.assertIsInstance(r['data'][1], str)
3218        self.assertIsNone(r['data'][2])
3219
3220    def testArrayOfBytea(self):
3221        array_on = pg.get_array()
3222        bytea_escaped = pg.get_bytea_escaped()
3223        self.createTable('arraytest', 'data bytea[]', oids=True)
3224        r = self.db.get_attnames('arraytest')
3225        self.assertEqual(r['data'], 'bytea[]')
3226        data = [b'Hello, World!', b'', None, b'{a,b,c}', b'"Hi!"',
3227                b"It's all \\ kinds \x00 of\r nasty \xff stuff!\n"]
3228        r = dict(data=data)
3229        self.db.insert('arraytest', r)
3230        if array_on:
3231            self.assertIsInstance(r['data'], list)
3232        if array_on and not bytea_escaped:
3233            self.assertEqual(r['data'], data)
3234            self.assertIsInstance(r['data'][1], bytes)
3235            self.assertIsNone(r['data'][2])
3236        else:
3237            self.assertNotEqual(r['data'], data)
3238        r['data'] = None
3239        self.db.get('arraytest', r)
3240        if array_on:
3241            self.assertIsInstance(r['data'], list)
3242        if array_on and not bytea_escaped:
3243            self.assertEqual(r['data'], data)
3244            self.assertIsInstance(r['data'][1], bytes)
3245            self.assertIsNone(r['data'][2])
3246        else:
3247            self.assertNotEqual(r['data'], data)
3248
3249    def testArrayOfJson(self):
3250        try:
3251            self.createTable('arraytest', 'data json[]', oids=True)
3252        except pg.ProgrammingError as error:
3253            if self.db.server_version < 90200:
3254                self.skipTest('database does not support json')
3255            self.fail(str(error))
3256        r = self.db.get_attnames('arraytest')
3257        self.assertEqual(r['data'], 'json[]')
3258        data = [dict(id=815, name='John Doe'), dict(id=816, name='Jane Roe')]
3259        array_on = pg.get_array()
3260        jsondecode = pg.get_jsondecode()
3261        r = dict(data=data)
3262        self.db.insert('arraytest', r)
3263        if not array_on:
3264            r['data'] = pg.cast_array(r['data'], jsondecode)
3265        if jsondecode is None:
3266            r['data'] = [json.loads(d) for d in r['data']]
3267        self.assertEqual(r['data'], data)
3268        r['data'] = None
3269        self.db.get('arraytest', r)
3270        if not array_on:
3271            r['data'] = pg.cast_array(r['data'], jsondecode)
3272        if jsondecode is None:
3273            r['data'] = [json.loads(d) for d in r['data']]
3274        self.assertEqual(r['data'], data)
3275        r = dict(data=[json.dumps(d) for d in data])
3276        self.db.insert('arraytest', r)
3277        if not array_on:
3278            r['data'] = pg.cast_array(r['data'], jsondecode)
3279        if jsondecode is None:
3280            r['data'] = [json.loads(d) for d in r['data']]
3281        self.assertEqual(r['data'], data)
3282        r['data'] = None
3283        self.db.get('arraytest', r)
3284        # insert empty json values
3285        r = dict(data=['', None])
3286        self.db.insert('arraytest', r)
3287        r = r['data']
3288        if array_on:
3289            self.assertIsInstance(r, list)
3290            self.assertEqual(len(r), 2)
3291            self.assertIsNone(r[0])
3292            self.assertIsNone(r[1])
3293        else:
3294            self.assertEqual(r, '{NULL,NULL}')
3295
3296    def testArrayOfJsonb(self):
3297        try:
3298            self.createTable('arraytest', 'data jsonb[]', oids=True)
3299        except pg.ProgrammingError as error:
3300            if self.db.server_version < 90400:
3301                self.skipTest('database does not support jsonb')
3302            self.fail(str(error))
3303        r = self.db.get_attnames('arraytest')
3304        self.assertEqual(r['data'], 'jsonb[]' if self.regtypes else 'json[]')
3305        data = [dict(id=815, name='John Doe'), dict(id=816, name='Jane Roe')]
3306        array_on = pg.get_array()
3307        jsondecode = pg.get_jsondecode()
3308        r = dict(data=data)
3309        self.db.insert('arraytest', r)
3310        if not array_on:
3311            r['data'] = pg.cast_array(r['data'], jsondecode)
3312        if jsondecode is None:
3313            r['data'] = [json.loads(d) for d in r['data']]
3314        self.assertEqual(r['data'], data)
3315        r['data'] = None
3316        self.db.get('arraytest', r)
3317        if not array_on:
3318            r['data'] = pg.cast_array(r['data'], jsondecode)
3319        if jsondecode is None:
3320            r['data'] = [json.loads(d) for d in r['data']]
3321        self.assertEqual(r['data'], data)
3322        r = dict(data=[json.dumps(d) for d in data])
3323        self.db.insert('arraytest', r)
3324        if not array_on:
3325            r['data'] = pg.cast_array(r['data'], jsondecode)
3326        if jsondecode is None:
3327            r['data'] = [json.loads(d) for d in r['data']]
3328        self.assertEqual(r['data'], data)
3329        r['data'] = None
3330        self.db.get('arraytest', r)
3331        # insert empty json values
3332        r = dict(data=['', None])
3333        self.db.insert('arraytest', r)
3334        r = r['data']
3335        if array_on:
3336            self.assertIsInstance(r, list)
3337            self.assertEqual(len(r), 2)
3338            self.assertIsNone(r[0])
3339            self.assertIsNone(r[1])
3340        else:
3341            self.assertEqual(r, '{NULL,NULL}')
3342
3343    def testDeepArray(self):
3344        array_on = pg.get_array()
3345        self.createTable('arraytest', 'data text[][][]', oids=True)
3346        r = self.db.get_attnames('arraytest')
3347        self.assertEqual(r['data'], 'text[]')
3348        data = [[['Hello, World!', '{a,b,c}', 'back\\slash']]]
3349        r = dict(data=data)
3350        self.db.insert('arraytest', r)
3351        if array_on:
3352            self.assertEqual(r['data'], data)
3353        else:
3354            self.assertTrue(r['data'].startswith('{{{"Hello,'))
3355        r['data'] = None
3356        self.db.get('arraytest', r)
3357        if array_on:
3358            self.assertEqual(r['data'], data)
3359        else:
3360            self.assertTrue(r['data'].startswith('{{{"Hello,'))
3361
3362    def testInsertUpdateGetRecord(self):
3363        query = self.db.query
3364        query('create type test_person_type as'
3365              ' (name varchar, age smallint, married bool,'
3366              ' weight real, salary money)')
3367        self.addCleanup(query, 'drop type test_person_type')
3368        self.createTable('test_person', 'person test_person_type',
3369                         temporary=False, oids=True)
3370        attnames = self.db.get_attnames('test_person')
3371        self.assertEqual(len(attnames), 2)
3372        self.assertIn('oid', attnames)
3373        self.assertIn('person', attnames)
3374        person_typ = attnames['person']
3375        if self.regtypes:
3376            self.assertEqual(person_typ, 'test_person_type')
3377        else:
3378            self.assertEqual(person_typ, 'record')
3379        if self.regtypes:
3380            self.assertEqual(person_typ.attnames,
3381                dict(name='character varying', age='smallint',
3382                    married='boolean', weight='real', salary='money'))
3383        else:
3384            self.assertEqual(person_typ.attnames,
3385                dict(name='text', age='int', married='bool',
3386                    weight='float', salary='money'))
3387        decimal = pg.get_decimal()
3388        if pg.get_bool():
3389            bool_class = bool
3390            t, f = True, False
3391        else:
3392            bool_class = str
3393            t, f = 't', 'f'
3394        person = ('John Doe', 61, t, 99.5, decimal('93456.75'))
3395        r = self.db.insert('test_person', None, person=person)
3396        p = r['person']
3397        self.assertIsInstance(p, tuple)
3398        self.assertEqual(p, person)
3399        self.assertEqual(p.name, 'John Doe')
3400        self.assertIsInstance(p.name, str)
3401        self.assertIsInstance(p.age, int)
3402        self.assertIsInstance(p.married, bool_class)
3403        self.assertIsInstance(p.weight, float)
3404        self.assertIsInstance(p.salary, decimal)
3405        person = ('Jane Roe', 59, f, 64.5, decimal('96543.25'))
3406        r['person'] = person
3407        self.db.update('test_person', r)
3408        p = r['person']
3409        self.assertIsInstance(p, tuple)
3410        self.assertEqual(p, person)
3411        self.assertEqual(p.name, 'Jane Roe')
3412        self.assertIsInstance(p.name, str)
3413        self.assertIsInstance(p.age, int)
3414        self.assertIsInstance(p.married, bool_class)
3415        self.assertIsInstance(p.weight, float)
3416        self.assertIsInstance(p.salary, decimal)
3417        r['person'] = None
3418        self.db.get('test_person', r)
3419        p = r['person']
3420        self.assertIsInstance(p, tuple)
3421        self.assertEqual(p, person)
3422        self.assertEqual(p.name, 'Jane Roe')
3423        self.assertIsInstance(p.name, str)
3424        self.assertIsInstance(p.age, int)
3425        self.assertIsInstance(p.married, bool_class)
3426        self.assertIsInstance(p.weight, float)
3427        self.assertIsInstance(p.salary, decimal)
3428        person = (None,) * 5
3429        r = self.db.insert('test_person', None, person=person)
3430        p = r['person']
3431        self.assertIsInstance(p, tuple)
3432        self.assertIsNone(p.name)
3433        self.assertIsNone(p.age)
3434        self.assertIsNone(p.married)
3435        self.assertIsNone(p.weight)
3436        self.assertIsNone(p.salary)
3437        r['person'] = None
3438        self.db.get('test_person', r)
3439        p = r['person']
3440        self.assertIsInstance(p, tuple)
3441        self.assertIsNone(p.name)
3442        self.assertIsNone(p.age)
3443        self.assertIsNone(p.married)
3444        self.assertIsNone(p.weight)
3445        self.assertIsNone(p.salary)
3446        r = self.db.insert('test_person', None, person=None)
3447        self.assertIsNone(r['person'])
3448        r['person'] = None
3449        self.db.get('test_person', r)
3450        self.assertIsNone(r['person'])
3451
3452    def testRecordInsertBytea(self):
3453        query = self.db.query
3454        query('create type test_person_type as'
3455              ' (name text, picture bytea)')
3456        self.addCleanup(query, 'drop type test_person_type')
3457        self.createTable('test_person', 'person test_person_type',
3458                         temporary=False, oids=True)
3459        person_typ = self.db.get_attnames('test_person')['person']
3460        self.assertEqual(person_typ.attnames,
3461                         dict(name='text', picture='bytea'))
3462        person = ('John Doe', b'O\x00ps\xff!')
3463        r = self.db.insert('test_person', None, person=person)
3464        p = r['person']
3465        self.assertIsInstance(p, tuple)
3466        self.assertEqual(p, person)
3467        self.assertEqual(p.name, 'John Doe')
3468        self.assertIsInstance(p.name, str)
3469        self.assertEqual(p.picture, person[1])
3470        self.assertIsInstance(p.picture, bytes)
3471
3472    def testRecordInsertJson(self):
3473        query = self.db.query
3474        try:
3475            query('create type test_person_type as'
3476                  ' (name text, data json)')
3477        except pg.ProgrammingError as error:
3478            if self.db.server_version < 90200:
3479                self.skipTest('database does not support json')
3480            self.fail(str(error))
3481        self.addCleanup(query, 'drop type test_person_type')
3482        self.createTable('test_person', 'person test_person_type',
3483                         temporary=False, oids=True)
3484        person_typ = self.db.get_attnames('test_person')['person']
3485        self.assertEqual(person_typ.attnames,
3486                         dict(name='text', data='json'))
3487        person = ('John Doe', dict(age=61, married=True, weight=99.5))
3488        r = self.db.insert('test_person', None, person=person)
3489        p = r['person']
3490        self.assertIsInstance(p, tuple)
3491        if pg.get_jsondecode() is None:
3492            p = p._replace(data=json.loads(p.data))
3493        self.assertEqual(p, person)
3494        self.assertEqual(p.name, 'John Doe')
3495        self.assertIsInstance(p.name, str)
3496        self.assertEqual(p.data, person[1])
3497        self.assertIsInstance(p.data, dict)
3498
3499    def testRecordLiteral(self):
3500        query = self.db.query
3501        query('create type test_person_type as'
3502              ' (name varchar, age smallint)')
3503        self.addCleanup(query, 'drop type test_person_type')
3504        self.createTable('test_person', 'person test_person_type',
3505                         temporary=False, oids=True)
3506        person_typ = self.db.get_attnames('test_person')['person']
3507        if self.regtypes:
3508            self.assertEqual(person_typ, 'test_person_type')
3509        else:
3510            self.assertEqual(person_typ, 'record')
3511        if self.regtypes:
3512            self.assertEqual(person_typ.attnames,
3513                             dict(name='character varying', age='smallint'))
3514        else:
3515            self.assertEqual(person_typ.attnames,
3516                             dict(name='text', age='int'))
3517        person = pg.Literal("('John Doe', 61)")
3518        r = self.db.insert('test_person', None, person=person)
3519        p = r['person']
3520        self.assertIsInstance(p, tuple)
3521        self.assertEqual(p.name, 'John Doe')
3522        self.assertIsInstance(p.name, str)
3523        self.assertEqual(p.age, 61)
3524        self.assertIsInstance(p.age, int)
3525
3526    def testDbTypesInfo(self):
3527        dbtypes = self.db.dbtypes
3528        self.assertIsInstance(dbtypes, dict)
3529        self.assertNotIn('numeric', dbtypes)
3530        typ = dbtypes['numeric']
3531        self.assertIn('numeric', dbtypes)
3532        self.assertEqual(typ, 'numeric' if self.regtypes else 'num')
3533        self.assertEqual(typ.oid, 1700)
3534        self.assertEqual(typ.pgtype, 'numeric')
3535        self.assertEqual(typ.regtype, 'numeric')
3536        self.assertEqual(typ.simple, 'num')
3537        self.assertEqual(typ.typtype, 'b')
3538        self.assertEqual(typ.category, 'N')
3539        self.assertEqual(typ.delim, ',')
3540        self.assertEqual(typ.relid, 0)
3541        self.assertIs(dbtypes[1700], typ)
3542        self.assertNotIn('pg_type', dbtypes)
3543        typ = dbtypes['pg_type']
3544        self.assertIn('pg_type', dbtypes)
3545        self.assertEqual(typ, 'pg_type' if self.regtypes else 'record')
3546        self.assertIsInstance(typ.oid, int)
3547        self.assertEqual(typ.pgtype, 'pg_type')
3548        self.assertEqual(typ.regtype, 'pg_type')
3549        self.assertEqual(typ.simple, 'record')
3550        self.assertEqual(typ.typtype, 'c')
3551        self.assertEqual(typ.category, 'C')
3552        self.assertEqual(typ.delim, ',')
3553        self.assertNotEqual(typ.relid, 0)
3554        attnames = typ.attnames
3555        self.assertIsInstance(attnames, dict)
3556        self.assertIs(attnames, dbtypes.get_attnames('pg_type'))
3557        self.assertEqual(list(attnames)[0], 'typname')
3558        typname = attnames['typname']
3559        self.assertEqual(typname, 'name' if self.regtypes else 'text')
3560        self.assertEqual(typname.typtype, 'b')  # base
3561        self.assertEqual(typname.category, 'S')  # string
3562        self.assertEqual(list(attnames)[3], 'typlen')
3563        typlen = attnames['typlen']
3564        self.assertEqual(typlen, 'smallint' if self.regtypes else 'int')
3565        self.assertEqual(typlen.typtype, 'b')  # base
3566        self.assertEqual(typlen.category, 'N')  # numeric
3567
3568    def testDbTypesTypecast(self):
3569        dbtypes = self.db.dbtypes
3570        self.assertIsInstance(dbtypes, dict)
3571        self.assertNotIn('int4', dbtypes)
3572        self.assertIs(dbtypes.get_typecast('int4'), int)
3573        dbtypes.set_typecast('int4', float)
3574        self.assertIs(dbtypes.get_typecast('int4'), float)
3575        dbtypes.reset_typecast('int4')
3576        self.assertIs(dbtypes.get_typecast('int4'), int)
3577        dbtypes.set_typecast('int4', float)
3578        self.assertIs(dbtypes.get_typecast('int4'), float)
3579        dbtypes.reset_typecast()
3580        self.assertIs(dbtypes.get_typecast('int4'), int)
3581        self.assertNotIn('circle', dbtypes)
3582        self.assertIsNone(dbtypes.get_typecast('circle'))
3583        squared_circle = lambda v: 'Squared Circle: %s' % v
3584        dbtypes.set_typecast('circle', squared_circle)
3585        self.assertIs(dbtypes.get_typecast('circle'), squared_circle)
3586        r = self.db.query("select '0,0,1'::circle").getresult()[0][0]
3587        self.assertIn('circle', dbtypes)
3588        self.assertEqual(r, 'Squared Circle: <(0,0),1>')
3589        self.assertEqual(dbtypes.typecast('Impossible', 'circle'),
3590            'Squared Circle: Impossible')
3591        dbtypes.reset_typecast('circle')
3592        self.assertIsNone(dbtypes.get_typecast('circle'))
3593
3594    def testGetSetTypeCast(self):
3595        get_typecast = pg.get_typecast
3596        set_typecast = pg.set_typecast
3597        dbtypes = self.db.dbtypes
3598        self.assertIsInstance(dbtypes, dict)
3599        self.assertNotIn('int4', dbtypes)
3600        self.assertNotIn('real', dbtypes)
3601        self.assertNotIn('bool', dbtypes)
3602        self.assertIs(get_typecast('int4'), int)
3603        self.assertIs(get_typecast('float4'), float)
3604        self.assertIs(get_typecast('bool'), pg.cast_bool)
3605        cast_circle = get_typecast('circle')
3606        self.addCleanup(set_typecast, 'circle', cast_circle)
3607        squared_circle = lambda v: 'Squared Circle: %s' % v
3608        self.assertNotIn('circle', dbtypes)
3609        set_typecast('circle', squared_circle)
3610        self.assertNotIn('circle', dbtypes)
3611        self.assertIs(get_typecast('circle'), squared_circle)
3612        r = self.db.query("select '0,0,1'::circle").getresult()[0][0]
3613        self.assertIn('circle', dbtypes)
3614        self.assertEqual(r, 'Squared Circle: <(0,0),1>')
3615        set_typecast('circle', cast_circle)
3616        self.assertIs(get_typecast('circle'), cast_circle)
3617
3618    def testNotificationHandler(self):
3619        # the notification handler itself is tested separately
3620        f = self.db.notification_handler
3621        callback = lambda arg_dict: None
3622        handler = f('test', callback)
3623        self.assertIsInstance(handler, pg.NotificationHandler)
3624        self.assertIs(handler.db, self.db)
3625        self.assertEqual(handler.event, 'test')
3626        self.assertEqual(handler.stop_event, 'stop_test')
3627        self.assertIs(handler.callback, callback)
3628        self.assertIsInstance(handler.arg_dict, dict)
3629        self.assertEqual(handler.arg_dict, {})
3630        self.assertIsNone(handler.timeout)
3631        self.assertFalse(handler.listening)
3632        handler.close()
3633        self.assertIsNone(handler.db)
3634        self.db.reopen()
3635        self.assertIsNone(handler.db)
3636        handler = f('test2', callback, timeout=2)
3637        self.assertIsInstance(handler, pg.NotificationHandler)
3638        self.assertIs(handler.db, self.db)
3639        self.assertEqual(handler.event, 'test2')
3640        self.assertEqual(handler.stop_event, 'stop_test2')
3641        self.assertIs(handler.callback, callback)
3642        self.assertIsInstance(handler.arg_dict, dict)
3643        self.assertEqual(handler.arg_dict, {})
3644        self.assertEqual(handler.timeout, 2)
3645        self.assertFalse(handler.listening)
3646        handler.close()
3647        self.assertIsNone(handler.db)
3648        self.db.reopen()
3649        self.assertIsNone(handler.db)
3650        arg_dict = {'testing': 3}
3651        handler = f('test3', callback, arg_dict=arg_dict)
3652        self.assertIsInstance(handler, pg.NotificationHandler)
3653        self.assertIs(handler.db, self.db)
3654        self.assertEqual(handler.event, 'test3')
3655        self.assertEqual(handler.stop_event, 'stop_test3')
3656        self.assertIs(handler.callback, callback)
3657        self.assertIs(handler.arg_dict, arg_dict)
3658        self.assertEqual(arg_dict['testing'], 3)
3659        self.assertIsNone(handler.timeout)
3660        self.assertFalse(handler.listening)
3661        handler.close()
3662        self.assertIsNone(handler.db)
3663        self.db.reopen()
3664        self.assertIsNone(handler.db)
3665        handler = f('test4', callback, stop_event='stop4')
3666        self.assertIsInstance(handler, pg.NotificationHandler)
3667        self.assertIs(handler.db, self.db)
3668        self.assertEqual(handler.event, 'test4')
3669        self.assertEqual(handler.stop_event, 'stop4')
3670        self.assertIs(handler.callback, callback)
3671        self.assertIsInstance(handler.arg_dict, dict)
3672        self.assertEqual(handler.arg_dict, {})
3673        self.assertIsNone(handler.timeout)
3674        self.assertFalse(handler.listening)
3675        handler.close()
3676        self.assertIsNone(handler.db)
3677        self.db.reopen()
3678        self.assertIsNone(handler.db)
3679        arg_dict = {'testing': 5}
3680        handler = f('test5', callback, arg_dict, 1.5, 'stop5')
3681        self.assertIsInstance(handler, pg.NotificationHandler)
3682        self.assertIs(handler.db, self.db)
3683        self.assertEqual(handler.event, 'test5')
3684        self.assertEqual(handler.stop_event, 'stop5')
3685        self.assertIs(handler.callback, callback)
3686        self.assertIs(handler.arg_dict, arg_dict)
3687        self.assertEqual(arg_dict['testing'], 5)
3688        self.assertEqual(handler.timeout, 1.5)
3689        self.assertFalse(handler.listening)
3690        handler.close()
3691        self.assertIsNone(handler.db)
3692        self.db.reopen()
3693        self.assertIsNone(handler.db)
3694
3695
3696class TestDBClassNonStdOpts(TestDBClass):
3697    """Test the methods of the DB class with non-standard global options."""
3698
3699    @classmethod
3700    def setUpClass(cls):
3701        cls.saved_options = {}
3702        cls.set_option('decimal', float)
3703        not_bool = not pg.get_bool()
3704        cls.set_option('bool', not_bool)
3705        not_array = not pg.get_array()
3706        cls.set_option('array', not_array)
3707        not_bytea_escaped = not pg.get_bytea_escaped()
3708        cls.set_option('bytea_escaped', not_bytea_escaped)
3709        cls.set_option('namedresult', None)
3710        cls.set_option('jsondecode', None)
3711        cls.regtypes = not DB().use_regtypes()
3712        super(TestDBClassNonStdOpts, cls).setUpClass()
3713
3714    @classmethod
3715    def tearDownClass(cls):
3716        super(TestDBClassNonStdOpts, cls).tearDownClass()
3717        cls.reset_option('jsondecode')
3718        cls.reset_option('namedresult')
3719        cls.reset_option('bool')
3720        cls.reset_option('array')
3721        cls.reset_option('bytea_escaped')
3722        cls.reset_option('decimal')
3723
3724    @classmethod
3725    def set_option(cls, option, value):
3726        cls.saved_options[option] = getattr(pg, 'get_' + option)()
3727        return getattr(pg, 'set_' + option)(value)
3728
3729    @classmethod
3730    def reset_option(cls, option):
3731        return getattr(pg, 'set_' + option)(cls.saved_options[option])
3732
3733
3734class TestDBClassAdapter(unittest.TestCase):
3735    """Test the adapter object associatd with the DB class."""
3736
3737    def setUp(self):
3738        self.db = DB()
3739        self.adapter = self.db.adapter
3740
3741    def tearDown(self):
3742        try:
3743            self.db.close()
3744        except pg.InternalError:
3745            pass
3746
3747    def testGuessSimpleType(self):
3748        f = self.adapter.guess_simple_type
3749        self.assertEqual(f(pg.Bytea(b'test')), 'bytea')
3750        self.assertEqual(f('string'), 'text')
3751        self.assertEqual(f(b'string'), 'text')
3752        self.assertEqual(f(True), 'bool')
3753        self.assertEqual(f(3), 'int')
3754        self.assertEqual(f(2.75), 'float')
3755        self.assertEqual(f(Decimal('4.25')), 'num')
3756        self.assertEqual(f(date(2016, 1, 30)), 'date')
3757        self.assertEqual(f([1, 2, 3]), 'int[]')
3758        self.assertEqual(f([[[123]]]), 'int[]')
3759        self.assertEqual(f(['a', 'b', 'c']), 'text[]')
3760        self.assertEqual(f([[['abc']]]), 'text[]')
3761        self.assertEqual(f([False, True]), 'bool[]')
3762        self.assertEqual(f([[[False]]]), 'bool[]')
3763        r = f(('string', True, 3, 2.75, [1], [False]))
3764        self.assertEqual(r, 'record')
3765        self.assertEqual(list(r.attnames.values()),
3766            ['text', 'bool', 'int', 'float', 'int[]', 'bool[]'])
3767
3768    def testAdaptQueryTypedList(self):
3769        format_query = self.adapter.format_query
3770        self.assertRaises(TypeError, format_query,
3771            '%s,%s', (1, 2), ('int2',))
3772        self.assertRaises(TypeError, format_query,
3773            '%s,%s', (1,), ('int2', 'int2'))
3774        values = (3, 7.5, 'hello', True)
3775        types = ('int4', 'float4', 'text', 'bool')
3776        sql, params = format_query("select %s,%s,%s,%s", values, types)
3777        self.assertEqual(sql, 'select $1,$2,$3,$4')
3778        self.assertEqual(params, [3, 7.5, 'hello', 't'])
3779        types = ('bool', 'bool', 'bool', 'bool')
3780        sql, params = format_query("select %s,%s,%s,%s", values, types)
3781        self.assertEqual(sql, 'select $1,$2,$3,$4')
3782        self.assertEqual(params, ['t', 't', 'f', 't'])
3783        values = ('2016-01-30', 'current_date')
3784        types = ('date', 'date')
3785        sql, params = format_query("values(%s,%s)", values, types)
3786        self.assertEqual(sql, 'values($1,current_date)')
3787        self.assertEqual(params, ['2016-01-30'])
3788        values = ([1, 2, 3], ['a', 'b', 'c'])
3789        types = ('_int4', '_text')
3790        sql, params = format_query("%s::int4[],%s::text[]", values, types)
3791        self.assertEqual(sql, '$1::int4[],$2::text[]')
3792        self.assertEqual(params, ['{1,2,3}', '{a,b,c}'])
3793        types = ('_bool', '_bool')
3794        sql, params = format_query("%s::bool[],%s::bool[]", values, types)
3795        self.assertEqual(sql, '$1::bool[],$2::bool[]')
3796        self.assertEqual(params, ['{t,t,t}', '{f,f,f}'])
3797        values = [(3, 7.5, 'hello', True, [123], ['abc'])]
3798        t = self.adapter.simple_type
3799        typ = t('record')
3800        typ._get_attnames = lambda _self: pg.AttrDict([
3801            ('i', t('int')), ('f', t('float')),
3802            ('t', t('text')), ('b', t('bool')),
3803            ('i3', t('int[]')), ('t3', t('text[]'))])
3804        types = [typ]
3805        sql, params = format_query('select %s', values, types)
3806        self.assertEqual(sql, 'select $1')
3807        self.assertEqual(params, ['(3,7.5,hello,t,{123},{abc})'])
3808
3809    def testAdaptQueryTypedDict(self):
3810        format_query = self.adapter.format_query
3811        self.assertRaises(TypeError, format_query,
3812            '%s,%s', dict(i1=1, i2=2), dict(i1='int2'))
3813        values = dict(i=3, f=7.5, t='hello', b=True)
3814        types = dict(i='int4', f='float4',
3815            t='text', b='bool')
3816        sql, params = format_query(
3817            "select %(i)s,%(f)s,%(t)s,%(b)s", values, types)
3818        self.assertEqual(sql, 'select $3,$2,$4,$1')
3819        self.assertEqual(params, ['t', 7.5, 3, 'hello'])
3820        types = dict(i='bool', f='bool',
3821            t='bool', b='bool')
3822        sql, params = format_query(
3823            "select %(i)s,%(f)s,%(t)s,%(b)s", values, types)
3824        self.assertEqual(sql, 'select $3,$2,$4,$1')
3825        self.assertEqual(params, ['t', 't', 't', 'f'])
3826        values = dict(d1='2016-01-30', d2='current_date')
3827        types = dict(d1='date', d2='date')
3828        sql, params = format_query("values(%(d1)s,%(d2)s)", values, types)
3829        self.assertEqual(sql, 'values($1,current_date)')
3830        self.assertEqual(params, ['2016-01-30'])
3831        values = dict(i=[1, 2, 3], t=['a', 'b', 'c'])
3832        types = dict(i='_int4', t='_text')
3833        sql, params = format_query(
3834            "%(i)s::int4[],%(t)s::text[]", values, types)
3835        self.assertEqual(sql, '$1::int4[],$2::text[]')
3836        self.assertEqual(params, ['{1,2,3}', '{a,b,c}'])
3837        types = dict(i='_bool', t='_bool')
3838        sql, params = format_query(
3839            "%(i)s::bool[],%(t)s::bool[]", values, types)
3840        self.assertEqual(sql, '$1::bool[],$2::bool[]')
3841        self.assertEqual(params, ['{t,t,t}', '{f,f,f}'])
3842        values = dict(record=(3, 7.5, 'hello', True, [123], ['abc']))
3843        t = self.adapter.simple_type
3844        typ = t('record')
3845        typ._get_attnames = lambda _self: pg.AttrDict([
3846            ('i', t('int')), ('f', t('float')),
3847            ('t', t('text')), ('b', t('bool')),
3848            ('i3', t('int[]')), ('t3', t('text[]'))])
3849        types = dict(record=typ)
3850        sql, params = format_query('select %(record)s', values, types)
3851        self.assertEqual(sql, 'select $1')
3852        self.assertEqual(params, ['(3,7.5,hello,t,{123},{abc})'])
3853
3854    def testAdaptQueryUntypedList(self):
3855        format_query = self.adapter.format_query
3856        values = (3, 7.5, 'hello', True)
3857        sql, params = format_query("select %s,%s,%s,%s", values)
3858        self.assertEqual(sql, 'select $1,$2,$3,$4')
3859        self.assertEqual(params, [3, 7.5, 'hello', 't'])
3860        values = [date(2016, 1, 30), 'current_date']
3861        sql, params = format_query("values(%s,%s)", values)
3862        self.assertEqual(sql, 'values($1,$2)')
3863        self.assertEqual(params, values)
3864        values = ([1, 2, 3], ['a', 'b', 'c'], [True, False, True])
3865        sql, params = format_query("%s,%s,%s", values)
3866        self.assertEqual(sql, "$1,$2,$3")
3867        self.assertEqual(params, ['{1,2,3}', '{a,b,c}', '{t,f,t}'])
3868        values = ([[1, 2], [3, 4]], [['a', 'b'], ['c', 'd']],
3869            [[True, False], [False, True]])
3870        sql, params = format_query("%s,%s,%s", values)
3871        self.assertEqual(sql, "$1,$2,$3")
3872        self.assertEqual(params, [
3873            '{{1,2},{3,4}}', '{{a,b},{c,d}}', '{{t,f},{f,t}}'])
3874        values = [(3, 7.5, 'hello', True, [123], ['abc'])]
3875        sql, params = format_query('select %s', values)
3876        self.assertEqual(sql, 'select $1')
3877        self.assertEqual(params, ['(3,7.5,hello,t,{123},{abc})'])
3878
3879    def testAdaptQueryUntypedDict(self):
3880        format_query = self.adapter.format_query
3881        values = dict(i=3, f=7.5, t='hello', b=True)
3882        sql, params = format_query(
3883            "select %(i)s,%(f)s,%(t)s,%(b)s", values)
3884        self.assertEqual(sql, 'select $3,$2,$4,$1')
3885        self.assertEqual(params, ['t', 7.5, 3, 'hello'])
3886        values = dict(d1='2016-01-30', d2='current_date')
3887        sql, params = format_query("values(%(d1)s,%(d2)s)", values)
3888        self.assertEqual(sql, 'values($1,$2)')
3889        self.assertEqual(params, [values['d1'], values['d2']])
3890        values = dict(i=[1, 2, 3], t=['a', 'b', 'c'], b=[True, False, True])
3891        sql, params = format_query("%(i)s,%(t)s,%(b)s", values)
3892        self.assertEqual(sql, "$2,$3,$1")
3893        self.assertEqual(params, ['{t,f,t}', '{1,2,3}', '{a,b,c}'])
3894        values = dict(i=[[1, 2], [3, 4]], t=[['a', 'b'], ['c', 'd']],
3895            b=[[True, False], [False, True]])
3896        sql, params = format_query("%(i)s,%(t)s,%(b)s", values)
3897        self.assertEqual(sql, "$2,$3,$1")
3898        self.assertEqual(params, [
3899            '{{t,f},{f,t}}', '{{1,2},{3,4}}', '{{a,b},{c,d}}'])
3900        values = dict(record=(3, 7.5, 'hello', True, [123], ['abc']))
3901        sql, params = format_query('select %(record)s', values)
3902        self.assertEqual(sql, 'select $1')
3903        self.assertEqual(params, ['(3,7.5,hello,t,{123},{abc})'])
3904
3905    def testAdaptQueryInlineList(self):
3906        format_query = self.adapter.format_query
3907        values = (3, 7.5, 'hello', True)
3908        sql, params = format_query("select %s,%s,%s,%s", values, inline=True)
3909        self.assertEqual(sql, "select 3,7.5,'hello',true")
3910        self.assertEqual(params, [])
3911        values = [date(2016, 1, 30), 'current_date']
3912        sql, params = format_query("values(%s,%s)", values, inline=True)
3913        self.assertEqual(sql, "values('2016-01-30','current_date')")
3914        self.assertEqual(params, [])
3915        values = ([1, 2, 3], ['a', 'b', 'c'], [True, False, True])
3916        sql, params = format_query("%s,%s,%s", values, inline=True)
3917        self.assertEqual(sql,
3918            "ARRAY[1,2,3],ARRAY['a','b','c'],ARRAY[true,false,true]")
3919        self.assertEqual(params, [])
3920        values = ([[1, 2], [3, 4]], [['a', 'b'], ['c', 'd']],
3921            [[True, False], [False, True]])
3922        sql, params = format_query("%s,%s,%s", values, inline=True)
3923        self.assertEqual(sql, "ARRAY[[1,2],[3,4]],ARRAY[['a','b'],['c','d']],"
3924            "ARRAY[[true,false],[false,true]]")
3925        self.assertEqual(params, [])
3926        values = [(3, 7.5, 'hello', True, [123], ['abc'])]
3927        sql, params = format_query('select %s', values, inline=True)
3928        self.assertEqual(sql,
3929            "select (3,7.5,'hello',true,ARRAY[123],ARRAY['abc'])")
3930        self.assertEqual(params, [])
3931
3932    def testAdaptQueryInlineDict(self):
3933        format_query = self.adapter.format_query
3934        values = dict(i=3, f=7.5, t='hello', b=True)
3935        sql, params = format_query(
3936            "select %(i)s,%(f)s,%(t)s,%(b)s", values, inline=True)
3937        self.assertEqual(sql, "select 3,7.5,'hello',true")
3938        self.assertEqual(params, [])
3939        values = dict(d1='2016-01-30', d2='current_date')
3940        sql, params = format_query(
3941            "values(%(d1)s,%(d2)s)", values, inline=True)
3942        self.assertEqual(sql, "values('2016-01-30','current_date')")
3943        self.assertEqual(params, [])
3944        values = dict(i=[1, 2, 3], t=['a', 'b', 'c'], b=[True, False, True])
3945        sql, params = format_query("%(i)s,%(t)s,%(b)s", values, inline=True)
3946        self.assertEqual(sql,
3947            "ARRAY[1,2,3],ARRAY['a','b','c'],ARRAY[true,false,true]")
3948        self.assertEqual(params, [])
3949        values = dict(i=[[1, 2], [3, 4]], t=[['a', 'b'], ['c', 'd']],
3950            b=[[True, False], [False, True]])
3951        sql, params = format_query("%(i)s,%(t)s,%(b)s", values, inline=True)
3952        self.assertEqual(sql, "ARRAY[[1,2],[3,4]],ARRAY[['a','b'],['c','d']],"
3953            "ARRAY[[true,false],[false,true]]")
3954        self.assertEqual(params, [])
3955        values = dict(record=(3, 7.5, 'hello', True, [123], ['abc']))
3956        sql, params = format_query('select %(record)s', values, inline=True)
3957        self.assertEqual(sql,
3958            "select (3,7.5,'hello',true,ARRAY[123],ARRAY['abc'])")
3959        self.assertEqual(params, [])
3960
3961    def testAdaptQueryWithPgRepr(self):
3962        format_query = self.adapter.format_query
3963        self.assertRaises(TypeError, format_query,
3964            '%s', object(), inline=True)
3965        class TestObject:
3966            def __pg_repr__(self):
3967                return "'adapted'"
3968        sql, params = format_query('select %s', [TestObject()], inline=True)
3969        self.assertEqual(sql, "select 'adapted'")
3970        self.assertEqual(params, [])
3971        sql, params = format_query('select %s', [[TestObject()]], inline=True)
3972        self.assertEqual(sql, "select ARRAY['adapted']")
3973        self.assertEqual(params, [])
3974
3975
3976class TestSchemas(unittest.TestCase):
3977    """Test correct handling of schemas (namespaces)."""
3978
3979    cls_set_up = False
3980
3981    @classmethod
3982    def setUpClass(cls):
3983        db = DB()
3984        query = db.query
3985        for num_schema in range(5):
3986            if num_schema:
3987                schema = "s%d" % num_schema
3988                query("drop schema if exists %s cascade" % (schema,))
3989                try:
3990                    query("create schema %s" % (schema,))
3991                except pg.ProgrammingError:
3992                    raise RuntimeError("The test user cannot create schemas.\n"
3993                        "Grant create on database %s to the user"
3994                        " for running these tests." % dbname)
3995            else:
3996                schema = "public"
3997                query("drop table if exists %s.t" % (schema,))
3998                query("drop table if exists %s.t%d" % (schema, num_schema))
3999            query("create table %s.t with oids as select 1 as n, %d as d"
4000                  % (schema, num_schema))
4001            query("create table %s.t%d with oids as select 1 as n, %d as d"
4002                  % (schema, num_schema, num_schema))
4003        db.close()
4004        cls.cls_set_up = True
4005
4006    @classmethod
4007    def tearDownClass(cls):
4008        db = DB()
4009        query = db.query
4010        for num_schema in range(5):
4011            if num_schema:
4012                schema = "s%d" % num_schema
4013                query("drop schema %s cascade" % (schema,))
4014            else:
4015                schema = "public"
4016                query("drop table %s.t" % (schema,))
4017                query("drop table %s.t%d" % (schema, num_schema))
4018        db.close()
4019
4020    def setUp(self):
4021        self.assertTrue(self.cls_set_up)
4022        self.db = DB()
4023
4024    def tearDown(self):
4025        self.doCleanups()
4026        self.db.close()
4027
4028    def testGetTables(self):
4029        tables = self.db.get_tables()
4030        for num_schema in range(5):
4031            if num_schema:
4032                schema = "s" + str(num_schema)
4033            else:
4034                schema = "public"
4035            for t in (schema + ".t",
4036                    schema + ".t" + str(num_schema)):
4037                self.assertIn(t, tables)
4038
4039    def testGetAttnames(self):
4040        get_attnames = self.db.get_attnames
4041        query = self.db.query
4042        result = {'oid': 'int', 'd': 'int', 'n': 'int'}
4043        r = get_attnames("t")
4044        self.assertEqual(r, result)
4045        r = get_attnames("s4.t4")
4046        self.assertEqual(r, result)
4047        query("drop table if exists s3.t3m")
4048        self.addCleanup(query, "drop table s3.t3m")
4049        query("create table s3.t3m with oids as select 1 as m")
4050        result_m = {'oid': 'int', 'm': 'int'}
4051        r = get_attnames("s3.t3m")
4052        self.assertEqual(r, result_m)
4053        query("set search_path to s1,s3")
4054        r = get_attnames("t3")
4055        self.assertEqual(r, result)
4056        r = get_attnames("t3m")
4057        self.assertEqual(r, result_m)
4058
4059    def testGet(self):
4060        get = self.db.get
4061        query = self.db.query
4062        PrgError = pg.ProgrammingError
4063        self.assertEqual(get("t", 1, 'n')['d'], 0)
4064        self.assertEqual(get("t0", 1, 'n')['d'], 0)
4065        self.assertEqual(get("public.t", 1, 'n')['d'], 0)
4066        self.assertEqual(get("public.t0", 1, 'n')['d'], 0)
4067        self.assertRaises(PrgError, get, "public.t1", 1, 'n')
4068        self.assertEqual(get("s1.t1", 1, 'n')['d'], 1)
4069        self.assertEqual(get("s3.t", 1, 'n')['d'], 3)
4070        query("set search_path to s2,s4")
4071        self.assertRaises(PrgError, get, "t1", 1, 'n')
4072        self.assertEqual(get("t4", 1, 'n')['d'], 4)
4073        self.assertRaises(PrgError, get, "t3", 1, 'n')
4074        self.assertEqual(get("t", 1, 'n')['d'], 2)
4075        self.assertEqual(get("s3.t3", 1, 'n')['d'], 3)
4076        query("set search_path to s1,s3")
4077        self.assertRaises(PrgError, get, "t2", 1, 'n')
4078        self.assertEqual(get("t3", 1, 'n')['d'], 3)
4079        self.assertRaises(PrgError, get, "t4", 1, 'n')
4080        self.assertEqual(get("t", 1, 'n')['d'], 1)
4081        self.assertEqual(get("s4.t4", 1, 'n')['d'], 4)
4082
4083    def testMunging(self):
4084        get = self.db.get
4085        query = self.db.query
4086        r = get("t", 1, 'n')
4087        self.assertIn('oid(t)', r)
4088        query("set search_path to s2")
4089        r = get("t2", 1, 'n')
4090        self.assertIn('oid(t2)', r)
4091        query("set search_path to s3")
4092        r = get("t", 1, 'n')
4093        self.assertIn('oid(t)', r)
4094
4095
4096class TestDebug(unittest.TestCase):
4097    """Test the debug attribute of the DB class."""
4098
4099    def setUp(self):
4100        self.db = DB()
4101        self.query = self.db.query
4102        self.debug = self.db.debug
4103        self.output = StringIO()
4104        self.stdout, sys.stdout = sys.stdout, self.output
4105
4106    def tearDown(self):
4107        sys.stdout = self.stdout
4108        self.output.close()
4109        self.db.debug = debug
4110        self.db.close()
4111
4112    def get_output(self):
4113        return self.output.getvalue()
4114
4115    def send_queries(self):
4116        self.db.query("select 1")
4117        self.db.query("select 2")
4118
4119    def testDebugDefault(self):
4120        if debug:
4121            self.assertEqual(self.db.debug, debug)
4122        else:
4123            self.assertIsNone(self.db.debug)
4124
4125    def testDebugIsFalse(self):
4126        self.db.debug = False
4127        self.send_queries()
4128        self.assertEqual(self.get_output(), "")
4129
4130    def testDebugIsTrue(self):
4131        self.db.debug = True
4132        self.send_queries()
4133        self.assertEqual(self.get_output(), "select 1\nselect 2\n")
4134
4135    def testDebugIsString(self):
4136        self.db.debug = "Test with string: %s."
4137        self.send_queries()
4138        self.assertEqual(self.get_output(),
4139            "Test with string: select 1.\nTest with string: select 2.\n")
4140
4141    def testDebugIsFileLike(self):
4142        with tempfile.TemporaryFile('w+') as debug_file:
4143            self.db.debug = debug_file
4144            self.send_queries()
4145            debug_file.seek(0)
4146            output = debug_file.read()
4147            self.assertEqual(output, "select 1\nselect 2\n")
4148            self.assertEqual(self.get_output(), "")
4149
4150    def testDebugIsCallable(self):
4151        output = []
4152        self.db.debug = output.append
4153        self.db.query("select 1")
4154        self.db.query("select 2")
4155        self.assertEqual(output, ["select 1", "select 2"])
4156        self.assertEqual(self.get_output(), "")
4157
4158    def testDebugMultipleArgs(self):
4159        output = []
4160        self.db.debug = output.append
4161        args = ['Error', 42, {1: 'a', 2: 'b'}, [3, 5, 7]]
4162        self.db._do_debug(*args)
4163        self.assertEqual(output, ['\n'.join(str(arg) for arg in args)])
4164        self.assertEqual(self.get_output(), "")
4165
4166
4167if __name__ == '__main__':
4168    unittest.main()
Note: See TracBrowser for help on using the repository browser.