source: trunk/tests/test_classic_dbwrapper.py @ 833

Last change on this file since 833 was 833, checked in by cito, 4 years ago

Small fixes to make trunk run with Python 2.6 again

  • Property svn:executable set to *
  • Property svn:keywords set to Id
File size: 175.7 KB
Line 
1#! /usr/bin/python
2# -*- coding: utf-8 -*-
3
4"""Test the classic PyGreSQL interface.
5
6Sub-tests for the DB wrapper object.
7
8Contributed by Christoph Zwerschke.
9
10These tests need a database to test against.
11"""
12
13try:
14    import unittest2 as unittest  # for Python < 2.7
15except ImportError:
16    import unittest
17
18import os
19import sys
20import json
21import tempfile
22
23import pg  # the module under test
24
25from decimal import Decimal
26from datetime import date, time, datetime, timedelta
27from uuid import UUID
28from time import strftime
29from operator import itemgetter
30
31# We need a database to test against.  If LOCAL_PyGreSQL.py exists we will
32# get our information from that.  Otherwise we use the defaults.
33# The current user must have create schema privilege on the database.
34dbname = 'unittest'
35dbhost = None
36dbport = 5432
37
38debug = False  # let DB wrapper print debugging output
39
40try:
41    from .LOCAL_PyGreSQL import *
42except (ImportError, ValueError):
43    try:
44        from LOCAL_PyGreSQL import *
45    except ImportError:
46        pass
47
48try:
49    long
50except NameError:  # Python >= 3.0
51    long = int
52
53try:
54    unicode
55except NameError:  # Python >= 3.0
56    unicode = str
57
58try:
59    from collections import OrderedDict
60except ImportError:  # Python 2.6 or 3.0
61    OrderedDict = dict
62
63if str is bytes:
64    from StringIO import StringIO
65else:
66    from io import StringIO
67
68windows = os.name == 'nt'
69
70# There is a known a bug in libpq under Windows which can cause
71# the interface to crash when calling PQhost():
72do_not_ask_for_host = windows
73do_not_ask_for_host_reason = 'libpq issue on Windows'
74
75
76def DB():
77    """Create a DB wrapper object connecting to the test database."""
78    db = pg.DB(dbname, dbhost, dbport)
79    if debug:
80        db.debug = debug
81    db.query("set client_min_messages=warning")
82    return db
83
84
85class TestAttrDict(unittest.TestCase):
86    """Test the simple ordered dictionary for attribute names."""
87
88    cls = pg.AttrDict
89    base = OrderedDict
90
91    def testInit(self):
92        a = self.cls()
93        self.assertIsInstance(a, self.base)
94        self.assertEqual(a, self.base())
95        items = [('id', 'int'), ('name', 'text')]
96        a = self.cls(items)
97        self.assertIsInstance(a, self.base)
98        self.assertEqual(a, self.base(items))
99        iteritems = iter(items)
100        a = self.cls(iteritems)
101        self.assertIsInstance(a, self.base)
102        self.assertEqual(a, self.base(items))
103
104    def testIter(self):
105        a = self.cls()
106        self.assertEqual(list(a), [])
107        keys = ['id', 'name', 'age']
108        items = [(key, None) for key in keys]
109        a = self.cls(items)
110        self.assertEqual(list(a), keys)
111
112    def testKeys(self):
113        a = self.cls()
114        self.assertEqual(list(a.keys()), [])
115        keys = ['id', 'name', 'age']
116        items = [(key, None) for key in keys]
117        a = self.cls(items)
118        self.assertEqual(list(a.keys()), keys)
119
120    def testValues(self):
121        a = self.cls()
122        self.assertEqual(list(a.values()), [])
123        items = [('id', 'int'), ('name', 'text')]
124        values = [item[1] for item in items]
125        a = self.cls(items)
126        self.assertEqual(list(a.values()), values)
127
128    def testItems(self):
129        a = self.cls()
130        self.assertEqual(list(a.items()), [])
131        items = [('id', 'int'), ('name', 'text')]
132        a = self.cls(items)
133        self.assertEqual(list(a.items()), items)
134
135    def testGet(self):
136        a = self.cls([('id', 1)])
137        try:
138            self.assertEqual(a['id'], 1)
139        except KeyError:
140            self.fail('AttrDict should be readable')
141
142    def testSet(self):
143        a = self.cls()
144        try:
145            a['id'] = 1
146        except TypeError:
147            pass
148        else:
149            self.fail('AttrDict should be read-only')
150
151    def testDel(self):
152        a = self.cls([('id', 1)])
153        try:
154            del a['id']
155        except TypeError:
156            pass
157        else:
158            self.fail('AttrDict should be read-only')
159
160    def testWriteMethods(self):
161        a = self.cls([('id', 1)])
162        self.assertEqual(a['id'], 1)
163        for method in 'clear', 'update', 'pop', 'setdefault', 'popitem':
164            method = getattr(a, method)
165            self.assertRaises(TypeError, method, a)
166
167
168class TestDBClassBasic(unittest.TestCase):
169    """Test existence of the DB class wrapped pg connection methods."""
170
171    def setUp(self):
172        self.db = DB()
173
174    def tearDown(self):
175        try:
176            self.db.close()
177        except pg.InternalError:
178            pass
179
180    def testAllDBAttributes(self):
181        attributes = [
182            'abort', 'adapter',
183            'begin',
184            'cancel', 'clear', 'close', 'commit',
185            'date_format', 'db', 'dbname', 'dbtypes',
186            'debug', 'decode_json', 'delete',
187            'encode_json', 'end', 'endcopy', 'error',
188            'escape_bytea', 'escape_identifier',
189            'escape_literal', 'escape_string',
190            'fileno',
191            'get', 'get_as_dict', 'get_as_list',
192            'get_attnames', 'get_cast_hook',
193            'get_databases', 'get_notice_receiver',
194            'get_parameter', 'get_relations', 'get_tables',
195            'getline', 'getlo', 'getnotify',
196            'has_table_privilege', 'host',
197            'insert', 'inserttable',
198            'locreate', 'loimport',
199            'notification_handler',
200            'options',
201            'parameter', 'pkey', 'port',
202            'protocol_version', 'putline',
203            'query', 'query_formatted',
204            'release', 'reopen', 'reset', 'rollback',
205            'savepoint', 'server_version',
206            'set_cast_hook', 'set_notice_receiver',
207            'set_parameter',
208            'source', 'start', 'status',
209            'transaction', 'truncate',
210            'unescape_bytea', 'update', 'upsert',
211            'use_regtypes', 'user',
212        ]
213        # __dir__ is not called in Python 2.6 for old-style classes
214        db_attributes = dir(self.db) if hasattr(
215            self.db.__class__, '__class__') else self.db.__dir__()
216        db_attributes = [a for a in db_attributes
217            if not a.startswith('_')]
218        self.assertEqual(attributes, db_attributes)
219
220    def testAttributeDb(self):
221        self.assertEqual(self.db.db.db, dbname)
222
223    def testAttributeDbname(self):
224        self.assertEqual(self.db.dbname, dbname)
225
226    def testAttributeError(self):
227        error = self.db.error
228        self.assertTrue(not error or 'krb5_' in error)
229        self.assertEqual(self.db.error, self.db.db.error)
230
231    @unittest.skipIf(do_not_ask_for_host, do_not_ask_for_host_reason)
232    def testAttributeHost(self):
233        def_host = 'localhost'
234        host = self.db.host
235        self.assertIsInstance(host, str)
236        self.assertEqual(host, dbhost or def_host)
237        self.assertEqual(host, self.db.db.host)
238
239    def testAttributeOptions(self):
240        no_options = ''
241        options = self.db.options
242        self.assertEqual(options, no_options)
243        self.assertEqual(options, self.db.db.options)
244
245    def testAttributePort(self):
246        def_port = 5432
247        port = self.db.port
248        self.assertIsInstance(port, int)
249        self.assertEqual(port, dbport or def_port)
250        self.assertEqual(port, self.db.db.port)
251
252    def testAttributeProtocolVersion(self):
253        protocol_version = self.db.protocol_version
254        self.assertIsInstance(protocol_version, int)
255        self.assertTrue(2 <= protocol_version < 4)
256        self.assertEqual(protocol_version, self.db.db.protocol_version)
257
258    def testAttributeServerVersion(self):
259        server_version = self.db.server_version
260        self.assertIsInstance(server_version, int)
261        self.assertTrue(70400 <= server_version < 100000)
262        self.assertEqual(server_version, self.db.db.server_version)
263
264    def testAttributeStatus(self):
265        status_ok = 1
266        status = self.db.status
267        self.assertIsInstance(status, int)
268        self.assertEqual(status, status_ok)
269        self.assertEqual(status, self.db.db.status)
270
271    def testAttributeUser(self):
272        no_user = 'Deprecated facility'
273        user = self.db.user
274        self.assertTrue(user)
275        self.assertIsInstance(user, str)
276        self.assertNotEqual(user, no_user)
277        self.assertEqual(user, self.db.db.user)
278
279    def testMethodEscapeLiteral(self):
280        self.assertEqual(self.db.escape_literal(''), "''")
281
282    def testMethodEscapeIdentifier(self):
283        self.assertEqual(self.db.escape_identifier(''), '""')
284
285    def testMethodEscapeString(self):
286        self.assertEqual(self.db.escape_string(''), '')
287
288    def testMethodEscapeBytea(self):
289        self.assertEqual(self.db.escape_bytea('').replace(
290            '\\x', '').replace('\\', ''), '')
291
292    def testMethodUnescapeBytea(self):
293        self.assertEqual(self.db.unescape_bytea(''), b'')
294
295    def testMethodDecodeJson(self):
296        self.assertEqual(self.db.decode_json('{}'), {})
297
298    def testMethodEncodeJson(self):
299        self.assertEqual(self.db.encode_json({}), '{}')
300
301    def testMethodQuery(self):
302        query = self.db.query
303        query("select 1+1")
304        query("select 1+$1+$2", 2, 3)
305        query("select 1+$1+$2", (2, 3))
306        query("select 1+$1+$2", [2, 3])
307        query("select 1+$1", 1)
308
309    def testMethodQueryEmpty(self):
310        self.assertRaises(ValueError, self.db.query, '')
311
312    def testMethodQueryDataError(self):
313        try:
314            self.db.query("select 1/0")
315        except pg.DataError as error:
316            self.assertEqual(error.sqlstate, '22012')
317
318    def testMethodEndcopy(self):
319        try:
320            self.db.endcopy()
321        except IOError:
322            pass
323
324    def testMethodClose(self):
325        self.db.close()
326        try:
327            self.db.reset()
328        except pg.Error:
329            pass
330        else:
331            self.fail('Reset should give an error for a closed connection')
332        self.assertIsNone(self.db.db)
333        self.assertRaises(pg.InternalError, self.db.close)
334        self.assertRaises(pg.InternalError, self.db.query, 'select 1')
335        self.assertRaises(pg.InternalError, getattr, self.db, 'status')
336        self.assertRaises(pg.InternalError, getattr, self.db, 'error')
337        self.assertRaises(pg.InternalError, getattr, self.db, 'absent')
338
339    def testMethodReset(self):
340        con = self.db.db
341        self.db.reset()
342        self.assertIs(self.db.db, con)
343        self.db.query("select 1+1")
344        self.db.close()
345        self.assertRaises(pg.InternalError, self.db.reset)
346
347    def testMethodReopen(self):
348        con = self.db.db
349        self.db.reopen()
350        self.assertIsNot(self.db.db, con)
351        con = self.db.db
352        self.db.query("select 1+1")
353        self.db.close()
354        self.db.reopen()
355        self.assertIsNot(self.db.db, con)
356        self.db.query("select 1+1")
357        self.db.close()
358
359    def testExistingConnection(self):
360        db = pg.DB(self.db.db)
361        self.assertEqual(self.db.db, db.db)
362        self.assertTrue(db.db)
363        db.close()
364        self.assertTrue(db.db)
365        db.reopen()
366        self.assertTrue(db.db)
367        db.close()
368        self.assertTrue(db.db)
369        db = pg.DB(self.db)
370        self.assertEqual(self.db.db, db.db)
371        db = pg.DB(db=self.db.db)
372        self.assertEqual(self.db.db, db.db)
373
374        class DB2:
375            pass
376
377        db2 = DB2()
378        db2._cnx = self.db.db
379        db = pg.DB(db2)
380        self.assertEqual(self.db.db, db.db)
381
382
383class TestDBClass(unittest.TestCase):
384    """Test the methods of the DB class wrapped pg connection."""
385
386    maxDiff = 80 * 20
387
388    cls_set_up = False
389
390    regtypes = None
391
392    @classmethod
393    def setUpClass(cls):
394        db = DB()
395        db.query("drop table if exists test cascade")
396        db.query("create table test ("
397                 "i2 smallint, i4 integer, i8 bigint,"
398                 " d numeric, f4 real, f8 double precision, m money,"
399                 " v4 varchar(4), c4 char(4), t text)")
400        db.query("create or replace view test_view as"
401                 " select i4, v4 from test")
402        db.close()
403        cls.cls_set_up = True
404
405    @classmethod
406    def tearDownClass(cls):
407        db = DB()
408        db.query("drop table test cascade")
409        db.close()
410
411    def setUp(self):
412        self.assertTrue(self.cls_set_up)
413        self.db = DB()
414        if self.regtypes is None:
415            self.regtypes = self.db.use_regtypes()
416        else:
417            self.db.use_regtypes(self.regtypes)
418        query = self.db.query
419        query('set client_encoding=utf8')
420        query('set standard_conforming_strings=on')
421        query("set lc_monetary='C'")
422        query("set datestyle='ISO,YMD'")
423        query('set bytea_output=hex')
424
425    def tearDown(self):
426        self.doCleanups()
427        self.db.close()
428
429    def createTable(self, table, definition,
430                    temporary=True, oids=None, values=None):
431        query = self.db.query
432        if not '"' in table or '.' in table:
433            table = '"%s"' % table
434        if not temporary:
435            q = 'drop table if exists %s cascade' % table
436            query(q)
437            self.addCleanup(query, q)
438        temporary = 'temporary table' if temporary else 'table'
439        as_query = definition.startswith(('as ', 'AS '))
440        if not as_query and not definition.startswith('('):
441            definition = '(%s)' % definition
442        with_oids = 'with oids' if oids else 'without oids'
443        q = ['create', temporary, table]
444        if as_query:
445            q.extend([with_oids, definition])
446        else:
447            q.extend([definition, with_oids])
448        q = ' '.join(q)
449        query(q)
450        if values:
451            for params in values:
452                if not isinstance(params, (list, tuple)):
453                    params = [params]
454                values = ', '.join('$%d' % (n + 1) for n in range(len(params)))
455                q = "insert into %s values (%s)" % (table, values)
456                query(q, params)
457
458    def testClassName(self):
459        self.assertEqual(self.db.__class__.__name__, 'DB')
460
461    def testModuleName(self):
462        self.assertEqual(self.db.__module__, 'pg')
463        self.assertEqual(self.db.__class__.__module__, 'pg')
464
465    def testEscapeLiteral(self):
466        f = self.db.escape_literal
467        r = f(b"plain")
468        self.assertIsInstance(r, bytes)
469        self.assertEqual(r, b"'plain'")
470        r = f(u"plain")
471        self.assertIsInstance(r, unicode)
472        self.assertEqual(r, u"'plain'")
473        r = f(u"that's kÀse".encode('utf-8'))
474        self.assertIsInstance(r, bytes)
475        self.assertEqual(r, u"'that''s kÀse'".encode('utf-8'))
476        r = f(u"that's kÀse")
477        self.assertIsInstance(r, unicode)
478        self.assertEqual(r, u"'that''s kÀse'")
479        self.assertEqual(f(r"It's fine to have a \ inside."),
480                         r" E'It''s fine to have a \\ inside.'")
481        self.assertEqual(f('No "quotes" must be escaped.'),
482                         "'No \"quotes\" must be escaped.'")
483
484    def testEscapeIdentifier(self):
485        f = self.db.escape_identifier
486        r = f(b"plain")
487        self.assertIsInstance(r, bytes)
488        self.assertEqual(r, b'"plain"')
489        r = f(u"plain")
490        self.assertIsInstance(r, unicode)
491        self.assertEqual(r, u'"plain"')
492        r = f(u"that's kÀse".encode('utf-8'))
493        self.assertIsInstance(r, bytes)
494        self.assertEqual(r, u'"that\'s kÀse"'.encode('utf-8'))
495        r = f(u"that's kÀse")
496        self.assertIsInstance(r, unicode)
497        self.assertEqual(r, u'"that\'s kÀse"')
498        self.assertEqual(f(r"It's fine to have a \ inside."),
499                         '"It\'s fine to have a \\ inside."')
500        self.assertEqual(f('All "quotes" must be escaped.'),
501                         '"All ""quotes"" must be escaped."')
502
503    def testEscapeString(self):
504        f = self.db.escape_string
505        r = f(b"plain")
506        self.assertIsInstance(r, bytes)
507        self.assertEqual(r, b"plain")
508        r = f(u"plain")
509        self.assertIsInstance(r, unicode)
510        self.assertEqual(r, u"plain")
511        r = f(u"that's kÀse".encode('utf-8'))
512        self.assertIsInstance(r, bytes)
513        self.assertEqual(r, u"that''s kÀse".encode('utf-8'))
514        r = f(u"that's kÀse")
515        self.assertIsInstance(r, unicode)
516        self.assertEqual(r, u"that''s kÀse")
517        self.assertEqual(f(r"It's fine to have a \ inside."),
518                         r"It''s fine to have a \ inside.")
519
520    def testEscapeBytea(self):
521        f = self.db.escape_bytea
522        # note that escape_byte always returns hex output since Pg 9.0,
523        # regardless of the bytea_output setting
524        r = f(b'plain')
525        self.assertIsInstance(r, bytes)
526        self.assertEqual(r, b'\\x706c61696e')
527        r = f(u'plain')
528        self.assertIsInstance(r, unicode)
529        self.assertEqual(r, u'\\x706c61696e')
530        r = f(u"das is' kÀse".encode('utf-8'))
531        self.assertIsInstance(r, bytes)
532        self.assertEqual(r, b'\\x64617320697327206bc3a47365')
533        r = f(u"das is' kÀse")
534        self.assertIsInstance(r, unicode)
535        self.assertEqual(r, u'\\x64617320697327206bc3a47365')
536        self.assertEqual(f(b'O\x00ps\xff!'), b'\\x4f007073ff21')
537
538    def testUnescapeBytea(self):
539        f = self.db.unescape_bytea
540        r = f(b'plain')
541        self.assertIsInstance(r, bytes)
542        self.assertEqual(r, b'plain')
543        r = f(u'plain')
544        self.assertIsInstance(r, bytes)
545        self.assertEqual(r, b'plain')
546        r = f(b"das is' k\\303\\244se")
547        self.assertIsInstance(r, bytes)
548        self.assertEqual(r, u"das is' kÀse".encode('utf8'))
549        r = f(u"das is' k\\303\\244se")
550        self.assertIsInstance(r, bytes)
551        self.assertEqual(r, u"das is' kÀse".encode('utf8'))
552        self.assertEqual(f(r'O\\000ps\\377!'), b'O\\000ps\\377!')
553        self.assertEqual(f(r'\\x706c61696e'), b'\\x706c61696e')
554        self.assertEqual(f(r'\\x746861742773206be47365'),
555                         b'\\x746861742773206be47365')
556        self.assertEqual(f(r'\\x4f007073ff21'), b'\\x4f007073ff21')
557
558    def testDecodeJson(self):
559        f = self.db.decode_json
560        self.assertIsNone(f('null'))
561        data = {
562            "id": 1, "name": "Foo", "price": 1234.5,
563            "new": True, "note": None,
564            "tags": ["Bar", "Eek"],
565            "stock": {"warehouse": 300, "retail": 20}}
566        text = json.dumps(data)
567        r = f(text)
568        self.assertIsInstance(r, dict)
569        self.assertEqual(r, data)
570        self.assertIsInstance(r['id'], int)
571        self.assertIsInstance(r['name'], unicode)
572        self.assertIsInstance(r['price'], float)
573        self.assertIsInstance(r['new'], bool)
574        self.assertIsInstance(r['tags'], list)
575        self.assertIsInstance(r['stock'], dict)
576
577    def testEncodeJson(self):
578        f = self.db.encode_json
579        self.assertEqual(f(None), 'null')
580        data = {
581            "id": 1, "name": "Foo", "price": 1234.5,
582            "new": True, "note": None,
583            "tags": ["Bar", "Eek"],
584            "stock": {"warehouse": 300, "retail": 20}}
585        text = json.dumps(data)
586        r = f(data)
587        self.assertIsInstance(r, str)
588        self.assertEqual(r, text)
589
590    def testGetParameter(self):
591        f = self.db.get_parameter
592        self.assertRaises(TypeError, f)
593        self.assertRaises(TypeError, f, None)
594        self.assertRaises(TypeError, f, 42)
595        self.assertRaises(TypeError, f, '')
596        self.assertRaises(TypeError, f, [])
597        self.assertRaises(TypeError, f, [''])
598        self.assertRaises(pg.ProgrammingError, f, 'this_does_not_exist')
599        r = f('standard_conforming_strings')
600        self.assertEqual(r, 'on')
601        r = f('lc_monetary')
602        self.assertEqual(r, 'C')
603        r = f('datestyle')
604        self.assertEqual(r, 'ISO, YMD')
605        r = f('bytea_output')
606        self.assertEqual(r, 'hex')
607        r = f(['bytea_output', 'lc_monetary'])
608        self.assertIsInstance(r, list)
609        self.assertEqual(r, ['hex', 'C'])
610        r = f(('standard_conforming_strings', 'datestyle', 'bytea_output'))
611        self.assertEqual(r, ['on', 'ISO, YMD', 'hex'])
612        r = f(set(['bytea_output', 'lc_monetary']))
613        self.assertIsInstance(r, dict)
614        self.assertEqual(r, {'bytea_output': 'hex', 'lc_monetary': 'C'})
615        r = f(set(['Bytea_Output', ' LC_Monetary ']))
616        self.assertIsInstance(r, dict)
617        self.assertEqual(r, {'Bytea_Output': 'hex', ' LC_Monetary ': 'C'})
618        s = dict.fromkeys(('bytea_output', 'lc_monetary'))
619        r = f(s)
620        self.assertIs(r, s)
621        self.assertEqual(r, {'bytea_output': 'hex', 'lc_monetary': 'C'})
622        s = dict.fromkeys(('Bytea_Output', ' LC_Monetary '))
623        r = f(s)
624        self.assertIs(r, s)
625        self.assertEqual(r, {'Bytea_Output': 'hex', ' LC_Monetary ': 'C'})
626
627    def testGetParameterServerVersion(self):
628        r = self.db.get_parameter('server_version_num')
629        self.assertIsInstance(r, str)
630        s = self.db.server_version
631        self.assertIsInstance(s, int)
632        self.assertEqual(r, str(s))
633
634    def testGetParameterAll(self):
635        f = self.db.get_parameter
636        r = f('all')
637        self.assertIsInstance(r, dict)
638        self.assertEqual(r['standard_conforming_strings'], 'on')
639        self.assertEqual(r['lc_monetary'], 'C')
640        self.assertEqual(r['DateStyle'], 'ISO, YMD')
641        self.assertEqual(r['bytea_output'], 'hex')
642
643    def testSetParameter(self):
644        f = self.db.set_parameter
645        g = self.db.get_parameter
646        self.assertRaises(TypeError, f)
647        self.assertRaises(TypeError, f, None)
648        self.assertRaises(TypeError, f, 42)
649        self.assertRaises(TypeError, f, '')
650        self.assertRaises(TypeError, f, [])
651        self.assertRaises(TypeError, f, [''])
652        self.assertRaises(ValueError, f, 'all', 'invalid')
653        self.assertRaises(ValueError, f, {
654            'invalid1': 'value1', 'invalid2': 'value2'}, 'value')
655        self.assertRaises(pg.ProgrammingError, f, 'this_does_not_exist')
656        f('standard_conforming_strings', 'off')
657        self.assertEqual(g('standard_conforming_strings'), 'off')
658        f('datestyle', 'ISO, DMY')
659        self.assertEqual(g('datestyle'), 'ISO, DMY')
660        f(['standard_conforming_strings', 'datestyle'], ['on', 'ISO, DMY'])
661        self.assertEqual(g('standard_conforming_strings'), 'on')
662        self.assertEqual(g('datestyle'), 'ISO, DMY')
663        f(['default_with_oids', 'standard_conforming_strings'], 'off')
664        self.assertEqual(g('default_with_oids'), 'off')
665        self.assertEqual(g('standard_conforming_strings'), 'off')
666        f(('standard_conforming_strings', 'datestyle'), ('on', 'ISO, YMD'))
667        self.assertEqual(g('standard_conforming_strings'), 'on')
668        self.assertEqual(g('datestyle'), 'ISO, YMD')
669        f(('default_with_oids', 'standard_conforming_strings'), 'off')
670        self.assertEqual(g('default_with_oids'), 'off')
671        self.assertEqual(g('standard_conforming_strings'), 'off')
672        f(set(['default_with_oids', 'standard_conforming_strings']), 'on')
673        self.assertEqual(g('default_with_oids'), 'on')
674        self.assertEqual(g('standard_conforming_strings'), 'on')
675        self.assertRaises(ValueError, f, set(['default_with_oids',
676            'standard_conforming_strings']), ['off', 'on'])
677        f(set(['default_with_oids', 'standard_conforming_strings']),
678            ['off', 'off'])
679        self.assertEqual(g('default_with_oids'), 'off')
680        self.assertEqual(g('standard_conforming_strings'), 'off')
681        f({'standard_conforming_strings': 'on', 'datestyle': 'ISO, YMD'})
682        self.assertEqual(g('standard_conforming_strings'), 'on')
683        self.assertEqual(g('datestyle'), 'ISO, YMD')
684
685    def testResetParameter(self):
686        db = DB()
687        f = db.set_parameter
688        g = db.get_parameter
689        r = g('default_with_oids')
690        self.assertIn(r, ('on', 'off'))
691        dwi, not_dwi = r, 'off' if r == 'on' else 'on'
692        r = g('standard_conforming_strings')
693        self.assertIn(r, ('on', 'off'))
694        scs, not_scs = r, 'off' if r == 'on' else 'on'
695        f('default_with_oids', not_dwi)
696        f('standard_conforming_strings', not_scs)
697        self.assertEqual(g('default_with_oids'), not_dwi)
698        self.assertEqual(g('standard_conforming_strings'), not_scs)
699        f('default_with_oids')
700        f('standard_conforming_strings', None)
701        self.assertEqual(g('default_with_oids'), dwi)
702        self.assertEqual(g('standard_conforming_strings'), scs)
703        f('default_with_oids', not_dwi)
704        f('standard_conforming_strings', not_scs)
705        self.assertEqual(g('default_with_oids'), not_dwi)
706        self.assertEqual(g('standard_conforming_strings'), not_scs)
707        f(['default_with_oids', 'standard_conforming_strings'], None)
708        self.assertEqual(g('default_with_oids'), dwi)
709        self.assertEqual(g('standard_conforming_strings'), scs)
710        f('default_with_oids', not_dwi)
711        f('standard_conforming_strings', not_scs)
712        self.assertEqual(g('default_with_oids'), not_dwi)
713        self.assertEqual(g('standard_conforming_strings'), not_scs)
714        f(('default_with_oids', 'standard_conforming_strings'))
715        self.assertEqual(g('default_with_oids'), dwi)
716        self.assertEqual(g('standard_conforming_strings'), scs)
717        f('default_with_oids', not_dwi)
718        f('standard_conforming_strings', not_scs)
719        self.assertEqual(g('default_with_oids'), not_dwi)
720        self.assertEqual(g('standard_conforming_strings'), not_scs)
721        f(set(['default_with_oids', 'standard_conforming_strings']))
722        self.assertEqual(g('default_with_oids'), dwi)
723        self.assertEqual(g('standard_conforming_strings'), scs)
724
725    def testResetParameterAll(self):
726        db = DB()
727        f = db.set_parameter
728        self.assertRaises(ValueError, f, 'all', 0)
729        self.assertRaises(ValueError, f, 'all', 'off')
730        g = db.get_parameter
731        r = g('default_with_oids')
732        self.assertIn(r, ('on', 'off'))
733        dwi, not_dwi = r, 'off' if r == 'on' else 'on'
734        r = g('standard_conforming_strings')
735        self.assertIn(r, ('on', 'off'))
736        scs, not_scs = r, 'off' if r == 'on' else 'on'
737        f('default_with_oids', not_dwi)
738        f('standard_conforming_strings', not_scs)
739        self.assertEqual(g('default_with_oids'), not_dwi)
740        self.assertEqual(g('standard_conforming_strings'), not_scs)
741        f('all')
742        self.assertEqual(g('default_with_oids'), dwi)
743        self.assertEqual(g('standard_conforming_strings'), scs)
744
745    def testSetParameterLocal(self):
746        f = self.db.set_parameter
747        g = self.db.get_parameter
748        self.assertEqual(g('standard_conforming_strings'), 'on')
749        self.db.begin()
750        f('standard_conforming_strings', 'off', local=True)
751        self.assertEqual(g('standard_conforming_strings'), 'off')
752        self.db.end()
753        self.assertEqual(g('standard_conforming_strings'), 'on')
754
755    def testSetParameterSession(self):
756        f = self.db.set_parameter
757        g = self.db.get_parameter
758        self.assertEqual(g('standard_conforming_strings'), 'on')
759        self.db.begin()
760        f('standard_conforming_strings', 'off', local=False)
761        self.assertEqual(g('standard_conforming_strings'), 'off')
762        self.db.end()
763        self.assertEqual(g('standard_conforming_strings'), 'off')
764
765    def testReset(self):
766        db = DB()
767        default_datestyle = db.get_parameter('datestyle')
768        changed_datestyle = 'ISO, DMY'
769        if changed_datestyle == default_datestyle:
770            changed_datestyle == 'ISO, YMD'
771        self.db.set_parameter('datestyle', changed_datestyle)
772        r = self.db.get_parameter('datestyle')
773        self.assertEqual(r, changed_datestyle)
774        con = self.db.db
775        q = con.query("show datestyle")
776        self.db.reset()
777        r = q.getresult()[0][0]
778        self.assertEqual(r, changed_datestyle)
779        q = con.query("show datestyle")
780        r = q.getresult()[0][0]
781        self.assertEqual(r, default_datestyle)
782        r = self.db.get_parameter('datestyle')
783        self.assertEqual(r, default_datestyle)
784
785    def testReopen(self):
786        db = DB()
787        default_datestyle = db.get_parameter('datestyle')
788        changed_datestyle = 'ISO, DMY'
789        if changed_datestyle == default_datestyle:
790            changed_datestyle == 'ISO, YMD'
791        self.db.set_parameter('datestyle', changed_datestyle)
792        r = self.db.get_parameter('datestyle')
793        self.assertEqual(r, changed_datestyle)
794        con = self.db.db
795        q = con.query("show datestyle")
796        self.db.reopen()
797        r = q.getresult()[0][0]
798        self.assertEqual(r, changed_datestyle)
799        self.assertRaises(TypeError, getattr, con, 'query')
800        r = self.db.get_parameter('datestyle')
801        self.assertEqual(r, default_datestyle)
802
803    def testCreateTable(self):
804        table = 'test hello world'
805        values = [(2, "World!"), (1, "Hello")]
806        self.createTable(table, "n smallint, t varchar",
807                         temporary=True, oids=True, values=values)
808        r = self.db.query('select t from "%s" order by n' % table).getresult()
809        r = ', '.join(row[0] for row in r)
810        self.assertEqual(r, "Hello, World!")
811        r = self.db.query('select oid from "%s" limit 1' % table).getresult()
812        self.assertIsInstance(r[0][0], int)
813
814    def testQuery(self):
815        query = self.db.query
816        table = 'test_table'
817        self.createTable(table, "n integer", oids=True)
818        q = "insert into test_table values (1)"
819        r = query(q)
820        self.assertIsInstance(r, int)
821        q = "insert into test_table select 2"
822        r = query(q)
823        self.assertIsInstance(r, int)
824        oid = r
825        q = "select oid from test_table where n=2"
826        r = query(q).getresult()
827        self.assertEqual(len(r), 1)
828        r = r[0]
829        self.assertEqual(len(r), 1)
830        r = r[0]
831        self.assertEqual(r, oid)
832        q = "insert into test_table select 3 union select 4 union select 5"
833        r = query(q)
834        self.assertIsInstance(r, str)
835        self.assertEqual(r, '3')
836        q = "update test_table set n=4 where n<5"
837        r = query(q)
838        self.assertIsInstance(r, str)
839        self.assertEqual(r, '4')
840        q = "delete from test_table"
841        r = query(q)
842        self.assertIsInstance(r, str)
843        self.assertEqual(r, '5')
844
845    def testMultipleQueries(self):
846        self.assertEqual(self.db.query(
847            "create temporary table test_multi (n integer);"
848            "insert into test_multi values (4711);"
849            "select n from test_multi").getresult()[0][0], 4711)
850
851    def testQueryWithParams(self):
852        query = self.db.query
853        self.createTable('test_table', 'n1 integer, n2 integer', oids=True)
854        q = "insert into test_table values ($1, $2)"
855        r = query(q, (1, 2))
856        self.assertIsInstance(r, int)
857        r = query(q, [3, 4])
858        self.assertIsInstance(r, int)
859        r = query(q, [5, 6])
860        self.assertIsInstance(r, int)
861        q = "select * from test_table order by 1, 2"
862        self.assertEqual(query(q).getresult(),
863                         [(1, 2), (3, 4), (5, 6)])
864        q = "select * from test_table where n1=$1 and n2=$2"
865        self.assertEqual(query(q, 3, 4).getresult(), [(3, 4)])
866        q = "update test_table set n2=$2 where n1=$1"
867        r = query(q, 3, 7)
868        self.assertEqual(r, '1')
869        q = "select * from test_table order by 1, 2"
870        self.assertEqual(query(q).getresult(),
871                         [(1, 2), (3, 7), (5, 6)])
872        q = "delete from test_table where n2!=$1"
873        r = query(q, 4)
874        self.assertEqual(r, '3')
875
876    def testEmptyQuery(self):
877        self.assertRaises(ValueError, self.db.query, '')
878
879    def testQueryDataError(self):
880        try:
881            self.db.query("select 1/0")
882        except pg.DataError as error:
883            self.assertEqual(error.sqlstate, '22012')
884
885    def testQueryFormatted(self):
886        f = self.db.query_formatted
887        t = True if pg.get_bool() else 't'
888        q = f("select %s::int, %s::real, %s::text, %s::bool",
889              (3, 2.5, 'hello', True))
890        r = q.getresult()[0]
891        self.assertEqual(r, (3, 2.5, 'hello', t))
892        q = f("select %s, %s, %s, %s", (3, 2.5, 'hello', True), inline=True)
893        r = q.getresult()[0]
894        if isinstance(r[1], Decimal):
895            # Python 2.6 cannot compare float and Decimal
896            r = list(r)
897            r[1] = float(r[1])
898            r = tuple(r)
899        self.assertEqual(r, (3, 2.5, 'hello', t))
900
901    def testPkey(self):
902        query = self.db.query
903        pkey = self.db.pkey
904        self.assertRaises(KeyError, pkey, 'test')
905        for t in ('pkeytest', 'primary key test'):
906            self.createTable('%s0' % t, 'a smallint')
907            self.createTable('%s1' % t, 'b smallint primary key')
908            self.createTable('%s2' % t,
909                'c smallint, d smallint primary key')
910            self.createTable('%s3' % t,
911                'e smallint, f smallint, g smallint, h smallint, i smallint,'
912                ' primary key (f, h)')
913            self.createTable('%s4' % t,
914                'e smallint, f smallint, g smallint, h smallint, i smallint,'
915                ' primary key (h, f)')
916            self.createTable('%s5' % t,
917                'more_than_one_letter varchar primary key')
918            self.createTable('%s6' % t,
919                '"with space" date primary key')
920            self.createTable('%s7' % t,
921                'a_very_long_column_name varchar, "with space" date, "42" int,'
922                ' primary key (a_very_long_column_name, "with space", "42")')
923            self.assertRaises(KeyError, pkey, '%s0' % t)
924            self.assertEqual(pkey('%s1' % t), 'b')
925            self.assertEqual(pkey('%s1' % t, True), ('b',))
926            self.assertEqual(pkey('%s1' % t, composite=False), 'b')
927            self.assertEqual(pkey('%s1' % t, composite=True), ('b',))
928            self.assertEqual(pkey('%s2' % t), 'd')
929            self.assertEqual(pkey('%s2' % t, composite=True), ('d',))
930            r = pkey('%s3' % t)
931            self.assertIsInstance(r, tuple)
932            self.assertEqual(r, ('f', 'h'))
933            r = pkey('%s3' % t, composite=False)
934            self.assertIsInstance(r, tuple)
935            self.assertEqual(r, ('f', 'h'))
936            r = pkey('%s4' % t)
937            self.assertIsInstance(r, tuple)
938            self.assertEqual(r, ('h', 'f'))
939            self.assertEqual(pkey('%s5' % t), 'more_than_one_letter')
940            self.assertEqual(pkey('%s6' % t), 'with space')
941            r = pkey('%s7' % t)
942            self.assertIsInstance(r, tuple)
943            self.assertEqual(r, (
944                'a_very_long_column_name', 'with space', '42'))
945            # a newly added primary key will be detected
946            query('alter table "%s0" add primary key (a)' % t)
947            self.assertEqual(pkey('%s0' % t), 'a')
948            # a changed primary key will not be detected,
949            # indicating that the internal cache is operating
950            query('alter table "%s1" rename column b to x' % t)
951            self.assertEqual(pkey('%s1' % t), 'b')
952            # we get the changed primary key when the cache is flushed
953            self.assertEqual(pkey('%s1' % t, flush=True), 'x')
954
955    def testGetDatabases(self):
956        databases = self.db.get_databases()
957        self.assertIn('template0', databases)
958        self.assertIn('template1', databases)
959        self.assertNotIn('not existing database', databases)
960        self.assertIn('postgres', databases)
961        self.assertIn(dbname, databases)
962
963    def testGetTables(self):
964        get_tables = self.db.get_tables
965        tables = ('A very Special Name', 'A_MiXeD_quoted_NaMe',
966                  'Hello, Test World!', 'Zoro', 'a1', 'a2', 'a321',
967                  'averyveryveryveryveryveryveryreallyreallylongtablename',
968                  'b0', 'b3', 'x', 'xXx', 'xx', 'y', 'z')
969        for t in tables:
970            self.db.query('drop table if exists "%s" cascade' % t)
971        before_tables = get_tables()
972        self.assertIsInstance(before_tables, list)
973        for t in before_tables:
974            t = t.split('.', 1)
975            self.assertGreaterEqual(len(t), 2)
976            if len(t) > 2:
977                self.assertTrue(t[1].startswith('"'))
978            t = t[0]
979            self.assertNotEqual(t, 'information_schema')
980            self.assertFalse(t.startswith('pg_'))
981        for t in tables:
982            self.createTable(t, 'as select 0', temporary=False)
983        current_tables = get_tables()
984        new_tables = [t for t in current_tables if t not in before_tables]
985        expected_new_tables = ['public.%s' % (
986            '"%s"' % t if ' ' in t or t != t.lower() else t) for t in tables]
987        self.assertEqual(new_tables, expected_new_tables)
988        self.doCleanups()
989        after_tables = get_tables()
990        self.assertEqual(after_tables, before_tables)
991
992    def testGetRelations(self):
993        get_relations = self.db.get_relations
994        result = get_relations()
995        self.assertIn('public.test', result)
996        self.assertIn('public.test_view', result)
997        result = get_relations('rv')
998        self.assertIn('public.test', result)
999        self.assertIn('public.test_view', result)
1000        result = get_relations('r')
1001        self.assertIn('public.test', result)
1002        self.assertNotIn('public.test_view', result)
1003        result = get_relations('v')
1004        self.assertNotIn('public.test', result)
1005        self.assertIn('public.test_view', result)
1006        result = get_relations('cisSt')
1007        self.assertNotIn('public.test', result)
1008        self.assertNotIn('public.test_view', result)
1009
1010    def testGetAttnames(self):
1011        get_attnames = self.db.get_attnames
1012        self.assertRaises(pg.ProgrammingError,
1013                          self.db.get_attnames, 'does_not_exist')
1014        self.assertRaises(pg.ProgrammingError,
1015                          self.db.get_attnames, 'has.too.many.dots')
1016        r = get_attnames('test')
1017        self.assertIsInstance(r, dict)
1018        if self.regtypes:
1019            self.assertEqual(r, dict(
1020                i2='smallint', i4='integer', i8='bigint', d='numeric',
1021                f4='real', f8='double precision', m='money',
1022                v4='character varying', c4='character', t='text'))
1023        else:
1024            self.assertEqual(r, dict(
1025                i2='int', i4='int', i8='int', d='num',
1026                f4='float', f8='float', m='money',
1027                v4='text', c4='text', t='text'))
1028        self.createTable('test_table',
1029                         'n int, alpha smallint, beta bool,'
1030                         ' gamma char(5), tau text, v varchar(3)')
1031        r = get_attnames('test_table')
1032        self.assertIsInstance(r, dict)
1033        if self.regtypes:
1034            self.assertEqual(r, dict(
1035                n='integer', alpha='smallint', beta='boolean',
1036                gamma='character', tau='text', v='character varying'))
1037        else:
1038            self.assertEqual(r, dict(
1039                n='int', alpha='int', beta='bool',
1040                gamma='text', tau='text', v='text'))
1041
1042    def testGetAttnamesWithQuotes(self):
1043        get_attnames = self.db.get_attnames
1044        table = 'test table for get_attnames()'
1045        self.createTable(table,
1046            '"Prime!" smallint, "much space" integer, "Questions?" text')
1047        r = get_attnames(table)
1048        self.assertIsInstance(r, dict)
1049        if self.regtypes:
1050            self.assertEqual(r, {
1051                'Prime!': 'smallint', 'much space': 'integer',
1052                'Questions?': 'text'})
1053        else:
1054            self.assertEqual(r, {
1055                'Prime!': 'int', 'much space': 'int', 'Questions?': 'text'})
1056        table = 'yet another test table for get_attnames()'
1057        self.createTable(table,
1058                         'a smallint, b integer, c bigint,'
1059                         ' e numeric, f real, f2 double precision, m money,'
1060                         ' x smallint, y smallint, z smallint,'
1061                         ' Normal_NaMe smallint, "Special Name" smallint,'
1062                         ' t text, u char(2), v varchar(2),'
1063                         ' primary key (y, u)', oids=True)
1064        r = get_attnames(table)
1065        self.assertIsInstance(r, dict)
1066        if self.regtypes:
1067            self.assertEqual(r, {
1068                'a': 'smallint', 'b': 'integer', 'c': 'bigint',
1069                'e': 'numeric', 'f': 'real', 'f2': 'double precision',
1070                'm': 'money', 'normal_name': 'smallint',
1071                'Special Name': 'smallint', 'u': 'character',
1072                't': 'text', 'v': 'character varying', 'y': 'smallint',
1073                'x': 'smallint', 'z': 'smallint', 'oid': 'oid'})
1074        else:
1075            self.assertEqual(r, {'a': 'int', 'b': 'int', 'c': 'int',
1076                 'e': 'num', 'f': 'float', 'f2': 'float', 'm': 'money',
1077                 'normal_name': 'int', 'Special Name': 'int',
1078                 'u': 'text', 't': 'text', 'v': 'text',
1079                 'y': 'int', 'x': 'int', 'z': 'int', 'oid': 'int'})
1080
1081    def testGetAttnamesWithRegtypes(self):
1082        get_attnames = self.db.get_attnames
1083        self.createTable('test_table', 'n int, alpha smallint, beta bool,'
1084            ' gamma char(5), tau text, v varchar(3)')
1085        use_regtypes = self.db.use_regtypes
1086        regtypes = use_regtypes()
1087        self.assertEqual(regtypes, self.regtypes)
1088        use_regtypes(True)
1089        try:
1090            r = get_attnames("test_table")
1091            self.assertIsInstance(r, dict)
1092        finally:
1093            use_regtypes(regtypes)
1094        self.assertEqual(r, dict(
1095            n='integer', alpha='smallint', beta='boolean',
1096            gamma='character', tau='text', v='character varying'))
1097
1098    def testGetAttnamesWithoutRegtypes(self):
1099        get_attnames = self.db.get_attnames
1100        self.createTable('test_table', 'n int, alpha smallint, beta bool,'
1101            ' gamma char(5), tau text, v varchar(3)')
1102        use_regtypes = self.db.use_regtypes
1103        regtypes = use_regtypes()
1104        self.assertEqual(regtypes, self.regtypes)
1105        use_regtypes(False)
1106        try:
1107            r = get_attnames("test_table")
1108            self.assertIsInstance(r, dict)
1109        finally:
1110            use_regtypes(regtypes)
1111        self.assertEqual(r, dict(
1112            n='int', alpha='int', beta='bool',
1113            gamma='text', tau='text', v='text'))
1114
1115    def testGetAttnamesIsCached(self):
1116        get_attnames = self.db.get_attnames
1117        int_type = 'integer' if self.regtypes else 'int'
1118        text_type = 'text'
1119        query = self.db.query
1120        self.createTable('test_table', 'col int')
1121        r = get_attnames("test_table")
1122        self.assertIsInstance(r, dict)
1123        self.assertEqual(r, dict(col=int_type))
1124        query("alter table test_table alter column col type text")
1125        query("alter table test_table add column col2 int")
1126        r = get_attnames("test_table")
1127        self.assertEqual(r, dict(col=int_type))
1128        r = get_attnames("test_table", flush=True)
1129        self.assertEqual(r, dict(col=text_type, col2=int_type))
1130        query("alter table test_table drop column col2")
1131        r = get_attnames("test_table")
1132        self.assertEqual(r, dict(col=text_type, col2=int_type))
1133        r = get_attnames("test_table", flush=True)
1134        self.assertEqual(r, dict(col=text_type))
1135        query("alter table test_table drop column col")
1136        r = get_attnames("test_table")
1137        self.assertEqual(r, dict(col=text_type))
1138        r = get_attnames("test_table", flush=True)
1139        self.assertEqual(r, dict())
1140
1141    def testGetAttnamesIsOrdered(self):
1142        get_attnames = self.db.get_attnames
1143        r = get_attnames('test', flush=True)
1144        self.assertIsInstance(r, OrderedDict)
1145        if self.regtypes:
1146            self.assertEqual(r, OrderedDict([
1147                ('i2', 'smallint'), ('i4', 'integer'), ('i8', 'bigint'),
1148                ('d', 'numeric'), ('f4', 'real'), ('f8', 'double precision'),
1149                ('m', 'money'), ('v4', 'character varying'),
1150                ('c4', 'character'), ('t', 'text')]))
1151        else:
1152            self.assertEqual(r, OrderedDict([
1153                ('i2', 'int'), ('i4', 'int'), ('i8', 'int'),
1154                ('d', 'num'), ('f4', 'float'), ('f8', 'float'), ('m', 'money'),
1155                ('v4', 'text'), ('c4', 'text'), ('t', 'text')]))
1156        if OrderedDict is not dict:
1157            r = ' '.join(list(r.keys()))
1158            self.assertEqual(r, 'i2 i4 i8 d f4 f8 m v4 c4 t')
1159        table = 'test table for get_attnames'
1160        self.createTable(table, 'n int, alpha smallint, v varchar(3),'
1161                ' gamma char(5), tau text, beta bool')
1162        r = get_attnames(table)
1163        self.assertIsInstance(r, OrderedDict)
1164        if self.regtypes:
1165            self.assertEqual(r, OrderedDict([
1166                ('n', 'integer'), ('alpha', 'smallint'),
1167                ('v', 'character varying'), ('gamma', 'character'),
1168                ('tau', 'text'), ('beta', 'boolean')]))
1169        else:
1170            self.assertEqual(r, OrderedDict([
1171                ('n', 'int'), ('alpha', 'int'), ('v', 'text'),
1172                ('gamma', 'text'), ('tau', 'text'), ('beta', 'bool')]))
1173        if OrderedDict is not dict:
1174            r = ' '.join(list(r.keys()))
1175            self.assertEqual(r, 'n alpha v gamma tau beta')
1176        else:
1177            self.skipTest('OrderedDict is not supported')
1178
1179    def testGetAttnamesIsAttrDict(self):
1180        AttrDict = pg.AttrDict
1181        get_attnames = self.db.get_attnames
1182        r = get_attnames('test', flush=True)
1183        self.assertIsInstance(r, AttrDict)
1184        if self.regtypes:
1185            self.assertEqual(r, AttrDict([
1186                ('i2', 'smallint'), ('i4', 'integer'), ('i8', 'bigint'),
1187                ('d', 'numeric'), ('f4', 'real'), ('f8', 'double precision'),
1188                ('m', 'money'), ('v4', 'character varying'),
1189                ('c4', 'character'), ('t', 'text')]))
1190        else:
1191            self.assertEqual(r, AttrDict([
1192                ('i2', 'int'), ('i4', 'int'), ('i8', 'int'),
1193                ('d', 'num'), ('f4', 'float'), ('f8', 'float'), ('m', 'money'),
1194                ('v4', 'text'), ('c4', 'text'), ('t', 'text')]))
1195        r = ' '.join(list(r.keys()))
1196        self.assertEqual(r, 'i2 i4 i8 d f4 f8 m v4 c4 t')
1197        table = 'test table for get_attnames'
1198        self.createTable(table, 'n int, alpha smallint, v varchar(3),'
1199                ' gamma char(5), tau text, beta bool')
1200        r = get_attnames(table)
1201        self.assertIsInstance(r, AttrDict)
1202        if self.regtypes:
1203            self.assertEqual(r, AttrDict([
1204                ('n', 'integer'), ('alpha', 'smallint'),
1205                ('v', 'character varying'), ('gamma', 'character'),
1206                ('tau', 'text'), ('beta', 'boolean')]))
1207        else:
1208            self.assertEqual(r, AttrDict([
1209                ('n', 'int'), ('alpha', 'int'), ('v', 'text'),
1210                ('gamma', 'text'), ('tau', 'text'), ('beta', 'bool')]))
1211        r = ' '.join(list(r.keys()))
1212        self.assertEqual(r, 'n alpha v gamma tau beta')
1213
1214    def testHasTablePrivilege(self):
1215        can = self.db.has_table_privilege
1216        self.assertEqual(can('test'), True)
1217        self.assertEqual(can('test', 'select'), True)
1218        self.assertEqual(can('test', 'SeLeCt'), True)
1219        self.assertEqual(can('test', 'SELECT'), True)
1220        self.assertEqual(can('test', 'insert'), True)
1221        self.assertEqual(can('test', 'update'), True)
1222        self.assertEqual(can('test', 'delete'), True)
1223        self.assertRaises(pg.DataError, can, 'test', 'foobar')
1224        self.assertRaises(pg.ProgrammingError, can, 'table_does_not_exist')
1225        r = self.db.query('select rolsuper FROM pg_roles'
1226            ' where rolname=current_user').getresult()[0][0]
1227        if not pg.get_bool():
1228            r = r == 't'
1229        if r:
1230            self.skipTest('must not be superuser')
1231        self.assertEqual(can('pg_views', 'select'), True)
1232        self.assertEqual(can('pg_views', 'delete'), False)
1233
1234    def testGet(self):
1235        get = self.db.get
1236        query = self.db.query
1237        table = 'get_test_table'
1238        self.assertRaises(TypeError, get)
1239        self.assertRaises(TypeError, get, table)
1240        self.createTable(table, 'n integer, t text',
1241                         values=enumerate('xyz', start=1))
1242        self.assertRaises(pg.ProgrammingError, get, table, 2)
1243        r = get(table, 2, 'n')
1244        self.assertIsInstance(r, dict)
1245        self.assertEqual(r, dict(n=2, t='y'))
1246        r = get(table, 1, 'n')
1247        self.assertEqual(r, dict(n=1, t='x'))
1248        r = get(table, (3,), ('n',))
1249        self.assertEqual(r, dict(n=3, t='z'))
1250        r = get(table, 'y', 't')
1251        self.assertEqual(r, dict(n=2, t='y'))
1252        self.assertRaises(pg.DatabaseError, get, table, 4)
1253        self.assertRaises(pg.DatabaseError, get, table, 4, 'n')
1254        self.assertRaises(pg.DatabaseError, get, table, 'y')
1255        self.assertRaises(pg.DatabaseError, get, table, 2, 't')
1256        s = dict(n=3)
1257        self.assertRaises(pg.ProgrammingError, get, table, s)
1258        r = get(table, s, 'n')
1259        self.assertIs(r, s)
1260        self.assertEqual(r, dict(n=3, t='z'))
1261        s.update(t='x')
1262        r = get(table, s, 't')
1263        self.assertIs(r, s)
1264        self.assertEqual(s, dict(n=1, t='x'))
1265        r = get(table, s, ('n', 't'))
1266        self.assertIs(r, s)
1267        self.assertEqual(r, dict(n=1, t='x'))
1268        query('alter table "%s" alter n set not null' % table)
1269        query('alter table "%s" add primary key (n)' % table)
1270        r = get(table, 2)
1271        self.assertIsInstance(r, dict)
1272        self.assertEqual(r, dict(n=2, t='y'))
1273        self.assertEqual(get(table, 1)['t'], 'x')
1274        self.assertEqual(get(table, 3)['t'], 'z')
1275        self.assertEqual(get(table + '*', 2)['t'], 'y')
1276        self.assertEqual(get(table + ' *', 2)['t'], 'y')
1277        self.assertRaises(KeyError, get, table, (2, 2))
1278        s = dict(n=3)
1279        r = get(table, s)
1280        self.assertIs(r, s)
1281        self.assertEqual(r, dict(n=3, t='z'))
1282        s.update(n=1)
1283        self.assertEqual(get(table, s)['t'], 'x')
1284        s.update(n=2)
1285        self.assertEqual(get(table, r)['t'], 'y')
1286        s.pop('n')
1287        self.assertRaises(KeyError, get, table, s)
1288
1289    def testGetWithOid(self):
1290        get = self.db.get
1291        query = self.db.query
1292        table = 'get_with_oid_test_table'
1293        self.createTable(table, 'n integer, t text', oids=True,
1294                         values=enumerate('xyz', start=1))
1295        self.assertRaises(pg.ProgrammingError, get, table, 2)
1296        self.assertRaises(KeyError, get, table, {}, 'oid')
1297        r = get(table, 2, 'n')
1298        qoid = 'oid(%s)' % table
1299        self.assertIn(qoid, r)
1300        oid = r[qoid]
1301        self.assertIsInstance(oid, int)
1302        result = {'t': 'y', 'n': 2, qoid: oid}
1303        self.assertEqual(r, result)
1304        r = get(table, oid, 'oid')
1305        self.assertEqual(r, result)
1306        r = get(table, dict(oid=oid))
1307        self.assertEqual(r, result)
1308        r = get(table, dict(oid=oid), 'oid')
1309        self.assertEqual(r, result)
1310        r = get(table, {qoid: oid})
1311        self.assertEqual(r, result)
1312        r = get(table, {qoid: oid}, 'oid')
1313        self.assertEqual(r, result)
1314        self.assertEqual(get(table + '*', 2, 'n'), r)
1315        self.assertEqual(get(table + ' *', 2, 'n'), r)
1316        self.assertEqual(get(table, oid, 'oid')['t'], 'y')
1317        self.assertEqual(get(table, 1, 'n')['t'], 'x')
1318        self.assertEqual(get(table, 3, 'n')['t'], 'z')
1319        self.assertEqual(get(table, 2, 'n')['t'], 'y')
1320        self.assertRaises(pg.DatabaseError, get, table, 4, 'n')
1321        r['n'] = 3
1322        self.assertEqual(get(table, r, 'n')['t'], 'z')
1323        self.assertEqual(get(table, 1, 'n')['t'], 'x')
1324        self.assertEqual(get(table, r, 'oid')['t'], 'z')
1325        query('alter table "%s" alter n set not null' % table)
1326        query('alter table "%s" add primary key (n)' % table)
1327        self.assertEqual(get(table, 3)['t'], 'z')
1328        self.assertEqual(get(table, 1)['t'], 'x')
1329        self.assertEqual(get(table, 2)['t'], 'y')
1330        r['n'] = 1
1331        self.assertEqual(get(table, r)['t'], 'x')
1332        r['n'] = 3
1333        self.assertEqual(get(table, r)['t'], 'z')
1334        r['n'] = 2
1335        self.assertEqual(get(table, r)['t'], 'y')
1336        r = get(table, oid, 'oid')
1337        self.assertEqual(r, result)
1338        r = get(table, dict(oid=oid))
1339        self.assertEqual(r, result)
1340        r = get(table, dict(oid=oid), 'oid')
1341        self.assertEqual(r, result)
1342        r = get(table, {qoid: oid})
1343        self.assertEqual(r, result)
1344        r = get(table, {qoid: oid}, 'oid')
1345        self.assertEqual(r, result)
1346        r = get(table, dict(oid=oid, n=1))
1347        self.assertEqual(r['n'], 1)
1348        self.assertNotEqual(r[qoid], oid)
1349        r = get(table, dict(oid=oid, t='z'), 't')
1350        self.assertEqual(r['n'], 3)
1351        self.assertNotEqual(r[qoid], oid)
1352
1353    def testGetWithCompositeKey(self):
1354        get = self.db.get
1355        query = self.db.query
1356        table = 'get_test_table_1'
1357        self.createTable(table, 'n integer primary key, t text',
1358                         values=enumerate('abc', start=1))
1359        self.assertEqual(get(table, 2)['t'], 'b')
1360        self.assertEqual(get(table, 1, 'n')['t'], 'a')
1361        self.assertEqual(get(table, 2, ('n',))['t'], 'b')
1362        self.assertEqual(get(table, 3, ['n'])['t'], 'c')
1363        self.assertEqual(get(table, (2,), ('n',))['t'], 'b')
1364        self.assertEqual(get(table, 'b', 't')['n'], 2)
1365        self.assertEqual(get(table, ('a',), ('t',))['n'], 1)
1366        self.assertEqual(get(table, ['c'], ['t'])['n'], 3)
1367        table = 'get_test_table_2'
1368        self.createTable(table,
1369                         'n integer, m integer, t text, primary key (n, m)',
1370                         values=[(n + 1, m + 1, chr(ord('a') + 2 * n + m))
1371                                 for n in range(3) for m in range(2)])
1372        self.assertRaises(KeyError, get, table, 2)
1373        self.assertEqual(get(table, (1, 1))['t'], 'a')
1374        self.assertEqual(get(table, (1, 2))['t'], 'b')
1375        self.assertEqual(get(table, (2, 1))['t'], 'c')
1376        self.assertEqual(get(table, (1, 2), ('n', 'm'))['t'], 'b')
1377        self.assertEqual(get(table, (1, 2), ('m', 'n'))['t'], 'c')
1378        self.assertEqual(get(table, (3, 1), ('n', 'm'))['t'], 'e')
1379        self.assertEqual(get(table, (1, 3), ('m', 'n'))['t'], 'e')
1380        self.assertEqual(get(table, dict(n=2, m=2))['t'], 'd')
1381        self.assertEqual(get(table, dict(n=1, m=2), ('n', 'm'))['t'], 'b')
1382        self.assertEqual(get(table, dict(n=2, m=1), ['n', 'm'])['t'], 'c')
1383        self.assertEqual(get(table, dict(n=3, m=2), ('m', 'n'))['t'], 'f')
1384
1385    def testGetWithQuotedNames(self):
1386        get = self.db.get
1387        query = self.db.query
1388        table = 'test table for get()'
1389        self.createTable(table, '"Prime!" smallint primary key,'
1390                                ' "much space" integer, "Questions?" text',
1391                         values=[(17, 1001, 'No!')])
1392        r = get(table, 17)
1393        self.assertIsInstance(r, dict)
1394        self.assertEqual(r['Prime!'], 17)
1395        self.assertEqual(r['much space'], 1001)
1396        self.assertEqual(r['Questions?'], 'No!')
1397
1398    def testGetFromView(self):
1399        self.db.query('delete from test where i4=14')
1400        self.db.query('insert into test (i4, v4) values('
1401                      "14, 'abc4')")
1402        r = self.db.get('test_view', 14, 'i4')
1403        self.assertIn('v4', r)
1404        self.assertEqual(r['v4'], 'abc4')
1405
1406    def testGetLittleBobbyTables(self):
1407        get = self.db.get
1408        query = self.db.query
1409        self.createTable('test_students',
1410                         'firstname varchar primary key, nickname varchar, grade char(2)',
1411                         values=[("D'Arcy", 'Darcey', 'A+'), ('Sheldon', 'Moonpie', 'A+'),
1412                                 ('Robert', 'Little Bobby Tables', 'D-')])
1413        r = get('test_students', 'Sheldon')
1414        self.assertEqual(r, dict(
1415            firstname="Sheldon", nickname='Moonpie', grade='A+'))
1416        r = get('test_students', 'Robert')
1417        self.assertEqual(r, dict(
1418            firstname="Robert", nickname='Little Bobby Tables', grade='D-'))
1419        r = get('test_students', "D'Arcy")
1420        self.assertEqual(r, dict(
1421            firstname="D'Arcy", nickname='Darcey', grade='A+'))
1422        try:
1423            get('test_students', "D' Arcy")
1424        except pg.DatabaseError as error:
1425            self.assertEqual(str(error),
1426                'No such record in test_students\nwhere "firstname" = $1\n'
1427                'with $1="D\' Arcy"')
1428        try:
1429            get('test_students', "Robert'); TRUNCATE TABLE test_students;--")
1430        except pg.DatabaseError as error:
1431            self.assertEqual(str(error),
1432                'No such record in test_students\nwhere "firstname" = $1\n'
1433                'with $1="Robert\'); TRUNCATE TABLE test_students;--"')
1434        q = "select * from test_students order by 1 limit 4"
1435        r = query(q).getresult()
1436        self.assertEqual(len(r), 3)
1437        self.assertEqual(r[1][2], 'D-')
1438
1439    def testInsert(self):
1440        insert = self.db.insert
1441        query = self.db.query
1442        bool_on = pg.get_bool()
1443        decimal = pg.get_decimal()
1444        table = 'insert_test_table'
1445        self.createTable(table,
1446            'i2 smallint, i4 integer, i8 bigint,'
1447            ' d numeric, f4 real, f8 double precision, m money,'
1448            ' v4 varchar(4), c4 char(4), t text,'
1449            ' b boolean, ts timestamp', oids=True)
1450        oid_table = 'oid(%s)' % table
1451        tests = [dict(i2=None, i4=None, i8=None),
1452             (dict(i2='', i4='', i8=''), dict(i2=None, i4=None, i8=None)),
1453             (dict(i2=0, i4=0, i8=0), dict(i2=0, i4=0, i8=0)),
1454             dict(i2=42, i4=123456, i8=9876543210),
1455             dict(i2=2 ** 15 - 1,
1456                  i4=int(2 ** 31 - 1), i8=long(2 ** 63 - 1)),
1457             dict(d=None), (dict(d=''), dict(d=None)),
1458             dict(d=Decimal(0)), (dict(d=0), dict(d=Decimal(0))),
1459             dict(f4=None, f8=None), dict(f4=0, f8=0),
1460             (dict(f4='', f8=''), dict(f4=None, f8=None)),
1461             (dict(d=1234.5, f4=1234.5, f8=1234.5),
1462              dict(d=Decimal('1234.5'))),
1463             dict(d=Decimal('123.456789'), f4=12.375, f8=123.4921875),
1464             dict(d=Decimal('123456789.9876543212345678987654321')),
1465             dict(m=None), (dict(m=''), dict(m=None)),
1466             dict(m=Decimal('-1234.56')),
1467             (dict(m=('-1234.56')), dict(m=Decimal('-1234.56'))),
1468             dict(m=Decimal('1234.56')), dict(m=Decimal('123456')),
1469             (dict(m='1234.56'), dict(m=Decimal('1234.56'))),
1470             (dict(m=1234.5), dict(m=Decimal('1234.5'))),
1471             (dict(m=-1234.5), dict(m=Decimal('-1234.5'))),
1472             (dict(m=123456), dict(m=Decimal('123456'))),
1473             (dict(m='1234567.89'), dict(m=Decimal('1234567.89'))),
1474             dict(b=None), (dict(b=''), dict(b=None)),
1475             dict(b='f'), dict(b='t'),
1476             (dict(b=0), dict(b='f')), (dict(b=1), dict(b='t')),
1477             (dict(b=False), dict(b='f')), (dict(b=True), dict(b='t')),
1478             (dict(b='0'), dict(b='f')), (dict(b='1'), dict(b='t')),
1479             (dict(b='n'), dict(b='f')), (dict(b='y'), dict(b='t')),
1480             (dict(b='no'), dict(b='f')), (dict(b='yes'), dict(b='t')),
1481             (dict(b='off'), dict(b='f')), (dict(b='on'), dict(b='t')),
1482             dict(v4=None, c4=None, t=None),
1483             (dict(v4='', c4='', t=''), dict(c4=' ' * 4)),
1484             dict(v4='1234', c4='1234', t='1234' * 10),
1485             dict(v4='abcd', c4='abcd', t='abcdefg'),
1486             (dict(v4='abc', c4='abc', t='abc'), dict(c4='abc ')),
1487             dict(ts=None), (dict(ts=''), dict(ts=None)),
1488             (dict(ts=0), dict(ts=None)), (dict(ts=False), dict(ts=None)),
1489             dict(ts='2012-12-21 00:00:00'),
1490             (dict(ts='2012-12-21'), dict(ts='2012-12-21 00:00:00')),
1491             dict(ts='2012-12-21 12:21:12'),
1492             dict(ts='2013-01-05 12:13:14'),
1493             dict(ts='current_timestamp')]
1494        for test in tests:
1495            if isinstance(test, dict):
1496                data = test
1497                change = {}
1498            else:
1499                data, change = test
1500            expect = data.copy()
1501            expect.update(change)
1502            if bool_on:
1503                b = expect.get('b')
1504                if b is not None:
1505                    expect['b'] = b == 't'
1506            if decimal is not Decimal:
1507                d = expect.get('d')
1508                if d is not None:
1509                    expect['d'] = decimal(d)
1510                m = expect.get('m')
1511                if m is not None:
1512                    expect['m'] = decimal(m)
1513            self.assertEqual(insert(table, data), data)
1514            self.assertIn(oid_table, data)
1515            oid = data[oid_table]
1516            self.assertIsInstance(oid, int)
1517            data = dict(item for item in data.items()
1518                        if item[0] in expect)
1519            ts = expect.get('ts')
1520            if ts:
1521                if ts == 'current_timestamp':
1522                    ts = data['ts']
1523                    self.assertIsInstance(ts, datetime)
1524                    self.assertEqual(ts.strftime('%Y-%m-%d'),
1525                        strftime('%Y-%m-%d'))
1526                else:
1527                    ts = datetime.strptime(ts, '%Y-%m-%d %H:%M:%S')
1528                expect['ts'] = ts
1529            self.assertEqual(data, expect)
1530            data = query(
1531                'select oid,* from "%s"' % table).dictresult()[0]
1532            self.assertEqual(data['oid'], oid)
1533            data = dict(item for item in data.items()
1534                        if item[0] in expect)
1535            self.assertEqual(data, expect)
1536            query('delete from "%s"' % table)
1537
1538    def testInsertWithOid(self):
1539        insert = self.db.insert
1540        query = self.db.query
1541        self.createTable('test_table', 'n int', oids=True)
1542        self.assertRaises(pg.ProgrammingError, insert, 'test_table', m=1)
1543        r = insert('test_table', n=1)
1544        self.assertIsInstance(r, dict)
1545        self.assertEqual(r['n'], 1)
1546        self.assertNotIn('oid', r)
1547        qoid = 'oid(test_table)'
1548        self.assertIn(qoid, r)
1549        oid = r[qoid]
1550        self.assertEqual(sorted(r.keys()), ['n', qoid])
1551        r = insert('test_table', n=2, oid=oid)
1552        self.assertIsInstance(r, dict)
1553        self.assertEqual(r['n'], 2)
1554        self.assertIn(qoid, r)
1555        self.assertNotEqual(r[qoid], oid)
1556        self.assertNotIn('oid', r)
1557        r = insert('test_table', None, n=3)
1558        self.assertIsInstance(r, dict)
1559        self.assertEqual(r['n'], 3)
1560        s = r
1561        r = insert('test_table', r)
1562        self.assertIs(r, s)
1563        self.assertEqual(r['n'], 3)
1564        r = insert('test_table *', r)
1565        self.assertIs(r, s)
1566        self.assertEqual(r['n'], 3)
1567        r = insert('test_table', r, n=4)
1568        self.assertIs(r, s)
1569        self.assertEqual(r['n'], 4)
1570        self.assertNotIn('oid', r)
1571        self.assertIn(qoid, r)
1572        oid = r[qoid]
1573        r = insert('test_table', r, n=5, oid=oid)
1574        self.assertIs(r, s)
1575        self.assertEqual(r['n'], 5)
1576        self.assertIn(qoid, r)
1577        self.assertNotEqual(r[qoid], oid)
1578        self.assertNotIn('oid', r)
1579        r['oid'] = oid = r[qoid]
1580        r = insert('test_table', r, n=6)
1581        self.assertIs(r, s)
1582        self.assertEqual(r['n'], 6)
1583        self.assertIn(qoid, r)
1584        self.assertNotEqual(r[qoid], oid)
1585        self.assertNotIn('oid', r)
1586        q = 'select n from test_table order by 1 limit 9'
1587        r = ' '.join(str(row[0]) for row in query(q).getresult())
1588        self.assertEqual(r, '1 2 3 3 3 4 5 6')
1589        query("truncate test_table")
1590        query("alter table test_table add unique (n)")
1591        r = insert('test_table', dict(n=7))
1592        self.assertIsInstance(r, dict)
1593        self.assertEqual(r['n'], 7)
1594        self.assertRaises(pg.IntegrityError, insert, 'test_table', r)
1595        r['n'] = 6
1596        self.assertRaises(pg.IntegrityError, insert, 'test_table', r, n=7)
1597        self.assertIsInstance(r, dict)
1598        self.assertEqual(r['n'], 7)
1599        r['n'] = 6
1600        r = insert('test_table', r)
1601        self.assertIsInstance(r, dict)
1602        self.assertEqual(r['n'], 6)
1603        r = ' '.join(str(row[0]) for row in query(q).getresult())
1604        self.assertEqual(r, '6 7')
1605
1606    def testInsertWithQuotedNames(self):
1607        insert = self.db.insert
1608        query = self.db.query
1609        table = 'test table for insert()'
1610        self.createTable(table, '"Prime!" smallint primary key,'
1611                                ' "much space" integer, "Questions?" text')
1612        r = {'Prime!': 11, 'much space': 2002, 'Questions?': 'What?'}
1613        r = insert(table, r)
1614        self.assertIsInstance(r, dict)
1615        self.assertEqual(r['Prime!'], 11)
1616        self.assertEqual(r['much space'], 2002)
1617        self.assertEqual(r['Questions?'], 'What?')
1618        r = query('select * from "%s" limit 2' % table).dictresult()
1619        self.assertEqual(len(r), 1)
1620        r = r[0]
1621        self.assertEqual(r['Prime!'], 11)
1622        self.assertEqual(r['much space'], 2002)
1623        self.assertEqual(r['Questions?'], 'What?')
1624
1625    def testInsertIntoView(self):
1626        insert = self.db.insert
1627        query = self.db.query
1628        query("truncate test")
1629        q = 'select * from test_view order by i4 limit 3'
1630        r = query(q).getresult()
1631        self.assertEqual(r, [])
1632        r = dict(i4=1234, v4='abcd')
1633        insert('test', r)
1634        self.assertIsNone(r['i2'])
1635        self.assertEqual(r['i4'], 1234)
1636        self.assertIsNone(r['i8'])
1637        self.assertEqual(r['v4'], 'abcd')
1638        self.assertIsNone(r['c4'])
1639        r = query(q).getresult()
1640        self.assertEqual(r, [(1234, 'abcd')])
1641        r = dict(i4=5678, v4='efgh')
1642        try:
1643            insert('test_view', r)
1644        except pg.NotSupportedError as error:
1645            if self.db.server_version < 90300:
1646                # must setup rules in older PostgreSQL versions
1647                self.skipTest('database cannot insert into view')
1648            self.fail(str(error))
1649        self.assertNotIn('i2', r)
1650        self.assertEqual(r['i4'], 5678)
1651        self.assertNotIn('i8', r)
1652        self.assertEqual(r['v4'], 'efgh')
1653        self.assertNotIn('c4', r)
1654        r = query(q).getresult()
1655        self.assertEqual(r, [(1234, 'abcd'), (5678, 'efgh')])
1656
1657    def testUpdate(self):
1658        update = self.db.update
1659        query = self.db.query
1660        self.assertRaises(pg.ProgrammingError, update,
1661                          'test', i2=2, i4=4, i8=8)
1662        table = 'update_test_table'
1663        self.createTable(table, 'n integer, t text', oids=True,
1664                         values=enumerate('xyz', start=1))
1665        self.assertRaises(pg.ProgrammingError, self.db.get, table, 2)
1666        r = self.db.get(table, 2, 'n')
1667        r['t'] = 'u'
1668        s = update(table, r)
1669        self.assertEqual(s, r)
1670        q = 'select t from "%s" where n=2' % table
1671        r = query(q).getresult()[0][0]
1672        self.assertEqual(r, 'u')
1673
1674    def testUpdateWithOid(self):
1675        update = self.db.update
1676        get = self.db.get
1677        query = self.db.query
1678        self.createTable('test_table', 'n int', oids=True, values=[1])
1679        s = get('test_table', 1, 'n')
1680        self.assertIsInstance(s, dict)
1681        self.assertEqual(s['n'], 1)
1682        s['n'] = 2
1683        r = update('test_table', s)
1684        self.assertIs(r, s)
1685        self.assertEqual(r['n'], 2)
1686        qoid = 'oid(test_table)'
1687        self.assertIn(qoid, r)
1688        self.assertNotIn('oid', r)
1689        self.assertEqual(sorted(r.keys()), ['n', qoid])
1690        r['n'] = 3
1691        oid = r.pop(qoid)
1692        r = update('test_table', r, oid=oid)
1693        self.assertIs(r, s)
1694        self.assertEqual(r['n'], 3)
1695        r.pop(qoid)
1696        self.assertRaises(pg.ProgrammingError, update, 'test_table', r)
1697        s = get('test_table', 3, 'n')
1698        self.assertIsInstance(s, dict)
1699        self.assertEqual(s['n'], 3)
1700        s.pop('n')
1701        r = update('test_table', s)
1702        oid = r.pop(qoid)
1703        self.assertEqual(r, {})
1704        q = "select n from test_table limit 2"
1705        r = query(q).getresult()
1706        self.assertEqual(r, [(3,)])
1707        query("insert into test_table values (1)")
1708        self.assertRaises(pg.ProgrammingError,
1709                          update, 'test_table', dict(oid=oid, n=4))
1710        r = update('test_table', dict(n=4), oid=oid)
1711        self.assertEqual(r['n'], 4)
1712        r = update('test_table *', dict(n=5), oid=oid)
1713        self.assertEqual(r['n'], 5)
1714        query("alter table test_table add column m int")
1715        query("alter table test_table add primary key (n)")
1716        self.assertIn('m', self.db.get_attnames('test_table', flush=True))
1717        self.assertEqual('n', self.db.pkey('test_table', flush=True))
1718        s = dict(n=1, m=4)
1719        r = update('test_table', s)
1720        self.assertIs(r, s)
1721        self.assertEqual(r['n'], 1)
1722        self.assertEqual(r['m'], 4)
1723        s = dict(m=7)
1724        r = update('test_table', s, n=5)
1725        self.assertIs(r, s)
1726        self.assertEqual(r['n'], 5)
1727        self.assertEqual(r['m'], 7)
1728        q = "select n, m from test_table order by 1 limit 3"
1729        r = query(q).getresult()
1730        self.assertEqual(r, [(1, 4), (5, 7)])
1731        s = dict(m=9, oid=oid)
1732        self.assertRaises(KeyError, update, 'test_table', s)
1733        r = update('test_table', s, oid=oid)
1734        self.assertIs(r, s)
1735        self.assertEqual(r['n'], 5)
1736        self.assertEqual(r['m'], 9)
1737        s = dict(n=1, m=3, oid=oid)
1738        r = update('test_table', s)
1739        self.assertIs(r, s)
1740        self.assertEqual(r['n'], 1)
1741        self.assertEqual(r['m'], 3)
1742        r = query(q).getresult()
1743        self.assertEqual(r, [(1, 3), (5, 9)])
1744
1745    def testUpdateWithoutOid(self):
1746        update = self.db.update
1747        query = self.db.query
1748        self.assertRaises(pg.ProgrammingError, update,
1749                          'test', i2=2, i4=4, i8=8)
1750        table = 'update_test_table'
1751        self.createTable(table, 'n integer primary key, t text', oids=False,
1752                         values=enumerate('xyz', start=1))
1753        r = self.db.get(table, 2)
1754        r['t'] = 'u'
1755        s = update(table, r)
1756        self.assertEqual(s, r)
1757        q = 'select t from "%s" where n=2' % table
1758        r = query(q).getresult()[0][0]
1759        self.assertEqual(r, 'u')
1760
1761    def testUpdateWithCompositeKey(self):
1762        update = self.db.update
1763        query = self.db.query
1764        table = 'update_test_table_1'
1765        self.createTable(table, 'n integer primary key, t text',
1766                         values=enumerate('abc', start=1))
1767        self.assertRaises(KeyError, update, table, dict(t='b'))
1768        s = dict(n=2, t='d')
1769        r = update(table, s)
1770        self.assertIs(r, s)
1771        self.assertEqual(r['n'], 2)
1772        self.assertEqual(r['t'], 'd')
1773        q = 'select t from "%s" where n=2' % table
1774        r = query(q).getresult()[0][0]
1775        self.assertEqual(r, 'd')
1776        s.update(dict(n=4, t='e'))
1777        r = update(table, s)
1778        self.assertEqual(r['n'], 4)
1779        self.assertEqual(r['t'], 'e')
1780        q = 'select t from "%s" where n=2' % table
1781        r = query(q).getresult()[0][0]
1782        self.assertEqual(r, 'd')
1783        q = 'select t from "%s" where n=4' % table
1784        r = query(q).getresult()
1785        self.assertEqual(len(r), 0)
1786        query('drop table "%s"' % table)
1787        table = 'update_test_table_2'
1788        self.createTable(table,
1789                         'n integer, m integer, t text, primary key (n, m)',
1790                         values=[(n + 1, m + 1, chr(ord('a') + 2 * n + m))
1791                                 for n in range(3) for m in range(2)])
1792        self.assertRaises(KeyError, update, table, dict(n=2, t='b'))
1793        self.assertEqual(update(table,
1794                                dict(n=2, m=2, t='x'))['t'], 'x')
1795        q = 'select t from "%s" where n=2 order by m' % table
1796        r = [r[0] for r in query(q).getresult()]
1797        self.assertEqual(r, ['c', 'x'])
1798
1799    def testUpdateWithQuotedNames(self):
1800        update = self.db.update
1801        query = self.db.query
1802        table = 'test table for update()'
1803        self.createTable(table, '"Prime!" smallint primary key,'
1804                                ' "much space" integer, "Questions?" text',
1805                         values=[(13, 3003, 'Why!')])
1806        r = {'Prime!': 13, 'much space': 7007, 'Questions?': 'When?'}
1807        r = update(table, r)
1808        self.assertIsInstance(r, dict)
1809        self.assertEqual(r['Prime!'], 13)
1810        self.assertEqual(r['much space'], 7007)
1811        self.assertEqual(r['Questions?'], 'When?')
1812        r = query('select * from "%s" limit 2' % table).dictresult()
1813        self.assertEqual(len(r), 1)
1814        r = r[0]
1815        self.assertEqual(r['Prime!'], 13)
1816        self.assertEqual(r['much space'], 7007)
1817        self.assertEqual(r['Questions?'], 'When?')
1818
1819    def testUpsert(self):
1820        upsert = self.db.upsert
1821        query = self.db.query
1822        self.assertRaises(pg.ProgrammingError, upsert,
1823                          'test', i2=2, i4=4, i8=8)
1824        table = 'upsert_test_table'
1825        self.createTable(table, 'n integer primary key, t text', oids=True)
1826        s = dict(n=1, t='x')
1827        try:
1828            r = upsert(table, s)
1829        except pg.ProgrammingError as error:
1830            if self.db.server_version < 90500:
1831                self.skipTest('database does not support upsert')
1832            self.fail(str(error))
1833        self.assertIs(r, s)
1834        self.assertEqual(r['n'], 1)
1835        self.assertEqual(r['t'], 'x')
1836        s.update(n=2, t='y')
1837        r = upsert(table, s, **dict.fromkeys(s))
1838        self.assertIs(r, s)
1839        self.assertEqual(r['n'], 2)
1840        self.assertEqual(r['t'], 'y')
1841        q = 'select n, t from "%s" order by n limit 3' % table
1842        r = query(q).getresult()
1843        self.assertEqual(r, [(1, 'x'), (2, 'y')])
1844        s.update(t='z')
1845        r = upsert(table, s)
1846        self.assertIs(r, s)
1847        self.assertEqual(r['n'], 2)
1848        self.assertEqual(r['t'], 'z')
1849        r = query(q).getresult()
1850        self.assertEqual(r, [(1, 'x'), (2, 'z')])
1851        s.update(t='n')
1852        r = upsert(table, s, t=False)
1853        self.assertIs(r, s)
1854        self.assertEqual(r['n'], 2)
1855        self.assertEqual(r['t'], 'z')
1856        r = query(q).getresult()
1857        self.assertEqual(r, [(1, 'x'), (2, 'z')])
1858        s.update(t='y')
1859        r = upsert(table, s, t=True)
1860        self.assertIs(r, s)
1861        self.assertEqual(r['n'], 2)
1862        self.assertEqual(r['t'], 'y')
1863        r = query(q).getresult()
1864        self.assertEqual(r, [(1, 'x'), (2, 'y')])
1865        s.update(t='n')
1866        r = upsert(table, s, t="included.t || '2'")
1867        self.assertIs(r, s)
1868        self.assertEqual(r['n'], 2)
1869        self.assertEqual(r['t'], 'y2')
1870        r = query(q).getresult()
1871        self.assertEqual(r, [(1, 'x'), (2, 'y2')])
1872        s.update(t='y')
1873        r = upsert(table, s, t="excluded.t || '3'")
1874        self.assertIs(r, s)
1875        self.assertEqual(r['n'], 2)
1876        self.assertEqual(r['t'], 'y3')
1877        r = query(q).getresult()
1878        self.assertEqual(r, [(1, 'x'), (2, 'y3')])
1879        s.update(n=1, t='2')
1880        r = upsert(table, s, t="included.t || excluded.t")
1881        self.assertIs(r, s)
1882        self.assertEqual(r['n'], 1)
1883        self.assertEqual(r['t'], 'x2')
1884        r = query(q).getresult()
1885        self.assertEqual(r, [(1, 'x2'), (2, 'y3')])
1886        # not existing columns and oid parameter should be ignored
1887        s = dict(m=3, u='z')
1888        r = upsert(table, s, oid='invalid')
1889        self.assertIs(r, s)
1890
1891    def testUpsertWithOid(self):
1892        upsert = self.db.upsert
1893        get = self.db.get
1894        query = self.db.query
1895        self.createTable('test_table', 'n int', oids=True, values=[1])
1896        self.assertRaises(pg.ProgrammingError,
1897                          upsert, 'test_table', dict(n=2))
1898        r = get('test_table', 1, 'n')
1899        self.assertIsInstance(r, dict)
1900        self.assertEqual(r['n'], 1)
1901        qoid = 'oid(test_table)'
1902        self.assertIn(qoid, r)
1903        self.assertNotIn('oid', r)
1904        oid = r[qoid]
1905        self.assertRaises(pg.ProgrammingError,
1906                          upsert, 'test_table', dict(n=2, oid=oid))
1907        query("alter table test_table add column m int")
1908        query("alter table test_table add primary key (n)")
1909        self.assertIn('m', self.db.get_attnames('test_table', flush=True))
1910        self.assertEqual('n', self.db.pkey('test_table', flush=True))
1911        s = dict(n=2)
1912        try:
1913            r = upsert('test_table', s)
1914        except pg.ProgrammingError as error:
1915            if self.db.server_version < 90500:
1916                self.skipTest('database does not support upsert')
1917            self.fail(str(error))
1918        self.assertIs(r, s)
1919        self.assertEqual(r['n'], 2)
1920        self.assertIsNone(r['m'])
1921        q = query("select n, m from test_table order by n limit 3")
1922        self.assertEqual(q.getresult(), [(1, None), (2, None)])
1923        r['oid'] = oid
1924        r = upsert('test_table', r)
1925        self.assertIs(r, s)
1926        self.assertEqual(r['n'], 2)
1927        self.assertIsNone(r['m'])
1928        self.assertIn(qoid, r)
1929        self.assertNotIn('oid', r)
1930        self.assertNotEqual(r[qoid], oid)
1931        r['m'] = 7
1932        r = upsert('test_table', r)
1933        self.assertIs(r, s)
1934        self.assertEqual(r['n'], 2)
1935        self.assertEqual(r['m'], 7)
1936        r.update(n=1, m=3)
1937        r = upsert('test_table', r)
1938        self.assertIs(r, s)
1939        self.assertEqual(r['n'], 1)
1940        self.assertEqual(r['m'], 3)
1941        q = query("select n, m from test_table order by n limit 3")
1942        self.assertEqual(q.getresult(), [(1, 3), (2, 7)])
1943        r = upsert('test_table', r, oid='invalid')
1944        self.assertIs(r, s)
1945        self.assertEqual(r['n'], 1)
1946        self.assertEqual(r['m'], 3)
1947        r['m'] = 5
1948        r = upsert('test_table', r, m=False)
1949        self.assertIs(r, s)
1950        self.assertEqual(r['n'], 1)
1951        self.assertEqual(r['m'], 3)
1952        r['m'] = 5
1953        r = upsert('test_table', r, m=True)
1954        self.assertIs(r, s)
1955        self.assertEqual(r['n'], 1)
1956        self.assertEqual(r['m'], 5)
1957        r.update(n=2, m=1)
1958        r = upsert('test_table', r, m='included.m')
1959        self.assertIs(r, s)
1960        self.assertEqual(r['n'], 2)
1961        self.assertEqual(r['m'], 7)
1962        r['m'] = 9
1963        r = upsert('test_table', r, m='excluded.m')
1964        self.assertIs(r, s)
1965        self.assertEqual(r['n'], 2)
1966        self.assertEqual(r['m'], 9)
1967        r['m'] = 8
1968        r = upsert('test_table *', r, m='included.m + 1')
1969        self.assertIs(r, s)
1970        self.assertEqual(r['n'], 2)
1971        self.assertEqual(r['m'], 10)
1972        q = query("select n, m from test_table order by n limit 3")
1973        self.assertEqual(q.getresult(), [(1, 5), (2, 10)])
1974
1975    def testUpsertWithCompositeKey(self):
1976        upsert = self.db.upsert
1977        query = self.db.query
1978        table = 'upsert_test_table_2'
1979        self.createTable(table,
1980                         'n integer, m integer, t text, primary key (n, m)')
1981        s = dict(n=1, m=2, t='x')
1982        try:
1983            r = upsert(table, s)
1984        except pg.ProgrammingError as error:
1985            if self.db.server_version < 90500:
1986                self.skipTest('database does not support upsert')
1987            self.fail(str(error))
1988        self.assertIs(r, s)
1989        self.assertEqual(r['n'], 1)
1990        self.assertEqual(r['m'], 2)
1991        self.assertEqual(r['t'], 'x')
1992        s.update(m=3, t='y')
1993        r = upsert(table, s, **dict.fromkeys(s))
1994        self.assertIs(r, s)
1995        self.assertEqual(r['n'], 1)
1996        self.assertEqual(r['m'], 3)
1997        self.assertEqual(r['t'], 'y')
1998        q = 'select n, m, t from "%s" order by n, m limit 3' % table
1999        r = query(q).getresult()
2000        self.assertEqual(r, [(1, 2, 'x'), (1, 3, 'y')])
2001        s.update(t='z')
2002        r = upsert(table, s)
2003        self.assertIs(r, s)
2004        self.assertEqual(r['n'], 1)
2005        self.assertEqual(r['m'], 3)
2006        self.assertEqual(r['t'], 'z')
2007        r = query(q).getresult()
2008        self.assertEqual(r, [(1, 2, 'x'), (1, 3, 'z')])
2009        s.update(t='n')
2010        r = upsert(table, s, t=False)
2011        self.assertIs(r, s)
2012        self.assertEqual(r['n'], 1)
2013        self.assertEqual(r['m'], 3)
2014        self.assertEqual(r['t'], 'z')
2015        r = query(q).getresult()
2016        self.assertEqual(r, [(1, 2, 'x'), (1, 3, 'z')])
2017        s.update(t='n')
2018        r = upsert(table, s, t=True)
2019        self.assertIs(r, s)
2020        self.assertEqual(r['n'], 1)
2021        self.assertEqual(r['m'], 3)
2022        self.assertEqual(r['t'], 'n')
2023        r = query(q).getresult()
2024        self.assertEqual(r, [(1, 2, 'x'), (1, 3, 'n')])
2025        s.update(n=2, t='y')
2026        r = upsert(table, s, t="'z'")
2027        self.assertIs(r, s)
2028        self.assertEqual(r['n'], 2)
2029        self.assertEqual(r['m'], 3)
2030        self.assertEqual(r['t'], 'y')
2031        r = query(q).getresult()
2032        self.assertEqual(r, [(1, 2, 'x'), (1, 3, 'n'), (2, 3, 'y')])
2033        s.update(n=1, t='m')
2034        r = upsert(table, s, t='included.t || excluded.t')
2035        self.assertIs(r, s)
2036        self.assertEqual(r['n'], 1)
2037        self.assertEqual(r['m'], 3)
2038        self.assertEqual(r['t'], 'nm')
2039        r = query(q).getresult()
2040        self.assertEqual(r, [(1, 2, 'x'), (1, 3, 'nm'), (2, 3, 'y')])
2041
2042    def testUpsertWithQuotedNames(self):
2043        upsert = self.db.upsert
2044        query = self.db.query
2045        table = 'test table for upsert()'
2046        self.createTable(table, '"Prime!" smallint primary key,'
2047                                ' "much space" integer, "Questions?" text')
2048        s = {'Prime!': 31, 'much space': 9009, 'Questions?': 'Yes.'}
2049        try:
2050            r = upsert(table, s)
2051        except pg.ProgrammingError as error:
2052            if self.db.server_version < 90500:
2053                self.skipTest('database does not support upsert')
2054            self.fail(str(error))
2055        self.assertIs(r, s)
2056        self.assertEqual(r['Prime!'], 31)
2057        self.assertEqual(r['much space'], 9009)
2058        self.assertEqual(r['Questions?'], 'Yes.')
2059        q = 'select * from "%s" limit 2' % table
2060        r = query(q).getresult()
2061        self.assertEqual(r, [(31, 9009, 'Yes.')])
2062        s.update({'Questions?': 'No.'})
2063        r = upsert(table, s)
2064        self.assertIs(r, s)
2065        self.assertEqual(r['Prime!'], 31)
2066        self.assertEqual(r['much space'], 9009)
2067        self.assertEqual(r['Questions?'], 'No.')
2068        r = query(q).getresult()
2069        self.assertEqual(r, [(31, 9009, 'No.')])
2070
2071    def testClear(self):
2072        clear = self.db.clear
2073        f = False if pg.get_bool() else 'f'
2074        r = clear('test')
2075        result = dict(
2076            i2=0, i4=0, i8=0, d=0, f4=0, f8=0, m=0, v4='', c4='', t='')
2077        self.assertEqual(r, result)
2078        table = 'clear_test_table'
2079        self.createTable(table,
2080            'n integer, f float, b boolean, d date, t text', oids=True)
2081        r = clear(table)
2082        result = dict(n=0, f=0, b=f, d='', t='')
2083        self.assertEqual(r, result)
2084        r['a'] = r['f'] = r['n'] = 1
2085        r['d'] = r['t'] = 'x'
2086        r['b'] = 't'
2087        r['oid'] = long(1)
2088        r = clear(table, r)
2089        result = dict(a=1, n=0, f=0, b=f, d='', t='', oid=long(1))
2090        self.assertEqual(r, result)
2091
2092    def testClearWithQuotedNames(self):
2093        clear = self.db.clear
2094        table = 'test table for clear()'
2095        self.createTable(table, '"Prime!" smallint primary key,'
2096            ' "much space" integer, "Questions?" text')
2097        r = clear(table)
2098        self.assertIsInstance(r, dict)
2099        self.assertEqual(r['Prime!'], 0)
2100        self.assertEqual(r['much space'], 0)
2101        self.assertEqual(r['Questions?'], '')
2102
2103    def testDelete(self):
2104        delete = self.db.delete
2105        query = self.db.query
2106        self.assertRaises(pg.ProgrammingError, delete,
2107                          'test', dict(i2=2, i4=4, i8=8))
2108        table = 'delete_test_table'
2109        self.createTable(table, 'n integer, t text', oids=True,
2110                         values=enumerate('xyz', start=1))
2111        self.assertRaises(pg.ProgrammingError, self.db.get, table, 2)
2112        r = self.db.get(table, 1, 'n')
2113        s = delete(table, r)
2114        self.assertEqual(s, 1)
2115        r = self.db.get(table, 3, 'n')
2116        s = delete(table, r)
2117        self.assertEqual(s, 1)
2118        s = delete(table, r)
2119        self.assertEqual(s, 0)
2120        r = query('select * from "%s"' % table).dictresult()
2121        self.assertEqual(len(r), 1)
2122        r = r[0]
2123        result = {'n': 2, 't': 'y'}
2124        self.assertEqual(r, result)
2125        r = self.db.get(table, 2, 'n')
2126        s = delete(table, r)
2127        self.assertEqual(s, 1)
2128        s = delete(table, r)
2129        self.assertEqual(s, 0)
2130        self.assertRaises(pg.DatabaseError, self.db.get, table, 2, 'n')
2131        # not existing columns and oid parameter should be ignored
2132        r.update(m=3, u='z', oid='invalid')
2133        s = delete(table, r)
2134        self.assertEqual(s, 0)
2135
2136    def testDeleteWithOid(self):
2137        delete = self.db.delete
2138        get = self.db.get
2139        query = self.db.query
2140        self.createTable('test_table', 'n int', oids=True, values=range(1, 7))
2141        r = dict(n=3)
2142        self.assertRaises(pg.ProgrammingError, delete, 'test_table', r)
2143        s = get('test_table', 1, 'n')
2144        qoid = 'oid(test_table)'
2145        self.assertIn(qoid, s)
2146        r = delete('test_table', s)
2147        self.assertEqual(r, 1)
2148        r = delete('test_table', s)
2149        self.assertEqual(r, 0)
2150        q = "select min(n),count(n) from test_table"
2151        self.assertEqual(query(q).getresult()[0], (2, 5))
2152        oid = get('test_table', 2, 'n')[qoid]
2153        s = dict(oid=oid, n=2)
2154        self.assertRaises(pg.ProgrammingError, delete, 'test_table', s)
2155        r = delete('test_table', None, oid=oid)
2156        self.assertEqual(r, 1)
2157        r = delete('test_table', None, oid=oid)
2158        self.assertEqual(r, 0)
2159        self.assertEqual(query(q).getresult()[0], (3, 4))
2160        s = dict(oid=oid, n=2)
2161        oid = get('test_table', 3, 'n')[qoid]
2162        self.assertRaises(pg.ProgrammingError, delete, 'test_table', s)
2163        r = delete('test_table', s, oid=oid)
2164        self.assertEqual(r, 1)
2165        r = delete('test_table', s, oid=oid)
2166        self.assertEqual(r, 0)
2167        self.assertEqual(query(q).getresult()[0], (4, 3))
2168        s = get('test_table', 4, 'n')
2169        r = delete('test_table *', s)
2170        self.assertEqual(r, 1)
2171        r = delete('test_table *', s)
2172        self.assertEqual(r, 0)
2173        self.assertEqual(query(q).getresult()[0], (5, 2))
2174        oid = get('test_table', 5, 'n')[qoid]
2175        s = {qoid: oid, 'm': 4}
2176        r = delete('test_table', s, m=6)
2177        self.assertEqual(r, 1)
2178        r = delete('test_table *', s)
2179        self.assertEqual(r, 0)
2180        self.assertEqual(query(q).getresult()[0], (6, 1))
2181        query("alter table test_table add column m int")
2182        query("alter table test_table add primary key (n)")
2183        self.assertIn('m', self.db.get_attnames('test_table', flush=True))
2184        self.assertEqual('n', self.db.pkey('test_table', flush=True))
2185        for i in range(5):
2186            query("insert into test_table values (%d, %d)" % (i + 1, i + 2))
2187        s = dict(m=2)
2188        self.assertRaises(KeyError, delete, 'test_table', s)
2189        s = dict(m=2, oid=oid)
2190        self.assertRaises(KeyError, delete, 'test_table', s)
2191        r = delete('test_table', dict(m=2), oid=oid)
2192        self.assertEqual(r, 0)
2193        oid = get('test_table', 1, 'n')[qoid]
2194        s = dict(oid=oid)
2195        self.assertRaises(KeyError, delete, 'test_table', s)
2196        r = delete('test_table', s, oid=oid)
2197        self.assertEqual(r, 1)
2198        r = delete('test_table', s, oid=oid)
2199        self.assertEqual(r, 0)
2200        self.assertEqual(query(q).getresult()[0], (2, 5))
2201        s = get('test_table', 2, 'n')
2202        del s['n']
2203        r = delete('test_table', s)
2204        self.assertEqual(r, 1)
2205        r = delete('test_table', s)
2206        self.assertEqual(r, 0)
2207        self.assertEqual(query(q).getresult()[0], (3, 4))
2208        r = delete('test_table', n=3)
2209        self.assertEqual(r, 1)
2210        r = delete('test_table', n=3)
2211        self.assertEqual(r, 0)
2212        self.assertEqual(query(q).getresult()[0], (4, 3))
2213        r = delete('test_table', None, n=4)
2214        self.assertEqual(r, 1)
2215        r = delete('test_table', None, n=4)
2216        self.assertEqual(r, 0)
2217        self.assertEqual(query(q).getresult()[0], (5, 2))
2218        s = dict(n=6)
2219        r = delete('test_table', s, n=5)
2220        self.assertEqual(r, 1)
2221        r = delete('test_table', s, n=5)
2222        self.assertEqual(r, 0)
2223        self.assertEqual(query(q).getresult()[0], (6, 1))
2224
2225    def testDeleteWithCompositeKey(self):
2226        query = self.db.query
2227        table = 'delete_test_table_1'
2228        self.createTable(table, 'n integer primary key, t text',
2229                         values=enumerate('abc', start=1))
2230        self.assertRaises(KeyError, self.db.delete, table, dict(t='b'))
2231        self.assertEqual(self.db.delete(table, dict(n=2)), 1)
2232        r = query('select t from "%s" where n=2' % table).getresult()
2233        self.assertEqual(r, [])
2234        self.assertEqual(self.db.delete(table, dict(n=2)), 0)
2235        r = query('select t from "%s" where n=3' % table).getresult()[0][0]
2236        self.assertEqual(r, 'c')
2237        table = 'delete_test_table_2'
2238        self.createTable(table,
2239             'n integer, m integer, t text, primary key (n, m)',
2240             values=[(n + 1, m + 1, chr(ord('a') + 2 * n + m))
2241                     for n in range(3) for m in range(2)])
2242        self.assertRaises(KeyError, self.db.delete, table, dict(n=2, t='b'))
2243        self.assertEqual(self.db.delete(table, dict(n=2, m=2)), 1)
2244        r = [r[0] for r in query('select t from "%s" where n=2'
2245            ' order by m' % table).getresult()]
2246        self.assertEqual(r, ['c'])
2247        self.assertEqual(self.db.delete(table, dict(n=2, m=2)), 0)
2248        r = [r[0] for r in query('select t from "%s" where n=3'
2249             ' order by m' % table).getresult()]
2250        self.assertEqual(r, ['e', 'f'])
2251        self.assertEqual(self.db.delete(table, dict(n=3, m=1)), 1)
2252        r = [r[0] for r in query('select t from "%s" where n=3'
2253             ' order by m' % table).getresult()]
2254        self.assertEqual(r, ['f'])
2255
2256    def testDeleteWithQuotedNames(self):
2257        delete = self.db.delete
2258        query = self.db.query
2259        table = 'test table for delete()'
2260        self.createTable(table, '"Prime!" smallint primary key,'
2261            ' "much space" integer, "Questions?" text',
2262            values=[(19, 5005, 'Yes!')])
2263        r = {'Prime!': 17}
2264        r = delete(table, r)
2265        self.assertEqual(r, 0)
2266        r = query('select count(*) from "%s"' % table).getresult()
2267        self.assertEqual(r[0][0], 1)
2268        r = {'Prime!': 19}
2269        r = delete(table, r)
2270        self.assertEqual(r, 1)
2271        r = query('select count(*) from "%s"' % table).getresult()
2272        self.assertEqual(r[0][0], 0)
2273
2274    def testDeleteReferenced(self):
2275        delete = self.db.delete
2276        query = self.db.query
2277        self.createTable('test_parent',
2278            'n smallint primary key', values=range(3))
2279        self.createTable('test_child',
2280            'n smallint primary key references test_parent', values=range(3))
2281        q = ("select (select count(*) from test_parent),"
2282             " (select count(*) from test_child)")
2283        self.assertEqual(query(q).getresult()[0], (3, 3))
2284        self.assertRaises(pg.IntegrityError,
2285                          delete, 'test_parent', None, n=2)
2286        self.assertRaises(pg.IntegrityError,
2287                          delete, 'test_parent *', None, n=2)
2288        r = delete('test_child', None, n=2)
2289        self.assertEqual(r, 1)
2290        self.assertEqual(query(q).getresult()[0], (3, 2))
2291        r = delete('test_parent', None, n=2)
2292        self.assertEqual(r, 1)
2293        self.assertEqual(query(q).getresult()[0], (2, 2))
2294        self.assertRaises(pg.IntegrityError,
2295                          delete, 'test_parent', dict(n=0))
2296        self.assertRaises(pg.IntegrityError,
2297                          delete, 'test_parent *', dict(n=0))
2298        r = delete('test_child', dict(n=0))
2299        self.assertEqual(r, 1)
2300        self.assertEqual(query(q).getresult()[0], (2, 1))
2301        r = delete('test_child', dict(n=0))
2302        self.assertEqual(r, 0)
2303        r = delete('test_parent', dict(n=0))
2304        self.assertEqual(r, 1)
2305        self.assertEqual(query(q).getresult()[0], (1, 1))
2306        r = delete('test_parent', None, n=0)
2307        self.assertEqual(r, 0)
2308        q = "select n from test_parent natural join test_child limit 2"
2309        self.assertEqual(query(q).getresult(), [(1,)])
2310
2311    def testTruncate(self):
2312        truncate = self.db.truncate
2313        self.assertRaises(TypeError, truncate, None)
2314        self.assertRaises(TypeError, truncate, 42)
2315        self.assertRaises(TypeError, truncate, dict(test_table=None))
2316        query = self.db.query
2317        self.createTable('test_table', 'n smallint',
2318                         temporary=False, values=[1] * 3)
2319        q = "select count(*) from test_table"
2320        r = query(q).getresult()[0][0]
2321        self.assertEqual(r, 3)
2322        truncate('test_table')
2323        r = query(q).getresult()[0][0]
2324        self.assertEqual(r, 0)
2325        for i in range(3):
2326            query("insert into test_table values (1)")
2327        r = query(q).getresult()[0][0]
2328        self.assertEqual(r, 3)
2329        truncate('public.test_table')
2330        r = query(q).getresult()[0][0]
2331        self.assertEqual(r, 0)
2332        self.createTable('test_table_2', 'n smallint', temporary=True)
2333        for t in (list, tuple, set):
2334            for i in range(3):
2335                query("insert into test_table values (1)")
2336                query("insert into test_table_2 values (2)")
2337            q = ("select (select count(*) from test_table),"
2338                 " (select count(*) from test_table_2)")
2339            r = query(q).getresult()[0]
2340            self.assertEqual(r, (3, 3))
2341            truncate(t(['test_table', 'test_table_2']))
2342            r = query(q).getresult()[0]
2343            self.assertEqual(r, (0, 0))
2344
2345    def testTruncateRestart(self):
2346        truncate = self.db.truncate
2347        self.assertRaises(TypeError, truncate, 'test_table', restart='invalid')
2348        query = self.db.query
2349        self.createTable('test_table', 'n serial, t text')
2350        for n in range(3):
2351            query("insert into test_table (t) values ('test')")
2352        q = "select count(n), min(n), max(n) from test_table"
2353        r = query(q).getresult()[0]
2354        self.assertEqual(r, (3, 1, 3))
2355        truncate('test_table')
2356        r = query(q).getresult()[0]
2357        self.assertEqual(r, (0, None, None))
2358        for n in range(3):
2359            query("insert into test_table (t) values ('test')")
2360        r = query(q).getresult()[0]
2361        self.assertEqual(r, (3, 4, 6))
2362        truncate('test_table', restart=True)
2363        r = query(q).getresult()[0]
2364        self.assertEqual(r, (0, None, None))
2365        for n in range(3):
2366            query("insert into test_table (t) values ('test')")
2367        r = query(q).getresult()[0]
2368        self.assertEqual(r, (3, 1, 3))
2369
2370    def testTruncateCascade(self):
2371        truncate = self.db.truncate
2372        self.assertRaises(TypeError, truncate, 'test_table', cascade='invalid')
2373        query = self.db.query
2374        self.createTable('test_parent', 'n smallint primary key',
2375                         values=range(3))
2376        self.createTable('test_child',
2377                         'n smallint primary key references test_parent (n)',
2378                         values=range(3))
2379        q = ("select (select count(*) from test_parent),"
2380             " (select count(*) from test_child)")
2381        r = query(q).getresult()[0]
2382        self.assertEqual(r, (3, 3))
2383        self.assertRaises(pg.NotSupportedError, truncate, 'test_parent')
2384        truncate(['test_parent', 'test_child'])
2385        r = query(q).getresult()[0]
2386        self.assertEqual(r, (0, 0))
2387        for n in range(3):
2388            query("insert into test_parent (n) values (%d)" % n)
2389            query("insert into test_child (n) values (%d)" % n)
2390        r = query(q).getresult()[0]
2391        self.assertEqual(r, (3, 3))
2392        truncate('test_parent', cascade=True)
2393        r = query(q).getresult()[0]
2394        self.assertEqual(r, (0, 0))
2395        for n in range(3):
2396            query("insert into test_parent (n) values (%d)" % n)
2397            query("insert into test_child (n) values (%d)" % n)
2398        r = query(q).getresult()[0]
2399        self.assertEqual(r, (3, 3))
2400        truncate('test_child')
2401        r = query(q).getresult()[0]
2402        self.assertEqual(r, (3, 0))
2403        self.assertRaises(pg.NotSupportedError, truncate, 'test_parent')
2404        truncate('test_parent', cascade=True)
2405        r = query(q).getresult()[0]
2406        self.assertEqual(r, (0, 0))
2407
2408    def testTruncateOnly(self):
2409        truncate = self.db.truncate
2410        self.assertRaises(TypeError, truncate, 'test_table', only='invalid')
2411        query = self.db.query
2412        self.createTable('test_parent', 'n smallint')
2413        self.createTable('test_child', 'm smallint) inherits (test_parent')
2414        for n in range(3):
2415            query("insert into test_parent (n) values (1)")
2416            query("insert into test_child (n, m) values (2, 3)")
2417        q = ("select (select count(*) from test_parent),"
2418             " (select count(*) from test_child)")
2419        r = query(q).getresult()[0]
2420        self.assertEqual(r, (6, 3))
2421        truncate('test_parent')
2422        r = query(q).getresult()[0]
2423        self.assertEqual(r, (0, 0))
2424        for n in range(3):
2425            query("insert into test_parent (n) values (1)")
2426            query("insert into test_child (n, m) values (2, 3)")
2427        r = query(q).getresult()[0]
2428        self.assertEqual(r, (6, 3))
2429        truncate('test_parent*')
2430        r = query(q).getresult()[0]
2431        self.assertEqual(r, (0, 0))
2432        for n in range(3):
2433            query("insert into test_parent (n) values (1)")
2434            query("insert into test_child (n, m) values (2, 3)")
2435        r = query(q).getresult()[0]
2436        self.assertEqual(r, (6, 3))
2437        truncate('test_parent', only=True)
2438        r = query(q).getresult()[0]
2439        self.assertEqual(r, (3, 3))
2440        truncate('test_parent', only=False)
2441        r = query(q).getresult()[0]
2442        self.assertEqual(r, (0, 0))
2443        self.assertRaises(ValueError, truncate, 'test_parent*', only=True)
2444        truncate('test_parent*', only=False)
2445        self.createTable('test_parent_2', 'n smallint')
2446        self.createTable('test_child_2', 'm smallint) inherits (test_parent_2')
2447        for t in '', '_2':
2448            for n in range(3):
2449                query("insert into test_parent%s (n) values (1)" % t)
2450                query("insert into test_child%s (n, m) values (2, 3)" % t)
2451        q = ("select (select count(*) from test_parent),"
2452             " (select count(*) from test_child),"
2453             " (select count(*) from test_parent_2),"
2454             " (select count(*) from test_child_2)")
2455        r = query(q).getresult()[0]
2456        self.assertEqual(r, (6, 3, 6, 3))
2457        truncate(['test_parent', 'test_parent_2'], only=[False, True])
2458        r = query(q).getresult()[0]
2459        self.assertEqual(r, (0, 0, 3, 3))
2460        truncate(['test_parent', 'test_parent_2'], only=False)
2461        r = query(q).getresult()[0]
2462        self.assertEqual(r, (0, 0, 0, 0))
2463        self.assertRaises(ValueError, truncate,
2464            ['test_parent*', 'test_child'], only=[True, False])
2465        truncate(['test_parent*', 'test_child'], only=[False, True])
2466
2467    def testTruncateQuoted(self):
2468        truncate = self.db.truncate
2469        query = self.db.query
2470        table = "test table for truncate()"
2471        self.createTable(table, 'n smallint', temporary=False, values=[1] * 3)
2472        q = 'select count(*) from "%s"' % table
2473        r = query(q).getresult()[0][0]
2474        self.assertEqual(r, 3)
2475        truncate(table)
2476        r = query(q).getresult()[0][0]
2477        self.assertEqual(r, 0)
2478        for i in range(3):
2479            query('insert into "%s" values (1)' % table)
2480        r = query(q).getresult()[0][0]
2481        self.assertEqual(r, 3)
2482        truncate('public."%s"' % table)
2483        r = query(q).getresult()[0][0]
2484        self.assertEqual(r, 0)
2485
2486    def testGetAsList(self):
2487        get_as_list = self.db.get_as_list
2488        self.assertRaises(TypeError, get_as_list)
2489        self.assertRaises(TypeError, get_as_list, None)
2490        query = self.db.query
2491        table = 'test_aslist'
2492        r = query('select 1 as colname').namedresult()[0]
2493        self.assertIsInstance(r, tuple)
2494        named = hasattr(r, 'colname')
2495        names = [(1, 'Homer'), (2, 'Marge'),
2496                 (3, 'Bart'), (4, 'Lisa'), (5, 'Maggie')]
2497        self.createTable(table,
2498            'id smallint primary key, name varchar', values=names)
2499        r = get_as_list(table)
2500        self.assertIsInstance(r, list)
2501        self.assertEqual(r, names)
2502        for t, n in zip(r, names):
2503            self.assertIsInstance(t, tuple)
2504            self.assertEqual(t, n)
2505            if named:
2506                self.assertEqual(t.id, n[0])
2507                self.assertEqual(t.name, n[1])
2508                self.assertEqual(t._asdict(), dict(id=n[0], name=n[1]))
2509        r = get_as_list(table, what='name')
2510        self.assertIsInstance(r, list)
2511        expected = sorted((row[1],) for row in names)
2512        self.assertEqual(r, expected)
2513        r = get_as_list(table, what='name, id')
2514        self.assertIsInstance(r, list)
2515        expected = sorted(tuple(reversed(row)) for row in names)
2516        self.assertEqual(r, expected)
2517        r = get_as_list(table, what=['name', 'id'])
2518        self.assertIsInstance(r, list)
2519        self.assertEqual(r, expected)
2520        r = get_as_list(table, where="name like 'Ba%'")
2521        self.assertIsInstance(r, list)
2522        self.assertEqual(r, names[2:3])
2523        r = get_as_list(table, what='name', where="name like 'Ma%'")
2524        self.assertIsInstance(r, list)
2525        self.assertEqual(r, [('Maggie',), ('Marge',)])
2526        r = get_as_list(table, what='name',
2527            where=["name like 'Ma%'", "name like '%r%'"])
2528        self.assertIsInstance(r, list)
2529        self.assertEqual(r, [('Marge',)])
2530        r = get_as_list(table, what='name', order='id')
2531        self.assertIsInstance(r, list)
2532        expected = [(row[1],) for row in names]
2533        self.assertEqual(r, expected)
2534        r = get_as_list(table, what=['name'], order=['id'])
2535        self.assertIsInstance(r, list)
2536        self.assertEqual(r, expected)
2537        r = get_as_list(table, what=['id', 'name'], order=['id', 'name'])
2538        self.assertIsInstance(r, list)
2539        self.assertEqual(r, names)
2540        r = get_as_list(table, what='id * 2 as num', order='id desc')
2541        self.assertIsInstance(r, list)
2542        expected = [(n,) for n in range(10, 0, -2)]
2543        self.assertEqual(r, expected)
2544        r = get_as_list(table, limit=2)
2545        self.assertIsInstance(r, list)
2546        self.assertEqual(r, names[:2])
2547        r = get_as_list(table, offset=3)
2548        self.assertIsInstance(r, list)
2549        self.assertEqual(r, names[3:])
2550        r = get_as_list(table, limit=1, offset=2)
2551        self.assertIsInstance(r, list)
2552        self.assertEqual(r, names[2:3])
2553        r = get_as_list(table, scalar=True)
2554        self.assertIsInstance(r, list)
2555        self.assertEqual(r, list(range(1, 6)))
2556        r = get_as_list(table, what='name', scalar=True)
2557        self.assertIsInstance(r, list)
2558        expected = sorted(row[1] for row in names)
2559        self.assertEqual(r, expected)
2560        r = get_as_list(table, what='name', limit=1, scalar=True)
2561        self.assertIsInstance(r, list)
2562        self.assertEqual(r, expected[:1])
2563        query('alter table "%s" drop constraint "%s_pkey"' % (table, table))
2564        self.assertRaises(KeyError, self.db.pkey, table, flush=True)
2565        names.insert(1, (1, 'Snowball'))
2566        query('insert into "%s" values ($1, $2)' % table, (1, 'Snowball'))
2567        r = get_as_list(table)
2568        self.assertIsInstance(r, list)
2569        self.assertEqual(r, names)
2570        r = get_as_list(table, what='name', where='id=1', scalar=True)
2571        self.assertIsInstance(r, list)
2572        self.assertEqual(r, ['Homer', 'Snowball'])
2573        # test with unordered query
2574        r = get_as_list(table, order=False)
2575        self.assertIsInstance(r, list)
2576        self.assertEqual(set(r), set(names))
2577        # test with arbitrary from clause
2578        from_table = '(select lower(name) as n2 from "%s") as t2' % table
2579        r = get_as_list(from_table)
2580        self.assertIsInstance(r, list)
2581        r = set(row[0] for row in r)
2582        expected = set(row[1].lower() for row in names)
2583        self.assertEqual(r, expected)
2584        r = get_as_list(from_table, order='n2', scalar=True)
2585        self.assertIsInstance(r, list)
2586        self.assertEqual(r, sorted(expected))
2587        r = get_as_list(from_table, order='n2', limit=1)
2588        self.assertIsInstance(r, list)
2589        self.assertEqual(len(r), 1)
2590        t = r[0]
2591        self.assertIsInstance(t, tuple)
2592        if named:
2593            self.assertEqual(t.n2, 'bart')
2594            self.assertEqual(t._asdict(), dict(n2='bart'))
2595        else:
2596            self.assertEqual(t, ('bart',))
2597
2598    def testGetAsDict(self):
2599        get_as_dict = self.db.get_as_dict
2600        self.assertRaises(TypeError, get_as_dict)
2601        self.assertRaises(TypeError, get_as_dict, None)
2602        # the test table has no primary key
2603        self.assertRaises(pg.ProgrammingError, get_as_dict, 'test')
2604        query = self.db.query
2605        table = 'test_asdict'
2606        r = query('select 1 as colname').namedresult()[0]
2607        self.assertIsInstance(r, tuple)
2608        named = hasattr(r, 'colname')
2609        colors = [(1, '#7cb9e8', 'Aero'), (2, '#b5a642', 'Brass'),
2610                  (3, '#b2ffff', 'Celeste'), (4, '#c19a6b', 'Desert')]
2611        self.createTable(table,
2612            'id smallint primary key, rgb char(7), name varchar',
2613            values=colors)
2614        # keyname must be string, list or tuple
2615        self.assertRaises(KeyError, get_as_dict, table, 3)
2616        self.assertRaises(KeyError, get_as_dict, table, dict(id=None))
2617        # missing keyname in row
2618        self.assertRaises(KeyError, get_as_dict, table,
2619                          keyname='rgb', what='name')
2620        r = get_as_dict(table)
2621        self.assertIsInstance(r, OrderedDict)
2622        expected = OrderedDict((row[0], row[1:]) for row in colors)
2623        self.assertEqual(r, expected)
2624        for key in r:
2625            self.assertIsInstance(key, int)
2626            self.assertIn(key, expected)
2627            row = r[key]
2628            self.assertIsInstance(row, tuple)
2629            t = expected[key]
2630            self.assertEqual(row, t)
2631            if named:
2632                self.assertEqual(row.rgb, t[0])
2633                self.assertEqual(row.name, t[1])
2634                self.assertEqual(row._asdict(), dict(rgb=t[0], name=t[1]))
2635        if OrderedDict is not dict:  # Python > 2.6
2636            self.assertEqual(r.keys(), expected.keys())
2637        r = get_as_dict(table, keyname='rgb')
2638        self.assertIsInstance(r, OrderedDict)
2639        expected = OrderedDict((row[1], (row[0], row[2]))
2640                               for row in sorted(colors, key=itemgetter(1)))
2641        self.assertEqual(r, expected)
2642        for key in r:
2643            self.assertIsInstance(key, str)
2644            self.assertIn(key, expected)
2645            row = r[key]
2646            self.assertIsInstance(row, tuple)
2647            t = expected[key]
2648            self.assertEqual(row, t)
2649            if named:
2650                self.assertEqual(row.id, t[0])
2651                self.assertEqual(row.name, t[1])
2652                self.assertEqual(row._asdict(), dict(id=t[0], name=t[1]))
2653        if OrderedDict is not dict:  # Python > 2.6
2654            self.assertEqual(r.keys(), expected.keys())
2655        r = get_as_dict(table, keyname=['id', 'rgb'])
2656        self.assertIsInstance(r, OrderedDict)
2657        expected = OrderedDict((row[:2], row[2:]) for row in colors)
2658        self.assertEqual(r, expected)
2659        for key in r:
2660            self.assertIsInstance(key, tuple)
2661            self.assertIsInstance(key[0], int)
2662            self.assertIsInstance(key[1], str)
2663            if named:
2664                self.assertEqual(key, (key.id, key.rgb))
2665                self.assertEqual(key._fields, ('id', 'rgb'))
2666            row = r[key]
2667            self.assertIsInstance(row, tuple)
2668            self.assertIsInstance(row[0], str)
2669            t = expected[key]
2670            self.assertEqual(row, t)
2671            if named:
2672                self.assertEqual(row.name, t[0])
2673                self.assertEqual(row._asdict(), dict(name=t[0]))
2674        if OrderedDict is not dict:  # Python > 2.6
2675            self.assertEqual(r.keys(), expected.keys())
2676        r = get_as_dict(table, keyname=['id', 'rgb'], scalar=True)
2677        self.assertIsInstance(r, OrderedDict)
2678        expected = OrderedDict((row[:2], row[2]) for row in colors)
2679        self.assertEqual(r, expected)
2680        for key in r:
2681            self.assertIsInstance(key, tuple)
2682            row = r[key]
2683            self.assertIsInstance(row, str)
2684            t = expected[key]
2685            self.assertEqual(row, t)
2686        if OrderedDict is not dict:  # Python > 2.6
2687            self.assertEqual(r.keys(), expected.keys())
2688        r = get_as_dict(table, keyname='rgb', what=['rgb', 'name'], scalar=True)
2689        self.assertIsInstance(r, OrderedDict)
2690        expected = OrderedDict((row[1], row[2])
2691            for row in sorted(colors, key=itemgetter(1)))
2692        self.assertEqual(r, expected)
2693        for key in r:
2694            self.assertIsInstance(key, str)
2695            row = r[key]
2696            self.assertIsInstance(row, str)
2697            t = expected[key]
2698            self.assertEqual(row, t)
2699        if OrderedDict is not dict:  # Python > 2.6
2700            self.assertEqual(r.keys(), expected.keys())
2701        r = get_as_dict(table, what='id, name',
2702            where="rgb like '#b%'", scalar=True)
2703        self.assertIsInstance(r, OrderedDict)
2704        expected = OrderedDict((row[0], row[2]) for row in colors[1:3])
2705        self.assertEqual(r, expected)
2706        for key in r:
2707            self.assertIsInstance(key, int)
2708            row = r[key]
2709            self.assertIsInstance(row, str)
2710            t = expected[key]
2711            self.assertEqual(row, t)
2712        if OrderedDict is not dict:  # Python > 2.6
2713            self.assertEqual(r.keys(), expected.keys())
2714        expected = r
2715        r = get_as_dict(table, what=['name', 'id'],
2716            where=['id > 1', 'id < 4', "rgb like '#b%'",
2717                   "name not like 'A%'", "name not like '%t'"], scalar=True)
2718        self.assertEqual(r, expected)
2719        r = get_as_dict(table, what='name, id', limit=2, offset=1, scalar=True)
2720        self.assertEqual(r, expected)
2721        r = get_as_dict(table, keyname=('id',), what=('name', 'id'),
2722            where=('id > 1', 'id < 4'), order=('id',), scalar=True)
2723        self.assertEqual(r, expected)
2724        r = get_as_dict(table, limit=1)
2725        self.assertEqual(len(r), 1)
2726        self.assertEqual(r[1][1], 'Aero')
2727        r = get_as_dict(table, offset=3)
2728        self.assertEqual(len(r), 1)
2729        self.assertEqual(r[4][1], 'Desert')
2730        r = get_as_dict(table, order='id desc')
2731        expected = OrderedDict((row[0], row[1:]) for row in reversed(colors))
2732        self.assertEqual(r, expected)
2733        r = get_as_dict(table, where='id > 5')
2734        self.assertIsInstance(r, OrderedDict)
2735        self.assertEqual(len(r), 0)
2736        # test with unordered query
2737        expected = dict((row[0], row[1:]) for row in colors)
2738        r = get_as_dict(table, order=False)
2739        self.assertIsInstance(r, dict)
2740        self.assertEqual(r, expected)
2741        if dict is not OrderedDict:  # Python > 2.6
2742            self.assertNotIsInstance(self, OrderedDict)
2743        # test with arbitrary from clause
2744        from_table = '(select id, lower(name) as n2 from "%s") as t2' % table
2745        # primary key must be passed explicitly in this case
2746        self.assertRaises(pg.ProgrammingError, get_as_dict, from_table)
2747        r = get_as_dict(from_table, 'id')
2748        self.assertIsInstance(r, OrderedDict)
2749        expected = OrderedDict((row[0], (row[2].lower(),)) for row in colors)
2750        self.assertEqual(r, expected)
2751        # test without a primary key
2752        query('alter table "%s" drop constraint "%s_pkey"' % (table, table))
2753        self.assertRaises(KeyError, self.db.pkey, table, flush=True)
2754        self.assertRaises(pg.ProgrammingError, get_as_dict, table)
2755        r = get_as_dict(table, keyname='id')
2756        expected = OrderedDict((row[0], row[1:]) for row in colors)
2757        self.assertIsInstance(r, dict)
2758        self.assertEqual(r, expected)
2759        r = (1, '#007fff', 'Azure')
2760        query('insert into "%s" values ($1, $2, $3)' % table, r)
2761        # the last entry will win
2762        expected[1] = r[1:]
2763        r = get_as_dict(table, keyname='id')
2764        self.assertEqual(r, expected)
2765
2766    def testTransaction(self):
2767        query = self.db.query
2768        self.createTable('test_table', 'n integer', temporary=False)
2769        self.db.begin()
2770        query("insert into test_table values (1)")
2771        query("insert into test_table values (2)")
2772        self.db.commit()
2773        self.db.begin()
2774        query("insert into test_table values (3)")
2775        query("insert into test_table values (4)")
2776        self.db.rollback()
2777        self.db.begin()
2778        query("insert into test_table values (5)")
2779        self.db.savepoint('before6')
2780        query("insert into test_table values (6)")
2781        self.db.rollback('before6')
2782        query("insert into test_table values (7)")
2783        self.db.commit()
2784        self.db.begin()
2785        self.db.savepoint('before8')
2786        query("insert into test_table values (8)")
2787        self.db.release('before8')
2788        self.assertRaises(pg.InternalError, self.db.rollback, 'before8')
2789        self.db.commit()
2790        self.db.start()
2791        query("insert into test_table values (9)")
2792        self.db.end()
2793        r = [r[0] for r in query(
2794            "select * from test_table order by 1").getresult()]
2795        self.assertEqual(r, [1, 2, 5, 7, 9])
2796        self.db.begin(mode='read only')
2797        self.assertRaises(pg.InternalError,
2798                          query, "insert into test_table values (0)")
2799        self.db.rollback()
2800        self.db.start(mode='Read Only')
2801        self.assertRaises(pg.InternalError,
2802                          query, "insert into test_table values (0)")
2803        self.db.abort()
2804
2805    def testTransactionAliases(self):
2806        self.assertEqual(self.db.begin, self.db.start)
2807        self.assertEqual(self.db.commit, self.db.end)
2808        self.assertEqual(self.db.rollback, self.db.abort)
2809
2810    def testContextManager(self):
2811        query = self.db.query
2812        self.createTable('test_table', 'n integer check(n>0)')
2813        with self.db:
2814            query("insert into test_table values (1)")
2815            query("insert into test_table values (2)")
2816        try:
2817            with self.db:
2818                query("insert into test_table values (3)")
2819                query("insert into test_table values (4)")
2820                raise ValueError('test transaction should rollback')
2821        except ValueError as error:
2822            self.assertEqual(str(error), 'test transaction should rollback')
2823        with self.db:
2824            query("insert into test_table values (5)")
2825        try:
2826            with self.db:
2827                query("insert into test_table values (6)")
2828                query("insert into test_table values (-1)")
2829        except pg.IntegrityError as error:
2830            self.assertTrue('check' in str(error))
2831        with self.db:
2832            query("insert into test_table values (7)")
2833        r = [r[0] for r in query(
2834            "select * from test_table order by 1").getresult()]
2835        self.assertEqual(r, [1, 2, 5, 7])
2836
2837    def testBytea(self):
2838        query = self.db.query
2839        self.createTable('bytea_test', 'n smallint primary key, data bytea')
2840        s = b"It's all \\ kinds \x00 of\r nasty \xff stuff!\n"
2841        r = self.db.escape_bytea(s)
2842        query('insert into bytea_test values(3, $1)', (r,))
2843        r = query('select * from bytea_test where n=3').getresult()
2844        self.assertEqual(len(r), 1)
2845        r = r[0]
2846        self.assertEqual(len(r), 2)
2847        self.assertEqual(r[0], 3)
2848        r = r[1]
2849        if pg.get_bytea_escaped():
2850            self.assertNotEqual(r, s)
2851            r = pg.unescape_bytea(r)
2852        self.assertIsInstance(r, bytes)
2853        self.assertEqual(r, s)
2854
2855    def testInsertUpdateGetBytea(self):
2856        query = self.db.query
2857        unescape = pg.unescape_bytea if pg.get_bytea_escaped() else None
2858        self.createTable('bytea_test', 'n smallint primary key, data bytea')
2859        # insert null value
2860        r = self.db.insert('bytea_test', n=0, data=None)
2861        self.assertIsInstance(r, dict)
2862        self.assertIn('n', r)
2863        self.assertEqual(r['n'], 0)
2864        self.assertIn('data', r)
2865        self.assertIsNone(r['data'])
2866        s = b'None'
2867        r = self.db.update('bytea_test', n=0, data=s)
2868        self.assertIsInstance(r, dict)
2869        self.assertIn('n', r)
2870        self.assertEqual(r['n'], 0)
2871        self.assertIn('data', r)
2872        r = r['data']
2873        if unescape:
2874            self.assertNotEqual(r, s)
2875            r = unescape(r)
2876        self.assertIsInstance(r, bytes)
2877        self.assertEqual(r, s)
2878        r = self.db.update('bytea_test', n=0, data=None)
2879        self.assertIsNone(r['data'])
2880        # insert as bytes
2881        s = b"It's all \\ kinds \x00 of\r nasty \xff stuff!\n"
2882        r = self.db.insert('bytea_test', n=5, data=s)
2883        self.assertIsInstance(r, dict)
2884        self.assertIn('n', r)
2885        self.assertEqual(r['n'], 5)
2886        self.assertIn('data', r)
2887        r = r['data']
2888        if unescape:
2889            self.assertNotEqual(r, s)
2890            r = unescape(r)
2891        self.assertIsInstance(r, bytes)
2892        self.assertEqual(r, s)
2893        # update as bytes
2894        s += b"and now even more \x00 nasty \t stuff!\f"
2895        r = self.db.update('bytea_test', n=5, data=s)
2896        self.assertIsInstance(r, dict)
2897        self.assertIn('n', r)
2898        self.assertEqual(r['n'], 5)
2899        self.assertIn('data', r)
2900        r = r['data']
2901        if unescape:
2902            self.assertNotEqual(r, s)
2903            r = unescape(r)
2904        self.assertIsInstance(r, bytes)
2905        self.assertEqual(r, s)
2906        r = query('select * from bytea_test where n=5').getresult()
2907        self.assertEqual(len(r), 1)
2908        r = r[0]
2909        self.assertEqual(len(r), 2)
2910        self.assertEqual(r[0], 5)
2911        r = r[1]
2912        if unescape:
2913            self.assertNotEqual(r, s)
2914            r = unescape(r)
2915        self.assertIsInstance(r, bytes)
2916        self.assertEqual(r, s)
2917        r = self.db.get('bytea_test', dict(n=5))
2918        self.assertIsInstance(r, dict)
2919        self.assertIn('n', r)
2920        self.assertEqual(r['n'], 5)
2921        self.assertIn('data', r)
2922        r = r['data']
2923        if unescape:
2924            self.assertNotEqual(r, s)
2925            r = pg.unescape_bytea(r)
2926        self.assertIsInstance(r, bytes)
2927        self.assertEqual(r, s)
2928
2929    def testUpsertBytea(self):
2930        self.createTable('bytea_test', 'n smallint primary key, data bytea')
2931        s = b"It's all \\ kinds \x00 of\r nasty \xff stuff!\n"
2932        r = dict(n=7, data=s)
2933        try:
2934            r = self.db.upsert('bytea_test', r)
2935        except pg.ProgrammingError as error:
2936            if self.db.server_version < 90500:
2937                self.skipTest('database does not support upsert')
2938            self.fail(str(error))
2939        self.assertIsInstance(r, dict)
2940        self.assertIn('n', r)
2941        self.assertEqual(r['n'], 7)
2942        self.assertIn('data', r)
2943        if pg.get_bytea_escaped():
2944            self.assertNotEqual(r['data'], s)
2945            r['data'] = pg.unescape_bytea(r['data'])
2946        self.assertIsInstance(r['data'], bytes)
2947        self.assertEqual(r['data'], s)
2948        r['data'] = None
2949        r = self.db.upsert('bytea_test', r)
2950        self.assertIsInstance(r, dict)
2951        self.assertIn('n', r)
2952        self.assertEqual(r['n'], 7)
2953        self.assertIn('data', r)
2954        self.assertIsNone(r['data'])
2955
2956    def testInsertGetJson(self):
2957        try:
2958            self.createTable('json_test', 'n smallint primary key, data json')
2959        except pg.ProgrammingError as error:
2960            if self.db.server_version < 90200:
2961                self.skipTest('database does not support json')
2962            self.fail(str(error))
2963        jsondecode = pg.get_jsondecode()
2964        # insert null value
2965        r = self.db.insert('json_test', n=0, data=None)
2966        self.assertIsInstance(r, dict)
2967        self.assertIn('n', r)
2968        self.assertEqual(r['n'], 0)
2969        self.assertIn('data', r)
2970        self.assertIsNone(r['data'])
2971        r = self.db.get('json_test', 0)
2972        self.assertIsInstance(r, dict)
2973        self.assertIn('n', r)
2974        self.assertEqual(r['n'], 0)
2975        self.assertIn('data', r)
2976        self.assertIsNone(r['data'])
2977        # insert JSON object
2978        data = {
2979            "id": 1, "name": "Foo", "price": 1234.5,
2980            "new": True, "note": None,
2981            "tags": ["Bar", "Eek"],
2982            "stock": {"warehouse": 300, "retail": 20}}
2983        r = self.db.insert('json_test', n=1, data=data)
2984        self.assertIsInstance(r, dict)
2985        self.assertIn('n', r)
2986        self.assertEqual(r['n'], 1)
2987        self.assertIn('data', r)
2988        r = r['data']
2989        if jsondecode is None:
2990            self.assertIsInstance(r, str)
2991            r = json.loads(r)
2992        self.assertIsInstance(r, dict)
2993        self.assertEqual(r, data)
2994        self.assertIsInstance(r['id'], int)
2995        self.assertIsInstance(r['name'], unicode)
2996        self.assertIsInstance(r['price'], float)
2997        self.assertIsInstance(r['new'], bool)
2998        self.assertIsInstance(r['tags'], list)
2999        self.assertIsInstance(r['stock'], dict)
3000        r = self.db.get('json_test', 1)
3001        self.assertIsInstance(r, dict)
3002        self.assertIn('n', r)
3003        self.assertEqual(r['n'], 1)
3004        self.assertIn('data', r)
3005        r = r['data']
3006        if jsondecode is None:
3007            self.assertIsInstance(r, str)
3008            r = json.loads(r)
3009        self.assertIsInstance(r, dict)
3010        self.assertEqual(r, data)
3011        self.assertIsInstance(r['id'], int)
3012        self.assertIsInstance(r['name'], unicode)
3013        self.assertIsInstance(r['price'], float)
3014        self.assertIsInstance(r['new'], bool)
3015        self.assertIsInstance(r['tags'], list)
3016        self.assertIsInstance(r['stock'], dict)
3017        # insert JSON object as text
3018        self.db.insert('json_test', n=2, data=json.dumps(data))
3019        q = "select data from json_test where n in (1, 2) order by n"
3020        r = self.db.query(q).getresult()
3021        self.assertEqual(len(r), 2)
3022        self.assertIsInstance(r[0][0], str if jsondecode is None else dict)
3023        self.assertEqual(r[0][0], r[1][0])
3024
3025    def testInsertGetJsonb(self):
3026        try:
3027            self.createTable('jsonb_test',
3028                             'n smallint primary key, data jsonb')
3029        except pg.ProgrammingError as error:
3030            if self.db.server_version < 90400:
3031                self.skipTest('database does not support jsonb')
3032            self.fail(str(error))
3033        jsondecode = pg.get_jsondecode()
3034        # insert null value
3035        r = self.db.insert('jsonb_test', n=0, data=None)
3036        self.assertIsInstance(r, dict)
3037        self.assertIn('n', r)
3038        self.assertEqual(r['n'], 0)
3039        self.assertIn('data', r)
3040        self.assertIsNone(r['data'])
3041        r = self.db.get('jsonb_test', 0)
3042        self.assertIsInstance(r, dict)
3043        self.assertIn('n', r)
3044        self.assertEqual(r['n'], 0)
3045        self.assertIn('data', r)
3046        self.assertIsNone(r['data'])
3047        # insert JSON object
3048        data = {
3049            "id": 1, "name": "Foo", "price": 1234.5,
3050            "new": True, "note": None,
3051            "tags": ["Bar", "Eek"],
3052            "stock": {"warehouse": 300, "retail": 20}}
3053        r = self.db.insert('jsonb_test', n=1, data=data)
3054        self.assertIsInstance(r, dict)
3055        self.assertIn('n', r)
3056        self.assertEqual(r['n'], 1)
3057        self.assertIn('data', r)
3058        r = r['data']
3059        if jsondecode is None:
3060            self.assertIsInstance(r, str)
3061            r = json.loads(r)
3062        self.assertIsInstance(r, dict)
3063        self.assertEqual(r, data)
3064        self.assertIsInstance(r['id'], int)
3065        self.assertIsInstance(r['name'], unicode)
3066        self.assertIsInstance(r['price'], float)
3067        self.assertIsInstance(r['new'], bool)
3068        self.assertIsInstance(r['tags'], list)
3069        self.assertIsInstance(r['stock'], dict)
3070        r = self.db.get('jsonb_test', 1)
3071        self.assertIsInstance(r, dict)
3072        self.assertIn('n', r)
3073        self.assertEqual(r['n'], 1)
3074        self.assertIn('data', r)
3075        r = r['data']
3076        if jsondecode is None:
3077            self.assertIsInstance(r, str)
3078            r = json.loads(r)
3079        self.assertIsInstance(r, dict)
3080        self.assertEqual(r, data)
3081        self.assertIsInstance(r['id'], int)
3082        self.assertIsInstance(r['name'], unicode)
3083        self.assertIsInstance(r['price'], float)
3084        self.assertIsInstance(r['new'], bool)
3085        self.assertIsInstance(r['tags'], list)
3086        self.assertIsInstance(r['stock'], dict)
3087
3088    def testArray(self):
3089        returns_arrays = pg.get_array()
3090        self.createTable('arraytest',
3091            'id smallint, i2 smallint[], i4 integer[], i8 bigint[],'
3092            ' d numeric[], f4 real[], f8 double precision[], m money[],'
3093            ' b bool[], v4 varchar(4)[], c4 char(4)[], t text[]')
3094        r = self.db.get_attnames('arraytest')
3095        if self.regtypes:
3096            self.assertEqual(r, dict(
3097                id='smallint', i2='smallint[]', i4='integer[]', i8='bigint[]',
3098                d='numeric[]', f4='real[]', f8='double precision[]',
3099                m='money[]', b='boolean[]',
3100                v4='character varying[]', c4='character[]', t='text[]'))
3101        else:
3102            self.assertEqual(r, dict(
3103                id='int', i2='int[]', i4='int[]', i8='int[]',
3104                d='num[]', f4='float[]', f8='float[]', m='money[]',
3105                b='bool[]', v4='text[]', c4='text[]', t='text[]'))
3106        decimal = pg.get_decimal()
3107        if decimal is Decimal:
3108            long_decimal = decimal('123456789.123456789')
3109            odd_money = decimal('1234567891234567.89')
3110        else:
3111            long_decimal = decimal('12345671234.5')
3112            odd_money = decimal('1234567123.25')
3113        t, f = (True, False) if pg.get_bool() else ('t', 'f')
3114        data = dict(id=42, i2=[42, 1234, None, 0, -1],
3115            i4=[42, 123456789, None, 0, 1, -1],
3116            i8=[long(42), long(123456789123456789), None,
3117                long(0), long(1), long(-1)],
3118            d=[decimal(42), long_decimal, None,
3119               decimal(0), decimal(1), decimal(-1), -long_decimal],
3120            f4=[42.0, 1234.5, None, 0.0, 1.0, -1.0,
3121                float('inf'), float('-inf')],
3122            f8=[42.0, 12345671234.5, None, 0.0, 1.0, -1.0,
3123                float('inf'), float('-inf')],
3124            m=[decimal('42.00'), odd_money, None,
3125               decimal('0.00'), decimal('1.00'), decimal('-1.00'), -odd_money],
3126            b=[t, f, t, None, f, t, None, None, t],
3127            v4=['abc', '"Hi"', '', None], c4=['abc ', '"Hi"', '    ', None],
3128            t=['abc', 'Hello, World!', '"Hello, World!"', '', None])
3129        r = data.copy()
3130        self.db.insert('arraytest', r)
3131        if returns_arrays:
3132            self.assertEqual(r, data)
3133        else:
3134            self.assertEqual(r['i4'], '{42,123456789,NULL,0,1,-1}')
3135        self.db.insert('arraytest', r)
3136        r = self.db.get('arraytest', 42, 'id')
3137        if returns_arrays:
3138            self.assertEqual(r, data)
3139        else:
3140            self.assertEqual(r['i4'], '{42,123456789,NULL,0,1,-1}')
3141        r = self.db.query('select * from arraytest limit 1').dictresult()[0]
3142        if returns_arrays:
3143            self.assertEqual(r, data)
3144        else:
3145            self.assertEqual(r['i4'], '{42,123456789,NULL,0,1,-1}')
3146
3147    def testArrayLiteral(self):
3148        insert = self.db.insert
3149        returns_arrays = pg.get_array()
3150        self.createTable('arraytest', 'i int[], t text[]', oids=True)
3151        r = dict(i=[1, 2, 3], t=['a', 'b', 'c'])
3152        insert('arraytest', r)
3153        if returns_arrays:
3154            self.assertEqual(r['i'], [1, 2, 3])
3155            self.assertEqual(r['t'], ['a', 'b', 'c'])
3156        else:
3157            self.assertEqual(r['i'], '{1,2,3}')
3158            self.assertEqual(r['t'], '{a,b,c}')
3159        r = dict(i='{1,2,3}', t='{a,b,c}')
3160        self.db.insert('arraytest', r)
3161        if returns_arrays:
3162            self.assertEqual(r['i'], [1, 2, 3])
3163            self.assertEqual(r['t'], ['a', 'b', 'c'])
3164        else:
3165            self.assertEqual(r['i'], '{1,2,3}')
3166            self.assertEqual(r['t'], '{a,b,c}')
3167        L = pg.Literal
3168        r = dict(i=L("ARRAY[1, 2, 3]"), t=L("ARRAY['a', 'b', 'c']"))
3169        self.db.insert('arraytest', r)
3170        if returns_arrays:
3171            self.assertEqual(r['i'], [1, 2, 3])
3172            self.assertEqual(r['t'], ['a', 'b', 'c'])
3173        else:
3174            self.assertEqual(r['i'], '{1,2,3}')
3175            self.assertEqual(r['t'], '{a,b,c}')
3176        r = dict(i="1, 2, 3", t="'a', 'b', 'c'")
3177        self.assertRaises(pg.DataError, self.db.insert, 'arraytest', r)
3178
3179    def testArrayOfIds(self):
3180        array_on = pg.get_array()
3181        self.createTable('arraytest', 'c cid[], o oid[], x xid[]', oids=True)
3182        r = self.db.get_attnames('arraytest')
3183        if self.regtypes:
3184            self.assertEqual(r, dict(
3185                oid='oid', c='cid[]', o='oid[]', x='xid[]'))
3186        else:
3187            self.assertEqual(r, dict(
3188                oid='int', c='int[]', o='int[]', x='int[]'))
3189        data = dict(c=[11, 12, 13], o=[21, 22, 23], x=[31, 32, 33])
3190        r = data.copy()
3191        self.db.insert('arraytest', r)
3192        qoid = 'oid(arraytest)'
3193        oid = r.pop(qoid)
3194        if array_on:
3195            self.assertEqual(r, data)
3196        else:
3197            self.assertEqual(r['o'], '{21,22,23}')
3198        r = {qoid: oid}
3199        self.db.get('arraytest', r)
3200        self.assertEqual(oid, r.pop(qoid))
3201        if array_on:
3202            self.assertEqual(r, data)
3203        else:
3204            self.assertEqual(r['o'], '{21,22,23}')
3205
3206    def testArrayOfText(self):
3207        array_on = pg.get_array()
3208        self.createTable('arraytest', 'data text[]', oids=True)
3209        r = self.db.get_attnames('arraytest')
3210        self.assertEqual(r['data'], 'text[]')
3211        data = ['Hello, World!', '', None, '{a,b,c}', '"Hi!"',
3212                'null', 'NULL', 'Null', 'nulL',
3213                "It's all \\ kinds of\r nasty stuff!\n"]
3214        r = dict(data=data)
3215        self.db.insert('arraytest', r)
3216        if not array_on:
3217            r['data'] = pg.cast_array(r['data'])
3218        self.assertEqual(r['data'], data)
3219        self.assertIsInstance(r['data'][1], str)
3220        self.assertIsNone(r['data'][2])
3221        r['data'] = None
3222        self.db.get('arraytest', r)
3223        if not array_on:
3224            r['data'] = pg.cast_array(r['data'])
3225        self.assertEqual(r['data'], data)
3226        self.assertIsInstance(r['data'][1], str)
3227        self.assertIsNone(r['data'][2])
3228
3229    def testArrayOfBytea(self):
3230        array_on = pg.get_array()
3231        bytea_escaped = pg.get_bytea_escaped()
3232        self.createTable('arraytest', 'data bytea[]', oids=True)
3233        r = self.db.get_attnames('arraytest')
3234        self.assertEqual(r['data'], 'bytea[]')
3235        data = [b'Hello, World!', b'', None, b'{a,b,c}', b'"Hi!"',
3236                b"It's all \\ kinds \x00 of\r nasty \xff stuff!\n"]
3237        r = dict(data=data)
3238        self.db.insert('arraytest', r)
3239        if array_on:
3240            self.assertIsInstance(r['data'], list)
3241        if array_on and not bytea_escaped:
3242            self.assertEqual(r['data'], data)
3243            self.assertIsInstance(r['data'][1], bytes)
3244            self.assertIsNone(r['data'][2])
3245        else:
3246            self.assertNotEqual(r['data'], data)
3247        r['data'] = None
3248        self.db.get('arraytest', r)
3249        if array_on:
3250            self.assertIsInstance(r['data'], list)
3251        if array_on and not bytea_escaped:
3252            self.assertEqual(r['data'], data)
3253            self.assertIsInstance(r['data'][1], bytes)
3254            self.assertIsNone(r['data'][2])
3255        else:
3256            self.assertNotEqual(r['data'], data)
3257
3258    def testArrayOfJson(self):
3259        try:
3260            self.createTable('arraytest', 'data json[]', oids=True)
3261        except pg.ProgrammingError as error:
3262            if self.db.server_version < 90200:
3263                self.skipTest('database does not support json')
3264            self.fail(str(error))
3265        r = self.db.get_attnames('arraytest')
3266        self.assertEqual(r['data'], 'json[]')
3267        data = [dict(id=815, name='John Doe'), dict(id=816, name='Jane Roe')]
3268        array_on = pg.get_array()
3269        jsondecode = pg.get_jsondecode()
3270        r = dict(data=data)
3271        self.db.insert('arraytest', r)
3272        if not array_on:
3273            r['data'] = pg.cast_array(r['data'], jsondecode)
3274        if jsondecode is None:
3275            r['data'] = [json.loads(d) for d in r['data']]
3276        self.assertEqual(r['data'], data)
3277        r['data'] = None
3278        self.db.get('arraytest', r)
3279        if not array_on:
3280            r['data'] = pg.cast_array(r['data'], jsondecode)
3281        if jsondecode is None:
3282            r['data'] = [json.loads(d) for d in r['data']]
3283        self.assertEqual(r['data'], data)
3284        r = dict(data=[json.dumps(d) for d in data])
3285        self.db.insert('arraytest', r)
3286        if not array_on:
3287            r['data'] = pg.cast_array(r['data'], jsondecode)
3288        if jsondecode is None:
3289            r['data'] = [json.loads(d) for d in r['data']]
3290        self.assertEqual(r['data'], data)
3291        r['data'] = None
3292        self.db.get('arraytest', r)
3293        # insert empty json values
3294        r = dict(data=['', None])
3295        self.db.insert('arraytest', r)
3296        r = r['data']
3297        if array_on:
3298            self.assertIsInstance(r, list)
3299            self.assertEqual(len(r), 2)
3300            self.assertIsNone(r[0])
3301            self.assertIsNone(r[1])
3302        else:
3303            self.assertEqual(r, '{NULL,NULL}')
3304
3305    def testArrayOfJsonb(self):
3306        try:
3307            self.createTable('arraytest', 'data jsonb[]', oids=True)
3308        except pg.ProgrammingError as error:
3309            if self.db.server_version < 90400:
3310                self.skipTest('database does not support jsonb')
3311            self.fail(str(error))
3312        r = self.db.get_attnames('arraytest')
3313        self.assertEqual(r['data'], 'jsonb[]' if self.regtypes else 'json[]')
3314        data = [dict(id=815, name='John Doe'), dict(id=816, name='Jane Roe')]
3315        array_on = pg.get_array()
3316        jsondecode = pg.get_jsondecode()
3317        r = dict(data=data)
3318        self.db.insert('arraytest', r)
3319        if not array_on:
3320            r['data'] = pg.cast_array(r['data'], jsondecode)
3321        if jsondecode is None:
3322            r['data'] = [json.loads(d) for d in r['data']]
3323        self.assertEqual(r['data'], data)
3324        r['data'] = None
3325        self.db.get('arraytest', r)
3326        if not array_on:
3327            r['data'] = pg.cast_array(r['data'], jsondecode)
3328        if jsondecode is None:
3329            r['data'] = [json.loads(d) for d in r['data']]
3330        self.assertEqual(r['data'], data)
3331        r = dict(data=[json.dumps(d) for d in data])
3332        self.db.insert('arraytest', r)
3333        if not array_on:
3334            r['data'] = pg.cast_array(r['data'], jsondecode)
3335        if jsondecode is None:
3336            r['data'] = [json.loads(d) for d in r['data']]
3337        self.assertEqual(r['data'], data)
3338        r['data'] = None
3339        self.db.get('arraytest', r)
3340        # insert empty json values
3341        r = dict(data=['', None])
3342        self.db.insert('arraytest', r)
3343        r = r['data']
3344        if array_on:
3345            self.assertIsInstance(r, list)
3346            self.assertEqual(len(r), 2)
3347            self.assertIsNone(r[0])
3348            self.assertIsNone(r[1])
3349        else:
3350            self.assertEqual(r, '{NULL,NULL}')
3351
3352    def testDeepArray(self):
3353        array_on = pg.get_array()
3354        self.createTable('arraytest', 'data text[][][]', oids=True)
3355        r = self.db.get_attnames('arraytest')
3356        self.assertEqual(r['data'], 'text[]')
3357        data = [[['Hello, World!', '{a,b,c}', 'back\\slash']]]
3358        r = dict(data=data)
3359        self.db.insert('arraytest', r)
3360        if array_on:
3361            self.assertEqual(r['data'], data)
3362        else:
3363            self.assertTrue(r['data'].startswith('{{{"Hello,'))
3364        r['data'] = None
3365        self.db.get('arraytest', r)
3366        if array_on:
3367            self.assertEqual(r['data'], data)
3368        else:
3369            self.assertTrue(r['data'].startswith('{{{"Hello,'))
3370
3371    def testInsertUpdateGetRecord(self):
3372        query = self.db.query
3373        query('create type test_person_type as'
3374              ' (name varchar, age smallint, married bool,'
3375              ' weight real, salary money)')
3376        self.addCleanup(query, 'drop type test_person_type')
3377        self.createTable('test_person', 'person test_person_type',
3378                         temporary=False, oids=True)
3379        attnames = self.db.get_attnames('test_person')
3380        self.assertEqual(len(attnames), 2)
3381        self.assertIn('oid', attnames)
3382        self.assertIn('person', attnames)
3383        person_typ = attnames['person']
3384        if self.regtypes:
3385            self.assertEqual(person_typ, 'test_person_type')
3386        else:
3387            self.assertEqual(person_typ, 'record')
3388        if self.regtypes:
3389            self.assertEqual(person_typ.attnames,
3390                dict(name='character varying', age='smallint',
3391                    married='boolean', weight='real', salary='money'))
3392        else:
3393            self.assertEqual(person_typ.attnames,
3394                dict(name='text', age='int', married='bool',
3395                    weight='float', salary='money'))
3396        decimal = pg.get_decimal()
3397        if pg.get_bool():
3398            bool_class = bool
3399            t, f = True, False
3400        else:
3401            bool_class = str
3402            t, f = 't', 'f'
3403        person = ('John Doe', 61, t, 99.5, decimal('93456.75'))
3404        r = self.db.insert('test_person', None, person=person)
3405        p = r['person']
3406        self.assertIsInstance(p, tuple)
3407        self.assertEqual(p, person)
3408        self.assertEqual(p.name, 'John Doe')
3409        self.assertIsInstance(p.name, str)
3410        self.assertIsInstance(p.age, int)
3411        self.assertIsInstance(p.married, bool_class)
3412        self.assertIsInstance(p.weight, float)
3413        self.assertIsInstance(p.salary, decimal)
3414        person = ('Jane Roe', 59, f, 64.5, decimal('96543.25'))
3415        r['person'] = person
3416        self.db.update('test_person', r)
3417        p = r['person']
3418        self.assertIsInstance(p, tuple)
3419        self.assertEqual(p, person)
3420        self.assertEqual(p.name, 'Jane Roe')
3421        self.assertIsInstance(p.name, str)
3422        self.assertIsInstance(p.age, int)
3423        self.assertIsInstance(p.married, bool_class)
3424        self.assertIsInstance(p.weight, float)
3425        self.assertIsInstance(p.salary, decimal)
3426        r['person'] = None
3427        self.db.get('test_person', r)
3428        p = r['person']
3429        self.assertIsInstance(p, tuple)
3430        self.assertEqual(p, person)
3431        self.assertEqual(p.name, 'Jane Roe')
3432        self.assertIsInstance(p.name, str)
3433        self.assertIsInstance(p.age, int)
3434        self.assertIsInstance(p.married, bool_class)
3435        self.assertIsInstance(p.weight, float)
3436        self.assertIsInstance(p.salary, decimal)
3437        person = (None,) * 5
3438        r = self.db.insert('test_person', None, person=person)
3439        p = r['person']
3440        self.assertIsInstance(p, tuple)
3441        self.assertIsNone(p.name)
3442        self.assertIsNone(p.age)
3443        self.assertIsNone(p.married)
3444        self.assertIsNone(p.weight)
3445        self.assertIsNone(p.salary)
3446        r['person'] = None
3447        self.db.get('test_person', r)
3448        p = r['person']
3449        self.assertIsInstance(p, tuple)
3450        self.assertIsNone(p.name)
3451        self.assertIsNone(p.age)
3452        self.assertIsNone(p.married)
3453        self.assertIsNone(p.weight)
3454        self.assertIsNone(p.salary)
3455        r = self.db.insert('test_person', None, person=None)
3456        self.assertIsNone(r['person'])
3457        r['person'] = None
3458        self.db.get('test_person', r)
3459        self.assertIsNone(r['person'])
3460
3461    def testRecordInsertBytea(self):
3462        query = self.db.query
3463        query('create type test_person_type as'
3464              ' (name text, picture bytea)')
3465        self.addCleanup(query, 'drop type test_person_type')
3466        self.createTable('test_person', 'person test_person_type',
3467                         temporary=False, oids=True)
3468        person_typ = self.db.get_attnames('test_person')['person']
3469        self.assertEqual(person_typ.attnames,
3470                         dict(name='text', picture='bytea'))
3471        person = ('John Doe', b'O\x00ps\xff!')
3472        r = self.db.insert('test_person', None, person=person)
3473        p = r['person']
3474        self.assertIsInstance(p, tuple)
3475        self.assertEqual(p, person)
3476        self.assertEqual(p.name, 'John Doe')
3477        self.assertIsInstance(p.name, str)
3478        self.assertEqual(p.picture, person[1])
3479        self.assertIsInstance(p.picture, bytes)
3480
3481    def testRecordInsertJson(self):
3482        query = self.db.query
3483        try:
3484            query('create type test_person_type as'
3485                  ' (name text, data json)')
3486        except pg.ProgrammingError as error:
3487            if self.db.server_version < 90200:
3488                self.skipTest('database does not support json')
3489            self.fail(str(error))
3490        self.addCleanup(query, 'drop type test_person_type')
3491        self.createTable('test_person', 'person test_person_type',
3492                         temporary=False, oids=True)
3493        person_typ = self.db.get_attnames('test_person')['person']
3494        self.assertEqual(person_typ.attnames,
3495                         dict(name='text', data='json'))
3496        person = ('John Doe', dict(age=61, married=True, weight=99.5))
3497        r = self.db.insert('test_person', None, person=person)
3498        p = r['person']
3499        self.assertIsInstance(p, tuple)
3500        if pg.get_jsondecode() is None:
3501            p = p._replace(data=json.loads(p.data))
3502        self.assertEqual(p, person)
3503        self.assertEqual(p.name, 'John Doe')
3504        self.assertIsInstance(p.name, str)
3505        self.assertEqual(p.data, person[1])
3506        self.assertIsInstance(p.data, dict)
3507
3508    def testRecordLiteral(self):
3509        query = self.db.query
3510        query('create type test_person_type as'
3511              ' (name varchar, age smallint)')
3512        self.addCleanup(query, 'drop type test_person_type')
3513        self.createTable('test_person', 'person test_person_type',
3514                         temporary=False, oids=True)
3515        person_typ = self.db.get_attnames('test_person')['person']
3516        if self.regtypes:
3517            self.assertEqual(person_typ, 'test_person_type')
3518        else:
3519            self.assertEqual(person_typ, 'record')
3520        if self.regtypes:
3521            self.assertEqual(person_typ.attnames,
3522                             dict(name='character varying', age='smallint'))
3523        else:
3524            self.assertEqual(person_typ.attnames,
3525                             dict(name='text', age='int'))
3526        person = pg.Literal("('John Doe', 61)")
3527        r = self.db.insert('test_person', None, person=person)
3528        p = r['person']
3529        self.assertIsInstance(p, tuple)
3530        self.assertEqual(p.name, 'John Doe')
3531        self.assertIsInstance(p.name, str)
3532        self.assertEqual(p.age, 61)
3533        self.assertIsInstance(p.age, int)
3534
3535    def testDate(self):
3536        query = self.db.query
3537        for datestyle in ('ISO', 'Postgres, MDY', 'Postgres, DMY',
3538                'SQL, MDY', 'SQL, DMY', 'German'):
3539            self.db.set_parameter('datestyle', datestyle)
3540            d = date(2016, 3, 14)
3541            q = "select $1::date"
3542            r = query(q, (d,)).getresult()[0][0]
3543            self.assertIsInstance(r, date)
3544            self.assertEqual(r, d)
3545            q = "select '10000-08-01'::date, '0099-01-08 BC'::date"
3546            r = query(q).getresult()[0]
3547            self.assertIsInstance(r[0], date)
3548            self.assertIsInstance(r[1], date)
3549            self.assertEqual(r[0], date.max)
3550            self.assertEqual(r[1], date.min)
3551        q = "select 'infinity'::date, '-infinity'::date"
3552        r = query(q).getresult()[0]
3553        self.assertIsInstance(r[0], date)
3554        self.assertIsInstance(r[1], date)
3555        self.assertEqual(r[0], date.max)
3556        self.assertEqual(r[1], date.min)
3557
3558    def testTime(self):
3559        query = self.db.query
3560        d = time(15, 9, 26)
3561        q = "select $1::time"
3562        r = query(q, (d,)).getresult()[0][0]
3563        self.assertIsInstance(r, time)
3564        self.assertEqual(r, d)
3565        d = time(15, 9, 26, 535897)
3566        q = "select $1::time"
3567        r = query(q, (d,)).getresult()[0][0]
3568        self.assertIsInstance(r, time)
3569        self.assertEqual(r, d)
3570
3571    def testTimetz(self):
3572        query = self.db.query
3573        timezones = dict(CET=1, EET=2, EST=-5, UTC=0)
3574        for timezone in sorted(timezones):
3575            offset = timezones[timezone]
3576            try:
3577                tzinfo = datetime.strptime('%+03d00' % offset, '%z').tzinfo
3578            except ValueError:  # Python < 3.3
3579                tzinfo = None
3580            self.db.set_parameter('timezone', timezone)
3581            d = time(15, 9, 26, tzinfo=tzinfo)
3582            q = "select $1::timetz"
3583            r = query(q, (d,)).getresult()[0][0]
3584            self.assertIsInstance(r, time)
3585            self.assertEqual(r, d)
3586            d = time(15, 9, 26, 535897, tzinfo)
3587            q = "select $1::timetz"
3588            r = query(q, (d,)).getresult()[0][0]
3589            self.assertIsInstance(r, time)
3590            self.assertEqual(r, d)
3591
3592    def testTimestamp(self):
3593        query = self.db.query
3594        for datestyle in ('ISO', 'Postgres, MDY', 'Postgres, DMY',
3595                'SQL, MDY', 'SQL, DMY', 'German'):
3596            self.db.set_parameter('datestyle', datestyle)
3597            d = datetime(2016, 3, 14)
3598            q = "select $1::timestamp"
3599            r = query(q, (d,)).getresult()[0][0]
3600            self.assertIsInstance(r, datetime)
3601            self.assertEqual(r, d)
3602            d = datetime(2016, 3, 14, 15, 9, 26)
3603            q = "select $1::timestamp"
3604            r = query(q, (d,)).getresult()[0][0]
3605            self.assertIsInstance(r, datetime)
3606            self.assertEqual(r, d)
3607            d = datetime(2016, 3, 14, 15, 9, 26, 535897)
3608            q = "select $1::timestamp"
3609            r = query(q, (d,)).getresult()[0][0]
3610            self.assertIsInstance(r, datetime)
3611            self.assertEqual(r, d)
3612            q = ("select '10000-08-01 AD'::timestamp,"
3613                " '0099-01-08 BC'::timestamp")
3614            r = query(q).getresult()[0]
3615            self.assertIsInstance(r[0], datetime)
3616            self.assertIsInstance(r[1], datetime)
3617            self.assertEqual(r[0], datetime.max)
3618            self.assertEqual(r[1], datetime.min)
3619        q = "select 'infinity'::timestamp, '-infinity'::timestamp"
3620        r = query(q).getresult()[0]
3621        self.assertIsInstance(r[0], datetime)
3622        self.assertIsInstance(r[1], datetime)
3623        self.assertEqual(r[0], datetime.max)
3624        self.assertEqual(r[1], datetime.min)
3625
3626    def testTimestamptz(self):
3627        query = self.db.query
3628        timezones = dict(CET=1, EET=2, EST=-5, UTC=0)
3629        for timezone in sorted(timezones):
3630            offset = timezones[timezone]
3631            try:
3632                tzinfo = datetime.strptime('%+03d00' % offset, '%z').tzinfo
3633            except ValueError:  # Python < 3.3
3634                tzinfo = None
3635            self.db.set_parameter('timezone', timezone)
3636            for datestyle in ('ISO', 'Postgres, MDY', 'Postgres, DMY',
3637                    'SQL, MDY', 'SQL, DMY', 'German'):
3638                self.db.set_parameter('datestyle', datestyle)
3639                d = datetime(2016, 3, 14, tzinfo=tzinfo)
3640                q = "select $1::timestamptz"
3641                r = query(q, (d,)).getresult()[0][0]
3642                self.assertIsInstance(r, datetime)
3643                self.assertEqual(r, d)
3644                d = datetime(2016, 3, 14, 15, 9, 26, tzinfo=tzinfo)
3645                q = "select $1::timestamptz"
3646                r = query(q, (d,)).getresult()[0][0]
3647                self.assertIsInstance(r, datetime)
3648                self.assertEqual(r, d)
3649                d = datetime(2016, 3, 14, 15, 9, 26, 535897, tzinfo)
3650                q = "select $1::timestamptz"
3651                r = query(q, (d,)).getresult()[0][0]
3652                self.assertIsInstance(r, datetime)
3653                self.assertEqual(r, d)
3654                q = ("select '10000-08-01 AD'::timestamptz,"
3655                    " '0099-01-08 BC'::timestamptz")
3656                r = query(q).getresult()[0]
3657                self.assertIsInstance(r[0], datetime)
3658                self.assertIsInstance(r[1], datetime)
3659                self.assertEqual(r[0], datetime.max)
3660                self.assertEqual(r[1], datetime.min)
3661        q = "select 'infinity'::timestamptz, '-infinity'::timestamptz"
3662        r = query(q).getresult()[0]
3663        self.assertIsInstance(r[0], datetime)
3664        self.assertIsInstance(r[1], datetime)
3665        self.assertEqual(r[0], datetime.max)
3666        self.assertEqual(r[1], datetime.min)
3667
3668    def testInterval(self):
3669        query = self.db.query
3670        for intervalstyle in (
3671                'sql_standard', 'postgres', 'postgres_verbose', 'iso_8601'):
3672            self.db.set_parameter('intervalstyle', intervalstyle)
3673            d = timedelta(3)
3674            q = "select $1::interval"
3675            r = query(q, (d,)).getresult()[0][0]
3676            self.assertIsInstance(r, timedelta)
3677            self.assertEqual(r, d)
3678            d = timedelta(-30)
3679            r = query(q, (d,)).getresult()[0][0]
3680            self.assertIsInstance(r, timedelta)
3681            self.assertEqual(r, d)
3682            d = timedelta(hours=3, minutes=31, seconds=42, microseconds=5678)
3683            q = "select $1::interval"
3684            r = query(q, (d,)).getresult()[0][0]
3685            self.assertIsInstance(r, timedelta)
3686            self.assertEqual(r, d)
3687
3688    def testDateAndTimeArrays(self):
3689        dt = (date(2016, 3, 14), time(15, 9, 26))
3690        q = "select ARRAY[$1::date], ARRAY[$2::time]"
3691        r = self.db.query(q, dt).getresult()[0]
3692        self.assertIsInstance(r[0], list)
3693        self.assertEqual(r[0][0], dt[0])
3694        self.assertIsInstance(r[1], list)
3695        self.assertEqual(r[1][0], dt[1])
3696
3697    def testHstore(self):
3698        try:
3699            self.db.query("select 'k=>v'::hstore")
3700        except pg.DatabaseError:
3701            try:
3702                self.db.query("create extension hstore")
3703            except pg.DatabaseError:
3704                self.skipTest("hstore extension not enabled")
3705        d = {'k': 'v', 'foo': 'bar', 'baz': 'whatever',
3706            '1a': 'anything at all', '2=b': 'value = 2', '3>c': 'value > 3',
3707            '4"c': 'value " 4', "5'c": "value ' 5", 'hello, world': '"hi!"',
3708            'None': None, 'NULL': 'NULL', 'empty': ''}
3709        q = "select $1::hstore"
3710        r = self.db.query(q, (pg.Hstore(d),)).getresult()[0][0]
3711        self.assertIsInstance(r, dict)
3712        self.assertEqual(r, d)
3713
3714    def testUuid(self):
3715        d = UUID('{12345678-1234-5678-1234-567812345678}')
3716        q = 'select $1::uuid'
3717        r = self.db.query(q, (d,)).getresult()[0][0]
3718        self.assertIsInstance(r, UUID)
3719        self.assertEqual(r, d)
3720
3721    def testDbTypesInfo(self):
3722        dbtypes = self.db.dbtypes
3723        self.assertIsInstance(dbtypes, dict)
3724        self.assertNotIn('numeric', dbtypes)
3725        typ = dbtypes['numeric']
3726        self.assertIn('numeric', dbtypes)
3727        self.assertEqual(typ, 'numeric' if self.regtypes else 'num')
3728        self.assertEqual(typ.oid, 1700)
3729        self.assertEqual(typ.pgtype, 'numeric')
3730        self.assertEqual(typ.regtype, 'numeric')
3731        self.assertEqual(typ.simple, 'num')
3732        self.assertEqual(typ.typtype, 'b')
3733        self.assertEqual(typ.category, 'N')
3734        self.assertEqual(typ.delim, ',')
3735        self.assertEqual(typ.relid, 0)
3736        self.assertIs(dbtypes[1700], typ)
3737        self.assertNotIn('pg_type', dbtypes)
3738        typ = dbtypes['pg_type']
3739        self.assertIn('pg_type', dbtypes)
3740        self.assertEqual(typ, 'pg_type' if self.regtypes else 'record')
3741        self.assertIsInstance(typ.oid, int)
3742        self.assertEqual(typ.pgtype, 'pg_type')
3743        self.assertEqual(typ.regtype, 'pg_type')
3744        self.assertEqual(typ.simple, 'record')
3745        self.assertEqual(typ.typtype, 'c')
3746        self.assertEqual(typ.category, 'C')
3747        self.assertEqual(typ.delim, ',')
3748        self.assertNotEqual(typ.relid, 0)
3749        attnames = typ.attnames
3750        self.assertIsInstance(attnames, dict)
3751        self.assertIs(attnames, dbtypes.get_attnames('pg_type'))
3752        self.assertEqual(list(attnames)[0], 'typname')
3753        typname = attnames['typname']
3754        self.assertEqual(typname, 'name' if self.regtypes else 'text')
3755        self.assertEqual(typname.typtype, 'b')  # base
3756        self.assertEqual(typname.category, 'S')  # string
3757        self.assertEqual(list(attnames)[3], 'typlen')
3758        typlen = attnames['typlen']
3759        self.assertEqual(typlen, 'smallint' if self.regtypes else 'int')
3760        self.assertEqual(typlen.typtype, 'b')  # base
3761        self.assertEqual(typlen.category, 'N')  # numeric
3762
3763    def testDbTypesTypecast(self):
3764        dbtypes = self.db.dbtypes
3765        self.assertIsInstance(dbtypes, dict)
3766        self.assertNotIn('int4', dbtypes)
3767        self.assertIs(dbtypes.get_typecast('int4'), int)
3768        dbtypes.set_typecast('int4', float)
3769        self.assertIs(dbtypes.get_typecast('int4'), float)
3770        dbtypes.reset_typecast('int4')
3771        self.assertIs(dbtypes.get_typecast('int4'), int)
3772        dbtypes.set_typecast('int4', float)
3773        self.assertIs(dbtypes.get_typecast('int4'), float)
3774        dbtypes.reset_typecast()
3775        self.assertIs(dbtypes.get_typecast('int4'), int)
3776        self.assertNotIn('circle', dbtypes)
3777        self.assertIsNone(dbtypes.get_typecast('circle'))
3778        squared_circle = lambda v: 'Squared Circle: %s' % v
3779        dbtypes.set_typecast('circle', squared_circle)
3780        self.assertIs(dbtypes.get_typecast('circle'), squared_circle)
3781        r = self.db.query("select '0,0,1'::circle").getresult()[0][0]
3782        self.assertIn('circle', dbtypes)
3783        self.assertEqual(r, 'Squared Circle: <(0,0),1>')
3784        self.assertEqual(dbtypes.typecast('Impossible', 'circle'),
3785            'Squared Circle: Impossible')
3786        dbtypes.reset_typecast('circle')
3787        self.assertIsNone(dbtypes.get_typecast('circle'))
3788
3789    def testGetSetTypeCast(self):
3790        get_typecast = pg.get_typecast
3791        set_typecast = pg.set_typecast
3792        dbtypes = self.db.dbtypes
3793        self.assertIsInstance(dbtypes, dict)
3794        self.assertNotIn('int4', dbtypes)
3795        self.assertNotIn('real', dbtypes)
3796        self.assertNotIn('bool', dbtypes)
3797        self.assertIs(get_typecast('int4'), int)
3798        self.assertIs(get_typecast('float4'), float)
3799        self.assertIs(get_typecast('bool'), pg.cast_bool)
3800        cast_circle = get_typecast('circle')
3801        self.addCleanup(set_typecast, 'circle', cast_circle)
3802        squared_circle = lambda v: 'Squared Circle: %s' % v
3803        self.assertNotIn('circle', dbtypes)
3804        set_typecast('circle', squared_circle)
3805        self.assertNotIn('circle', dbtypes)
3806        self.assertIs(get_typecast('circle'), squared_circle)
3807        r = self.db.query("select '0,0,1'::circle").getresult()[0][0]
3808        self.assertIn('circle', dbtypes)
3809        self.assertEqual(r, 'Squared Circle: <(0,0),1>')
3810        set_typecast('circle', cast_circle)
3811        self.assertIs(get_typecast('circle'), cast_circle)
3812
3813    def testNotificationHandler(self):
3814        # the notification handler itself is tested separately
3815        f = self.db.notification_handler
3816        callback = lambda arg_dict: None
3817        handler = f('test', callback)
3818        self.assertIsInstance(handler, pg.NotificationHandler)
3819        self.assertIs(handler.db, self.db)
3820        self.assertEqual(handler.event, 'test')
3821        self.assertEqual(handler.stop_event, 'stop_test')
3822        self.assertIs(handler.callback, callback)
3823        self.assertIsInstance(handler.arg_dict, dict)
3824        self.assertEqual(handler.arg_dict, {})
3825        self.assertIsNone(handler.timeout)
3826        self.assertFalse(handler.listening)
3827        handler.close()
3828        self.assertIsNone(handler.db)
3829        self.db.reopen()
3830        self.assertIsNone(handler.db)
3831        handler = f('test2', callback, timeout=2)
3832        self.assertIsInstance(handler, pg.NotificationHandler)
3833        self.assertIs(handler.db, self.db)
3834        self.assertEqual(handler.event, 'test2')
3835        self.assertEqual(handler.stop_event, 'stop_test2')
3836        self.assertIs(handler.callback, callback)
3837        self.assertIsInstance(handler.arg_dict, dict)
3838        self.assertEqual(handler.arg_dict, {})
3839        self.assertEqual(handler.timeout, 2)
3840        self.assertFalse(handler.listening)
3841        handler.close()
3842        self.assertIsNone(handler.db)
3843        self.db.reopen()
3844        self.assertIsNone(handler.db)
3845        arg_dict = {'testing': 3}
3846        handler = f('test3', callback, arg_dict=arg_dict)
3847        self.assertIsInstance(handler, pg.NotificationHandler)
3848        self.assertIs(handler.db, self.db)
3849        self.assertEqual(handler.event, 'test3')
3850        self.assertEqual(handler.stop_event, 'stop_test3')
3851        self.assertIs(handler.callback, callback)
3852        self.assertIs(handler.arg_dict, arg_dict)
3853        self.assertEqual(arg_dict['testing'], 3)
3854        self.assertIsNone(handler.timeout)
3855        self.assertFalse(handler.listening)
3856        handler.close()
3857        self.assertIsNone(handler.db)
3858        self.db.reopen()
3859        self.assertIsNone(handler.db)
3860        handler = f('test4', callback, stop_event='stop4')
3861        self.assertIsInstance(handler, pg.NotificationHandler)
3862        self.assertIs(handler.db, self.db)
3863        self.assertEqual(handler.event, 'test4')
3864        self.assertEqual(handler.stop_event, 'stop4')
3865        self.assertIs(handler.callback, callback)
3866        self.assertIsInstance(handler.arg_dict, dict)
3867        self.assertEqual(handler.arg_dict, {})
3868        self.assertIsNone(handler.timeout)
3869        self.assertFalse(handler.listening)
3870        handler.close()
3871        self.assertIsNone(handler.db)
3872        self.db.reopen()
3873        self.assertIsNone(handler.db)
3874        arg_dict = {'testing': 5}
3875        handler = f('test5', callback, arg_dict, 1.5, 'stop5')
3876        self.assertIsInstance(handler, pg.NotificationHandler)
3877        self.assertIs(handler.db, self.db)
3878        self.assertEqual(handler.event, 'test5')
3879        self.assertEqual(handler.stop_event, 'stop5')
3880        self.assertIs(handler.callback, callback)
3881        self.assertIs(handler.arg_dict, arg_dict)
3882        self.assertEqual(arg_dict['testing'], 5)
3883        self.assertEqual(handler.timeout, 1.5)
3884        self.assertFalse(handler.listening)
3885        handler.close()
3886        self.assertIsNone(handler.db)
3887        self.db.reopen()
3888        self.assertIsNone(handler.db)
3889
3890
3891class TestDBClassNonStdOpts(TestDBClass):
3892    """Test the methods of the DB class with non-standard global options."""
3893
3894    @classmethod
3895    def setUpClass(cls):
3896        cls.saved_options = {}
3897        cls.set_option('decimal', float)
3898        not_bool = not pg.get_bool()
3899        cls.set_option('bool', not_bool)
3900        not_array = not pg.get_array()
3901        cls.set_option('array', not_array)
3902        not_bytea_escaped = not pg.get_bytea_escaped()
3903        cls.set_option('bytea_escaped', not_bytea_escaped)
3904        cls.set_option('namedresult', None)
3905        cls.set_option('jsondecode', None)
3906        cls.regtypes = not DB().use_regtypes()
3907        super(TestDBClassNonStdOpts, cls).setUpClass()
3908
3909    @classmethod
3910    def tearDownClass(cls):
3911        super(TestDBClassNonStdOpts, cls).tearDownClass()
3912        cls.reset_option('jsondecode')
3913        cls.reset_option('namedresult')
3914        cls.reset_option('bool')
3915        cls.reset_option('array')
3916        cls.reset_option('bytea_escaped')
3917        cls.reset_option('decimal')
3918
3919    @classmethod
3920    def set_option(cls, option, value):
3921        cls.saved_options[option] = getattr(pg, 'get_' + option)()
3922        return getattr(pg, 'set_' + option)(value)
3923
3924    @classmethod
3925    def reset_option(cls, option):
3926        return getattr(pg, 'set_' + option)(cls.saved_options[option])
3927
3928
3929class TestDBClassAdapter(unittest.TestCase):
3930    """Test the adapter object associatd with the DB class."""
3931
3932    def setUp(self):
3933        self.db = DB()
3934        self.adapter = self.db.adapter
3935
3936    def tearDown(self):
3937        try:
3938            self.db.close()
3939        except pg.InternalError:
3940            pass
3941
3942    def testGuessSimpleType(self):
3943        f = self.adapter.guess_simple_type
3944        self.assertEqual(f(pg.Bytea(b'test')), 'bytea')
3945        self.assertEqual(f('string'), 'text')
3946        self.assertEqual(f(b'string'), 'text')
3947        self.assertEqual(f(True), 'bool')
3948        self.assertEqual(f(3), 'int')
3949        self.assertEqual(f(2.75), 'float')
3950        self.assertEqual(f(Decimal('4.25')), 'num')
3951        self.assertEqual(f(date(2016, 1, 30)), 'date')
3952        self.assertEqual(f([1, 2, 3]), 'int[]')
3953        self.assertEqual(f([[[123]]]), 'int[]')
3954        self.assertEqual(f(['a', 'b', 'c']), 'text[]')
3955        self.assertEqual(f([[['abc']]]), 'text[]')
3956        self.assertEqual(f([False, True]), 'bool[]')
3957        self.assertEqual(f([[[False]]]), 'bool[]')
3958        r = f(('string', True, 3, 2.75, [1], [False]))
3959        self.assertEqual(r, 'record')
3960        self.assertEqual(list(r.attnames.values()),
3961            ['text', 'bool', 'int', 'float', 'int[]', 'bool[]'])
3962
3963    def testAdaptQueryTypedList(self):
3964        format_query = self.adapter.format_query
3965        self.assertRaises(TypeError, format_query,
3966            '%s,%s', (1, 2), ('int2',))
3967        self.assertRaises(TypeError, format_query,
3968            '%s,%s', (1,), ('int2', 'int2'))
3969        values = (3, 7.5, 'hello', True)
3970        types = ('int4', 'float4', 'text', 'bool')
3971        sql, params = format_query("select %s,%s,%s,%s", values, types)
3972        self.assertEqual(sql, 'select $1,$2,$3,$4')
3973        self.assertEqual(params, [3, 7.5, 'hello', 't'])
3974        types = ('bool', 'bool', 'bool', 'bool')
3975        sql, params = format_query("select %s,%s,%s,%s", values, types)
3976        self.assertEqual(sql, 'select $1,$2,$3,$4')
3977        self.assertEqual(params, ['t', 't', 'f', 't'])
3978        values = ('2016-01-30', 'current_date')
3979        types = ('date', 'date')
3980        sql, params = format_query("values(%s,%s)", values, types)
3981        self.assertEqual(sql, 'values($1,current_date)')
3982        self.assertEqual(params, ['2016-01-30'])
3983        values = ([1, 2, 3], ['a', 'b', 'c'])
3984        types = ('_int4', '_text')
3985        sql, params = format_query("%s::int4[],%s::text[]", values, types)
3986        self.assertEqual(sql, '$1::int4[],$2::text[]')
3987        self.assertEqual(params, ['{1,2,3}', '{a,b,c}'])
3988        types = ('_bool', '_bool')
3989        sql, params = format_query("%s::bool[],%s::bool[]", values, types)
3990        self.assertEqual(sql, '$1::bool[],$2::bool[]')
3991        self.assertEqual(params, ['{t,t,t}', '{f,f,f}'])
3992        values = [(3, 7.5, 'hello', True, [123], ['abc'])]
3993        t = self.adapter.simple_type
3994        typ = t('record')
3995        typ._get_attnames = lambda _self: pg.AttrDict([
3996            ('i', t('int')), ('f', t('float')),
3997            ('t', t('text')), ('b', t('bool')),
3998            ('i3', t('int[]')), ('t3', t('text[]'))])
3999        types = [typ]
4000        sql, params = format_query('select %s', values, types)
4001        self.assertEqual(sql, 'select $1')
4002        self.assertEqual(params, ['(3,7.5,hello,t,{123},{abc})'])
4003
4004    def testAdaptQueryTypedDict(self):
4005        format_query = self.adapter.format_query
4006        self.assertRaises(TypeError, format_query,
4007            '%s,%s', dict(i1=1, i2=2), dict(i1='int2'))
4008        values = dict(i=3, f=7.5, t='hello', b=True)
4009        types = dict(i='int4', f='float4',
4010            t='text', b='bool')
4011        sql, params = format_query(
4012            "select %(i)s,%(f)s,%(t)s,%(b)s", values, types)
4013        self.assertEqual(sql, 'select $3,$2,$4,$1')
4014        self.assertEqual(params, ['t', 7.5, 3, 'hello'])
4015        types = dict(i='bool', f='bool',
4016            t='bool', b='bool')
4017        sql, params = format_query(
4018            "select %(i)s,%(f)s,%(t)s,%(b)s", values, types)
4019        self.assertEqual(sql, 'select $3,$2,$4,$1')
4020        self.assertEqual(params, ['t', 't', 't', 'f'])
4021        values = dict(d1='2016-01-30', d2='current_date')
4022        types = dict(d1='date', d2='date')
4023        sql, params = format_query("values(%(d1)s,%(d2)s)", values, types)
4024        self.assertEqual(sql, 'values($1,current_date)')
4025        self.assertEqual(params, ['2016-01-30'])
4026        values = dict(i=[1, 2, 3], t=['a', 'b', 'c'])
4027        types = dict(i='_int4', t='_text')
4028        sql, params = format_query(
4029            "%(i)s::int4[],%(t)s::text[]", values, types)
4030        self.assertEqual(sql, '$1::int4[],$2::text[]')
4031        self.assertEqual(params, ['{1,2,3}', '{a,b,c}'])
4032        types = dict(i='_bool', t='_bool')
4033        sql, params = format_query(
4034            "%(i)s::bool[],%(t)s::bool[]", values, types)
4035        self.assertEqual(sql, '$1::bool[],$2::bool[]')
4036        self.assertEqual(params, ['{t,t,t}', '{f,f,f}'])
4037        values = dict(record=(3, 7.5, 'hello', True, [123], ['abc']))
4038        t = self.adapter.simple_type
4039        typ = t('record')
4040        typ._get_attnames = lambda _self: pg.AttrDict([
4041            ('i', t('int')), ('f', t('float')),
4042            ('t', t('text')), ('b', t('bool')),
4043            ('i3', t('int[]')), ('t3', t('text[]'))])
4044        types = dict(record=typ)
4045        sql, params = format_query('select %(record)s', values, types)
4046        self.assertEqual(sql, 'select $1')
4047        self.assertEqual(params, ['(3,7.5,hello,t,{123},{abc})'])
4048
4049    def testAdaptQueryUntypedList(self):
4050        format_query = self.adapter.format_query
4051        values = (3, 7.5, 'hello', True)
4052        sql, params = format_query("select %s,%s,%s,%s", values)
4053        self.assertEqual(sql, 'select $1,$2,$3,$4')
4054        self.assertEqual(params, [3, 7.5, 'hello', 't'])
4055        values = [date(2016, 1, 30), 'current_date']
4056        sql, params = format_query("values(%s,%s)", values)
4057        self.assertEqual(sql, 'values($1,$2)')
4058        self.assertEqual(params, values)
4059        values = ([1, 2, 3], ['a', 'b', 'c'], [True, False, True])
4060        sql, params = format_query("%s,%s,%s", values)
4061        self.assertEqual(sql, "$1,$2,$3")
4062        self.assertEqual(params, ['{1,2,3}', '{a,b,c}', '{t,f,t}'])
4063        values = ([[1, 2], [3, 4]], [['a', 'b'], ['c', 'd']],
4064            [[True, False], [False, True]])
4065        sql, params = format_query("%s,%s,%s", values)
4066        self.assertEqual(sql, "$1,$2,$3")
4067        self.assertEqual(params, [
4068            '{{1,2},{3,4}}', '{{a,b},{c,d}}', '{{t,f},{f,t}}'])
4069        values = [(3, 7.5, 'hello', True, [123], ['abc'])]
4070        sql, params = format_query('select %s', values)
4071        self.assertEqual(sql, 'select $1')
4072        self.assertEqual(params, ['(3,7.5,hello,t,{123},{abc})'])
4073
4074    def testAdaptQueryUntypedDict(self):
4075        format_query = self.adapter.format_query
4076        values = dict(i=3, f=7.5, t='hello', b=True)
4077        sql, params = format_query(
4078            "select %(i)s,%(f)s,%(t)s,%(b)s", values)
4079        self.assertEqual(sql, 'select $3,$2,$4,$1')
4080        self.assertEqual(params, ['t', 7.5, 3, 'hello'])
4081        values = dict(d1='2016-01-30', d2='current_date')
4082        sql, params = format_query("values(%(d1)s,%(d2)s)", values)
4083        self.assertEqual(sql, 'values($1,$2)')
4084        self.assertEqual(params, [values['d1'], values['d2']])
4085        values = dict(i=[1, 2, 3], t=['a', 'b', 'c'], b=[True, False, True])
4086        sql, params = format_query("%(i)s,%(t)s,%(b)s", values)
4087        self.assertEqual(sql, "$2,$3,$1")
4088        self.assertEqual(params, ['{t,f,t}', '{1,2,3}', '{a,b,c}'])
4089        values = dict(i=[[1, 2], [3, 4]], t=[['a', 'b'], ['c', 'd']],
4090            b=[[True, False], [False, True]])
4091        sql, params = format_query("%(i)s,%(t)s,%(b)s", values)
4092        self.assertEqual(sql, "$2,$3,$1")
4093        self.assertEqual(params, [
4094            '{{t,f},{f,t}}', '{{1,2},{3,4}}', '{{a,b},{c,d}}'])
4095        values = dict(record=(3, 7.5, 'hello', True, [123], ['abc']))
4096        sql, params = format_query('select %(record)s', values)
4097        self.assertEqual(sql, 'select $1')
4098        self.assertEqual(params, ['(3,7.5,hello,t,{123},{abc})'])
4099
4100    def testAdaptQueryInlineList(self):
4101        format_query = self.adapter.format_query
4102        values = (3, 7.5, 'hello', True)
4103        sql, params = format_query("select %s,%s,%s,%s", values, inline=True)
4104        self.assertEqual(sql, "select 3,7.5,'hello',true")
4105        self.assertEqual(params, [])
4106        values = [date(2016, 1, 30), 'current_date']
4107        sql, params = format_query("values(%s,%s)", values, inline=True)
4108        self.assertEqual(sql, "values('2016-01-30','current_date')")
4109        self.assertEqual(params, [])
4110        values = ([1, 2, 3], ['a', 'b', 'c'], [True, False, True])
4111        sql, params = format_query("%s,%s,%s", values, inline=True)
4112        self.assertEqual(sql,
4113            "ARRAY[1,2,3],ARRAY['a','b','c'],ARRAY[true,false,true]")
4114        self.assertEqual(params, [])
4115        values = ([[1, 2], [3, 4]], [['a', 'b'], ['c', 'd']],
4116            [[True, False], [False, True]])
4117        sql, params = format_query("%s,%s,%s", values, inline=True)
4118        self.assertEqual(sql, "ARRAY[[1,2],[3,4]],ARRAY[['a','b'],['c','d']],"
4119            "ARRAY[[true,false],[false,true]]")
4120        self.assertEqual(params, [])
4121        values = [(3, 7.5, 'hello', True, [123], ['abc'])]
4122        sql, params = format_query('select %s', values, inline=True)
4123        self.assertEqual(sql,
4124            "select (3,7.5,'hello',true,ARRAY[123],ARRAY['abc'])")
4125        self.assertEqual(params, [])
4126
4127    def testAdaptQueryInlineDict(self):
4128        format_query = self.adapter.format_query
4129        values = dict(i=3, f=7.5, t='hello', b=True)
4130        sql, params = format_query(
4131            "select %(i)s,%(f)s,%(t)s,%(b)s", values, inline=True)
4132        self.assertEqual(sql, "select 3,7.5,'hello',true")
4133        self.assertEqual(params, [])
4134        values = dict(d1='2016-01-30', d2='current_date')
4135        sql, params = format_query(
4136            "values(%(d1)s,%(d2)s)", values, inline=True)
4137        self.assertEqual(sql, "values('2016-01-30','current_date')")
4138        self.assertEqual(params, [])
4139        values = dict(i=[1, 2, 3], t=['a', 'b', 'c'], b=[True, False, True])
4140        sql, params = format_query("%(i)s,%(t)s,%(b)s", values, inline=True)
4141        self.assertEqual(sql,
4142            "ARRAY[1,2,3],ARRAY['a','b','c'],ARRAY[true,false,true]")
4143        self.assertEqual(params, [])
4144        values = dict(i=[[1, 2], [3, 4]], t=[['a', 'b'], ['c', 'd']],
4145            b=[[True, False], [False, True]])
4146        sql, params = format_query("%(i)s,%(t)s,%(b)s", values, inline=True)
4147        self.assertEqual(sql, "ARRAY[[1,2],[3,4]],ARRAY[['a','b'],['c','d']],"
4148            "ARRAY[[true,false],[false,true]]")
4149        self.assertEqual(params, [])
4150        values = dict(record=(3, 7.5, 'hello', True, [123], ['abc']))
4151        sql, params = format_query('select %(record)s', values, inline=True)
4152        self.assertEqual(sql,
4153            "select (3,7.5,'hello',true,ARRAY[123],ARRAY['abc'])")
4154        self.assertEqual(params, [])
4155
4156    def testAdaptQueryWithPgRepr(self):
4157        format_query = self.adapter.format_query
4158        self.assertRaises(TypeError, format_query,
4159            '%s', object(), inline=True)
4160        class TestObject:
4161            def __pg_repr__(self):
4162                return "'adapted'"
4163        sql, params = format_query('select %s', [TestObject()], inline=True)
4164        self.assertEqual(sql, "select 'adapted'")
4165        self.assertEqual(params, [])
4166        sql, params = format_query('select %s', [[TestObject()]], inline=True)
4167        self.assertEqual(sql, "select ARRAY['adapted']")
4168        self.assertEqual(params, [])
4169
4170
4171class TestSchemas(unittest.TestCase):
4172    """Test correct handling of schemas (namespaces)."""
4173
4174    cls_set_up = False
4175
4176    @classmethod
4177    def setUpClass(cls):
4178        db = DB()
4179        query = db.query
4180        for num_schema in range(5):
4181            if num_schema:
4182                schema = "s%d" % num_schema
4183                query("drop schema if exists %s cascade" % (schema,))
4184                try:
4185                    query("create schema %s" % (schema,))
4186                except pg.ProgrammingError:
4187                    raise RuntimeError("The test user cannot create schemas.\n"
4188                        "Grant create on database %s to the user"
4189                        " for running these tests." % dbname)
4190            else:
4191                schema = "public"
4192                query("drop table if exists %s.t" % (schema,))
4193                query("drop table if exists %s.t%d" % (schema, num_schema))
4194            query("create table %s.t with oids as select 1 as n, %d as d"
4195                  % (schema, num_schema))
4196            query("create table %s.t%d with oids as select 1 as n, %d as d"
4197                  % (schema, num_schema, num_schema))
4198        db.close()
4199        cls.cls_set_up = True
4200
4201    @classmethod
4202    def tearDownClass(cls):
4203        db = DB()
4204        query = db.query
4205        for num_schema in range(5):
4206            if num_schema:
4207                schema = "s%d" % num_schema
4208                query("drop schema %s cascade" % (schema,))
4209            else:
4210                schema = "public"
4211                query("drop table %s.t" % (schema,))
4212                query("drop table %s.t%d" % (schema, num_schema))
4213        db.close()
4214
4215    def setUp(self):
4216        self.assertTrue(self.cls_set_up)
4217        self.db = DB()
4218
4219    def tearDown(self):
4220        self.doCleanups()
4221        self.db.close()
4222
4223    def testGetTables(self):
4224        tables = self.db.get_tables()
4225        for num_schema in range(5):
4226            if num_schema:
4227                schema = "s" + str(num_schema)
4228            else:
4229                schema = "public"
4230            for t in (schema + ".t",
4231                    schema + ".t" + str(num_schema)):
4232                self.assertIn(t, tables)
4233
4234    def testGetAttnames(self):
4235        get_attnames = self.db.get_attnames
4236        query = self.db.query
4237        result = {'oid': 'int', 'd': 'int', 'n': 'int'}
4238        r = get_attnames("t")
4239        self.assertEqual(r, result)
4240        r = get_attnames("s4.t4")
4241        self.assertEqual(r, result)
4242        query("drop table if exists s3.t3m")
4243        self.addCleanup(query, "drop table s3.t3m")
4244        query("create table s3.t3m with oids as select 1 as m")
4245        result_m = {'oid': 'int', 'm': 'int'}
4246        r = get_attnames("s3.t3m")
4247        self.assertEqual(r, result_m)
4248        query("set search_path to s1,s3")
4249        r = get_attnames("t3")
4250        self.assertEqual(r, result)
4251        r = get_attnames("t3m")
4252        self.assertEqual(r, result_m)
4253
4254    def testGet(self):
4255        get = self.db.get
4256        query = self.db.query
4257        PrgError = pg.ProgrammingError
4258        self.assertEqual(get("t", 1, 'n')['d'], 0)
4259        self.assertEqual(get("t0", 1, 'n')['d'], 0)
4260        self.assertEqual(get("public.t", 1, 'n')['d'], 0)
4261        self.assertEqual(get("public.t0", 1, 'n')['d'], 0)
4262        self.assertRaises(PrgError, get, "public.t1", 1, 'n')
4263        self.assertEqual(get("s1.t1", 1, 'n')['d'], 1)
4264        self.assertEqual(get("s3.t", 1, 'n')['d'], 3)
4265        query("set search_path to s2,s4")
4266        self.assertRaises(PrgError, get, "t1", 1, 'n')
4267        self.assertEqual(get("t4", 1, 'n')['d'], 4)
4268        self.assertRaises(PrgError, get, "t3", 1, 'n')
4269        self.assertEqual(get("t", 1, 'n')['d'], 2)
4270        self.assertEqual(get("s3.t3", 1, 'n')['d'], 3)
4271        query("set search_path to s1,s3")
4272        self.assertRaises(PrgError, get, "t2", 1, 'n')
4273        self.assertEqual(get("t3", 1, 'n')['d'], 3)
4274        self.assertRaises(PrgError, get, "t4", 1, 'n')
4275        self.assertEqual(get("t", 1, 'n')['d'], 1)
4276        self.assertEqual(get("s4.t4", 1, 'n')['d'], 4)
4277
4278    def testMunging(self):
4279        get = self.db.get
4280        query = self.db.query
4281        r = get("t", 1, 'n')
4282        self.assertIn('oid(t)', r)
4283        query("set search_path to s2")
4284        r = get("t2", 1, 'n')
4285        self.assertIn('oid(t2)', r)
4286        query("set search_path to s3")
4287        r = get("t", 1, 'n')
4288        self.assertIn('oid(t)', r)
4289
4290
4291class TestDebug(unittest.TestCase):
4292    """Test the debug attribute of the DB class."""
4293
4294    def setUp(self):
4295        self.db = DB()
4296        self.query = self.db.query
4297        self.debug = self.db.debug
4298        self.output = StringIO()
4299        self.stdout, sys.stdout = sys.stdout, self.output
4300
4301    def tearDown(self):
4302        sys.stdout = self.stdout
4303        self.output.close()
4304        self.db.debug = debug
4305        self.db.close()
4306
4307    def get_output(self):
4308        return self.output.getvalue()
4309
4310    def send_queries(self):
4311        self.db.query("select 1")
4312        self.db.query("select 2")
4313
4314    def testDebugDefault(self):
4315        if debug:
4316            self.assertEqual(self.db.debug, debug)
4317        else:
4318            self.assertIsNone(self.db.debug)
4319
4320    def testDebugIsFalse(self):
4321        self.db.debug = False
4322        self.send_queries()
4323        self.assertEqual(self.get_output(), "")
4324
4325    def testDebugIsTrue(self):
4326        self.db.debug = True
4327        self.send_queries()
4328        self.assertEqual(self.get_output(), "select 1\nselect 2\n")
4329
4330    def testDebugIsString(self):
4331        self.db.debug = "Test with string: %s."
4332        self.send_queries()
4333        self.assertEqual(self.get_output(),
4334            "Test with string: select 1.\nTest with string: select 2.\n")
4335
4336    def testDebugIsFileLike(self):
4337        with tempfile.TemporaryFile('w+') as debug_file:
4338            self.db.debug = debug_file
4339            self.send_queries()
4340            debug_file.seek(0)
4341            output = debug_file.read()
4342            self.assertEqual(output, "select 1\nselect 2\n")
4343            self.assertEqual(self.get_output(), "")
4344
4345    def testDebugIsCallable(self):
4346        output = []
4347        self.db.debug = output.append
4348        self.db.query("select 1")
4349        self.db.query("select 2")
4350        self.assertEqual(output, ["select 1", "select 2"])
4351        self.assertEqual(self.get_output(), "")
4352
4353    def testDebugMultipleArgs(self):
4354        output = []
4355        self.db.debug = output.append
4356        args = ['Error', 42, {1: 'a', 2: 'b'}, [3, 5, 7]]
4357        self.db._do_debug(*args)
4358        self.assertEqual(output, ['\n'.join(str(arg) for arg in args)])
4359        self.assertEqual(self.get_output(), "")
4360
4361
4362if __name__ == '__main__':
4363    unittest.main()
Note: See TracBrowser for help on using the repository browser.