source: trunk/tests/test_classic_dbwrapper.py @ 857

Last change on this file since 857 was 857, checked in by cito, 3 years ago

Add system parameter to get_relations()

Also fix a regression in the 4.x branch when using temporary tables,
related to filtering system tables (as discussed on the mailing list).

  • Property svn:executable set to *
  • Property svn:keywords set to Id
File size: 177.4 KB
Line 
1#! /usr/bin/python
2# -*- coding: utf-8 -*-
3
4"""Test the classic PyGreSQL interface.
5
6Sub-tests for the DB wrapper object.
7
8Contributed by Christoph Zwerschke.
9
10These tests need a database to test against.
11"""
12
13try:
14    import unittest2 as unittest  # for Python < 2.7
15except ImportError:
16    import unittest
17
18import os
19import sys
20import json
21import tempfile
22
23import pg  # the module under test
24
25from decimal import Decimal
26from datetime import date, time, datetime, timedelta
27from uuid import UUID
28from time import strftime
29from operator import itemgetter
30
31# We need a database to test against.  If LOCAL_PyGreSQL.py exists we will
32# get our information from that.  Otherwise we use the defaults.
33# The current user must have create schema privilege on the database.
34dbname = 'unittest'
35dbhost = None
36dbport = 5432
37
38debug = False  # let DB wrapper print debugging output
39
40try:
41    from .LOCAL_PyGreSQL import *
42except (ImportError, ValueError):
43    try:
44        from LOCAL_PyGreSQL import *
45    except ImportError:
46        pass
47
48try:
49    long
50except NameError:  # Python >= 3.0
51    long = int
52
53try:
54    unicode
55except NameError:  # Python >= 3.0
56    unicode = str
57
58try:
59    from collections import OrderedDict
60except ImportError:  # Python 2.6 or 3.0
61    OrderedDict = dict
62
63if str is bytes:
64    from StringIO import StringIO
65else:
66    from io import StringIO
67
68windows = os.name == 'nt'
69
70# There is a known a bug in libpq under Windows which can cause
71# the interface to crash when calling PQhost():
72do_not_ask_for_host = windows
73do_not_ask_for_host_reason = 'libpq issue on Windows'
74
75
76def DB():
77    """Create a DB wrapper object connecting to the test database."""
78    db = pg.DB(dbname, dbhost, dbport)
79    if debug:
80        db.debug = debug
81    db.query("set client_min_messages=warning")
82    return db
83
84
85class TestAttrDict(unittest.TestCase):
86    """Test the simple ordered dictionary for attribute names."""
87
88    cls = pg.AttrDict
89    base = OrderedDict
90
91    def testInit(self):
92        a = self.cls()
93        self.assertIsInstance(a, self.base)
94        self.assertEqual(a, self.base())
95        items = [('id', 'int'), ('name', 'text')]
96        a = self.cls(items)
97        self.assertIsInstance(a, self.base)
98        self.assertEqual(a, self.base(items))
99        iteritems = iter(items)
100        a = self.cls(iteritems)
101        self.assertIsInstance(a, self.base)
102        self.assertEqual(a, self.base(items))
103
104    def testIter(self):
105        a = self.cls()
106        self.assertEqual(list(a), [])
107        keys = ['id', 'name', 'age']
108        items = [(key, None) for key in keys]
109        a = self.cls(items)
110        self.assertEqual(list(a), keys)
111
112    def testKeys(self):
113        a = self.cls()
114        self.assertEqual(list(a.keys()), [])
115        keys = ['id', 'name', 'age']
116        items = [(key, None) for key in keys]
117        a = self.cls(items)
118        self.assertEqual(list(a.keys()), keys)
119
120    def testValues(self):
121        a = self.cls()
122        self.assertEqual(list(a.values()), [])
123        items = [('id', 'int'), ('name', 'text')]
124        values = [item[1] for item in items]
125        a = self.cls(items)
126        self.assertEqual(list(a.values()), values)
127
128    def testItems(self):
129        a = self.cls()
130        self.assertEqual(list(a.items()), [])
131        items = [('id', 'int'), ('name', 'text')]
132        a = self.cls(items)
133        self.assertEqual(list(a.items()), items)
134
135    def testGet(self):
136        a = self.cls([('id', 1)])
137        try:
138            self.assertEqual(a['id'], 1)
139        except KeyError:
140            self.fail('AttrDict should be readable')
141
142    def testSet(self):
143        a = self.cls()
144        try:
145            a['id'] = 1
146        except TypeError:
147            pass
148        else:
149            self.fail('AttrDict should be read-only')
150
151    def testDel(self):
152        a = self.cls([('id', 1)])
153        try:
154            del a['id']
155        except TypeError:
156            pass
157        else:
158            self.fail('AttrDict should be read-only')
159
160    def testWriteMethods(self):
161        a = self.cls([('id', 1)])
162        self.assertEqual(a['id'], 1)
163        for method in 'clear', 'update', 'pop', 'setdefault', 'popitem':
164            method = getattr(a, method)
165            self.assertRaises(TypeError, method, a)
166
167
168class TestDBClassBasic(unittest.TestCase):
169    """Test existence of the DB class wrapped pg connection methods."""
170
171    def setUp(self):
172        self.db = DB()
173
174    def tearDown(self):
175        try:
176            self.db.close()
177        except pg.InternalError:
178            pass
179
180    def testAllDBAttributes(self):
181        attributes = [
182            'abort', 'adapter',
183            'begin',
184            'cancel', 'clear', 'close', 'commit',
185            'date_format', 'db', 'dbname', 'dbtypes',
186            'debug', 'decode_json', 'delete',
187            'encode_json', 'end', 'endcopy', 'error',
188            'escape_bytea', 'escape_identifier',
189            'escape_literal', 'escape_string',
190            'fileno',
191            'get', 'get_as_dict', 'get_as_list',
192            'get_attnames', 'get_cast_hook',
193            'get_databases', 'get_notice_receiver',
194            'get_parameter', 'get_relations', 'get_tables',
195            'getline', 'getlo', 'getnotify',
196            'has_table_privilege', 'host',
197            'insert', 'inserttable',
198            'locreate', 'loimport',
199            'notification_handler',
200            'options',
201            'parameter', 'pkey', 'port',
202            'protocol_version', 'putline',
203            'query', 'query_formatted',
204            'release', 'reopen', 'reset', 'rollback',
205            'savepoint', 'server_version',
206            'set_cast_hook', 'set_notice_receiver',
207            'set_parameter',
208            'source', 'start', 'status',
209            'transaction', 'truncate',
210            'unescape_bytea', 'update', 'upsert',
211            'use_regtypes', 'user',
212        ]
213        # __dir__ is not called in Python 2.6 for old-style classes
214        db_attributes = dir(self.db) if hasattr(
215            self.db.__class__, '__class__') else self.db.__dir__()
216        db_attributes = [a for a in db_attributes
217            if not a.startswith('_')]
218        self.assertEqual(attributes, db_attributes)
219
220    def testAttributeDb(self):
221        self.assertEqual(self.db.db.db, dbname)
222
223    def testAttributeDbname(self):
224        self.assertEqual(self.db.dbname, dbname)
225
226    def testAttributeError(self):
227        error = self.db.error
228        self.assertTrue(not error or 'krb5_' in error)
229        self.assertEqual(self.db.error, self.db.db.error)
230
231    @unittest.skipIf(do_not_ask_for_host, do_not_ask_for_host_reason)
232    def testAttributeHost(self):
233        def_host = 'localhost'
234        host = self.db.host
235        self.assertIsInstance(host, str)
236        self.assertEqual(host, dbhost or def_host)
237        self.assertEqual(host, self.db.db.host)
238
239    def testAttributeOptions(self):
240        no_options = ''
241        options = self.db.options
242        self.assertEqual(options, no_options)
243        self.assertEqual(options, self.db.db.options)
244
245    def testAttributePort(self):
246        def_port = 5432
247        port = self.db.port
248        self.assertIsInstance(port, int)
249        self.assertEqual(port, dbport or def_port)
250        self.assertEqual(port, self.db.db.port)
251
252    def testAttributeProtocolVersion(self):
253        protocol_version = self.db.protocol_version
254        self.assertIsInstance(protocol_version, int)
255        self.assertTrue(2 <= protocol_version < 4)
256        self.assertEqual(protocol_version, self.db.db.protocol_version)
257
258    def testAttributeServerVersion(self):
259        server_version = self.db.server_version
260        self.assertIsInstance(server_version, int)
261        self.assertTrue(70400 <= server_version < 100000)
262        self.assertEqual(server_version, self.db.db.server_version)
263
264    def testAttributeStatus(self):
265        status_ok = 1
266        status = self.db.status
267        self.assertIsInstance(status, int)
268        self.assertEqual(status, status_ok)
269        self.assertEqual(status, self.db.db.status)
270
271    def testAttributeUser(self):
272        no_user = 'Deprecated facility'
273        user = self.db.user
274        self.assertTrue(user)
275        self.assertIsInstance(user, str)
276        self.assertNotEqual(user, no_user)
277        self.assertEqual(user, self.db.db.user)
278
279    def testMethodEscapeLiteral(self):
280        self.assertEqual(self.db.escape_literal(''), "''")
281
282    def testMethodEscapeIdentifier(self):
283        self.assertEqual(self.db.escape_identifier(''), '""')
284
285    def testMethodEscapeString(self):
286        self.assertEqual(self.db.escape_string(''), '')
287
288    def testMethodEscapeBytea(self):
289        self.assertEqual(self.db.escape_bytea('').replace(
290            '\\x', '').replace('\\', ''), '')
291
292    def testMethodUnescapeBytea(self):
293        self.assertEqual(self.db.unescape_bytea(''), b'')
294
295    def testMethodDecodeJson(self):
296        self.assertEqual(self.db.decode_json('{}'), {})
297
298    def testMethodEncodeJson(self):
299        self.assertEqual(self.db.encode_json({}), '{}')
300
301    def testMethodQuery(self):
302        query = self.db.query
303        query("select 1+1")
304        query("select 1+$1+$2", 2, 3)
305        query("select 1+$1+$2", (2, 3))
306        query("select 1+$1+$2", [2, 3])
307        query("select 1+$1", 1)
308
309    def testMethodQueryEmpty(self):
310        self.assertRaises(ValueError, self.db.query, '')
311
312    def testMethodQueryDataError(self):
313        try:
314            self.db.query("select 1/0")
315        except pg.DataError as error:
316            self.assertEqual(error.sqlstate, '22012')
317
318    def testMethodEndcopy(self):
319        try:
320            self.db.endcopy()
321        except IOError:
322            pass
323
324    def testMethodClose(self):
325        self.db.close()
326        try:
327            self.db.reset()
328        except pg.Error:
329            pass
330        else:
331            self.fail('Reset should give an error for a closed connection')
332        self.assertIsNone(self.db.db)
333        self.assertRaises(pg.InternalError, self.db.close)
334        self.assertRaises(pg.InternalError, self.db.query, 'select 1')
335        self.assertRaises(pg.InternalError, getattr, self.db, 'status')
336        self.assertRaises(pg.InternalError, getattr, self.db, 'error')
337        self.assertRaises(pg.InternalError, getattr, self.db, 'absent')
338
339    def testMethodReset(self):
340        con = self.db.db
341        self.db.reset()
342        self.assertIs(self.db.db, con)
343        self.db.query("select 1+1")
344        self.db.close()
345        self.assertRaises(pg.InternalError, self.db.reset)
346
347    def testMethodReopen(self):
348        con = self.db.db
349        self.db.reopen()
350        self.assertIsNot(self.db.db, con)
351        con = self.db.db
352        self.db.query("select 1+1")
353        self.db.close()
354        self.db.reopen()
355        self.assertIsNot(self.db.db, con)
356        self.db.query("select 1+1")
357        self.db.close()
358
359    def testExistingConnection(self):
360        db = pg.DB(self.db.db)
361        self.assertEqual(self.db.db, db.db)
362        self.assertTrue(db.db)
363        db.close()
364        self.assertTrue(db.db)
365        db.reopen()
366        self.assertTrue(db.db)
367        db.close()
368        self.assertTrue(db.db)
369        db = pg.DB(self.db)
370        self.assertEqual(self.db.db, db.db)
371        db = pg.DB(db=self.db.db)
372        self.assertEqual(self.db.db, db.db)
373
374        class DB2:
375            pass
376
377        db2 = DB2()
378        db2._cnx = self.db.db
379        db = pg.DB(db2)
380        self.assertEqual(self.db.db, db.db)
381
382
383class TestDBClass(unittest.TestCase):
384    """Test the methods of the DB class wrapped pg connection."""
385
386    maxDiff = 80 * 20
387
388    cls_set_up = False
389
390    regtypes = None
391
392    @classmethod
393    def setUpClass(cls):
394        db = DB()
395        db.query("drop table if exists test cascade")
396        db.query("create table test ("
397                 "i2 smallint, i4 integer, i8 bigint,"
398                 " d numeric, f4 real, f8 double precision, m money,"
399                 " v4 varchar(4), c4 char(4), t text)")
400        db.query("create or replace view test_view as"
401                 " select i4, v4 from test")
402        db.close()
403        cls.cls_set_up = True
404
405    @classmethod
406    def tearDownClass(cls):
407        db = DB()
408        db.query("drop table test cascade")
409        db.close()
410
411    def setUp(self):
412        self.assertTrue(self.cls_set_up)
413        self.db = DB()
414        if self.regtypes is None:
415            self.regtypes = self.db.use_regtypes()
416        else:
417            self.db.use_regtypes(self.regtypes)
418        query = self.db.query
419        query('set client_encoding=utf8')
420        query('set standard_conforming_strings=on')
421        query("set lc_monetary='C'")
422        query("set datestyle='ISO,YMD'")
423        query('set bytea_output=hex')
424
425    def tearDown(self):
426        self.doCleanups()
427        self.db.close()
428
429    def createTable(self, table, definition,
430                    temporary=True, oids=None, values=None):
431        query = self.db.query
432        if not '"' in table or '.' in table:
433            table = '"%s"' % table
434        if not temporary:
435            q = 'drop table if exists %s cascade' % table
436            query(q)
437            self.addCleanup(query, q)
438        temporary = 'temporary table' if temporary else 'table'
439        as_query = definition.startswith(('as ', 'AS '))
440        if not as_query and not definition.startswith('('):
441            definition = '(%s)' % definition
442        with_oids = 'with oids' if oids else 'without oids'
443        q = ['create', temporary, table]
444        if as_query:
445            q.extend([with_oids, definition])
446        else:
447            q.extend([definition, with_oids])
448        q = ' '.join(q)
449        query(q)
450        if values:
451            for params in values:
452                if not isinstance(params, (list, tuple)):
453                    params = [params]
454                values = ', '.join('$%d' % (n + 1) for n in range(len(params)))
455                q = "insert into %s values (%s)" % (table, values)
456                query(q, params)
457
458    def testClassName(self):
459        self.assertEqual(self.db.__class__.__name__, 'DB')
460
461    def testModuleName(self):
462        self.assertEqual(self.db.__module__, 'pg')
463        self.assertEqual(self.db.__class__.__module__, 'pg')
464
465    def testEscapeLiteral(self):
466        f = self.db.escape_literal
467        r = f(b"plain")
468        self.assertIsInstance(r, bytes)
469        self.assertEqual(r, b"'plain'")
470        r = f(u"plain")
471        self.assertIsInstance(r, unicode)
472        self.assertEqual(r, u"'plain'")
473        r = f(u"that's kÀse".encode('utf-8'))
474        self.assertIsInstance(r, bytes)
475        self.assertEqual(r, u"'that''s kÀse'".encode('utf-8'))
476        r = f(u"that's kÀse")
477        self.assertIsInstance(r, unicode)
478        self.assertEqual(r, u"'that''s kÀse'")
479        self.assertEqual(f(r"It's fine to have a \ inside."),
480                         r" E'It''s fine to have a \\ inside.'")
481        self.assertEqual(f('No "quotes" must be escaped.'),
482                         "'No \"quotes\" must be escaped.'")
483
484    def testEscapeIdentifier(self):
485        f = self.db.escape_identifier
486        r = f(b"plain")
487        self.assertIsInstance(r, bytes)
488        self.assertEqual(r, b'"plain"')
489        r = f(u"plain")
490        self.assertIsInstance(r, unicode)
491        self.assertEqual(r, u'"plain"')
492        r = f(u"that's kÀse".encode('utf-8'))
493        self.assertIsInstance(r, bytes)
494        self.assertEqual(r, u'"that\'s kÀse"'.encode('utf-8'))
495        r = f(u"that's kÀse")
496        self.assertIsInstance(r, unicode)
497        self.assertEqual(r, u'"that\'s kÀse"')
498        self.assertEqual(f(r"It's fine to have a \ inside."),
499                         '"It\'s fine to have a \\ inside."')
500        self.assertEqual(f('All "quotes" must be escaped.'),
501                         '"All ""quotes"" must be escaped."')
502
503    def testEscapeString(self):
504        f = self.db.escape_string
505        r = f(b"plain")
506        self.assertIsInstance(r, bytes)
507        self.assertEqual(r, b"plain")
508        r = f(u"plain")
509        self.assertIsInstance(r, unicode)
510        self.assertEqual(r, u"plain")
511        r = f(u"that's kÀse".encode('utf-8'))
512        self.assertIsInstance(r, bytes)
513        self.assertEqual(r, u"that''s kÀse".encode('utf-8'))
514        r = f(u"that's kÀse")
515        self.assertIsInstance(r, unicode)
516        self.assertEqual(r, u"that''s kÀse")
517        self.assertEqual(f(r"It's fine to have a \ inside."),
518                         r"It''s fine to have a \ inside.")
519
520    def testEscapeBytea(self):
521        f = self.db.escape_bytea
522        # note that escape_byte always returns hex output since Pg 9.0,
523        # regardless of the bytea_output setting
524        r = f(b'plain')
525        self.assertIsInstance(r, bytes)
526        self.assertEqual(r, b'\\x706c61696e')
527        r = f(u'plain')
528        self.assertIsInstance(r, unicode)
529        self.assertEqual(r, u'\\x706c61696e')
530        r = f(u"das is' kÀse".encode('utf-8'))
531        self.assertIsInstance(r, bytes)
532        self.assertEqual(r, b'\\x64617320697327206bc3a47365')
533        r = f(u"das is' kÀse")
534        self.assertIsInstance(r, unicode)
535        self.assertEqual(r, u'\\x64617320697327206bc3a47365')
536        self.assertEqual(f(b'O\x00ps\xff!'), b'\\x4f007073ff21')
537
538    def testUnescapeBytea(self):
539        f = self.db.unescape_bytea
540        r = f(b'plain')
541        self.assertIsInstance(r, bytes)
542        self.assertEqual(r, b'plain')
543        r = f(u'plain')
544        self.assertIsInstance(r, bytes)
545        self.assertEqual(r, b'plain')
546        r = f(b"das is' k\\303\\244se")
547        self.assertIsInstance(r, bytes)
548        self.assertEqual(r, u"das is' kÀse".encode('utf8'))
549        r = f(u"das is' k\\303\\244se")
550        self.assertIsInstance(r, bytes)
551        self.assertEqual(r, u"das is' kÀse".encode('utf8'))
552        self.assertEqual(f(r'O\\000ps\\377!'), b'O\\000ps\\377!')
553        self.assertEqual(f(r'\\x706c61696e'), b'\\x706c61696e')
554        self.assertEqual(f(r'\\x746861742773206be47365'),
555                         b'\\x746861742773206be47365')
556        self.assertEqual(f(r'\\x4f007073ff21'), b'\\x4f007073ff21')
557
558    def testDecodeJson(self):
559        f = self.db.decode_json
560        self.assertIsNone(f('null'))
561        data = {
562            "id": 1, "name": "Foo", "price": 1234.5,
563            "new": True, "note": None,
564            "tags": ["Bar", "Eek"],
565            "stock": {"warehouse": 300, "retail": 20}}
566        text = json.dumps(data)
567        r = f(text)
568        self.assertIsInstance(r, dict)
569        self.assertEqual(r, data)
570        self.assertIsInstance(r['id'], int)
571        self.assertIsInstance(r['name'], unicode)
572        self.assertIsInstance(r['price'], float)
573        self.assertIsInstance(r['new'], bool)
574        self.assertIsInstance(r['tags'], list)
575        self.assertIsInstance(r['stock'], dict)
576
577    def testEncodeJson(self):
578        f = self.db.encode_json
579        self.assertEqual(f(None), 'null')
580        data = {
581            "id": 1, "name": "Foo", "price": 1234.5,
582            "new": True, "note": None,
583            "tags": ["Bar", "Eek"],
584            "stock": {"warehouse": 300, "retail": 20}}
585        text = json.dumps(data)
586        r = f(data)
587        self.assertIsInstance(r, str)
588        self.assertEqual(r, text)
589
590    def testGetParameter(self):
591        f = self.db.get_parameter
592        self.assertRaises(TypeError, f)
593        self.assertRaises(TypeError, f, None)
594        self.assertRaises(TypeError, f, 42)
595        self.assertRaises(TypeError, f, '')
596        self.assertRaises(TypeError, f, [])
597        self.assertRaises(TypeError, f, [''])
598        self.assertRaises(pg.ProgrammingError, f, 'this_does_not_exist')
599        r = f('standard_conforming_strings')
600        self.assertEqual(r, 'on')
601        r = f('lc_monetary')
602        self.assertEqual(r, 'C')
603        r = f('datestyle')
604        self.assertEqual(r, 'ISO, YMD')
605        r = f('bytea_output')
606        self.assertEqual(r, 'hex')
607        r = f(['bytea_output', 'lc_monetary'])
608        self.assertIsInstance(r, list)
609        self.assertEqual(r, ['hex', 'C'])
610        r = f(('standard_conforming_strings', 'datestyle', 'bytea_output'))
611        self.assertEqual(r, ['on', 'ISO, YMD', 'hex'])
612        r = f(set(['bytea_output', 'lc_monetary']))
613        self.assertIsInstance(r, dict)
614        self.assertEqual(r, {'bytea_output': 'hex', 'lc_monetary': 'C'})
615        r = f(set(['Bytea_Output', ' LC_Monetary ']))
616        self.assertIsInstance(r, dict)
617        self.assertEqual(r, {'Bytea_Output': 'hex', ' LC_Monetary ': 'C'})
618        s = dict.fromkeys(('bytea_output', 'lc_monetary'))
619        r = f(s)
620        self.assertIs(r, s)
621        self.assertEqual(r, {'bytea_output': 'hex', 'lc_monetary': 'C'})
622        s = dict.fromkeys(('Bytea_Output', ' LC_Monetary '))
623        r = f(s)
624        self.assertIs(r, s)
625        self.assertEqual(r, {'Bytea_Output': 'hex', ' LC_Monetary ': 'C'})
626
627    def testGetParameterServerVersion(self):
628        r = self.db.get_parameter('server_version_num')
629        self.assertIsInstance(r, str)
630        s = self.db.server_version
631        self.assertIsInstance(s, int)
632        self.assertEqual(r, str(s))
633
634    def testGetParameterAll(self):
635        f = self.db.get_parameter
636        r = f('all')
637        self.assertIsInstance(r, dict)
638        self.assertEqual(r['standard_conforming_strings'], 'on')
639        self.assertEqual(r['lc_monetary'], 'C')
640        self.assertEqual(r['DateStyle'], 'ISO, YMD')
641        self.assertEqual(r['bytea_output'], 'hex')
642
643    def testSetParameter(self):
644        f = self.db.set_parameter
645        g = self.db.get_parameter
646        self.assertRaises(TypeError, f)
647        self.assertRaises(TypeError, f, None)
648        self.assertRaises(TypeError, f, 42)
649        self.assertRaises(TypeError, f, '')
650        self.assertRaises(TypeError, f, [])
651        self.assertRaises(TypeError, f, [''])
652        self.assertRaises(ValueError, f, 'all', 'invalid')
653        self.assertRaises(ValueError, f, {
654            'invalid1': 'value1', 'invalid2': 'value2'}, 'value')
655        self.assertRaises(pg.ProgrammingError, f, 'this_does_not_exist')
656        f('standard_conforming_strings', 'off')
657        self.assertEqual(g('standard_conforming_strings'), 'off')
658        f('datestyle', 'ISO, DMY')
659        self.assertEqual(g('datestyle'), 'ISO, DMY')
660        f(['standard_conforming_strings', 'datestyle'], ['on', 'ISO, DMY'])
661        self.assertEqual(g('standard_conforming_strings'), 'on')
662        self.assertEqual(g('datestyle'), 'ISO, DMY')
663        f(['default_with_oids', 'standard_conforming_strings'], 'off')
664        self.assertEqual(g('default_with_oids'), 'off')
665        self.assertEqual(g('standard_conforming_strings'), 'off')
666        f(('standard_conforming_strings', 'datestyle'), ('on', 'ISO, YMD'))
667        self.assertEqual(g('standard_conforming_strings'), 'on')
668        self.assertEqual(g('datestyle'), 'ISO, YMD')
669        f(('default_with_oids', 'standard_conforming_strings'), 'off')
670        self.assertEqual(g('default_with_oids'), 'off')
671        self.assertEqual(g('standard_conforming_strings'), 'off')
672        f(set(['default_with_oids', 'standard_conforming_strings']), 'on')
673        self.assertEqual(g('default_with_oids'), 'on')
674        self.assertEqual(g('standard_conforming_strings'), 'on')
675        self.assertRaises(ValueError, f, set(['default_with_oids',
676            'standard_conforming_strings']), ['off', 'on'])
677        f(set(['default_with_oids', 'standard_conforming_strings']),
678            ['off', 'off'])
679        self.assertEqual(g('default_with_oids'), 'off')
680        self.assertEqual(g('standard_conforming_strings'), 'off')
681        f({'standard_conforming_strings': 'on', 'datestyle': 'ISO, YMD'})
682        self.assertEqual(g('standard_conforming_strings'), 'on')
683        self.assertEqual(g('datestyle'), 'ISO, YMD')
684
685    def testResetParameter(self):
686        db = DB()
687        f = db.set_parameter
688        g = db.get_parameter
689        r = g('default_with_oids')
690        self.assertIn(r, ('on', 'off'))
691        dwi, not_dwi = r, 'off' if r == 'on' else 'on'
692        r = g('standard_conforming_strings')
693        self.assertIn(r, ('on', 'off'))
694        scs, not_scs = r, 'off' if r == 'on' else 'on'
695        f('default_with_oids', not_dwi)
696        f('standard_conforming_strings', not_scs)
697        self.assertEqual(g('default_with_oids'), not_dwi)
698        self.assertEqual(g('standard_conforming_strings'), not_scs)
699        f('default_with_oids')
700        f('standard_conforming_strings', None)
701        self.assertEqual(g('default_with_oids'), dwi)
702        self.assertEqual(g('standard_conforming_strings'), scs)
703        f('default_with_oids', not_dwi)
704        f('standard_conforming_strings', not_scs)
705        self.assertEqual(g('default_with_oids'), not_dwi)
706        self.assertEqual(g('standard_conforming_strings'), not_scs)
707        f(['default_with_oids', 'standard_conforming_strings'], None)
708        self.assertEqual(g('default_with_oids'), dwi)
709        self.assertEqual(g('standard_conforming_strings'), scs)
710        f('default_with_oids', not_dwi)
711        f('standard_conforming_strings', not_scs)
712        self.assertEqual(g('default_with_oids'), not_dwi)
713        self.assertEqual(g('standard_conforming_strings'), not_scs)
714        f(('default_with_oids', 'standard_conforming_strings'))
715        self.assertEqual(g('default_with_oids'), dwi)
716        self.assertEqual(g('standard_conforming_strings'), scs)
717        f('default_with_oids', not_dwi)
718        f('standard_conforming_strings', not_scs)
719        self.assertEqual(g('default_with_oids'), not_dwi)
720        self.assertEqual(g('standard_conforming_strings'), not_scs)
721        f(set(['default_with_oids', 'standard_conforming_strings']))
722        self.assertEqual(g('default_with_oids'), dwi)
723        self.assertEqual(g('standard_conforming_strings'), scs)
724
725    def testResetParameterAll(self):
726        db = DB()
727        f = db.set_parameter
728        self.assertRaises(ValueError, f, 'all', 0)
729        self.assertRaises(ValueError, f, 'all', 'off')
730        g = db.get_parameter
731        r = g('default_with_oids')
732        self.assertIn(r, ('on', 'off'))
733        dwi, not_dwi = r, 'off' if r == 'on' else 'on'
734        r = g('standard_conforming_strings')
735        self.assertIn(r, ('on', 'off'))
736        scs, not_scs = r, 'off' if r == 'on' else 'on'
737        f('default_with_oids', not_dwi)
738        f('standard_conforming_strings', not_scs)
739        self.assertEqual(g('default_with_oids'), not_dwi)
740        self.assertEqual(g('standard_conforming_strings'), not_scs)
741        f('all')
742        self.assertEqual(g('default_with_oids'), dwi)
743        self.assertEqual(g('standard_conforming_strings'), scs)
744
745    def testSetParameterLocal(self):
746        f = self.db.set_parameter
747        g = self.db.get_parameter
748        self.assertEqual(g('standard_conforming_strings'), 'on')
749        self.db.begin()
750        f('standard_conforming_strings', 'off', local=True)
751        self.assertEqual(g('standard_conforming_strings'), 'off')
752        self.db.end()
753        self.assertEqual(g('standard_conforming_strings'), 'on')
754
755    def testSetParameterSession(self):
756        f = self.db.set_parameter
757        g = self.db.get_parameter
758        self.assertEqual(g('standard_conforming_strings'), 'on')
759        self.db.begin()
760        f('standard_conforming_strings', 'off', local=False)
761        self.assertEqual(g('standard_conforming_strings'), 'off')
762        self.db.end()
763        self.assertEqual(g('standard_conforming_strings'), 'off')
764
765    def testReset(self):
766        db = DB()
767        default_datestyle = db.get_parameter('datestyle')
768        changed_datestyle = 'ISO, DMY'
769        if changed_datestyle == default_datestyle:
770            changed_datestyle == 'ISO, YMD'
771        self.db.set_parameter('datestyle', changed_datestyle)
772        r = self.db.get_parameter('datestyle')
773        self.assertEqual(r, changed_datestyle)
774        con = self.db.db
775        q = con.query("show datestyle")
776        self.db.reset()
777        r = q.getresult()[0][0]
778        self.assertEqual(r, changed_datestyle)
779        q = con.query("show datestyle")
780        r = q.getresult()[0][0]
781        self.assertEqual(r, default_datestyle)
782        r = self.db.get_parameter('datestyle')
783        self.assertEqual(r, default_datestyle)
784
785    def testReopen(self):
786        db = DB()
787        default_datestyle = db.get_parameter('datestyle')
788        changed_datestyle = 'ISO, DMY'
789        if changed_datestyle == default_datestyle:
790            changed_datestyle == 'ISO, YMD'
791        self.db.set_parameter('datestyle', changed_datestyle)
792        r = self.db.get_parameter('datestyle')
793        self.assertEqual(r, changed_datestyle)
794        con = self.db.db
795        q = con.query("show datestyle")
796        self.db.reopen()
797        r = q.getresult()[0][0]
798        self.assertEqual(r, changed_datestyle)
799        self.assertRaises(TypeError, getattr, con, 'query')
800        r = self.db.get_parameter('datestyle')
801        self.assertEqual(r, default_datestyle)
802
803    def testCreateTable(self):
804        table = 'test hello world'
805        values = [(2, "World!"), (1, "Hello")]
806        self.createTable(table, "n smallint, t varchar",
807                         temporary=True, oids=True, values=values)
808        r = self.db.query('select t from "%s" order by n' % table).getresult()
809        r = ', '.join(row[0] for row in r)
810        self.assertEqual(r, "Hello, World!")
811        r = self.db.query('select oid from "%s" limit 1' % table).getresult()
812        self.assertIsInstance(r[0][0], int)
813
814    def testQuery(self):
815        query = self.db.query
816        table = 'test_table'
817        self.createTable(table, "n integer", oids=True)
818        q = "insert into test_table values (1)"
819        r = query(q)
820        self.assertIsInstance(r, int)
821        q = "insert into test_table select 2"
822        r = query(q)
823        self.assertIsInstance(r, int)
824        oid = r
825        q = "select oid from test_table where n=2"
826        r = query(q).getresult()
827        self.assertEqual(len(r), 1)
828        r = r[0]
829        self.assertEqual(len(r), 1)
830        r = r[0]
831        self.assertEqual(r, oid)
832        q = "insert into test_table select 3 union select 4 union select 5"
833        r = query(q)
834        self.assertIsInstance(r, str)
835        self.assertEqual(r, '3')
836        q = "update test_table set n=4 where n<5"
837        r = query(q)
838        self.assertIsInstance(r, str)
839        self.assertEqual(r, '4')
840        q = "delete from test_table"
841        r = query(q)
842        self.assertIsInstance(r, str)
843        self.assertEqual(r, '5')
844
845    def testMultipleQueries(self):
846        self.assertEqual(self.db.query(
847            "create temporary table test_multi (n integer);"
848            "insert into test_multi values (4711);"
849            "select n from test_multi").getresult()[0][0], 4711)
850
851    def testQueryWithParams(self):
852        query = self.db.query
853        self.createTable('test_table', 'n1 integer, n2 integer', oids=True)
854        q = "insert into test_table values ($1, $2)"
855        r = query(q, (1, 2))
856        self.assertIsInstance(r, int)
857        r = query(q, [3, 4])
858        self.assertIsInstance(r, int)
859        r = query(q, [5, 6])
860        self.assertIsInstance(r, int)
861        q = "select * from test_table order by 1, 2"
862        self.assertEqual(query(q).getresult(),
863                         [(1, 2), (3, 4), (5, 6)])
864        q = "select * from test_table where n1=$1 and n2=$2"
865        self.assertEqual(query(q, 3, 4).getresult(), [(3, 4)])
866        q = "update test_table set n2=$2 where n1=$1"
867        r = query(q, 3, 7)
868        self.assertEqual(r, '1')
869        q = "select * from test_table order by 1, 2"
870        self.assertEqual(query(q).getresult(),
871                         [(1, 2), (3, 7), (5, 6)])
872        q = "delete from test_table where n2!=$1"
873        r = query(q, 4)
874        self.assertEqual(r, '3')
875
876    def testEmptyQuery(self):
877        self.assertRaises(ValueError, self.db.query, '')
878
879    def testQueryDataError(self):
880        try:
881            self.db.query("select 1/0")
882        except pg.DataError as error:
883            self.assertEqual(error.sqlstate, '22012')
884
885    def testQueryFormatted(self):
886        f = self.db.query_formatted
887        t = True if pg.get_bool() else 't'
888        q = f("select %s::int, %s::real, %s::text, %s::bool",
889              (3, 2.5, 'hello', True))
890        r = q.getresult()[0]
891        self.assertEqual(r, (3, 2.5, 'hello', t))
892        q = f("select %s, %s, %s, %s", (3, 2.5, 'hello', True), inline=True)
893        r = q.getresult()[0]
894        if isinstance(r[1], Decimal):
895            # Python 2.6 cannot compare float and Decimal
896            r = list(r)
897            r[1] = float(r[1])
898            r = tuple(r)
899        self.assertEqual(r, (3, 2.5, 'hello', t))
900
901    def testPkey(self):
902        query = self.db.query
903        pkey = self.db.pkey
904        self.assertRaises(KeyError, pkey, 'test')
905        for t in ('pkeytest', 'primary key test'):
906            self.createTable('%s0' % t, 'a smallint')
907            self.createTable('%s1' % t, 'b smallint primary key')
908            self.createTable('%s2' % t,
909                'c smallint, d smallint primary key')
910            self.createTable('%s3' % t,
911                'e smallint, f smallint, g smallint, h smallint, i smallint,'
912                ' primary key (f, h)')
913            self.createTable('%s4' % t,
914                'e smallint, f smallint, g smallint, h smallint, i smallint,'
915                ' primary key (h, f)')
916            self.createTable('%s5' % t,
917                'more_than_one_letter varchar primary key')
918            self.createTable('%s6' % t,
919                '"with space" date primary key')
920            self.createTable('%s7' % t,
921                'a_very_long_column_name varchar, "with space" date, "42" int,'
922                ' primary key (a_very_long_column_name, "with space", "42")')
923            self.assertRaises(KeyError, pkey, '%s0' % t)
924            self.assertEqual(pkey('%s1' % t), 'b')
925            self.assertEqual(pkey('%s1' % t, True), ('b',))
926            self.assertEqual(pkey('%s1' % t, composite=False), 'b')
927            self.assertEqual(pkey('%s1' % t, composite=True), ('b',))
928            self.assertEqual(pkey('%s2' % t), 'd')
929            self.assertEqual(pkey('%s2' % t, composite=True), ('d',))
930            r = pkey('%s3' % t)
931            self.assertIsInstance(r, tuple)
932            self.assertEqual(r, ('f', 'h'))
933            r = pkey('%s3' % t, composite=False)
934            self.assertIsInstance(r, tuple)
935            self.assertEqual(r, ('f', 'h'))
936            r = pkey('%s4' % t)
937            self.assertIsInstance(r, tuple)
938            self.assertEqual(r, ('h', 'f'))
939            self.assertEqual(pkey('%s5' % t), 'more_than_one_letter')
940            self.assertEqual(pkey('%s6' % t), 'with space')
941            r = pkey('%s7' % t)
942            self.assertIsInstance(r, tuple)
943            self.assertEqual(r, (
944                'a_very_long_column_name', 'with space', '42'))
945            # a newly added primary key will be detected
946            query('alter table "%s0" add primary key (a)' % t)
947            self.assertEqual(pkey('%s0' % t), 'a')
948            # a changed primary key will not be detected,
949            # indicating that the internal cache is operating
950            query('alter table "%s1" rename column b to x' % t)
951            self.assertEqual(pkey('%s1' % t), 'b')
952            # we get the changed primary key when the cache is flushed
953            self.assertEqual(pkey('%s1' % t, flush=True), 'x')
954
955    def testGetDatabases(self):
956        databases = self.db.get_databases()
957        self.assertIn('template0', databases)
958        self.assertIn('template1', databases)
959        self.assertNotIn('not existing database', databases)
960        self.assertIn('postgres', databases)
961        self.assertIn(dbname, databases)
962
963    def testGetTables(self):
964        get_tables = self.db.get_tables
965        tables = ('A very Special Name', 'A_MiXeD_quoted_NaMe',
966                  'Hello, Test World!', 'Zoro', 'a1', 'a2', 'a321',
967                  'averyveryveryveryveryveryveryreallyreallylongtablename',
968                  'b0', 'b3', 'x', 'xXx', 'xx', 'y', 'z')
969        for t in tables:
970            self.db.query('drop table if exists "%s" cascade' % t)
971        before_tables = get_tables()
972        self.assertIsInstance(before_tables, list)
973        for t in before_tables:
974            t = t.split('.', 1)
975            self.assertGreaterEqual(len(t), 2)
976            if len(t) > 2:
977                self.assertTrue(t[1].startswith('"'))
978            t = t[0]
979            self.assertNotEqual(t, 'information_schema')
980            self.assertFalse(t.startswith('pg_'))
981        for t in tables:
982            self.createTable(t, 'as select 0', temporary=False)
983        current_tables = get_tables()
984        new_tables = [t for t in current_tables if t not in before_tables]
985        expected_new_tables = ['public.%s' % (
986            '"%s"' % t if ' ' in t or t != t.lower() else t) for t in tables]
987        self.assertEqual(new_tables, expected_new_tables)
988        self.doCleanups()
989        after_tables = get_tables()
990        self.assertEqual(after_tables, before_tables)
991
992    def testGetSystemTables(self):
993        get_tables = self.db.get_tables
994        result = get_tables()
995        self.assertNotIn('pg_catalog.pg_class', result)
996        self.assertNotIn('information_schema.tables', result)
997        result = get_tables(system=False)
998        self.assertNotIn('pg_catalog.pg_class', result)
999        self.assertNotIn('information_schema.tables', result)
1000        result = get_tables(system=True)
1001        self.assertIn('pg_catalog.pg_class', result)
1002        self.assertNotIn('information_schema.tables', result)
1003
1004    def testGetRelations(self):
1005        get_relations = self.db.get_relations
1006        result = get_relations()
1007        self.assertIn('public.test', result)
1008        self.assertIn('public.test_view', result)
1009        result = get_relations('rv')
1010        self.assertIn('public.test', result)
1011        self.assertIn('public.test_view', result)
1012        result = get_relations('r')
1013        self.assertIn('public.test', result)
1014        self.assertNotIn('public.test_view', result)
1015        result = get_relations('v')
1016        self.assertNotIn('public.test', result)
1017        self.assertIn('public.test_view', result)
1018        result = get_relations('cisSt')
1019        self.assertNotIn('public.test', result)
1020        self.assertNotIn('public.test_view', result)
1021
1022    def testGetSystemRelations(self):
1023        get_relations = self.db.get_relations
1024        result = get_relations()
1025        self.assertNotIn('pg_catalog.pg_class', result)
1026        self.assertNotIn('information_schema.tables', result)
1027        result = get_relations(system=False)
1028        self.assertNotIn('pg_catalog.pg_class', result)
1029        self.assertNotIn('information_schema.tables', result)
1030        result = get_relations(system=True)
1031        self.assertIn('pg_catalog.pg_class', result)
1032        self.assertIn('information_schema.tables', result)
1033
1034    def testGetAttnames(self):
1035        get_attnames = self.db.get_attnames
1036        self.assertRaises(pg.ProgrammingError,
1037                          self.db.get_attnames, 'does_not_exist')
1038        self.assertRaises(pg.ProgrammingError,
1039                          self.db.get_attnames, 'has.too.many.dots')
1040        r = get_attnames('test')
1041        self.assertIsInstance(r, dict)
1042        if self.regtypes:
1043            self.assertEqual(r, dict(
1044                i2='smallint', i4='integer', i8='bigint', d='numeric',
1045                f4='real', f8='double precision', m='money',
1046                v4='character varying', c4='character', t='text'))
1047        else:
1048            self.assertEqual(r, dict(
1049                i2='int', i4='int', i8='int', d='num',
1050                f4='float', f8='float', m='money',
1051                v4='text', c4='text', t='text'))
1052        self.createTable('test_table',
1053                         'n int, alpha smallint, beta bool,'
1054                         ' gamma char(5), tau text, v varchar(3)')
1055        r = get_attnames('test_table')
1056        self.assertIsInstance(r, dict)
1057        if self.regtypes:
1058            self.assertEqual(r, dict(
1059                n='integer', alpha='smallint', beta='boolean',
1060                gamma='character', tau='text', v='character varying'))
1061        else:
1062            self.assertEqual(r, dict(
1063                n='int', alpha='int', beta='bool',
1064                gamma='text', tau='text', v='text'))
1065
1066    def testGetAttnamesWithQuotes(self):
1067        get_attnames = self.db.get_attnames
1068        table = 'test table for get_attnames()'
1069        self.createTable(table,
1070            '"Prime!" smallint, "much space" integer, "Questions?" text')
1071        r = get_attnames(table)
1072        self.assertIsInstance(r, dict)
1073        if self.regtypes:
1074            self.assertEqual(r, {
1075                'Prime!': 'smallint', 'much space': 'integer',
1076                'Questions?': 'text'})
1077        else:
1078            self.assertEqual(r, {
1079                'Prime!': 'int', 'much space': 'int', 'Questions?': 'text'})
1080        table = 'yet another test table for get_attnames()'
1081        self.createTable(table,
1082                         'a smallint, b integer, c bigint,'
1083                         ' e numeric, f real, f2 double precision, m money,'
1084                         ' x smallint, y smallint, z smallint,'
1085                         ' Normal_NaMe smallint, "Special Name" smallint,'
1086                         ' t text, u char(2), v varchar(2),'
1087                         ' primary key (y, u)', oids=True)
1088        r = get_attnames(table)
1089        self.assertIsInstance(r, dict)
1090        if self.regtypes:
1091            self.assertEqual(r, {
1092                'a': 'smallint', 'b': 'integer', 'c': 'bigint',
1093                'e': 'numeric', 'f': 'real', 'f2': 'double precision',
1094                'm': 'money', 'normal_name': 'smallint',
1095                'Special Name': 'smallint', 'u': 'character',
1096                't': 'text', 'v': 'character varying', 'y': 'smallint',
1097                'x': 'smallint', 'z': 'smallint', 'oid': 'oid'})
1098        else:
1099            self.assertEqual(r, {'a': 'int', 'b': 'int', 'c': 'int',
1100                 'e': 'num', 'f': 'float', 'f2': 'float', 'm': 'money',
1101                 'normal_name': 'int', 'Special Name': 'int',
1102                 'u': 'text', 't': 'text', 'v': 'text',
1103                 'y': 'int', 'x': 'int', 'z': 'int', 'oid': 'int'})
1104
1105    def testGetAttnamesWithRegtypes(self):
1106        get_attnames = self.db.get_attnames
1107        self.createTable('test_table', 'n int, alpha smallint, beta bool,'
1108            ' gamma char(5), tau text, v varchar(3)')
1109        use_regtypes = self.db.use_regtypes
1110        regtypes = use_regtypes()
1111        self.assertEqual(regtypes, self.regtypes)
1112        use_regtypes(True)
1113        try:
1114            r = get_attnames("test_table")
1115            self.assertIsInstance(r, dict)
1116        finally:
1117            use_regtypes(regtypes)
1118        self.assertEqual(r, dict(
1119            n='integer', alpha='smallint', beta='boolean',
1120            gamma='character', tau='text', v='character varying'))
1121
1122    def testGetAttnamesWithoutRegtypes(self):
1123        get_attnames = self.db.get_attnames
1124        self.createTable('test_table', 'n int, alpha smallint, beta bool,'
1125            ' gamma char(5), tau text, v varchar(3)')
1126        use_regtypes = self.db.use_regtypes
1127        regtypes = use_regtypes()
1128        self.assertEqual(regtypes, self.regtypes)
1129        use_regtypes(False)
1130        try:
1131            r = get_attnames("test_table")
1132            self.assertIsInstance(r, dict)
1133        finally:
1134            use_regtypes(regtypes)
1135        self.assertEqual(r, dict(
1136            n='int', alpha='int', beta='bool',
1137            gamma='text', tau='text', v='text'))
1138
1139    def testGetAttnamesIsCached(self):
1140        get_attnames = self.db.get_attnames
1141        int_type = 'integer' if self.regtypes else 'int'
1142        text_type = 'text'
1143        query = self.db.query
1144        self.createTable('test_table', 'col int')
1145        r = get_attnames("test_table")
1146        self.assertIsInstance(r, dict)
1147        self.assertEqual(r, dict(col=int_type))
1148        query("alter table test_table alter column col type text")
1149        query("alter table test_table add column col2 int")
1150        r = get_attnames("test_table")
1151        self.assertEqual(r, dict(col=int_type))
1152        r = get_attnames("test_table", flush=True)
1153        self.assertEqual(r, dict(col=text_type, col2=int_type))
1154        query("alter table test_table drop column col2")
1155        r = get_attnames("test_table")
1156        self.assertEqual(r, dict(col=text_type, col2=int_type))
1157        r = get_attnames("test_table", flush=True)
1158        self.assertEqual(r, dict(col=text_type))
1159        query("alter table test_table drop column col")
1160        r = get_attnames("test_table")
1161        self.assertEqual(r, dict(col=text_type))
1162        r = get_attnames("test_table", flush=True)
1163        self.assertEqual(r, dict())
1164
1165    def testGetAttnamesIsOrdered(self):
1166        get_attnames = self.db.get_attnames
1167        r = get_attnames('test', flush=True)
1168        self.assertIsInstance(r, OrderedDict)
1169        if self.regtypes:
1170            self.assertEqual(r, OrderedDict([
1171                ('i2', 'smallint'), ('i4', 'integer'), ('i8', 'bigint'),
1172                ('d', 'numeric'), ('f4', 'real'), ('f8', 'double precision'),
1173                ('m', 'money'), ('v4', 'character varying'),
1174                ('c4', 'character'), ('t', 'text')]))
1175        else:
1176            self.assertEqual(r, OrderedDict([
1177                ('i2', 'int'), ('i4', 'int'), ('i8', 'int'),
1178                ('d', 'num'), ('f4', 'float'), ('f8', 'float'), ('m', 'money'),
1179                ('v4', 'text'), ('c4', 'text'), ('t', 'text')]))
1180        if OrderedDict is not dict:
1181            r = ' '.join(list(r.keys()))
1182            self.assertEqual(r, 'i2 i4 i8 d f4 f8 m v4 c4 t')
1183        table = 'test table for get_attnames'
1184        self.createTable(table, 'n int, alpha smallint, v varchar(3),'
1185                ' gamma char(5), tau text, beta bool')
1186        r = get_attnames(table)
1187        self.assertIsInstance(r, OrderedDict)
1188        if self.regtypes:
1189            self.assertEqual(r, OrderedDict([
1190                ('n', 'integer'), ('alpha', 'smallint'),
1191                ('v', 'character varying'), ('gamma', 'character'),
1192                ('tau', 'text'), ('beta', 'boolean')]))
1193        else:
1194            self.assertEqual(r, OrderedDict([
1195                ('n', 'int'), ('alpha', 'int'), ('v', 'text'),
1196                ('gamma', 'text'), ('tau', 'text'), ('beta', 'bool')]))
1197        if OrderedDict is not dict:
1198            r = ' '.join(list(r.keys()))
1199            self.assertEqual(r, 'n alpha v gamma tau beta')
1200        else:
1201            self.skipTest('OrderedDict is not supported')
1202
1203    def testGetAttnamesIsAttrDict(self):
1204        AttrDict = pg.AttrDict
1205        get_attnames = self.db.get_attnames
1206        r = get_attnames('test', flush=True)
1207        self.assertIsInstance(r, AttrDict)
1208        if self.regtypes:
1209            self.assertEqual(r, AttrDict([
1210                ('i2', 'smallint'), ('i4', 'integer'), ('i8', 'bigint'),
1211                ('d', 'numeric'), ('f4', 'real'), ('f8', 'double precision'),
1212                ('m', 'money'), ('v4', 'character varying'),
1213                ('c4', 'character'), ('t', 'text')]))
1214        else:
1215            self.assertEqual(r, AttrDict([
1216                ('i2', 'int'), ('i4', 'int'), ('i8', 'int'),
1217                ('d', 'num'), ('f4', 'float'), ('f8', 'float'), ('m', 'money'),
1218                ('v4', 'text'), ('c4', 'text'), ('t', 'text')]))
1219        r = ' '.join(list(r.keys()))
1220        self.assertEqual(r, 'i2 i4 i8 d f4 f8 m v4 c4 t')
1221        table = 'test table for get_attnames'
1222        self.createTable(table, 'n int, alpha smallint, v varchar(3),'
1223                ' gamma char(5), tau text, beta bool')
1224        r = get_attnames(table)
1225        self.assertIsInstance(r, AttrDict)
1226        if self.regtypes:
1227            self.assertEqual(r, AttrDict([
1228                ('n', 'integer'), ('alpha', 'smallint'),
1229                ('v', 'character varying'), ('gamma', 'character'),
1230                ('tau', 'text'), ('beta', 'boolean')]))
1231        else:
1232            self.assertEqual(r, AttrDict([
1233                ('n', 'int'), ('alpha', 'int'), ('v', 'text'),
1234                ('gamma', 'text'), ('tau', 'text'), ('beta', 'bool')]))
1235        r = ' '.join(list(r.keys()))
1236        self.assertEqual(r, 'n alpha v gamma tau beta')
1237
1238    def testHasTablePrivilege(self):
1239        can = self.db.has_table_privilege
1240        self.assertEqual(can('test'), True)
1241        self.assertEqual(can('test', 'select'), True)
1242        self.assertEqual(can('test', 'SeLeCt'), True)
1243        self.assertEqual(can('test', 'SELECT'), True)
1244        self.assertEqual(can('test', 'insert'), True)
1245        self.assertEqual(can('test', 'update'), True)
1246        self.assertEqual(can('test', 'delete'), True)
1247        self.assertRaises(pg.DataError, can, 'test', 'foobar')
1248        self.assertRaises(pg.ProgrammingError, can, 'table_does_not_exist')
1249        r = self.db.query('select rolsuper FROM pg_roles'
1250            ' where rolname=current_user').getresult()[0][0]
1251        if not pg.get_bool():
1252            r = r == 't'
1253        if r:
1254            self.skipTest('must not be superuser')
1255        self.assertEqual(can('pg_views', 'select'), True)
1256        self.assertEqual(can('pg_views', 'delete'), False)
1257
1258    def testGet(self):
1259        get = self.db.get
1260        query = self.db.query
1261        table = 'get_test_table'
1262        self.assertRaises(TypeError, get)
1263        self.assertRaises(TypeError, get, table)
1264        self.createTable(table, 'n integer, t text',
1265                         values=enumerate('xyz', start=1))
1266        self.assertRaises(pg.ProgrammingError, get, table, 2)
1267        r = get(table, 2, 'n')
1268        self.assertIsInstance(r, dict)
1269        self.assertEqual(r, dict(n=2, t='y'))
1270        r = get(table, 1, 'n')
1271        self.assertEqual(r, dict(n=1, t='x'))
1272        r = get(table, (3,), ('n',))
1273        self.assertEqual(r, dict(n=3, t='z'))
1274        r = get(table, 'y', 't')
1275        self.assertEqual(r, dict(n=2, t='y'))
1276        self.assertRaises(pg.DatabaseError, get, table, 4)
1277        self.assertRaises(pg.DatabaseError, get, table, 4, 'n')
1278        self.assertRaises(pg.DatabaseError, get, table, 'y')
1279        self.assertRaises(pg.DatabaseError, get, table, 2, 't')
1280        s = dict(n=3)
1281        self.assertRaises(pg.ProgrammingError, get, table, s)
1282        r = get(table, s, 'n')
1283        self.assertIs(r, s)
1284        self.assertEqual(r, dict(n=3, t='z'))
1285        s.update(t='x')
1286        r = get(table, s, 't')
1287        self.assertIs(r, s)
1288        self.assertEqual(s, dict(n=1, t='x'))
1289        r = get(table, s, ('n', 't'))
1290        self.assertIs(r, s)
1291        self.assertEqual(r, dict(n=1, t='x'))
1292        query('alter table "%s" alter n set not null' % table)
1293        query('alter table "%s" add primary key (n)' % table)
1294        r = get(table, 2)
1295        self.assertIsInstance(r, dict)
1296        self.assertEqual(r, dict(n=2, t='y'))
1297        self.assertEqual(get(table, 1)['t'], 'x')
1298        self.assertEqual(get(table, 3)['t'], 'z')
1299        self.assertEqual(get(table + '*', 2)['t'], 'y')
1300        self.assertEqual(get(table + ' *', 2)['t'], 'y')
1301        self.assertRaises(KeyError, get, table, (2, 2))
1302        s = dict(n=3)
1303        r = get(table, s)
1304        self.assertIs(r, s)
1305        self.assertEqual(r, dict(n=3, t='z'))
1306        s.update(n=1)
1307        self.assertEqual(get(table, s)['t'], 'x')
1308        s.update(n=2)
1309        self.assertEqual(get(table, r)['t'], 'y')
1310        s.pop('n')
1311        self.assertRaises(KeyError, get, table, s)
1312
1313    def testGetWithOid(self):
1314        get = self.db.get
1315        query = self.db.query
1316        table = 'get_with_oid_test_table'
1317        self.createTable(table, 'n integer, t text', oids=True,
1318                         values=enumerate('xyz', start=1))
1319        self.assertRaises(pg.ProgrammingError, get, table, 2)
1320        self.assertRaises(KeyError, get, table, {}, 'oid')
1321        r = get(table, 2, 'n')
1322        qoid = 'oid(%s)' % table
1323        self.assertIn(qoid, r)
1324        oid = r[qoid]
1325        self.assertIsInstance(oid, int)
1326        result = {'t': 'y', 'n': 2, qoid: oid}
1327        self.assertEqual(r, result)
1328        r = get(table, oid, 'oid')
1329        self.assertEqual(r, result)
1330        r = get(table, dict(oid=oid))
1331        self.assertEqual(r, result)
1332        r = get(table, dict(oid=oid), 'oid')
1333        self.assertEqual(r, result)
1334        r = get(table, {qoid: oid})
1335        self.assertEqual(r, result)
1336        r = get(table, {qoid: oid}, 'oid')
1337        self.assertEqual(r, result)
1338        self.assertEqual(get(table + '*', 2, 'n'), r)
1339        self.assertEqual(get(table + ' *', 2, 'n'), r)
1340        self.assertEqual(get(table, oid, 'oid')['t'], 'y')
1341        self.assertEqual(get(table, 1, 'n')['t'], 'x')
1342        self.assertEqual(get(table, 3, 'n')['t'], 'z')
1343        self.assertEqual(get(table, 2, 'n')['t'], 'y')
1344        self.assertRaises(pg.DatabaseError, get, table, 4, 'n')
1345        r['n'] = 3
1346        self.assertEqual(get(table, r, 'n')['t'], 'z')
1347        self.assertEqual(get(table, 1, 'n')['t'], 'x')
1348        self.assertEqual(get(table, r, 'oid')['t'], 'z')
1349        query('alter table "%s" alter n set not null' % table)
1350        query('alter table "%s" add primary key (n)' % table)
1351        self.assertEqual(get(table, 3)['t'], 'z')
1352        self.assertEqual(get(table, 1)['t'], 'x')
1353        self.assertEqual(get(table, 2)['t'], 'y')
1354        r['n'] = 1
1355        self.assertEqual(get(table, r)['t'], 'x')
1356        r['n'] = 3
1357        self.assertEqual(get(table, r)['t'], 'z')
1358        r['n'] = 2
1359        self.assertEqual(get(table, r)['t'], 'y')
1360        r = get(table, oid, 'oid')
1361        self.assertEqual(r, result)
1362        r = get(table, dict(oid=oid))
1363        self.assertEqual(r, result)
1364        r = get(table, dict(oid=oid), 'oid')
1365        self.assertEqual(r, result)
1366        r = get(table, {qoid: oid})
1367        self.assertEqual(r, result)
1368        r = get(table, {qoid: oid}, 'oid')
1369        self.assertEqual(r, result)
1370        r = get(table, dict(oid=oid, n=1))
1371        self.assertEqual(r['n'], 1)
1372        self.assertNotEqual(r[qoid], oid)
1373        r = get(table, dict(oid=oid, t='z'), 't')
1374        self.assertEqual(r['n'], 3)
1375        self.assertNotEqual(r[qoid], oid)
1376
1377    def testGetWithCompositeKey(self):
1378        get = self.db.get
1379        query = self.db.query
1380        table = 'get_test_table_1'
1381        self.createTable(table, 'n integer primary key, t text',
1382                         values=enumerate('abc', start=1))
1383        self.assertEqual(get(table, 2)['t'], 'b')
1384        self.assertEqual(get(table, 1, 'n')['t'], 'a')
1385        self.assertEqual(get(table, 2, ('n',))['t'], 'b')
1386        self.assertEqual(get(table, 3, ['n'])['t'], 'c')
1387        self.assertEqual(get(table, (2,), ('n',))['t'], 'b')
1388        self.assertEqual(get(table, 'b', 't')['n'], 2)
1389        self.assertEqual(get(table, ('a',), ('t',))['n'], 1)
1390        self.assertEqual(get(table, ['c'], ['t'])['n'], 3)
1391        table = 'get_test_table_2'
1392        self.createTable(table,
1393                         'n integer, m integer, t text, primary key (n, m)',
1394                         values=[(n + 1, m + 1, chr(ord('a') + 2 * n + m))
1395                                 for n in range(3) for m in range(2)])
1396        self.assertRaises(KeyError, get, table, 2)
1397        self.assertEqual(get(table, (1, 1))['t'], 'a')
1398        self.assertEqual(get(table, (1, 2))['t'], 'b')
1399        self.assertEqual(get(table, (2, 1))['t'], 'c')
1400        self.assertEqual(get(table, (1, 2), ('n', 'm'))['t'], 'b')
1401        self.assertEqual(get(table, (1, 2), ('m', 'n'))['t'], 'c')
1402        self.assertEqual(get(table, (3, 1), ('n', 'm'))['t'], 'e')
1403        self.assertEqual(get(table, (1, 3), ('m', 'n'))['t'], 'e')
1404        self.assertEqual(get(table, dict(n=2, m=2))['t'], 'd')
1405        self.assertEqual(get(table, dict(n=1, m=2), ('n', 'm'))['t'], 'b')
1406        self.assertEqual(get(table, dict(n=2, m=1), ['n', 'm'])['t'], 'c')
1407        self.assertEqual(get(table, dict(n=3, m=2), ('m', 'n'))['t'], 'f')
1408
1409    def testGetWithQuotedNames(self):
1410        get = self.db.get
1411        query = self.db.query
1412        table = 'test table for get()'
1413        self.createTable(table, '"Prime!" smallint primary key,'
1414                                ' "much space" integer, "Questions?" text',
1415                         values=[(17, 1001, 'No!')])
1416        r = get(table, 17)
1417        self.assertIsInstance(r, dict)
1418        self.assertEqual(r['Prime!'], 17)
1419        self.assertEqual(r['much space'], 1001)
1420        self.assertEqual(r['Questions?'], 'No!')
1421
1422    def testGetFromView(self):
1423        self.db.query('delete from test where i4=14')
1424        self.db.query('insert into test (i4, v4) values('
1425                      "14, 'abc4')")
1426        r = self.db.get('test_view', 14, 'i4')
1427        self.assertIn('v4', r)
1428        self.assertEqual(r['v4'], 'abc4')
1429
1430    def testGetLittleBobbyTables(self):
1431        get = self.db.get
1432        query = self.db.query
1433        self.createTable('test_students',
1434                         'firstname varchar primary key, nickname varchar, grade char(2)',
1435                         values=[("D'Arcy", 'Darcey', 'A+'), ('Sheldon', 'Moonpie', 'A+'),
1436                                 ('Robert', 'Little Bobby Tables', 'D-')])
1437        r = get('test_students', 'Sheldon')
1438        self.assertEqual(r, dict(
1439            firstname="Sheldon", nickname='Moonpie', grade='A+'))
1440        r = get('test_students', 'Robert')
1441        self.assertEqual(r, dict(
1442            firstname="Robert", nickname='Little Bobby Tables', grade='D-'))
1443        r = get('test_students', "D'Arcy")
1444        self.assertEqual(r, dict(
1445            firstname="D'Arcy", nickname='Darcey', grade='A+'))
1446        try:
1447            get('test_students', "D' Arcy")
1448        except pg.DatabaseError as error:
1449            self.assertEqual(str(error),
1450                'No such record in test_students\nwhere "firstname" = $1\n'
1451                'with $1="D\' Arcy"')
1452        try:
1453            get('test_students', "Robert'); TRUNCATE TABLE test_students;--")
1454        except pg.DatabaseError as error:
1455            self.assertEqual(str(error),
1456                'No such record in test_students\nwhere "firstname" = $1\n'
1457                'with $1="Robert\'); TRUNCATE TABLE test_students;--"')
1458        q = "select * from test_students order by 1 limit 4"
1459        r = query(q).getresult()
1460        self.assertEqual(len(r), 3)
1461        self.assertEqual(r[1][2], 'D-')
1462
1463    def testInsert(self):
1464        insert = self.db.insert
1465        query = self.db.query
1466        bool_on = pg.get_bool()
1467        decimal = pg.get_decimal()
1468        table = 'insert_test_table'
1469        self.createTable(table,
1470            'i2 smallint, i4 integer, i8 bigint,'
1471            ' d numeric, f4 real, f8 double precision, m money,'
1472            ' v4 varchar(4), c4 char(4), t text,'
1473            ' b boolean, ts timestamp', oids=True)
1474        oid_table = 'oid(%s)' % table
1475        tests = [dict(i2=None, i4=None, i8=None),
1476             (dict(i2='', i4='', i8=''), dict(i2=None, i4=None, i8=None)),
1477             (dict(i2=0, i4=0, i8=0), dict(i2=0, i4=0, i8=0)),
1478             dict(i2=42, i4=123456, i8=9876543210),
1479             dict(i2=2 ** 15 - 1,
1480                  i4=int(2 ** 31 - 1), i8=long(2 ** 63 - 1)),
1481             dict(d=None), (dict(d=''), dict(d=None)),
1482             dict(d=Decimal(0)), (dict(d=0), dict(d=Decimal(0))),
1483             dict(f4=None, f8=None), dict(f4=0, f8=0),
1484             (dict(f4='', f8=''), dict(f4=None, f8=None)),
1485             (dict(d=1234.5, f4=1234.5, f8=1234.5),
1486              dict(d=Decimal('1234.5'))),
1487             dict(d=Decimal('123.456789'), f4=12.375, f8=123.4921875),
1488             dict(d=Decimal('123456789.9876543212345678987654321')),
1489             dict(m=None), (dict(m=''), dict(m=None)),
1490             dict(m=Decimal('-1234.56')),
1491             (dict(m=('-1234.56')), dict(m=Decimal('-1234.56'))),
1492             dict(m=Decimal('1234.56')), dict(m=Decimal('123456')),
1493             (dict(m='1234.56'), dict(m=Decimal('1234.56'))),
1494             (dict(m=1234.5), dict(m=Decimal('1234.5'))),
1495             (dict(m=-1234.5), dict(m=Decimal('-1234.5'))),
1496             (dict(m=123456), dict(m=Decimal('123456'))),
1497             (dict(m='1234567.89'), dict(m=Decimal('1234567.89'))),
1498             dict(b=None), (dict(b=''), dict(b=None)),
1499             dict(b='f'), dict(b='t'),
1500             (dict(b=0), dict(b='f')), (dict(b=1), dict(b='t')),
1501             (dict(b=False), dict(b='f')), (dict(b=True), dict(b='t')),
1502             (dict(b='0'), dict(b='f')), (dict(b='1'), dict(b='t')),
1503             (dict(b='n'), dict(b='f')), (dict(b='y'), dict(b='t')),
1504             (dict(b='no'), dict(b='f')), (dict(b='yes'), dict(b='t')),
1505             (dict(b='off'), dict(b='f')), (dict(b='on'), dict(b='t')),
1506             dict(v4=None, c4=None, t=None),
1507             (dict(v4='', c4='', t=''), dict(c4=' ' * 4)),
1508             dict(v4='1234', c4='1234', t='1234' * 10),
1509             dict(v4='abcd', c4='abcd', t='abcdefg'),
1510             (dict(v4='abc', c4='abc', t='abc'), dict(c4='abc ')),
1511             dict(ts=None), (dict(ts=''), dict(ts=None)),
1512             (dict(ts=0), dict(ts=None)), (dict(ts=False), dict(ts=None)),
1513             dict(ts='2012-12-21 00:00:00'),
1514             (dict(ts='2012-12-21'), dict(ts='2012-12-21 00:00:00')),
1515             dict(ts='2012-12-21 12:21:12'),
1516             dict(ts='2013-01-05 12:13:14'),
1517             dict(ts='current_timestamp')]
1518        for test in tests:
1519            if isinstance(test, dict):
1520                data = test
1521                change = {}
1522            else:
1523                data, change = test
1524            expect = data.copy()
1525            expect.update(change)
1526            if bool_on:
1527                b = expect.get('b')
1528                if b is not None:
1529                    expect['b'] = b == 't'
1530            if decimal is not Decimal:
1531                d = expect.get('d')
1532                if d is not None:
1533                    expect['d'] = decimal(d)
1534                m = expect.get('m')
1535                if m is not None:
1536                    expect['m'] = decimal(m)
1537            self.assertEqual(insert(table, data), data)
1538            self.assertIn(oid_table, data)
1539            oid = data[oid_table]
1540            self.assertIsInstance(oid, int)
1541            data = dict(item for item in data.items()
1542                        if item[0] in expect)
1543            ts = expect.get('ts')
1544            if ts:
1545                if ts == 'current_timestamp':
1546                    ts = data['ts']
1547                    self.assertIsInstance(ts, datetime)
1548                    self.assertEqual(ts.strftime('%Y-%m-%d'),
1549                        strftime('%Y-%m-%d'))
1550                else:
1551                    ts = datetime.strptime(ts, '%Y-%m-%d %H:%M:%S')
1552                expect['ts'] = ts
1553            self.assertEqual(data, expect)
1554            data = query(
1555                'select oid,* from "%s"' % table).dictresult()[0]
1556            self.assertEqual(data['oid'], oid)
1557            data = dict(item for item in data.items()
1558                        if item[0] in expect)
1559            self.assertEqual(data, expect)
1560            query('delete from "%s"' % table)
1561
1562    def testInsertWithOid(self):
1563        insert = self.db.insert
1564        query = self.db.query
1565        self.createTable('test_table', 'n int', oids=True)
1566        self.assertRaises(pg.ProgrammingError, insert, 'test_table', m=1)
1567        r = insert('test_table', n=1)
1568        self.assertIsInstance(r, dict)
1569        self.assertEqual(r['n'], 1)
1570        self.assertNotIn('oid', r)
1571        qoid = 'oid(test_table)'
1572        self.assertIn(qoid, r)
1573        oid = r[qoid]
1574        self.assertEqual(sorted(r.keys()), ['n', qoid])
1575        r = insert('test_table', n=2, oid=oid)
1576        self.assertIsInstance(r, dict)
1577        self.assertEqual(r['n'], 2)
1578        self.assertIn(qoid, r)
1579        self.assertNotEqual(r[qoid], oid)
1580        self.assertNotIn('oid', r)
1581        r = insert('test_table', None, n=3)
1582        self.assertIsInstance(r, dict)
1583        self.assertEqual(r['n'], 3)
1584        s = r
1585        r = insert('test_table', r)
1586        self.assertIs(r, s)
1587        self.assertEqual(r['n'], 3)
1588        r = insert('test_table *', r)
1589        self.assertIs(r, s)
1590        self.assertEqual(r['n'], 3)
1591        r = insert('test_table', r, n=4)
1592        self.assertIs(r, s)
1593        self.assertEqual(r['n'], 4)
1594        self.assertNotIn('oid', r)
1595        self.assertIn(qoid, r)
1596        oid = r[qoid]
1597        r = insert('test_table', r, n=5, oid=oid)
1598        self.assertIs(r, s)
1599        self.assertEqual(r['n'], 5)
1600        self.assertIn(qoid, r)
1601        self.assertNotEqual(r[qoid], oid)
1602        self.assertNotIn('oid', r)
1603        r['oid'] = oid = r[qoid]
1604        r = insert('test_table', r, n=6)
1605        self.assertIs(r, s)
1606        self.assertEqual(r['n'], 6)
1607        self.assertIn(qoid, r)
1608        self.assertNotEqual(r[qoid], oid)
1609        self.assertNotIn('oid', r)
1610        q = 'select n from test_table order by 1 limit 9'
1611        r = ' '.join(str(row[0]) for row in query(q).getresult())
1612        self.assertEqual(r, '1 2 3 3 3 4 5 6')
1613        query("truncate test_table")
1614        query("alter table test_table add unique (n)")
1615        r = insert('test_table', dict(n=7))
1616        self.assertIsInstance(r, dict)
1617        self.assertEqual(r['n'], 7)
1618        self.assertRaises(pg.IntegrityError, insert, 'test_table', r)
1619        r['n'] = 6
1620        self.assertRaises(pg.IntegrityError, insert, 'test_table', r, n=7)
1621        self.assertIsInstance(r, dict)
1622        self.assertEqual(r['n'], 7)
1623        r['n'] = 6
1624        r = insert('test_table', r)
1625        self.assertIsInstance(r, dict)
1626        self.assertEqual(r['n'], 6)
1627        r = ' '.join(str(row[0]) for row in query(q).getresult())
1628        self.assertEqual(r, '6 7')
1629
1630    def testInsertWithQuotedNames(self):
1631        insert = self.db.insert
1632        query = self.db.query
1633        table = 'test table for insert()'
1634        self.createTable(table, '"Prime!" smallint primary key,'
1635                                ' "much space" integer, "Questions?" text')
1636        r = {'Prime!': 11, 'much space': 2002, 'Questions?': 'What?'}
1637        r = insert(table, r)
1638        self.assertIsInstance(r, dict)
1639        self.assertEqual(r['Prime!'], 11)
1640        self.assertEqual(r['much space'], 2002)
1641        self.assertEqual(r['Questions?'], 'What?')
1642        r = query('select * from "%s" limit 2' % table).dictresult()
1643        self.assertEqual(len(r), 1)
1644        r = r[0]
1645        self.assertEqual(r['Prime!'], 11)
1646        self.assertEqual(r['much space'], 2002)
1647        self.assertEqual(r['Questions?'], 'What?')
1648
1649    def testInsertIntoView(self):
1650        insert = self.db.insert
1651        query = self.db.query
1652        query("truncate test")
1653        q = 'select * from test_view order by i4 limit 3'
1654        r = query(q).getresult()
1655        self.assertEqual(r, [])
1656        r = dict(i4=1234, v4='abcd')
1657        insert('test', r)
1658        self.assertIsNone(r['i2'])
1659        self.assertEqual(r['i4'], 1234)
1660        self.assertIsNone(r['i8'])
1661        self.assertEqual(r['v4'], 'abcd')
1662        self.assertIsNone(r['c4'])
1663        r = query(q).getresult()
1664        self.assertEqual(r, [(1234, 'abcd')])
1665        r = dict(i4=5678, v4='efgh')
1666        try:
1667            insert('test_view', r)
1668        except pg.NotSupportedError as error:
1669            if self.db.server_version < 90300:
1670                # must setup rules in older PostgreSQL versions
1671                self.skipTest('database cannot insert into view')
1672            self.fail(str(error))
1673        self.assertNotIn('i2', r)
1674        self.assertEqual(r['i4'], 5678)
1675        self.assertNotIn('i8', r)
1676        self.assertEqual(r['v4'], 'efgh')
1677        self.assertNotIn('c4', r)
1678        r = query(q).getresult()
1679        self.assertEqual(r, [(1234, 'abcd'), (5678, 'efgh')])
1680
1681    def testUpdate(self):
1682        update = self.db.update
1683        query = self.db.query
1684        self.assertRaises(pg.ProgrammingError, update,
1685                          'test', i2=2, i4=4, i8=8)
1686        table = 'update_test_table'
1687        self.createTable(table, 'n integer, t text', oids=True,
1688                         values=enumerate('xyz', start=1))
1689        self.assertRaises(pg.ProgrammingError, self.db.get, table, 2)
1690        r = self.db.get(table, 2, 'n')
1691        r['t'] = 'u'
1692        s = update(table, r)
1693        self.assertEqual(s, r)
1694        q = 'select t from "%s" where n=2' % table
1695        r = query(q).getresult()[0][0]
1696        self.assertEqual(r, 'u')
1697
1698    def testUpdateWithOid(self):
1699        update = self.db.update
1700        get = self.db.get
1701        query = self.db.query
1702        self.createTable('test_table', 'n int', oids=True, values=[1])
1703        s = get('test_table', 1, 'n')
1704        self.assertIsInstance(s, dict)
1705        self.assertEqual(s['n'], 1)
1706        s['n'] = 2
1707        r = update('test_table', s)
1708        self.assertIs(r, s)
1709        self.assertEqual(r['n'], 2)
1710        qoid = 'oid(test_table)'
1711        self.assertIn(qoid, r)
1712        self.assertNotIn('oid', r)
1713        self.assertEqual(sorted(r.keys()), ['n', qoid])
1714        r['n'] = 3
1715        oid = r.pop(qoid)
1716        r = update('test_table', r, oid=oid)
1717        self.assertIs(r, s)
1718        self.assertEqual(r['n'], 3)
1719        r.pop(qoid)
1720        self.assertRaises(pg.ProgrammingError, update, 'test_table', r)
1721        s = get('test_table', 3, 'n')
1722        self.assertIsInstance(s, dict)
1723        self.assertEqual(s['n'], 3)
1724        s.pop('n')
1725        r = update('test_table', s)
1726        oid = r.pop(qoid)
1727        self.assertEqual(r, {})
1728        q = "select n from test_table limit 2"
1729        r = query(q).getresult()
1730        self.assertEqual(r, [(3,)])
1731        query("insert into test_table values (1)")
1732        self.assertRaises(pg.ProgrammingError,
1733                          update, 'test_table', dict(oid=oid, n=4))
1734        r = update('test_table', dict(n=4), oid=oid)
1735        self.assertEqual(r['n'], 4)
1736        r = update('test_table *', dict(n=5), oid=oid)
1737        self.assertEqual(r['n'], 5)
1738        query("alter table test_table add column m int")
1739        query("alter table test_table add primary key (n)")
1740        self.assertIn('m', self.db.get_attnames('test_table', flush=True))
1741        self.assertEqual('n', self.db.pkey('test_table', flush=True))
1742        s = dict(n=1, m=4)
1743        r = update('test_table', s)
1744        self.assertIs(r, s)
1745        self.assertEqual(r['n'], 1)
1746        self.assertEqual(r['m'], 4)
1747        s = dict(m=7)
1748        r = update('test_table', s, n=5)
1749        self.assertIs(r, s)
1750        self.assertEqual(r['n'], 5)
1751        self.assertEqual(r['m'], 7)
1752        q = "select n, m from test_table order by 1 limit 3"
1753        r = query(q).getresult()
1754        self.assertEqual(r, [(1, 4), (5, 7)])
1755        s = dict(m=9, oid=oid)
1756        self.assertRaises(KeyError, update, 'test_table', s)
1757        r = update('test_table', s, oid=oid)
1758        self.assertIs(r, s)
1759        self.assertEqual(r['n'], 5)
1760        self.assertEqual(r['m'], 9)
1761        s = dict(n=1, m=3, oid=oid)
1762        r = update('test_table', s)
1763        self.assertIs(r, s)
1764        self.assertEqual(r['n'], 1)
1765        self.assertEqual(r['m'], 3)
1766        r = query(q).getresult()
1767        self.assertEqual(r, [(1, 3), (5, 9)])
1768
1769    def testUpdateWithoutOid(self):
1770        update = self.db.update
1771        query = self.db.query
1772        self.assertRaises(pg.ProgrammingError, update,
1773                          'test', i2=2, i4=4, i8=8)
1774        table = 'update_test_table'
1775        self.createTable(table, 'n integer primary key, t text', oids=False,
1776                         values=enumerate('xyz', start=1))
1777        r = self.db.get(table, 2)
1778        r['t'] = 'u'
1779        s = update(table, r)
1780        self.assertEqual(s, r)
1781        q = 'select t from "%s" where n=2' % table
1782        r = query(q).getresult()[0][0]
1783        self.assertEqual(r, 'u')
1784
1785    def testUpdateWithCompositeKey(self):
1786        update = self.db.update
1787        query = self.db.query
1788        table = 'update_test_table_1'
1789        self.createTable(table, 'n integer primary key, t text',
1790                         values=enumerate('abc', start=1))
1791        self.assertRaises(KeyError, update, table, dict(t='b'))
1792        s = dict(n=2, t='d')
1793        r = update(table, s)
1794        self.assertIs(r, s)
1795        self.assertEqual(r['n'], 2)
1796        self.assertEqual(r['t'], 'd')
1797        q = 'select t from "%s" where n=2' % table
1798        r = query(q).getresult()[0][0]
1799        self.assertEqual(r, 'd')
1800        s.update(dict(n=4, t='e'))
1801        r = update(table, s)
1802        self.assertEqual(r['n'], 4)
1803        self.assertEqual(r['t'], 'e')
1804        q = 'select t from "%s" where n=2' % table
1805        r = query(q).getresult()[0][0]
1806        self.assertEqual(r, 'd')
1807        q = 'select t from "%s" where n=4' % table
1808        r = query(q).getresult()
1809        self.assertEqual(len(r), 0)
1810        query('drop table "%s"' % table)
1811        table = 'update_test_table_2'
1812        self.createTable(table,
1813                         'n integer, m integer, t text, primary key (n, m)',
1814                         values=[(n + 1, m + 1, chr(ord('a') + 2 * n + m))
1815                                 for n in range(3) for m in range(2)])
1816        self.assertRaises(KeyError, update, table, dict(n=2, t='b'))
1817        self.assertEqual(update(table,
1818                                dict(n=2, m=2, t='x'))['t'], 'x')
1819        q = 'select t from "%s" where n=2 order by m' % table
1820        r = [r[0] for r in query(q).getresult()]
1821        self.assertEqual(r, ['c', 'x'])
1822
1823    def testUpdateWithQuotedNames(self):
1824        update = self.db.update
1825        query = self.db.query
1826        table = 'test table for update()'
1827        self.createTable(table, '"Prime!" smallint primary key,'
1828                                ' "much space" integer, "Questions?" text',
1829                         values=[(13, 3003, 'Why!')])
1830        r = {'Prime!': 13, 'much space': 7007, 'Questions?': 'When?'}
1831        r = update(table, r)
1832        self.assertIsInstance(r, dict)
1833        self.assertEqual(r['Prime!'], 13)
1834        self.assertEqual(r['much space'], 7007)
1835        self.assertEqual(r['Questions?'], 'When?')
1836        r = query('select * from "%s" limit 2' % table).dictresult()
1837        self.assertEqual(len(r), 1)
1838        r = r[0]
1839        self.assertEqual(r['Prime!'], 13)
1840        self.assertEqual(r['much space'], 7007)
1841        self.assertEqual(r['Questions?'], 'When?')
1842
1843    def testUpsert(self):
1844        upsert = self.db.upsert
1845        query = self.db.query
1846        self.assertRaises(pg.ProgrammingError, upsert,
1847                          'test', i2=2, i4=4, i8=8)
1848        table = 'upsert_test_table'
1849        self.createTable(table, 'n integer primary key, t text', oids=True)
1850        s = dict(n=1, t='x')
1851        try:
1852            r = upsert(table, s)
1853        except pg.ProgrammingError as error:
1854            if self.db.server_version < 90500:
1855                self.skipTest('database does not support upsert')
1856            self.fail(str(error))
1857        self.assertIs(r, s)
1858        self.assertEqual(r['n'], 1)
1859        self.assertEqual(r['t'], 'x')
1860        s.update(n=2, t='y')
1861        r = upsert(table, s, **dict.fromkeys(s))
1862        self.assertIs(r, s)
1863        self.assertEqual(r['n'], 2)
1864        self.assertEqual(r['t'], 'y')
1865        q = 'select n, t from "%s" order by n limit 3' % table
1866        r = query(q).getresult()
1867        self.assertEqual(r, [(1, 'x'), (2, 'y')])
1868        s.update(t='z')
1869        r = upsert(table, s)
1870        self.assertIs(r, s)
1871        self.assertEqual(r['n'], 2)
1872        self.assertEqual(r['t'], 'z')
1873        r = query(q).getresult()
1874        self.assertEqual(r, [(1, 'x'), (2, 'z')])
1875        s.update(t='n')
1876        r = upsert(table, s, t=False)
1877        self.assertIs(r, s)
1878        self.assertEqual(r['n'], 2)
1879        self.assertEqual(r['t'], 'z')
1880        r = query(q).getresult()
1881        self.assertEqual(r, [(1, 'x'), (2, 'z')])
1882        s.update(t='y')
1883        r = upsert(table, s, t=True)
1884        self.assertIs(r, s)
1885        self.assertEqual(r['n'], 2)
1886        self.assertEqual(r['t'], 'y')
1887        r = query(q).getresult()
1888        self.assertEqual(r, [(1, 'x'), (2, 'y')])
1889        s.update(t='n')
1890        r = upsert(table, s, t="included.t || '2'")
1891        self.assertIs(r, s)
1892        self.assertEqual(r['n'], 2)
1893        self.assertEqual(r['t'], 'y2')
1894        r = query(q).getresult()
1895        self.assertEqual(r, [(1, 'x'), (2, 'y2')])
1896        s.update(t='y')
1897        r = upsert(table, s, t="excluded.t || '3'")
1898        self.assertIs(r, s)
1899        self.assertEqual(r['n'], 2)
1900        self.assertEqual(r['t'], 'y3')
1901        r = query(q).getresult()
1902        self.assertEqual(r, [(1, 'x'), (2, 'y3')])
1903        s.update(n=1, t='2')
1904        r = upsert(table, s, t="included.t || excluded.t")
1905        self.assertIs(r, s)
1906        self.assertEqual(r['n'], 1)
1907        self.assertEqual(r['t'], 'x2')
1908        r = query(q).getresult()
1909        self.assertEqual(r, [(1, 'x2'), (2, 'y3')])
1910        # not existing columns and oid parameter should be ignored
1911        s = dict(m=3, u='z')
1912        r = upsert(table, s, oid='invalid')
1913        self.assertIs(r, s)
1914
1915    def testUpsertWithOid(self):
1916        upsert = self.db.upsert
1917        get = self.db.get
1918        query = self.db.query
1919        self.createTable('test_table', 'n int', oids=True, values=[1])
1920        self.assertRaises(pg.ProgrammingError,
1921                          upsert, 'test_table', dict(n=2))
1922        r = get('test_table', 1, 'n')
1923        self.assertIsInstance(r, dict)
1924        self.assertEqual(r['n'], 1)
1925        qoid = 'oid(test_table)'
1926        self.assertIn(qoid, r)
1927        self.assertNotIn('oid', r)
1928        oid = r[qoid]
1929        self.assertRaises(pg.ProgrammingError,
1930                          upsert, 'test_table', dict(n=2, oid=oid))
1931        query("alter table test_table add column m int")
1932        query("alter table test_table add primary key (n)")
1933        self.assertIn('m', self.db.get_attnames('test_table', flush=True))
1934        self.assertEqual('n', self.db.pkey('test_table', flush=True))
1935        s = dict(n=2)
1936        try:
1937            r = upsert('test_table', s)
1938        except pg.ProgrammingError as error:
1939            if self.db.server_version < 90500:
1940                self.skipTest('database does not support upsert')
1941            self.fail(str(error))
1942        self.assertIs(r, s)
1943        self.assertEqual(r['n'], 2)
1944        self.assertIsNone(r['m'])
1945        q = query("select n, m from test_table order by n limit 3")
1946        self.assertEqual(q.getresult(), [(1, None), (2, None)])
1947        r['oid'] = oid
1948        r = upsert('test_table', r)
1949        self.assertIs(r, s)
1950        self.assertEqual(r['n'], 2)
1951        self.assertIsNone(r['m'])
1952        self.assertIn(qoid, r)
1953        self.assertNotIn('oid', r)
1954        self.assertNotEqual(r[qoid], oid)
1955        r['m'] = 7
1956        r = upsert('test_table', r)
1957        self.assertIs(r, s)
1958        self.assertEqual(r['n'], 2)
1959        self.assertEqual(r['m'], 7)
1960        r.update(n=1, m=3)
1961        r = upsert('test_table', r)
1962        self.assertIs(r, s)
1963        self.assertEqual(r['n'], 1)
1964        self.assertEqual(r['m'], 3)
1965        q = query("select n, m from test_table order by n limit 3")
1966        self.assertEqual(q.getresult(), [(1, 3), (2, 7)])
1967        r = upsert('test_table', r, oid='invalid')
1968        self.assertIs(r, s)
1969        self.assertEqual(r['n'], 1)
1970        self.assertEqual(r['m'], 3)
1971        r['m'] = 5
1972        r = upsert('test_table', r, m=False)
1973        self.assertIs(r, s)
1974        self.assertEqual(r['n'], 1)
1975        self.assertEqual(r['m'], 3)
1976        r['m'] = 5
1977        r = upsert('test_table', r, m=True)
1978        self.assertIs(r, s)
1979        self.assertEqual(r['n'], 1)
1980        self.assertEqual(r['m'], 5)
1981        r.update(n=2, m=1)
1982        r = upsert('test_table', r, m='included.m')
1983        self.assertIs(r, s)
1984        self.assertEqual(r['n'], 2)
1985        self.assertEqual(r['m'], 7)
1986        r['m'] = 9
1987        r = upsert('test_table', r, m='excluded.m')
1988        self.assertIs(r, s)
1989        self.assertEqual(r['n'], 2)
1990        self.assertEqual(r['m'], 9)
1991        r['m'] = 8
1992        r = upsert('test_table *', r, m='included.m + 1')
1993        self.assertIs(r, s)
1994        self.assertEqual(r['n'], 2)
1995        self.assertEqual(r['m'], 10)
1996        q = query("select n, m from test_table order by n limit 3")
1997        self.assertEqual(q.getresult(), [(1, 5), (2, 10)])
1998
1999    def testUpsertWithCompositeKey(self):
2000        upsert = self.db.upsert
2001        query = self.db.query
2002        table = 'upsert_test_table_2'
2003        self.createTable(table,
2004                         'n integer, m integer, t text, primary key (n, m)')
2005        s = dict(n=1, m=2, t='x')
2006        try:
2007            r = upsert(table, s)
2008        except pg.ProgrammingError as error:
2009            if self.db.server_version < 90500:
2010                self.skipTest('database does not support upsert')
2011            self.fail(str(error))
2012        self.assertIs(r, s)
2013        self.assertEqual(r['n'], 1)
2014        self.assertEqual(r['m'], 2)
2015        self.assertEqual(r['t'], 'x')
2016        s.update(m=3, t='y')
2017        r = upsert(table, s, **dict.fromkeys(s))
2018        self.assertIs(r, s)
2019        self.assertEqual(r['n'], 1)
2020        self.assertEqual(r['m'], 3)
2021        self.assertEqual(r['t'], 'y')
2022        q = 'select n, m, t from "%s" order by n, m limit 3' % table
2023        r = query(q).getresult()
2024        self.assertEqual(r, [(1, 2, 'x'), (1, 3, 'y')])
2025        s.update(t='z')
2026        r = upsert(table, s)
2027        self.assertIs(r, s)
2028        self.assertEqual(r['n'], 1)
2029        self.assertEqual(r['m'], 3)
2030        self.assertEqual(r['t'], 'z')
2031        r = query(q).getresult()
2032        self.assertEqual(r, [(1, 2, 'x'), (1, 3, 'z')])
2033        s.update(t='n')
2034        r = upsert(table, s, t=False)
2035        self.assertIs(r, s)
2036        self.assertEqual(r['n'], 1)
2037        self.assertEqual(r['m'], 3)
2038        self.assertEqual(r['t'], 'z')
2039        r = query(q).getresult()
2040        self.assertEqual(r, [(1, 2, 'x'), (1, 3, 'z')])
2041        s.update(t='n')
2042        r = upsert(table, s, t=True)
2043        self.assertIs(r, s)
2044        self.assertEqual(r['n'], 1)
2045        self.assertEqual(r['m'], 3)
2046        self.assertEqual(r['t'], 'n')
2047        r = query(q).getresult()
2048        self.assertEqual(r, [(1, 2, 'x'), (1, 3, 'n')])
2049        s.update(n=2, t='y')
2050        r = upsert(table, s, t="'z'")
2051        self.assertIs(r, s)
2052        self.assertEqual(r['n'], 2)
2053        self.assertEqual(r['m'], 3)
2054        self.assertEqual(r['t'], 'y')
2055        r = query(q).getresult()
2056        self.assertEqual(r, [(1, 2, 'x'), (1, 3, 'n'), (2, 3, 'y')])
2057        s.update(n=1, t='m')
2058        r = upsert(table, s, t='included.t || excluded.t')
2059        self.assertIs(r, s)
2060        self.assertEqual(r['n'], 1)
2061        self.assertEqual(r['m'], 3)
2062        self.assertEqual(r['t'], 'nm')
2063        r = query(q).getresult()
2064        self.assertEqual(r, [(1, 2, 'x'), (1, 3, 'nm'), (2, 3, 'y')])
2065
2066    def testUpsertWithQuotedNames(self):
2067        upsert = self.db.upsert
2068        query = self.db.query
2069        table = 'test table for upsert()'
2070        self.createTable(table, '"Prime!" smallint primary key,'
2071                                ' "much space" integer, "Questions?" text')
2072        s = {'Prime!': 31, 'much space': 9009, 'Questions?': 'Yes.'}
2073        try:
2074            r = upsert(table, s)
2075        except pg.ProgrammingError as error:
2076            if self.db.server_version < 90500:
2077                self.skipTest('database does not support upsert')
2078            self.fail(str(error))
2079        self.assertIs(r, s)
2080        self.assertEqual(r['Prime!'], 31)
2081        self.assertEqual(r['much space'], 9009)
2082        self.assertEqual(r['Questions?'], 'Yes.')
2083        q = 'select * from "%s" limit 2' % table
2084        r = query(q).getresult()
2085        self.assertEqual(r, [(31, 9009, 'Yes.')])
2086        s.update({'Questions?': 'No.'})
2087        r = upsert(table, s)
2088        self.assertIs(r, s)
2089        self.assertEqual(r['Prime!'], 31)
2090        self.assertEqual(r['much space'], 9009)
2091        self.assertEqual(r['Questions?'], 'No.')
2092        r = query(q).getresult()
2093        self.assertEqual(r, [(31, 9009, 'No.')])
2094
2095    def testClear(self):
2096        clear = self.db.clear
2097        f = False if pg.get_bool() else 'f'
2098        r = clear('test')
2099        result = dict(
2100            i2=0, i4=0, i8=0, d=0, f4=0, f8=0, m=0, v4='', c4='', t='')
2101        self.assertEqual(r, result)
2102        table = 'clear_test_table'
2103        self.createTable(table,
2104            'n integer, f float, b boolean, d date, t text', oids=True)
2105        r = clear(table)
2106        result = dict(n=0, f=0, b=f, d='', t='')
2107        self.assertEqual(r, result)
2108        r['a'] = r['f'] = r['n'] = 1
2109        r['d'] = r['t'] = 'x'
2110        r['b'] = 't'
2111        r['oid'] = long(1)
2112        r = clear(table, r)
2113        result = dict(a=1, n=0, f=0, b=f, d='', t='', oid=long(1))
2114        self.assertEqual(r, result)
2115
2116    def testClearWithQuotedNames(self):
2117        clear = self.db.clear
2118        table = 'test table for clear()'
2119        self.createTable(table, '"Prime!" smallint primary key,'
2120            ' "much space" integer, "Questions?" text')
2121        r = clear(table)
2122        self.assertIsInstance(r, dict)
2123        self.assertEqual(r['Prime!'], 0)
2124        self.assertEqual(r['much space'], 0)
2125        self.assertEqual(r['Questions?'], '')
2126
2127    def testDelete(self):
2128        delete = self.db.delete
2129        query = self.db.query
2130        self.assertRaises(pg.ProgrammingError, delete,
2131                          'test', dict(i2=2, i4=4, i8=8))
2132        table = 'delete_test_table'
2133        self.createTable(table, 'n integer, t text', oids=True,
2134                         values=enumerate('xyz', start=1))
2135        self.assertRaises(pg.ProgrammingError, self.db.get, table, 2)
2136        r = self.db.get(table, 1, 'n')
2137        s = delete(table, r)
2138        self.assertEqual(s, 1)
2139        r = self.db.get(table, 3, 'n')
2140        s = delete(table, r)
2141        self.assertEqual(s, 1)
2142        s = delete(table, r)
2143        self.assertEqual(s, 0)
2144        r = query('select * from "%s"' % table).dictresult()
2145        self.assertEqual(len(r), 1)
2146        r = r[0]
2147        result = {'n': 2, 't': 'y'}
2148        self.assertEqual(r, result)
2149        r = self.db.get(table, 2, 'n')
2150        s = delete(table, r)
2151        self.assertEqual(s, 1)
2152        s = delete(table, r)
2153        self.assertEqual(s, 0)
2154        self.assertRaises(pg.DatabaseError, self.db.get, table, 2, 'n')
2155        # not existing columns and oid parameter should be ignored
2156        r.update(m=3, u='z', oid='invalid')
2157        s = delete(table, r)
2158        self.assertEqual(s, 0)
2159
2160    def testDeleteWithOid(self):
2161        delete = self.db.delete
2162        get = self.db.get
2163        query = self.db.query
2164        self.createTable('test_table', 'n int', oids=True, values=range(1, 7))
2165        r = dict(n=3)
2166        self.assertRaises(pg.ProgrammingError, delete, 'test_table', r)
2167        s = get('test_table', 1, 'n')
2168        qoid = 'oid(test_table)'
2169        self.assertIn(qoid, s)
2170        r = delete('test_table', s)
2171        self.assertEqual(r, 1)
2172        r = delete('test_table', s)
2173        self.assertEqual(r, 0)
2174        q = "select min(n),count(n) from test_table"
2175        self.assertEqual(query(q).getresult()[0], (2, 5))
2176        oid = get('test_table', 2, 'n')[qoid]
2177        s = dict(oid=oid, n=2)
2178        self.assertRaises(pg.ProgrammingError, delete, 'test_table', s)
2179        r = delete('test_table', None, oid=oid)
2180        self.assertEqual(r, 1)
2181        r = delete('test_table', None, oid=oid)
2182        self.assertEqual(r, 0)
2183        self.assertEqual(query(q).getresult()[0], (3, 4))
2184        s = dict(oid=oid, n=2)
2185        oid = get('test_table', 3, 'n')[qoid]
2186        self.assertRaises(pg.ProgrammingError, delete, 'test_table', s)
2187        r = delete('test_table', s, oid=oid)
2188        self.assertEqual(r, 1)
2189        r = delete('test_table', s, oid=oid)
2190        self.assertEqual(r, 0)
2191        self.assertEqual(query(q).getresult()[0], (4, 3))
2192        s = get('test_table', 4, 'n')
2193        r = delete('test_table *', s)
2194        self.assertEqual(r, 1)
2195        r = delete('test_table *', s)
2196        self.assertEqual(r, 0)
2197        self.assertEqual(query(q).getresult()[0], (5, 2))
2198        oid = get('test_table', 5, 'n')[qoid]
2199        s = {qoid: oid, 'm': 4}
2200        r = delete('test_table', s, m=6)
2201        self.assertEqual(r, 1)
2202        r = delete('test_table *', s)
2203        self.assertEqual(r, 0)
2204        self.assertEqual(query(q).getresult()[0], (6, 1))
2205        query("alter table test_table add column m int")
2206        query("alter table test_table add primary key (n)")
2207        self.assertIn('m', self.db.get_attnames('test_table', flush=True))
2208        self.assertEqual('n', self.db.pkey('test_table', flush=True))
2209        for i in range(5):
2210            query("insert into test_table values (%d, %d)" % (i + 1, i + 2))
2211        s = dict(m=2)
2212        self.assertRaises(KeyError, delete, 'test_table', s)
2213        s = dict(m=2, oid=oid)
2214        self.assertRaises(KeyError, delete, 'test_table', s)
2215        r = delete('test_table', dict(m=2), oid=oid)
2216        self.assertEqual(r, 0)
2217        oid = get('test_table', 1, 'n')[qoid]
2218        s = dict(oid=oid)
2219        self.assertRaises(KeyError, delete, 'test_table', s)
2220        r = delete('test_table', s, oid=oid)
2221        self.assertEqual(r, 1)
2222        r = delete('test_table', s, oid=oid)
2223        self.assertEqual(r, 0)
2224        self.assertEqual(query(q).getresult()[0], (2, 5))
2225        s = get('test_table', 2, 'n')
2226        del s['n']
2227        r = delete('test_table', s)
2228        self.assertEqual(r, 1)
2229        r = delete('test_table', s)
2230        self.assertEqual(r, 0)
2231        self.assertEqual(query(q).getresult()[0], (3, 4))
2232        r = delete('test_table', n=3)
2233        self.assertEqual(r, 1)
2234        r = delete('test_table', n=3)
2235        self.assertEqual(r, 0)
2236        self.assertEqual(query(q).getresult()[0], (4, 3))
2237        r = delete('test_table', None, n=4)
2238        self.assertEqual(r, 1)
2239        r = delete('test_table', None, n=4)
2240        self.assertEqual(r, 0)
2241        self.assertEqual(query(q).getresult()[0], (5, 2))
2242        s = dict(n=6)
2243        r = delete('test_table', s, n=5)
2244        self.assertEqual(r, 1)
2245        r = delete('test_table', s, n=5)
2246        self.assertEqual(r, 0)
2247        self.assertEqual(query(q).getresult()[0], (6, 1))
2248
2249    def testDeleteWithCompositeKey(self):
2250        query = self.db.query
2251        table = 'delete_test_table_1'
2252        self.createTable(table, 'n integer primary key, t text',
2253                         values=enumerate('abc', start=1))
2254        self.assertRaises(KeyError, self.db.delete, table, dict(t='b'))
2255        self.assertEqual(self.db.delete(table, dict(n=2)), 1)
2256        r = query('select t from "%s" where n=2' % table).getresult()
2257        self.assertEqual(r, [])
2258        self.assertEqual(self.db.delete(table, dict(n=2)), 0)
2259        r = query('select t from "%s" where n=3' % table).getresult()[0][0]
2260        self.assertEqual(r, 'c')
2261        table = 'delete_test_table_2'
2262        self.createTable(table,
2263             'n integer, m integer, t text, primary key (n, m)',
2264             values=[(n + 1, m + 1, chr(ord('a') + 2 * n + m))
2265                     for n in range(3) for m in range(2)])
2266        self.assertRaises(KeyError, self.db.delete, table, dict(n=2, t='b'))
2267        self.assertEqual(self.db.delete(table, dict(n=2, m=2)), 1)
2268        r = [r[0] for r in query('select t from "%s" where n=2'
2269            ' order by m' % table).getresult()]
2270        self.assertEqual(r, ['c'])
2271        self.assertEqual(self.db.delete(table, dict(n=2, m=2)), 0)
2272        r = [r[0] for r in query('select t from "%s" where n=3'
2273             ' order by m' % table).getresult()]
2274        self.assertEqual(r, ['e', 'f'])
2275        self.assertEqual(self.db.delete(table, dict(n=3, m=1)), 1)
2276        r = [r[0] for r in query('select t from "%s" where n=3'
2277             ' order by m' % table).getresult()]
2278        self.assertEqual(r, ['f'])
2279
2280    def testDeleteWithQuotedNames(self):
2281        delete = self.db.delete
2282        query = self.db.query
2283        table = 'test table for delete()'
2284        self.createTable(table, '"Prime!" smallint primary key,'
2285            ' "much space" integer, "Questions?" text',
2286            values=[(19, 5005, 'Yes!')])
2287        r = {'Prime!': 17}
2288        r = delete(table, r)
2289        self.assertEqual(r, 0)
2290        r = query('select count(*) from "%s"' % table).getresult()
2291        self.assertEqual(r[0][0], 1)
2292        r = {'Prime!': 19}
2293        r = delete(table, r)
2294        self.assertEqual(r, 1)
2295        r = query('select count(*) from "%s"' % table).getresult()
2296        self.assertEqual(r[0][0], 0)
2297
2298    def testDeleteReferenced(self):
2299        delete = self.db.delete
2300        query = self.db.query
2301        self.createTable('test_parent',
2302            'n smallint primary key', values=range(3))
2303        self.createTable('test_child',
2304            'n smallint primary key references test_parent', values=range(3))
2305        q = ("select (select count(*) from test_parent),"
2306             " (select count(*) from test_child)")
2307        self.assertEqual(query(q).getresult()[0], (3, 3))
2308        self.assertRaises(pg.IntegrityError,
2309                          delete, 'test_parent', None, n=2)
2310        self.assertRaises(pg.IntegrityError,
2311                          delete, 'test_parent *', None, n=2)
2312        r = delete('test_child', None, n=2)
2313        self.assertEqual(r, 1)
2314        self.assertEqual(query(q).getresult()[0], (3, 2))
2315        r = delete('test_parent', None, n=2)
2316        self.assertEqual(r, 1)
2317        self.assertEqual(query(q).getresult()[0], (2, 2))
2318        self.assertRaises(pg.IntegrityError,
2319                          delete, 'test_parent', dict(n=0))
2320        self.assertRaises(pg.IntegrityError,
2321                          delete, 'test_parent *', dict(n=0))
2322        r = delete('test_child', dict(n=0))
2323        self.assertEqual(r, 1)
2324        self.assertEqual(query(q).getresult()[0], (2, 1))
2325        r = delete('test_child', dict(n=0))
2326        self.assertEqual(r, 0)
2327        r = delete('test_parent', dict(n=0))
2328        self.assertEqual(r, 1)
2329        self.assertEqual(query(q).getresult()[0], (1, 1))
2330        r = delete('test_parent', None, n=0)
2331        self.assertEqual(r, 0)
2332        q = "select n from test_parent natural join test_child limit 2"
2333        self.assertEqual(query(q).getresult(), [(1,)])
2334
2335    def testTempCrud(self):
2336        table = 'test_temp_table'
2337        self.createTable(table, "n int primary key, t varchar", temporary=True)
2338        self.db.insert(table, dict(n=1, t='one'))
2339        self.db.insert(table, dict(n=2, t='too'))
2340        self.db.insert(table, dict(n=3, t='three'))
2341        r = self.db.get(table, 2)
2342        self.assertEqual(r['t'], 'too')
2343        self.db.update(table, dict(n=2, t='two'))
2344        r = self.db.get(table, 2)
2345        self.assertEqual(r['t'], 'two')
2346        self.db.delete(table, r)
2347        r = self.db.query('select n, t from %s order by 1' % table).getresult()
2348        self.assertEqual(r, [(1, 'one'), (3, 'three')])
2349
2350    def testTruncate(self):
2351        truncate = self.db.truncate
2352        self.assertRaises(TypeError, truncate, None)
2353        self.assertRaises(TypeError, truncate, 42)
2354        self.assertRaises(TypeError, truncate, dict(test_table=None))
2355        query = self.db.query
2356        self.createTable('test_table', 'n smallint',
2357                         temporary=False, values=[1] * 3)
2358        q = "select count(*) from test_table"
2359        r = query(q).getresult()[0][0]
2360        self.assertEqual(r, 3)
2361        truncate('test_table')
2362        r = query(q).getresult()[0][0]
2363        self.assertEqual(r, 0)
2364        for i in range(3):
2365            query("insert into test_table values (1)")
2366        r = query(q).getresult()[0][0]
2367        self.assertEqual(r, 3)
2368        truncate('public.test_table')
2369        r = query(q).getresult()[0][0]
2370        self.assertEqual(r, 0)
2371        self.createTable('test_table_2', 'n smallint', temporary=True)
2372        for t in (list, tuple, set):
2373            for i in range(3):
2374                query("insert into test_table values (1)")
2375                query("insert into test_table_2 values (2)")
2376            q = ("select (select count(*) from test_table),"
2377                 " (select count(*) from test_table_2)")
2378            r = query(q).getresult()[0]
2379            self.assertEqual(r, (3, 3))
2380            truncate(t(['test_table', 'test_table_2']))
2381            r = query(q).getresult()[0]
2382            self.assertEqual(r, (0, 0))
2383
2384    def testTruncateRestart(self):
2385        truncate = self.db.truncate
2386        self.assertRaises(TypeError, truncate, 'test_table', restart='invalid')
2387        query = self.db.query
2388        self.createTable('test_table', 'n serial, t text')
2389        for n in range(3):
2390            query("insert into test_table (t) values ('test')")
2391        q = "select count(n), min(n), max(n) from test_table"
2392        r = query(q).getresult()[0]
2393        self.assertEqual(r, (3, 1, 3))
2394        truncate('test_table')
2395        r = query(q).getresult()[0]
2396        self.assertEqual(r, (0, None, None))
2397        for n in range(3):
2398            query("insert into test_table (t) values ('test')")
2399        r = query(q).getresult()[0]
2400        self.assertEqual(r, (3, 4, 6))
2401        truncate('test_table', restart=True)
2402        r = query(q).getresult()[0]
2403        self.assertEqual(r, (0, None, None))
2404        for n in range(3):
2405            query("insert into test_table (t) values ('test')")
2406        r = query(q).getresult()[0]
2407        self.assertEqual(r, (3, 1, 3))
2408
2409    def testTruncateCascade(self):
2410        truncate = self.db.truncate
2411        self.assertRaises(TypeError, truncate, 'test_table', cascade='invalid')
2412        query = self.db.query
2413        self.createTable('test_parent', 'n smallint primary key',
2414                         values=range(3))
2415        self.createTable('test_child',
2416                         'n smallint primary key references test_parent (n)',
2417                         values=range(3))
2418        q = ("select (select count(*) from test_parent),"
2419             " (select count(*) from test_child)")
2420        r = query(q).getresult()[0]
2421        self.assertEqual(r, (3, 3))
2422        self.assertRaises(pg.NotSupportedError, truncate, 'test_parent')
2423        truncate(['test_parent', 'test_child'])
2424        r = query(q).getresult()[0]
2425        self.assertEqual(r, (0, 0))
2426        for n in range(3):
2427            query("insert into test_parent (n) values (%d)" % n)
2428            query("insert into test_child (n) values (%d)" % n)
2429        r = query(q).getresult()[0]
2430        self.assertEqual(r, (3, 3))
2431        truncate('test_parent', cascade=True)
2432        r = query(q).getresult()[0]
2433        self.assertEqual(r, (0, 0))
2434        for n in range(3):
2435            query("insert into test_parent (n) values (%d)" % n)
2436            query("insert into test_child (n) values (%d)" % n)
2437        r = query(q).getresult()[0]
2438        self.assertEqual(r, (3, 3))
2439        truncate('test_child')
2440        r = query(q).getresult()[0]
2441        self.assertEqual(r, (3, 0))
2442        self.assertRaises(pg.NotSupportedError, truncate, 'test_parent')
2443        truncate('test_parent', cascade=True)
2444        r = query(q).getresult()[0]
2445        self.assertEqual(r, (0, 0))
2446
2447    def testTruncateOnly(self):
2448        truncate = self.db.truncate
2449        self.assertRaises(TypeError, truncate, 'test_table', only='invalid')
2450        query = self.db.query
2451        self.createTable('test_parent', 'n smallint')
2452        self.createTable('test_child', 'm smallint) inherits (test_parent')
2453        for n in range(3):
2454            query("insert into test_parent (n) values (1)")
2455            query("insert into test_child (n, m) values (2, 3)")
2456        q = ("select (select count(*) from test_parent),"
2457             " (select count(*) from test_child)")
2458        r = query(q).getresult()[0]
2459        self.assertEqual(r, (6, 3))
2460        truncate('test_parent')
2461        r = query(q).getresult()[0]
2462        self.assertEqual(r, (0, 0))
2463        for n in range(3):
2464            query("insert into test_parent (n) values (1)")
2465            query("insert into test_child (n, m) values (2, 3)")
2466        r = query(q).getresult()[0]
2467        self.assertEqual(r, (6, 3))
2468        truncate('test_parent*')
2469        r = query(q).getresult()[0]
2470        self.assertEqual(r, (0, 0))
2471        for n in range(3):
2472            query("insert into test_parent (n) values (1)")
2473            query("insert into test_child (n, m) values (2, 3)")
2474        r = query(q).getresult()[0]
2475        self.assertEqual(r, (6, 3))
2476        truncate('test_parent', only=True)
2477        r = query(q).getresult()[0]
2478        self.assertEqual(r, (3, 3))
2479        truncate('test_parent', only=False)
2480        r = query(q).getresult()[0]
2481        self.assertEqual(r, (0, 0))
2482        self.assertRaises(ValueError, truncate, 'test_parent*', only=True)
2483        truncate('test_parent*', only=False)
2484        self.createTable('test_parent_2', 'n smallint')
2485        self.createTable('test_child_2', 'm smallint) inherits (test_parent_2')
2486        for t in '', '_2':
2487            for n in range(3):
2488                query("insert into test_parent%s (n) values (1)" % t)
2489                query("insert into test_child%s (n, m) values (2, 3)" % t)
2490        q = ("select (select count(*) from test_parent),"
2491             " (select count(*) from test_child),"
2492             " (select count(*) from test_parent_2),"
2493             " (select count(*) from test_child_2)")
2494        r = query(q).getresult()[0]
2495        self.assertEqual(r, (6, 3, 6, 3))
2496        truncate(['test_parent', 'test_parent_2'], only=[False, True])
2497        r = query(q).getresult()[0]
2498        self.assertEqual(r, (0, 0, 3, 3))
2499        truncate(['test_parent', 'test_parent_2'], only=False)
2500        r = query(q).getresult()[0]
2501        self.assertEqual(r, (0, 0, 0, 0))
2502        self.assertRaises(ValueError, truncate,
2503            ['test_parent*', 'test_child'], only=[True, False])
2504        truncate(['test_parent*', 'test_child'], only=[False, True])
2505
2506    def testTruncateQuoted(self):
2507        truncate = self.db.truncate
2508        query = self.db.query
2509        table = "test table for truncate()"
2510        self.createTable(table, 'n smallint', temporary=False, values=[1] * 3)
2511        q = 'select count(*) from "%s"' % table
2512        r = query(q).getresult()[0][0]
2513        self.assertEqual(r, 3)
2514        truncate(table)
2515        r = query(q).getresult()[0][0]
2516        self.assertEqual(r, 0)
2517        for i in range(3):
2518            query('insert into "%s" values (1)' % table)
2519        r = query(q).getresult()[0][0]
2520        self.assertEqual(r, 3)
2521        truncate('public."%s"' % table)
2522        r = query(q).getresult()[0][0]
2523        self.assertEqual(r, 0)
2524
2525    def testGetAsList(self):
2526        get_as_list = self.db.get_as_list
2527        self.assertRaises(TypeError, get_as_list)
2528        self.assertRaises(TypeError, get_as_list, None)
2529        query = self.db.query
2530        table = 'test_aslist'
2531        r = query('select 1 as colname').namedresult()[0]
2532        self.assertIsInstance(r, tuple)
2533        named = hasattr(r, 'colname')
2534        names = [(1, 'Homer'), (2, 'Marge'),
2535                 (3, 'Bart'), (4, 'Lisa'), (5, 'Maggie')]
2536        self.createTable(table,
2537            'id smallint primary key, name varchar', values=names)
2538        r = get_as_list(table)
2539        self.assertIsInstance(r, list)
2540        self.assertEqual(r, names)
2541        for t, n in zip(r, names):
2542            self.assertIsInstance(t, tuple)
2543            self.assertEqual(t, n)
2544            if named:
2545                self.assertEqual(t.id, n[0])
2546                self.assertEqual(t.name, n[1])
2547                self.assertEqual(t._asdict(), dict(id=n[0], name=n[1]))
2548        r = get_as_list(table, what='name')
2549        self.assertIsInstance(r, list)
2550        expected = sorted((row[1],) for row in names)
2551        self.assertEqual(r, expected)
2552        r = get_as_list(table, what='name, id')
2553        self.assertIsInstance(r, list)
2554        expected = sorted(tuple(reversed(row)) for row in names)
2555        self.assertEqual(r, expected)
2556        r = get_as_list(table, what=['name', 'id'])
2557        self.assertIsInstance(r, list)
2558        self.assertEqual(r, expected)
2559        r = get_as_list(table, where="name like 'Ba%'")
2560        self.assertIsInstance(r, list)
2561        self.assertEqual(r, names[2:3])
2562        r = get_as_list(table, what='name', where="name like 'Ma%'")
2563        self.assertIsInstance(r, list)
2564        self.assertEqual(r, [('Maggie',), ('Marge',)])
2565        r = get_as_list(table, what='name',
2566            where=["name like 'Ma%'", "name like '%r%'"])
2567        self.assertIsInstance(r, list)
2568        self.assertEqual(r, [('Marge',)])
2569        r = get_as_list(table, what='name', order='id')
2570        self.assertIsInstance(r, list)
2571        expected = [(row[1],) for row in names]
2572        self.assertEqual(r, expected)
2573        r = get_as_list(table, what=['name'], order=['id'])
2574        self.assertIsInstance(r, list)
2575        self.assertEqual(r, expected)
2576        r = get_as_list(table, what=['id', 'name'], order=['id', 'name'])
2577        self.assertIsInstance(r, list)
2578        self.assertEqual(r, names)
2579        r = get_as_list(table, what='id * 2 as num', order='id desc')
2580        self.assertIsInstance(r, list)
2581        expected = [(n,) for n in range(10, 0, -2)]
2582        self.assertEqual(r, expected)
2583        r = get_as_list(table, limit=2)
2584        self.assertIsInstance(r, list)
2585        self.assertEqual(r, names[:2])
2586        r = get_as_list(table, offset=3)
2587        self.assertIsInstance(r, list)
2588        self.assertEqual(r, names[3:])
2589        r = get_as_list(table, limit=1, offset=2)
2590        self.assertIsInstance(r, list)
2591        self.assertEqual(r, names[2:3])
2592        r = get_as_list(table, scalar=True)
2593        self.assertIsInstance(r, list)
2594        self.assertEqual(r, list(range(1, 6)))
2595        r = get_as_list(table, what='name', scalar=True)
2596        self.assertIsInstance(r, list)
2597        expected = sorted(row[1] for row in names)
2598        self.assertEqual(r, expected)
2599        r = get_as_list(table, what='name', limit=1, scalar=True)
2600        self.assertIsInstance(r, list)
2601        self.assertEqual(r, expected[:1])
2602        query('alter table "%s" drop constraint "%s_pkey"' % (table, table))
2603        self.assertRaises(KeyError, self.db.pkey, table, flush=True)
2604        names.insert(1, (1, 'Snowball'))
2605        query('insert into "%s" values ($1, $2)' % table, (1, 'Snowball'))
2606        r = get_as_list(table)
2607        self.assertIsInstance(r, list)
2608        self.assertEqual(r, names)
2609        r = get_as_list(table, what='name', where='id=1', scalar=True)
2610        self.assertIsInstance(r, list)
2611        self.assertEqual(r, ['Homer', 'Snowball'])
2612        # test with unordered query
2613        r = get_as_list(table, order=False)
2614        self.assertIsInstance(r, list)
2615        self.assertEqual(set(r), set(names))
2616        # test with arbitrary from clause
2617        from_table = '(select lower(name) as n2 from "%s") as t2' % table
2618        r = get_as_list(from_table)
2619        self.assertIsInstance(r, list)
2620        r = set(row[0] for row in r)
2621        expected = set(row[1].lower() for row in names)
2622        self.assertEqual(r, expected)
2623        r = get_as_list(from_table, order='n2', scalar=True)
2624        self.assertIsInstance(r, list)
2625        self.assertEqual(r, sorted(expected))
2626        r = get_as_list(from_table, order='n2', limit=1)
2627        self.assertIsInstance(r, list)
2628        self.assertEqual(len(r), 1)
2629        t = r[0]
2630        self.assertIsInstance(t, tuple)
2631        if named:
2632            self.assertEqual(t.n2, 'bart')
2633            self.assertEqual(t._asdict(), dict(n2='bart'))
2634        else:
2635            self.assertEqual(t, ('bart',))
2636
2637    def testGetAsDict(self):
2638        get_as_dict = self.db.get_as_dict
2639        self.assertRaises(TypeError, get_as_dict)
2640        self.assertRaises(TypeError, get_as_dict, None)
2641        # the test table has no primary key
2642        self.assertRaises(pg.ProgrammingError, get_as_dict, 'test')
2643        query = self.db.query
2644        table = 'test_asdict'
2645        r = query('select 1 as colname').namedresult()[0]
2646        self.assertIsInstance(r, tuple)
2647        named = hasattr(r, 'colname')
2648        colors = [(1, '#7cb9e8', 'Aero'), (2, '#b5a642', 'Brass'),
2649                  (3, '#b2ffff', 'Celeste'), (4, '#c19a6b', 'Desert')]
2650        self.createTable(table,
2651            'id smallint primary key, rgb char(7), name varchar',
2652            values=colors)
2653        # keyname must be string, list or tuple
2654        self.assertRaises(KeyError, get_as_dict, table, 3)
2655        self.assertRaises(KeyError, get_as_dict, table, dict(id=None))
2656        # missing keyname in row
2657        self.assertRaises(KeyError, get_as_dict, table,
2658                          keyname='rgb', what='name')
2659        r = get_as_dict(table)
2660        self.assertIsInstance(r, OrderedDict)
2661        expected = OrderedDict((row[0], row[1:]) for row in colors)
2662        self.assertEqual(r, expected)
2663        for key in r:
2664            self.assertIsInstance(key, int)
2665            self.assertIn(key, expected)
2666            row = r[key]
2667            self.assertIsInstance(row, tuple)
2668            t = expected[key]
2669            self.assertEqual(row, t)
2670            if named:
2671                self.assertEqual(row.rgb, t[0])
2672                self.assertEqual(row.name, t[1])
2673                self.assertEqual(row._asdict(), dict(rgb=t[0], name=t[1]))
2674        if OrderedDict is not dict:  # Python > 2.6
2675            self.assertEqual(r.keys(), expected.keys())
2676        r = get_as_dict(table, keyname='rgb')
2677        self.assertIsInstance(r, OrderedDict)
2678        expected = OrderedDict((row[1], (row[0], row[2]))
2679                               for row in sorted(colors, key=itemgetter(1)))
2680        self.assertEqual(r, expected)
2681        for key in r:
2682            self.assertIsInstance(key, str)
2683            self.assertIn(key, expected)
2684            row = r[key]
2685            self.assertIsInstance(row, tuple)
2686            t = expected[key]
2687            self.assertEqual(row, t)
2688            if named:
2689                self.assertEqual(row.id, t[0])
2690                self.assertEqual(row.name, t[1])
2691                self.assertEqual(row._asdict(), dict(id=t[0], name=t[1]))
2692        if OrderedDict is not dict:  # Python > 2.6
2693            self.assertEqual(r.keys(), expected.keys())
2694        r = get_as_dict(table, keyname=['id', 'rgb'])
2695        self.assertIsInstance(r, OrderedDict)
2696        expected = OrderedDict((row[:2], row[2:]) for row in colors)
2697        self.assertEqual(r, expected)
2698        for key in r:
2699            self.assertIsInstance(key, tuple)
2700            self.assertIsInstance(key[0], int)
2701            self.assertIsInstance(key[1], str)
2702            if named:
2703                self.assertEqual(key, (key.id, key.rgb))
2704                self.assertEqual(key._fields, ('id', 'rgb'))
2705            row = r[key]
2706            self.assertIsInstance(row, tuple)
2707            self.assertIsInstance(row[0], str)
2708            t = expected[key]
2709            self.assertEqual(row, t)
2710            if named:
2711                self.assertEqual(row.name, t[0])
2712                self.assertEqual(row._asdict(), dict(name=t[0]))
2713        if OrderedDict is not dict:  # Python > 2.6
2714            self.assertEqual(r.keys(), expected.keys())
2715        r = get_as_dict(table, keyname=['id', 'rgb'], scalar=True)
2716        self.assertIsInstance(r, OrderedDict)
2717        expected = OrderedDict((row[:2], row[2]) for row in colors)
2718        self.assertEqual(r, expected)
2719        for key in r:
2720            self.assertIsInstance(key, tuple)
2721            row = r[key]
2722            self.assertIsInstance(row, str)
2723            t = expected[key]
2724            self.assertEqual(row, t)
2725        if OrderedDict is not dict:  # Python > 2.6
2726            self.assertEqual(r.keys(), expected.keys())
2727        r = get_as_dict(table, keyname='rgb', what=['rgb', 'name'], scalar=True)
2728        self.assertIsInstance(r, OrderedDict)
2729        expected = OrderedDict((row[1], row[2])
2730            for row in sorted(colors, key=itemgetter(1)))
2731        self.assertEqual(r, expected)
2732        for key in r:
2733            self.assertIsInstance(key, str)
2734            row = r[key]
2735            self.assertIsInstance(row, str)
2736            t = expected[key]
2737            self.assertEqual(row, t)
2738        if OrderedDict is not dict:  # Python > 2.6
2739            self.assertEqual(r.keys(), expected.keys())
2740        r = get_as_dict(table, what='id, name',
2741            where="rgb like '#b%'", scalar=True)
2742        self.assertIsInstance(r, OrderedDict)
2743        expected = OrderedDict((row[0], row[2]) for row in colors[1:3])
2744        self.assertEqual(r, expected)
2745        for key in r:
2746            self.assertIsInstance(key, int)
2747            row = r[key]
2748            self.assertIsInstance(row, str)
2749            t = expected[key]
2750            self.assertEqual(row, t)
2751        if OrderedDict is not dict:  # Python > 2.6
2752            self.assertEqual(r.keys(), expected.keys())
2753        expected = r
2754        r = get_as_dict(table, what=['name', 'id'],
2755            where=['id > 1', 'id < 4', "rgb like '#b%'",
2756                   "name not like 'A%'", "name not like '%t'"], scalar=True)
2757        self.assertEqual(r, expected)
2758        r = get_as_dict(table, what='name, id', limit=2, offset=1, scalar=True)
2759        self.assertEqual(r, expected)
2760        r = get_as_dict(table, keyname=('id',), what=('name', 'id'),
2761            where=('id > 1', 'id < 4'), order=('id',), scalar=True)
2762        self.assertEqual(r, expected)
2763        r = get_as_dict(table, limit=1)
2764        self.assertEqual(len(r), 1)
2765        self.assertEqual(r[1][1], 'Aero')
2766        r = get_as_dict(table, offset=3)
2767        self.assertEqual(len(r), 1)
2768        self.assertEqual(r[4][1], 'Desert')
2769        r = get_as_dict(table, order='id desc')
2770        expected = OrderedDict((row[0], row[1:]) for row in reversed(colors))
2771        self.assertEqual(r, expected)
2772        r = get_as_dict(table, where='id > 5')
2773        self.assertIsInstance(r, OrderedDict)
2774        self.assertEqual(len(r), 0)
2775        # test with unordered query
2776        expected = dict((row[0], row[1:]) for row in colors)
2777        r = get_as_dict(table, order=False)
2778        self.assertIsInstance(r, dict)
2779        self.assertEqual(r, expected)
2780        if dict is not OrderedDict:  # Python > 2.6
2781            self.assertNotIsInstance(self, OrderedDict)
2782        # test with arbitrary from clause
2783        from_table = '(select id, lower(name) as n2 from "%s") as t2' % table
2784        # primary key must be passed explicitly in this case
2785        self.assertRaises(pg.ProgrammingError, get_as_dict, from_table)
2786        r = get_as_dict(from_table, 'id')
2787        self.assertIsInstance(r, OrderedDict)
2788        expected = OrderedDict((row[0], (row[2].lower(),)) for row in colors)
2789        self.assertEqual(r, expected)
2790        # test without a primary key
2791        query('alter table "%s" drop constraint "%s_pkey"' % (table, table))
2792        self.assertRaises(KeyError, self.db.pkey, table, flush=True)
2793        self.assertRaises(pg.ProgrammingError, get_as_dict, table)
2794        r = get_as_dict(table, keyname='id')
2795        expected = OrderedDict((row[0], row[1:]) for row in colors)
2796        self.assertIsInstance(r, dict)
2797        self.assertEqual(r, expected)
2798        r = (1, '#007fff', 'Azure')
2799        query('insert into "%s" values ($1, $2, $3)' % table, r)
2800        # the last entry will win
2801        expected[1] = r[1:]
2802        r = get_as_dict(table, keyname='id')
2803        self.assertEqual(r, expected)
2804
2805    def testTransaction(self):
2806        query = self.db.query
2807        self.createTable('test_table', 'n integer', temporary=False)
2808        self.db.begin()
2809        query("insert into test_table values (1)")
2810        query("insert into test_table values (2)")
2811        self.db.commit()
2812        self.db.begin()
2813        query("insert into test_table values (3)")
2814        query("insert into test_table values (4)")
2815        self.db.rollback()
2816        self.db.begin()
2817        query("insert into test_table values (5)")
2818        self.db.savepoint('before6')
2819        query("insert into test_table values (6)")
2820        self.db.rollback('before6')
2821        query("insert into test_table values (7)")
2822        self.db.commit()
2823        self.db.begin()
2824        self.db.savepoint('before8')
2825        query("insert into test_table values (8)")
2826        self.db.release('before8')
2827        self.assertRaises(pg.InternalError, self.db.rollback, 'before8')
2828        self.db.commit()
2829        self.db.start()
2830        query("insert into test_table values (9)")
2831        self.db.end()
2832        r = [r[0] for r in query(
2833            "select * from test_table order by 1").getresult()]
2834        self.assertEqual(r, [1, 2, 5, 7, 9])
2835        self.db.begin(mode='read only')
2836        self.assertRaises(pg.InternalError,
2837                          query, "insert into test_table values (0)")
2838        self.db.rollback()
2839        self.db.start(mode='Read Only')
2840        self.assertRaises(pg.InternalError,
2841                          query, "insert into test_table values (0)")
2842        self.db.abort()
2843
2844    def testTransactionAliases(self):
2845        self.assertEqual(self.db.begin, self.db.start)
2846        self.assertEqual(self.db.commit, self.db.end)
2847        self.assertEqual(self.db.rollback, self.db.abort)
2848
2849    def testContextManager(self):
2850        query = self.db.query
2851        self.createTable('test_table', 'n integer check(n>0)')
2852        with self.db:
2853            query("insert into test_table values (1)")
2854            query("insert into test_table values (2)")
2855        try:
2856            with self.db:
2857                query("insert into test_table values (3)")
2858                query("insert into test_table values (4)")
2859                raise ValueError('test transaction should rollback')
2860        except ValueError as error:
2861            self.assertEqual(str(error), 'test transaction should rollback')
2862        with self.db:
2863            query("insert into test_table values (5)")
2864        try:
2865            with self.db:
2866                query("insert into test_table values (6)")
2867                query("insert into test_table values (-1)")
2868        except pg.IntegrityError as error:
2869            self.assertTrue('check' in str(error))
2870        with self.db:
2871            query("insert into test_table values (7)")
2872        r = [r[0] for r in query(
2873            "select * from test_table order by 1").getresult()]
2874        self.assertEqual(r, [1, 2, 5, 7])
2875
2876    def testBytea(self):
2877        query = self.db.query
2878        self.createTable('bytea_test', 'n smallint primary key, data bytea')
2879        s = b"It's all \\ kinds \x00 of\r nasty \xff stuff!\n"
2880        r = self.db.escape_bytea(s)
2881        query('insert into bytea_test values(3, $1)', (r,))
2882        r = query('select * from bytea_test where n=3').getresult()
2883        self.assertEqual(len(r), 1)
2884        r = r[0]
2885        self.assertEqual(len(r), 2)
2886        self.assertEqual(r[0], 3)
2887        r = r[1]
2888        if pg.get_bytea_escaped():
2889            self.assertNotEqual(r, s)
2890            r = pg.unescape_bytea(r)
2891        self.assertIsInstance(r, bytes)
2892        self.assertEqual(r, s)
2893
2894    def testInsertUpdateGetBytea(self):
2895        query = self.db.query
2896        unescape = pg.unescape_bytea if pg.get_bytea_escaped() else None
2897        self.createTable('bytea_test', 'n smallint primary key, data bytea')
2898        # insert null value
2899        r = self.db.insert('bytea_test', n=0, data=None)
2900        self.assertIsInstance(r, dict)
2901        self.assertIn('n', r)
2902        self.assertEqual(r['n'], 0)
2903        self.assertIn('data', r)
2904        self.assertIsNone(r['data'])
2905        s = b'None'
2906        r = self.db.update('bytea_test', n=0, data=s)
2907        self.assertIsInstance(r, dict)
2908        self.assertIn('n', r)
2909        self.assertEqual(r['n'], 0)
2910        self.assertIn('data', r)
2911        r = r['data']
2912        if unescape:
2913            self.assertNotEqual(r, s)
2914            r = unescape(r)
2915        self.assertIsInstance(r, bytes)
2916        self.assertEqual(r, s)
2917        r = self.db.update('bytea_test', n=0, data=None)
2918        self.assertIsNone(r['data'])
2919        # insert as bytes
2920        s = b"It's all \\ kinds \x00 of\r nasty \xff stuff!\n"
2921        r = self.db.insert('bytea_test', n=5, data=s)
2922        self.assertIsInstance(r, dict)
2923        self.assertIn('n', r)
2924        self.assertEqual(r['n'], 5)
2925        self.assertIn('data', r)
2926        r = r['data']
2927        if unescape:
2928            self.assertNotEqual(r, s)
2929            r = unescape(r)
2930        self.assertIsInstance(r, bytes)
2931        self.assertEqual(r, s)
2932        # update as bytes
2933        s += b"and now even more \x00 nasty \t stuff!\f"
2934        r = self.db.update('bytea_test', n=5, data=s)
2935        self.assertIsInstance(r, dict)
2936        self.assertIn('n', r)
2937        self.assertEqual(r['n'], 5)
2938        self.assertIn('data', r)
2939        r = r['data']
2940        if unescape:
2941            self.assertNotEqual(r, s)
2942            r = unescape(r)
2943        self.assertIsInstance(r, bytes)
2944        self.assertEqual(r, s)
2945        r = query('select * from bytea_test where n=5').getresult()
2946        self.assertEqual(len(r), 1)
2947        r = r[0]
2948        self.assertEqual(len(r), 2)
2949        self.assertEqual(r[0], 5)
2950        r = r[1]
2951        if unescape:
2952            self.assertNotEqual(r, s)
2953            r = unescape(r)
2954        self.assertIsInstance(r, bytes)
2955        self.assertEqual(r, s)
2956        r = self.db.get('bytea_test', dict(n=5))
2957        self.assertIsInstance(r, dict)
2958        self.assertIn('n', r)
2959        self.assertEqual(r['n'], 5)
2960        self.assertIn('data', r)
2961        r = r['data']
2962        if unescape:
2963            self.assertNotEqual(r, s)
2964            r = pg.unescape_bytea(r)
2965        self.assertIsInstance(r, bytes)
2966        self.assertEqual(r, s)
2967
2968    def testUpsertBytea(self):
2969        self.createTable('bytea_test', 'n smallint primary key, data bytea')
2970        s = b"It's all \\ kinds \x00 of\r nasty \xff stuff!\n"
2971        r = dict(n=7, data=s)
2972        try:
2973            r = self.db.upsert('bytea_test', r)
2974        except pg.ProgrammingError as error:
2975            if self.db.server_version < 90500:
2976                self.skipTest('database does not support upsert')
2977            self.fail(str(error))
2978        self.assertIsInstance(r, dict)
2979        self.assertIn('n', r)
2980        self.assertEqual(r['n'], 7)
2981        self.assertIn('data', r)
2982        if pg.get_bytea_escaped():
2983            self.assertNotEqual(r['data'], s)
2984            r['data'] = pg.unescape_bytea(r['data'])
2985        self.assertIsInstance(r['data'], bytes)
2986        self.assertEqual(r['data'], s)
2987        r['data'] = None
2988        r = self.db.upsert('bytea_test', r)
2989        self.assertIsInstance(r, dict)
2990        self.assertIn('n', r)
2991        self.assertEqual(r['n'], 7)
2992        self.assertIn('data', r)
2993        self.assertIsNone(r['data'])
2994
2995    def testInsertGetJson(self):
2996        try:
2997            self.createTable('json_test', 'n smallint primary key, data json')
2998        except pg.ProgrammingError as error:
2999            if self.db.server_version < 90200:
3000                self.skipTest('database does not support json')
3001            self.fail(str(error))
3002        jsondecode = pg.get_jsondecode()
3003        # insert null value
3004        r = self.db.insert('json_test', n=0, data=None)
3005        self.assertIsInstance(r, dict)
3006        self.assertIn('n', r)
3007        self.assertEqual(r['n'], 0)
3008        self.assertIn('data', r)
3009        self.assertIsNone(r['data'])
3010        r = self.db.get('json_test', 0)
3011        self.assertIsInstance(r, dict)
3012        self.assertIn('n', r)
3013        self.assertEqual(r['n'], 0)
3014        self.assertIn('data', r)
3015        self.assertIsNone(r['data'])
3016        # insert JSON object
3017        data = {
3018            "id": 1, "name": "Foo", "price": 1234.5,
3019            "new": True, "note": None,
3020            "tags": ["Bar", "Eek"],
3021            "stock": {"warehouse": 300, "retail": 20}}
3022        r = self.db.insert('json_test', n=1, data=data)
3023        self.assertIsInstance(r, dict)
3024        self.assertIn('n', r)
3025        self.assertEqual(r['n'], 1)
3026        self.assertIn('data', r)
3027        r = r['data']
3028        if jsondecode is None:
3029            self.assertIsInstance(r, str)
3030            r = json.loads(r)
3031        self.assertIsInstance(r, dict)
3032        self.assertEqual(r, data)
3033        self.assertIsInstance(r['id'], int)
3034        self.assertIsInstance(r['name'], unicode)
3035        self.assertIsInstance(r['price'], float)
3036        self.assertIsInstance(r['new'], bool)
3037        self.assertIsInstance(r['tags'], list)
3038        self.assertIsInstance(r['stock'], dict)
3039        r = self.db.get('json_test', 1)
3040        self.assertIsInstance(r, dict)
3041        self.assertIn('n', r)
3042        self.assertEqual(r['n'], 1)
3043        self.assertIn('data', r)
3044        r = r['data']
3045        if jsondecode is None:
3046            self.assertIsInstance(r, str)
3047            r = json.loads(r)
3048        self.assertIsInstance(r, dict)
3049        self.assertEqual(r, data)
3050        self.assertIsInstance(r['id'], int)
3051        self.assertIsInstance(r['name'], unicode)
3052        self.assertIsInstance(r['price'], float)
3053        self.assertIsInstance(r['new'], bool)
3054        self.assertIsInstance(r['tags'], list)
3055        self.assertIsInstance(r['stock'], dict)
3056        # insert JSON object as text
3057        self.db.insert('json_test', n=2, data=json.dumps(data))
3058        q = "select data from json_test where n in (1, 2) order by n"
3059        r = self.db.query(q).getresult()
3060        self.assertEqual(len(r), 2)
3061        self.assertIsInstance(r[0][0], str if jsondecode is None else dict)
3062        self.assertEqual(r[0][0], r[1][0])
3063
3064    def testInsertGetJsonb(self):
3065        try:
3066            self.createTable('jsonb_test',
3067                             'n smallint primary key, data jsonb')
3068        except pg.ProgrammingError as error:
3069            if self.db.server_version < 90400:
3070                self.skipTest('database does not support jsonb')
3071            self.fail(str(error))
3072        jsondecode = pg.get_jsondecode()
3073        # insert null value
3074        r = self.db.insert('jsonb_test', n=0, data=None)
3075        self.assertIsInstance(r, dict)
3076        self.assertIn('n', r)
3077        self.assertEqual(r['n'], 0)
3078        self.assertIn('data', r)
3079        self.assertIsNone(r['data'])
3080        r = self.db.get('jsonb_test', 0)
3081        self.assertIsInstance(r, dict)
3082        self.assertIn('n', r)
3083        self.assertEqual(r['n'], 0)
3084        self.assertIn('data', r)
3085        self.assertIsNone(r['data'])
3086        # insert JSON object
3087        data = {
3088            "id": 1, "name": "Foo", "price": 1234.5,
3089            "new": True, "note": None,
3090            "tags": ["Bar", "Eek"],
3091            "stock": {"warehouse": 300, "retail": 20}}
3092        r = self.db.insert('jsonb_test', n=1, data=data)
3093        self.assertIsInstance(r, dict)
3094        self.assertIn('n', r)
3095        self.assertEqual(r['n'], 1)
3096        self.assertIn('data', r)
3097        r = r['data']
3098        if jsondecode is None:
3099            self.assertIsInstance(r, str)
3100            r = json.loads(r)
3101        self.assertIsInstance(r, dict)
3102        self.assertEqual(r, data)
3103        self.assertIsInstance(r['id'], int)
3104        self.assertIsInstance(r['name'], unicode)
3105        self.assertIsInstance(r['price'], float)
3106        self.assertIsInstance(r['new'], bool)
3107        self.assertIsInstance(r['tags'], list)
3108        self.assertIsInstance(r['stock'], dict)
3109        r = self.db.get('jsonb_test', 1)
3110        self.assertIsInstance(r, dict)
3111        self.assertIn('n', r)
3112        self.assertEqual(r['n'], 1)
3113        self.assertIn('data', r)
3114        r = r['data']
3115        if jsondecode is None:
3116            self.assertIsInstance(r, str)
3117            r = json.loads(r)
3118        self.assertIsInstance(r, dict)
3119        self.assertEqual(r, data)
3120        self.assertIsInstance(r['id'], int)
3121        self.assertIsInstance(r['name'], unicode)
3122        self.assertIsInstance(r['price'], float)
3123        self.assertIsInstance(r['new'], bool)
3124        self.assertIsInstance(r['tags'], list)
3125        self.assertIsInstance(r['stock'], dict)
3126
3127    def testArray(self):
3128        returns_arrays = pg.get_array()
3129        self.createTable('arraytest',
3130            'id smallint, i2 smallint[], i4 integer[], i8 bigint[],'
3131            ' d numeric[], f4 real[], f8 double precision[], m money[],'
3132            ' b bool[], v4 varchar(4)[], c4 char(4)[], t text[]')
3133        r = self.db.get_attnames('arraytest')
3134        if self.regtypes:
3135            self.assertEqual(r, dict(
3136                id='smallint', i2='smallint[]', i4='integer[]', i8='bigint[]',
3137                d='numeric[]', f4='real[]', f8='double precision[]',
3138                m='money[]', b='boolean[]',
3139                v4='character varying[]', c4='character[]', t='text[]'))
3140        else:
3141            self.assertEqual(r, dict(
3142                id='int', i2='int[]', i4='int[]', i8='int[]',
3143                d='num[]', f4='float[]', f8='float[]', m='money[]',
3144                b='bool[]', v4='text[]', c4='text[]', t='text[]'))
3145        decimal = pg.get_decimal()
3146        if decimal is Decimal:
3147            long_decimal = decimal('123456789.123456789')
3148            odd_money = decimal('1234567891234567.89')
3149        else:
3150            long_decimal = decimal('12345671234.5')
3151            odd_money = decimal('1234567123.25')
3152        t, f = (True, False) if pg.get_bool() else ('t', 'f')
3153        data = dict(id=42, i2=[42, 1234, None, 0, -1],
3154            i4=[42, 123456789, None, 0, 1, -1],
3155            i8=[long(42), long(123456789123456789), None,
3156                long(0), long(1), long(-1)],
3157            d=[decimal(42), long_decimal, None,
3158               decimal(0), decimal(1), decimal(-1), -long_decimal],
3159            f4=[42.0, 1234.5, None, 0.0, 1.0, -1.0,
3160                float('inf'), float('-inf')],
3161            f8=[42.0, 12345671234.5, None, 0.0, 1.0, -1.0,
3162                float('inf'), float('-inf')],
3163            m=[decimal('42.00'), odd_money, None,
3164               decimal('0.00'), decimal('1.00'), decimal('-1.00'), -odd_money],
3165            b=[t, f, t, None, f, t, None, None, t],
3166            v4=['abc', '"Hi"', '', None], c4=['abc ', '"Hi"', '    ', None],
3167            t=['abc', 'Hello, World!', '"Hello, World!"', '', None])
3168        r = data.copy()
3169        self.db.insert('arraytest', r)
3170        if returns_arrays:
3171            self.assertEqual(r, data)
3172        else:
3173            self.assertEqual(r['i4'], '{42,123456789,NULL,0,1,-1}')
3174        self.db.insert('arraytest', r)
3175        r = self.db.get('arraytest', 42, 'id')
3176        if returns_arrays:
3177            self.assertEqual(r, data)
3178        else:
3179            self.assertEqual(r['i4'], '{42,123456789,NULL,0,1,-1}')
3180        r = self.db.query('select * from arraytest limit 1').dictresult()[0]
3181        if returns_arrays:
3182            self.assertEqual(r, data)
3183        else:
3184            self.assertEqual(r['i4'], '{42,123456789,NULL,0,1,-1}')
3185
3186    def testArrayLiteral(self):
3187        insert = self.db.insert
3188        returns_arrays = pg.get_array()
3189        self.createTable('arraytest', 'i int[], t text[]', oids=True)
3190        r = dict(i=[1, 2, 3], t=['a', 'b', 'c'])
3191        insert('arraytest', r)
3192        if returns_arrays:
3193            self.assertEqual(r['i'], [1, 2, 3])
3194            self.assertEqual(r['t'], ['a', 'b', 'c'])
3195        else:
3196            self.assertEqual(r['i'], '{1,2,3}')
3197            self.assertEqual(r['t'], '{a,b,c}')
3198        r = dict(i='{1,2,3}', t='{a,b,c}')
3199        self.db.insert('arraytest', r)
3200        if returns_arrays:
3201            self.assertEqual(r['i'], [1, 2, 3])
3202            self.assertEqual(r['t'], ['a', 'b', 'c'])
3203        else:
3204            self.assertEqual(r['i'], '{1,2,3}')
3205            self.assertEqual(r['t'], '{a,b,c}')
3206        L = pg.Literal
3207        r = dict(i=L("ARRAY[1, 2, 3]"), t=L("ARRAY['a', 'b', 'c']"))
3208        self.db.insert('arraytest', r)
3209        if returns_arrays:
3210            self.assertEqual(r['i'], [1, 2, 3])
3211            self.assertEqual(r['t'], ['a', 'b', 'c'])
3212        else:
3213            self.assertEqual(r['i'], '{1,2,3}')
3214            self.assertEqual(r['t'], '{a,b,c}')
3215        r = dict(i="1, 2, 3", t="'a', 'b', 'c'")
3216        self.assertRaises(pg.DataError, self.db.insert, 'arraytest', r)
3217
3218    def testArrayOfIds(self):
3219        array_on = pg.get_array()
3220        self.createTable('arraytest', 'c cid[], o oid[], x xid[]', oids=True)
3221        r = self.db.get_attnames('arraytest')
3222        if self.regtypes:
3223            self.assertEqual(r, dict(
3224                oid='oid', c='cid[]', o='oid[]', x='xid[]'))
3225        else:
3226            self.assertEqual(r, dict(
3227                oid='int', c='int[]', o='int[]', x='int[]'))
3228        data = dict(c=[11, 12, 13], o=[21, 22, 23], x=[31, 32, 33])
3229        r = data.copy()
3230        self.db.insert('arraytest', r)
3231        qoid = 'oid(arraytest)'
3232        oid = r.pop(qoid)
3233        if array_on:
3234            self.assertEqual(r, data)
3235        else:
3236            self.assertEqual(r['o'], '{21,22,23}')
3237        r = {qoid: oid}
3238        self.db.get('arraytest', r)
3239        self.assertEqual(oid, r.pop(qoid))
3240        if array_on:
3241            self.assertEqual(r, data)
3242        else:
3243            self.assertEqual(r['o'], '{21,22,23}')
3244
3245    def testArrayOfText(self):
3246        array_on = pg.get_array()
3247        self.createTable('arraytest', 'data text[]', oids=True)
3248        r = self.db.get_attnames('arraytest')
3249        self.assertEqual(r['data'], 'text[]')
3250        data = ['Hello, World!', '', None, '{a,b,c}', '"Hi!"',
3251                'null', 'NULL', 'Null', 'nulL',
3252                "It's all \\ kinds of\r nasty stuff!\n"]
3253        r = dict(data=data)
3254        self.db.insert('arraytest', r)
3255        if not array_on:
3256            r['data'] = pg.cast_array(r['data'])
3257        self.assertEqual(r['data'], data)
3258        self.assertIsInstance(r['data'][1], str)
3259        self.assertIsNone(r['data'][2])
3260        r['data'] = None
3261        self.db.get('arraytest', r)
3262        if not array_on:
3263            r['data'] = pg.cast_array(r['data'])
3264        self.assertEqual(r['data'], data)
3265        self.assertIsInstance(r['data'][1], str)
3266        self.assertIsNone(r['data'][2])
3267
3268    def testArrayOfBytea(self):
3269        array_on = pg.get_array()
3270        bytea_escaped = pg.get_bytea_escaped()
3271        self.createTable('arraytest', 'data bytea[]', oids=True)
3272        r = self.db.get_attnames('arraytest')
3273        self.assertEqual(r['data'], 'bytea[]')
3274        data = [b'Hello, World!', b'', None, b'{a,b,c}', b'"Hi!"',
3275                b"It's all \\ kinds \x00 of\r nasty \xff stuff!\n"]
3276        r = dict(data=data)
3277        self.db.insert('arraytest', r)
3278        if array_on:
3279            self.assertIsInstance(r['data'], list)
3280        if array_on and not bytea_escaped:
3281            self.assertEqual(r['data'], data)
3282            self.assertIsInstance(r['data'][1], bytes)
3283            self.assertIsNone(r['data'][2])
3284        else:
3285            self.assertNotEqual(r['data'], data)
3286        r['data'] = None
3287        self.db.get('arraytest', r)
3288        if array_on:
3289            self.assertIsInstance(r['data'], list)
3290        if array_on and not bytea_escaped:
3291            self.assertEqual(r['data'], data)
3292            self.assertIsInstance(r['data'][1], bytes)
3293            self.assertIsNone(r['data'][2])
3294        else:
3295            self.assertNotEqual(r['data'], data)
3296
3297    def testArrayOfJson(self):
3298        try:
3299            self.createTable('arraytest', 'data json[]', oids=True)
3300        except pg.ProgrammingError as error:
3301            if self.db.server_version < 90200:
3302                self.skipTest('database does not support json')
3303            self.fail(str(error))
3304        r = self.db.get_attnames('arraytest')
3305        self.assertEqual(r['data'], 'json[]')
3306        data = [dict(id=815, name='John Doe'), dict(id=816, name='Jane Roe')]
3307        array_on = pg.get_array()
3308        jsondecode = pg.get_jsondecode()
3309        r = dict(data=data)
3310        self.db.insert('arraytest', r)
3311        if not array_on:
3312            r['data'] = pg.cast_array(r['data'], jsondecode)
3313        if jsondecode is None:
3314            r['data'] = [json.loads(d) for d in r['data']]
3315        self.assertEqual(r['data'], data)
3316        r['data'] = None
3317        self.db.get('arraytest', r)
3318        if not array_on:
3319            r['data'] = pg.cast_array(r['data'], jsondecode)
3320        if jsondecode is None:
3321            r['data'] = [json.loads(d) for d in r['data']]
3322        self.assertEqual(r['data'], data)
3323        r = dict(data=[json.dumps(d) for d in data])
3324        self.db.insert('arraytest', r)
3325        if not array_on:
3326            r['data'] = pg.cast_array(r['data'], jsondecode)
3327        if jsondecode is None:
3328            r['data'] = [json.loads(d) for d in r['data']]
3329        self.assertEqual(r['data'], data)
3330        r['data'] = None
3331        self.db.get('arraytest', r)
3332        # insert empty json values
3333        r = dict(data=['', None])
3334        self.db.insert('arraytest', r)
3335        r = r['data']
3336        if array_on:
3337            self.assertIsInstance(r, list)
3338            self.assertEqual(len(r), 2)
3339            self.assertIsNone(r[0])
3340            self.assertIsNone(r[1])
3341        else:
3342            self.assertEqual(r, '{NULL,NULL}')
3343
3344    def testArrayOfJsonb(self):
3345        try:
3346            self.createTable('arraytest', 'data jsonb[]', oids=True)
3347        except pg.ProgrammingError as error:
3348            if self.db.server_version < 90400:
3349                self.skipTest('database does not support jsonb')
3350            self.fail(str(error))
3351        r = self.db.get_attnames('arraytest')
3352        self.assertEqual(r['data'], 'jsonb[]' if self.regtypes else 'json[]')
3353        data = [dict(id=815, name='John Doe'), dict(id=816, name='Jane Roe')]
3354        array_on = pg.get_array()
3355        jsondecode = pg.get_jsondecode()
3356        r = dict(data=data)
3357        self.db.insert('arraytest', r)
3358        if not array_on:
3359            r['data'] = pg.cast_array(r['data'], jsondecode)
3360        if jsondecode is None:
3361            r['data'] = [json.loads(d) for d in r['data']]
3362        self.assertEqual(r['data'], data)
3363        r['data'] = None
3364        self.db.get('arraytest', r)
3365        if not array_on:
3366            r['data'] = pg.cast_array(r['data'], jsondecode)
3367        if jsondecode is None:
3368            r['data'] = [json.loads(d) for d in r['data']]
3369        self.assertEqual(r['data'], data)
3370        r = dict(data=[json.dumps(d) for d in data])
3371        self.db.insert('arraytest', r)
3372        if not array_on:
3373            r['data'] = pg.cast_array(r['data'], jsondecode)
3374        if jsondecode is None:
3375            r['data'] = [json.loads(d) for d in r['data']]
3376        self.assertEqual(r['data'], data)
3377        r['data'] = None
3378        self.db.get('arraytest', r)
3379        # insert empty json values
3380        r = dict(data=['', None])
3381        self.db.insert('arraytest', r)
3382        r = r['data']
3383        if array_on:
3384            self.assertIsInstance(r, list)
3385            self.assertEqual(len(r), 2)
3386            self.assertIsNone(r[0])
3387            self.assertIsNone(r[1])
3388        else:
3389            self.assertEqual(r, '{NULL,NULL}')
3390
3391    def testDeepArray(self):
3392        array_on = pg.get_array()
3393        self.createTable('arraytest', 'data text[][][]', oids=True)
3394        r = self.db.get_attnames('arraytest')
3395        self.assertEqual(r['data'], 'text[]')
3396        data = [[['Hello, World!', '{a,b,c}', 'back\\slash']]]
3397        r = dict(data=data)
3398        self.db.insert('arraytest', r)
3399        if array_on:
3400            self.assertEqual(r['data'], data)
3401        else:
3402            self.assertTrue(r['data'].startswith('{{{"Hello,'))
3403        r['data'] = None
3404        self.db.get('arraytest', r)
3405        if array_on:
3406            self.assertEqual(r['data'], data)
3407        else:
3408            self.assertTrue(r['data'].startswith('{{{"Hello,'))
3409
3410    def testInsertUpdateGetRecord(self):
3411        query = self.db.query
3412        query('create type test_person_type as'
3413              ' (name varchar, age smallint, married bool,'
3414              ' weight real, salary money)')
3415        self.addCleanup(query, 'drop type test_person_type')
3416        self.createTable('test_person', 'person test_person_type',
3417                         temporary=False, oids=True)
3418        attnames = self.db.get_attnames('test_person')
3419        self.assertEqual(len(attnames), 2)
3420        self.assertIn('oid', attnames)
3421        self.assertIn('person', attnames)
3422        person_typ = attnames['person']
3423        if self.regtypes:
3424            self.assertEqual(person_typ, 'test_person_type')
3425        else:
3426            self.assertEqual(person_typ, 'record')
3427        if self.regtypes:
3428            self.assertEqual(person_typ.attnames,
3429                dict(name='character varying', age='smallint',
3430                    married='boolean', weight='real', salary='money'))
3431        else:
3432            self.assertEqual(person_typ.attnames,
3433                dict(name='text', age='int', married='bool',
3434                    weight='float', salary='money'))
3435        decimal = pg.get_decimal()
3436        if pg.get_bool():
3437            bool_class = bool
3438            t, f = True, False
3439        else:
3440            bool_class = str
3441            t, f = 't', 'f'
3442        person = ('John Doe', 61, t, 99.5, decimal('93456.75'))
3443        r = self.db.insert('test_person', None, person=person)
3444        p = r['person']
3445        self.assertIsInstance(p, tuple)
3446        self.assertEqual(p, person)
3447        self.assertEqual(p.name, 'John Doe')
3448        self.assertIsInstance(p.name, str)
3449        self.assertIsInstance(p.age, int)
3450        self.assertIsInstance(p.married, bool_class)
3451        self.assertIsInstance(p.weight, float)
3452        self.assertIsInstance(p.salary, decimal)
3453        person = ('Jane Roe', 59, f, 64.5, decimal('96543.25'))
3454        r['person'] = person
3455        self.db.update('test_person', r)
3456        p = r['person']
3457        self.assertIsInstance(p, tuple)
3458        self.assertEqual(p, person)
3459        self.assertEqual(p.name, 'Jane Roe')
3460        self.assertIsInstance(p.name, str)
3461        self.assertIsInstance(p.age, int)
3462        self.assertIsInstance(p.married, bool_class)
3463        self.assertIsInstance(p.weight, float)
3464        self.assertIsInstance(p.salary, decimal)
3465        r['person'] = None
3466        self.db.get('test_person', r)
3467        p = r['person']
3468        self.assertIsInstance(p, tuple)
3469        self.assertEqual(p, person)
3470        self.assertEqual(p.name, 'Jane Roe')
3471        self.assertIsInstance(p.name, str)
3472        self.assertIsInstance(p.age, int)
3473        self.assertIsInstance(p.married, bool_class)
3474        self.assertIsInstance(p.weight, float)
3475        self.assertIsInstance(p.salary, decimal)
3476        person = (None,) * 5
3477        r = self.db.insert('test_person', None, person=person)
3478        p = r['person']
3479        self.assertIsInstance(p, tuple)
3480        self.assertIsNone(p.name)
3481        self.assertIsNone(p.age)
3482        self.assertIsNone(p.married)
3483        self.assertIsNone(p.weight)
3484        self.assertIsNone(p.salary)
3485        r['person'] = None
3486        self.db.get('test_person', r)
3487        p = r['person']
3488        self.assertIsInstance(p, tuple)
3489        self.assertIsNone(p.name)
3490        self.assertIsNone(p.age)
3491        self.assertIsNone(p.married)
3492        self.assertIsNone(p.weight)
3493        self.assertIsNone(p.salary)
3494        r = self.db.insert('test_person', None, person=None)
3495        self.assertIsNone(r['person'])
3496        r['person'] = None
3497        self.db.get('test_person', r)
3498        self.assertIsNone(r['person'])
3499
3500    def testRecordInsertBytea(self):
3501        query = self.db.query
3502        query('create type test_person_type as'
3503              ' (name text, picture bytea)')
3504        self.addCleanup(query, 'drop type test_person_type')
3505        self.createTable('test_person', 'person test_person_type',
3506                         temporary=False, oids=True)
3507        person_typ = self.db.get_attnames('test_person')['person']
3508        self.assertEqual(person_typ.attnames,
3509                         dict(name='text', picture='bytea'))
3510        person = ('John Doe', b'O\x00ps\xff!')
3511        r = self.db.insert('test_person', None, person=person)
3512        p = r['person']
3513        self.assertIsInstance(p, tuple)
3514        self.assertEqual(p, person)
3515        self.assertEqual(p.name, 'John Doe')
3516        self.assertIsInstance(p.name, str)
3517        self.assertEqual(p.picture, person[1])
3518        self.assertIsInstance(p.picture, bytes)
3519
3520    def testRecordInsertJson(self):
3521        query = self.db.query
3522        try:
3523            query('create type test_person_type as'
3524                  ' (name text, data json)')
3525        except pg.ProgrammingError as error:
3526            if self.db.server_version < 90200:
3527                self.skipTest('database does not support json')
3528            self.fail(str(error))
3529        self.addCleanup(query, 'drop type test_person_type')
3530        self.createTable('test_person', 'person test_person_type',
3531                         temporary=False, oids=True)
3532        person_typ = self.db.get_attnames('test_person')['person']
3533        self.assertEqual(person_typ.attnames,
3534                         dict(name='text', data='json'))
3535        person = ('John Doe', dict(age=61, married=True, weight=99.5))
3536        r = self.db.insert('test_person', None, person=person)
3537        p = r['person']
3538        self.assertIsInstance(p, tuple)
3539        if pg.get_jsondecode() is None:
3540            p = p._replace(data=json.loads(p.data))
3541        self.assertEqual(p, person)
3542        self.assertEqual(p.name, 'John Doe')
3543        self.assertIsInstance(p.name, str)
3544        self.assertEqual(p.data, person[1])
3545        self.assertIsInstance(p.data, dict)
3546
3547    def testRecordLiteral(self):
3548        query = self.db.query
3549        query('create type test_person_type as'
3550              ' (name varchar, age smallint)')
3551        self.addCleanup(query, 'drop type test_person_type')
3552        self.createTable('test_person', 'person test_person_type',
3553                         temporary=False, oids=True)
3554        person_typ = self.db.get_attnames('test_person')['person']
3555        if self.regtypes:
3556            self.assertEqual(person_typ, 'test_person_type')
3557        else:
3558            self.assertEqual(person_typ, 'record')
3559        if self.regtypes:
3560            self.assertEqual(person_typ.attnames,
3561                             dict(name='character varying', age='smallint'))
3562        else:
3563            self.assertEqual(person_typ.attnames,
3564                             dict(name='text', age='int'))
3565        person = pg.Literal("('John Doe', 61)")
3566        r = self.db.insert('test_person', None, person=person)
3567        p = r['person']
3568        self.assertIsInstance(p, tuple)
3569        self.assertEqual(p.name, 'John Doe')
3570        self.assertIsInstance(p.name, str)
3571        self.assertEqual(p.age, 61)
3572        self.assertIsInstance(p.age, int)
3573
3574    def testDate(self):
3575        query = self.db.query
3576        for datestyle in ('ISO', 'Postgres, MDY', 'Postgres, DMY',
3577                'SQL, MDY', 'SQL, DMY', 'German'):
3578            self.db.set_parameter('datestyle', datestyle)
3579            d = date(2016, 3, 14)
3580            q = "select $1::date"
3581            r = query(q, (d,)).getresult()[0][0]
3582            self.assertIsInstance(r, date)
3583            self.assertEqual(r, d)
3584            q = "select '10000-08-01'::date, '0099-01-08 BC'::date"
3585            r = query(q).getresult()[0]
3586            self.assertIsInstance(r[0], date)
3587            self.assertIsInstance(r[1], date)
3588            self.assertEqual(r[0], date.max)
3589            self.assertEqual(r[1], date.min)
3590        q = "select 'infinity'::date, '-infinity'::date"
3591        r = query(q).getresult()[0]
3592        self.assertIsInstance(r[0], date)
3593        self.assertIsInstance(r[1], date)
3594        self.assertEqual(r[0], date.max)
3595        self.assertEqual(r[1], date.min)
3596
3597    def testTime(self):
3598        query = self.db.query
3599        d = time(15, 9, 26)
3600        q = "select $1::time"
3601        r = query(q, (d,)).getresult()[0][0]
3602        self.assertIsInstance(r, time)
3603        self.assertEqual(r, d)
3604        d = time(15, 9, 26, 535897)
3605        q = "select $1::time"
3606        r = query(q, (d,)).getresult()[0][0]
3607        self.assertIsInstance(r, time)
3608        self.assertEqual(r, d)
3609
3610    def testTimetz(self):
3611        query = self.db.query
3612        timezones = dict(CET=1, EET=2, EST=-5, UTC=0)
3613        for timezone in sorted(timezones):
3614            tz = '%+03d00' % timezones[timezone]
3615            try:
3616                tzinfo = datetime.strptime(tz, '%z').tzinfo
3617            except ValueError:  # Python < 3.2
3618                tzinfo = pg._get_timezone(tz)
3619            self.db.set_parameter('timezone', timezone)
3620            d = time(15, 9, 26, tzinfo=tzinfo)
3621            q = "select $1::timetz"
3622            r = query(q, (d,)).getresult()[0][0]
3623            self.assertIsInstance(r, time)
3624            self.assertEqual(r, d)
3625            d = time(15, 9, 26, 535897, tzinfo)
3626            q = "select $1::timetz"
3627            r = query(q, (d,)).getresult()[0][0]
3628            self.assertIsInstance(r, time)
3629            self.assertEqual(r, d)
3630
3631    def testTimestamp(self):
3632        query = self.db.query
3633        for datestyle in ('ISO', 'Postgres, MDY', 'Postgres, DMY',
3634                'SQL, MDY', 'SQL, DMY', 'German'):
3635            self.db.set_parameter('datestyle', datestyle)
3636            d = datetime(2016, 3, 14)
3637            q = "select $1::timestamp"
3638            r = query(q, (d,)).getresult()[0][0]
3639            self.assertIsInstance(r, datetime)
3640            self.assertEqual(r, d)
3641            d = datetime(2016, 3, 14, 15, 9, 26)
3642            q = "select $1::timestamp"
3643            r = query(q, (d,)).getresult()[0][0]
3644            self.assertIsInstance(r, datetime)
3645            self.assertEqual(r, d)
3646            d = datetime(2016, 3, 14, 15, 9, 26, 535897)
3647            q = "select $1::timestamp"
3648            r = query(q, (d,)).getresult()[0][0]
3649            self.assertIsInstance(r, datetime)
3650            self.assertEqual(r, d)
3651            q = ("select '10000-08-01 AD'::timestamp,"
3652                " '0099-01-08 BC'::timestamp")
3653            r = query(q).getresult()[0]
3654            self.assertIsInstance(r[0], datetime)
3655            self.assertIsInstance(r[1], datetime)
3656            self.assertEqual(r[0], datetime.max)
3657            self.assertEqual(r[1], datetime.min)
3658        q = "select 'infinity'::timestamp, '-infinity'::timestamp"
3659        r = query(q).getresult()[0]
3660        self.assertIsInstance(r[0], datetime)
3661        self.assertIsInstance(r[1], datetime)
3662        self.assertEqual(r[0], datetime.max)
3663        self.assertEqual(r[1], datetime.min)
3664
3665    def testTimestamptz(self):
3666        query = self.db.query
3667        timezones = dict(CET=1, EET=2, EST=-5, UTC=0)
3668        for timezone in sorted(timezones):
3669            tz = '%+03d00' % timezones[timezone]
3670            try:
3671                tzinfo = datetime.strptime(tz, '%z').tzinfo
3672            except ValueError:  # Python < 3.2
3673                tzinfo = pg._get_timezone(tz)
3674            self.db.set_parameter('timezone', timezone)
3675            for datestyle in ('ISO', 'Postgres, MDY', 'Postgres, DMY',
3676                    'SQL, MDY', 'SQL, DMY', 'German'):
3677                self.db.set_parameter('datestyle', datestyle)
3678                d = datetime(2016, 3, 14, tzinfo=tzinfo)
3679                q = "select $1::timestamptz"
3680                r = query(q, (d,)).getresult()[0][0]
3681                self.assertIsInstance(r, datetime)
3682                self.assertEqual(r, d)
3683                d = datetime(2016, 3, 14, 15, 9, 26, tzinfo=tzinfo)
3684                q = "select $1::timestamptz"
3685                r = query(q, (d,)).getresult()[0][0]
3686                self.assertIsInstance(r, datetime)
3687                self.assertEqual(r, d)
3688                d = datetime(2016, 3, 14, 15, 9, 26, 535897, tzinfo)
3689                q = "select $1::timestamptz"
3690                r = query(q, (d,)).getresult()[0][0]
3691                self.assertIsInstance(r, datetime)
3692                self.assertEqual(r, d)
3693                q = ("select '10000-08-01 AD'::timestamptz,"
3694                    " '0099-01-08 BC'::timestamptz")
3695                r = query(q).getresult()[0]
3696                self.assertIsInstance(r[0], datetime)
3697                self.assertIsInstance(r[1], datetime)
3698                self.assertEqual(r[0], datetime.max)
3699                self.assertEqual(r[1], datetime.min)
3700        q = "select 'infinity'::timestamptz, '-infinity'::timestamptz"
3701        r = query(q).getresult()[0]
3702        self.assertIsInstance(r[0], datetime)
3703        self.assertIsInstance(r[1], datetime)
3704        self.assertEqual(r[0], datetime.max)
3705        self.assertEqual(r[1], datetime.min)
3706
3707    def testInterval(self):
3708        query = self.db.query
3709        for intervalstyle in (
3710                'sql_standard', 'postgres', 'postgres_verbose', 'iso_8601'):
3711            self.db.set_parameter('intervalstyle', intervalstyle)
3712            d = timedelta(3)
3713            q = "select $1::interval"
3714            r = query(q, (d,)).getresult()[0][0]
3715            self.assertIsInstance(r, timedelta)
3716            self.assertEqual(r, d)
3717            d = timedelta(-30)
3718            r = query(q, (d,)).getresult()[0][0]
3719            self.assertIsInstance(r, timedelta)
3720            self.assertEqual(r, d)
3721            d = timedelta(hours=3, minutes=31, seconds=42, microseconds=5678)
3722            q = "select $1::interval"
3723            r = query(q, (d,)).getresult()[0][0]
3724            self.assertIsInstance(r, timedelta)
3725            self.assertEqual(r, d)
3726
3727    def testDateAndTimeArrays(self):
3728        dt = (date(2016, 3, 14), time(15, 9, 26))
3729        q = "select ARRAY[$1::date], ARRAY[$2::time]"
3730        r = self.db.query(q, dt).getresult()[0]
3731        self.assertIsInstance(r[0], list)
3732        self.assertEqual(r[0][0], dt[0])
3733        self.assertIsInstance(r[1], list)
3734        self.assertEqual(r[1][0], dt[1])
3735
3736    def testHstore(self):
3737        try:
3738            self.db.query("select 'k=>v'::hstore")
3739        except pg.DatabaseError:
3740            try:
3741                self.db.query("create extension hstore")
3742            except pg.DatabaseError:
3743                self.skipTest("hstore extension not enabled")
3744        d = {'k': 'v', 'foo': 'bar', 'baz': 'whatever',
3745            '1a': 'anything at all', '2=b': 'value = 2', '3>c': 'value > 3',
3746            '4"c': 'value " 4', "5'c": "value ' 5", 'hello, world': '"hi!"',
3747            'None': None, 'NULL': 'NULL', 'empty': ''}
3748        q = "select $1::hstore"
3749        r = self.db.query(q, (pg.Hstore(d),)).getresult()[0][0]
3750        self.assertIsInstance(r, dict)
3751        self.assertEqual(r, d)
3752
3753    def testUuid(self):
3754        d = UUID('{12345678-1234-5678-1234-567812345678}')
3755        q = 'select $1::uuid'
3756        r = self.db.query(q, (d,)).getresult()[0][0]
3757        self.assertIsInstance(r, UUID)
3758        self.assertEqual(r, d)
3759
3760    def testDbTypesInfo(self):
3761        dbtypes = self.db.dbtypes
3762        self.assertIsInstance(dbtypes, dict)
3763        self.assertNotIn('numeric', dbtypes)
3764        typ = dbtypes['numeric']
3765        self.assertIn('numeric', dbtypes)
3766        self.assertEqual(typ, 'numeric' if self.regtypes else 'num')
3767        self.assertEqual(typ.oid, 1700)
3768        self.assertEqual(typ.pgtype, 'numeric')
3769        self.assertEqual(typ.regtype, 'numeric')
3770        self.assertEqual(typ.simple, 'num')
3771        self.assertEqual(typ.typtype, 'b')
3772        self.assertEqual(typ.category, 'N')
3773        self.assertEqual(typ.delim, ',')
3774        self.assertEqual(typ.relid, 0)
3775        self.assertIs(dbtypes[1700], typ)
3776        self.assertNotIn('pg_type', dbtypes)
3777        typ = dbtypes['pg_type']
3778        self.assertIn('pg_type', dbtypes)
3779        self.assertEqual(typ, 'pg_type' if self.regtypes else 'record')
3780        self.assertIsInstance(typ.oid, int)
3781        self.assertEqual(typ.pgtype, 'pg_type')
3782        self.assertEqual(typ.regtype, 'pg_type')
3783        self.assertEqual(typ.simple, 'record')
3784        self.assertEqual(typ.typtype, 'c')
3785        self.assertEqual(typ.category, 'C')
3786        self.assertEqual(typ.delim, ',')
3787        self.assertNotEqual(typ.relid, 0)
3788        attnames = typ.attnames
3789        self.assertIsInstance(attnames, dict)
3790        self.assertIs(attnames, dbtypes.get_attnames('pg_type'))
3791        self.assertEqual(list(attnames)[0], 'typname')
3792        typname = attnames['typname']
3793        self.assertEqual(typname, 'name' if self.regtypes else 'text')
3794        self.assertEqual(typname.typtype, 'b')  # base
3795        self.assertEqual(typname.category, 'S')  # string
3796        self.assertEqual(list(attnames)[3], 'typlen')
3797        typlen = attnames['typlen']
3798        self.assertEqual(typlen, 'smallint' if self.regtypes else 'int')
3799        self.assertEqual(typlen.typtype, 'b')  # base
3800        self.assertEqual(typlen.category, 'N')  # numeric
3801
3802    def testDbTypesTypecast(self):
3803        dbtypes = self.db.dbtypes
3804        self.assertIsInstance(dbtypes, dict)
3805        self.assertNotIn('int4', dbtypes)
3806        self.assertIs(dbtypes.get_typecast('int4'), int)
3807        dbtypes.set_typecast('int4', float)
3808        self.assertIs(dbtypes.get_typecast('int4'), float)
3809        dbtypes.reset_typecast('int4')
3810        self.assertIs(dbtypes.get_typecast('int4'), int)
3811        dbtypes.set_typecast('int4', float)
3812        self.assertIs(dbtypes.get_typecast('int4'), float)
3813        dbtypes.reset_typecast()
3814        self.assertIs(dbtypes.get_typecast('int4'), int)
3815        self.assertNotIn('circle', dbtypes)
3816        self.assertIsNone(dbtypes.get_typecast('circle'))
3817        squared_circle = lambda v: 'Squared Circle: %s' % v
3818        dbtypes.set_typecast('circle', squared_circle)
3819        self.assertIs(dbtypes.get_typecast('circle'), squared_circle)
3820        r = self.db.query("select '0,0,1'::circle").getresult()[0][0]
3821        self.assertIn('circle', dbtypes)
3822        self.assertEqual(r, 'Squared Circle: <(0,0),1>')
3823        self.assertEqual(dbtypes.typecast('Impossible', 'circle'),
3824            'Squared Circle: Impossible')
3825        dbtypes.reset_typecast('circle')
3826        self.assertIsNone(dbtypes.get_typecast('circle'))
3827
3828    def testGetSetTypeCast(self):
3829        get_typecast = pg.get_typecast
3830        set_typecast = pg.set_typecast
3831        dbtypes = self.db.dbtypes
3832        self.assertIsInstance(dbtypes, dict)
3833        self.assertNotIn('int4', dbtypes)
3834        self.assertNotIn('real', dbtypes)
3835        self.assertNotIn('bool', dbtypes)
3836        self.assertIs(get_typecast('int4'), int)
3837        self.assertIs(get_typecast('float4'), float)
3838        self.assertIs(get_typecast('bool'), pg.cast_bool)
3839        cast_circle = get_typecast('circle')
3840        self.addCleanup(set_typecast, 'circle', cast_circle)
3841        squared_circle = lambda v: 'Squared Circle: %s' % v
3842        self.assertNotIn('circle', dbtypes)
3843        set_typecast('circle', squared_circle)
3844        self.assertNotIn('circle', dbtypes)
3845        self.assertIs(get_typecast('circle'), squared_circle)
3846        r = self.db.query("select '0,0,1'::circle").getresult()[0][0]
3847        self.assertIn('circle', dbtypes)
3848        self.assertEqual(r, 'Squared Circle: <(0,0),1>')
3849        set_typecast('circle', cast_circle)
3850        self.assertIs(get_typecast('circle'), cast_circle)
3851
3852    def testNotificationHandler(self):
3853        # the notification handler itself is tested separately
3854        f = self.db.notification_handler
3855        callback = lambda arg_dict: None
3856        handler = f('test', callback)
3857        self.assertIsInstance(handler, pg.NotificationHandler)
3858        self.assertIs(handler.db, self.db)
3859        self.assertEqual(handler.event, 'test')
3860        self.assertEqual(handler.stop_event, 'stop_test')
3861        self.assertIs(handler.callback, callback)
3862        self.assertIsInstance(handler.arg_dict, dict)
3863        self.assertEqual(handler.arg_dict, {})
3864        self.assertIsNone(handler.timeout)
3865        self.assertFalse(handler.listening)
3866        handler.close()
3867        self.assertIsNone(handler.db)
3868        self.db.reopen()
3869        self.assertIsNone(handler.db)
3870        handler = f('test2', callback, timeout=2)
3871        self.assertIsInstance(handler, pg.NotificationHandler)
3872        self.assertIs(handler.db, self.db)
3873        self.assertEqual(handler.event, 'test2')
3874        self.assertEqual(handler.stop_event, 'stop_test2')
3875        self.assertIs(handler.callback, callback)
3876        self.assertIsInstance(handler.arg_dict, dict)
3877        self.assertEqual(handler.arg_dict, {})
3878        self.assertEqual(handler.timeout, 2)
3879        self.assertFalse(handler.listening)
3880        handler.close()
3881        self.assertIsNone(handler.db)
3882        self.db.reopen()
3883        self.assertIsNone(handler.db)
3884        arg_dict = {'testing': 3}
3885        handler = f('test3', callback, arg_dict=arg_dict)
3886        self.assertIsInstance(handler, pg.NotificationHandler)
3887        self.assertIs(handler.db, self.db)
3888        self.assertEqual(handler.event, 'test3')
3889        self.assertEqual(handler.stop_event, 'stop_test3')
3890        self.assertIs(handler.callback, callback)
3891        self.assertIs(handler.arg_dict, arg_dict)
3892        self.assertEqual(arg_dict['testing'], 3)
3893        self.assertIsNone(handler.timeout)
3894        self.assertFalse(handler.listening)
3895        handler.close()
3896        self.assertIsNone(handler.db)
3897        self.db.reopen()
3898        self.assertIsNone(handler.db)
3899        handler = f('test4', callback, stop_event='stop4')
3900        self.assertIsInstance(handler, pg.NotificationHandler)
3901        self.assertIs(handler.db, self.db)
3902        self.assertEqual(handler.event, 'test4')
3903        self.assertEqual(handler.stop_event, 'stop4')
3904        self.assertIs(handler.callback, callback)
3905        self.assertIsInstance(handler.arg_dict, dict)
3906        self.assertEqual(handler.arg_dict, {})
3907        self.assertIsNone(handler.timeout)
3908        self.assertFalse(handler.listening)
3909        handler.close()
3910        self.assertIsNone(handler.db)
3911        self.db.reopen()
3912        self.assertIsNone(handler.db)
3913        arg_dict = {'testing': 5}
3914        handler = f('test5', callback, arg_dict, 1.5, 'stop5')
3915        self.assertIsInstance(handler, pg.NotificationHandler)
3916        self.assertIs(handler.db, self.db)
3917        self.assertEqual(handler.event, 'test5')
3918        self.assertEqual(handler.stop_event, 'stop5')
3919        self.assertIs(handler.callback, callback)
3920        self.assertIs(handler.arg_dict, arg_dict)
3921        self.assertEqual(arg_dict['testing'], 5)
3922        self.assertEqual(handler.timeout, 1.5)
3923        self.assertFalse(handler.listening)
3924        handler.close()
3925        self.assertIsNone(handler.db)
3926        self.db.reopen()
3927        self.assertIsNone(handler.db)
3928
3929
3930class TestDBClassNonStdOpts(TestDBClass):
3931    """Test the methods of the DB class with non-standard global options."""
3932
3933    @classmethod
3934    def setUpClass(cls):
3935        cls.saved_options = {}
3936        cls.set_option('decimal', float)
3937        not_bool = not pg.get_bool()
3938        cls.set_option('bool', not_bool)
3939        not_array = not pg.get_array()
3940        cls.set_option('array', not_array)
3941        not_bytea_escaped = not pg.get_bytea_escaped()
3942        cls.set_option('bytea_escaped', not_bytea_escaped)
3943        cls.set_option('namedresult', None)
3944        cls.set_option('jsondecode', None)
3945        cls.regtypes = not DB().use_regtypes()
3946        super(TestDBClassNonStdOpts, cls).setUpClass()
3947
3948    @classmethod
3949    def tearDownClass(cls):
3950        super(TestDBClassNonStdOpts, cls).tearDownClass()
3951        cls.reset_option('jsondecode')
3952        cls.reset_option('namedresult')
3953        cls.reset_option('bool')
3954        cls.reset_option('array')
3955        cls.reset_option('bytea_escaped')
3956        cls.reset_option('decimal')
3957
3958    @classmethod
3959    def set_option(cls, option, value):
3960        cls.saved_options[option] = getattr(pg, 'get_' + option)()
3961        return getattr(pg, 'set_' + option)(value)
3962
3963    @classmethod
3964    def reset_option(cls, option):
3965        return getattr(pg, 'set_' + option)(cls.saved_options[option])
3966
3967
3968class TestDBClassAdapter(unittest.TestCase):
3969    """Test the adapter object associatd with the DB class."""
3970
3971    def setUp(self):
3972        self.db = DB()
3973        self.adapter = self.db.adapter
3974
3975    def tearDown(self):
3976        try:
3977            self.db.close()
3978        except pg.InternalError:
3979            pass
3980
3981    def testGuessSimpleType(self):
3982        f = self.adapter.guess_simple_type
3983        self.assertEqual(f(pg.Bytea(b'test')), 'bytea')
3984        self.assertEqual(f('string'), 'text')
3985        self.assertEqual(f(b'string'), 'text')
3986        self.assertEqual(f(True), 'bool')
3987        self.assertEqual(f(3), 'int')
3988        self.assertEqual(f(2.75), 'float')
3989        self.assertEqual(f(Decimal('4.25')), 'num')
3990        self.assertEqual(f(date(2016, 1, 30)), 'date')
3991        self.assertEqual(f([1, 2, 3]), 'int[]')
3992        self.assertEqual(f([[[123]]]), 'int[]')
3993        self.assertEqual(f(['a', 'b', 'c']), 'text[]')
3994        self.assertEqual(f([[['abc']]]), 'text[]')
3995        self.assertEqual(f([False, True]), 'bool[]')
3996        self.assertEqual(f([[[False]]]), 'bool[]')
3997        r = f(('string', True, 3, 2.75, [1], [False]))
3998        self.assertEqual(r, 'record')
3999        self.assertEqual(list(r.attnames.values()),
4000            ['text', 'bool', 'int', 'float', 'int[]', 'bool[]'])
4001
4002    def testAdaptQueryTypedList(self):
4003        format_query = self.adapter.format_query
4004        self.assertRaises(TypeError, format_query,
4005            '%s,%s', (1, 2), ('int2',))
4006        self.assertRaises(TypeError, format_query,
4007            '%s,%s', (1,), ('int2', 'int2'))
4008        values = (3, 7.5, 'hello', True)
4009        types = ('int4', 'float4', 'text', 'bool')
4010        sql, params = format_query("select %s,%s,%s,%s", values, types)
4011        self.assertEqual(sql, 'select $1,$2,$3,$4')
4012        self.assertEqual(params, [3, 7.5, 'hello', 't'])
4013        types = ('bool', 'bool', 'bool', 'bool')
4014        sql, params = format_query("select %s,%s,%s,%s", values, types)
4015        self.assertEqual(sql, 'select $1,$2,$3,$4')
4016        self.assertEqual(params, ['t', 't', 'f', 't'])
4017        values = ('2016-01-30', 'current_date')
4018        types = ('date', 'date')
4019        sql, params = format_query("values(%s,%s)", values, types)
4020        self.assertEqual(sql, 'values($1,current_date)')
4021        self.assertEqual(params, ['2016-01-30'])
4022        values = ([1, 2, 3], ['a', 'b', 'c'])
4023        types = ('_int4', '_text')
4024        sql, params = format_query("%s::int4[],%s::text[]", values, types)
4025        self.assertEqual(sql, '$1::int4[],$2::text[]')
4026        self.assertEqual(params, ['{1,2,3}', '{a,b,c}'])
4027        types = ('_bool', '_bool')
4028        sql, params = format_query("%s::bool[],%s::bool[]", values, types)
4029        self.assertEqual(sql, '$1::bool[],$2::bool[]')
4030        self.assertEqual(params, ['{t,t,t}', '{f,f,f}'])
4031        values = [(3, 7.5, 'hello', True, [123], ['abc'])]
4032        t = self.adapter.simple_type
4033        typ = t('record')
4034        typ._get_attnames = lambda _self: pg.AttrDict([
4035            ('i', t('int')), ('f', t('float')),
4036            ('t', t('text')), ('b', t('bool')),
4037            ('i3', t('int[]')), ('t3', t('text[]'))])
4038        types = [typ]
4039        sql, params = format_query('select %s', values, types)
4040        self.assertEqual(sql, 'select $1')
4041        self.assertEqual(params, ['(3,7.5,hello,t,{123},{abc})'])
4042
4043    def testAdaptQueryTypedDict(self):
4044        format_query = self.adapter.format_query
4045        self.assertRaises(TypeError, format_query,
4046            '%s,%s', dict(i1=1, i2=2), dict(i1='int2'))
4047        values = dict(i=3, f=7.5, t='hello', b=True)
4048        types = dict(i='int4', f='float4',
4049            t='text', b='bool')
4050        sql, params = format_query(
4051            "select %(i)s,%(f)s,%(t)s,%(b)s", values, types)
4052        self.assertEqual(sql, 'select $3,$2,$4,$1')
4053        self.assertEqual(params, ['t', 7.5, 3, 'hello'])
4054        types = dict(i='bool', f='bool',
4055            t='bool', b='bool')
4056        sql, params = format_query(
4057            "select %(i)s,%(f)s,%(t)s,%(b)s", values, types)
4058        self.assertEqual(sql, 'select $3,$2,$4,$1')
4059        self.assertEqual(params, ['t', 't', 't', 'f'])
4060        values = dict(d1='2016-01-30', d2='current_date')
4061        types = dict(d1='date', d2='date')
4062        sql, params = format_query("values(%(d1)s,%(d2)s)", values, types)
4063        self.assertEqual(sql, 'values($1,current_date)')
4064        self.assertEqual(params, ['2016-01-30'])
4065        values = dict(i=[1, 2, 3], t=['a', 'b', 'c'])
4066        types = dict(i='_int4', t='_text')
4067        sql, params = format_query(
4068            "%(i)s::int4[],%(t)s::text[]", values, types)
4069        self.assertEqual(sql, '$1::int4[],$2::text[]')
4070        self.assertEqual(params, ['{1,2,3}', '{a,b,c}'])
4071        types = dict(i='_bool', t='_bool')
4072        sql, params = format_query(
4073            "%(i)s::bool[],%(t)s::bool[]", values, types)
4074        self.assertEqual(sql, '$1::bool[],$2::bool[]')
4075        self.assertEqual(params, ['{t,t,t}', '{f,f,f}'])
4076        values = dict(record=(3, 7.5, 'hello', True, [123], ['abc']))
4077        t = self.adapter.simple_type
4078        typ = t('record')
4079        typ._get_attnames = lambda _self: pg.AttrDict([
4080            ('i', t('int')), ('f', t('float')),
4081            ('t', t('text')), ('b', t('bool')),
4082            ('i3', t('int[]')), ('t3', t('text[]'))])
4083        types = dict(record=typ)
4084        sql, params = format_query('select %(record)s', values, types)
4085        self.assertEqual(sql, 'select $1')
4086        self.assertEqual(params, ['(3,7.5,hello,t,{123},{abc})'])
4087
4088    def testAdaptQueryUntypedList(self):
4089        format_query = self.adapter.format_query
4090        values = (3, 7.5, 'hello', True)
4091        sql, params = format_query("select %s,%s,%s,%s", values)
4092        self.assertEqual(sql, 'select $1,$2,$3,$4')
4093        self.assertEqual(params, [3, 7.5, 'hello', 't'])
4094        values = [date(2016, 1, 30), 'current_date']
4095        sql, params = format_query("values(%s,%s)", values)
4096        self.assertEqual(sql, 'values($1,$2)')
4097        self.assertEqual(params, values)
4098        values = ([1, 2, 3], ['a', 'b', 'c'], [True, False, True])
4099        sql, params = format_query("%s,%s,%s", values)
4100        self.assertEqual(sql, "$1,$2,$3")
4101        self.assertEqual(params, ['{1,2,3}', '{a,b,c}', '{t,f,t}'])
4102        values = ([[1, 2], [3, 4]], [['a', 'b'], ['c', 'd']],
4103            [[True, False], [False, True]])
4104        sql, params = format_query("%s,%s,%s", values)
4105        self.assertEqual(sql, "$1,$2,$3")
4106        self.assertEqual(params, [
4107            '{{1,2},{3,4}}', '{{a,b},{c,d}}', '{{t,f},{f,t}}'])
4108        values = [(3, 7.5, 'hello', True, [123], ['abc'])]
4109        sql, params = format_query('select %s', values)
4110        self.assertEqual(sql, 'select $1')
4111        self.assertEqual(params, ['(3,7.5,hello,t,{123},{abc})'])
4112
4113    def testAdaptQueryUntypedDict(self):
4114        format_query = self.adapter.format_query
4115        values = dict(i=3, f=7.5, t='hello', b=True)
4116        sql, params = format_query(
4117            "select %(i)s,%(f)s,%(t)s,%(b)s", values)
4118        self.assertEqual(sql, 'select $3,$2,$4,$1')
4119        self.assertEqual(params, ['t', 7.5, 3, 'hello'])
4120        values = dict(d1='2016-01-30', d2='current_date')
4121        sql, params = format_query("values(%(d1)s,%(d2)s)", values)
4122        self.assertEqual(sql, 'values($1,$2)')
4123        self.assertEqual(params, [values['d1'], values['d2']])
4124        values = dict(i=[1, 2, 3], t=['a', 'b', 'c'], b=[True, False, True])
4125        sql, params = format_query("%(i)s,%(t)s,%(b)s", values)
4126        self.assertEqual(sql, "$2,$3,$1")
4127        self.assertEqual(params, ['{t,f,t}', '{1,2,3}', '{a,b,c}'])
4128        values = dict(i=[[1, 2], [3, 4]], t=[['a', 'b'], ['c', 'd']],
4129            b=[[True, False], [False, True]])
4130        sql, params = format_query("%(i)s,%(t)s,%(b)s", values)
4131        self.assertEqual(sql, "$2,$3,$1")
4132        self.assertEqual(params, [
4133            '{{t,f},{f,t}}', '{{1,2},{3,4}}', '{{a,b},{c,d}}'])
4134        values = dict(record=(3, 7.5, 'hello', True, [123], ['abc']))
4135        sql, params = format_query('select %(record)s', values)
4136        self.assertEqual(sql, 'select $1')
4137        self.assertEqual(params, ['(3,7.5,hello,t,{123},{abc})'])
4138
4139    def testAdaptQueryInlineList(self):
4140        format_query = self.adapter.format_query
4141        values = (3, 7.5, 'hello', True)
4142        sql, params = format_query("select %s,%s,%s,%s", values, inline=True)
4143        self.assertEqual(sql, "select 3,7.5,'hello',true")
4144        self.assertEqual(params, [])
4145        values = [date(2016, 1, 30), 'current_date']
4146        sql, params = format_query("values(%s,%s)", values, inline=True)
4147        self.assertEqual(sql, "values('2016-01-30','current_date')")
4148        self.assertEqual(params, [])
4149        values = ([1, 2, 3], ['a', 'b', 'c'], [True, False, True])
4150        sql, params = format_query("%s,%s,%s", values, inline=True)
4151        self.assertEqual(sql,
4152            "ARRAY[1,2,3],ARRAY['a','b','c'],ARRAY[true,false,true]")
4153        self.assertEqual(params, [])
4154        values = ([[1, 2], [3, 4]], [['a', 'b'], ['c', 'd']],
4155            [[True, False], [False, True]])
4156        sql, params = format_query("%s,%s,%s", values, inline=True)
4157        self.assertEqual(sql, "ARRAY[[1,2],[3,4]],ARRAY[['a','b'],['c','d']],"
4158            "ARRAY[[true,false],[false,true]]")
4159        self.assertEqual(params, [])
4160        values = [(3, 7.5, 'hello', True, [123], ['abc'])]
4161        sql, params = format_query('select %s', values, inline=True)
4162        self.assertEqual(sql,
4163            "select (3,7.5,'hello',true,ARRAY[123],ARRAY['abc'])")
4164        self.assertEqual(params, [])
4165
4166    def testAdaptQueryInlineDict(self):
4167        format_query = self.adapter.format_query
4168        values = dict(i=3, f=7.5, t='hello', b=True)
4169        sql, params = format_query(
4170            "select %(i)s,%(f)s,%(t)s,%(b)s", values, inline=True)
4171        self.assertEqual(sql, "select 3,7.5,'hello',true")
4172        self.assertEqual(params, [])
4173        values = dict(d1='2016-01-30', d2='current_date')
4174        sql, params = format_query(
4175            "values(%(d1)s,%(d2)s)", values, inline=True)
4176        self.assertEqual(sql, "values('2016-01-30','current_date')")
4177        self.assertEqual(params, [])
4178        values = dict(i=[1, 2, 3], t=['a', 'b', 'c'], b=[True, False, True])
4179        sql, params = format_query("%(i)s,%(t)s,%(b)s", values, inline=True)
4180        self.assertEqual(sql,
4181            "ARRAY[1,2,3],ARRAY['a','b','c'],ARRAY[true,false,true]")
4182        self.assertEqual(params, [])
4183        values = dict(i=[[1, 2], [3, 4]], t=[['a', 'b'], ['c', 'd']],
4184            b=[[True, False], [False, True]])
4185        sql, params = format_query("%(i)s,%(t)s,%(b)s", values, inline=True)
4186        self.assertEqual(sql, "ARRAY[[1,2],[3,4]],ARRAY[['a','b'],['c','d']],"
4187            "ARRAY[[true,false],[false,true]]")
4188        self.assertEqual(params, [])
4189        values = dict(record=(3, 7.5, 'hello', True, [123], ['abc']))
4190        sql, params = format_query('select %(record)s', values, inline=True)
4191        self.assertEqual(sql,
4192            "select (3,7.5,'hello',true,ARRAY[123],ARRAY['abc'])")
4193        self.assertEqual(params, [])
4194
4195    def testAdaptQueryWithPgRepr(self):
4196        format_query = self.adapter.format_query
4197        self.assertRaises(TypeError, format_query,
4198            '%s', object(), inline=True)
4199        class TestObject:
4200            def __pg_repr__(self):
4201                return "'adapted'"
4202        sql, params = format_query('select %s', [TestObject()], inline=True)
4203        self.assertEqual(sql, "select 'adapted'")
4204        self.assertEqual(params, [])
4205        sql, params = format_query('select %s', [[TestObject()]], inline=True)
4206        self.assertEqual(sql, "select ARRAY['adapted']")
4207        self.assertEqual(params, [])
4208
4209
4210class TestSchemas(unittest.TestCase):
4211    """Test correct handling of schemas (namespaces)."""
4212
4213    cls_set_up = False
4214
4215    @classmethod
4216    def setUpClass(cls):
4217        db = DB()
4218        query = db.query
4219        for num_schema in range(5):
4220            if num_schema:
4221                schema = "s%d" % num_schema
4222                query("drop schema if exists %s cascade" % (schema,))
4223                try:
4224                    query("create schema %s" % (schema,))
4225                except pg.ProgrammingError:
4226                    raise RuntimeError("The test user cannot create schemas.\n"
4227                        "Grant create on database %s to the user"
4228                        " for running these tests." % dbname)
4229            else:
4230                schema = "public"
4231                query("drop table if exists %s.t" % (schema,))
4232                query("drop table if exists %s.t%d" % (schema, num_schema))
4233            query("create table %s.t with oids as select 1 as n, %d as d"
4234                  % (schema, num_schema))
4235            query("create table %s.t%d with oids as select 1 as n, %d as d"
4236                  % (schema, num_schema, num_schema))
4237        db.close()
4238        cls.cls_set_up = True
4239
4240    @classmethod
4241    def tearDownClass(cls):
4242        db = DB()
4243        query = db.query
4244        for num_schema in range(5):
4245            if num_schema:
4246                schema = "s%d" % num_schema
4247                query("drop schema %s cascade" % (schema,))
4248            else:
4249                schema = "public"
4250                query("drop table %s.t" % (schema,))
4251                query("drop table %s.t%d" % (schema, num_schema))
4252        db.close()
4253
4254    def setUp(self):
4255        self.assertTrue(self.cls_set_up)
4256        self.db = DB()
4257
4258    def tearDown(self):
4259        self.doCleanups()
4260        self.db.close()
4261
4262    def testGetTables(self):
4263        tables = self.db.get_tables()
4264        for num_schema in range(5):
4265            if num_schema:
4266                schema = "s" + str(num_schema)
4267            else:
4268                schema = "public"
4269            for t in (schema + ".t",
4270                    schema + ".t" + str(num_schema)):
4271                self.assertIn(t, tables)
4272
4273    def testGetAttnames(self):
4274        get_attnames = self.db.get_attnames
4275        query = self.db.query
4276        result = {'oid': 'int', 'd': 'int', 'n': 'int'}
4277        r = get_attnames("t")
4278        self.assertEqual(r, result)
4279        r = get_attnames("s4.t4")
4280        self.assertEqual(r, result)
4281        query("drop table if exists s3.t3m")
4282        self.addCleanup(query, "drop table s3.t3m")
4283        query("create table s3.t3m with oids as select 1 as m")
4284        result_m = {'oid': 'int', 'm': 'int'}
4285        r = get_attnames("s3.t3m")
4286        self.assertEqual(r, result_m)
4287        query("set search_path to s1,s3")
4288        r = get_attnames("t3")
4289        self.assertEqual(r, result)
4290        r = get_attnames("t3m")
4291        self.assertEqual(r, result_m)
4292
4293    def testGet(self):
4294        get = self.db.get
4295        query = self.db.query
4296        PrgError = pg.ProgrammingError
4297        self.assertEqual(get("t", 1, 'n')['d'], 0)
4298        self.assertEqual(get("t0", 1, 'n')['d'], 0)
4299        self.assertEqual(get("public.t", 1, 'n')['d'], 0)
4300        self.assertEqual(get("public.t0", 1, 'n')['d'], 0)
4301        self.assertRaises(PrgError, get, "public.t1", 1, 'n')
4302        self.assertEqual(get("s1.t1", 1, 'n')['d'], 1)
4303        self.assertEqual(get("s3.t", 1, 'n')['d'], 3)
4304        query("set search_path to s2,s4")
4305        self.assertRaises(PrgError, get, "t1", 1, 'n')
4306        self.assertEqual(get("t4", 1, 'n')['d'], 4)
4307        self.assertRaises(PrgError, get, "t3", 1, 'n')
4308        self.assertEqual(get("t", 1, 'n')['d'], 2)
4309        self.assertEqual(get("s3.t3", 1, 'n')['d'], 3)
4310        query("set search_path to s1,s3")
4311        self.assertRaises(PrgError, get, "t2", 1, 'n')
4312        self.assertEqual(get("t3", 1, 'n')['d'], 3)
4313        self.assertRaises(PrgError, get, "t4", 1, 'n')
4314        self.assertEqual(get("t", 1, 'n')['d'], 1)
4315        self.assertEqual(get("s4.t4", 1, 'n')['d'], 4)
4316
4317    def testMunging(self):
4318        get = self.db.get
4319        query = self.db.query
4320        r = get("t", 1, 'n')
4321        self.assertIn('oid(t)', r)
4322        query("set search_path to s2")
4323        r = get("t2", 1, 'n')
4324        self.assertIn('oid(t2)', r)
4325        query("set search_path to s3")
4326        r = get("t", 1, 'n')
4327        self.assertIn('oid(t)', r)
4328
4329
4330class TestDebug(unittest.TestCase):
4331    """Test the debug attribute of the DB class."""
4332
4333    def setUp(self):
4334        self.db = DB()
4335        self.query = self.db.query
4336        self.debug = self.db.debug
4337        self.output = StringIO()
4338        self.stdout, sys.stdout = sys.stdout, self.output
4339
4340    def tearDown(self):
4341        sys.stdout = self.stdout
4342        self.output.close()
4343        self.db.debug = debug
4344        self.db.close()
4345
4346    def get_output(self):
4347        return self.output.getvalue()
4348
4349    def send_queries(self):
4350        self.db.query("select 1")
4351        self.db.query("select 2")
4352
4353    def testDebugDefault(self):
4354        if debug:
4355            self.assertEqual(self.db.debug, debug)
4356        else:
4357            self.assertIsNone(self.db.debug)
4358
4359    def testDebugIsFalse(self):
4360        self.db.debug = False
4361        self.send_queries()
4362        self.assertEqual(self.get_output(), "")
4363
4364    def testDebugIsTrue(self):
4365        self.db.debug = True
4366        self.send_queries()
4367        self.assertEqual(self.get_output(), "select 1\nselect 2\n")
4368
4369    def testDebugIsString(self):
4370        self.db.debug = "Test with string: %s."
4371        self.send_queries()
4372        self.assertEqual(self.get_output(),
4373            "Test with string: select 1.\nTest with string: select 2.\n")
4374
4375    def testDebugIsFileLike(self):
4376        with tempfile.TemporaryFile('w+') as debug_file:
4377            self.db.debug = debug_file
4378            self.send_queries()
4379            debug_file.seek(0)
4380            output = debug_file.read()
4381            self.assertEqual(output, "select 1\nselect 2\n")
4382            self.assertEqual(self.get_output(), "")
4383
4384    def testDebugIsCallable(self):
4385        output = []
4386        self.db.debug = output.append
4387        self.db.query("select 1")
4388        self.db.query("select 2")
4389        self.assertEqual(output, ["select 1", "select 2"])
4390        self.assertEqual(self.get_output(), "")
4391
4392    def testDebugMultipleArgs(self):
4393        output = []
4394        self.db.debug = output.append
4395        args = ['Error', 42, {1: 'a', 2: 'b'}, [3, 5, 7]]
4396        self.db._do_debug(*args)
4397        self.assertEqual(output, ['\n'.join(str(arg) for arg in args)])
4398        self.assertEqual(self.get_output(), "")
4399
4400
4401if __name__ == '__main__':
4402    unittest.main()
Note: See TracBrowser for help on using the repository browser.