source: trunk/tests/test_classic_dbwrapper.py @ 802

Last change on this file since 802 was 802, checked in by cito, 4 years ago

Make the unescaping of bytea configurable

By default, bytea is returned unescaped in 5.0, but the old
behavior can now be restored with set_escaped_bytea().

  • Property svn:executable set to *
  • Property svn:keywords set to Id
File size: 165.2 KB
Line 
1#! /usr/bin/python
2# -*- coding: utf-8 -*-
3
4"""Test the classic PyGreSQL interface.
5
6Sub-tests for the DB wrapper object.
7
8Contributed by Christoph Zwerschke.
9
10These tests need a database to test against.
11"""
12
13try:
14    import unittest2 as unittest  # for Python < 2.7
15except ImportError:
16    import unittest
17
18import os
19import sys
20import tempfile
21import json
22
23import pg  # the module under test
24
25from decimal import Decimal
26from datetime import date
27from operator import itemgetter
28
29# We need a database to test against.  If LOCAL_PyGreSQL.py exists we will
30# get our information from that.  Otherwise we use the defaults.
31# The current user must have create schema privilege on the database.
32dbname = 'unittest'
33dbhost = None
34dbport = 5432
35
36debug = False  # let DB wrapper print debugging output
37
38try:
39    from .LOCAL_PyGreSQL import *
40except (ImportError, ValueError):
41    try:
42        from LOCAL_PyGreSQL import *
43    except ImportError:
44        pass
45
46try:
47    long
48except NameError:  # Python >= 3.0
49    long = int
50
51try:
52    unicode
53except NameError:  # Python >= 3.0
54    unicode = str
55
56try:
57    from collections import OrderedDict
58except ImportError:  # Python 2.6 or 3.0
59    OrderedDict = dict
60
61if str is bytes:
62    from StringIO import StringIO
63else:
64    from io import StringIO
65
66windows = os.name == 'nt'
67
68# There is a known a bug in libpq under Windows which can cause
69# the interface to crash when calling PQhost():
70do_not_ask_for_host = windows
71do_not_ask_for_host_reason = 'libpq issue on Windows'
72
73
74def DB():
75    """Create a DB wrapper object connecting to the test database."""
76    db = pg.DB(dbname, dbhost, dbport)
77    if debug:
78        db.debug = debug
79    db.query("set client_min_messages=warning")
80    return db
81
82
83class TestAttrDict(unittest.TestCase):
84    """Test the simple ordered dictionary for attribute names."""
85
86    cls = pg.AttrDict
87    base = OrderedDict
88
89    def testInit(self):
90        a = self.cls()
91        self.assertIsInstance(a, self.base)
92        self.assertEqual(a, self.base())
93        items = [('id', 'int'), ('name', 'text')]
94        a = self.cls(items)
95        self.assertIsInstance(a, self.base)
96        self.assertEqual(a, self.base(items))
97        iteritems = iter(items)
98        a = self.cls(iteritems)
99        self.assertIsInstance(a, self.base)
100        self.assertEqual(a, self.base(items))
101
102    def testIter(self):
103        a = self.cls()
104        self.assertEqual(list(a), [])
105        keys = ['id', 'name', 'age']
106        items = [(key, None) for key in keys]
107        a = self.cls(items)
108        self.assertEqual(list(a), keys)
109
110    def testKeys(self):
111        a = self.cls()
112        self.assertEqual(list(a.keys()), [])
113        keys = ['id', 'name', 'age']
114        items = [(key, None) for key in keys]
115        a = self.cls(items)
116        self.assertEqual(list(a.keys()), keys)
117
118    def testValues(self):
119        a = self.cls()
120        self.assertEqual(list(a.values()), [])
121        items = [('id', 'int'), ('name', 'text')]
122        values = [item[1] for item in items]
123        a = self.cls(items)
124        self.assertEqual(list(a.values()), values)
125
126    def testItems(self):
127        a = self.cls()
128        self.assertEqual(list(a.items()), [])
129        items = [('id', 'int'), ('name', 'text')]
130        a = self.cls(items)
131        self.assertEqual(list(a.items()), items)
132
133    def testGet(self):
134        a = self.cls([('id', 1)])
135        try:
136            self.assertEqual(a['id'], 1)
137        except KeyError:
138            self.fail('AttrDict should be readable')
139
140    def testSet(self):
141        a = self.cls()
142        try:
143            a['id'] = 1
144        except TypeError:
145            pass
146        else:
147            self.fail('AttrDict should be read-only')
148
149    def testDel(self):
150        a = self.cls([('id', 1)])
151        try:
152            del a['id']
153        except TypeError:
154            pass
155        else:
156            self.fail('AttrDict should be read-only')
157
158    def testWriteMethods(self):
159        a = self.cls([('id', 1)])
160        self.assertEqual(a['id'], 1)
161        for method in 'clear', 'update', 'pop', 'setdefault', 'popitem':
162            method = getattr(a, method)
163            self.assertRaises(TypeError, method, a)
164
165
166class TestDBClassBasic(unittest.TestCase):
167    """Test existence of the DB class wrapped pg connection methods."""
168
169    def setUp(self):
170        self.db = DB()
171
172    def tearDown(self):
173        try:
174            self.db.close()
175        except pg.InternalError:
176            pass
177
178    def testAllDBAttributes(self):
179        attributes = [
180            'abort', 'adapter',
181            'begin',
182            'cancel', 'clear', 'close', 'commit',
183            'db', 'dbname', 'dbtypes',
184            'debug', 'decode_json', 'delete',
185            'encode_json', 'end', 'endcopy', 'error',
186            'escape_bytea', 'escape_identifier',
187            'escape_literal', 'escape_string',
188            'fileno',
189            'get', 'get_as_dict', 'get_as_list',
190            'get_attnames', 'get_cast_hook',
191            'get_databases', 'get_notice_receiver',
192            'get_parameter', 'get_relations', 'get_tables',
193            'getline', 'getlo', 'getnotify',
194            'has_table_privilege', 'host',
195            'insert', 'inserttable',
196            'locreate', 'loimport',
197            'notification_handler',
198            'options',
199            'parameter', 'pkey', 'port',
200            'protocol_version', 'putline',
201            'query', 'query_formatted',
202            'release', 'reopen', 'reset', 'rollback',
203            'savepoint', 'server_version',
204            'set_cast_hook', 'set_notice_receiver',
205            'set_parameter',
206            'source', 'start', 'status',
207            'transaction', 'truncate',
208            'unescape_bytea', 'update', 'upsert',
209            'use_regtypes', 'user',
210        ]
211        db_attributes = [a for a in dir(self.db)
212            if not a.startswith('_')]
213        self.assertEqual(attributes, db_attributes)
214
215    def testAttributeDb(self):
216        self.assertEqual(self.db.db.db, dbname)
217
218    def testAttributeDbname(self):
219        self.assertEqual(self.db.dbname, dbname)
220
221    def testAttributeError(self):
222        error = self.db.error
223        self.assertTrue(not error or 'krb5_' in error)
224        self.assertEqual(self.db.error, self.db.db.error)
225
226    @unittest.skipIf(do_not_ask_for_host, do_not_ask_for_host_reason)
227    def testAttributeHost(self):
228        def_host = 'localhost'
229        host = self.db.host
230        self.assertIsInstance(host, str)
231        self.assertEqual(host, dbhost or def_host)
232        self.assertEqual(host, self.db.db.host)
233
234    def testAttributeOptions(self):
235        no_options = ''
236        options = self.db.options
237        self.assertEqual(options, no_options)
238        self.assertEqual(options, self.db.db.options)
239
240    def testAttributePort(self):
241        def_port = 5432
242        port = self.db.port
243        self.assertIsInstance(port, int)
244        self.assertEqual(port, dbport or def_port)
245        self.assertEqual(port, self.db.db.port)
246
247    def testAttributeProtocolVersion(self):
248        protocol_version = self.db.protocol_version
249        self.assertIsInstance(protocol_version, int)
250        self.assertTrue(2 <= protocol_version < 4)
251        self.assertEqual(protocol_version, self.db.db.protocol_version)
252
253    def testAttributeServerVersion(self):
254        server_version = self.db.server_version
255        self.assertIsInstance(server_version, int)
256        self.assertTrue(70400 <= server_version < 100000)
257        self.assertEqual(server_version, self.db.db.server_version)
258
259    def testAttributeStatus(self):
260        status_ok = 1
261        status = self.db.status
262        self.assertIsInstance(status, int)
263        self.assertEqual(status, status_ok)
264        self.assertEqual(status, self.db.db.status)
265
266    def testAttributeUser(self):
267        no_user = 'Deprecated facility'
268        user = self.db.user
269        self.assertTrue(user)
270        self.assertIsInstance(user, str)
271        self.assertNotEqual(user, no_user)
272        self.assertEqual(user, self.db.db.user)
273
274    def testMethodEscapeLiteral(self):
275        self.assertEqual(self.db.escape_literal(''), "''")
276
277    def testMethodEscapeIdentifier(self):
278        self.assertEqual(self.db.escape_identifier(''), '""')
279
280    def testMethodEscapeString(self):
281        self.assertEqual(self.db.escape_string(''), '')
282
283    def testMethodEscapeBytea(self):
284        self.assertEqual(self.db.escape_bytea('').replace(
285            '\\x', '').replace('\\', ''), '')
286
287    def testMethodUnescapeBytea(self):
288        self.assertEqual(self.db.unescape_bytea(''), b'')
289
290    def testMethodDecodeJson(self):
291        self.assertEqual(self.db.decode_json('{}'), {})
292
293    def testMethodEncodeJson(self):
294        self.assertEqual(self.db.encode_json({}), '{}')
295
296    def testMethodQuery(self):
297        query = self.db.query
298        query("select 1+1")
299        query("select 1+$1+$2", 2, 3)
300        query("select 1+$1+$2", (2, 3))
301        query("select 1+$1+$2", [2, 3])
302        query("select 1+$1", 1)
303
304    def testMethodQueryEmpty(self):
305        self.assertRaises(ValueError, self.db.query, '')
306
307    def testMethodQueryProgrammingError(self):
308        try:
309            self.db.query("select 1/0")
310        except pg.ProgrammingError as error:
311            self.assertEqual(error.sqlstate, '22012')
312
313    def testMethodEndcopy(self):
314        try:
315            self.db.endcopy()
316        except IOError:
317            pass
318
319    def testMethodClose(self):
320        self.db.close()
321        try:
322            self.db.reset()
323        except pg.Error:
324            pass
325        else:
326            self.fail('Reset should give an error for a closed connection')
327        self.assertIsNone(self.db.db)
328        self.assertRaises(pg.InternalError, self.db.close)
329        self.assertRaises(pg.InternalError, self.db.query, 'select 1')
330        self.assertRaises(pg.InternalError, getattr, self.db, 'status')
331        self.assertRaises(pg.InternalError, getattr, self.db, 'error')
332        self.assertRaises(pg.InternalError, getattr, self.db, 'absent')
333
334    def testMethodReset(self):
335        con = self.db.db
336        self.db.reset()
337        self.assertIs(self.db.db, con)
338        self.db.query("select 1+1")
339        self.db.close()
340        self.assertRaises(pg.InternalError, self.db.reset)
341
342    def testMethodReopen(self):
343        con = self.db.db
344        self.db.reopen()
345        self.assertIsNot(self.db.db, con)
346        con = self.db.db
347        self.db.query("select 1+1")
348        self.db.close()
349        self.db.reopen()
350        self.assertIsNot(self.db.db, con)
351        self.db.query("select 1+1")
352        self.db.close()
353
354    def testExistingConnection(self):
355        db = pg.DB(self.db.db)
356        self.assertEqual(self.db.db, db.db)
357        self.assertTrue(db.db)
358        db.close()
359        self.assertTrue(db.db)
360        db.reopen()
361        self.assertTrue(db.db)
362        db.close()
363        self.assertTrue(db.db)
364        db = pg.DB(self.db)
365        self.assertEqual(self.db.db, db.db)
366        db = pg.DB(db=self.db.db)
367        self.assertEqual(self.db.db, db.db)
368
369        class DB2:
370            pass
371
372        db2 = DB2()
373        db2._cnx = self.db.db
374        db = pg.DB(db2)
375        self.assertEqual(self.db.db, db.db)
376
377
378class TestDBClass(unittest.TestCase):
379    """Test the methods of the DB class wrapped pg connection."""
380
381    maxDiff = 80 * 20
382
383    cls_set_up = False
384
385    regtypes = None
386
387    @classmethod
388    def setUpClass(cls):
389        db = DB()
390        db.query("drop table if exists test cascade")
391        db.query("create table test ("
392                 "i2 smallint, i4 integer, i8 bigint,"
393                 " d numeric, f4 real, f8 double precision, m money,"
394                 " v4 varchar(4), c4 char(4), t text)")
395        db.query("create or replace view test_view as"
396                 " select i4, v4 from test")
397        db.close()
398        cls.cls_set_up = True
399
400    @classmethod
401    def tearDownClass(cls):
402        db = DB()
403        db.query("drop table test cascade")
404        db.close()
405
406    def setUp(self):
407        self.assertTrue(self.cls_set_up)
408        self.db = DB()
409        if self.regtypes is None:
410            self.regtypes = self.db.use_regtypes()
411        else:
412            self.db.use_regtypes(self.regtypes)
413        query = self.db.query
414        query('set client_encoding=utf8')
415        query('set standard_conforming_strings=on')
416        query("set lc_monetary='C'")
417        query("set datestyle='ISO,YMD'")
418        query('set bytea_output=hex')
419
420    def tearDown(self):
421        self.doCleanups()
422        self.db.close()
423
424    def createTable(self, table, definition,
425                    temporary=True, oids=None, values=None):
426        query = self.db.query
427        if not '"' in table or '.' in table:
428            table = '"%s"' % table
429        if not temporary:
430            q = 'drop table if exists %s cascade' % table
431            query(q)
432            self.addCleanup(query, q)
433        temporary = 'temporary table' if temporary else 'table'
434        as_query = definition.startswith(('as ', 'AS '))
435        if not as_query and not definition.startswith('('):
436            definition = '(%s)' % definition
437        with_oids = 'with oids' if oids else 'without oids'
438        q = ['create', temporary, table]
439        if as_query:
440            q.extend([with_oids, definition])
441        else:
442            q.extend([definition, with_oids])
443        q = ' '.join(q)
444        query(q)
445        if values:
446            for params in values:
447                if not isinstance(params, (list, tuple)):
448                    params = [params]
449                values = ', '.join('$%d' % (n + 1) for n in range(len(params)))
450                q = "insert into %s values (%s)" % (table, values)
451                query(q, params)
452
453    def testClassName(self):
454        self.assertEqual(self.db.__class__.__name__, 'DB')
455
456    def testModuleName(self):
457        self.assertEqual(self.db.__module__, 'pg')
458        self.assertEqual(self.db.__class__.__module__, 'pg')
459
460    def testEscapeLiteral(self):
461        f = self.db.escape_literal
462        r = f(b"plain")
463        self.assertIsInstance(r, bytes)
464        self.assertEqual(r, b"'plain'")
465        r = f(u"plain")
466        self.assertIsInstance(r, unicode)
467        self.assertEqual(r, u"'plain'")
468        r = f(u"that's kÀse".encode('utf-8'))
469        self.assertIsInstance(r, bytes)
470        self.assertEqual(r, u"'that''s kÀse'".encode('utf-8'))
471        r = f(u"that's kÀse")
472        self.assertIsInstance(r, unicode)
473        self.assertEqual(r, u"'that''s kÀse'")
474        self.assertEqual(f(r"It's fine to have a \ inside."),
475                         r" E'It''s fine to have a \\ inside.'")
476        self.assertEqual(f('No "quotes" must be escaped.'),
477                         "'No \"quotes\" must be escaped.'")
478
479    def testEscapeIdentifier(self):
480        f = self.db.escape_identifier
481        r = f(b"plain")
482        self.assertIsInstance(r, bytes)
483        self.assertEqual(r, b'"plain"')
484        r = f(u"plain")
485        self.assertIsInstance(r, unicode)
486        self.assertEqual(r, u'"plain"')
487        r = f(u"that's kÀse".encode('utf-8'))
488        self.assertIsInstance(r, bytes)
489        self.assertEqual(r, u'"that\'s kÀse"'.encode('utf-8'))
490        r = f(u"that's kÀse")
491        self.assertIsInstance(r, unicode)
492        self.assertEqual(r, u'"that\'s kÀse"')
493        self.assertEqual(f(r"It's fine to have a \ inside."),
494                         '"It\'s fine to have a \\ inside."')
495        self.assertEqual(f('All "quotes" must be escaped.'),
496                         '"All ""quotes"" must be escaped."')
497
498    def testEscapeString(self):
499        f = self.db.escape_string
500        r = f(b"plain")
501        self.assertIsInstance(r, bytes)
502        self.assertEqual(r, b"plain")
503        r = f(u"plain")
504        self.assertIsInstance(r, unicode)
505        self.assertEqual(r, u"plain")
506        r = f(u"that's kÀse".encode('utf-8'))
507        self.assertIsInstance(r, bytes)
508        self.assertEqual(r, u"that''s kÀse".encode('utf-8'))
509        r = f(u"that's kÀse")
510        self.assertIsInstance(r, unicode)
511        self.assertEqual(r, u"that''s kÀse")
512        self.assertEqual(f(r"It's fine to have a \ inside."),
513                         r"It''s fine to have a \ inside.")
514
515    def testEscapeBytea(self):
516        f = self.db.escape_bytea
517        # note that escape_byte always returns hex output since Pg 9.0,
518        # regardless of the bytea_output setting
519        r = f(b'plain')
520        self.assertIsInstance(r, bytes)
521        self.assertEqual(r, b'\\x706c61696e')
522        r = f(u'plain')
523        self.assertIsInstance(r, unicode)
524        self.assertEqual(r, u'\\x706c61696e')
525        r = f(u"das is' kÀse".encode('utf-8'))
526        self.assertIsInstance(r, bytes)
527        self.assertEqual(r, b'\\x64617320697327206bc3a47365')
528        r = f(u"das is' kÀse")
529        self.assertIsInstance(r, unicode)
530        self.assertEqual(r, u'\\x64617320697327206bc3a47365')
531        self.assertEqual(f(b'O\x00ps\xff!'), b'\\x4f007073ff21')
532
533    def testUnescapeBytea(self):
534        f = self.db.unescape_bytea
535        r = f(b'plain')
536        self.assertIsInstance(r, bytes)
537        self.assertEqual(r, b'plain')
538        r = f(u'plain')
539        self.assertIsInstance(r, bytes)
540        self.assertEqual(r, b'plain')
541        r = f(b"das is' k\\303\\244se")
542        self.assertIsInstance(r, bytes)
543        self.assertEqual(r, u"das is' kÀse".encode('utf8'))
544        r = f(u"das is' k\\303\\244se")
545        self.assertIsInstance(r, bytes)
546        self.assertEqual(r, u"das is' kÀse".encode('utf8'))
547        self.assertEqual(f(r'O\\000ps\\377!'), b'O\\000ps\\377!')
548        self.assertEqual(f(r'\\x706c61696e'), b'\\x706c61696e')
549        self.assertEqual(f(r'\\x746861742773206be47365'),
550                         b'\\x746861742773206be47365')
551        self.assertEqual(f(r'\\x4f007073ff21'), b'\\x4f007073ff21')
552
553    def testDecodeJson(self):
554        f = self.db.decode_json
555        self.assertIsNone(f('null'))
556        data = {
557            "id": 1, "name": "Foo", "price": 1234.5,
558            "new": True, "note": None,
559            "tags": ["Bar", "Eek"],
560            "stock": {"warehouse": 300, "retail": 20}}
561        text = json.dumps(data)
562        r = f(text)
563        self.assertIsInstance(r, dict)
564        self.assertEqual(r, data)
565        self.assertIsInstance(r['id'], int)
566        self.assertIsInstance(r['name'], unicode)
567        self.assertIsInstance(r['price'], float)
568        self.assertIsInstance(r['new'], bool)
569        self.assertIsInstance(r['tags'], list)
570        self.assertIsInstance(r['stock'], dict)
571
572    def testEncodeJson(self):
573        f = self.db.encode_json
574        self.assertEqual(f(None), 'null')
575        data = {
576            "id": 1, "name": "Foo", "price": 1234.5,
577            "new": True, "note": None,
578            "tags": ["Bar", "Eek"],
579            "stock": {"warehouse": 300, "retail": 20}}
580        text = json.dumps(data)
581        r = f(data)
582        self.assertIsInstance(r, str)
583        self.assertEqual(r, text)
584
585    def testGetParameter(self):
586        f = self.db.get_parameter
587        self.assertRaises(TypeError, f)
588        self.assertRaises(TypeError, f, None)
589        self.assertRaises(TypeError, f, 42)
590        self.assertRaises(TypeError, f, '')
591        self.assertRaises(TypeError, f, [])
592        self.assertRaises(TypeError, f, [''])
593        self.assertRaises(pg.ProgrammingError, f, 'this_does_not_exist')
594        r = f('standard_conforming_strings')
595        self.assertEqual(r, 'on')
596        r = f('lc_monetary')
597        self.assertEqual(r, 'C')
598        r = f('datestyle')
599        self.assertEqual(r, 'ISO, YMD')
600        r = f('bytea_output')
601        self.assertEqual(r, 'hex')
602        r = f(['bytea_output', 'lc_monetary'])
603        self.assertIsInstance(r, list)
604        self.assertEqual(r, ['hex', 'C'])
605        r = f(('standard_conforming_strings', 'datestyle', 'bytea_output'))
606        self.assertEqual(r, ['on', 'ISO, YMD', 'hex'])
607        r = f(set(['bytea_output', 'lc_monetary']))
608        self.assertIsInstance(r, dict)
609        self.assertEqual(r, {'bytea_output': 'hex', 'lc_monetary': 'C'})
610        r = f(set(['Bytea_Output', ' LC_Monetary ']))
611        self.assertIsInstance(r, dict)
612        self.assertEqual(r, {'Bytea_Output': 'hex', ' LC_Monetary ': 'C'})
613        s = dict.fromkeys(('bytea_output', 'lc_monetary'))
614        r = f(s)
615        self.assertIs(r, s)
616        self.assertEqual(r, {'bytea_output': 'hex', 'lc_monetary': 'C'})
617        s = dict.fromkeys(('Bytea_Output', ' LC_Monetary '))
618        r = f(s)
619        self.assertIs(r, s)
620        self.assertEqual(r, {'Bytea_Output': 'hex', ' LC_Monetary ': 'C'})
621
622    def testGetParameterServerVersion(self):
623        r = self.db.get_parameter('server_version_num')
624        self.assertIsInstance(r, str)
625        s = self.db.server_version
626        self.assertIsInstance(s, int)
627        self.assertEqual(r, str(s))
628
629    def testGetParameterAll(self):
630        f = self.db.get_parameter
631        r = f('all')
632        self.assertIsInstance(r, dict)
633        self.assertEqual(r['standard_conforming_strings'], 'on')
634        self.assertEqual(r['lc_monetary'], 'C')
635        self.assertEqual(r['DateStyle'], 'ISO, YMD')
636        self.assertEqual(r['bytea_output'], 'hex')
637
638    def testSetParameter(self):
639        f = self.db.set_parameter
640        g = self.db.get_parameter
641        self.assertRaises(TypeError, f)
642        self.assertRaises(TypeError, f, None)
643        self.assertRaises(TypeError, f, 42)
644        self.assertRaises(TypeError, f, '')
645        self.assertRaises(TypeError, f, [])
646        self.assertRaises(TypeError, f, [''])
647        self.assertRaises(ValueError, f, 'all', 'invalid')
648        self.assertRaises(ValueError, f, {
649            'invalid1': 'value1', 'invalid2': 'value2'}, 'value')
650        self.assertRaises(pg.ProgrammingError, f, 'this_does_not_exist')
651        f('standard_conforming_strings', 'off')
652        self.assertEqual(g('standard_conforming_strings'), 'off')
653        f('datestyle', 'ISO, DMY')
654        self.assertEqual(g('datestyle'), 'ISO, DMY')
655        f(['standard_conforming_strings', 'datestyle'], ['on', 'ISO, DMY'])
656        self.assertEqual(g('standard_conforming_strings'), 'on')
657        self.assertEqual(g('datestyle'), 'ISO, DMY')
658        f(['default_with_oids', 'standard_conforming_strings'], 'off')
659        self.assertEqual(g('default_with_oids'), 'off')
660        self.assertEqual(g('standard_conforming_strings'), 'off')
661        f(('standard_conforming_strings', 'datestyle'), ('on', 'ISO, YMD'))
662        self.assertEqual(g('standard_conforming_strings'), 'on')
663        self.assertEqual(g('datestyle'), 'ISO, YMD')
664        f(('default_with_oids', 'standard_conforming_strings'), 'off')
665        self.assertEqual(g('default_with_oids'), 'off')
666        self.assertEqual(g('standard_conforming_strings'), 'off')
667        f(set(['default_with_oids', 'standard_conforming_strings']), 'on')
668        self.assertEqual(g('default_with_oids'), 'on')
669        self.assertEqual(g('standard_conforming_strings'), 'on')
670        self.assertRaises(ValueError, f, set(['default_with_oids',
671            'standard_conforming_strings']), ['off', 'on'])
672        f(set(['default_with_oids', 'standard_conforming_strings']),
673            ['off', 'off'])
674        self.assertEqual(g('default_with_oids'), 'off')
675        self.assertEqual(g('standard_conforming_strings'), 'off')
676        f({'standard_conforming_strings': 'on', 'datestyle': 'ISO, YMD'})
677        self.assertEqual(g('standard_conforming_strings'), 'on')
678        self.assertEqual(g('datestyle'), 'ISO, YMD')
679
680    def testResetParameter(self):
681        db = DB()
682        f = db.set_parameter
683        g = db.get_parameter
684        r = g('default_with_oids')
685        self.assertIn(r, ('on', 'off'))
686        dwi, not_dwi = r, 'off' if r == 'on' else 'on'
687        r = g('standard_conforming_strings')
688        self.assertIn(r, ('on', 'off'))
689        scs, not_scs = r, 'off' if r == 'on' else 'on'
690        f('default_with_oids', not_dwi)
691        f('standard_conforming_strings', not_scs)
692        self.assertEqual(g('default_with_oids'), not_dwi)
693        self.assertEqual(g('standard_conforming_strings'), not_scs)
694        f('default_with_oids')
695        f('standard_conforming_strings', None)
696        self.assertEqual(g('default_with_oids'), dwi)
697        self.assertEqual(g('standard_conforming_strings'), scs)
698        f('default_with_oids', not_dwi)
699        f('standard_conforming_strings', not_scs)
700        self.assertEqual(g('default_with_oids'), not_dwi)
701        self.assertEqual(g('standard_conforming_strings'), not_scs)
702        f(['default_with_oids', 'standard_conforming_strings'], None)
703        self.assertEqual(g('default_with_oids'), dwi)
704        self.assertEqual(g('standard_conforming_strings'), scs)
705        f('default_with_oids', not_dwi)
706        f('standard_conforming_strings', not_scs)
707        self.assertEqual(g('default_with_oids'), not_dwi)
708        self.assertEqual(g('standard_conforming_strings'), not_scs)
709        f(('default_with_oids', 'standard_conforming_strings'))
710        self.assertEqual(g('default_with_oids'), dwi)
711        self.assertEqual(g('standard_conforming_strings'), scs)
712        f('default_with_oids', not_dwi)
713        f('standard_conforming_strings', not_scs)
714        self.assertEqual(g('default_with_oids'), not_dwi)
715        self.assertEqual(g('standard_conforming_strings'), not_scs)
716        f(set(['default_with_oids', 'standard_conforming_strings']))
717        self.assertEqual(g('default_with_oids'), dwi)
718        self.assertEqual(g('standard_conforming_strings'), scs)
719
720    def testResetParameterAll(self):
721        db = DB()
722        f = db.set_parameter
723        self.assertRaises(ValueError, f, 'all', 0)
724        self.assertRaises(ValueError, f, 'all', 'off')
725        g = db.get_parameter
726        r = g('default_with_oids')
727        self.assertIn(r, ('on', 'off'))
728        dwi, not_dwi = r, 'off' if r == 'on' else 'on'
729        r = g('standard_conforming_strings')
730        self.assertIn(r, ('on', 'off'))
731        scs, not_scs = r, 'off' if r == 'on' else 'on'
732        f('default_with_oids', not_dwi)
733        f('standard_conforming_strings', not_scs)
734        self.assertEqual(g('default_with_oids'), not_dwi)
735        self.assertEqual(g('standard_conforming_strings'), not_scs)
736        f('all')
737        self.assertEqual(g('default_with_oids'), dwi)
738        self.assertEqual(g('standard_conforming_strings'), scs)
739
740    def testSetParameterLocal(self):
741        f = self.db.set_parameter
742        g = self.db.get_parameter
743        self.assertEqual(g('standard_conforming_strings'), 'on')
744        self.db.begin()
745        f('standard_conforming_strings', 'off', local=True)
746        self.assertEqual(g('standard_conforming_strings'), 'off')
747        self.db.end()
748        self.assertEqual(g('standard_conforming_strings'), 'on')
749
750    def testSetParameterSession(self):
751        f = self.db.set_parameter
752        g = self.db.get_parameter
753        self.assertEqual(g('standard_conforming_strings'), 'on')
754        self.db.begin()
755        f('standard_conforming_strings', 'off', local=False)
756        self.assertEqual(g('standard_conforming_strings'), 'off')
757        self.db.end()
758        self.assertEqual(g('standard_conforming_strings'), 'off')
759
760    def testReset(self):
761        db = DB()
762        default_datestyle = db.get_parameter('datestyle')
763        changed_datestyle = 'ISO, DMY'
764        if changed_datestyle == default_datestyle:
765            changed_datestyle == 'ISO, YMD'
766        self.db.set_parameter('datestyle', changed_datestyle)
767        r = self.db.get_parameter('datestyle')
768        self.assertEqual(r, changed_datestyle)
769        con = self.db.db
770        q = con.query("show datestyle")
771        self.db.reset()
772        r = q.getresult()[0][0]
773        self.assertEqual(r, changed_datestyle)
774        q = con.query("show datestyle")
775        r = q.getresult()[0][0]
776        self.assertEqual(r, default_datestyle)
777        r = self.db.get_parameter('datestyle')
778        self.assertEqual(r, default_datestyle)
779
780    def testReopen(self):
781        db = DB()
782        default_datestyle = db.get_parameter('datestyle')
783        changed_datestyle = 'ISO, DMY'
784        if changed_datestyle == default_datestyle:
785            changed_datestyle == 'ISO, YMD'
786        self.db.set_parameter('datestyle', changed_datestyle)
787        r = self.db.get_parameter('datestyle')
788        self.assertEqual(r, changed_datestyle)
789        con = self.db.db
790        q = con.query("show datestyle")
791        self.db.reopen()
792        r = q.getresult()[0][0]
793        self.assertEqual(r, changed_datestyle)
794        self.assertRaises(TypeError, getattr, con, 'query')
795        r = self.db.get_parameter('datestyle')
796        self.assertEqual(r, default_datestyle)
797
798    def testCreateTable(self):
799        table = 'test hello world'
800        values = [(2, "World!"), (1, "Hello")]
801        self.createTable(table, "n smallint, t varchar",
802                         temporary=True, oids=True, values=values)
803        r = self.db.query('select t from "%s" order by n' % table).getresult()
804        r = ', '.join(row[0] for row in r)
805        self.assertEqual(r, "Hello, World!")
806        r = self.db.query('select oid from "%s" limit 1' % table).getresult()
807        self.assertIsInstance(r[0][0], int)
808
809    def testQuery(self):
810        query = self.db.query
811        table = 'test_table'
812        self.createTable(table, "n integer", oids=True)
813        q = "insert into test_table values (1)"
814        r = query(q)
815        self.assertIsInstance(r, int)
816        q = "insert into test_table select 2"
817        r = query(q)
818        self.assertIsInstance(r, int)
819        oid = r
820        q = "select oid from test_table where n=2"
821        r = query(q).getresult()
822        self.assertEqual(len(r), 1)
823        r = r[0]
824        self.assertEqual(len(r), 1)
825        r = r[0]
826        self.assertEqual(r, oid)
827        q = "insert into test_table select 3 union select 4 union select 5"
828        r = query(q)
829        self.assertIsInstance(r, str)
830        self.assertEqual(r, '3')
831        q = "update test_table set n=4 where n<5"
832        r = query(q)
833        self.assertIsInstance(r, str)
834        self.assertEqual(r, '4')
835        q = "delete from test_table"
836        r = query(q)
837        self.assertIsInstance(r, str)
838        self.assertEqual(r, '5')
839
840    def testMultipleQueries(self):
841        self.assertEqual(self.db.query(
842            "create temporary table test_multi (n integer);"
843            "insert into test_multi values (4711);"
844            "select n from test_multi").getresult()[0][0], 4711)
845
846    def testQueryWithParams(self):
847        query = self.db.query
848        self.createTable('test_table', 'n1 integer, n2 integer', oids=True)
849        q = "insert into test_table values ($1, $2)"
850        r = query(q, (1, 2))
851        self.assertIsInstance(r, int)
852        r = query(q, [3, 4])
853        self.assertIsInstance(r, int)
854        r = query(q, [5, 6])
855        self.assertIsInstance(r, int)
856        q = "select * from test_table order by 1, 2"
857        self.assertEqual(query(q).getresult(),
858                         [(1, 2), (3, 4), (5, 6)])
859        q = "select * from test_table where n1=$1 and n2=$2"
860        self.assertEqual(query(q, 3, 4).getresult(), [(3, 4)])
861        q = "update test_table set n2=$2 where n1=$1"
862        r = query(q, 3, 7)
863        self.assertEqual(r, '1')
864        q = "select * from test_table order by 1, 2"
865        self.assertEqual(query(q).getresult(),
866                         [(1, 2), (3, 7), (5, 6)])
867        q = "delete from test_table where n2!=$1"
868        r = query(q, 4)
869        self.assertEqual(r, '3')
870
871    def testEmptyQuery(self):
872        self.assertRaises(ValueError, self.db.query, '')
873
874    def testQueryProgrammingError(self):
875        try:
876            self.db.query("select 1/0")
877        except pg.ProgrammingError as error:
878            self.assertEqual(error.sqlstate, '22012')
879
880    def testQueryFormatted(self):
881        f = self.db.query_formatted
882        t = True if pg.get_bool() else 't'
883        q = f("select %s::int, %s::real, %s::text, %s::bool",
884              (3, 2.5, 'hello', True))
885        r = q.getresult()[0]
886        self.assertEqual(r, (3, 2.5, 'hello', t))
887        q = f("select %s, %s, %s, %s", (3, 2.5, 'hello', True), inline=True)
888        r = q.getresult()[0]
889        self.assertEqual(r, (3, 2.5, 'hello', t))
890
891    def testPkey(self):
892        query = self.db.query
893        pkey = self.db.pkey
894        self.assertRaises(KeyError, pkey, 'test')
895        for t in ('pkeytest', 'primary key test'):
896            self.createTable('%s0' % t, 'a smallint')
897            self.createTable('%s1' % t, 'b smallint primary key')
898            self.createTable('%s2' % t,
899                'c smallint, d smallint primary key')
900            self.createTable('%s3' % t,
901                'e smallint, f smallint, g smallint, h smallint, i smallint,'
902                ' primary key (f, h)')
903            self.createTable('%s4' % t,
904                'e smallint, f smallint, g smallint, h smallint, i smallint,'
905                ' primary key (h, f)')
906            self.createTable('%s5' % t,
907                'more_than_one_letter varchar primary key')
908            self.createTable('%s6' % t,
909                '"with space" date primary key')
910            self.createTable('%s7' % t,
911                'a_very_long_column_name varchar, "with space" date, "42" int,'
912                ' primary key (a_very_long_column_name, "with space", "42")')
913            self.assertRaises(KeyError, pkey, '%s0' % t)
914            self.assertEqual(pkey('%s1' % t), 'b')
915            self.assertEqual(pkey('%s1' % t, True), ('b',))
916            self.assertEqual(pkey('%s1' % t, composite=False), 'b')
917            self.assertEqual(pkey('%s1' % t, composite=True), ('b',))
918            self.assertEqual(pkey('%s2' % t), 'd')
919            self.assertEqual(pkey('%s2' % t, composite=True), ('d',))
920            r = pkey('%s3' % t)
921            self.assertIsInstance(r, tuple)
922            self.assertEqual(r, ('f', 'h'))
923            r = pkey('%s3' % t, composite=False)
924            self.assertIsInstance(r, tuple)
925            self.assertEqual(r, ('f', 'h'))
926            r = pkey('%s4' % t)
927            self.assertIsInstance(r, tuple)
928            self.assertEqual(r, ('h', 'f'))
929            self.assertEqual(pkey('%s5' % t), 'more_than_one_letter')
930            self.assertEqual(pkey('%s6' % t), 'with space')
931            r = pkey('%s7' % t)
932            self.assertIsInstance(r, tuple)
933            self.assertEqual(r, (
934                'a_very_long_column_name', 'with space', '42'))
935            # a newly added primary key will be detected
936            query('alter table "%s0" add primary key (a)' % t)
937            self.assertEqual(pkey('%s0' % t), 'a')
938            # a changed primary key will not be detected,
939            # indicating that the internal cache is operating
940            query('alter table "%s1" rename column b to x' % t)
941            self.assertEqual(pkey('%s1' % t), 'b')
942            # we get the changed primary key when the cache is flushed
943            self.assertEqual(pkey('%s1' % t, flush=True), 'x')
944
945    def testGetDatabases(self):
946        databases = self.db.get_databases()
947        self.assertIn('template0', databases)
948        self.assertIn('template1', databases)
949        self.assertNotIn('not existing database', databases)
950        self.assertIn('postgres', databases)
951        self.assertIn(dbname, databases)
952
953    def testGetTables(self):
954        get_tables = self.db.get_tables
955        tables = ('A very Special Name', 'A_MiXeD_quoted_NaMe',
956                  'Hello, Test World!', 'Zoro', 'a1', 'a2', 'a321',
957                  'averyveryveryveryveryveryveryreallyreallylongtablename',
958                  'b0', 'b3', 'x', 'xXx', 'xx', 'y', 'z')
959        for t in tables:
960            self.db.query('drop table if exists "%s" cascade' % t)
961        before_tables = get_tables()
962        self.assertIsInstance(before_tables, list)
963        for t in before_tables:
964            t = t.split('.', 1)
965            self.assertGreaterEqual(len(t), 2)
966            if len(t) > 2:
967                self.assertTrue(t[1].startswith('"'))
968            t = t[0]
969            self.assertNotEqual(t, 'information_schema')
970            self.assertFalse(t.startswith('pg_'))
971        for t in tables:
972            self.createTable(t, 'as select 0', temporary=False)
973        current_tables = get_tables()
974        new_tables = [t for t in current_tables if t not in before_tables]
975        expected_new_tables = ['public.%s' % (
976            '"%s"' % t if ' ' in t or t != t.lower() else t) for t in tables]
977        self.assertEqual(new_tables, expected_new_tables)
978        self.doCleanups()
979        after_tables = get_tables()
980        self.assertEqual(after_tables, before_tables)
981
982    def testGetRelations(self):
983        get_relations = self.db.get_relations
984        result = get_relations()
985        self.assertIn('public.test', result)
986        self.assertIn('public.test_view', result)
987        result = get_relations('rv')
988        self.assertIn('public.test', result)
989        self.assertIn('public.test_view', result)
990        result = get_relations('r')
991        self.assertIn('public.test', result)
992        self.assertNotIn('public.test_view', result)
993        result = get_relations('v')
994        self.assertNotIn('public.test', result)
995        self.assertIn('public.test_view', result)
996        result = get_relations('cisSt')
997        self.assertNotIn('public.test', result)
998        self.assertNotIn('public.test_view', result)
999
1000    def testGetAttnames(self):
1001        get_attnames = self.db.get_attnames
1002        self.assertRaises(pg.ProgrammingError,
1003                          self.db.get_attnames, 'does_not_exist')
1004        self.assertRaises(pg.ProgrammingError,
1005                          self.db.get_attnames, 'has.too.many.dots')
1006        r = get_attnames('test')
1007        self.assertIsInstance(r, dict)
1008        if self.regtypes:
1009            self.assertEqual(r, dict(
1010                i2='smallint', i4='integer', i8='bigint', d='numeric',
1011                f4='real', f8='double precision', m='money',
1012                v4='character varying', c4='character', t='text'))
1013        else:
1014            self.assertEqual(r, dict(
1015                i2='int', i4='int', i8='int', d='num',
1016                f4='float', f8='float', m='money',
1017                v4='text', c4='text', t='text'))
1018        self.createTable('test_table',
1019                         'n int, alpha smallint, beta bool,'
1020                         ' gamma char(5), tau text, v varchar(3)')
1021        r = get_attnames('test_table')
1022        self.assertIsInstance(r, dict)
1023        if self.regtypes:
1024            self.assertEqual(r, dict(
1025                n='integer', alpha='smallint', beta='boolean',
1026                gamma='character', tau='text', v='character varying'))
1027        else:
1028            self.assertEqual(r, dict(
1029                n='int', alpha='int', beta='bool',
1030                gamma='text', tau='text', v='text'))
1031
1032    def testGetAttnamesWithQuotes(self):
1033        get_attnames = self.db.get_attnames
1034        table = 'test table for get_attnames()'
1035        self.createTable(table,
1036            '"Prime!" smallint, "much space" integer, "Questions?" text')
1037        r = get_attnames(table)
1038        self.assertIsInstance(r, dict)
1039        if self.regtypes:
1040            self.assertEqual(r, {
1041                'Prime!': 'smallint', 'much space': 'integer',
1042                'Questions?': 'text'})
1043        else:
1044            self.assertEqual(r, {
1045                'Prime!': 'int', 'much space': 'int', 'Questions?': 'text'})
1046        table = 'yet another test table for get_attnames()'
1047        self.createTable(table,
1048                         'a smallint, b integer, c bigint,'
1049                         ' e numeric, f real, f2 double precision, m money,'
1050                         ' x smallint, y smallint, z smallint,'
1051                         ' Normal_NaMe smallint, "Special Name" smallint,'
1052                         ' t text, u char(2), v varchar(2),'
1053                         ' primary key (y, u)', oids=True)
1054        r = get_attnames(table)
1055        self.assertIsInstance(r, dict)
1056        if self.regtypes:
1057            self.assertEqual(r, {
1058                'a': 'smallint', 'b': 'integer', 'c': 'bigint',
1059                'e': 'numeric', 'f': 'real', 'f2': 'double precision',
1060                'm': 'money', 'normal_name': 'smallint',
1061                'Special Name': 'smallint', 'u': 'character',
1062                't': 'text', 'v': 'character varying', 'y': 'smallint',
1063                'x': 'smallint', 'z': 'smallint', 'oid': 'oid'})
1064        else:
1065            self.assertEqual(r, {'a': 'int', 'b': 'int', 'c': 'int',
1066                 'e': 'num', 'f': 'float', 'f2': 'float', 'm': 'money',
1067                 'normal_name': 'int', 'Special Name': 'int',
1068                 'u': 'text', 't': 'text', 'v': 'text',
1069                 'y': 'int', 'x': 'int', 'z': 'int', 'oid': 'int'})
1070
1071    def testGetAttnamesWithRegtypes(self):
1072        get_attnames = self.db.get_attnames
1073        self.createTable('test_table',
1074                         ' n int, alpha smallint, beta bool,'
1075                         ' gamma char(5), tau text, v varchar(3)')
1076        use_regtypes = self.db.use_regtypes
1077        regtypes = use_regtypes()
1078        self.assertEqual(regtypes, self.regtypes)
1079        use_regtypes(True)
1080        try:
1081            r = get_attnames("test_table")
1082            self.assertIsInstance(r, dict)
1083        finally:
1084            use_regtypes(regtypes)
1085        self.assertEqual(r, dict(
1086            n='integer', alpha='smallint', beta='boolean',
1087            gamma='character', tau='text', v='character varying'))
1088
1089    def testGetAttnamesWithoutRegtypes(self):
1090        get_attnames = self.db.get_attnames
1091        self.createTable('test_table',
1092                         ' n int, alpha smallint, beta bool,'
1093                         ' gamma char(5), tau text, v varchar(3)')
1094        use_regtypes = self.db.use_regtypes
1095        regtypes = use_regtypes()
1096        self.assertEqual(regtypes, self.regtypes)
1097        use_regtypes(False)
1098        try:
1099            r = get_attnames("test_table")
1100            self.assertIsInstance(r, dict)
1101        finally:
1102            use_regtypes(regtypes)
1103        self.assertEqual(r, dict(
1104            n='int', alpha='int', beta='bool',
1105            gamma='text', tau='text', v='text'))
1106
1107    def testGetAttnamesIsCached(self):
1108        get_attnames = self.db.get_attnames
1109        int_type = 'integer' if self.regtypes else 'int'
1110        text_type = 'text'
1111        query = self.db.query
1112        self.createTable('test_table', 'col int')
1113        r = get_attnames("test_table")
1114        self.assertIsInstance(r, dict)
1115        self.assertEqual(r, dict(col=int_type))
1116        query("alter table test_table alter column col type text")
1117        query("alter table test_table add column col2 int")
1118        r = get_attnames("test_table")
1119        self.assertEqual(r, dict(col=int_type))
1120        r = get_attnames("test_table", flush=True)
1121        self.assertEqual(r, dict(col=text_type, col2=int_type))
1122        query("alter table test_table drop column col2")
1123        r = get_attnames("test_table")
1124        self.assertEqual(r, dict(col=text_type, col2=int_type))
1125        r = get_attnames("test_table", flush=True)
1126        self.assertEqual(r, dict(col=text_type))
1127        query("alter table test_table drop column col")
1128        r = get_attnames("test_table")
1129        self.assertEqual(r, dict(col=text_type))
1130        r = get_attnames("test_table", flush=True)
1131        self.assertEqual(r, dict())
1132
1133    def testGetAttnamesIsOrdered(self):
1134        get_attnames = self.db.get_attnames
1135        r = get_attnames('test', flush=True)
1136        self.assertIsInstance(r, OrderedDict)
1137        if self.regtypes:
1138            self.assertEqual(r, OrderedDict([
1139                ('i2', 'smallint'), ('i4', 'integer'), ('i8', 'bigint'),
1140                ('d', 'numeric'), ('f4', 'real'), ('f8', 'double precision'),
1141                ('m', 'money'), ('v4', 'character varying'),
1142                ('c4', 'character'), ('t', 'text')]))
1143        else:
1144            self.assertEqual(r, OrderedDict([
1145                ('i2', 'int'), ('i4', 'int'), ('i8', 'int'),
1146                ('d', 'num'), ('f4', 'float'), ('f8', 'float'), ('m', 'money'),
1147                ('v4', 'text'), ('c4', 'text'), ('t', 'text')]))
1148        if OrderedDict is not dict:
1149            r = ' '.join(list(r.keys()))
1150            self.assertEqual(r, 'i2 i4 i8 d f4 f8 m v4 c4 t')
1151        table = 'test table for get_attnames'
1152        self.createTable(table,
1153                         ' n int, alpha smallint, v varchar(3),'
1154                         ' gamma char(5), tau text, beta bool')
1155        r = get_attnames(table)
1156        self.assertIsInstance(r, OrderedDict)
1157        if self.regtypes:
1158            self.assertEqual(r, OrderedDict([
1159                ('n', 'integer'), ('alpha', 'smallint'),
1160                ('v', 'character varying'), ('gamma', 'character'),
1161                ('tau', 'text'), ('beta', 'boolean')]))
1162        else:
1163            self.assertEqual(r, OrderedDict([
1164                ('n', 'int'), ('alpha', 'int'), ('v', 'text'),
1165                ('gamma', 'text'), ('tau', 'text'), ('beta', 'bool')]))
1166        if OrderedDict is not dict:
1167            r = ' '.join(list(r.keys()))
1168            self.assertEqual(r, 'n alpha v gamma tau beta')
1169        else:
1170            self.skipTest('OrderedDict is not supported')
1171
1172    def testGetAttnamesIsAttrDict(self):
1173        AttrDict = pg.AttrDict
1174        get_attnames = self.db.get_attnames
1175        r = get_attnames('test', flush=True)
1176        self.assertIsInstance(r, AttrDict)
1177        if self.regtypes:
1178            self.assertEqual(r, AttrDict([
1179                ('i2', 'smallint'), ('i4', 'integer'), ('i8', 'bigint'),
1180                ('d', 'numeric'), ('f4', 'real'), ('f8', 'double precision'),
1181                ('m', 'money'), ('v4', 'character varying'),
1182                ('c4', 'character'), ('t', 'text')]))
1183        else:
1184            self.assertEqual(r, AttrDict([
1185                ('i2', 'int'), ('i4', 'int'), ('i8', 'int'),
1186                ('d', 'num'), ('f4', 'float'), ('f8', 'float'), ('m', 'money'),
1187                ('v4', 'text'), ('c4', 'text'), ('t', 'text')]))
1188        r = ' '.join(list(r.keys()))
1189        self.assertEqual(r, 'i2 i4 i8 d f4 f8 m v4 c4 t')
1190        table = 'test table for get_attnames'
1191        self.createTable(table,
1192                         ' n int, alpha smallint, v varchar(3),'
1193                         ' gamma char(5), tau text, beta bool')
1194        r = get_attnames(table)
1195        self.assertIsInstance(r, AttrDict)
1196        if self.regtypes:
1197            self.assertEqual(r, AttrDict([
1198                ('n', 'integer'), ('alpha', 'smallint'),
1199                ('v', 'character varying'), ('gamma', 'character'),
1200                ('tau', 'text'), ('beta', 'boolean')]))
1201        else:
1202            self.assertEqual(r, AttrDict([
1203                ('n', 'int'), ('alpha', 'int'), ('v', 'text'),
1204                ('gamma', 'text'), ('tau', 'text'), ('beta', 'bool')]))
1205        r = ' '.join(list(r.keys()))
1206        self.assertEqual(r, 'n alpha v gamma tau beta')
1207
1208    def testHasTablePrivilege(self):
1209        can = self.db.has_table_privilege
1210        self.assertEqual(can('test'), True)
1211        self.assertEqual(can('test', 'select'), True)
1212        self.assertEqual(can('test', 'SeLeCt'), True)
1213        self.assertEqual(can('test', 'SELECT'), True)
1214        self.assertEqual(can('test', 'insert'), True)
1215        self.assertEqual(can('test', 'update'), True)
1216        self.assertEqual(can('test', 'delete'), True)
1217        self.assertEqual(can('pg_views', 'select'), True)
1218        self.assertEqual(can('pg_views', 'delete'), False)
1219        self.assertRaises(pg.ProgrammingError, can, 'test', 'foobar')
1220        self.assertRaises(pg.ProgrammingError, can, 'table_does_not_exist')
1221
1222    def testGet(self):
1223        get = self.db.get
1224        query = self.db.query
1225        table = 'get_test_table'
1226        self.assertRaises(TypeError, get)
1227        self.assertRaises(TypeError, get, table)
1228        self.createTable(table, 'n integer, t text',
1229                         values=enumerate('xyz', start=1))
1230        self.assertRaises(pg.ProgrammingError, get, table, 2)
1231        r = get(table, 2, 'n')
1232        self.assertIsInstance(r, dict)
1233        self.assertEqual(r, dict(n=2, t='y'))
1234        r = get(table, 1, 'n')
1235        self.assertEqual(r, dict(n=1, t='x'))
1236        r = get(table, (3,), ('n',))
1237        self.assertEqual(r, dict(n=3, t='z'))
1238        r = get(table, 'y', 't')
1239        self.assertEqual(r, dict(n=2, t='y'))
1240        self.assertRaises(pg.DatabaseError, get, table, 4)
1241        self.assertRaises(pg.DatabaseError, get, table, 4, 'n')
1242        self.assertRaises(pg.DatabaseError, get, table, 'y')
1243        self.assertRaises(pg.DatabaseError, get, table, 2, 't')
1244        s = dict(n=3)
1245        self.assertRaises(pg.ProgrammingError, get, table, s)
1246        r = get(table, s, 'n')
1247        self.assertIs(r, s)
1248        self.assertEqual(r, dict(n=3, t='z'))
1249        s.update(t='x')
1250        r = get(table, s, 't')
1251        self.assertIs(r, s)
1252        self.assertEqual(s, dict(n=1, t='x'))
1253        r = get(table, s, ('n', 't'))
1254        self.assertIs(r, s)
1255        self.assertEqual(r, dict(n=1, t='x'))
1256        query('alter table "%s" alter n set not null' % table)
1257        query('alter table "%s" add primary key (n)' % table)
1258        r = get(table, 2)
1259        self.assertIsInstance(r, dict)
1260        self.assertEqual(r, dict(n=2, t='y'))
1261        self.assertEqual(get(table, 1)['t'], 'x')
1262        self.assertEqual(get(table, 3)['t'], 'z')
1263        self.assertEqual(get(table + '*', 2)['t'], 'y')
1264        self.assertEqual(get(table + ' *', 2)['t'], 'y')
1265        self.assertRaises(KeyError, get, table, (2, 2))
1266        s = dict(n=3)
1267        r = get(table, s)
1268        self.assertIs(r, s)
1269        self.assertEqual(r, dict(n=3, t='z'))
1270        s.update(n=1)
1271        self.assertEqual(get(table, s)['t'], 'x')
1272        s.update(n=2)
1273        self.assertEqual(get(table, r)['t'], 'y')
1274        s.pop('n')
1275        self.assertRaises(KeyError, get, table, s)
1276
1277    def testGetWithOid(self):
1278        get = self.db.get
1279        query = self.db.query
1280        table = 'get_with_oid_test_table'
1281        self.createTable(table, 'n integer, t text', oids=True,
1282                         values=enumerate('xyz', start=1))
1283        self.assertRaises(pg.ProgrammingError, get, table, 2)
1284        self.assertRaises(KeyError, get, table, {}, 'oid')
1285        r = get(table, 2, 'n')
1286        qoid = 'oid(%s)' % table
1287        self.assertIn(qoid, r)
1288        oid = r[qoid]
1289        self.assertIsInstance(oid, int)
1290        result = {'t': 'y', 'n': 2, qoid: oid}
1291        self.assertEqual(r, result)
1292        r = get(table, oid, 'oid')
1293        self.assertEqual(r, result)
1294        r = get(table, dict(oid=oid))
1295        self.assertEqual(r, result)
1296        r = get(table, dict(oid=oid), 'oid')
1297        self.assertEqual(r, result)
1298        r = get(table, {qoid: oid})
1299        self.assertEqual(r, result)
1300        r = get(table, {qoid: oid}, 'oid')
1301        self.assertEqual(r, result)
1302        self.assertEqual(get(table + '*', 2, 'n'), r)
1303        self.assertEqual(get(table + ' *', 2, 'n'), r)
1304        self.assertEqual(get(table, oid, 'oid')['t'], 'y')
1305        self.assertEqual(get(table, 1, 'n')['t'], 'x')
1306        self.assertEqual(get(table, 3, 'n')['t'], 'z')
1307        self.assertEqual(get(table, 2, 'n')['t'], 'y')
1308        self.assertRaises(pg.DatabaseError, get, table, 4, 'n')
1309        r['n'] = 3
1310        self.assertEqual(get(table, r, 'n')['t'], 'z')
1311        self.assertEqual(get(table, 1, 'n')['t'], 'x')
1312        self.assertEqual(get(table, r, 'oid')['t'], 'z')
1313        query('alter table "%s" alter n set not null' % table)
1314        query('alter table "%s" add primary key (n)' % table)
1315        self.assertEqual(get(table, 3)['t'], 'z')
1316        self.assertEqual(get(table, 1)['t'], 'x')
1317        self.assertEqual(get(table, 2)['t'], 'y')
1318        r['n'] = 1
1319        self.assertEqual(get(table, r)['t'], 'x')
1320        r['n'] = 3
1321        self.assertEqual(get(table, r)['t'], 'z')
1322        r['n'] = 2
1323        self.assertEqual(get(table, r)['t'], 'y')
1324        r = get(table, oid, 'oid')
1325        self.assertEqual(r, result)
1326        r = get(table, dict(oid=oid))
1327        self.assertEqual(r, result)
1328        r = get(table, dict(oid=oid), 'oid')
1329        self.assertEqual(r, result)
1330        r = get(table, {qoid: oid})
1331        self.assertEqual(r, result)
1332        r = get(table, {qoid: oid}, 'oid')
1333        self.assertEqual(r, result)
1334        r = get(table, dict(oid=oid, n=1))
1335        self.assertEqual(r['n'], 1)
1336        self.assertNotEqual(r[qoid], oid)
1337        r = get(table, dict(oid=oid, t='z'), 't')
1338        self.assertEqual(r['n'], 3)
1339        self.assertNotEqual(r[qoid], oid)
1340
1341    def testGetWithCompositeKey(self):
1342        get = self.db.get
1343        query = self.db.query
1344        table = 'get_test_table_1'
1345        self.createTable(table, 'n integer primary key, t text',
1346                         values=enumerate('abc', start=1))
1347        self.assertEqual(get(table, 2)['t'], 'b')
1348        self.assertEqual(get(table, 1, 'n')['t'], 'a')
1349        self.assertEqual(get(table, 2, ('n',))['t'], 'b')
1350        self.assertEqual(get(table, 3, ['n'])['t'], 'c')
1351        self.assertEqual(get(table, (2,), ('n',))['t'], 'b')
1352        self.assertEqual(get(table, 'b', 't')['n'], 2)
1353        self.assertEqual(get(table, ('a',), ('t',))['n'], 1)
1354        self.assertEqual(get(table, ['c'], ['t'])['n'], 3)
1355        table = 'get_test_table_2'
1356        self.createTable(table,
1357                         'n integer, m integer, t text, primary key (n, m)',
1358                         values=[(n + 1, m + 1, chr(ord('a') + 2 * n + m))
1359                                 for n in range(3) for m in range(2)])
1360        self.assertRaises(KeyError, get, table, 2)
1361        self.assertEqual(get(table, (1, 1))['t'], 'a')
1362        self.assertEqual(get(table, (1, 2))['t'], 'b')
1363        self.assertEqual(get(table, (2, 1))['t'], 'c')
1364        self.assertEqual(get(table, (1, 2), ('n', 'm'))['t'], 'b')
1365        self.assertEqual(get(table, (1, 2), ('m', 'n'))['t'], 'c')
1366        self.assertEqual(get(table, (3, 1), ('n', 'm'))['t'], 'e')
1367        self.assertEqual(get(table, (1, 3), ('m', 'n'))['t'], 'e')
1368        self.assertEqual(get(table, dict(n=2, m=2))['t'], 'd')
1369        self.assertEqual(get(table, dict(n=1, m=2), ('n', 'm'))['t'], 'b')
1370        self.assertEqual(get(table, dict(n=2, m=1), ['n', 'm'])['t'], 'c')
1371        self.assertEqual(get(table, dict(n=3, m=2), ('m', 'n'))['t'], 'f')
1372
1373    def testGetWithQuotedNames(self):
1374        get = self.db.get
1375        query = self.db.query
1376        table = 'test table for get()'
1377        self.createTable(table, '"Prime!" smallint primary key,'
1378                                ' "much space" integer, "Questions?" text',
1379                         values=[(17, 1001, 'No!')])
1380        r = get(table, 17)
1381        self.assertIsInstance(r, dict)
1382        self.assertEqual(r['Prime!'], 17)
1383        self.assertEqual(r['much space'], 1001)
1384        self.assertEqual(r['Questions?'], 'No!')
1385
1386    def testGetFromView(self):
1387        self.db.query('delete from test where i4=14')
1388        self.db.query('insert into test (i4, v4) values('
1389                      "14, 'abc4')")
1390        r = self.db.get('test_view', 14, 'i4')
1391        self.assertIn('v4', r)
1392        self.assertEqual(r['v4'], 'abc4')
1393
1394    def testGetLittleBobbyTables(self):
1395        get = self.db.get
1396        query = self.db.query
1397        self.createTable('test_students',
1398                         'firstname varchar primary key, nickname varchar, grade char(2)',
1399                         values=[("D'Arcy", 'Darcey', 'A+'), ('Sheldon', 'Moonpie', 'A+'),
1400                                 ('Robert', 'Little Bobby Tables', 'D-')])
1401        r = get('test_students', 'Sheldon')
1402        self.assertEqual(r, dict(
1403            firstname="Sheldon", nickname='Moonpie', grade='A+'))
1404        r = get('test_students', 'Robert')
1405        self.assertEqual(r, dict(
1406            firstname="Robert", nickname='Little Bobby Tables', grade='D-'))
1407        r = get('test_students', "D'Arcy")
1408        self.assertEqual(r, dict(
1409            firstname="D'Arcy", nickname='Darcey', grade='A+'))
1410        try:
1411            get('test_students', "D' Arcy")
1412        except pg.DatabaseError as error:
1413            self.assertEqual(str(error),
1414                'No such record in test_students\nwhere "firstname" = $1\n'
1415                'with $1="D\' Arcy"')
1416        try:
1417            get('test_students', "Robert'); TRUNCATE TABLE test_students;--")
1418        except pg.DatabaseError as error:
1419            self.assertEqual(str(error),
1420                'No such record in test_students\nwhere "firstname" = $1\n'
1421                'with $1="Robert\'); TRUNCATE TABLE test_students;--"')
1422        q = "select * from test_students order by 1 limit 4"
1423        r = query(q).getresult()
1424        self.assertEqual(len(r), 3)
1425        self.assertEqual(r[1][2], 'D-')
1426
1427    def testInsert(self):
1428        insert = self.db.insert
1429        query = self.db.query
1430        bool_on = pg.get_bool()
1431        decimal = pg.get_decimal()
1432        table = 'insert_test_table'
1433        self.createTable(table,
1434            'i2 smallint, i4 integer, i8 bigint,'
1435            ' d numeric, f4 real, f8 double precision, m money,'
1436            ' v4 varchar(4), c4 char(4), t text,'
1437            ' b boolean, ts timestamp', oids=True)
1438        oid_table = 'oid(%s)' % table
1439        tests = [dict(i2=None, i4=None, i8=None),
1440             (dict(i2='', i4='', i8=''), dict(i2=None, i4=None, i8=None)),
1441             (dict(i2=0, i4=0, i8=0), dict(i2=0, i4=0, i8=0)),
1442             dict(i2=42, i4=123456, i8=9876543210),
1443             dict(i2=2 ** 15 - 1,
1444                  i4=int(2 ** 31 - 1), i8=long(2 ** 63 - 1)),
1445             dict(d=None), (dict(d=''), dict(d=None)),
1446             dict(d=Decimal(0)), (dict(d=0), dict(d=Decimal(0))),
1447             dict(f4=None, f8=None), dict(f4=0, f8=0),
1448             (dict(f4='', f8=''), dict(f4=None, f8=None)),
1449             (dict(d=1234.5, f4=1234.5, f8=1234.5),
1450              dict(d=Decimal('1234.5'))),
1451             dict(d=Decimal('123.456789'), f4=12.375, f8=123.4921875),
1452             dict(d=Decimal('123456789.9876543212345678987654321')),
1453             dict(m=None), (dict(m=''), dict(m=None)),
1454             dict(m=Decimal('-1234.56')),
1455             (dict(m=('-1234.56')), dict(m=Decimal('-1234.56'))),
1456             dict(m=Decimal('1234.56')), dict(m=Decimal('123456')),
1457             (dict(m='1234.56'), dict(m=Decimal('1234.56'))),
1458             (dict(m=1234.5), dict(m=Decimal('1234.5'))),
1459             (dict(m=-1234.5), dict(m=Decimal('-1234.5'))),
1460             (dict(m=123456), dict(m=Decimal('123456'))),
1461             (dict(m='1234567.89'), dict(m=Decimal('1234567.89'))),
1462             dict(b=None), (dict(b=''), dict(b=None)),
1463             dict(b='f'), dict(b='t'),
1464             (dict(b=0), dict(b='f')), (dict(b=1), dict(b='t')),
1465             (dict(b=False), dict(b='f')), (dict(b=True), dict(b='t')),
1466             (dict(b='0'), dict(b='f')), (dict(b='1'), dict(b='t')),
1467             (dict(b='n'), dict(b='f')), (dict(b='y'), dict(b='t')),
1468             (dict(b='no'), dict(b='f')), (dict(b='yes'), dict(b='t')),
1469             (dict(b='off'), dict(b='f')), (dict(b='on'), dict(b='t')),
1470             dict(v4=None, c4=None, t=None),
1471             (dict(v4='', c4='', t=''), dict(c4=' ' * 4)),
1472             dict(v4='1234', c4='1234', t='1234' * 10),
1473             dict(v4='abcd', c4='abcd', t='abcdefg'),
1474             (dict(v4='abc', c4='abc', t='abc'), dict(c4='abc ')),
1475             dict(ts=None), (dict(ts=''), dict(ts=None)),
1476             (dict(ts=0), dict(ts=None)), (dict(ts=False), dict(ts=None)),
1477             dict(ts='2012-12-21 00:00:00'),
1478             (dict(ts='2012-12-21'), dict(ts='2012-12-21 00:00:00')),
1479             dict(ts='2012-12-21 12:21:12'),
1480             dict(ts='2013-01-05 12:13:14'),
1481             dict(ts='current_timestamp')]
1482        for test in tests:
1483            if isinstance(test, dict):
1484                data = test
1485                change = {}
1486            else:
1487                data, change = test
1488            expect = data.copy()
1489            expect.update(change)
1490            if bool_on:
1491                b = expect.get('b')
1492                if b is not None:
1493                    expect['b'] = b == 't'
1494            if decimal is not Decimal:
1495                d = expect.get('d')
1496                if d is not None:
1497                    expect['d'] = decimal(d)
1498                m = expect.get('m')
1499                if m is not None:
1500                    expect['m'] = decimal(m)
1501            self.assertEqual(insert(table, data), data)
1502            self.assertIn(oid_table, data)
1503            oid = data[oid_table]
1504            self.assertIsInstance(oid, int)
1505            data = dict(item for item in data.items()
1506                        if item[0] in expect)
1507            ts = expect.get('ts')
1508            if ts == 'current_timestamp':
1509                ts = expect['ts'] = data['ts']
1510                if len(ts) > 19:
1511                    self.assertEqual(ts[19], '.')
1512                    ts = ts[:19]
1513                else:
1514                    self.assertEqual(len(ts), 19)
1515                self.assertTrue(ts[:4].isdigit())
1516                self.assertEqual(ts[4], '-')
1517                self.assertEqual(ts[10], ' ')
1518                self.assertTrue(ts[11:13].isdigit())
1519                self.assertEqual(ts[13], ':')
1520            self.assertEqual(data, expect)
1521            data = query(
1522                'select oid,* from "%s"' % table).dictresult()[0]
1523            self.assertEqual(data['oid'], oid)
1524            data = dict(item for item in data.items()
1525                        if item[0] in expect)
1526            self.assertEqual(data, expect)
1527            query('delete from "%s"' % table)
1528
1529    def testInsertWithOid(self):
1530        insert = self.db.insert
1531        query = self.db.query
1532        self.createTable('test_table', 'n int', oids=True)
1533        self.assertRaises(pg.ProgrammingError, insert, 'test_table', m=1)
1534        r = insert('test_table', n=1)
1535        self.assertIsInstance(r, dict)
1536        self.assertEqual(r['n'], 1)
1537        self.assertNotIn('oid', r)
1538        qoid = 'oid(test_table)'
1539        self.assertIn(qoid, r)
1540        oid = r[qoid]
1541        self.assertEqual(sorted(r.keys()), ['n', qoid])
1542        r = insert('test_table', n=2, oid=oid)
1543        self.assertIsInstance(r, dict)
1544        self.assertEqual(r['n'], 2)
1545        self.assertIn(qoid, r)
1546        self.assertNotEqual(r[qoid], oid)
1547        self.assertNotIn('oid', r)
1548        r = insert('test_table', None, n=3)
1549        self.assertIsInstance(r, dict)
1550        self.assertEqual(r['n'], 3)
1551        s = r
1552        r = insert('test_table', r)
1553        self.assertIs(r, s)
1554        self.assertEqual(r['n'], 3)
1555        r = insert('test_table *', r)
1556        self.assertIs(r, s)
1557        self.assertEqual(r['n'], 3)
1558        r = insert('test_table', r, n=4)
1559        self.assertIs(r, s)
1560        self.assertEqual(r['n'], 4)
1561        self.assertNotIn('oid', r)
1562        self.assertIn(qoid, r)
1563        oid = r[qoid]
1564        r = insert('test_table', r, n=5, oid=oid)
1565        self.assertIs(r, s)
1566        self.assertEqual(r['n'], 5)
1567        self.assertIn(qoid, r)
1568        self.assertNotEqual(r[qoid], oid)
1569        self.assertNotIn('oid', r)
1570        r['oid'] = oid = r[qoid]
1571        r = insert('test_table', r, n=6)
1572        self.assertIs(r, s)
1573        self.assertEqual(r['n'], 6)
1574        self.assertIn(qoid, r)
1575        self.assertNotEqual(r[qoid], oid)
1576        self.assertNotIn('oid', r)
1577        q = 'select n from test_table order by 1 limit 9'
1578        r = ' '.join(str(row[0]) for row in query(q).getresult())
1579        self.assertEqual(r, '1 2 3 3 3 4 5 6')
1580        query("truncate test_table")
1581        query("alter table test_table add unique (n)")
1582        r = insert('test_table', dict(n=7))
1583        self.assertIsInstance(r, dict)
1584        self.assertEqual(r['n'], 7)
1585        self.assertRaises(pg.ProgrammingError, insert, 'test_table', r)
1586        r['n'] = 6
1587        self.assertRaises(pg.ProgrammingError, insert, 'test_table', r, n=7)
1588        self.assertIsInstance(r, dict)
1589        self.assertEqual(r['n'], 7)
1590        r['n'] = 6
1591        r = insert('test_table', r)
1592        self.assertIsInstance(r, dict)
1593        self.assertEqual(r['n'], 6)
1594        r = ' '.join(str(row[0]) for row in query(q).getresult())
1595        self.assertEqual(r, '6 7')
1596
1597    def testInsertWithQuotedNames(self):
1598        insert = self.db.insert
1599        query = self.db.query
1600        table = 'test table for insert()'
1601        self.createTable(table, '"Prime!" smallint primary key,'
1602                                ' "much space" integer, "Questions?" text')
1603        r = {'Prime!': 11, 'much space': 2002, 'Questions?': 'What?'}
1604        r = insert(table, r)
1605        self.assertIsInstance(r, dict)
1606        self.assertEqual(r['Prime!'], 11)
1607        self.assertEqual(r['much space'], 2002)
1608        self.assertEqual(r['Questions?'], 'What?')
1609        r = query('select * from "%s" limit 2' % table).dictresult()
1610        self.assertEqual(len(r), 1)
1611        r = r[0]
1612        self.assertEqual(r['Prime!'], 11)
1613        self.assertEqual(r['much space'], 2002)
1614        self.assertEqual(r['Questions?'], 'What?')
1615
1616    def testInsertIntoView(self):
1617        insert = self.db.insert
1618        query = self.db.query
1619        query("truncate test")
1620        q = 'select * from test_view order by i4 limit 3'
1621        r = query(q).getresult()
1622        self.assertEqual(r, [])
1623        r = dict(i4=1234, v4='abcd')
1624        insert('test', r)
1625        self.assertIsNone(r['i2'])
1626        self.assertEqual(r['i4'], 1234)
1627        self.assertIsNone(r['i8'])
1628        self.assertEqual(r['v4'], 'abcd')
1629        self.assertIsNone(r['c4'])
1630        r = query(q).getresult()
1631        self.assertEqual(r, [(1234, 'abcd')])
1632        r = dict(i4=5678, v4='efgh')
1633        try:
1634            insert('test_view', r)
1635        except pg.ProgrammingError as error:
1636            if self.db.server_version < 90300:
1637                # must setup rules in older PostgreSQL versions
1638                self.skipTest('database cannot insert into view')
1639            self.fail(str(error))
1640        self.assertNotIn('i2', r)
1641        self.assertEqual(r['i4'], 5678)
1642        self.assertNotIn('i8', r)
1643        self.assertEqual(r['v4'], 'efgh')
1644        self.assertNotIn('c4', r)
1645        r = query(q).getresult()
1646        self.assertEqual(r, [(1234, 'abcd'), (5678, 'efgh')])
1647
1648    def testUpdate(self):
1649        update = self.db.update
1650        query = self.db.query
1651        self.assertRaises(pg.ProgrammingError, update,
1652                          'test', i2=2, i4=4, i8=8)
1653        table = 'update_test_table'
1654        self.createTable(table, 'n integer, t text', oids=True,
1655                         values=enumerate('xyz', start=1))
1656        self.assertRaises(pg.ProgrammingError, self.db.get, table, 2)
1657        r = self.db.get(table, 2, 'n')
1658        r['t'] = 'u'
1659        s = update(table, r)
1660        self.assertEqual(s, r)
1661        q = 'select t from "%s" where n=2' % table
1662        r = query(q).getresult()[0][0]
1663        self.assertEqual(r, 'u')
1664
1665    def testUpdateWithOid(self):
1666        update = self.db.update
1667        get = self.db.get
1668        query = self.db.query
1669        self.createTable('test_table', 'n int', oids=True, values=[1])
1670        s = get('test_table', 1, 'n')
1671        self.assertIsInstance(s, dict)
1672        self.assertEqual(s['n'], 1)
1673        s['n'] = 2
1674        r = update('test_table', s)
1675        self.assertIs(r, s)
1676        self.assertEqual(r['n'], 2)
1677        qoid = 'oid(test_table)'
1678        self.assertIn(qoid, r)
1679        self.assertNotIn('oid', r)
1680        self.assertEqual(sorted(r.keys()), ['n', qoid])
1681        r['n'] = 3
1682        oid = r.pop(qoid)
1683        r = update('test_table', r, oid=oid)
1684        self.assertIs(r, s)
1685        self.assertEqual(r['n'], 3)
1686        r.pop(qoid)
1687        self.assertRaises(pg.ProgrammingError, update, 'test_table', r)
1688        s = get('test_table', 3, 'n')
1689        self.assertIsInstance(s, dict)
1690        self.assertEqual(s['n'], 3)
1691        s.pop('n')
1692        r = update('test_table', s)
1693        oid = r.pop(qoid)
1694        self.assertEqual(r, {})
1695        q = "select n from test_table limit 2"
1696        r = query(q).getresult()
1697        self.assertEqual(r, [(3,)])
1698        query("insert into test_table values (1)")
1699        self.assertRaises(pg.ProgrammingError,
1700                          update, 'test_table', dict(oid=oid, n=4))
1701        r = update('test_table', dict(n=4), oid=oid)
1702        self.assertEqual(r['n'], 4)
1703        r = update('test_table *', dict(n=5), oid=oid)
1704        self.assertEqual(r['n'], 5)
1705        query("alter table test_table add column m int")
1706        query("alter table test_table add primary key (n)")
1707        self.assertIn('m', self.db.get_attnames('test_table', flush=True))
1708        self.assertEqual('n', self.db.pkey('test_table', flush=True))
1709        s = dict(n=1, m=4)
1710        r = update('test_table', s)
1711        self.assertIs(r, s)
1712        self.assertEqual(r['n'], 1)
1713        self.assertEqual(r['m'], 4)
1714        s = dict(m=7)
1715        r = update('test_table', s, n=5)
1716        self.assertIs(r, s)
1717        self.assertEqual(r['n'], 5)
1718        self.assertEqual(r['m'], 7)
1719        q = "select n, m from test_table order by 1 limit 3"
1720        r = query(q).getresult()
1721        self.assertEqual(r, [(1, 4), (5, 7)])
1722        s = dict(m=9, oid=oid)
1723        self.assertRaises(KeyError, update, 'test_table', s)
1724        r = update('test_table', s, oid=oid)
1725        self.assertIs(r, s)
1726        self.assertEqual(r['n'], 5)
1727        self.assertEqual(r['m'], 9)
1728        s = dict(n=1, m=3, oid=oid)
1729        r = update('test_table', s)
1730        self.assertIs(r, s)
1731        self.assertEqual(r['n'], 1)
1732        self.assertEqual(r['m'], 3)
1733        r = query(q).getresult()
1734        self.assertEqual(r, [(1, 3), (5, 9)])
1735
1736    def testUpdateWithoutOid(self):
1737        update = self.db.update
1738        query = self.db.query
1739        self.assertRaises(pg.ProgrammingError, update,
1740                          'test', i2=2, i4=4, i8=8)
1741        table = 'update_test_table'
1742        self.createTable(table, 'n integer primary key, t text', oids=False,
1743                         values=enumerate('xyz', start=1))
1744        r = self.db.get(table, 2)
1745        r['t'] = 'u'
1746        s = update(table, r)
1747        self.assertEqual(s, r)
1748        q = 'select t from "%s" where n=2' % table
1749        r = query(q).getresult()[0][0]
1750        self.assertEqual(r, 'u')
1751
1752    def testUpdateWithCompositeKey(self):
1753        update = self.db.update
1754        query = self.db.query
1755        table = 'update_test_table_1'
1756        self.createTable(table, 'n integer primary key, t text',
1757                         values=enumerate('abc', start=1))
1758        self.assertRaises(KeyError, update, table, dict(t='b'))
1759        s = dict(n=2, t='d')
1760        r = update(table, s)
1761        self.assertIs(r, s)
1762        self.assertEqual(r['n'], 2)
1763        self.assertEqual(r['t'], 'd')
1764        q = 'select t from "%s" where n=2' % table
1765        r = query(q).getresult()[0][0]
1766        self.assertEqual(r, 'd')
1767        s.update(dict(n=4, t='e'))
1768        r = update(table, s)
1769        self.assertEqual(r['n'], 4)
1770        self.assertEqual(r['t'], 'e')
1771        q = 'select t from "%s" where n=2' % table
1772        r = query(q).getresult()[0][0]
1773        self.assertEqual(r, 'd')
1774        q = 'select t from "%s" where n=4' % table
1775        r = query(q).getresult()
1776        self.assertEqual(len(r), 0)
1777        query('drop table "%s"' % table)
1778        table = 'update_test_table_2'
1779        self.createTable(table,
1780                         'n integer, m integer, t text, primary key (n, m)',
1781                         values=[(n + 1, m + 1, chr(ord('a') + 2 * n + m))
1782                                 for n in range(3) for m in range(2)])
1783        self.assertRaises(KeyError, update, table, dict(n=2, t='b'))
1784        self.assertEqual(update(table,
1785                                dict(n=2, m=2, t='x'))['t'], 'x')
1786        q = 'select t from "%s" where n=2 order by m' % table
1787        r = [r[0] for r in query(q).getresult()]
1788        self.assertEqual(r, ['c', 'x'])
1789
1790    def testUpdateWithQuotedNames(self):
1791        update = self.db.update
1792        query = self.db.query
1793        table = 'test table for update()'
1794        self.createTable(table, '"Prime!" smallint primary key,'
1795                                ' "much space" integer, "Questions?" text',
1796                         values=[(13, 3003, 'Why!')])
1797        r = {'Prime!': 13, 'much space': 7007, 'Questions?': 'When?'}
1798        r = update(table, r)
1799        self.assertIsInstance(r, dict)
1800        self.assertEqual(r['Prime!'], 13)
1801        self.assertEqual(r['much space'], 7007)
1802        self.assertEqual(r['Questions?'], 'When?')
1803        r = query('select * from "%s" limit 2' % table).dictresult()
1804        self.assertEqual(len(r), 1)
1805        r = r[0]
1806        self.assertEqual(r['Prime!'], 13)
1807        self.assertEqual(r['much space'], 7007)
1808        self.assertEqual(r['Questions?'], 'When?')
1809
1810    def testUpsert(self):
1811        upsert = self.db.upsert
1812        query = self.db.query
1813        self.assertRaises(pg.ProgrammingError, upsert,
1814                          'test', i2=2, i4=4, i8=8)
1815        table = 'upsert_test_table'
1816        self.createTable(table, 'n integer primary key, t text', oids=True)
1817        s = dict(n=1, t='x')
1818        try:
1819            r = upsert(table, s)
1820        except pg.ProgrammingError as error:
1821            if self.db.server_version < 90500:
1822                self.skipTest('database does not support upsert')
1823            self.fail(str(error))
1824        self.assertIs(r, s)
1825        self.assertEqual(r['n'], 1)
1826        self.assertEqual(r['t'], 'x')
1827        s.update(n=2, t='y')
1828        r = upsert(table, s, **dict.fromkeys(s))
1829        self.assertIs(r, s)
1830        self.assertEqual(r['n'], 2)
1831        self.assertEqual(r['t'], 'y')
1832        q = 'select n, t from "%s" order by n limit 3' % table
1833        r = query(q).getresult()
1834        self.assertEqual(r, [(1, 'x'), (2, 'y')])
1835        s.update(t='z')
1836        r = upsert(table, s)
1837        self.assertIs(r, s)
1838        self.assertEqual(r['n'], 2)
1839        self.assertEqual(r['t'], 'z')
1840        r = query(q).getresult()
1841        self.assertEqual(r, [(1, 'x'), (2, 'z')])
1842        s.update(t='n')
1843        r = upsert(table, s, t=False)
1844        self.assertIs(r, s)
1845        self.assertEqual(r['n'], 2)
1846        self.assertEqual(r['t'], 'z')
1847        r = query(q).getresult()
1848        self.assertEqual(r, [(1, 'x'), (2, 'z')])
1849        s.update(t='y')
1850        r = upsert(table, s, t=True)
1851        self.assertIs(r, s)
1852        self.assertEqual(r['n'], 2)
1853        self.assertEqual(r['t'], 'y')
1854        r = query(q).getresult()
1855        self.assertEqual(r, [(1, 'x'), (2, 'y')])
1856        s.update(t='n')
1857        r = upsert(table, s, t="included.t || '2'")
1858        self.assertIs(r, s)
1859        self.assertEqual(r['n'], 2)
1860        self.assertEqual(r['t'], 'y2')
1861        r = query(q).getresult()
1862        self.assertEqual(r, [(1, 'x'), (2, 'y2')])
1863        s.update(t='y')
1864        r = upsert(table, s, t="excluded.t || '3'")
1865        self.assertIs(r, s)
1866        self.assertEqual(r['n'], 2)
1867        self.assertEqual(r['t'], 'y3')
1868        r = query(q).getresult()
1869        self.assertEqual(r, [(1, 'x'), (2, 'y3')])
1870        s.update(n=1, t='2')
1871        r = upsert(table, s, t="included.t || excluded.t")
1872        self.assertIs(r, s)
1873        self.assertEqual(r['n'], 1)
1874        self.assertEqual(r['t'], 'x2')
1875        r = query(q).getresult()
1876        self.assertEqual(r, [(1, 'x2'), (2, 'y3')])
1877        # not existing columns and oid parameter should be ignored
1878        s = dict(m=3, u='z')
1879        r = upsert(table, s, oid='invalid')
1880        self.assertIs(r, s)
1881
1882    def testUpsertWithOid(self):
1883        upsert = self.db.upsert
1884        get = self.db.get
1885        query = self.db.query
1886        self.createTable('test_table', 'n int', oids=True, values=[1])
1887        self.assertRaises(pg.ProgrammingError,
1888                          upsert, 'test_table', dict(n=2))
1889        r = get('test_table', 1, 'n')
1890        self.assertIsInstance(r, dict)
1891        self.assertEqual(r['n'], 1)
1892        qoid = 'oid(test_table)'
1893        self.assertIn(qoid, r)
1894        self.assertNotIn('oid', r)
1895        oid = r[qoid]
1896        self.assertRaises(pg.ProgrammingError,
1897                          upsert, 'test_table', dict(n=2, oid=oid))
1898        query("alter table test_table add column m int")
1899        query("alter table test_table add primary key (n)")
1900        self.assertIn('m', self.db.get_attnames('test_table', flush=True))
1901        self.assertEqual('n', self.db.pkey('test_table', flush=True))
1902        s = dict(n=2)
1903        try:
1904            r = upsert('test_table', s)
1905        except pg.ProgrammingError as error:
1906            if self.db.server_version < 90500:
1907                self.skipTest('database does not support upsert')
1908            self.fail(str(error))
1909        self.assertIs(r, s)
1910        self.assertEqual(r['n'], 2)
1911        self.assertIsNone(r['m'])
1912        q = query("select n, m from test_table order by n limit 3")
1913        self.assertEqual(q.getresult(), [(1, None), (2, None)])
1914        r['oid'] = oid
1915        r = upsert('test_table', r)
1916        self.assertIs(r, s)
1917        self.assertEqual(r['n'], 2)
1918        self.assertIsNone(r['m'])
1919        self.assertIn(qoid, r)
1920        self.assertNotIn('oid', r)
1921        self.assertNotEqual(r[qoid], oid)
1922        r['m'] = 7
1923        r = upsert('test_table', r)
1924        self.assertIs(r, s)
1925        self.assertEqual(r['n'], 2)
1926        self.assertEqual(r['m'], 7)
1927        r.update(n=1, m=3)
1928        r = upsert('test_table', r)
1929        self.assertIs(r, s)
1930        self.assertEqual(r['n'], 1)
1931        self.assertEqual(r['m'], 3)
1932        q = query("select n, m from test_table order by n limit 3")
1933        self.assertEqual(q.getresult(), [(1, 3), (2, 7)])
1934        r = upsert('test_table', r, oid='invalid')
1935        self.assertIs(r, s)
1936        self.assertEqual(r['n'], 1)
1937        self.assertEqual(r['m'], 3)
1938        r['m'] = 5
1939        r = upsert('test_table', r, m=False)
1940        self.assertIs(r, s)
1941        self.assertEqual(r['n'], 1)
1942        self.assertEqual(r['m'], 3)
1943        r['m'] = 5
1944        r = upsert('test_table', r, m=True)
1945        self.assertIs(r, s)
1946        self.assertEqual(r['n'], 1)
1947        self.assertEqual(r['m'], 5)
1948        r.update(n=2, m=1)
1949        r = upsert('test_table', r, m='included.m')
1950        self.assertIs(r, s)
1951        self.assertEqual(r['n'], 2)
1952        self.assertEqual(r['m'], 7)
1953        r['m'] = 9
1954        r = upsert('test_table', r, m='excluded.m')
1955        self.assertIs(r, s)
1956        self.assertEqual(r['n'], 2)
1957        self.assertEqual(r['m'], 9)
1958        r['m'] = 8
1959        r = upsert('test_table *', r, m='included.m + 1')
1960        self.assertIs(r, s)
1961        self.assertEqual(r['n'], 2)
1962        self.assertEqual(r['m'], 10)
1963        q = query("select n, m from test_table order by n limit 3")
1964        self.assertEqual(q.getresult(), [(1, 5), (2, 10)])
1965
1966    def testUpsertWithCompositeKey(self):
1967        upsert = self.db.upsert
1968        query = self.db.query
1969        table = 'upsert_test_table_2'
1970        self.createTable(table,
1971                         'n integer, m integer, t text, primary key (n, m)')
1972        s = dict(n=1, m=2, t='x')
1973        try:
1974            r = upsert(table, s)
1975        except pg.ProgrammingError as error:
1976            if self.db.server_version < 90500:
1977                self.skipTest('database does not support upsert')
1978            self.fail(str(error))
1979        self.assertIs(r, s)
1980        self.assertEqual(r['n'], 1)
1981        self.assertEqual(r['m'], 2)
1982        self.assertEqual(r['t'], 'x')
1983        s.update(m=3, t='y')
1984        r = upsert(table, s, **dict.fromkeys(s))
1985        self.assertIs(r, s)
1986        self.assertEqual(r['n'], 1)
1987        self.assertEqual(r['m'], 3)
1988        self.assertEqual(r['t'], 'y')
1989        q = 'select n, m, t from "%s" order by n, m limit 3' % table
1990        r = query(q).getresult()
1991        self.assertEqual(r, [(1, 2, 'x'), (1, 3, 'y')])
1992        s.update(t='z')
1993        r = upsert(table, s)
1994        self.assertIs(r, s)
1995        self.assertEqual(r['n'], 1)
1996        self.assertEqual(r['m'], 3)
1997        self.assertEqual(r['t'], 'z')
1998        r = query(q).getresult()
1999        self.assertEqual(r, [(1, 2, 'x'), (1, 3, 'z')])
2000        s.update(t='n')
2001        r = upsert(table, s, t=False)
2002        self.assertIs(r, s)
2003        self.assertEqual(r['n'], 1)
2004        self.assertEqual(r['m'], 3)
2005        self.assertEqual(r['t'], 'z')
2006        r = query(q).getresult()
2007        self.assertEqual(r, [(1, 2, 'x'), (1, 3, 'z')])
2008        s.update(t='n')
2009        r = upsert(table, s, t=True)
2010        self.assertIs(r, s)
2011        self.assertEqual(r['n'], 1)
2012        self.assertEqual(r['m'], 3)
2013        self.assertEqual(r['t'], 'n')
2014        r = query(q).getresult()
2015        self.assertEqual(r, [(1, 2, 'x'), (1, 3, 'n')])
2016        s.update(n=2, t='y')
2017        r = upsert(table, s, t="'z'")
2018        self.assertIs(r, s)
2019        self.assertEqual(r['n'], 2)
2020        self.assertEqual(r['m'], 3)
2021        self.assertEqual(r['t'], 'y')
2022        r = query(q).getresult()
2023        self.assertEqual(r, [(1, 2, 'x'), (1, 3, 'n'), (2, 3, 'y')])
2024        s.update(n=1, t='m')
2025        r = upsert(table, s, t='included.t || excluded.t')
2026        self.assertIs(r, s)
2027        self.assertEqual(r['n'], 1)
2028        self.assertEqual(r['m'], 3)
2029        self.assertEqual(r['t'], 'nm')
2030        r = query(q).getresult()
2031        self.assertEqual(r, [(1, 2, 'x'), (1, 3, 'nm'), (2, 3, 'y')])
2032
2033    def testUpsertWithQuotedNames(self):
2034        upsert = self.db.upsert
2035        query = self.db.query
2036        table = 'test table for upsert()'
2037        self.createTable(table, '"Prime!" smallint primary key,'
2038                                ' "much space" integer, "Questions?" text')
2039        s = {'Prime!': 31, 'much space': 9009, 'Questions?': 'Yes.'}
2040        try:
2041            r = upsert(table, s)
2042        except pg.ProgrammingError as error:
2043            if self.db.server_version < 90500:
2044                self.skipTest('database does not support upsert')
2045            self.fail(str(error))
2046        self.assertIs(r, s)
2047        self.assertEqual(r['Prime!'], 31)
2048        self.assertEqual(r['much space'], 9009)
2049        self.assertEqual(r['Questions?'], 'Yes.')
2050        q = 'select * from "%s" limit 2' % table
2051        r = query(q).getresult()
2052        self.assertEqual(r, [(31, 9009, 'Yes.')])
2053        s.update({'Questions?': 'No.'})
2054        r = upsert(table, s)
2055        self.assertIs(r, s)
2056        self.assertEqual(r['Prime!'], 31)
2057        self.assertEqual(r['much space'], 9009)
2058        self.assertEqual(r['Questions?'], 'No.')
2059        r = query(q).getresult()
2060        self.assertEqual(r, [(31, 9009, 'No.')])
2061
2062    def testClear(self):
2063        clear = self.db.clear
2064        f = False if pg.get_bool() else 'f'
2065        r = clear('test')
2066        result = dict(
2067            i2=0, i4=0, i8=0, d=0, f4=0, f8=0, m=0, v4='', c4='', t='')
2068        self.assertEqual(r, result)
2069        table = 'clear_test_table'
2070        self.createTable(table,
2071            'n integer, f float, b boolean, d date, t text', oids=True)
2072        r = clear(table)
2073        result = dict(n=0, f=0, b=f, d='', t='')
2074        self.assertEqual(r, result)
2075        r['a'] = r['f'] = r['n'] = 1
2076        r['d'] = r['t'] = 'x'
2077        r['b'] = 't'
2078        r['oid'] = long(1)
2079        r = clear(table, r)
2080        result = dict(a=1, n=0, f=0, b=f, d='', t='', oid=long(1))
2081        self.assertEqual(r, result)
2082
2083    def testClearWithQuotedNames(self):
2084        clear = self.db.clear
2085        table = 'test table for clear()'
2086        self.createTable(table, '"Prime!" smallint primary key,'
2087            ' "much space" integer, "Questions?" text')
2088        r = clear(table)
2089        self.assertIsInstance(r, dict)
2090        self.assertEqual(r['Prime!'], 0)
2091        self.assertEqual(r['much space'], 0)
2092        self.assertEqual(r['Questions?'], '')
2093
2094    def testDelete(self):
2095        delete = self.db.delete
2096        query = self.db.query
2097        self.assertRaises(pg.ProgrammingError, delete,
2098                          'test', dict(i2=2, i4=4, i8=8))
2099        table = 'delete_test_table'
2100        self.createTable(table, 'n integer, t text', oids=True,
2101                         values=enumerate('xyz', start=1))
2102        self.assertRaises(pg.ProgrammingError, self.db.get, table, 2)
2103        r = self.db.get(table, 1, 'n')
2104        s = delete(table, r)
2105        self.assertEqual(s, 1)
2106        r = self.db.get(table, 3, 'n')
2107        s = delete(table, r)
2108        self.assertEqual(s, 1)
2109        s = delete(table, r)
2110        self.assertEqual(s, 0)
2111        r = query('select * from "%s"' % table).dictresult()
2112        self.assertEqual(len(r), 1)
2113        r = r[0]
2114        result = {'n': 2, 't': 'y'}
2115        self.assertEqual(r, result)
2116        r = self.db.get(table, 2, 'n')
2117        s = delete(table, r)
2118        self.assertEqual(s, 1)
2119        s = delete(table, r)
2120        self.assertEqual(s, 0)
2121        self.assertRaises(pg.DatabaseError, self.db.get, table, 2, 'n')
2122        # not existing columns and oid parameter should be ignored
2123        r.update(m=3, u='z', oid='invalid')
2124        s = delete(table, r)
2125        self.assertEqual(s, 0)
2126
2127    def testDeleteWithOid(self):
2128        delete = self.db.delete
2129        get = self.db.get
2130        query = self.db.query
2131        self.createTable('test_table', 'n int', oids=True, values=range(1, 7))
2132        r = dict(n=3)
2133        self.assertRaises(pg.ProgrammingError, delete, 'test_table', r)
2134        s = get('test_table', 1, 'n')
2135        qoid = 'oid(test_table)'
2136        self.assertIn(qoid, s)
2137        r = delete('test_table', s)
2138        self.assertEqual(r, 1)
2139        r = delete('test_table', s)
2140        self.assertEqual(r, 0)
2141        q = "select min(n),count(n) from test_table"
2142        self.assertEqual(query(q).getresult()[0], (2, 5))
2143        oid = get('test_table', 2, 'n')[qoid]
2144        s = dict(oid=oid, n=2)
2145        self.assertRaises(pg.ProgrammingError, delete, 'test_table', s)
2146        r = delete('test_table', None, oid=oid)
2147        self.assertEqual(r, 1)
2148        r = delete('test_table', None, oid=oid)
2149        self.assertEqual(r, 0)
2150        self.assertEqual(query(q).getresult()[0], (3, 4))
2151        s = dict(oid=oid, n=2)
2152        oid = get('test_table', 3, 'n')[qoid]
2153        self.assertRaises(pg.ProgrammingError, delete, 'test_table', s)
2154        r = delete('test_table', s, oid=oid)
2155        self.assertEqual(r, 1)
2156        r = delete('test_table', s, oid=oid)
2157        self.assertEqual(r, 0)
2158        self.assertEqual(query(q).getresult()[0], (4, 3))
2159        s = get('test_table', 4, 'n')
2160        r = delete('test_table *', s)
2161        self.assertEqual(r, 1)
2162        r = delete('test_table *', s)
2163        self.assertEqual(r, 0)
2164        self.assertEqual(query(q).getresult()[0], (5, 2))
2165        oid = get('test_table', 5, 'n')[qoid]
2166        s = {qoid: oid, 'm': 4}
2167        r = delete('test_table', s, m=6)
2168        self.assertEqual(r, 1)
2169        r = delete('test_table *', s)
2170        self.assertEqual(r, 0)
2171        self.assertEqual(query(q).getresult()[0], (6, 1))
2172        query("alter table test_table add column m int")
2173        query("alter table test_table add primary key (n)")
2174        self.assertIn('m', self.db.get_attnames('test_table', flush=True))
2175        self.assertEqual('n', self.db.pkey('test_table', flush=True))
2176        for i in range(5):
2177            query("insert into test_table values (%d, %d)" % (i + 1, i + 2))
2178        s = dict(m=2)
2179        self.assertRaises(KeyError, delete, 'test_table', s)
2180        s = dict(m=2, oid=oid)
2181        self.assertRaises(KeyError, delete, 'test_table', s)
2182        r = delete('test_table', dict(m=2), oid=oid)
2183        self.assertEqual(r, 0)
2184        oid = get('test_table', 1, 'n')[qoid]
2185        s = dict(oid=oid)
2186        self.assertRaises(KeyError, delete, 'test_table', s)
2187        r = delete('test_table', s, oid=oid)
2188        self.assertEqual(r, 1)
2189        r = delete('test_table', s, oid=oid)
2190        self.assertEqual(r, 0)
2191        self.assertEqual(query(q).getresult()[0], (2, 5))
2192        s = get('test_table', 2, 'n')
2193        del s['n']
2194        r = delete('test_table', s)
2195        self.assertEqual(r, 1)
2196        r = delete('test_table', s)
2197        self.assertEqual(r, 0)
2198        self.assertEqual(query(q).getresult()[0], (3, 4))
2199        r = delete('test_table', n=3)
2200        self.assertEqual(r, 1)
2201        r = delete('test_table', n=3)
2202        self.assertEqual(r, 0)
2203        self.assertEqual(query(q).getresult()[0], (4, 3))
2204        r = delete('test_table', None, n=4)
2205        self.assertEqual(r, 1)
2206        r = delete('test_table', None, n=4)
2207        self.assertEqual(r, 0)
2208        self.assertEqual(query(q).getresult()[0], (5, 2))
2209        s = dict(n=6)
2210        r = delete('test_table', s, n=5)
2211        self.assertEqual(r, 1)
2212        r = delete('test_table', s, n=5)
2213        self.assertEqual(r, 0)
2214        self.assertEqual(query(q).getresult()[0], (6, 1))
2215
2216    def testDeleteWithCompositeKey(self):
2217        query = self.db.query
2218        table = 'delete_test_table_1'
2219        self.createTable(table, 'n integer primary key, t text',
2220                         values=enumerate('abc', start=1))
2221        self.assertRaises(KeyError, self.db.delete, table, dict(t='b'))
2222        self.assertEqual(self.db.delete(table, dict(n=2)), 1)
2223        r = query('select t from "%s" where n=2' % table).getresult()
2224        self.assertEqual(r, [])
2225        self.assertEqual(self.db.delete(table, dict(n=2)), 0)
2226        r = query('select t from "%s" where n=3' % table).getresult()[0][0]
2227        self.assertEqual(r, 'c')
2228        table = 'delete_test_table_2'
2229        self.createTable(table,
2230             'n integer, m integer, t text, primary key (n, m)',
2231             values=[(n + 1, m + 1, chr(ord('a') + 2 * n + m))
2232                     for n in range(3) for m in range(2)])
2233        self.assertRaises(KeyError, self.db.delete, table, dict(n=2, t='b'))
2234        self.assertEqual(self.db.delete(table, dict(n=2, m=2)), 1)
2235        r = [r[0] for r in query('select t from "%s" where n=2'
2236            ' order by m' % table).getresult()]
2237        self.assertEqual(r, ['c'])
2238        self.assertEqual(self.db.delete(table, dict(n=2, m=2)), 0)
2239        r = [r[0] for r in query('select t from "%s" where n=3'
2240             ' order by m' % table).getresult()]
2241        self.assertEqual(r, ['e', 'f'])
2242        self.assertEqual(self.db.delete(table, dict(n=3, m=1)), 1)
2243        r = [r[0] for r in query('select t from "%s" where n=3'
2244             ' order by m' % table).getresult()]
2245        self.assertEqual(r, ['f'])
2246
2247    def testDeleteWithQuotedNames(self):
2248        delete = self.db.delete
2249        query = self.db.query
2250        table = 'test table for delete()'
2251        self.createTable(table, '"Prime!" smallint primary key,'
2252            ' "much space" integer, "Questions?" text',
2253            values=[(19, 5005, 'Yes!')])
2254        r = {'Prime!': 17}
2255        r = delete(table, r)
2256        self.assertEqual(r, 0)
2257        r = query('select count(*) from "%s"' % table).getresult()
2258        self.assertEqual(r[0][0], 1)
2259        r = {'Prime!': 19}
2260        r = delete(table, r)
2261        self.assertEqual(r, 1)
2262        r = query('select count(*) from "%s"' % table).getresult()
2263        self.assertEqual(r[0][0], 0)
2264
2265    def testDeleteReferenced(self):
2266        delete = self.db.delete
2267        query = self.db.query
2268        self.createTable('test_parent',
2269            'n smallint primary key', values=range(3))
2270        self.createTable('test_child',
2271            'n smallint primary key references test_parent', values=range(3))
2272        q = ("select (select count(*) from test_parent),"
2273             " (select count(*) from test_child)")
2274        self.assertEqual(query(q).getresult()[0], (3, 3))
2275        self.assertRaises(pg.ProgrammingError,
2276                          delete, 'test_parent', None, n=2)
2277        self.assertRaises(pg.ProgrammingError,
2278                          delete, 'test_parent *', None, n=2)
2279        r = delete('test_child', None, n=2)
2280        self.assertEqual(r, 1)
2281        self.assertEqual(query(q).getresult()[0], (3, 2))
2282        r = delete('test_parent', None, n=2)
2283        self.assertEqual(r, 1)
2284        self.assertEqual(query(q).getresult()[0], (2, 2))
2285        self.assertRaises(pg.ProgrammingError,
2286                          delete, 'test_parent', dict(n=0))
2287        self.assertRaises(pg.ProgrammingError,
2288                          delete, 'test_parent *', dict(n=0))
2289        r = delete('test_child', dict(n=0))
2290        self.assertEqual(r, 1)
2291        self.assertEqual(query(q).getresult()[0], (2, 1))
2292        r = delete('test_child', dict(n=0))
2293        self.assertEqual(r, 0)
2294        r = delete('test_parent', dict(n=0))
2295        self.assertEqual(r, 1)
2296        self.assertEqual(query(q).getresult()[0], (1, 1))
2297        r = delete('test_parent', None, n=0)
2298        self.assertEqual(r, 0)
2299        q = "select n from test_parent natural join test_child limit 2"
2300        self.assertEqual(query(q).getresult(), [(1,)])
2301
2302    def testTruncate(self):
2303        truncate = self.db.truncate
2304        self.assertRaises(TypeError, truncate, None)
2305        self.assertRaises(TypeError, truncate, 42)
2306        self.assertRaises(TypeError, truncate, dict(test_table=None))
2307        query = self.db.query
2308        self.createTable('test_table', 'n smallint',
2309                         temporary=False, values=[1] * 3)
2310        q = "select count(*) from test_table"
2311        r = query(q).getresult()[0][0]
2312        self.assertEqual(r, 3)
2313        truncate('test_table')
2314        r = query(q).getresult()[0][0]
2315        self.assertEqual(r, 0)
2316        for i in range(3):
2317            query("insert into test_table values (1)")
2318        r = query(q).getresult()[0][0]
2319        self.assertEqual(r, 3)
2320        truncate('public.test_table')
2321        r = query(q).getresult()[0][0]
2322        self.assertEqual(r, 0)
2323        self.createTable('test_table_2', 'n smallint', temporary=True)
2324        for t in (list, tuple, set):
2325            for i in range(3):
2326                query("insert into test_table values (1)")
2327                query("insert into test_table_2 values (2)")
2328            q = ("select (select count(*) from test_table),"
2329                 " (select count(*) from test_table_2)")
2330            r = query(q).getresult()[0]
2331            self.assertEqual(r, (3, 3))
2332            truncate(t(['test_table', 'test_table_2']))
2333            r = query(q).getresult()[0]
2334            self.assertEqual(r, (0, 0))
2335
2336    def testTruncateRestart(self):
2337        truncate = self.db.truncate
2338        self.assertRaises(TypeError, truncate, 'test_table', restart='invalid')
2339        query = self.db.query
2340        self.createTable('test_table', 'n serial, t text')
2341        for n in range(3):
2342            query("insert into test_table (t) values ('test')")
2343        q = "select count(n), min(n), max(n) from test_table"
2344        r = query(q).getresult()[0]
2345        self.assertEqual(r, (3, 1, 3))
2346        truncate('test_table')
2347        r = query(q).getresult()[0]
2348        self.assertEqual(r, (0, None, None))
2349        for n in range(3):
2350            query("insert into test_table (t) values ('test')")
2351        r = query(q).getresult()[0]
2352        self.assertEqual(r, (3, 4, 6))
2353        truncate('test_table', restart=True)
2354        r = query(q).getresult()[0]
2355        self.assertEqual(r, (0, None, None))
2356        for n in range(3):
2357            query("insert into test_table (t) values ('test')")
2358        r = query(q).getresult()[0]
2359        self.assertEqual(r, (3, 1, 3))
2360
2361    def testTruncateCascade(self):
2362        truncate = self.db.truncate
2363        self.assertRaises(TypeError, truncate, 'test_table', cascade='invalid')
2364        query = self.db.query
2365        self.createTable('test_parent', 'n smallint primary key',
2366                         values=range(3))
2367        self.createTable('test_child',
2368                         'n smallint primary key references test_parent (n)',
2369                         values=range(3))
2370        q = ("select (select count(*) from test_parent),"
2371             " (select count(*) from test_child)")
2372        r = query(q).getresult()[0]
2373        self.assertEqual(r, (3, 3))
2374        self.assertRaises(pg.ProgrammingError, truncate, 'test_parent')
2375        truncate(['test_parent', 'test_child'])
2376        r = query(q).getresult()[0]
2377        self.assertEqual(r, (0, 0))
2378        for n in range(3):
2379            query("insert into test_parent (n) values (%d)" % n)
2380            query("insert into test_child (n) values (%d)" % n)
2381        r = query(q).getresult()[0]
2382        self.assertEqual(r, (3, 3))
2383        truncate('test_parent', cascade=True)
2384        r = query(q).getresult()[0]
2385        self.assertEqual(r, (0, 0))
2386        for n in range(3):
2387            query("insert into test_parent (n) values (%d)" % n)
2388            query("insert into test_child (n) values (%d)" % n)
2389        r = query(q).getresult()[0]
2390        self.assertEqual(r, (3, 3))
2391        truncate('test_child')
2392        r = query(q).getresult()[0]
2393        self.assertEqual(r, (3, 0))
2394        self.assertRaises(pg.ProgrammingError, truncate, 'test_parent')
2395        truncate('test_parent', cascade=True)
2396        r = query(q).getresult()[0]
2397        self.assertEqual(r, (0, 0))
2398
2399    def testTruncateOnly(self):
2400        truncate = self.db.truncate
2401        self.assertRaises(TypeError, truncate, 'test_table', only='invalid')
2402        query = self.db.query
2403        self.createTable('test_parent', 'n smallint')
2404        self.createTable('test_child', 'm smallint) inherits (test_parent')
2405        for n in range(3):
2406            query("insert into test_parent (n) values (1)")
2407            query("insert into test_child (n, m) values (2, 3)")
2408        q = ("select (select count(*) from test_parent),"
2409             " (select count(*) from test_child)")
2410        r = query(q).getresult()[0]
2411        self.assertEqual(r, (6, 3))
2412        truncate('test_parent')
2413        r = query(q).getresult()[0]
2414        self.assertEqual(r, (0, 0))
2415        for n in range(3):
2416            query("insert into test_parent (n) values (1)")
2417            query("insert into test_child (n, m) values (2, 3)")
2418        r = query(q).getresult()[0]
2419        self.assertEqual(r, (6, 3))
2420        truncate('test_parent*')
2421        r = query(q).getresult()[0]
2422        self.assertEqual(r, (0, 0))
2423        for n in range(3):
2424            query("insert into test_parent (n) values (1)")
2425            query("insert into test_child (n, m) values (2, 3)")
2426        r = query(q).getresult()[0]
2427        self.assertEqual(r, (6, 3))
2428        truncate('test_parent', only=True)
2429        r = query(q).getresult()[0]
2430        self.assertEqual(r, (3, 3))
2431        truncate('test_parent', only=False)
2432        r = query(q).getresult()[0]
2433        self.assertEqual(r, (0, 0))
2434        self.assertRaises(ValueError, truncate, 'test_parent*', only=True)
2435        truncate('test_parent*', only=False)
2436        self.createTable('test_parent_2', 'n smallint')
2437        self.createTable('test_child_2', 'm smallint) inherits (test_parent_2')
2438        for t in '', '_2':
2439            for n in range(3):
2440                query("insert into test_parent%s (n) values (1)" % t)
2441                query("insert into test_child%s (n, m) values (2, 3)" % t)
2442        q = ("select (select count(*) from test_parent),"
2443             " (select count(*) from test_child),"
2444             " (select count(*) from test_parent_2),"
2445             " (select count(*) from test_child_2)")
2446        r = query(q).getresult()[0]
2447        self.assertEqual(r, (6, 3, 6, 3))
2448        truncate(['test_parent', 'test_parent_2'], only=[False, True])
2449        r = query(q).getresult()[0]
2450        self.assertEqual(r, (0, 0, 3, 3))
2451        truncate(['test_parent', 'test_parent_2'], only=False)
2452        r = query(q).getresult()[0]
2453        self.assertEqual(r, (0, 0, 0, 0))
2454        self.assertRaises(ValueError, truncate,
2455            ['test_parent*', 'test_child'], only=[True, False])
2456        truncate(['test_parent*', 'test_child'], only=[False, True])
2457
2458    def testTruncateQuoted(self):
2459        truncate = self.db.truncate
2460        query = self.db.query
2461        table = "test table for truncate()"
2462        self.createTable(table, 'n smallint', temporary=False, values=[1] * 3)
2463        q = 'select count(*) from "%s"' % table
2464        r = query(q).getresult()[0][0]
2465        self.assertEqual(r, 3)
2466        truncate(table)
2467        r = query(q).getresult()[0][0]
2468        self.assertEqual(r, 0)
2469        for i in range(3):
2470            query('insert into "%s" values (1)' % table)
2471        r = query(q).getresult()[0][0]
2472        self.assertEqual(r, 3)
2473        truncate('public."%s"' % table)
2474        r = query(q).getresult()[0][0]
2475        self.assertEqual(r, 0)
2476
2477    def testGetAsList(self):
2478        get_as_list = self.db.get_as_list
2479        self.assertRaises(TypeError, get_as_list)
2480        self.assertRaises(TypeError, get_as_list, None)
2481        query = self.db.query
2482        table = 'test_aslist'
2483        r = query('select 1 as colname').namedresult()[0]
2484        self.assertIsInstance(r, tuple)
2485        named = hasattr(r, 'colname')
2486        names = [(1, 'Homer'), (2, 'Marge'),
2487                 (3, 'Bart'), (4, 'Lisa'), (5, 'Maggie')]
2488        self.createTable(table,
2489            'id smallint primary key, name varchar', values=names)
2490        r = get_as_list(table)
2491        self.assertIsInstance(r, list)
2492        self.assertEqual(r, names)
2493        for t, n in zip(r, names):
2494            self.assertIsInstance(t, tuple)
2495            self.assertEqual(t, n)
2496            if named:
2497                self.assertEqual(t.id, n[0])
2498                self.assertEqual(t.name, n[1])
2499                self.assertEqual(t._asdict(), dict(id=n[0], name=n[1]))
2500        r = get_as_list(table, what='name')
2501        self.assertIsInstance(r, list)
2502        expected = sorted((row[1],) for row in names)
2503        self.assertEqual(r, expected)
2504        r = get_as_list(table, what='name, id')
2505        self.assertIsInstance(r, list)
2506        expected = sorted(tuple(reversed(row)) for row in names)
2507        self.assertEqual(r, expected)
2508        r = get_as_list(table, what=['name', 'id'])
2509        self.assertIsInstance(r, list)
2510        self.assertEqual(r, expected)
2511        r = get_as_list(table, where="name like 'Ba%'")
2512        self.assertIsInstance(r, list)
2513        self.assertEqual(r, names[2:3])
2514        r = get_as_list(table, what='name', where="name like 'Ma%'")
2515        self.assertIsInstance(r, list)
2516        self.assertEqual(r, [('Maggie',), ('Marge',)])
2517        r = get_as_list(table, what='name',
2518            where=["name like 'Ma%'", "name like '%r%'"])
2519        self.assertIsInstance(r, list)
2520        self.assertEqual(r, [('Marge',)])
2521        r = get_as_list(table, what='name', order='id')
2522        self.assertIsInstance(r, list)
2523        expected = [(row[1],) for row in names]
2524        self.assertEqual(r, expected)
2525        r = get_as_list(table, what=['name'], order=['id'])
2526        self.assertIsInstance(r, list)
2527        self.assertEqual(r, expected)
2528        r = get_as_list(table, what=['id', 'name'], order=['id', 'name'])
2529        self.assertIsInstance(r, list)
2530        self.assertEqual(r, names)
2531        r = get_as_list(table, what='id * 2 as num', order='id desc')
2532        self.assertIsInstance(r, list)
2533        expected = [(n,) for n in range(10, 0, -2)]
2534        self.assertEqual(r, expected)
2535        r = get_as_list(table, limit=2)
2536        self.assertIsInstance(r, list)
2537        self.assertEqual(r, names[:2])
2538        r = get_as_list(table, offset=3)
2539        self.assertIsInstance(r, list)
2540        self.assertEqual(r, names[3:])
2541        r = get_as_list(table, limit=1, offset=2)
2542        self.assertIsInstance(r, list)
2543        self.assertEqual(r, names[2:3])
2544        r = get_as_list(table, scalar=True)
2545        self.assertIsInstance(r, list)
2546        self.assertEqual(r, list(range(1, 6)))
2547        r = get_as_list(table, what='name', scalar=True)
2548        self.assertIsInstance(r, list)
2549        expected = sorted(row[1] for row in names)
2550        self.assertEqual(r, expected)
2551        r = get_as_list(table, what='name', limit=1, scalar=True)
2552        self.assertIsInstance(r, list)
2553        self.assertEqual(r, expected[:1])
2554        query('alter table "%s" drop constraint "%s_pkey"' % (table, table))
2555        self.assertRaises(KeyError, self.db.pkey, table, flush=True)
2556        names.insert(1, (1, 'Snowball'))
2557        query('insert into "%s" values ($1, $2)' % table, (1, 'Snowball'))
2558        r = get_as_list(table)
2559        self.assertIsInstance(r, list)
2560        self.assertEqual(r, names)
2561        r = get_as_list(table, what='name', where='id=1', scalar=True)
2562        self.assertIsInstance(r, list)
2563        self.assertEqual(r, ['Homer', 'Snowball'])
2564        # test with unordered query
2565        r = get_as_list(table, order=False)
2566        self.assertIsInstance(r, list)
2567        self.assertEqual(set(r), set(names))
2568        # test with arbitrary from clause
2569        from_table = '(select lower(name) as n2 from "%s") as t2' % table
2570        r = get_as_list(from_table)
2571        self.assertIsInstance(r, list)
2572        r = set(row[0] for row in r)
2573        expected = set(row[1].lower() for row in names)
2574        self.assertEqual(r, expected)
2575        r = get_as_list(from_table, order='n2', scalar=True)
2576        self.assertIsInstance(r, list)
2577        self.assertEqual(r, sorted(expected))
2578        r = get_as_list(from_table, order='n2', limit=1)
2579        self.assertIsInstance(r, list)
2580        self.assertEqual(len(r), 1)
2581        t = r[0]
2582        self.assertIsInstance(t, tuple)
2583        if named:
2584            self.assertEqual(t.n2, 'bart')
2585            self.assertEqual(t._asdict(), dict(n2='bart'))
2586        else:
2587            self.assertEqual(t, ('bart',))
2588
2589    def testGetAsDict(self):
2590        get_as_dict = self.db.get_as_dict
2591        self.assertRaises(TypeError, get_as_dict)
2592        self.assertRaises(TypeError, get_as_dict, None)
2593        # the test table has no primary key
2594        self.assertRaises(pg.ProgrammingError, get_as_dict, 'test')
2595        query = self.db.query
2596        table = 'test_asdict'
2597        r = query('select 1 as colname').namedresult()[0]
2598        self.assertIsInstance(r, tuple)
2599        named = hasattr(r, 'colname')
2600        colors = [(1, '#7cb9e8', 'Aero'), (2, '#b5a642', 'Brass'),
2601                  (3, '#b2ffff', 'Celeste'), (4, '#c19a6b', 'Desert')]
2602        self.createTable(table,
2603            'id smallint primary key, rgb char(7), name varchar',
2604            values=colors)
2605        # keyname must be string, list or tuple
2606        self.assertRaises(KeyError, get_as_dict, table, 3)
2607        self.assertRaises(KeyError, get_as_dict, table, dict(id=None))
2608        # missing keyname in row
2609        self.assertRaises(KeyError, get_as_dict, table,
2610                          keyname='rgb', what='name')
2611        r = get_as_dict(table)
2612        self.assertIsInstance(r, OrderedDict)
2613        expected = OrderedDict((row[0], row[1:]) for row in colors)
2614        self.assertEqual(r, expected)
2615        for key in r:
2616            self.assertIsInstance(key, int)
2617            self.assertIn(key, expected)
2618            row = r[key]
2619            self.assertIsInstance(row, tuple)
2620            t = expected[key]
2621            self.assertEqual(row, t)
2622            if named:
2623                self.assertEqual(row.rgb, t[0])
2624                self.assertEqual(row.name, t[1])
2625                self.assertEqual(row._asdict(), dict(rgb=t[0], name=t[1]))
2626        if OrderedDict is not dict:  # Python > 2.6
2627            self.assertEqual(r.keys(), expected.keys())
2628        r = get_as_dict(table, keyname='rgb')
2629        self.assertIsInstance(r, OrderedDict)
2630        expected = OrderedDict((row[1], (row[0], row[2]))
2631                               for row in sorted(colors, key=itemgetter(1)))
2632        self.assertEqual(r, expected)
2633        for key in r:
2634            self.assertIsInstance(key, str)
2635            self.assertIn(key, expected)
2636            row = r[key]
2637            self.assertIsInstance(row, tuple)
2638            t = expected[key]
2639            self.assertEqual(row, t)
2640            if named:
2641                self.assertEqual(row.id, t[0])
2642                self.assertEqual(row.name, t[1])
2643                self.assertEqual(row._asdict(), dict(id=t[0], name=t[1]))
2644        if OrderedDict is not dict:  # Python > 2.6
2645            self.assertEqual(r.keys(), expected.keys())
2646        r = get_as_dict(table, keyname=['id', 'rgb'])
2647        self.assertIsInstance(r, OrderedDict)
2648        expected = OrderedDict((row[:2], row[2:]) for row in colors)
2649        self.assertEqual(r, expected)
2650        for key in r:
2651            self.assertIsInstance(key, tuple)
2652            self.assertIsInstance(key[0], int)
2653            self.assertIsInstance(key[1], str)
2654            if named:
2655                self.assertEqual(key, (key.id, key.rgb))
2656                self.assertEqual(key._fields, ('id', 'rgb'))
2657            row = r[key]
2658            self.assertIsInstance(row, tuple)
2659            self.assertIsInstance(row[0], str)
2660            t = expected[key]
2661            self.assertEqual(row, t)
2662            if named:
2663                self.assertEqual(row.name, t[0])
2664                self.assertEqual(row._asdict(), dict(name=t[0]))
2665        if OrderedDict is not dict:  # Python > 2.6
2666            self.assertEqual(r.keys(), expected.keys())
2667        r = get_as_dict(table, keyname=['id', 'rgb'], scalar=True)
2668        self.assertIsInstance(r, OrderedDict)
2669        expected = OrderedDict((row[:2], row[2]) for row in colors)
2670        self.assertEqual(r, expected)
2671        for key in r:
2672            self.assertIsInstance(key, tuple)
2673            row = r[key]
2674            self.assertIsInstance(row, str)
2675            t = expected[key]
2676            self.assertEqual(row, t)
2677        if OrderedDict is not dict:  # Python > 2.6
2678            self.assertEqual(r.keys(), expected.keys())
2679        r = get_as_dict(table, keyname='rgb', what=['rgb', 'name'], scalar=True)
2680        self.assertIsInstance(r, OrderedDict)
2681        expected = OrderedDict((row[1], row[2])
2682            for row in sorted(colors, key=itemgetter(1)))
2683        self.assertEqual(r, expected)
2684        for key in r:
2685            self.assertIsInstance(key, str)
2686            row = r[key]
2687            self.assertIsInstance(row, str)
2688            t = expected[key]
2689            self.assertEqual(row, t)
2690        if OrderedDict is not dict:  # Python > 2.6
2691            self.assertEqual(r.keys(), expected.keys())
2692        r = get_as_dict(table, what='id, name',
2693            where="rgb like '#b%'", scalar=True)
2694        self.assertIsInstance(r, OrderedDict)
2695        expected = OrderedDict((row[0], row[2]) for row in colors[1:3])
2696        self.assertEqual(r, expected)
2697        for key in r:
2698            self.assertIsInstance(key, int)
2699            row = r[key]
2700            self.assertIsInstance(row, str)
2701            t = expected[key]
2702            self.assertEqual(row, t)
2703        if OrderedDict is not dict:  # Python > 2.6
2704            self.assertEqual(r.keys(), expected.keys())
2705        expected = r
2706        r = get_as_dict(table, what=['name', 'id'],
2707            where=['id > 1', 'id < 4', "rgb like '#b%'",
2708                   "name not like 'A%'", "name not like '%t'"], scalar=True)
2709        self.assertEqual(r, expected)
2710        r = get_as_dict(table, what='name, id', limit=2, offset=1, scalar=True)
2711        self.assertEqual(r, expected)
2712        r = get_as_dict(table, keyname=('id',), what=('name', 'id'),
2713            where=('id > 1', 'id < 4'), order=('id',), scalar=True)
2714        self.assertEqual(r, expected)
2715        r = get_as_dict(table, limit=1)
2716        self.assertEqual(len(r), 1)
2717        self.assertEqual(r[1][1], 'Aero')
2718        r = get_as_dict(table, offset=3)
2719        self.assertEqual(len(r), 1)
2720        self.assertEqual(r[4][1], 'Desert')
2721        r = get_as_dict(table, order='id desc')
2722        expected = OrderedDict((row[0], row[1:]) for row in reversed(colors))
2723        self.assertEqual(r, expected)
2724        r = get_as_dict(table, where='id > 5')
2725        self.assertIsInstance(r, OrderedDict)
2726        self.assertEqual(len(r), 0)
2727        # test with unordered query
2728        expected = dict((row[0], row[1:]) for row in colors)
2729        r = get_as_dict(table, order=False)
2730        self.assertIsInstance(r, dict)
2731        self.assertEqual(r, expected)
2732        if dict is not OrderedDict:  # Python > 2.6
2733            self.assertNotIsInstance(self, OrderedDict)
2734        # test with arbitrary from clause
2735        from_table = '(select id, lower(name) as n2 from "%s") as t2' % table
2736        # primary key must be passed explicitly in this case
2737        self.assertRaises(pg.ProgrammingError, get_as_dict, from_table)
2738        r = get_as_dict(from_table, 'id')
2739        self.assertIsInstance(r, OrderedDict)
2740        expected = OrderedDict((row[0], (row[2].lower(),)) for row in colors)
2741        self.assertEqual(r, expected)
2742        # test without a primary key
2743        query('alter table "%s" drop constraint "%s_pkey"' % (table, table))
2744        self.assertRaises(KeyError, self.db.pkey, table, flush=True)
2745        self.assertRaises(pg.ProgrammingError, get_as_dict, table)
2746        r = get_as_dict(table, keyname='id')
2747        expected = OrderedDict((row[0], row[1:]) for row in colors)
2748        self.assertIsInstance(r, dict)
2749        self.assertEqual(r, expected)
2750        r = (1, '#007fff', 'Azure')
2751        query('insert into "%s" values ($1, $2, $3)' % table, r)
2752        # the last entry will win
2753        expected[1] = r[1:]
2754        r = get_as_dict(table, keyname='id')
2755        self.assertEqual(r, expected)
2756
2757    def testTransaction(self):
2758        query = self.db.query
2759        self.createTable('test_table', 'n integer', temporary=False)
2760        self.db.begin()
2761        query("insert into test_table values (1)")
2762        query("insert into test_table values (2)")
2763        self.db.commit()
2764        self.db.begin()
2765        query("insert into test_table values (3)")
2766        query("insert into test_table values (4)")
2767        self.db.rollback()
2768        self.db.begin()
2769        query("insert into test_table values (5)")
2770        self.db.savepoint('before6')
2771        query("insert into test_table values (6)")
2772        self.db.rollback('before6')
2773        query("insert into test_table values (7)")
2774        self.db.commit()
2775        self.db.begin()
2776        self.db.savepoint('before8')
2777        query("insert into test_table values (8)")
2778        self.db.release('before8')
2779        self.assertRaises(pg.ProgrammingError, self.db.rollback, 'before8')
2780        self.db.commit()
2781        self.db.start()
2782        query("insert into test_table values (9)")
2783        self.db.end()
2784        r = [r[0] for r in query(
2785            "select * from test_table order by 1").getresult()]
2786        self.assertEqual(r, [1, 2, 5, 7, 9])
2787        self.db.begin(mode='read only')
2788        self.assertRaises(pg.ProgrammingError,
2789                          query, "insert into test_table values (0)")
2790        self.db.rollback()
2791        self.db.start(mode='Read Only')
2792        self.assertRaises(pg.ProgrammingError,
2793                          query, "insert into test_table values (0)")
2794        self.db.abort()
2795
2796    def testTransactionAliases(self):
2797        self.assertEqual(self.db.begin, self.db.start)
2798        self.assertEqual(self.db.commit, self.db.end)
2799        self.assertEqual(self.db.rollback, self.db.abort)
2800
2801    def testContextManager(self):
2802        query = self.db.query
2803        self.createTable('test_table', 'n integer check(n>0)')
2804        with self.db:
2805            query("insert into test_table values (1)")
2806            query("insert into test_table values (2)")
2807        try:
2808            with self.db:
2809                query("insert into test_table values (3)")
2810                query("insert into test_table values (4)")
2811                raise ValueError('test transaction should rollback')
2812        except ValueError as error:
2813            self.assertEqual(str(error), 'test transaction should rollback')
2814        with self.db:
2815            query("insert into test_table values (5)")
2816        try:
2817            with self.db:
2818                query("insert into test_table values (6)")
2819                query("insert into test_table values (-1)")
2820        except pg.ProgrammingError as error:
2821            self.assertTrue('check' in str(error))
2822        with self.db:
2823            query("insert into test_table values (7)")
2824        r = [r[0] for r in query(
2825            "select * from test_table order by 1").getresult()]
2826        self.assertEqual(r, [1, 2, 5, 7])
2827
2828    def testBytea(self):
2829        query = self.db.query
2830        self.createTable('bytea_test', 'n smallint primary key, data bytea')
2831        s = b"It's all \\ kinds \x00 of\r nasty \xff stuff!\n"
2832        r = self.db.escape_bytea(s)
2833        query('insert into bytea_test values(3, $1)', (r,))
2834        r = query('select * from bytea_test where n=3').getresult()
2835        self.assertEqual(len(r), 1)
2836        r = r[0]
2837        self.assertEqual(len(r), 2)
2838        self.assertEqual(r[0], 3)
2839        r = r[1]
2840        if pg.get_bytea_escaped():
2841            self.assertNotEqual(r, s)
2842            r = pg.unescape_bytea(r)
2843        self.assertIsInstance(r, bytes)
2844        self.assertEqual(r, s)
2845
2846    def testInsertUpdateGetBytea(self):
2847        query = self.db.query
2848        unescape = pg.unescape_bytea if pg.get_bytea_escaped() else None
2849        self.createTable('bytea_test', 'n smallint primary key, data bytea')
2850        # insert null value
2851        r = self.db.insert('bytea_test', n=0, data=None)
2852        self.assertIsInstance(r, dict)
2853        self.assertIn('n', r)
2854        self.assertEqual(r['n'], 0)
2855        self.assertIn('data', r)
2856        self.assertIsNone(r['data'])
2857        s = b'None'
2858        r = self.db.update('bytea_test', n=0, data=s)
2859        self.assertIsInstance(r, dict)
2860        self.assertIn('n', r)
2861        self.assertEqual(r['n'], 0)
2862        self.assertIn('data', r)
2863        r = r['data']
2864        if unescape:
2865            self.assertNotEqual(r, s)
2866            r = unescape(r)
2867        self.assertIsInstance(r, bytes)
2868        self.assertEqual(r, s)
2869        r = self.db.update('bytea_test', n=0, data=None)
2870        self.assertIsNone(r['data'])
2871        # insert as bytes
2872        s = b"It's all \\ kinds \x00 of\r nasty \xff stuff!\n"
2873        r = self.db.insert('bytea_test', n=5, data=s)
2874        self.assertIsInstance(r, dict)
2875        self.assertIn('n', r)
2876        self.assertEqual(r['n'], 5)
2877        self.assertIn('data', r)
2878        r = r['data']
2879        if unescape:
2880            self.assertNotEqual(r, s)
2881            r = unescape(r)
2882        self.assertIsInstance(r, bytes)
2883        self.assertEqual(r, s)
2884        # update as bytes
2885        s += b"and now even more \x00 nasty \t stuff!\f"
2886        r = self.db.update('bytea_test', n=5, data=s)
2887        self.assertIsInstance(r, dict)
2888        self.assertIn('n', r)
2889        self.assertEqual(r['n'], 5)
2890        self.assertIn('data', r)
2891        r = r['data']
2892        if unescape:
2893            self.assertNotEqual(r, s)
2894            r = unescape(r)
2895        self.assertIsInstance(r, bytes)
2896        self.assertEqual(r, s)
2897        r = query('select * from bytea_test where n=5').getresult()
2898        self.assertEqual(len(r), 1)
2899        r = r[0]
2900        self.assertEqual(len(r), 2)
2901        self.assertEqual(r[0], 5)
2902        r = r[1]
2903        if unescape:
2904            self.assertNotEqual(r, s)
2905            r = unescape(r)
2906        self.assertIsInstance(r, bytes)
2907        self.assertEqual(r, s)
2908        r = self.db.get('bytea_test', dict(n=5))
2909        self.assertIsInstance(r, dict)
2910        self.assertIn('n', r)
2911        self.assertEqual(r['n'], 5)
2912        self.assertIn('data', r)
2913        r = r['data']
2914        if unescape:
2915            self.assertNotEqual(r, s)
2916            r = pg.unescape_bytea(r)
2917        self.assertIsInstance(r, bytes)
2918        self.assertEqual(r, s)
2919
2920    def testUpsertBytea(self):
2921        self.createTable('bytea_test', 'n smallint primary key, data bytea')
2922        s = b"It's all \\ kinds \x00 of\r nasty \xff stuff!\n"
2923        r = dict(n=7, data=s)
2924        try:
2925            r = self.db.upsert('bytea_test', r)
2926        except pg.ProgrammingError as error:
2927            if self.db.server_version < 90500:
2928                self.skipTest('database does not support upsert')
2929            self.fail(str(error))
2930        self.assertIsInstance(r, dict)
2931        self.assertIn('n', r)
2932        self.assertEqual(r['n'], 7)
2933        self.assertIn('data', r)
2934        if pg.get_bytea_escaped():
2935            self.assertNotEqual(r['data'], s)
2936            r['data'] = pg.unescape_bytea(r['data'])
2937        self.assertIsInstance(r['data'], bytes)
2938        self.assertEqual(r['data'], s)
2939        r['data'] = None
2940        r = self.db.upsert('bytea_test', r)
2941        self.assertIsInstance(r, dict)
2942        self.assertIn('n', r)
2943        self.assertEqual(r['n'], 7)
2944        self.assertIn('data', r)
2945        self.assertIsNone(r['data'])
2946
2947    def testInsertGetJson(self):
2948        try:
2949            self.createTable('json_test', 'n smallint primary key, data json')
2950        except pg.ProgrammingError as error:
2951            if self.db.server_version < 90200:
2952                self.skipTest('database does not support json')
2953            self.fail(str(error))
2954        jsondecode = pg.get_jsondecode()
2955        # insert null value
2956        r = self.db.insert('json_test', n=0, data=None)
2957        self.assertIsInstance(r, dict)
2958        self.assertIn('n', r)
2959        self.assertEqual(r['n'], 0)
2960        self.assertIn('data', r)
2961        self.assertIsNone(r['data'])
2962        r = self.db.get('json_test', 0)
2963        self.assertIsInstance(r, dict)
2964        self.assertIn('n', r)
2965        self.assertEqual(r['n'], 0)
2966        self.assertIn('data', r)
2967        self.assertIsNone(r['data'])
2968        # insert JSON object
2969        data = {
2970            "id": 1, "name": "Foo", "price": 1234.5,
2971            "new": True, "note": None,
2972            "tags": ["Bar", "Eek"],
2973            "stock": {"warehouse": 300, "retail": 20}}
2974        r = self.db.insert('json_test', n=1, data=data)
2975        self.assertIsInstance(r, dict)
2976        self.assertIn('n', r)
2977        self.assertEqual(r['n'], 1)
2978        self.assertIn('data', r)
2979        r = r['data']
2980        if jsondecode is None:
2981            self.assertIsInstance(r, str)
2982            r = json.loads(r)
2983        self.assertIsInstance(r, dict)
2984        self.assertEqual(r, data)
2985        self.assertIsInstance(r['id'], int)
2986        self.assertIsInstance(r['name'], unicode)
2987        self.assertIsInstance(r['price'], float)
2988        self.assertIsInstance(r['new'], bool)
2989        self.assertIsInstance(r['tags'], list)
2990        self.assertIsInstance(r['stock'], dict)
2991        r = self.db.get('json_test', 1)
2992        self.assertIsInstance(r, dict)
2993        self.assertIn('n', r)
2994        self.assertEqual(r['n'], 1)
2995        self.assertIn('data', r)
2996        r = r['data']
2997        if jsondecode is None:
2998            self.assertIsInstance(r, str)
2999            r = json.loads(r)
3000        self.assertIsInstance(r, dict)
3001        self.assertEqual(r, data)
3002        self.assertIsInstance(r['id'], int)
3003        self.assertIsInstance(r['name'], unicode)
3004        self.assertIsInstance(r['price'], float)
3005        self.assertIsInstance(r['new'], bool)
3006        self.assertIsInstance(r['tags'], list)
3007        self.assertIsInstance(r['stock'], dict)
3008        # insert JSON object as text
3009        self.db.insert('json_test', n=2, data=json.dumps(data))
3010        q = "select data from json_test where n in (1, 2) order by n"
3011        r = self.db.query(q).getresult()
3012        self.assertEqual(len(r), 2)
3013        self.assertIsInstance(r[0][0], str if jsondecode is None else dict)
3014        self.assertEqual(r[0][0], r[1][0])
3015
3016    def testInsertGetJsonb(self):
3017        try:
3018            self.createTable('jsonb_test',
3019                             'n smallint primary key, data jsonb')
3020        except pg.ProgrammingError as error:
3021            if self.db.server_version < 90400:
3022                self.skipTest('database does not support jsonb')
3023            self.fail(str(error))
3024        jsondecode = pg.get_jsondecode()
3025        # insert null value
3026        r = self.db.insert('jsonb_test', n=0, data=None)
3027        self.assertIsInstance(r, dict)
3028        self.assertIn('n', r)
3029        self.assertEqual(r['n'], 0)
3030        self.assertIn('data', r)
3031        self.assertIsNone(r['data'])
3032        r = self.db.get('jsonb_test', 0)
3033        self.assertIsInstance(r, dict)
3034        self.assertIn('n', r)
3035        self.assertEqual(r['n'], 0)
3036        self.assertIn('data', r)
3037        self.assertIsNone(r['data'])
3038        # insert JSON object
3039        data = {
3040            "id": 1, "name": "Foo", "price": 1234.5,
3041            "new": True, "note": None,
3042            "tags": ["Bar", "Eek"],
3043            "stock": {"warehouse": 300, "retail": 20}}
3044        r = self.db.insert('jsonb_test', n=1, data=data)
3045        self.assertIsInstance(r, dict)
3046        self.assertIn('n', r)
3047        self.assertEqual(r['n'], 1)
3048        self.assertIn('data', r)
3049        r = r['data']
3050        if jsondecode is None:
3051            self.assertIsInstance(r, str)
3052            r = json.loads(r)
3053        self.assertIsInstance(r, dict)
3054        self.assertEqual(r, data)
3055        self.assertIsInstance(r['id'], int)
3056        self.assertIsInstance(r['name'], unicode)
3057        self.assertIsInstance(r['price'], float)
3058        self.assertIsInstance(r['new'], bool)
3059        self.assertIsInstance(r['tags'], list)
3060        self.assertIsInstance(r['stock'], dict)
3061        r = self.db.get('jsonb_test', 1)
3062        self.assertIsInstance(r, dict)
3063        self.assertIn('n', r)
3064        self.assertEqual(r['n'], 1)
3065        self.assertIn('data', r)
3066        r = r['data']
3067        if jsondecode is None:
3068            self.assertIsInstance(r, str)
3069            r = json.loads(r)
3070        self.assertIsInstance(r, dict)
3071        self.assertEqual(r, data)
3072        self.assertIsInstance(r['id'], int)
3073        self.assertIsInstance(r['name'], unicode)
3074        self.assertIsInstance(r['price'], float)
3075        self.assertIsInstance(r['new'], bool)
3076        self.assertIsInstance(r['tags'], list)
3077        self.assertIsInstance(r['stock'], dict)
3078
3079    def testArray(self):
3080        self.createTable('arraytest',
3081            'id smallint, i2 smallint[], i4 integer[], i8 bigint[],'
3082            ' d numeric[], f4 real[], f8 double precision[], m money[],'
3083            ' b bool[], v4 varchar(4)[], c4 char(4)[], t text[]')
3084        r = self.db.get_attnames('arraytest')
3085        if self.regtypes:
3086            self.assertEqual(r, dict(
3087                id='smallint', i2='smallint[]', i4='integer[]', i8='bigint[]',
3088                d='numeric[]', f4='real[]', f8='double precision[]',
3089                m='money[]', b='boolean[]',
3090                v4='character varying[]', c4='character[]', t='text[]'))
3091        else:
3092            self.assertEqual(r, dict(
3093                id='int', i2='int[]', i4='int[]', i8='int[]',
3094                d='num[]', f4='float[]', f8='float[]', m='money[]',
3095                b='bool[]', v4='text[]', c4='text[]', t='text[]'))
3096        decimal = pg.get_decimal()
3097        if decimal is Decimal:
3098            long_decimal = decimal('123456789.123456789')
3099            odd_money = decimal('1234567891234567.89')
3100        else:
3101            long_decimal = decimal('12345671234.5')
3102            odd_money = decimal('1234567123.25')
3103        t, f = (True, False) if pg.get_bool() else ('t', 'f')
3104        data = dict(id=42, i2=[42, 1234, None, 0, -1],
3105            i4=[42, 123456789, None, 0, 1, -1],
3106            i8=[long(42), long(123456789123456789), None,
3107                long(0), long(1), long(-1)],
3108            d=[decimal(42), long_decimal, None,
3109               decimal(0), decimal(1), decimal(-1), -long_decimal],
3110            f4=[42.0, 1234.5, None, 0.0, 1.0, -1.0,
3111                float('inf'), float('-inf')],
3112            f8=[42.0, 12345671234.5, None, 0.0, 1.0, -1.0,
3113                float('inf'), float('-inf')],
3114            m=[decimal('42.00'), odd_money, None,
3115               decimal('0.00'), decimal('1.00'), decimal('-1.00'), -odd_money],
3116            b=[t, f, t, None, f, t, None, None, t],
3117            v4=['abc', '"Hi"', '', None], c4=['abc ', '"Hi"', '    ', None],
3118            t=['abc', 'Hello, World!', '"Hello, World!"', '', None])
3119        r = data.copy()
3120        self.db.insert('arraytest', r)
3121        self.assertEqual(r, data)
3122        self.db.insert('arraytest', r)
3123        r = self.db.get('arraytest', 42, 'id')
3124        self.assertEqual(r, data)
3125        r = self.db.query('select * from arraytest limit 1').dictresult()[0]
3126        self.assertEqual(r, data)
3127
3128    def testArrayLiteral(self):
3129        insert = self.db.insert
3130        self.createTable('arraytest', 'i int[], t text[]', oids=True)
3131        r = dict(i=[1, 2, 3], t=['a', 'b', 'c'])
3132        insert('arraytest', r)
3133        self.assertEqual(r['i'], [1, 2, 3])
3134        self.assertEqual(r['t'], ['a', 'b', 'c'])
3135        r = dict(i='{1,2,3}', t='{a,b,c}')
3136        self.db.insert('arraytest', r)
3137        self.assertEqual(r['i'], [1, 2, 3])
3138        self.assertEqual(r['t'], ['a', 'b', 'c'])
3139        L = pg.Literal
3140        r = dict(i=L("ARRAY[1, 2, 3]"), t=L("ARRAY['a', 'b', 'c']"))
3141        self.db.insert('arraytest', r)
3142        self.assertEqual(r['i'], [1, 2, 3])
3143        self.assertEqual(r['t'], ['a', 'b', 'c'])
3144        r = dict(i="1, 2, 3", t="'a', 'b', 'c'")
3145        self.assertRaises(pg.ProgrammingError, self.db.insert, 'arraytest', r)
3146
3147    def testArrayOfIds(self):
3148        self.createTable('arraytest', 'c cid[], o oid[], x xid[]', oids=True)
3149        r = self.db.get_attnames('arraytest')
3150        if self.regtypes:
3151            self.assertEqual(r, dict(
3152                oid='oid', c='cid[]', o='oid[]', x='xid[]'))
3153        else:
3154            self.assertEqual(r, dict(
3155                oid='int', c='int[]', o='int[]', x='int[]'))
3156        data = dict(c=[11, 12, 13], o=[21, 22, 23], x=[31, 32, 33])
3157        r = data.copy()
3158        self.db.insert('arraytest', r)
3159        qoid = 'oid(arraytest)'
3160        oid = r.pop(qoid)
3161        self.assertEqual(r, data)
3162        r = {qoid: oid}
3163        self.db.get('arraytest', r)
3164        self.assertEqual(oid, r.pop(qoid))
3165        self.assertEqual(r, data)
3166
3167    def testArrayOfText(self):
3168        self.createTable('arraytest', 'data text[]', oids=True)
3169        r = self.db.get_attnames('arraytest')
3170        self.assertEqual(r['data'], 'text[]')
3171        data = ['Hello, World!', '', None, '{a,b,c}', '"Hi!"',
3172                'null', 'NULL', 'Null', 'nulL',
3173                "It's all \\ kinds of\r nasty stuff!\n"]
3174        r = dict(data=data)
3175        self.db.insert('arraytest', r)
3176        self.assertEqual(r['data'], data)
3177        self.assertIsInstance(r['data'][1], str)
3178        self.assertIsNone(r['data'][2])
3179        r['data'] = None
3180        self.db.get('arraytest', r)
3181        self.assertEqual(r['data'], data)
3182        self.assertIsInstance(r['data'][1], str)
3183        self.assertIsNone(r['data'][2])
3184
3185    def testArrayOfBytea(self):
3186        unescape = pg.unescape_bytea if pg.get_bytea_escaped() else None
3187        self.createTable('arraytest', 'data bytea[]', oids=True)
3188        r = self.db.get_attnames('arraytest')
3189        self.assertEqual(r['data'], 'bytea[]')
3190        data = [b'Hello, World!', b'', None, b'{a,b,c}', b'"Hi!"',
3191                b"It's all \\ kinds \x00 of\r nasty \xff stuff!\n"]
3192        r = dict(data=data)
3193        self.db.insert('arraytest', r)
3194        if unescape:
3195            self.assertNotEqual(r['data'], data)
3196            r['data'] = [unescape(v) if v else v for v in r['data']]
3197        self.assertEqual(r['data'], data)
3198        self.assertIsInstance(r['data'][1], bytes)
3199        self.assertIsNone(r['data'][2])
3200        r['data'] = None
3201        self.db.get('arraytest', r)
3202        if unescape:
3203            self.assertNotEqual(r['data'], data)
3204            r['data'] = [unescape(v) if v else v for v in r['data']]
3205        self.assertEqual(r['data'], data)
3206        self.assertIsInstance(r['data'][1], bytes)
3207        self.assertIsNone(r['data'][2])
3208
3209    def testArrayOfJson(self):
3210        try:
3211            self.createTable('arraytest', 'data json[]', oids=True)
3212        except pg.ProgrammingError as error:
3213            if self.db.server_version < 90200:
3214                self.skipTest('database does not support json')
3215            self.fail(str(error))
3216        r = self.db.get_attnames('arraytest')
3217        self.assertEqual(r['data'], 'json[]')
3218        data = [dict(id=815, name='John Doe'), dict(id=816, name='Jane Roe')]
3219        jsondecode = pg.get_jsondecode()
3220        r = dict(data=data)
3221        self.db.insert('arraytest', r)
3222        if jsondecode is None:
3223            r['data'] = [json.loads(d) for d in r['data']]
3224        self.assertEqual(r['data'], data)
3225        r['data'] = None
3226        self.db.get('arraytest', r)
3227        if jsondecode is None:
3228            r['data'] = [json.loads(d) for d in r['data']]
3229        self.assertEqual(r['data'], data)
3230        r = dict(data=[json.dumps(d) for d in data])
3231        self.db.insert('arraytest', r)
3232        if jsondecode is None:
3233            r['data'] = [json.loads(d) for d in r['data']]
3234        self.assertEqual(r['data'], data)
3235        r['data'] = None
3236        self.db.get('arraytest', r)
3237        # insert empty json values
3238        r = dict(data=['', None])
3239        self.db.insert('arraytest', r)
3240        r = r['data']
3241        self.assertIsInstance(r, list)
3242        self.assertEqual(len(r), 2)
3243        self.assertIsNone(r[0])
3244        self.assertIsNone(r[1])
3245
3246    def testArrayOfJsonb(self):
3247        try:
3248            self.createTable('arraytest', 'data jsonb[]', oids=True)
3249        except pg.ProgrammingError as error:
3250            if self.db.server_version < 90400:
3251                self.skipTest('database does not support jsonb')
3252            self.fail(str(error))
3253        r = self.db.get_attnames('arraytest')
3254        self.assertEqual(r['data'], 'jsonb[]' if self.regtypes else 'json[]')
3255        data = [dict(id=815, name='John Doe'), dict(id=816, name='Jane Roe')]
3256        jsondecode = pg.get_jsondecode()
3257        r = dict(data=data)
3258        self.db.insert('arraytest', r)
3259        if jsondecode is None:
3260            r['data'] = [json.loads(d) for d in r['data']]
3261        self.assertEqual(r['data'], data)
3262        r['data'] = None
3263        self.db.get('arraytest', r)
3264        if jsondecode is None:
3265            r['data'] = [json.loads(d) for d in r['data']]
3266        self.assertEqual(r['data'], data)
3267        r = dict(data=[json.dumps(d) for d in data])
3268        self.db.insert('arraytest', r)
3269        if jsondecode is None:
3270            r['data'] = [json.loads(d) for d in r['data']]
3271        self.assertEqual(r['data'], data)
3272        r['data'] = None
3273        self.db.get('arraytest', r)
3274        # insert empty json values
3275        r = dict(data=['', None])
3276        self.db.insert('arraytest', r)
3277        r = r['data']
3278        self.assertIsInstance(r, list)
3279        self.assertEqual(len(r), 2)
3280        self.assertIsNone(r[0])
3281        self.assertIsNone(r[1])
3282
3283    def testDeepArray(self):
3284        self.createTable('arraytest', 'data text[][][]', oids=True)
3285        r = self.db.get_attnames('arraytest')
3286        self.assertEqual(r['data'], 'text[]')
3287        data = [[['Hello, World!', '{a,b,c}', 'back\\slash']]]
3288        r = dict(data=data)
3289        self.db.insert('arraytest', r)
3290        self.assertEqual(r['data'], data)
3291        r['data'] = None
3292        self.db.get('arraytest', r)
3293        self.assertEqual(r['data'], data)
3294
3295    def testInsertUpdateGetRecord(self):
3296        query = self.db.query
3297        query('create type test_person_type as'
3298              ' (name varchar, age smallint, married bool,'
3299              ' weight real, salary money)')
3300        self.addCleanup(query, 'drop type test_person_type')
3301        self.createTable('test_person', 'person test_person_type',
3302                         temporary=False, oids=True)
3303        attnames = self.db.get_attnames('test_person')
3304        self.assertEqual(len(attnames), 2)
3305        self.assertIn('oid', attnames)
3306        self.assertIn('person', attnames)
3307        person_typ = attnames['person']
3308        if self.regtypes:
3309            self.assertEqual(person_typ, 'test_person_type')
3310        else:
3311            self.assertEqual(person_typ, 'record')
3312        if self.regtypes:
3313            self.assertEqual(person_typ.attnames,
3314                dict(name='character varying', age='smallint',
3315                    married='boolean', weight='real', salary='money'))
3316        else:
3317            self.assertEqual(person_typ.attnames,
3318                dict(name='text', age='int', married='bool',
3319                    weight='float', salary='money'))
3320        decimal = pg.get_decimal()
3321        if pg.get_bool():
3322            bool_class = bool
3323            t, f = True, False
3324        else:
3325            bool_class = str
3326            t, f = 't', 'f'
3327        person = ('John Doe', 61, t, 99.5, decimal('93456.75'))
3328        r = self.db.insert('test_person', None, person=person)
3329        p = r['person']
3330        self.assertIsInstance(p, tuple)
3331        self.assertEqual(p, person)
3332        self.assertEqual(p.name, 'John Doe')
3333        self.assertIsInstance(p.name, str)
3334        self.assertIsInstance(p.age, int)
3335        self.assertIsInstance(p.married, bool_class)
3336        self.assertIsInstance(p.weight, float)
3337        self.assertIsInstance(p.salary, decimal)
3338        person = ('Jane Roe', 59, f, 64.5, decimal('96543.25'))
3339        r['person'] = person
3340        self.db.update('test_person', r)
3341        p = r['person']
3342        self.assertIsInstance(p, tuple)
3343        self.assertEqual(p, person)
3344        self.assertEqual(p.name, 'Jane Roe')
3345        self.assertIsInstance(p.name, str)
3346        self.assertIsInstance(p.age, int)
3347        self.assertIsInstance(p.married, bool_class)
3348        self.assertIsInstance(p.weight, float)
3349        self.assertIsInstance(p.salary, decimal)
3350        r['person'] = None
3351        self.db.get('test_person', r)
3352        p = r['person']
3353        self.assertIsInstance(p, tuple)
3354        self.assertEqual(p, person)
3355        self.assertEqual(p.name, 'Jane Roe')
3356        self.assertIsInstance(p.name, str)
3357        self.assertIsInstance(p.age, int)
3358        self.assertIsInstance(p.married, bool_class)
3359        self.assertIsInstance(p.weight, float)
3360        self.assertIsInstance(p.salary, decimal)
3361        person = (None,) * 5
3362        r = self.db.insert('test_person', None, person=person)
3363        p = r['person']
3364        self.assertIsInstance(p, tuple)
3365        self.assertIsNone(p.name)
3366        self.assertIsNone(p.age)
3367        self.assertIsNone(p.married)
3368        self.assertIsNone(p.weight)
3369        self.assertIsNone(p.salary)
3370        r['person'] = None
3371        self.db.get('test_person', r)
3372        p = r['person']
3373        self.assertIsInstance(p, tuple)
3374        self.assertIsNone(p.name)
3375        self.assertIsNone(p.age)
3376        self.assertIsNone(p.married)
3377        self.assertIsNone(p.weight)
3378        self.assertIsNone(p.salary)
3379        r = self.db.insert('test_person', None, person=None)
3380        self.assertIsNone(r['person'])
3381        r['person'] = None
3382        self.db.get('test_person', r)
3383        self.assertIsNone(r['person'])
3384
3385    def testRecordInsertBytea(self):
3386        query = self.db.query
3387        query('create type test_person_type as'
3388              ' (name text, picture bytea)')
3389        self.addCleanup(query, 'drop type test_person_type')
3390        self.createTable('test_person', 'person test_person_type',
3391                         temporary=False, oids=True)
3392        person_typ = self.db.get_attnames('test_person')['person']
3393        self.assertEqual(person_typ.attnames,
3394                         dict(name='text', picture='bytea'))
3395        person = ('John Doe', b'O\x00ps\xff!')
3396        r = self.db.insert('test_person', None, person=person)
3397        p = r['person']
3398        self.assertIsInstance(p, tuple)
3399        self.assertEqual(p, person)
3400        self.assertEqual(p.name, 'John Doe')
3401        self.assertIsInstance(p.name, str)
3402        self.assertEqual(p.picture, person[1])
3403        self.assertIsInstance(p.picture, bytes)
3404
3405    def testRecordInsertJson(self):
3406        query = self.db.query
3407        try:
3408            query('create type test_person_type as'
3409                  ' (name text, data json)')
3410        except pg.ProgrammingError as error:
3411            if self.db.server_version < 90200:
3412                self.skipTest('database does not support json')
3413            self.fail(str(error))
3414        self.addCleanup(query, 'drop type test_person_type')
3415        self.createTable('test_person', 'person test_person_type',
3416                         temporary=False, oids=True)
3417        person_typ = self.db.get_attnames('test_person')['person']
3418        self.assertEqual(person_typ.attnames,
3419                         dict(name='text', data='json'))
3420        person = ('John Doe', dict(age=61, married=True, weight=99.5))
3421        r = self.db.insert('test_person', None, person=person)
3422        p = r['person']
3423        self.assertIsInstance(p, tuple)
3424        if pg.get_jsondecode() is None:
3425            p = p._replace(data=json.loads(p.data))
3426        self.assertEqual(p, person)
3427        self.assertEqual(p.name, 'John Doe')
3428        self.assertIsInstance(p.name, str)
3429        self.assertEqual(p.data, person[1])
3430        self.assertIsInstance(p.data, dict)
3431
3432    def testRecordLiteral(self):
3433        query = self.db.query
3434        query('create type test_person_type as'
3435              ' (name varchar, age smallint)')
3436        self.addCleanup(query, 'drop type test_person_type')
3437        self.createTable('test_person', 'person test_person_type',
3438                         temporary=False, oids=True)
3439        person_typ = self.db.get_attnames('test_person')['person']
3440        if self.regtypes:
3441            self.assertEqual(person_typ, 'test_person_type')
3442        else:
3443            self.assertEqual(person_typ, 'record')
3444        if self.regtypes:
3445            self.assertEqual(person_typ.attnames,
3446                             dict(name='character varying', age='smallint'))
3447        else:
3448            self.assertEqual(person_typ.attnames,
3449                             dict(name='text', age='int'))
3450        person = pg.Literal("('John Doe', 61)")
3451        r = self.db.insert('test_person', None, person=person)
3452        p = r['person']
3453        self.assertIsInstance(p, tuple)
3454        self.assertEqual(p.name, 'John Doe')
3455        self.assertIsInstance(p.name, str)
3456        self.assertEqual(p.age, 61)
3457        self.assertIsInstance(p.age, int)
3458
3459    def testDbTypesInfo(self):
3460        dbtypes = self.db.dbtypes
3461        self.assertIsInstance(dbtypes, dict)
3462        self.assertNotIn('numeric', dbtypes)
3463        typ = dbtypes['numeric']
3464        self.assertIn('numeric', dbtypes)
3465        self.assertEqual(typ, 'numeric' if self.regtypes else 'num')
3466        self.assertEqual(typ.oid, 1700)
3467        self.assertEqual(typ.pgtype, 'numeric')
3468        self.assertEqual(typ.regtype, 'numeric')
3469        self.assertEqual(typ.simple, 'num')
3470        self.assertEqual(typ.typtype, 'b')
3471        self.assertEqual(typ.category, 'N')
3472        self.assertEqual(typ.delim, ',')
3473        self.assertEqual(typ.relid, 0)
3474        self.assertIs(dbtypes[1700], typ)
3475        self.assertNotIn('pg_type', dbtypes)
3476        typ = dbtypes['pg_type']
3477        self.assertIn('pg_type', dbtypes)
3478        self.assertEqual(typ, 'pg_type' if self.regtypes else 'record')
3479        self.assertIsInstance(typ.oid, int)
3480        self.assertEqual(typ.pgtype, 'pg_type')
3481        self.assertEqual(typ.regtype, 'pg_type')
3482        self.assertEqual(typ.simple, 'record')
3483        self.assertEqual(typ.typtype, 'c')
3484        self.assertEqual(typ.category, 'C')
3485        self.assertEqual(typ.delim, ',')
3486        self.assertNotEqual(typ.relid, 0)
3487        attnames = typ.attnames
3488        self.assertIsInstance(attnames, dict)
3489        self.assertIs(attnames, dbtypes.get_attnames('pg_type'))
3490        self.assertEqual(list(attnames)[0], 'typname')
3491        typname = attnames['typname']
3492        self.assertEqual(typname, 'name' if self.regtypes else 'text')
3493        self.assertEqual(typname.typtype, 'b')  # base
3494        self.assertEqual(typname.category, 'S')  # string
3495        self.assertEqual(list(attnames)[3], 'typlen')
3496        typlen = attnames['typlen']
3497        self.assertEqual(typlen, 'smallint' if self.regtypes else 'int')
3498        self.assertEqual(typlen.typtype, 'b')  # base
3499        self.assertEqual(typlen.category, 'N')  # numeric
3500
3501    def testDbTypesTypecast(self):
3502        dbtypes = self.db.dbtypes
3503        self.assertIsInstance(dbtypes, dict)
3504        self.assertNotIn('int4', dbtypes)
3505        self.assertIs(dbtypes.get_typecast('int4'), int)
3506        dbtypes.set_typecast('int4', float)
3507        self.assertIs(dbtypes.get_typecast('int4'), float)
3508        dbtypes.reset_typecast('int4')
3509        self.assertIs(dbtypes.get_typecast('int4'), int)
3510        dbtypes.set_typecast('int4', float)
3511        self.assertIs(dbtypes.get_typecast('int4'), float)
3512        dbtypes.reset_typecast()
3513        self.assertIs(dbtypes.get_typecast('int4'), int)
3514        self.assertNotIn('circle', dbtypes)
3515        self.assertIsNone(dbtypes.get_typecast('circle'))
3516        squared_circle = lambda v: 'Squared Circle: %s' % v
3517        dbtypes.set_typecast('circle', squared_circle)
3518        self.assertIs(dbtypes.get_typecast('circle'), squared_circle)
3519        r = self.db.query("select '0,0,1'::circle").getresult()[0][0]
3520        self.assertIn('circle', dbtypes)
3521        self.assertEqual(r, 'Squared Circle: <(0,0),1>')
3522        self.assertEqual(dbtypes.typecast('Impossible', 'circle'),
3523            'Squared Circle: Impossible')
3524        dbtypes.reset_typecast('circle')
3525        self.assertIsNone(dbtypes.get_typecast('circle'))
3526
3527    def testGetSetTypeCast(self):
3528        get_typecast = pg.get_typecast
3529        set_typecast = pg.set_typecast
3530        dbtypes = self.db.dbtypes
3531        self.assertIsInstance(dbtypes, dict)
3532        self.assertNotIn('int4', dbtypes)
3533        self.assertNotIn('real', dbtypes)
3534        self.assertNotIn('bool', dbtypes)
3535        self.assertIs(get_typecast('int4'), int)
3536        self.assertIs(get_typecast('float4'), float)
3537        self.assertIs(get_typecast('bool'), pg.cast_bool)
3538        cast_circle = get_typecast('circle')
3539        self.addCleanup(set_typecast, 'circle', cast_circle)
3540        squared_circle = lambda v: 'Squared Circle: %s' % v
3541        self.assertNotIn('circle', dbtypes)
3542        set_typecast('circle', squared_circle)
3543        self.assertNotIn('circle', dbtypes)
3544        self.assertIs(get_typecast('circle'), squared_circle)
3545        r = self.db.query("select '0,0,1'::circle").getresult()[0][0]
3546        self.assertIn('circle', dbtypes)
3547        self.assertEqual(r, 'Squared Circle: <(0,0),1>')
3548        set_typecast('circle', cast_circle)
3549        self.assertIs(get_typecast('circle'), cast_circle)
3550
3551    def testNotificationHandler(self):
3552        # the notification handler itself is tested separately
3553        f = self.db.notification_handler
3554        callback = lambda arg_dict: None
3555        handler = f('test', callback)
3556        self.assertIsInstance(handler, pg.NotificationHandler)
3557        self.assertIs(handler.db, self.db)
3558        self.assertEqual(handler.event, 'test')
3559        self.assertEqual(handler.stop_event, 'stop_test')
3560        self.assertIs(handler.callback, callback)
3561        self.assertIsInstance(handler.arg_dict, dict)
3562        self.assertEqual(handler.arg_dict, {})
3563        self.assertIsNone(handler.timeout)
3564        self.assertFalse(handler.listening)
3565        handler.close()
3566        self.assertIsNone(handler.db)
3567        self.db.reopen()
3568        self.assertIsNone(handler.db)
3569        handler = f('test2', callback, timeout=2)
3570        self.assertIsInstance(handler, pg.NotificationHandler)
3571        self.assertIs(handler.db, self.db)
3572        self.assertEqual(handler.event, 'test2')
3573        self.assertEqual(handler.stop_event, 'stop_test2')
3574        self.assertIs(handler.callback, callback)
3575        self.assertIsInstance(handler.arg_dict, dict)
3576        self.assertEqual(handler.arg_dict, {})
3577        self.assertEqual(handler.timeout, 2)
3578        self.assertFalse(handler.listening)
3579        handler.close()
3580        self.assertIsNone(handler.db)
3581        self.db.reopen()
3582        self.assertIsNone(handler.db)
3583        arg_dict = {'testing': 3}
3584        handler = f('test3', callback, arg_dict=arg_dict)
3585        self.assertIsInstance(handler, pg.NotificationHandler)
3586        self.assertIs(handler.db, self.db)
3587        self.assertEqual(handler.event, 'test3')
3588        self.assertEqual(handler.stop_event, 'stop_test3')
3589        self.assertIs(handler.callback, callback)
3590        self.assertIs(handler.arg_dict, arg_dict)
3591        self.assertEqual(arg_dict['testing'], 3)
3592        self.assertIsNone(handler.timeout)
3593        self.assertFalse(handler.listening)
3594        handler.close()
3595        self.assertIsNone(handler.db)
3596        self.db.reopen()
3597        self.assertIsNone(handler.db)
3598        handler = f('test4', callback, stop_event='stop4')
3599        self.assertIsInstance(handler, pg.NotificationHandler)
3600        self.assertIs(handler.db, self.db)
3601        self.assertEqual(handler.event, 'test4')
3602        self.assertEqual(handler.stop_event, 'stop4')
3603        self.assertIs(handler.callback, callback)
3604        self.assertIsInstance(handler.arg_dict, dict)
3605        self.assertEqual(handler.arg_dict, {})
3606        self.assertIsNone(handler.timeout)
3607        self.assertFalse(handler.listening)
3608        handler.close()
3609        self.assertIsNone(handler.db)
3610        self.db.reopen()
3611        self.assertIsNone(handler.db)
3612        arg_dict = {'testing': 5}
3613        handler = f('test5', callback, arg_dict, 1.5, 'stop5')
3614        self.assertIsInstance(handler, pg.NotificationHandler)
3615        self.assertIs(handler.db, self.db)
3616        self.assertEqual(handler.event, 'test5')
3617        self.assertEqual(handler.stop_event, 'stop5')
3618        self.assertIs(handler.callback, callback)
3619        self.assertIs(handler.arg_dict, arg_dict)
3620        self.assertEqual(arg_dict['testing'], 5)
3621        self.assertEqual(handler.timeout, 1.5)
3622        self.assertFalse(handler.listening)
3623        handler.close()
3624        self.assertIsNone(handler.db)
3625        self.db.reopen()
3626        self.assertIsNone(handler.db)
3627
3628
3629class TestDBClassNonStdOpts(TestDBClass):
3630    """Test the methods of the DB class with non-standard global options."""
3631
3632    @classmethod
3633    def setUpClass(cls):
3634        cls.saved_options = {}
3635        cls.set_option('decimal', float)
3636        not_bool = not pg.get_bool()
3637        cls.set_option('bool', not_bool)
3638        not_bytea_escaped = not pg.get_bytea_escaped()
3639        cls.set_option('bytea_escaped', not_bytea_escaped)
3640        cls.set_option('namedresult', None)
3641        cls.set_option('jsondecode', None)
3642        cls.regtypes = not DB().use_regtypes()
3643        super(TestDBClassNonStdOpts, cls).setUpClass()
3644
3645    @classmethod
3646    def tearDownClass(cls):
3647        super(TestDBClassNonStdOpts, cls).tearDownClass()
3648        cls.reset_option('jsondecode')
3649        cls.reset_option('namedresult')
3650        cls.reset_option('bool')
3651        cls.reset_option('bytea_escaped')
3652        cls.reset_option('decimal')
3653
3654    @classmethod
3655    def set_option(cls, option, value):
3656        cls.saved_options[option] = getattr(pg, 'get_' + option)()
3657        return getattr(pg, 'set_' + option)(value)
3658
3659    @classmethod
3660    def reset_option(cls, option):
3661        return getattr(pg, 'set_' + option)(cls.saved_options[option])
3662
3663
3664class TestDBClassAdapter(unittest.TestCase):
3665    """Test the adapter object associatd with the DB class."""
3666
3667    def setUp(self):
3668        self.db = DB()
3669        self.adapter = self.db.adapter
3670
3671    def tearDown(self):
3672        try:
3673            self.db.close()
3674        except pg.InternalError:
3675            pass
3676
3677    def testGuessSimpleType(self):
3678        f = self.adapter.guess_simple_type
3679        self.assertEqual(f(pg.Bytea(b'test')), 'bytea')
3680        self.assertEqual(f('string'), 'text')
3681        self.assertEqual(f(b'string'), 'text')
3682        self.assertEqual(f(True), 'bool')
3683        self.assertEqual(f(3), 'int')
3684        self.assertEqual(f(2.75), 'float')
3685        self.assertEqual(f(Decimal('4.25')), 'num')
3686        self.assertEqual(f(date(2016, 1, 30)), 'date')
3687        self.assertEqual(f([1, 2, 3]), 'int[]')
3688        self.assertEqual(f([[[123]]]), 'int[]')
3689        self.assertEqual(f(['a', 'b', 'c']), 'text[]')
3690        self.assertEqual(f([[['abc']]]), 'text[]')
3691        self.assertEqual(f([False, True]), 'bool[]')
3692        self.assertEqual(f([[[False]]]), 'bool[]')
3693        r = f(('string', True, 3, 2.75, [1], [False]))
3694        self.assertEqual(r, 'record')
3695        self.assertEqual(list(r.attnames.values()),
3696            ['text', 'bool', 'int', 'float', 'int[]', 'bool[]'])
3697
3698    def testAdaptQueryTypedList(self):
3699        format_query = self.adapter.format_query
3700        self.assertRaises(TypeError, format_query,
3701            '%s,%s', (1, 2), ('int2',))
3702        self.assertRaises(TypeError, format_query,
3703            '%s,%s', (1,), ('int2', 'int2'))
3704        values = (3, 7.5, 'hello', True)
3705        types = ('int4', 'float4', 'text', 'bool')
3706        sql, params = format_query("select %s,%s,%s,%s", values, types)
3707        self.assertEqual(sql, 'select $1,$2,$3,$4')
3708        self.assertEqual(params, [3, 7.5, 'hello', 't'])
3709        types = ('bool', 'bool', 'bool', 'bool')
3710        sql, params = format_query("select %s,%s,%s,%s", values, types)
3711        self.assertEqual(sql, 'select $1,$2,$3,$4')
3712        self.assertEqual(params, ['t', 't', 'f', 't'])
3713        values = ('2016-01-30', 'current_date')
3714        types = ('date', 'date')
3715        sql, params = format_query("values(%s,%s)", values, types)
3716        self.assertEqual(sql, 'values($1,current_date)')
3717        self.assertEqual(params, ['2016-01-30'])
3718        values = ([1, 2, 3], ['a', 'b', 'c'])
3719        types = ('_int4', '_text')
3720        sql, params = format_query("%s::int4[],%s::text[]", values, types)
3721        self.assertEqual(sql, '$1::int4[],$2::text[]')
3722        self.assertEqual(params, ['{1,2,3}', '{a,b,c}'])
3723        types = ('_bool', '_bool')
3724        sql, params = format_query("%s::bool[],%s::bool[]", values, types)
3725        self.assertEqual(sql, '$1::bool[],$2::bool[]')
3726        self.assertEqual(params, ['{t,t,t}', '{f,f,f}'])
3727        values = [(3, 7.5, 'hello', True, [123], ['abc'])]
3728        t = self.adapter.simple_type
3729        typ = t('record')
3730        typ._get_attnames = lambda _self: pg.AttrDict([
3731            ('i', t('int')), ('f', t('float')),
3732            ('t', t('text')), ('b', t('bool')),
3733            ('i3', t('int[]')), ('t3', t('text[]'))])
3734        types = [typ]
3735        sql, params = format_query('select %s', values, types)
3736        self.assertEqual(sql, 'select $1')
3737        self.assertEqual(params, ['(3,7.5,hello,t,{123},{abc})'])
3738
3739    def testAdaptQueryTypedDict(self):
3740        format_query = self.adapter.format_query
3741        self.assertRaises(TypeError, format_query,
3742            '%s,%s', dict(i1=1, i2=2), dict(i1='int2'))
3743        values = dict(i=3, f=7.5, t='hello', b=True)
3744        types = dict(i='int4', f='float4',
3745            t='text', b='bool')
3746        sql, params = format_query(
3747            "select %(i)s,%(f)s,%(t)s,%(b)s", values, types)
3748        self.assertEqual(sql, 'select $3,$2,$4,$1')
3749        self.assertEqual(params, ['t', 7.5, 3, 'hello'])
3750        types = dict(i='bool', f='bool',
3751            t='bool', b='bool')
3752        sql, params = format_query(
3753            "select %(i)s,%(f)s,%(t)s,%(b)s", values, types)
3754        self.assertEqual(sql, 'select $3,$2,$4,$1')
3755        self.assertEqual(params, ['t', 't', 't', 'f'])
3756        values = dict(d1='2016-01-30', d2='current_date')
3757        types = dict(d1='date', d2='date')
3758        sql, params = format_query("values(%(d1)s,%(d2)s)", values, types)
3759        self.assertEqual(sql, 'values($1,current_date)')
3760        self.assertEqual(params, ['2016-01-30'])
3761        values = dict(i=[1, 2, 3], t=['a', 'b', 'c'])
3762        types = dict(i='_int4', t='_text')
3763        sql, params = format_query(
3764            "%(i)s::int4[],%(t)s::text[]", values, types)
3765        self.assertEqual(sql, '$1::int4[],$2::text[]')
3766        self.assertEqual(params, ['{1,2,3}', '{a,b,c}'])
3767        types = dict(i='_bool', t='_bool')
3768        sql, params = format_query(
3769            "%(i)s::bool[],%(t)s::bool[]", values, types)
3770        self.assertEqual(sql, '$1::bool[],$2::bool[]')
3771        self.assertEqual(params, ['{t,t,t}', '{f,f,f}'])
3772        values = dict(record=(3, 7.5, 'hello', True, [123], ['abc']))
3773        t = self.adapter.simple_type
3774        typ = t('record')
3775        typ._get_attnames = lambda _self: pg.AttrDict([
3776            ('i', t('int')), ('f', t('float')),
3777            ('t', t('text')), ('b', t('bool')),
3778            ('i3', t('int[]')), ('t3', t('text[]'))])
3779        types = dict(record=typ)
3780        sql, params = format_query('select %(record)s', values, types)
3781        self.assertEqual(sql, 'select $1')
3782        self.assertEqual(params, ['(3,7.5,hello,t,{123},{abc})'])
3783
3784    def testAdaptQueryUntypedList(self):
3785        format_query = self.adapter.format_query
3786        values = (3, 7.5, 'hello', True)
3787        sql, params = format_query("select %s,%s,%s,%s", values)
3788        self.assertEqual(sql, 'select $1,$2,$3,$4')
3789        self.assertEqual(params, [3, 7.5, 'hello', 't'])
3790        values = [date(2016, 1, 30), 'current_date']
3791        sql, params = format_query("values(%s,%s)", values)
3792        self.assertEqual(sql, 'values($1,$2)')
3793        self.assertEqual(params, values)
3794        values = ([1, 2, 3], ['a', 'b', 'c'], [True, False, True])
3795        sql, params = format_query("%s,%s,%s", values)
3796        self.assertEqual(sql, "$1,$2,$3")
3797        self.assertEqual(params, ['{1,2,3}', '{a,b,c}', '{t,f,t}'])
3798        values = ([[1, 2], [3, 4]], [['a', 'b'], ['c', 'd']],
3799            [[True, False], [False, True]])
3800        sql, params = format_query("%s,%s,%s", values)
3801        self.assertEqual(sql, "$1,$2,$3")
3802        self.assertEqual(params, [
3803            '{{1,2},{3,4}}', '{{a,b},{c,d}}', '{{t,f},{f,t}}'])
3804        values = [(3, 7.5, 'hello', True, [123], ['abc'])]
3805        sql, params = format_query('select %s', values)
3806        self.assertEqual(sql, 'select $1')
3807        self.assertEqual(params, ['(3,7.5,hello,t,{123},{abc})'])
3808
3809    def testAdaptQueryUntypedDict(self):
3810        format_query = self.adapter.format_query
3811        values = dict(i=3, f=7.5, t='hello', b=True)
3812        sql, params = format_query(
3813            "select %(i)s,%(f)s,%(t)s,%(b)s", values)
3814        self.assertEqual(sql, 'select $3,$2,$4,$1')
3815        self.assertEqual(params, ['t', 7.5, 3, 'hello'])
3816        values = dict(d1='2016-01-30', d2='current_date')
3817        sql, params = format_query("values(%(d1)s,%(d2)s)", values)
3818        self.assertEqual(sql, 'values($1,$2)')
3819        self.assertEqual(params, [values['d1'], values['d2']])
3820        values = dict(i=[1, 2, 3], t=['a', 'b', 'c'], b=[True, False, True])
3821        sql, params = format_query("%(i)s,%(t)s,%(b)s", values)
3822        self.assertEqual(sql, "$2,$3,$1")
3823        self.assertEqual(params, ['{t,f,t}', '{1,2,3}', '{a,b,c}'])
3824        values = dict(i=[[1, 2], [3, 4]], t=[['a', 'b'], ['c', 'd']],
3825            b=[[True, False], [False, True]])
3826        sql, params = format_query("%(i)s,%(t)s,%(b)s", values)
3827        self.assertEqual(sql, "$2,$3,$1")
3828        self.assertEqual(params, [
3829            '{{t,f},{f,t}}', '{{1,2},{3,4}}', '{{a,b},{c,d}}'])
3830        values = dict(record=(3, 7.5, 'hello', True, [123], ['abc']))
3831        sql, params = format_query('select %(record)s', values)
3832        self.assertEqual(sql, 'select $1')
3833        self.assertEqual(params, ['(3,7.5,hello,t,{123},{abc})'])
3834
3835    def testAdaptQueryInlineList(self):
3836        format_query = self.adapter.format_query
3837        values = (3, 7.5, 'hello', True)
3838        sql, params = format_query("select %s,%s,%s,%s", values, inline=True)
3839        self.assertEqual(sql, "select 3,7.5,'hello',true")
3840        self.assertEqual(params, [])
3841        values = [date(2016, 1, 30), 'current_date']
3842        sql, params = format_query("values(%s,%s)", values, inline=True)
3843        self.assertEqual(sql, "values('2016-01-30','current_date')")
3844        self.assertEqual(params, [])
3845        values = ([1, 2, 3], ['a', 'b', 'c'], [True, False, True])
3846        sql, params = format_query("%s,%s,%s", values, inline=True)
3847        self.assertEqual(sql,
3848            "ARRAY[1,2,3],ARRAY['a','b','c'],ARRAY[true,false,true]")
3849        self.assertEqual(params, [])
3850        values = ([[1, 2], [3, 4]], [['a', 'b'], ['c', 'd']],
3851            [[True, False], [False, True]])
3852        sql, params = format_query("%s,%s,%s", values, inline=True)
3853        self.assertEqual(sql, "ARRAY[[1,2],[3,4]],ARRAY[['a','b'],['c','d']],"
3854            "ARRAY[[true,false],[false,true]]")
3855        self.assertEqual(params, [])
3856        values = [(3, 7.5, 'hello', True, [123], ['abc'])]
3857        sql, params = format_query('select %s', values, inline=True)
3858        self.assertEqual(sql,
3859            "select (3,7.5,'hello',true,ARRAY[123],ARRAY['abc'])")
3860        self.assertEqual(params, [])
3861
3862    def testAdaptQueryInlineDict(self):
3863        format_query = self.adapter.format_query
3864        values = dict(i=3, f=7.5, t='hello', b=True)
3865        sql, params = format_query(
3866            "select %(i)s,%(f)s,%(t)s,%(b)s", values, inline=True)
3867        self.assertEqual(sql, "select 3,7.5,'hello',true")
3868        self.assertEqual(params, [])
3869        values = dict(d1='2016-01-30', d2='current_date')
3870        sql, params = format_query(
3871            "values(%(d1)s,%(d2)s)", values, inline=True)
3872        self.assertEqual(sql, "values('2016-01-30','current_date')")
3873        self.assertEqual(params, [])
3874        values = dict(i=[1, 2, 3], t=['a', 'b', 'c'], b=[True, False, True])
3875        sql, params = format_query("%(i)s,%(t)s,%(b)s", values, inline=True)
3876        self.assertEqual(sql,
3877            "ARRAY[1,2,3],ARRAY['a','b','c'],ARRAY[true,false,true]")
3878        self.assertEqual(params, [])
3879        values = dict(i=[[1, 2], [3, 4]], t=[['a', 'b'], ['c', 'd']],
3880            b=[[True, False], [False, True]])
3881        sql, params = format_query("%(i)s,%(t)s,%(b)s", values, inline=True)
3882        self.assertEqual(sql, "ARRAY[[1,2],[3,4]],ARRAY[['a','b'],['c','d']],"
3883            "ARRAY[[true,false],[false,true]]")
3884        self.assertEqual(params, [])
3885        values = dict(record=(3, 7.5, 'hello', True, [123], ['abc']))
3886        sql, params = format_query('select %(record)s', values, inline=True)
3887        self.assertEqual(sql,
3888            "select (3,7.5,'hello',true,ARRAY[123],ARRAY['abc'])")
3889        self.assertEqual(params, [])
3890
3891    def testAdaptQueryWithPgRepr(self):
3892        format_query = self.adapter.format_query
3893        self.assertRaises(TypeError, format_query,
3894            '%s', object(), inline=True)
3895        class TestObject:
3896            def __pg_repr__(self):
3897                return "'adapted'"
3898        sql, params = format_query('select %s', [TestObject()], inline=True)
3899        self.assertEqual(sql, "select 'adapted'")
3900        self.assertEqual(params, [])
3901        sql, params = format_query('select %s', [[TestObject()]], inline=True)
3902        self.assertEqual(sql, "select ARRAY['adapted']")
3903        self.assertEqual(params, [])
3904
3905
3906class TestSchemas(unittest.TestCase):
3907    """Test correct handling of schemas (namespaces)."""
3908
3909    cls_set_up = False
3910
3911    @classmethod
3912    def setUpClass(cls):
3913        db = DB()
3914        query = db.query
3915        for num_schema in range(5):
3916            if num_schema:
3917                schema = "s%d" % num_schema
3918                query("drop schema if exists %s cascade" % (schema,))
3919                try:
3920                    query("create schema %s" % (schema,))
3921                except pg.ProgrammingError:
3922                    raise RuntimeError("The test user cannot create schemas.\n"
3923                        "Grant create on database %s to the user"
3924                        " for running these tests." % dbname)
3925            else:
3926                schema = "public"
3927                query("drop table if exists %s.t" % (schema,))
3928                query("drop table if exists %s.t%d" % (schema, num_schema))
3929            query("create table %s.t with oids as select 1 as n, %d as d"
3930                  % (schema, num_schema))
3931            query("create table %s.t%d with oids as select 1 as n, %d as d"
3932                  % (schema, num_schema, num_schema))
3933        db.close()
3934        cls.cls_set_up = True
3935
3936    @classmethod
3937    def tearDownClass(cls):
3938        db = DB()
3939        query = db.query
3940        for num_schema in range(5):
3941            if num_schema:
3942                schema = "s%d" % num_schema
3943                query("drop schema %s cascade" % (schema,))
3944            else:
3945                schema = "public"
3946                query("drop table %s.t" % (schema,))
3947                query("drop table %s.t%d" % (schema, num_schema))
3948        db.close()
3949
3950    def setUp(self):
3951        self.assertTrue(self.cls_set_up)
3952        self.db = DB()
3953
3954    def tearDown(self):
3955        self.doCleanups()
3956        self.db.close()
3957
3958    def testGetTables(self):
3959        tables = self.db.get_tables()
3960        for num_schema in range(5):
3961            if num_schema:
3962                schema = "s" + str(num_schema)
3963            else:
3964                schema = "public"
3965            for t in (schema + ".t",
3966                    schema + ".t" + str(num_schema)):
3967                self.assertIn(t, tables)
3968
3969    def testGetAttnames(self):
3970        get_attnames = self.db.get_attnames
3971        query = self.db.query
3972        result = {'oid': 'int', 'd': 'int', 'n': 'int'}
3973        r = get_attnames("t")
3974        self.assertEqual(r, result)
3975        r = get_attnames("s4.t4")
3976        self.assertEqual(r, result)
3977        query("drop table if exists s3.t3m")
3978        self.addCleanup(query, "drop table s3.t3m")
3979        query("create table s3.t3m with oids as select 1 as m")
3980        result_m = {'oid': 'int', 'm': 'int'}
3981        r = get_attnames("s3.t3m")
3982        self.assertEqual(r, result_m)
3983        query("set search_path to s1,s3")
3984        r = get_attnames("t3")
3985        self.assertEqual(r, result)
3986        r = get_attnames("t3m")
3987        self.assertEqual(r, result_m)
3988
3989    def testGet(self):
3990        get = self.db.get
3991        query = self.db.query
3992        PrgError = pg.ProgrammingError
3993        self.assertEqual(get("t", 1, 'n')['d'], 0)
3994        self.assertEqual(get("t0", 1, 'n')['d'], 0)
3995        self.assertEqual(get("public.t", 1, 'n')['d'], 0)
3996        self.assertEqual(get("public.t0", 1, 'n')['d'], 0)
3997        self.assertRaises(PrgError, get, "public.t1", 1, 'n')
3998        self.assertEqual(get("s1.t1", 1, 'n')['d'], 1)
3999        self.assertEqual(get("s3.t", 1, 'n')['d'], 3)
4000        query("set search_path to s2,s4")
4001        self.assertRaises(PrgError, get, "t1", 1, 'n')
4002        self.assertEqual(get("t4", 1, 'n')['d'], 4)
4003        self.assertRaises(PrgError, get, "t3", 1, 'n')
4004        self.assertEqual(get("t", 1, 'n')['d'], 2)
4005        self.assertEqual(get("s3.t3", 1, 'n')['d'], 3)
4006        query("set search_path to s1,s3")
4007        self.assertRaises(PrgError, get, "t2", 1, 'n')
4008        self.assertEqual(get("t3", 1, 'n')['d'], 3)
4009        self.assertRaises(PrgError, get, "t4", 1, 'n')
4010        self.assertEqual(get("t", 1, 'n')['d'], 1)
4011        self.assertEqual(get("s4.t4", 1, 'n')['d'], 4)
4012
4013    def testMunging(self):
4014        get = self.db.get
4015        query = self.db.query
4016        r = get("t", 1, 'n')
4017        self.assertIn('oid(t)', r)
4018        query("set search_path to s2")
4019        r = get("t2", 1, 'n')
4020        self.assertIn('oid(t2)', r)
4021        query("set search_path to s3")
4022        r = get("t", 1, 'n')
4023        self.assertIn('oid(t)', r)
4024
4025
4026class TestDebug(unittest.TestCase):
4027    """Test the debug attribute of the DB class."""
4028
4029    def setUp(self):
4030        self.db = DB()
4031        self.query = self.db.query
4032        self.debug = self.db.debug
4033        self.output = StringIO()
4034        self.stdout, sys.stdout = sys.stdout, self.output
4035
4036    def tearDown(self):
4037        sys.stdout = self.stdout
4038        self.output.close()
4039        self.db.debug = debug
4040        self.db.close()
4041
4042    def get_output(self):
4043        return self.output.getvalue()
4044
4045    def send_queries(self):
4046        self.db.query("select 1")
4047        self.db.query("select 2")
4048
4049    def testDebugDefault(self):
4050        if debug:
4051            self.assertEqual(self.db.debug, debug)
4052        else:
4053            self.assertIsNone(self.db.debug)
4054
4055    def testDebugIsFalse(self):
4056        self.db.debug = False
4057        self.send_queries()
4058        self.assertEqual(self.get_output(), "")
4059
4060    def testDebugIsTrue(self):
4061        self.db.debug = True
4062        self.send_queries()
4063        self.assertEqual(self.get_output(), "select 1\nselect 2\n")
4064
4065    def testDebugIsString(self):
4066        self.db.debug = "Test with string: %s."
4067        self.send_queries()
4068        self.assertEqual(self.get_output(),
4069            "Test with string: select 1.\nTest with string: select 2.\n")
4070
4071    def testDebugIsFileLike(self):
4072        with tempfile.TemporaryFile('w+') as debug_file:
4073            self.db.debug = debug_file
4074            self.send_queries()
4075            debug_file.seek(0)
4076            output = debug_file.read()
4077            self.assertEqual(output, "select 1\nselect 2\n")
4078            self.assertEqual(self.get_output(), "")
4079
4080    def testDebugIsCallable(self):
4081        output = []
4082        self.db.debug = output.append
4083        self.db.query("select 1")
4084        self.db.query("select 2")
4085        self.assertEqual(output, ["select 1", "select 2"])
4086        self.assertEqual(self.get_output(), "")
4087
4088    def testDebugMultipleArgs(self):
4089        output = []
4090        self.db.debug = output.append
4091        args = ['Error', 42, {1: 'a', 2: 'b'}, [3, 5, 7]]
4092        self.db._do_debug(*args)
4093        self.assertEqual(output, ['\n'.join(str(arg) for arg in args)])
4094        self.assertEqual(self.get_output(), "")
4095
4096
4097if __name__ == '__main__':
4098    unittest.main()
Note: See TracBrowser for help on using the repository browser.