source: trunk/tests/test_classic_dbwrapper.py @ 798

Last change on this file since 798 was 798, checked in by cito, 3 years ago

Port type cache and typecasting from pgdb to pg

So far, the typecasting in the classic module was been only done by
the C extension module and was not extensible through typecasting
functions in Python. This has now been made extensible by adding
a cast hook to the C extension module which has been hooked up to
a new type cache object that holds information on the types and the
associated typecast functions. All of this works very similar to the
pgdb module now, except that the basic types are still handled by
the C extension module and the Python typecast functions are only
called via the hook for types which are not supported internally.

Also added tests and a chapter on the type cache in the documentation,
and cleaned up the error messages in the C extension module.

  • Property svn:executable set to *
  • Property svn:keywords set to Id
File size: 152.5 KB
Line 
1#! /usr/bin/python
2# -*- coding: utf-8 -*-
3
4"""Test the classic PyGreSQL interface.
5
6Sub-tests for the DB wrapper object.
7
8Contributed by Christoph Zwerschke.
9
10These tests need a database to test against.
11"""
12
13try:
14    import unittest2 as unittest  # for Python < 2.7
15except ImportError:
16    import unittest
17
18import os
19import sys
20import tempfile
21import json
22
23import pg  # the module under test
24
25from decimal import Decimal
26from operator import itemgetter
27
28# We need a database to test against.  If LOCAL_PyGreSQL.py exists we will
29# get our information from that.  Otherwise we use the defaults.
30# The current user must have create schema privilege on the database.
31dbname = 'unittest'
32dbhost = None
33dbport = 5432
34
35debug = False  # let DB wrapper print debugging output
36
37try:
38    from .LOCAL_PyGreSQL import *
39except (ImportError, ValueError):
40    try:
41        from LOCAL_PyGreSQL import *
42    except ImportError:
43        pass
44
45try:
46    long
47except NameError:  # Python >= 3.0
48    long = int
49
50try:
51    unicode
52except NameError:  # Python >= 3.0
53    unicode = str
54
55try:
56    from collections import OrderedDict
57except ImportError:  # Python 2.6 or 3.0
58    OrderedDict = dict
59
60if str is bytes:
61    from StringIO import StringIO
62else:
63    from io import StringIO
64
65windows = os.name == 'nt'
66
67# There is a known a bug in libpq under Windows which can cause
68# the interface to crash when calling PQhost():
69do_not_ask_for_host = windows
70do_not_ask_for_host_reason = 'libpq issue on Windows'
71
72
73def DB():
74    """Create a DB wrapper object connecting to the test database."""
75    db = pg.DB(dbname, dbhost, dbport)
76    if debug:
77        db.debug = debug
78    db.query("set client_min_messages=warning")
79    return db
80
81
82class TestAttrDict(unittest.TestCase):
83    """Test the simple ordered dictionary for attribute names."""
84
85    cls = pg.AttrDict
86    base = OrderedDict
87
88    def testInit(self):
89        a = self.cls()
90        self.assertIsInstance(a, self.base)
91        self.assertEqual(a, self.base())
92        items = [('id', 'int'), ('name', 'text')]
93        a = self.cls(items)
94        self.assertIsInstance(a, self.base)
95        self.assertEqual(a, self.base(items))
96        iteritems = iter(items)
97        a = self.cls(iteritems)
98        self.assertIsInstance(a, self.base)
99        self.assertEqual(a, self.base(items))
100
101    def testIter(self):
102        a = self.cls()
103        self.assertEqual(list(a), [])
104        keys = ['id', 'name', 'age']
105        items = [(key, None) for key in keys]
106        a = self.cls(items)
107        self.assertEqual(list(a), keys)
108
109    def testKeys(self):
110        a = self.cls()
111        self.assertEqual(list(a.keys()), [])
112        keys = ['id', 'name', 'age']
113        items = [(key, None) for key in keys]
114        a = self.cls(items)
115        self.assertEqual(list(a.keys()), keys)
116
117    def testValues(self):
118        a = self.cls()
119        self.assertEqual(list(a.values()), [])
120        items = [('id', 'int'), ('name', 'text')]
121        values = [item[1] for item in items]
122        a = self.cls(items)
123        self.assertEqual(list(a.values()), values)
124
125    def testItems(self):
126        a = self.cls()
127        self.assertEqual(list(a.items()), [])
128        items = [('id', 'int'), ('name', 'text')]
129        a = self.cls(items)
130        self.assertEqual(list(a.items()), items)
131
132    def testGet(self):
133        a = self.cls([('id', 1)])
134        try:
135            self.assertEqual(a['id'], 1)
136        except KeyError:
137            self.fail('AttrDict should be readable')
138
139    def testSet(self):
140        a = self.cls()
141        try:
142            a['id'] = 1
143        except TypeError:
144            pass
145        else:
146            self.fail('AttrDict should be read-only')
147
148    def testDel(self):
149        a = self.cls([('id', 1)])
150        try:
151            del a['id']
152        except TypeError:
153            pass
154        else:
155            self.fail('AttrDict should be read-only')
156
157    def testWriteMethods(self):
158        a = self.cls([('id', 1)])
159        self.assertEqual(a['id'], 1)
160        for method in 'clear', 'update', 'pop', 'setdefault', 'popitem':
161            method = getattr(a, method)
162            self.assertRaises(TypeError, method, a)
163
164
165class TestDBClassBasic(unittest.TestCase):
166    """Test existence of the DB class wrapped pg connection methods."""
167
168    def setUp(self):
169        self.db = DB()
170
171    def tearDown(self):
172        try:
173            self.db.close()
174        except pg.InternalError:
175            pass
176
177    def testAllDBAttributes(self):
178        attributes = [
179            'abort',
180            'begin',
181            'cancel', 'clear', 'close', 'commit',
182            'db', 'dbname', 'dbtypes',
183            'debug', 'decode_json', 'delete',
184            'encode_json', 'end', 'endcopy', 'error',
185            'escape_bytea', 'escape_identifier',
186            'escape_literal', 'escape_string',
187            'fileno',
188            'get', 'get_as_dict', 'get_as_list',
189            'get_attnames', 'get_cast_hook',
190            'get_databases', 'get_notice_receiver',
191            'get_parameter', 'get_relations', 'get_tables',
192            'getline', 'getlo', 'getnotify',
193            'has_table_privilege', 'host',
194            'insert', 'inserttable',
195            'locreate', 'loimport',
196            'notification_handler',
197            'options',
198            'parameter', 'pkey', 'port',
199            'protocol_version', 'putline',
200            'query',
201            'release', 'reopen', 'reset', 'rollback',
202            'savepoint', 'server_version',
203            'set_cast_hook', 'set_notice_receiver',
204            'set_parameter',
205            'source', 'start', 'status',
206            'transaction', 'truncate',
207            'unescape_bytea', 'update', 'upsert',
208            'use_regtypes', 'user',
209        ]
210        db_attributes = [a for a in dir(self.db)
211            if not a.startswith('_')]
212        self.assertEqual(attributes, db_attributes)
213
214    def testAttributeDb(self):
215        self.assertEqual(self.db.db.db, dbname)
216
217    def testAttributeDbname(self):
218        self.assertEqual(self.db.dbname, dbname)
219
220    def testAttributeError(self):
221        error = self.db.error
222        self.assertTrue(not error or 'krb5_' in error)
223        self.assertEqual(self.db.error, self.db.db.error)
224
225    @unittest.skipIf(do_not_ask_for_host, do_not_ask_for_host_reason)
226    def testAttributeHost(self):
227        def_host = 'localhost'
228        host = self.db.host
229        self.assertIsInstance(host, str)
230        self.assertEqual(host, dbhost or def_host)
231        self.assertEqual(host, self.db.db.host)
232
233    def testAttributeOptions(self):
234        no_options = ''
235        options = self.db.options
236        self.assertEqual(options, no_options)
237        self.assertEqual(options, self.db.db.options)
238
239    def testAttributePort(self):
240        def_port = 5432
241        port = self.db.port
242        self.assertIsInstance(port, int)
243        self.assertEqual(port, dbport or def_port)
244        self.assertEqual(port, self.db.db.port)
245
246    def testAttributeProtocolVersion(self):
247        protocol_version = self.db.protocol_version
248        self.assertIsInstance(protocol_version, int)
249        self.assertTrue(2 <= protocol_version < 4)
250        self.assertEqual(protocol_version, self.db.db.protocol_version)
251
252    def testAttributeServerVersion(self):
253        server_version = self.db.server_version
254        self.assertIsInstance(server_version, int)
255        self.assertTrue(70400 <= server_version < 100000)
256        self.assertEqual(server_version, self.db.db.server_version)
257
258    def testAttributeStatus(self):
259        status_ok = 1
260        status = self.db.status
261        self.assertIsInstance(status, int)
262        self.assertEqual(status, status_ok)
263        self.assertEqual(status, self.db.db.status)
264
265    def testAttributeUser(self):
266        no_user = 'Deprecated facility'
267        user = self.db.user
268        self.assertTrue(user)
269        self.assertIsInstance(user, str)
270        self.assertNotEqual(user, no_user)
271        self.assertEqual(user, self.db.db.user)
272
273    def testMethodEscapeLiteral(self):
274        self.assertEqual(self.db.escape_literal(''), "''")
275
276    def testMethodEscapeIdentifier(self):
277        self.assertEqual(self.db.escape_identifier(''), '""')
278
279    def testMethodEscapeString(self):
280        self.assertEqual(self.db.escape_string(''), '')
281
282    def testMethodEscapeBytea(self):
283        self.assertEqual(self.db.escape_bytea('').replace(
284            '\\x', '').replace('\\', ''), '')
285
286    def testMethodUnescapeBytea(self):
287        self.assertEqual(self.db.unescape_bytea(''), b'')
288
289    def testMethodDecodeJson(self):
290        self.assertEqual(self.db.decode_json('{}'), {})
291
292    def testMethodEncodeJson(self):
293        self.assertEqual(self.db.encode_json({}), '{}')
294
295    def testMethodQuery(self):
296        query = self.db.query
297        query("select 1+1")
298        query("select 1+$1+$2", 2, 3)
299        query("select 1+$1+$2", (2, 3))
300        query("select 1+$1+$2", [2, 3])
301        query("select 1+$1", 1)
302
303    def testMethodQueryEmpty(self):
304        self.assertRaises(ValueError, self.db.query, '')
305
306    def testMethodQueryProgrammingError(self):
307        try:
308            self.db.query("select 1/0")
309        except pg.ProgrammingError as error:
310            self.assertEqual(error.sqlstate, '22012')
311
312    def testMethodEndcopy(self):
313        try:
314            self.db.endcopy()
315        except IOError:
316            pass
317
318    def testMethodClose(self):
319        self.db.close()
320        try:
321            self.db.reset()
322        except pg.Error:
323            pass
324        else:
325            self.fail('Reset should give an error for a closed connection')
326        self.assertIsNone(self.db.db)
327        self.assertRaises(pg.InternalError, self.db.close)
328        self.assertRaises(pg.InternalError, self.db.query, 'select 1')
329        self.assertRaises(pg.InternalError, getattr, self.db, 'status')
330        self.assertRaises(pg.InternalError, getattr, self.db, 'error')
331        self.assertRaises(pg.InternalError, getattr, self.db, 'absent')
332
333    def testMethodReset(self):
334        con = self.db.db
335        self.db.reset()
336        self.assertIs(self.db.db, con)
337        self.db.query("select 1+1")
338        self.db.close()
339        self.assertRaises(pg.InternalError, self.db.reset)
340
341    def testMethodReopen(self):
342        con = self.db.db
343        self.db.reopen()
344        self.assertIsNot(self.db.db, con)
345        con = self.db.db
346        self.db.query("select 1+1")
347        self.db.close()
348        self.db.reopen()
349        self.assertIsNot(self.db.db, con)
350        self.db.query("select 1+1")
351        self.db.close()
352
353    def testExistingConnection(self):
354        db = pg.DB(self.db.db)
355        self.assertEqual(self.db.db, db.db)
356        self.assertTrue(db.db)
357        db.close()
358        self.assertTrue(db.db)
359        db.reopen()
360        self.assertTrue(db.db)
361        db.close()
362        self.assertTrue(db.db)
363        db = pg.DB(self.db)
364        self.assertEqual(self.db.db, db.db)
365        db = pg.DB(db=self.db.db)
366        self.assertEqual(self.db.db, db.db)
367
368        class DB2:
369            pass
370
371        db2 = DB2()
372        db2._cnx = self.db.db
373        db = pg.DB(db2)
374        self.assertEqual(self.db.db, db.db)
375
376
377class TestDBClass(unittest.TestCase):
378    """Test the methods of the DB class wrapped pg connection."""
379
380    maxDiff = 80 * 20
381
382    cls_set_up = False
383
384    regtypes = None
385
386    @classmethod
387    def setUpClass(cls):
388        db = DB()
389        db.query("drop table if exists test cascade")
390        db.query("create table test ("
391                 "i2 smallint, i4 integer, i8 bigint,"
392                 " d numeric, f4 real, f8 double precision, m money,"
393                 " v4 varchar(4), c4 char(4), t text)")
394        db.query("create or replace view test_view as"
395                 " select i4, v4 from test")
396        db.close()
397        cls.cls_set_up = True
398
399    @classmethod
400    def tearDownClass(cls):
401        db = DB()
402        db.query("drop table test cascade")
403        db.close()
404
405    def setUp(self):
406        self.assertTrue(self.cls_set_up)
407        self.db = DB()
408        if self.regtypes is None:
409            self.regtypes = self.db.use_regtypes()
410        else:
411            self.db.use_regtypes(self.regtypes)
412        query = self.db.query
413        query('set client_encoding=utf8')
414        query('set standard_conforming_strings=on')
415        query("set lc_monetary='C'")
416        query("set datestyle='ISO,YMD'")
417        query('set bytea_output=hex')
418
419    def tearDown(self):
420        self.doCleanups()
421        self.db.close()
422
423    def createTable(self, table, definition,
424                    temporary=True, oids=None, values=None):
425        query = self.db.query
426        if not '"' in table or '.' in table:
427            table = '"%s"' % table
428        if not temporary:
429            q = 'drop table if exists %s cascade' % table
430            query(q)
431            self.addCleanup(query, q)
432        temporary = 'temporary table' if temporary else 'table'
433        as_query = definition.startswith(('as ', 'AS '))
434        if not as_query and not definition.startswith('('):
435            definition = '(%s)' % definition
436        with_oids = 'with oids' if oids else 'without oids'
437        q = ['create', temporary, table]
438        if as_query:
439            q.extend([with_oids, definition])
440        else:
441            q.extend([definition, with_oids])
442        q = ' '.join(q)
443        query(q)
444        if values:
445            for params in values:
446                if not isinstance(params, (list, tuple)):
447                    params = [params]
448                values = ', '.join('$%d' % (n + 1) for n in range(len(params)))
449                q = "insert into %s values (%s)" % (table, values)
450                query(q, params)
451
452    def testClassName(self):
453        self.assertEqual(self.db.__class__.__name__, 'DB')
454
455    def testModuleName(self):
456        self.assertEqual(self.db.__module__, 'pg')
457        self.assertEqual(self.db.__class__.__module__, 'pg')
458
459    def testEscapeLiteral(self):
460        f = self.db.escape_literal
461        r = f(b"plain")
462        self.assertIsInstance(r, bytes)
463        self.assertEqual(r, b"'plain'")
464        r = f(u"plain")
465        self.assertIsInstance(r, unicode)
466        self.assertEqual(r, u"'plain'")
467        r = f(u"that's kÀse".encode('utf-8'))
468        self.assertIsInstance(r, bytes)
469        self.assertEqual(r, u"'that''s kÀse'".encode('utf-8'))
470        r = f(u"that's kÀse")
471        self.assertIsInstance(r, unicode)
472        self.assertEqual(r, u"'that''s kÀse'")
473        self.assertEqual(f(r"It's fine to have a \ inside."),
474                         r" E'It''s fine to have a \\ inside.'")
475        self.assertEqual(f('No "quotes" must be escaped.'),
476                         "'No \"quotes\" must be escaped.'")
477
478    def testEscapeIdentifier(self):
479        f = self.db.escape_identifier
480        r = f(b"plain")
481        self.assertIsInstance(r, bytes)
482        self.assertEqual(r, b'"plain"')
483        r = f(u"plain")
484        self.assertIsInstance(r, unicode)
485        self.assertEqual(r, u'"plain"')
486        r = f(u"that's kÀse".encode('utf-8'))
487        self.assertIsInstance(r, bytes)
488        self.assertEqual(r, u'"that\'s kÀse"'.encode('utf-8'))
489        r = f(u"that's kÀse")
490        self.assertIsInstance(r, unicode)
491        self.assertEqual(r, u'"that\'s kÀse"')
492        self.assertEqual(f(r"It's fine to have a \ inside."),
493                         '"It\'s fine to have a \\ inside."')
494        self.assertEqual(f('All "quotes" must be escaped.'),
495                         '"All ""quotes"" must be escaped."')
496
497    def testEscapeString(self):
498        f = self.db.escape_string
499        r = f(b"plain")
500        self.assertIsInstance(r, bytes)
501        self.assertEqual(r, b"plain")
502        r = f(u"plain")
503        self.assertIsInstance(r, unicode)
504        self.assertEqual(r, u"plain")
505        r = f(u"that's kÀse".encode('utf-8'))
506        self.assertIsInstance(r, bytes)
507        self.assertEqual(r, u"that''s kÀse".encode('utf-8'))
508        r = f(u"that's kÀse")
509        self.assertIsInstance(r, unicode)
510        self.assertEqual(r, u"that''s kÀse")
511        self.assertEqual(f(r"It's fine to have a \ inside."),
512                         r"It''s fine to have a \ inside.")
513
514    def testEscapeBytea(self):
515        f = self.db.escape_bytea
516        # note that escape_byte always returns hex output since Pg 9.0,
517        # regardless of the bytea_output setting
518        r = f(b'plain')
519        self.assertIsInstance(r, bytes)
520        self.assertEqual(r, b'\\x706c61696e')
521        r = f(u'plain')
522        self.assertIsInstance(r, unicode)
523        self.assertEqual(r, u'\\x706c61696e')
524        r = f(u"das is' kÀse".encode('utf-8'))
525        self.assertIsInstance(r, bytes)
526        self.assertEqual(r, b'\\x64617320697327206bc3a47365')
527        r = f(u"das is' kÀse")
528        self.assertIsInstance(r, unicode)
529        self.assertEqual(r, u'\\x64617320697327206bc3a47365')
530        self.assertEqual(f(b'O\x00ps\xff!'), b'\\x4f007073ff21')
531
532    def testUnescapeBytea(self):
533        f = self.db.unescape_bytea
534        r = f(b'plain')
535        self.assertIsInstance(r, bytes)
536        self.assertEqual(r, b'plain')
537        r = f(u'plain')
538        self.assertIsInstance(r, bytes)
539        self.assertEqual(r, b'plain')
540        r = f(b"das is' k\\303\\244se")
541        self.assertIsInstance(r, bytes)
542        self.assertEqual(r, u"das is' kÀse".encode('utf8'))
543        r = f(u"das is' k\\303\\244se")
544        self.assertIsInstance(r, bytes)
545        self.assertEqual(r, u"das is' kÀse".encode('utf8'))
546        self.assertEqual(f(r'O\\000ps\\377!'), b'O\\000ps\\377!')
547        self.assertEqual(f(r'\\x706c61696e'), b'\\x706c61696e')
548        self.assertEqual(f(r'\\x746861742773206be47365'),
549                         b'\\x746861742773206be47365')
550        self.assertEqual(f(r'\\x4f007073ff21'), b'\\x4f007073ff21')
551
552    def testDecodeJson(self):
553        f = self.db.decode_json
554        self.assertIsNone(f('null'))
555        data = {
556            "id": 1, "name": "Foo", "price": 1234.5,
557            "new": True, "note": None,
558            "tags": ["Bar", "Eek"],
559            "stock": {"warehouse": 300, "retail": 20}}
560        text = json.dumps(data)
561        r = f(text)
562        self.assertIsInstance(r, dict)
563        self.assertEqual(r, data)
564        self.assertIsInstance(r['id'], int)
565        self.assertIsInstance(r['name'], unicode)
566        self.assertIsInstance(r['price'], float)
567        self.assertIsInstance(r['new'], bool)
568        self.assertIsInstance(r['tags'], list)
569        self.assertIsInstance(r['stock'], dict)
570
571    def testEncodeJson(self):
572        f = self.db.encode_json
573        self.assertEqual(f(None), 'null')
574        data = {
575            "id": 1, "name": "Foo", "price": 1234.5,
576            "new": True, "note": None,
577            "tags": ["Bar", "Eek"],
578            "stock": {"warehouse": 300, "retail": 20}}
579        text = json.dumps(data)
580        r = f(data)
581        self.assertIsInstance(r, str)
582        self.assertEqual(r, text)
583
584    def testGetParameter(self):
585        f = self.db.get_parameter
586        self.assertRaises(TypeError, f)
587        self.assertRaises(TypeError, f, None)
588        self.assertRaises(TypeError, f, 42)
589        self.assertRaises(TypeError, f, '')
590        self.assertRaises(TypeError, f, [])
591        self.assertRaises(TypeError, f, [''])
592        self.assertRaises(pg.ProgrammingError, f, 'this_does_not_exist')
593        r = f('standard_conforming_strings')
594        self.assertEqual(r, 'on')
595        r = f('lc_monetary')
596        self.assertEqual(r, 'C')
597        r = f('datestyle')
598        self.assertEqual(r, 'ISO, YMD')
599        r = f('bytea_output')
600        self.assertEqual(r, 'hex')
601        r = f(['bytea_output', 'lc_monetary'])
602        self.assertIsInstance(r, list)
603        self.assertEqual(r, ['hex', 'C'])
604        r = f(('standard_conforming_strings', 'datestyle', 'bytea_output'))
605        self.assertEqual(r, ['on', 'ISO, YMD', 'hex'])
606        r = f(set(['bytea_output', 'lc_monetary']))
607        self.assertIsInstance(r, dict)
608        self.assertEqual(r, {'bytea_output': 'hex', 'lc_monetary': 'C'})
609        r = f(set(['Bytea_Output', ' LC_Monetary ']))
610        self.assertIsInstance(r, dict)
611        self.assertEqual(r, {'Bytea_Output': 'hex', ' LC_Monetary ': 'C'})
612        s = dict.fromkeys(('bytea_output', 'lc_monetary'))
613        r = f(s)
614        self.assertIs(r, s)
615        self.assertEqual(r, {'bytea_output': 'hex', 'lc_monetary': 'C'})
616        s = dict.fromkeys(('Bytea_Output', ' LC_Monetary '))
617        r = f(s)
618        self.assertIs(r, s)
619        self.assertEqual(r, {'Bytea_Output': 'hex', ' LC_Monetary ': 'C'})
620
621    def testGetParameterServerVersion(self):
622        r = self.db.get_parameter('server_version_num')
623        self.assertIsInstance(r, str)
624        s = self.db.server_version
625        self.assertIsInstance(s, int)
626        self.assertEqual(r, str(s))
627
628    def testGetParameterAll(self):
629        f = self.db.get_parameter
630        r = f('all')
631        self.assertIsInstance(r, dict)
632        self.assertEqual(r['standard_conforming_strings'], 'on')
633        self.assertEqual(r['lc_monetary'], 'C')
634        self.assertEqual(r['DateStyle'], 'ISO, YMD')
635        self.assertEqual(r['bytea_output'], 'hex')
636
637    def testSetParameter(self):
638        f = self.db.set_parameter
639        g = self.db.get_parameter
640        self.assertRaises(TypeError, f)
641        self.assertRaises(TypeError, f, None)
642        self.assertRaises(TypeError, f, 42)
643        self.assertRaises(TypeError, f, '')
644        self.assertRaises(TypeError, f, [])
645        self.assertRaises(TypeError, f, [''])
646        self.assertRaises(ValueError, f, 'all', 'invalid')
647        self.assertRaises(ValueError, f, {
648            'invalid1': 'value1', 'invalid2': 'value2'}, 'value')
649        self.assertRaises(pg.ProgrammingError, f, 'this_does_not_exist')
650        f('standard_conforming_strings', 'off')
651        self.assertEqual(g('standard_conforming_strings'), 'off')
652        f('datestyle', 'ISO, DMY')
653        self.assertEqual(g('datestyle'), 'ISO, DMY')
654        f(['standard_conforming_strings', 'datestyle'], ['on', 'ISO, DMY'])
655        self.assertEqual(g('standard_conforming_strings'), 'on')
656        self.assertEqual(g('datestyle'), 'ISO, DMY')
657        f(['default_with_oids', 'standard_conforming_strings'], 'off')
658        self.assertEqual(g('default_with_oids'), 'off')
659        self.assertEqual(g('standard_conforming_strings'), 'off')
660        f(('standard_conforming_strings', 'datestyle'), ('on', 'ISO, YMD'))
661        self.assertEqual(g('standard_conforming_strings'), 'on')
662        self.assertEqual(g('datestyle'), 'ISO, YMD')
663        f(('default_with_oids', 'standard_conforming_strings'), 'off')
664        self.assertEqual(g('default_with_oids'), 'off')
665        self.assertEqual(g('standard_conforming_strings'), 'off')
666        f(set(['default_with_oids', 'standard_conforming_strings']), 'on')
667        self.assertEqual(g('default_with_oids'), 'on')
668        self.assertEqual(g('standard_conforming_strings'), 'on')
669        self.assertRaises(ValueError, f, set([ 'default_with_oids',
670                                               'standard_conforming_strings']), ['off', 'on'])
671        f(set(['default_with_oids', 'standard_conforming_strings']),
672          ['off', 'off'])
673        self.assertEqual(g('default_with_oids'), 'off')
674        self.assertEqual(g('standard_conforming_strings'), 'off')
675        f({'standard_conforming_strings': 'on', 'datestyle': 'ISO, YMD'})
676        self.assertEqual(g('standard_conforming_strings'), 'on')
677        self.assertEqual(g('datestyle'), 'ISO, YMD')
678
679    def testResetParameter(self):
680        db = DB()
681        f = db.set_parameter
682        g = db.get_parameter
683        r = g('default_with_oids')
684        self.assertIn(r, ('on', 'off'))
685        dwi, not_dwi = r, 'off' if r == 'on' else 'on'
686        r = g('standard_conforming_strings')
687        self.assertIn(r, ('on', 'off'))
688        scs, not_scs = r, 'off' if r == 'on' else 'on'
689        f('default_with_oids', not_dwi)
690        f('standard_conforming_strings', not_scs)
691        self.assertEqual(g('default_with_oids'), not_dwi)
692        self.assertEqual(g('standard_conforming_strings'), not_scs)
693        f('default_with_oids')
694        f('standard_conforming_strings', None)
695        self.assertEqual(g('default_with_oids'), dwi)
696        self.assertEqual(g('standard_conforming_strings'), scs)
697        f('default_with_oids', not_dwi)
698        f('standard_conforming_strings', not_scs)
699        self.assertEqual(g('default_with_oids'), not_dwi)
700        self.assertEqual(g('standard_conforming_strings'), not_scs)
701        f(['default_with_oids', 'standard_conforming_strings'], None)
702        self.assertEqual(g('default_with_oids'), dwi)
703        self.assertEqual(g('standard_conforming_strings'), scs)
704        f('default_with_oids', not_dwi)
705        f('standard_conforming_strings', not_scs)
706        self.assertEqual(g('default_with_oids'), not_dwi)
707        self.assertEqual(g('standard_conforming_strings'), not_scs)
708        f(('default_with_oids', 'standard_conforming_strings'))
709        self.assertEqual(g('default_with_oids'), dwi)
710        self.assertEqual(g('standard_conforming_strings'), scs)
711        f('default_with_oids', not_dwi)
712        f('standard_conforming_strings', not_scs)
713        self.assertEqual(g('default_with_oids'), not_dwi)
714        self.assertEqual(g('standard_conforming_strings'), not_scs)
715        f(set(['default_with_oids', 'standard_conforming_strings']))
716        self.assertEqual(g('default_with_oids'), dwi)
717        self.assertEqual(g('standard_conforming_strings'), scs)
718
719    def testResetParameterAll(self):
720        db = DB()
721        f = db.set_parameter
722        self.assertRaises(ValueError, f, 'all', 0)
723        self.assertRaises(ValueError, f, 'all', 'off')
724        g = db.get_parameter
725        r = g('default_with_oids')
726        self.assertIn(r, ('on', 'off'))
727        dwi, not_dwi = r, 'off' if r == 'on' else 'on'
728        r = g('standard_conforming_strings')
729        self.assertIn(r, ('on', 'off'))
730        scs, not_scs = r, 'off' if r == 'on' else 'on'
731        f('default_with_oids', not_dwi)
732        f('standard_conforming_strings', not_scs)
733        self.assertEqual(g('default_with_oids'), not_dwi)
734        self.assertEqual(g('standard_conforming_strings'), not_scs)
735        f('all')
736        self.assertEqual(g('default_with_oids'), dwi)
737        self.assertEqual(g('standard_conforming_strings'), scs)
738
739    def testSetParameterLocal(self):
740        f = self.db.set_parameter
741        g = self.db.get_parameter
742        self.assertEqual(g('standard_conforming_strings'), 'on')
743        self.db.begin()
744        f('standard_conforming_strings', 'off', local=True)
745        self.assertEqual(g('standard_conforming_strings'), 'off')
746        self.db.end()
747        self.assertEqual(g('standard_conforming_strings'), 'on')
748
749    def testSetParameterSession(self):
750        f = self.db.set_parameter
751        g = self.db.get_parameter
752        self.assertEqual(g('standard_conforming_strings'), 'on')
753        self.db.begin()
754        f('standard_conforming_strings', 'off', local=False)
755        self.assertEqual(g('standard_conforming_strings'), 'off')
756        self.db.end()
757        self.assertEqual(g('standard_conforming_strings'), 'off')
758
759    def testReset(self):
760        db = DB()
761        default_datestyle = db.get_parameter('datestyle')
762        changed_datestyle = 'ISO, DMY'
763        if changed_datestyle == default_datestyle:
764            changed_datestyle == 'ISO, YMD'
765        self.db.set_parameter('datestyle', changed_datestyle)
766        r = self.db.get_parameter('datestyle')
767        self.assertEqual(r, changed_datestyle)
768        con = self.db.db
769        q = con.query("show datestyle")
770        self.db.reset()
771        r = q.getresult()[0][0]
772        self.assertEqual(r, changed_datestyle)
773        q = con.query("show datestyle")
774        r = q.getresult()[0][0]
775        self.assertEqual(r, default_datestyle)
776        r = self.db.get_parameter('datestyle')
777        self.assertEqual(r, default_datestyle)
778
779    def testReopen(self):
780        db = DB()
781        default_datestyle = db.get_parameter('datestyle')
782        changed_datestyle = 'ISO, DMY'
783        if changed_datestyle == default_datestyle:
784            changed_datestyle == 'ISO, YMD'
785        self.db.set_parameter('datestyle', changed_datestyle)
786        r = self.db.get_parameter('datestyle')
787        self.assertEqual(r, changed_datestyle)
788        con = self.db.db
789        q = con.query("show datestyle")
790        self.db.reopen()
791        r = q.getresult()[0][0]
792        self.assertEqual(r, changed_datestyle)
793        self.assertRaises(TypeError, getattr, con, 'query')
794        r = self.db.get_parameter('datestyle')
795        self.assertEqual(r, default_datestyle)
796
797    def testCreateTable(self):
798        table = 'test hello world'
799        values = [(2, "World!"), (1, "Hello")]
800        self.createTable(table, "n smallint, t varchar",
801                         temporary=True, oids=True, values=values)
802        r = self.db.query('select t from "%s" order by n' % table).getresult()
803        r = ', '.join(row[0] for row in r)
804        self.assertEqual(r, "Hello, World!")
805        r = self.db.query('select oid from "%s" limit 1' % table).getresult()
806        self.assertIsInstance(r[0][0], int)
807
808    def testQuery(self):
809        query = self.db.query
810        table = 'test_table'
811        self.createTable(table, "n integer", oids=True)
812        q = "insert into test_table values (1)"
813        r = query(q)
814        self.assertIsInstance(r, int)
815        q = "insert into test_table select 2"
816        r = query(q)
817        self.assertIsInstance(r, int)
818        oid = r
819        q = "select oid from test_table where n=2"
820        r = query(q).getresult()
821        self.assertEqual(len(r), 1)
822        r = r[0]
823        self.assertEqual(len(r), 1)
824        r = r[0]
825        self.assertEqual(r, oid)
826        q = "insert into test_table select 3 union select 4 union select 5"
827        r = query(q)
828        self.assertIsInstance(r, str)
829        self.assertEqual(r, '3')
830        q = "update test_table set n=4 where n<5"
831        r = query(q)
832        self.assertIsInstance(r, str)
833        self.assertEqual(r, '4')
834        q = "delete from test_table"
835        r = query(q)
836        self.assertIsInstance(r, str)
837        self.assertEqual(r, '5')
838
839    def testMultipleQueries(self):
840        self.assertEqual(self.db.query(
841            "create temporary table test_multi (n integer);"
842            "insert into test_multi values (4711);"
843            "select n from test_multi").getresult()[0][0], 4711)
844
845    def testQueryWithParams(self):
846        query = self.db.query
847        self.createTable('test_table', 'n1 integer, n2 integer', oids=True)
848        q = "insert into test_table values ($1, $2)"
849        r = query(q, (1, 2))
850        self.assertIsInstance(r, int)
851        r = query(q, [3, 4])
852        self.assertIsInstance(r, int)
853        r = query(q, [5, 6])
854        self.assertIsInstance(r, int)
855        q = "select * from test_table order by 1, 2"
856        self.assertEqual(query(q).getresult(),
857                         [(1, 2), (3, 4), (5, 6)])
858        q = "select * from test_table where n1=$1 and n2=$2"
859        self.assertEqual(query(q, 3, 4).getresult(), [(3, 4)])
860        q = "update test_table set n2=$2 where n1=$1"
861        r = query(q, 3, 7)
862        self.assertEqual(r, '1')
863        q = "select * from test_table order by 1, 2"
864        self.assertEqual(query(q).getresult(),
865                         [(1, 2), (3, 7), (5, 6)])
866        q = "delete from test_table where n2!=$1"
867        r = query(q, 4)
868        self.assertEqual(r, '3')
869
870    def testEmptyQuery(self):
871        self.assertRaises(ValueError, self.db.query, '')
872
873    def testQueryProgrammingError(self):
874        try:
875            self.db.query("select 1/0")
876        except pg.ProgrammingError as error:
877            self.assertEqual(error.sqlstate, '22012')
878
879    def testPkey(self):
880        query = self.db.query
881        pkey = self.db.pkey
882        self.assertRaises(KeyError, pkey, 'test')
883        for t in ('pkeytest', 'primary key test'):
884            self.createTable('%s0' % t, 'a smallint')
885            self.createTable('%s1' % t, 'b smallint primary key')
886            self.createTable('%s2' % t,
887                             'c smallint, d smallint primary key')
888            self.createTable('%s3' % t,
889                             'e smallint, f smallint, g smallint, h smallint, i smallint,'
890                             ' primary key (f, h)')
891            self.createTable('%s4' % t,
892                             'e smallint, f smallint, g smallint, h smallint, i smallint,'
893                             ' primary key (h, f)')
894            self.createTable('%s5' % t,
895                             'more_than_one_letter varchar primary key')
896            self.createTable('%s6' % t,
897                             '"with space" date primary key')
898            self.createTable('%s7' % t,
899                             'a_very_long_column_name varchar, "with space" date, "42" int,'
900                             ' primary key (a_very_long_column_name, "with space", "42")')
901            self.assertRaises(KeyError, pkey, '%s0' % t)
902            self.assertEqual(pkey('%s1' % t), 'b')
903            self.assertEqual(pkey('%s1' % t, True), ('b',))
904            self.assertEqual(pkey('%s1' % t, composite=False), 'b')
905            self.assertEqual(pkey('%s1' % t, composite=True), ('b',))
906            self.assertEqual(pkey('%s2' % t), 'd')
907            self.assertEqual(pkey('%s2' % t, composite=True), ('d',))
908            r = pkey('%s3' % t)
909            self.assertIsInstance(r, tuple)
910            self.assertEqual(r, ('f', 'h'))
911            r = pkey('%s3' % t, composite=False)
912            self.assertIsInstance(r, tuple)
913            self.assertEqual(r, ('f', 'h'))
914            r = pkey('%s4' % t)
915            self.assertIsInstance(r, tuple)
916            self.assertEqual(r, ('h', 'f'))
917            self.assertEqual(pkey('%s5' % t), 'more_than_one_letter')
918            self.assertEqual(pkey('%s6' % t), 'with space')
919            r = pkey('%s7' % t)
920            self.assertIsInstance(r, tuple)
921            self.assertEqual(r, (
922                'a_very_long_column_name', 'with space', '42'))
923            # a newly added primary key will be detected
924            query('alter table "%s0" add primary key (a)' % t)
925            self.assertEqual(pkey('%s0' % t), 'a')
926            # a changed primary key will not be detected,
927            # indicating that the internal cache is operating
928            query('alter table "%s1" rename column b to x' % t)
929            self.assertEqual(pkey('%s1' % t), 'b')
930            # we get the changed primary key when the cache is flushed
931            self.assertEqual(pkey('%s1' % t, flush=True), 'x')
932
933    def testGetDatabases(self):
934        databases = self.db.get_databases()
935        self.assertIn('template0', databases)
936        self.assertIn('template1', databases)
937        self.assertNotIn('not existing database', databases)
938        self.assertIn('postgres', databases)
939        self.assertIn(dbname, databases)
940
941    def testGetTables(self):
942        get_tables = self.db.get_tables
943        tables = ('A very Special Name', 'A_MiXeD_quoted_NaMe',
944                  'Hello, Test World!', 'Zoro', 'a1', 'a2', 'a321',
945                  'averyveryveryveryveryveryveryreallyreallylongtablename',
946                  'b0', 'b3', 'x', 'xXx', 'xx', 'y', 'z')
947        for t in tables:
948            self.db.query('drop table if exists "%s" cascade' % t)
949        before_tables = get_tables()
950        self.assertIsInstance(before_tables, list)
951        for t in before_tables:
952            t = t.split('.', 1)
953            self.assertGreaterEqual(len(t), 2)
954            if len(t) > 2:
955                self.assertTrue(t[1].startswith('"'))
956            t = t[0]
957            self.assertNotEqual(t, 'information_schema')
958            self.assertFalse(t.startswith('pg_'))
959        for t in tables:
960            self.createTable(t, 'as select 0', temporary=False)
961        current_tables = get_tables()
962        new_tables = [t for t in current_tables if t not in before_tables]
963        expected_new_tables = ['public.%s' % (
964            '"%s"' % t if ' ' in t or t != t.lower() else t) for t in tables]
965        self.assertEqual(new_tables, expected_new_tables)
966        self.doCleanups()
967        after_tables = get_tables()
968        self.assertEqual(after_tables, before_tables)
969
970    def testGetRelations(self):
971        get_relations = self.db.get_relations
972        result = get_relations()
973        self.assertIn('public.test', result)
974        self.assertIn('public.test_view', result)
975        result = get_relations('rv')
976        self.assertIn('public.test', result)
977        self.assertIn('public.test_view', result)
978        result = get_relations('r')
979        self.assertIn('public.test', result)
980        self.assertNotIn('public.test_view', result)
981        result = get_relations('v')
982        self.assertNotIn('public.test', result)
983        self.assertIn('public.test_view', result)
984        result = get_relations('cisSt')
985        self.assertNotIn('public.test', result)
986        self.assertNotIn('public.test_view', result)
987
988    def testGetAttnames(self):
989        get_attnames = self.db.get_attnames
990        self.assertRaises(pg.ProgrammingError,
991                          self.db.get_attnames, 'does_not_exist')
992        self.assertRaises(pg.ProgrammingError,
993                          self.db.get_attnames, 'has.too.many.dots')
994        r = get_attnames('test')
995        self.assertIsInstance(r, dict)
996        if self.regtypes:
997            self.assertEqual(r, dict(
998                i2='smallint', i4='integer', i8='bigint', d='numeric',
999                f4='real', f8='double precision', m='money',
1000                v4='character varying', c4='character', t='text'))
1001        else:
1002            self.assertEqual(r, dict(
1003                i2='int', i4='int', i8='int', d='num',
1004                f4='float', f8='float', m='money',
1005                v4='text', c4='text', t='text'))
1006        self.createTable('test_table',
1007                         'n int, alpha smallint, beta bool,'
1008                         ' gamma char(5), tau text, v varchar(3)')
1009        r = get_attnames('test_table')
1010        self.assertIsInstance(r, dict)
1011        if self.regtypes:
1012            self.assertEqual(r, dict(
1013                n='integer', alpha='smallint', beta='boolean',
1014                gamma='character', tau='text', v='character varying'))
1015        else:
1016            self.assertEqual(r, dict(
1017                n='int', alpha='int', beta='bool',
1018                gamma='text', tau='text', v='text'))
1019
1020    def testGetAttnamesWithQuotes(self):
1021        get_attnames = self.db.get_attnames
1022        table = 'test table for get_attnames()'
1023        self.createTable(table,
1024                         '"Prime!" smallint, "much space" integer, "Questions?" text')
1025        r = get_attnames(table)
1026        self.assertIsInstance(r, dict)
1027        if self.regtypes:
1028            self.assertEqual(r, {
1029                'Prime!': 'smallint', 'much space': 'integer',
1030                'Questions?': 'text'})
1031        else:
1032            self.assertEqual(r, {
1033                'Prime!': 'int', 'much space': 'int', 'Questions?': 'text'})
1034        table = 'yet another test table for get_attnames()'
1035        self.createTable(table,
1036                         'a smallint, b integer, c bigint,'
1037                         ' e numeric, f real, f2 double precision, m money,'
1038                         ' x smallint, y smallint, z smallint,'
1039                         ' Normal_NaMe smallint, "Special Name" smallint,'
1040                         ' t text, u char(2), v varchar(2),'
1041                         ' primary key (y, u)', oids=True)
1042        r = get_attnames(table)
1043        self.assertIsInstance(r, dict)
1044        if self.regtypes:
1045            self.assertEqual(r, {
1046                'a': 'smallint', 'b': 'integer', 'c': 'bigint',
1047                'e': 'numeric', 'f': 'real', 'f2': 'double precision',
1048                'm': 'money', 'normal_name': 'smallint',
1049                'Special Name': 'smallint', 'u': 'character',
1050                't': 'text', 'v': 'character varying', 'y': 'smallint',
1051                'x': 'smallint', 'z': 'smallint', 'oid': 'oid'})
1052        else:
1053            self.assertEqual(r, {'a': 'int', 'b': 'int', 'c': 'int',
1054                                 'e': 'num', 'f': 'float', 'f2': 'float', 'm': 'money',
1055                                 'normal_name': 'int', 'Special Name': 'int',
1056                                 'u': 'text', 't': 'text', 'v': 'text',
1057                                 'y': 'int', 'x': 'int', 'z': 'int', 'oid': 'int'})
1058
1059    def testGetAttnamesWithRegtypes(self):
1060        get_attnames = self.db.get_attnames
1061        self.createTable('test_table',
1062                         ' n int, alpha smallint, beta bool,'
1063                         ' gamma char(5), tau text, v varchar(3)')
1064        use_regtypes = self.db.use_regtypes
1065        regtypes = use_regtypes()
1066        self.assertEqual(regtypes, self.regtypes)
1067        use_regtypes(True)
1068        try:
1069            r = get_attnames("test_table")
1070            self.assertIsInstance(r, dict)
1071        finally:
1072            use_regtypes(regtypes)
1073        self.assertEqual(r, dict(
1074            n='integer', alpha='smallint', beta='boolean',
1075            gamma='character', tau='text', v='character varying'))
1076
1077    def testGetAttnamesWithoutRegtypes(self):
1078        get_attnames = self.db.get_attnames
1079        self.createTable('test_table',
1080                         ' n int, alpha smallint, beta bool,'
1081                         ' gamma char(5), tau text, v varchar(3)')
1082        use_regtypes = self.db.use_regtypes
1083        regtypes = use_regtypes()
1084        self.assertEqual(regtypes, self.regtypes)
1085        use_regtypes(False)
1086        try:
1087            r = get_attnames("test_table")
1088            self.assertIsInstance(r, dict)
1089        finally:
1090            use_regtypes(regtypes)
1091        self.assertEqual(r, dict(
1092            n='int', alpha='int', beta='bool',
1093            gamma='text', tau='text', v='text'))
1094
1095    def testGetAttnamesIsCached(self):
1096        get_attnames = self.db.get_attnames
1097        int_type = 'integer' if self.regtypes else 'int'
1098        text_type = 'text'
1099        query = self.db.query
1100        self.createTable('test_table', 'col int')
1101        r = get_attnames("test_table")
1102        self.assertIsInstance(r, dict)
1103        self.assertEqual(r, dict(col=int_type))
1104        query("alter table test_table alter column col type text")
1105        query("alter table test_table add column col2 int")
1106        r = get_attnames("test_table")
1107        self.assertEqual(r, dict(col=int_type))
1108        r = get_attnames("test_table", flush=True)
1109        self.assertEqual(r, dict(col=text_type, col2=int_type))
1110        query("alter table test_table drop column col2")
1111        r = get_attnames("test_table")
1112        self.assertEqual(r, dict(col=text_type, col2=int_type))
1113        r = get_attnames("test_table", flush=True)
1114        self.assertEqual(r, dict(col=text_type))
1115        query("alter table test_table drop column col")
1116        r = get_attnames("test_table")
1117        self.assertEqual(r, dict(col=text_type))
1118        r = get_attnames("test_table", flush=True)
1119        self.assertEqual(r, dict())
1120
1121    def testGetAttnamesIsOrdered(self):
1122        get_attnames = self.db.get_attnames
1123        r = get_attnames('test', flush=True)
1124        self.assertIsInstance(r, OrderedDict)
1125        if self.regtypes:
1126            self.assertEqual(r, OrderedDict([
1127                ('i2', 'smallint'), ('i4', 'integer'), ('i8', 'bigint'),
1128                ('d', 'numeric'), ('f4', 'real'), ('f8', 'double precision'),
1129                ('m', 'money'), ('v4', 'character varying'),
1130                ('c4', 'character'), ('t', 'text')]))
1131        else:
1132            self.assertEqual(r, OrderedDict([
1133                ('i2', 'int'), ('i4', 'int'), ('i8', 'int'),
1134                ('d', 'num'), ('f4', 'float'), ('f8', 'float'), ('m', 'money'),
1135                ('v4', 'text'), ('c4', 'text'), ('t', 'text')]))
1136        if OrderedDict is not dict:
1137            r = ' '.join(list(r.keys()))
1138            self.assertEqual(r, 'i2 i4 i8 d f4 f8 m v4 c4 t')
1139        table = 'test table for get_attnames'
1140        self.createTable(table,
1141                         ' n int, alpha smallint, v varchar(3),'
1142                         ' gamma char(5), tau text, beta bool')
1143        r = get_attnames(table)
1144        self.assertIsInstance(r, OrderedDict)
1145        if self.regtypes:
1146            self.assertEqual(r, OrderedDict([
1147                ('n', 'integer'), ('alpha', 'smallint'),
1148                ('v', 'character varying'), ('gamma', 'character'),
1149                ('tau', 'text'), ('beta', 'boolean')]))
1150        else:
1151            self.assertEqual(r, OrderedDict([
1152                ('n', 'int'), ('alpha', 'int'), ('v', 'text'),
1153                ('gamma', 'text'), ('tau', 'text'), ('beta', 'bool')]))
1154        if OrderedDict is not dict:
1155            r = ' '.join(list(r.keys()))
1156            self.assertEqual(r, 'n alpha v gamma tau beta')
1157        else:
1158            self.skipTest('OrderedDict is not supported')
1159
1160    def testGetAttnamesIsAttrDict(self):
1161        AttrDict = pg.AttrDict
1162        get_attnames = self.db.get_attnames
1163        r = get_attnames('test', flush=True)
1164        self.assertIsInstance(r, AttrDict)
1165        if self.regtypes:
1166            self.assertEqual(r, AttrDict([
1167                ('i2', 'smallint'), ('i4', 'integer'), ('i8', 'bigint'),
1168                ('d', 'numeric'), ('f4', 'real'), ('f8', 'double precision'),
1169                ('m', 'money'), ('v4', 'character varying'),
1170                ('c4', 'character'), ('t', 'text')]))
1171        else:
1172            self.assertEqual(r, AttrDict([
1173                ('i2', 'int'), ('i4', 'int'), ('i8', 'int'),
1174                ('d', 'num'), ('f4', 'float'), ('f8', 'float'), ('m', 'money'),
1175                ('v4', 'text'), ('c4', 'text'), ('t', 'text')]))
1176        r = ' '.join(list(r.keys()))
1177        self.assertEqual(r, 'i2 i4 i8 d f4 f8 m v4 c4 t')
1178        table = 'test table for get_attnames'
1179        self.createTable(table,
1180                         ' n int, alpha smallint, v varchar(3),'
1181                         ' gamma char(5), tau text, beta bool')
1182        r = get_attnames(table)
1183        self.assertIsInstance(r, AttrDict)
1184        if self.regtypes:
1185            self.assertEqual(r, AttrDict([
1186                ('n', 'integer'), ('alpha', 'smallint'),
1187                ('v', 'character varying'), ('gamma', 'character'),
1188                ('tau', 'text'), ('beta', 'boolean')]))
1189        else:
1190            self.assertEqual(r, AttrDict([
1191                ('n', 'int'), ('alpha', 'int'), ('v', 'text'),
1192                ('gamma', 'text'), ('tau', 'text'), ('beta', 'bool')]))
1193        r = ' '.join(list(r.keys()))
1194        self.assertEqual(r, 'n alpha v gamma tau beta')
1195
1196    def testHasTablePrivilege(self):
1197        can = self.db.has_table_privilege
1198        self.assertEqual(can('test'), True)
1199        self.assertEqual(can('test', 'select'), True)
1200        self.assertEqual(can('test', 'SeLeCt'), True)
1201        self.assertEqual(can('test', 'SELECT'), True)
1202        self.assertEqual(can('test', 'insert'), True)
1203        self.assertEqual(can('test', 'update'), True)
1204        self.assertEqual(can('test', 'delete'), True)
1205        self.assertEqual(can('pg_views', 'select'), True)
1206        self.assertEqual(can('pg_views', 'delete'), False)
1207        self.assertRaises(pg.ProgrammingError, can, 'test', 'foobar')
1208        self.assertRaises(pg.ProgrammingError, can, 'table_does_not_exist')
1209
1210    def testGet(self):
1211        get = self.db.get
1212        query = self.db.query
1213        table = 'get_test_table'
1214        self.assertRaises(TypeError, get)
1215        self.assertRaises(TypeError, get, table)
1216        self.createTable(table, 'n integer, t text',
1217                         values=enumerate('xyz', start=1))
1218        self.assertRaises(pg.ProgrammingError, get, table, 2)
1219        r = get(table, 2, 'n')
1220        self.assertIsInstance(r, dict)
1221        self.assertEqual(r, dict(n=2, t='y'))
1222        r = get(table, 1, 'n')
1223        self.assertEqual(r, dict(n=1, t='x'))
1224        r = get(table, (3,), ('n',))
1225        self.assertEqual(r, dict(n=3, t='z'))
1226        r = get(table, 'y', 't')
1227        self.assertEqual(r, dict(n=2, t='y'))
1228        self.assertRaises(pg.DatabaseError, get, table, 4)
1229        self.assertRaises(pg.DatabaseError, get, table, 4, 'n')
1230        self.assertRaises(pg.DatabaseError, get, table, 'y')
1231        self.assertRaises(pg.DatabaseError, get, table, 2, 't')
1232        s = dict(n=3)
1233        self.assertRaises(pg.ProgrammingError, get, table, s)
1234        r = get(table, s, 'n')
1235        self.assertIs(r, s)
1236        self.assertEqual(r, dict(n=3, t='z'))
1237        s.update(t='x')
1238        r = get(table, s, 't')
1239        self.assertIs(r, s)
1240        self.assertEqual(s, dict(n=1, t='x'))
1241        r = get(table, s, ('n', 't'))
1242        self.assertIs(r, s)
1243        self.assertEqual(r, dict(n=1, t='x'))
1244        query('alter table "%s" alter n set not null' % table)
1245        query('alter table "%s" add primary key (n)' % table)
1246        r = get(table, 2)
1247        self.assertIsInstance(r, dict)
1248        self.assertEqual(r, dict(n=2, t='y'))
1249        self.assertEqual(get(table, 1)['t'], 'x')
1250        self.assertEqual(get(table, 3)['t'], 'z')
1251        self.assertEqual(get(table + '*', 2)['t'], 'y')
1252        self.assertEqual(get(table + ' *', 2)['t'], 'y')
1253        self.assertRaises(KeyError, get, table, (2, 2))
1254        s = dict(n=3)
1255        r = get(table, s)
1256        self.assertIs(r, s)
1257        self.assertEqual(r, dict(n=3, t='z'))
1258        s.update(n=1)
1259        self.assertEqual(get(table, s)['t'], 'x')
1260        s.update(n=2)
1261        self.assertEqual(get(table, r)['t'], 'y')
1262        s.pop('n')
1263        self.assertRaises(KeyError, get, table, s)
1264
1265    def testGetWithOid(self):
1266        get = self.db.get
1267        query = self.db.query
1268        table = 'get_with_oid_test_table'
1269        self.createTable(table, 'n integer, t text', oids=True,
1270                         values=enumerate('xyz', start=1))
1271        self.assertRaises(pg.ProgrammingError, get, table, 2)
1272        self.assertRaises(KeyError, get, table, {}, 'oid')
1273        r = get(table, 2, 'n')
1274        qoid = 'oid(%s)' % table
1275        self.assertIn(qoid, r)
1276        oid = r[qoid]
1277        self.assertIsInstance(oid, int)
1278        result = {'t': 'y', 'n': 2, qoid: oid}
1279        self.assertEqual(r, result)
1280        r = get(table, oid, 'oid')
1281        self.assertEqual(r, result)
1282        r = get(table, dict(oid=oid))
1283        self.assertEqual(r, result)
1284        r = get(table, dict(oid=oid), 'oid')
1285        self.assertEqual(r, result)
1286        r = get(table, {qoid: oid})
1287        self.assertEqual(r, result)
1288        r = get(table, {qoid: oid}, 'oid')
1289        self.assertEqual(r, result)
1290        self.assertEqual(get(table + '*', 2, 'n'), r)
1291        self.assertEqual(get(table + ' *', 2, 'n'), r)
1292        self.assertEqual(get(table, oid, 'oid')['t'], 'y')
1293        self.assertEqual(get(table, 1, 'n')['t'], 'x')
1294        self.assertEqual(get(table, 3, 'n')['t'], 'z')
1295        self.assertEqual(get(table, 2, 'n')['t'], 'y')
1296        self.assertRaises(pg.DatabaseError, get, table, 4, 'n')
1297        r['n'] = 3
1298        self.assertEqual(get(table, r, 'n')['t'], 'z')
1299        self.assertEqual(get(table, 1, 'n')['t'], 'x')
1300        self.assertEqual(get(table, r, 'oid')['t'], 'z')
1301        query('alter table "%s" alter n set not null' % table)
1302        query('alter table "%s" add primary key (n)' % table)
1303        self.assertEqual(get(table, 3)['t'], 'z')
1304        self.assertEqual(get(table, 1)['t'], 'x')
1305        self.assertEqual(get(table, 2)['t'], 'y')
1306        r['n'] = 1
1307        self.assertEqual(get(table, r)['t'], 'x')
1308        r['n'] = 3
1309        self.assertEqual(get(table, r)['t'], 'z')
1310        r['n'] = 2
1311        self.assertEqual(get(table, r)['t'], 'y')
1312        r = get(table, oid, 'oid')
1313        self.assertEqual(r, result)
1314        r = get(table, dict(oid=oid))
1315        self.assertEqual(r, result)
1316        r = get(table, dict(oid=oid), 'oid')
1317        self.assertEqual(r, result)
1318        r = get(table, {qoid: oid})
1319        self.assertEqual(r, result)
1320        r = get(table, {qoid: oid}, 'oid')
1321        self.assertEqual(r, result)
1322        r = get(table, dict(oid=oid, n=1))
1323        self.assertEqual(r['n'], 1)
1324        self.assertNotEqual(r[qoid], oid)
1325        r = get(table, dict(oid=oid, t='z'), 't')
1326        self.assertEqual(r['n'], 3)
1327        self.assertNotEqual(r[qoid], oid)
1328
1329    def testGetWithCompositeKey(self):
1330        get = self.db.get
1331        query = self.db.query
1332        table = 'get_test_table_1'
1333        self.createTable(table, 'n integer primary key, t text',
1334                         values=enumerate('abc', start=1))
1335        self.assertEqual(get(table, 2)['t'], 'b')
1336        self.assertEqual(get(table, 1, 'n')['t'], 'a')
1337        self.assertEqual(get(table, 2, ('n',))['t'], 'b')
1338        self.assertEqual(get(table, 3, ['n'])['t'], 'c')
1339        self.assertEqual(get(table, (2,), ('n',))['t'], 'b')
1340        self.assertEqual(get(table, 'b', 't')['n'], 2)
1341        self.assertEqual(get(table, ('a',), ('t',))['n'], 1)
1342        self.assertEqual(get(table, ['c'], ['t'])['n'], 3)
1343        table = 'get_test_table_2'
1344        self.createTable(table,
1345                         'n integer, m integer, t text, primary key (n, m)',
1346                         values=[(n + 1, m + 1, chr(ord('a') + 2 * n + m))
1347                                 for n in range(3) for m in range(2)])
1348        self.assertRaises(KeyError, get, table, 2)
1349        self.assertEqual(get(table, (1, 1))['t'], 'a')
1350        self.assertEqual(get(table, (1, 2))['t'], 'b')
1351        self.assertEqual(get(table, (2, 1))['t'], 'c')
1352        self.assertEqual(get(table, (1, 2), ('n', 'm'))['t'], 'b')
1353        self.assertEqual(get(table, (1, 2), ('m', 'n'))['t'], 'c')
1354        self.assertEqual(get(table, (3, 1), ('n', 'm'))['t'], 'e')
1355        self.assertEqual(get(table, (1, 3), ('m', 'n'))['t'], 'e')
1356        self.assertEqual(get(table, dict(n=2, m=2))['t'], 'd')
1357        self.assertEqual(get(table, dict(n=1, m=2), ('n', 'm'))['t'], 'b')
1358        self.assertEqual(get(table, dict(n=2, m=1), ['n', 'm'])['t'], 'c')
1359        self.assertEqual(get(table, dict(n=3, m=2), ('m', 'n'))['t'], 'f')
1360
1361    def testGetWithQuotedNames(self):
1362        get = self.db.get
1363        query = self.db.query
1364        table = 'test table for get()'
1365        self.createTable(table, '"Prime!" smallint primary key,'
1366                                ' "much space" integer, "Questions?" text',
1367                         values=[(17, 1001, 'No!')])
1368        r = get(table, 17)
1369        self.assertIsInstance(r, dict)
1370        self.assertEqual(r['Prime!'], 17)
1371        self.assertEqual(r['much space'], 1001)
1372        self.assertEqual(r['Questions?'], 'No!')
1373
1374    def testGetFromView(self):
1375        self.db.query('delete from test where i4=14')
1376        self.db.query('insert into test (i4, v4) values('
1377                      "14, 'abc4')")
1378        r = self.db.get('test_view', 14, 'i4')
1379        self.assertIn('v4', r)
1380        self.assertEqual(r['v4'], 'abc4')
1381
1382    def testGetLittleBobbyTables(self):
1383        get = self.db.get
1384        query = self.db.query
1385        self.createTable('test_students',
1386                         'firstname varchar primary key, nickname varchar, grade char(2)',
1387                         values=[("D'Arcy", 'Darcey', 'A+'), ('Sheldon', 'Moonpie', 'A+'),
1388                                 ('Robert', 'Little Bobby Tables', 'D-')])
1389        r = get('test_students', 'Sheldon')
1390        self.assertEqual(r, dict(
1391            firstname="Sheldon", nickname='Moonpie', grade='A+'))
1392        r = get('test_students', 'Robert')
1393        self.assertEqual(r, dict(
1394            firstname="Robert", nickname='Little Bobby Tables', grade='D-'))
1395        r = get('test_students', "D'Arcy")
1396        self.assertEqual(r, dict(
1397            firstname="D'Arcy", nickname='Darcey', grade='A+'))
1398        try:
1399            get('test_students', "D' Arcy")
1400        except pg.DatabaseError as error:
1401            self.assertEqual(str(error),
1402                             'No such record in test_students\nwhere "firstname" = $1\n'
1403                             'with $1="D\' Arcy"')
1404        try:
1405            get('test_students', "Robert'); TRUNCATE TABLE test_students;--")
1406        except pg.DatabaseError as error:
1407            self.assertEqual(str(error),
1408                             'No such record in test_students\nwhere "firstname" = $1\n'
1409                             'with $1="Robert\'); TRUNCATE TABLE test_students;--"')
1410        q = "select * from test_students order by 1 limit 4"
1411        r = query(q).getresult()
1412        self.assertEqual(len(r), 3)
1413        self.assertEqual(r[1][2], 'D-')
1414
1415    def testInsert(self):
1416        insert = self.db.insert
1417        query = self.db.query
1418        bool_on = pg.get_bool()
1419        decimal = pg.get_decimal()
1420        table = 'insert_test_table'
1421        self.createTable(table,
1422                         'i2 smallint, i4 integer, i8 bigint,'
1423                         ' d numeric, f4 real, f8 double precision, m money,'
1424                         ' v4 varchar(4), c4 char(4), t text,'
1425                         ' b boolean, ts timestamp', oids=True)
1426        oid_table = 'oid(%s)' % table
1427        tests = [dict(i2=None, i4=None, i8=None),
1428                 (dict(i2='', i4='', i8=''), dict(i2=None, i4=None, i8=None)),
1429                 (dict(i2=0, i4=0, i8=0), dict(i2=0, i4=0, i8=0)),
1430                 dict(i2=42, i4=123456, i8=9876543210),
1431                 dict(i2=2 ** 15 - 1,
1432                      i4=int(2 ** 31 - 1), i8=long(2 ** 63 - 1)),
1433                 dict(d=None), (dict(d=''), dict(d=None)),
1434                 dict(d=Decimal(0)), (dict(d=0), dict(d=Decimal(0))),
1435                 dict(f4=None, f8=None), dict(f4=0, f8=0),
1436                 (dict(f4='', f8=''), dict(f4=None, f8=None)),
1437                 (dict(d=1234.5, f4=1234.5, f8=1234.5),
1438                  dict(d=Decimal('1234.5'))),
1439                 dict(d=Decimal('123.456789'), f4=12.375, f8=123.4921875),
1440                 dict(d=Decimal('123456789.9876543212345678987654321')),
1441                 dict(m=None), (dict(m=''), dict(m=None)),
1442                 dict(m=Decimal('-1234.56')),
1443                 (dict(m=('-1234.56')), dict(m=Decimal('-1234.56'))),
1444                 dict(m=Decimal('1234.56')), dict(m=Decimal('123456')),
1445                 (dict(m='1234.56'), dict(m=Decimal('1234.56'))),
1446                 (dict(m=1234.5), dict(m=Decimal('1234.5'))),
1447                 (dict(m=-1234.5), dict(m=Decimal('-1234.5'))),
1448                 (dict(m=123456), dict(m=Decimal('123456'))),
1449                 (dict(m='1234567.89'), dict(m=Decimal('1234567.89'))),
1450                 dict(b=None), (dict(b=''), dict(b=None)),
1451                 dict(b='f'), dict(b='t'),
1452                 (dict(b=0), dict(b='f')), (dict(b=1), dict(b='t')),
1453                 (dict(b=False), dict(b='f')), (dict(b=True), dict(b='t')),
1454                 (dict(b='0'), dict(b='f')), (dict(b='1'), dict(b='t')),
1455                 (dict(b='n'), dict(b='f')), (dict(b='y'), dict(b='t')),
1456                 (dict(b='no'), dict(b='f')), (dict(b='yes'), dict(b='t')),
1457                 (dict(b='off'), dict(b='f')), (dict(b='on'), dict(b='t')),
1458                 dict(v4=None, c4=None, t=None),
1459                 (dict(v4='', c4='', t=''), dict(c4=' ' * 4)),
1460                 dict(v4='1234', c4='1234', t='1234' * 10),
1461                 dict(v4='abcd', c4='abcd', t='abcdefg'),
1462                 (dict(v4='abc', c4='abc', t='abc'), dict(c4='abc ')),
1463                 dict(ts=None), (dict(ts=''), dict(ts=None)),
1464                 (dict(ts=0), dict(ts=None)), (dict(ts=False), dict(ts=None)),
1465                 dict(ts='2012-12-21 00:00:00'),
1466                 (dict(ts='2012-12-21'), dict(ts='2012-12-21 00:00:00')),
1467                 dict(ts='2012-12-21 12:21:12'),
1468                 dict(ts='2013-01-05 12:13:14'),
1469                 dict(ts='current_timestamp')]
1470        for test in tests:
1471            if isinstance(test, dict):
1472                data = test
1473                change = {}
1474            else:
1475                data, change = test
1476            expect = data.copy()
1477            expect.update(change)
1478            if bool_on:
1479                b = expect.get('b')
1480                if b is not None:
1481                    expect['b'] = b == 't'
1482            if decimal is not Decimal:
1483                d = expect.get('d')
1484                if d is not None:
1485                    expect['d'] = decimal(d)
1486                m = expect.get('m')
1487                if m is not None:
1488                    expect['m'] = decimal(m)
1489            self.assertEqual(insert(table, data), data)
1490            self.assertIn(oid_table, data)
1491            oid = data[oid_table]
1492            self.assertIsInstance(oid, int)
1493            data = dict(item for item in data.items()
1494                        if item[0] in expect)
1495            ts = expect.get('ts')
1496            if ts == 'current_timestamp':
1497                ts = expect['ts'] = data['ts']
1498                if len(ts) > 19:
1499                    self.assertEqual(ts[19], '.')
1500                    ts = ts[:19]
1501                else:
1502                    self.assertEqual(len(ts), 19)
1503                self.assertTrue(ts[:4].isdigit())
1504                self.assertEqual(ts[4], '-')
1505                self.assertEqual(ts[10], ' ')
1506                self.assertTrue(ts[11:13].isdigit())
1507                self.assertEqual(ts[13], ':')
1508            self.assertEqual(data, expect)
1509            data = query(
1510                'select oid,* from "%s"' % table).dictresult()[0]
1511            self.assertEqual(data['oid'], oid)
1512            data = dict(item for item in data.items()
1513                        if item[0] in expect)
1514            self.assertEqual(data, expect)
1515            query('delete from "%s"' % table)
1516
1517    def testInsertWithOid(self):
1518        insert = self.db.insert
1519        query = self.db.query
1520        self.createTable('test_table', 'n int', oids=True)
1521        self.assertRaises(pg.ProgrammingError, insert, 'test_table', m=1)
1522        r = insert('test_table', n=1)
1523        self.assertIsInstance(r, dict)
1524        self.assertEqual(r['n'], 1)
1525        self.assertNotIn('oid', r)
1526        qoid = 'oid(test_table)'
1527        self.assertIn(qoid, r)
1528        oid = r[qoid]
1529        self.assertEqual(sorted(r.keys()), ['n', qoid])
1530        r = insert('test_table', n=2, oid=oid)
1531        self.assertIsInstance(r, dict)
1532        self.assertEqual(r['n'], 2)
1533        self.assertIn(qoid, r)
1534        self.assertNotEqual(r[qoid], oid)
1535        self.assertNotIn('oid', r)
1536        r = insert('test_table', None, n=3)
1537        self.assertIsInstance(r, dict)
1538        self.assertEqual(r['n'], 3)
1539        s = r
1540        r = insert('test_table', r)
1541        self.assertIs(r, s)
1542        self.assertEqual(r['n'], 3)
1543        r = insert('test_table *', r)
1544        self.assertIs(r, s)
1545        self.assertEqual(r['n'], 3)
1546        r = insert('test_table', r, n=4)
1547        self.assertIs(r, s)
1548        self.assertEqual(r['n'], 4)
1549        self.assertNotIn('oid', r)
1550        self.assertIn(qoid, r)
1551        oid = r[qoid]
1552        r = insert('test_table', r, n=5, oid=oid)
1553        self.assertIs(r, s)
1554        self.assertEqual(r['n'], 5)
1555        self.assertIn(qoid, r)
1556        self.assertNotEqual(r[qoid], oid)
1557        self.assertNotIn('oid', r)
1558        r['oid'] = oid = r[qoid]
1559        r = insert('test_table', r, n=6)
1560        self.assertIs(r, s)
1561        self.assertEqual(r['n'], 6)
1562        self.assertIn(qoid, r)
1563        self.assertNotEqual(r[qoid], oid)
1564        self.assertNotIn('oid', r)
1565        q = 'select n from test_table order by 1 limit 9'
1566        r = ' '.join(str(row[0]) for row in query(q).getresult())
1567        self.assertEqual(r, '1 2 3 3 3 4 5 6')
1568        query("truncate test_table")
1569        query("alter table test_table add unique (n)")
1570        r = insert('test_table', dict(n=7))
1571        self.assertIsInstance(r, dict)
1572        self.assertEqual(r['n'], 7)
1573        self.assertRaises(pg.ProgrammingError, insert, 'test_table', r)
1574        r['n'] = 6
1575        self.assertRaises(pg.ProgrammingError, insert, 'test_table', r, n=7)
1576        self.assertIsInstance(r, dict)
1577        self.assertEqual(r['n'], 7)
1578        r['n'] = 6
1579        r = insert('test_table', r)
1580        self.assertIsInstance(r, dict)
1581        self.assertEqual(r['n'], 6)
1582        r = ' '.join(str(row[0]) for row in query(q).getresult())
1583        self.assertEqual(r, '6 7')
1584
1585    def testInsertWithQuotedNames(self):
1586        insert = self.db.insert
1587        query = self.db.query
1588        table = 'test table for insert()'
1589        self.createTable(table, '"Prime!" smallint primary key,'
1590                                ' "much space" integer, "Questions?" text')
1591        r = {'Prime!': 11, 'much space': 2002, 'Questions?': 'What?'}
1592        r = insert(table, r)
1593        self.assertIsInstance(r, dict)
1594        self.assertEqual(r['Prime!'], 11)
1595        self.assertEqual(r['much space'], 2002)
1596        self.assertEqual(r['Questions?'], 'What?')
1597        r = query('select * from "%s" limit 2' % table).dictresult()
1598        self.assertEqual(len(r), 1)
1599        r = r[0]
1600        self.assertEqual(r['Prime!'], 11)
1601        self.assertEqual(r['much space'], 2002)
1602        self.assertEqual(r['Questions?'], 'What?')
1603
1604    def testInsertIntoView(self):
1605        insert = self.db.insert
1606        query = self.db.query
1607        query("truncate test")
1608        q = 'select * from test_view order by i4 limit 3'
1609        r = query(q).getresult()
1610        self.assertEqual(r, [])
1611        r = dict(i4=1234, v4='abcd')
1612        insert('test', r)
1613        self.assertIsNone(r['i2'])
1614        self.assertEqual(r['i4'], 1234)
1615        self.assertIsNone(r['i8'])
1616        self.assertEqual(r['v4'], 'abcd')
1617        self.assertIsNone(r['c4'])
1618        r = query(q).getresult()
1619        self.assertEqual(r, [(1234, 'abcd')])
1620        r = dict(i4=5678, v4='efgh')
1621        try:
1622            insert('test_view', r)
1623        except pg.ProgrammingError as error:
1624            if self.db.server_version < 90300:
1625                # must setup rules in older PostgreSQL versions
1626                self.skipTest('database cannot insert into view')
1627            self.fail(str(error))
1628        self.assertNotIn('i2', r)
1629        self.assertEqual(r['i4'], 5678)
1630        self.assertNotIn('i8', r)
1631        self.assertEqual(r['v4'], 'efgh')
1632        self.assertNotIn('c4', r)
1633        r = query(q).getresult()
1634        self.assertEqual(r, [(1234, 'abcd'), (5678, 'efgh')])
1635
1636    def testUpdate(self):
1637        update = self.db.update
1638        query = self.db.query
1639        self.assertRaises(pg.ProgrammingError, update,
1640                          'test', i2=2, i4=4, i8=8)
1641        table = 'update_test_table'
1642        self.createTable(table, 'n integer, t text', oids=True,
1643                         values=enumerate('xyz', start=1))
1644        self.assertRaises(pg.ProgrammingError, self.db.get, table, 2)
1645        r = self.db.get(table, 2, 'n')
1646        r['t'] = 'u'
1647        s = update(table, r)
1648        self.assertEqual(s, r)
1649        q = 'select t from "%s" where n=2' % table
1650        r = query(q).getresult()[0][0]
1651        self.assertEqual(r, 'u')
1652
1653    def testUpdateWithOid(self):
1654        update = self.db.update
1655        get = self.db.get
1656        query = self.db.query
1657        self.createTable('test_table', 'n int', oids=True, values=[1])
1658        s = get('test_table', 1, 'n')
1659        self.assertIsInstance(s, dict)
1660        self.assertEqual(s['n'], 1)
1661        s['n'] = 2
1662        r = update('test_table', s)
1663        self.assertIs(r, s)
1664        self.assertEqual(r['n'], 2)
1665        qoid = 'oid(test_table)'
1666        self.assertIn(qoid, r)
1667        self.assertNotIn('oid', r)
1668        self.assertEqual(sorted(r.keys()), ['n', qoid])
1669        r['n'] = 3
1670        oid = r.pop(qoid)
1671        r = update('test_table', r, oid=oid)
1672        self.assertIs(r, s)
1673        self.assertEqual(r['n'], 3)
1674        r.pop(qoid)
1675        self.assertRaises(pg.ProgrammingError, update, 'test_table', r)
1676        s = get('test_table', 3, 'n')
1677        self.assertIsInstance(s, dict)
1678        self.assertEqual(s['n'], 3)
1679        s.pop('n')
1680        r = update('test_table', s)
1681        oid = r.pop(qoid)
1682        self.assertEqual(r, {})
1683        q = "select n from test_table limit 2"
1684        r = query(q).getresult()
1685        self.assertEqual(r, [(3,)])
1686        query("insert into test_table values (1)")
1687        self.assertRaises(pg.ProgrammingError,
1688                          update, 'test_table', dict(oid=oid, n=4))
1689        r = update('test_table', dict(n=4), oid=oid)
1690        self.assertEqual(r['n'], 4)
1691        r = update('test_table *', dict(n=5), oid=oid)
1692        self.assertEqual(r['n'], 5)
1693        query("alter table test_table add column m int")
1694        query("alter table test_table add primary key (n)")
1695        self.assertIn('m', self.db.get_attnames('test_table', flush=True))
1696        self.assertEqual('n', self.db.pkey('test_table', flush=True))
1697        s = dict(n=1, m=4)
1698        r = update('test_table', s)
1699        self.assertIs(r, s)
1700        self.assertEqual(r['n'], 1)
1701        self.assertEqual(r['m'], 4)
1702        s = dict(m=7)
1703        r = update('test_table', s, n=5)
1704        self.assertIs(r, s)
1705        self.assertEqual(r['n'], 5)
1706        self.assertEqual(r['m'], 7)
1707        q = "select n, m from test_table order by 1 limit 3"
1708        r = query(q).getresult()
1709        self.assertEqual(r, [(1, 4), (5, 7)])
1710        s = dict(m=9, oid=oid)
1711        self.assertRaises(KeyError, update, 'test_table', s)
1712        r = update('test_table', s, oid=oid)
1713        self.assertIs(r, s)
1714        self.assertEqual(r['n'], 5)
1715        self.assertEqual(r['m'], 9)
1716        s = dict(n=1, m=3, oid=oid)
1717        r = update('test_table', s)
1718        self.assertIs(r, s)
1719        self.assertEqual(r['n'], 1)
1720        self.assertEqual(r['m'], 3)
1721        r = query(q).getresult()
1722        self.assertEqual(r, [(1, 3), (5, 9)])
1723
1724    def testUpdateWithoutOid(self):
1725        update = self.db.update
1726        query = self.db.query
1727        self.assertRaises(pg.ProgrammingError, update,
1728                          'test', i2=2, i4=4, i8=8)
1729        table = 'update_test_table'
1730        self.createTable(table, 'n integer primary key, t text', oids=False,
1731                         values=enumerate('xyz', start=1))
1732        r = self.db.get(table, 2)
1733        r['t'] = 'u'
1734        s = update(table, r)
1735        self.assertEqual(s, r)
1736        q = 'select t from "%s" where n=2' % table
1737        r = query(q).getresult()[0][0]
1738        self.assertEqual(r, 'u')
1739
1740    def testUpdateWithCompositeKey(self):
1741        update = self.db.update
1742        query = self.db.query
1743        table = 'update_test_table_1'
1744        self.createTable(table, 'n integer primary key, t text',
1745                         values=enumerate('abc', start=1))
1746        self.assertRaises(KeyError, update, table, dict(t='b'))
1747        s = dict(n=2, t='d')
1748        r = update(table, s)
1749        self.assertIs(r, s)
1750        self.assertEqual(r['n'], 2)
1751        self.assertEqual(r['t'], 'd')
1752        q = 'select t from "%s" where n=2' % table
1753        r = query(q).getresult()[0][0]
1754        self.assertEqual(r, 'd')
1755        s.update(dict(n=4, t='e'))
1756        r = update(table, s)
1757        self.assertEqual(r['n'], 4)
1758        self.assertEqual(r['t'], 'e')
1759        q = 'select t from "%s" where n=2' % table
1760        r = query(q).getresult()[0][0]
1761        self.assertEqual(r, 'd')
1762        q = 'select t from "%s" where n=4' % table
1763        r = query(q).getresult()
1764        self.assertEqual(len(r), 0)
1765        query('drop table "%s"' % table)
1766        table = 'update_test_table_2'
1767        self.createTable(table,
1768                         'n integer, m integer, t text, primary key (n, m)',
1769                         values=[(n + 1, m + 1, chr(ord('a') + 2 * n + m))
1770                                 for n in range(3) for m in range(2)])
1771        self.assertRaises(KeyError, update, table, dict(n=2, t='b'))
1772        self.assertEqual(update(table,
1773                                dict(n=2, m=2, t='x'))['t'], 'x')
1774        q = 'select t from "%s" where n=2 order by m' % table
1775        r = [r[0] for r in query(q).getresult()]
1776        self.assertEqual(r, ['c', 'x'])
1777
1778    def testUpdateWithQuotedNames(self):
1779        update = self.db.update
1780        query = self.db.query
1781        table = 'test table for update()'
1782        self.createTable(table, '"Prime!" smallint primary key,'
1783                                ' "much space" integer, "Questions?" text',
1784                         values=[(13, 3003, 'Why!')])
1785        r = {'Prime!': 13, 'much space': 7007, 'Questions?': 'When?'}
1786        r = update(table, r)
1787        self.assertIsInstance(r, dict)
1788        self.assertEqual(r['Prime!'], 13)
1789        self.assertEqual(r['much space'], 7007)
1790        self.assertEqual(r['Questions?'], 'When?')
1791        r = query('select * from "%s" limit 2' % table).dictresult()
1792        self.assertEqual(len(r), 1)
1793        r = r[0]
1794        self.assertEqual(r['Prime!'], 13)
1795        self.assertEqual(r['much space'], 7007)
1796        self.assertEqual(r['Questions?'], 'When?')
1797
1798    def testUpsert(self):
1799        upsert = self.db.upsert
1800        query = self.db.query
1801        self.assertRaises(pg.ProgrammingError, upsert,
1802                          'test', i2=2, i4=4, i8=8)
1803        table = 'upsert_test_table'
1804        self.createTable(table, 'n integer primary key, t text', oids=True)
1805        s = dict(n=1, t='x')
1806        try:
1807            r = upsert(table, s)
1808        except pg.ProgrammingError as error:
1809            if self.db.server_version < 90500:
1810                self.skipTest('database does not support upsert')
1811            self.fail(str(error))
1812        self.assertIs(r, s)
1813        self.assertEqual(r['n'], 1)
1814        self.assertEqual(r['t'], 'x')
1815        s.update(n=2, t='y')
1816        r = upsert(table, s, **dict.fromkeys(s))
1817        self.assertIs(r, s)
1818        self.assertEqual(r['n'], 2)
1819        self.assertEqual(r['t'], 'y')
1820        q = 'select n, t from "%s" order by n limit 3' % table
1821        r = query(q).getresult()
1822        self.assertEqual(r, [(1, 'x'), (2, 'y')])
1823        s.update(t='z')
1824        r = upsert(table, s)
1825        self.assertIs(r, s)
1826        self.assertEqual(r['n'], 2)
1827        self.assertEqual(r['t'], 'z')
1828        r = query(q).getresult()
1829        self.assertEqual(r, [(1, 'x'), (2, 'z')])
1830        s.update(t='n')
1831        r = upsert(table, s, t=False)
1832        self.assertIs(r, s)
1833        self.assertEqual(r['n'], 2)
1834        self.assertEqual(r['t'], 'z')
1835        r = query(q).getresult()
1836        self.assertEqual(r, [(1, 'x'), (2, 'z')])
1837        s.update(t='y')
1838        r = upsert(table, s, t=True)
1839        self.assertIs(r, s)
1840        self.assertEqual(r['n'], 2)
1841        self.assertEqual(r['t'], 'y')
1842        r = query(q).getresult()
1843        self.assertEqual(r, [(1, 'x'), (2, 'y')])
1844        s.update(t='n')
1845        r = upsert(table, s, t="included.t || '2'")
1846        self.assertIs(r, s)
1847        self.assertEqual(r['n'], 2)
1848        self.assertEqual(r['t'], 'y2')
1849        r = query(q).getresult()
1850        self.assertEqual(r, [(1, 'x'), (2, 'y2')])
1851        s.update(t='y')
1852        r = upsert(table, s, t="excluded.t || '3'")
1853        self.assertIs(r, s)
1854        self.assertEqual(r['n'], 2)
1855        self.assertEqual(r['t'], 'y3')
1856        r = query(q).getresult()
1857        self.assertEqual(r, [(1, 'x'), (2, 'y3')])
1858        s.update(n=1, t='2')
1859        r = upsert(table, s, t="included.t || excluded.t")
1860        self.assertIs(r, s)
1861        self.assertEqual(r['n'], 1)
1862        self.assertEqual(r['t'], 'x2')
1863        r = query(q).getresult()
1864        self.assertEqual(r, [(1, 'x2'), (2, 'y3')])
1865        # not existing columns and oid parameter should be ignored
1866        s = dict(m=3, u='z')
1867        r = upsert(table, s, oid='invalid')
1868        self.assertIs(r, s)
1869
1870    def testUpsertWithOid(self):
1871        upsert = self.db.upsert
1872        get = self.db.get
1873        query = self.db.query
1874        self.createTable('test_table', 'n int', oids=True, values=[1])
1875        self.assertRaises(pg.ProgrammingError,
1876                          upsert, 'test_table', dict(n=2))
1877        r = get('test_table', 1, 'n')
1878        self.assertIsInstance(r, dict)
1879        self.assertEqual(r['n'], 1)
1880        qoid = 'oid(test_table)'
1881        self.assertIn(qoid, r)
1882        self.assertNotIn('oid', r)
1883        oid = r[qoid]
1884        self.assertRaises(pg.ProgrammingError,
1885                          upsert, 'test_table', dict(n=2, oid=oid))
1886        query("alter table test_table add column m int")
1887        query("alter table test_table add primary key (n)")
1888        self.assertIn('m', self.db.get_attnames('test_table', flush=True))
1889        self.assertEqual('n', self.db.pkey('test_table', flush=True))
1890        s = dict(n=2)
1891        try:
1892            r = upsert('test_table', s)
1893        except pg.ProgrammingError as error:
1894            if self.db.server_version < 90500:
1895                self.skipTest('database does not support upsert')
1896            self.fail(str(error))
1897        self.assertIs(r, s)
1898        self.assertEqual(r['n'], 2)
1899        self.assertIsNone(r['m'])
1900        q = query("select n, m from test_table order by n limit 3")
1901        self.assertEqual(q.getresult(), [(1, None), (2, None)])
1902        r['oid'] = oid
1903        r = upsert('test_table', r)
1904        self.assertIs(r, s)
1905        self.assertEqual(r['n'], 2)
1906        self.assertIsNone(r['m'])
1907        self.assertIn(qoid, r)
1908        self.assertNotIn('oid', r)
1909        self.assertNotEqual(r[qoid], oid)
1910        r['m'] = 7
1911        r = upsert('test_table', r)
1912        self.assertIs(r, s)
1913        self.assertEqual(r['n'], 2)
1914        self.assertEqual(r['m'], 7)
1915        r.update(n=1, m=3)
1916        r = upsert('test_table', r)
1917        self.assertIs(r, s)
1918        self.assertEqual(r['n'], 1)
1919        self.assertEqual(r['m'], 3)
1920        q = query("select n, m from test_table order by n limit 3")
1921        self.assertEqual(q.getresult(), [(1, 3), (2, 7)])
1922        r = upsert('test_table', r, oid='invalid')
1923        self.assertIs(r, s)
1924        self.assertEqual(r['n'], 1)
1925        self.assertEqual(r['m'], 3)
1926        r['m'] = 5
1927        r = upsert('test_table', r, m=False)
1928        self.assertIs(r, s)
1929        self.assertEqual(r['n'], 1)
1930        self.assertEqual(r['m'], 3)
1931        r['m'] = 5
1932        r = upsert('test_table', r, m=True)
1933        self.assertIs(r, s)
1934        self.assertEqual(r['n'], 1)
1935        self.assertEqual(r['m'], 5)
1936        r.update(n=2, m=1)
1937        r = upsert('test_table', r, m='included.m')
1938        self.assertIs(r, s)
1939        self.assertEqual(r['n'], 2)
1940        self.assertEqual(r['m'], 7)
1941        r['m'] = 9
1942        r = upsert('test_table', r, m='excluded.m')
1943        self.assertIs(r, s)
1944        self.assertEqual(r['n'], 2)
1945        self.assertEqual(r['m'], 9)
1946        r['m'] = 8
1947        r = upsert('test_table *', r, m='included.m + 1')
1948        self.assertIs(r, s)
1949        self.assertEqual(r['n'], 2)
1950        self.assertEqual(r['m'], 10)
1951        q = query("select n, m from test_table order by n limit 3")
1952        self.assertEqual(q.getresult(), [(1, 5), (2, 10)])
1953
1954    def testUpsertWithCompositeKey(self):
1955        upsert = self.db.upsert
1956        query = self.db.query
1957        table = 'upsert_test_table_2'
1958        self.createTable(table,
1959                         'n integer, m integer, t text, primary key (n, m)')
1960        s = dict(n=1, m=2, t='x')
1961        try:
1962            r = upsert(table, s)
1963        except pg.ProgrammingError as error:
1964            if self.db.server_version < 90500:
1965                self.skipTest('database does not support upsert')
1966            self.fail(str(error))
1967        self.assertIs(r, s)
1968        self.assertEqual(r['n'], 1)
1969        self.assertEqual(r['m'], 2)
1970        self.assertEqual(r['t'], 'x')
1971        s.update(m=3, t='y')
1972        r = upsert(table, s, **dict.fromkeys(s))
1973        self.assertIs(r, s)
1974        self.assertEqual(r['n'], 1)
1975        self.assertEqual(r['m'], 3)
1976        self.assertEqual(r['t'], 'y')
1977        q = 'select n, m, t from "%s" order by n, m limit 3' % table
1978        r = query(q).getresult()
1979        self.assertEqual(r, [(1, 2, 'x'), (1, 3, 'y')])
1980        s.update(t='z')
1981        r = upsert(table, s)
1982        self.assertIs(r, s)
1983        self.assertEqual(r['n'], 1)
1984        self.assertEqual(r['m'], 3)
1985        self.assertEqual(r['t'], 'z')
1986        r = query(q).getresult()
1987        self.assertEqual(r, [(1, 2, 'x'), (1, 3, 'z')])
1988        s.update(t='n')
1989        r = upsert(table, s, t=False)
1990        self.assertIs(r, s)
1991        self.assertEqual(r['n'], 1)
1992        self.assertEqual(r['m'], 3)
1993        self.assertEqual(r['t'], 'z')
1994        r = query(q).getresult()
1995        self.assertEqual(r, [(1, 2, 'x'), (1, 3, 'z')])
1996        s.update(t='n')
1997        r = upsert(table, s, t=True)
1998        self.assertIs(r, s)
1999        self.assertEqual(r['n'], 1)
2000        self.assertEqual(r['m'], 3)
2001        self.assertEqual(r['t'], 'n')
2002        r = query(q).getresult()
2003        self.assertEqual(r, [(1, 2, 'x'), (1, 3, 'n')])
2004        s.update(n=2, t='y')
2005        r = upsert(table, s, t="'z'")
2006        self.assertIs(r, s)
2007        self.assertEqual(r['n'], 2)
2008        self.assertEqual(r['m'], 3)
2009        self.assertEqual(r['t'], 'y')
2010        r = query(q).getresult()
2011        self.assertEqual(r, [(1, 2, 'x'), (1, 3, 'n'), (2, 3, 'y')])
2012        s.update(n=1, t='m')
2013        r = upsert(table, s, t='included.t || excluded.t')
2014        self.assertIs(r, s)
2015        self.assertEqual(r['n'], 1)
2016        self.assertEqual(r['m'], 3)
2017        self.assertEqual(r['t'], 'nm')
2018        r = query(q).getresult()
2019        self.assertEqual(r, [(1, 2, 'x'), (1, 3, 'nm'), (2, 3, 'y')])
2020
2021    def testUpsertWithQuotedNames(self):
2022        upsert = self.db.upsert
2023        query = self.db.query
2024        table = 'test table for upsert()'
2025        self.createTable(table, '"Prime!" smallint primary key,'
2026                                ' "much space" integer, "Questions?" text')
2027        s = {'Prime!': 31, 'much space': 9009, 'Questions?': 'Yes.'}
2028        try:
2029            r = upsert(table, s)
2030        except pg.ProgrammingError as error:
2031            if self.db.server_version < 90500:
2032                self.skipTest('database does not support upsert')
2033            self.fail(str(error))
2034        self.assertIs(r, s)
2035        self.assertEqual(r['Prime!'], 31)
2036        self.assertEqual(r['much space'], 9009)
2037        self.assertEqual(r['Questions?'], 'Yes.')
2038        q = 'select * from "%s" limit 2' % table
2039        r = query(q).getresult()
2040        self.assertEqual(r, [(31, 9009, 'Yes.')])
2041        s.update({'Questions?': 'No.'})
2042        r = upsert(table, s)
2043        self.assertIs(r, s)
2044        self.assertEqual(r['Prime!'], 31)
2045        self.assertEqual(r['much space'], 9009)
2046        self.assertEqual(r['Questions?'], 'No.')
2047        r = query(q).getresult()
2048        self.assertEqual(r, [(31, 9009, 'No.')])
2049
2050    def testClear(self):
2051        clear = self.db.clear
2052        f = False if pg.get_bool() else 'f'
2053        r = clear('test')
2054        result = dict(
2055            i2=0, i4=0, i8=0, d=0, f4=0, f8=0, m=0, v4='', c4='', t='')
2056        self.assertEqual(r, result)
2057        table = 'clear_test_table'
2058        self.createTable(table,
2059                         'n integer, f float, b boolean, d date, t text', oids=True)
2060        r = clear(table)
2061        result = dict(n=0, f=0, b=f, d='', t='')
2062        self.assertEqual(r, result)
2063        r['a'] = r['f'] = r['n'] = 1
2064        r['d'] = r['t'] = 'x'
2065        r['b'] = 't'
2066        r['oid'] = long(1)
2067        r = clear(table, r)
2068        result = dict(a=1, n=0, f=0, b=f, d='', t='', oid=long(1))
2069        self.assertEqual(r, result)
2070
2071    def testClearWithQuotedNames(self):
2072        clear = self.db.clear
2073        table = 'test table for clear()'
2074        self.createTable(table, '"Prime!" smallint primary key,'
2075                                ' "much space" integer, "Questions?" text')
2076        r = clear(table)
2077        self.assertIsInstance(r, dict)
2078        self.assertEqual(r['Prime!'], 0)
2079        self.assertEqual(r['much space'], 0)
2080        self.assertEqual(r['Questions?'], '')
2081
2082    def testDelete(self):
2083        delete = self.db.delete
2084        query = self.db.query
2085        self.assertRaises(pg.ProgrammingError, delete,
2086                          'test', dict(i2=2, i4=4, i8=8))
2087        table = 'delete_test_table'
2088        self.createTable(table, 'n integer, t text', oids=True,
2089                         values=enumerate('xyz', start=1))
2090        self.assertRaises(pg.ProgrammingError, self.db.get, table, 2)
2091        r = self.db.get(table, 1, 'n')
2092        s = delete(table, r)
2093        self.assertEqual(s, 1)
2094        r = self.db.get(table, 3, 'n')
2095        s = delete(table, r)
2096        self.assertEqual(s, 1)
2097        s = delete(table, r)
2098        self.assertEqual(s, 0)
2099        r = query('select * from "%s"' % table).dictresult()
2100        self.assertEqual(len(r), 1)
2101        r = r[0]
2102        result = {'n': 2, 't': 'y'}
2103        self.assertEqual(r, result)
2104        r = self.db.get(table, 2, 'n')
2105        s = delete(table, r)
2106        self.assertEqual(s, 1)
2107        s = delete(table, r)
2108        self.assertEqual(s, 0)
2109        self.assertRaises(pg.DatabaseError, self.db.get, table, 2, 'n')
2110        # not existing columns and oid parameter should be ignored
2111        r.update(m=3, u='z', oid='invalid')
2112        s = delete(table, r)
2113        self.assertEqual(s, 0)
2114
2115    def testDeleteWithOid(self):
2116        delete = self.db.delete
2117        get = self.db.get
2118        query = self.db.query
2119        self.createTable('test_table', 'n int', oids=True, values=range(1, 7))
2120        r = dict(n=3)
2121        self.assertRaises(pg.ProgrammingError, delete, 'test_table', r)
2122        s = get('test_table', 1, 'n')
2123        qoid = 'oid(test_table)'
2124        self.assertIn(qoid, s)
2125        r = delete('test_table', s)
2126        self.assertEqual(r, 1)
2127        r = delete('test_table', s)
2128        self.assertEqual(r, 0)
2129        q = "select min(n),count(n) from test_table"
2130        self.assertEqual(query(q).getresult()[0], (2, 5))
2131        oid = get('test_table', 2, 'n')[qoid]
2132        s = dict(oid=oid, n=2)
2133        self.assertRaises(pg.ProgrammingError, delete, 'test_table', s)
2134        r = delete('test_table', None, oid=oid)
2135        self.assertEqual(r, 1)
2136        r = delete('test_table', None, oid=oid)
2137        self.assertEqual(r, 0)
2138        self.assertEqual(query(q).getresult()[0], (3, 4))
2139        s = dict(oid=oid, n=2)
2140        oid = get('test_table', 3, 'n')[qoid]
2141        self.assertRaises(pg.ProgrammingError, delete, 'test_table', s)
2142        r = delete('test_table', s, oid=oid)
2143        self.assertEqual(r, 1)
2144        r = delete('test_table', s, oid=oid)
2145        self.assertEqual(r, 0)
2146        self.assertEqual(query(q).getresult()[0], (4, 3))
2147        s = get('test_table', 4, 'n')
2148        r = delete('test_table *', s)
2149        self.assertEqual(r, 1)
2150        r = delete('test_table *', s)
2151        self.assertEqual(r, 0)
2152        self.assertEqual(query(q).getresult()[0], (5, 2))
2153        oid = get('test_table', 5, 'n')[qoid]
2154        s = {qoid: oid, 'm': 4}
2155        r = delete('test_table', s, m=6)
2156        self.assertEqual(r, 1)
2157        r = delete('test_table *', s)
2158        self.assertEqual(r, 0)
2159        self.assertEqual(query(q).getresult()[0], (6, 1))
2160        query("alter table test_table add column m int")
2161        query("alter table test_table add primary key (n)")
2162        self.assertIn('m', self.db.get_attnames('test_table', flush=True))
2163        self.assertEqual('n', self.db.pkey('test_table', flush=True))
2164        for i in range(5):
2165            query("insert into test_table values (%d, %d)" % (i + 1, i + 2))
2166        s = dict(m=2)
2167        self.assertRaises(KeyError, delete, 'test_table', s)
2168        s = dict(m=2, oid=oid)
2169        self.assertRaises(KeyError, delete, 'test_table', s)
2170        r = delete('test_table', dict(m=2), oid=oid)
2171        self.assertEqual(r, 0)
2172        oid = get('test_table', 1, 'n')[qoid]
2173        s = dict(oid=oid)
2174        self.assertRaises(KeyError, delete, 'test_table', s)
2175        r = delete('test_table', s, oid=oid)
2176        self.assertEqual(r, 1)
2177        r = delete('test_table', s, oid=oid)
2178        self.assertEqual(r, 0)
2179        self.assertEqual(query(q).getresult()[0], (2, 5))
2180        s = get('test_table', 2, 'n')
2181        del s['n']
2182        r = delete('test_table', s)
2183        self.assertEqual(r, 1)
2184        r = delete('test_table', s)
2185        self.assertEqual(r, 0)
2186        self.assertEqual(query(q).getresult()[0], (3, 4))
2187        r = delete('test_table', n=3)
2188        self.assertEqual(r, 1)
2189        r = delete('test_table', n=3)
2190        self.assertEqual(r, 0)
2191        self.assertEqual(query(q).getresult()[0], (4, 3))
2192        r = delete('test_table', None, n=4)
2193        self.assertEqual(r, 1)
2194        r = delete('test_table', None, n=4)
2195        self.assertEqual(r, 0)
2196        self.assertEqual(query(q).getresult()[0], (5, 2))
2197        s = dict(n=6)
2198        r = delete('test_table', s, n=5)
2199        self.assertEqual(r, 1)
2200        r = delete('test_table', s, n=5)
2201        self.assertEqual(r, 0)
2202        self.assertEqual(query(q).getresult()[0], (6, 1))
2203
2204    def testDeleteWithCompositeKey(self):
2205        query = self.db.query
2206        table = 'delete_test_table_1'
2207        self.createTable(table, 'n integer primary key, t text',
2208                         values=enumerate('abc', start=1))
2209        self.assertRaises(KeyError, self.db.delete, table, dict(t='b'))
2210        self.assertEqual(self.db.delete(table, dict(n=2)), 1)
2211        r = query('select t from "%s" where n=2' % table).getresult()
2212        self.assertEqual(r, [])
2213        self.assertEqual(self.db.delete(table, dict(n=2)), 0)
2214        r = query('select t from "%s" where n=3' % table).getresult()[0][0]
2215        self.assertEqual(r, 'c')
2216        table = 'delete_test_table_2'
2217        self.createTable(table,
2218                         'n integer, m integer, t text, primary key (n, m)',
2219                         values=[(n + 1, m + 1, chr(ord('a') + 2 * n + m))
2220                                 for n in range(3) for m in range(2)])
2221        self.assertRaises(KeyError, self.db.delete, table, dict(n=2, t='b'))
2222        self.assertEqual(self.db.delete(table, dict(n=2, m=2)), 1)
2223        r = [r[0] for r in query('select t from "%s" where n=2'
2224                                 ' order by m' % table).getresult()]
2225        self.assertEqual(r, ['c'])
2226        self.assertEqual(self.db.delete(table, dict(n=2, m=2)), 0)
2227        r = [r[0] for r in query('select t from "%s" where n=3'
2228                                 ' order by m' % table).getresult()]
2229        self.assertEqual(r, ['e', 'f'])
2230        self.assertEqual(self.db.delete(table, dict(n=3, m=1)), 1)
2231        r = [r[0] for r in query('select t from "%s" where n=3'
2232                                 ' order by m' % table).getresult()]
2233        self.assertEqual(r, ['f'])
2234
2235    def testDeleteWithQuotedNames(self):
2236        delete = self.db.delete
2237        query = self.db.query
2238        table = 'test table for delete()'
2239        self.createTable(table, '"Prime!" smallint primary key,'
2240                                ' "much space" integer, "Questions?" text',
2241                         values=[(19, 5005, 'Yes!')])
2242        r = {'Prime!': 17}
2243        r = delete(table, r)
2244        self.assertEqual(r, 0)
2245        r = query('select count(*) from "%s"' % table).getresult()
2246        self.assertEqual(r[0][0], 1)
2247        r = {'Prime!': 19}
2248        r = delete(table, r)
2249        self.assertEqual(r, 1)
2250        r = query('select count(*) from "%s"' % table).getresult()
2251        self.assertEqual(r[0][0], 0)
2252
2253    def testDeleteReferenced(self):
2254        delete = self.db.delete
2255        query = self.db.query
2256        self.createTable('test_parent',
2257                         'n smallint primary key', values=range(3))
2258        self.createTable('test_child',
2259                         'n smallint primary key references test_parent', values=range(3))
2260        q = ("select (select count(*) from test_parent),"
2261             " (select count(*) from test_child)")
2262        self.assertEqual(query(q).getresult()[0], (3, 3))
2263        self.assertRaises(pg.ProgrammingError,
2264                          delete, 'test_parent', None, n=2)
2265        self.assertRaises(pg.ProgrammingError,
2266                          delete, 'test_parent *', None, n=2)
2267        r = delete('test_child', None, n=2)
2268        self.assertEqual(r, 1)
2269        self.assertEqual(query(q).getresult()[0], (3, 2))
2270        r = delete('test_parent', None, n=2)
2271        self.assertEqual(r, 1)
2272        self.assertEqual(query(q).getresult()[0], (2, 2))
2273        self.assertRaises(pg.ProgrammingError,
2274                          delete, 'test_parent', dict(n=0))
2275        self.assertRaises(pg.ProgrammingError,
2276                          delete, 'test_parent *', dict(n=0))
2277        r = delete('test_child', dict(n=0))
2278        self.assertEqual(r, 1)
2279        self.assertEqual(query(q).getresult()[0], (2, 1))
2280        r = delete('test_child', dict(n=0))
2281        self.assertEqual(r, 0)
2282        r = delete('test_parent', dict(n=0))
2283        self.assertEqual(r, 1)
2284        self.assertEqual(query(q).getresult()[0], (1, 1))
2285        r = delete('test_parent', None, n=0)
2286        self.assertEqual(r, 0)
2287        q = "select n from test_parent natural join test_child limit 2"
2288        self.assertEqual(query(q).getresult(), [(1,)])
2289
2290    def testTruncate(self):
2291        truncate = self.db.truncate
2292        self.assertRaises(TypeError, truncate, None)
2293        self.assertRaises(TypeError, truncate, 42)
2294        self.assertRaises(TypeError, truncate, dict(test_table=None))
2295        query = self.db.query
2296        self.createTable('test_table', 'n smallint',
2297                         temporary=False, values=[1] * 3)
2298        q = "select count(*) from test_table"
2299        r = query(q).getresult()[0][0]
2300        self.assertEqual(r, 3)
2301        truncate('test_table')
2302        r = query(q).getresult()[0][0]
2303        self.assertEqual(r, 0)
2304        for i in range(3):
2305            query("insert into test_table values (1)")
2306        r = query(q).getresult()[0][0]
2307        self.assertEqual(r, 3)
2308        truncate('public.test_table')
2309        r = query(q).getresult()[0][0]
2310        self.assertEqual(r, 0)
2311        self.createTable('test_table_2', 'n smallint', temporary=True)
2312        for t in (list, tuple, set):
2313            for i in range(3):
2314                query("insert into test_table values (1)")
2315                query("insert into test_table_2 values (2)")
2316            q = ("select (select count(*) from test_table),"
2317                 " (select count(*) from test_table_2)")
2318            r = query(q).getresult()[0]
2319            self.assertEqual(r, (3, 3))
2320            truncate(t(['test_table', 'test_table_2']))
2321            r = query(q).getresult()[0]
2322            self.assertEqual(r, (0, 0))
2323
2324    def testTruncateRestart(self):
2325        truncate = self.db.truncate
2326        self.assertRaises(TypeError, truncate, 'test_table', restart='invalid')
2327        query = self.db.query
2328        self.createTable('test_table', 'n serial, t text')
2329        for n in range(3):
2330            query("insert into test_table (t) values ('test')")
2331        q = "select count(n), min(n), max(n) from test_table"
2332        r = query(q).getresult()[0]
2333        self.assertEqual(r, (3, 1, 3))
2334        truncate('test_table')
2335        r = query(q).getresult()[0]
2336        self.assertEqual(r, (0, None, None))
2337        for n in range(3):
2338            query("insert into test_table (t) values ('test')")
2339        r = query(q).getresult()[0]
2340        self.assertEqual(r, (3, 4, 6))
2341        truncate('test_table', restart=True)
2342        r = query(q).getresult()[0]
2343        self.assertEqual(r, (0, None, None))
2344        for n in range(3):
2345            query("insert into test_table (t) values ('test')")
2346        r = query(q).getresult()[0]
2347        self.assertEqual(r, (3, 1, 3))
2348
2349    def testTruncateCascade(self):
2350        truncate = self.db.truncate
2351        self.assertRaises(TypeError, truncate, 'test_table', cascade='invalid')
2352        query = self.db.query
2353        self.createTable('test_parent', 'n smallint primary key',
2354                         values=range(3))
2355        self.createTable('test_child',
2356                         'n smallint primary key references test_parent (n)',
2357                         values=range(3))
2358        q = ("select (select count(*) from test_parent),"
2359             " (select count(*) from test_child)")
2360        r = query(q).getresult()[0]
2361        self.assertEqual(r, (3, 3))
2362        self.assertRaises(pg.ProgrammingError, truncate, 'test_parent')
2363        truncate(['test_parent', 'test_child'])
2364        r = query(q).getresult()[0]
2365        self.assertEqual(r, (0, 0))
2366        for n in range(3):
2367            query("insert into test_parent (n) values (%d)" % n)
2368            query("insert into test_child (n) values (%d)" % n)
2369        r = query(q).getresult()[0]
2370        self.assertEqual(r, (3, 3))
2371        truncate('test_parent', cascade=True)
2372        r = query(q).getresult()[0]
2373        self.assertEqual(r, (0, 0))
2374        for n in range(3):
2375            query("insert into test_parent (n) values (%d)" % n)
2376            query("insert into test_child (n) values (%d)" % n)
2377        r = query(q).getresult()[0]
2378        self.assertEqual(r, (3, 3))
2379        truncate('test_child')
2380        r = query(q).getresult()[0]
2381        self.assertEqual(r, (3, 0))
2382        self.assertRaises(pg.ProgrammingError, truncate, 'test_parent')
2383        truncate('test_parent', cascade=True)
2384        r = query(q).getresult()[0]
2385        self.assertEqual(r, (0, 0))
2386
2387    def testTruncateOnly(self):
2388        truncate = self.db.truncate
2389        self.assertRaises(TypeError, truncate, 'test_table', only='invalid')
2390        query = self.db.query
2391        self.createTable('test_parent', 'n smallint')
2392        self.createTable('test_child', 'm smallint) inherits (test_parent')
2393        for n in range(3):
2394            query("insert into test_parent (n) values (1)")
2395            query("insert into test_child (n, m) values (2, 3)")
2396        q = ("select (select count(*) from test_parent),"
2397             " (select count(*) from test_child)")
2398        r = query(q).getresult()[0]
2399        self.assertEqual(r, (6, 3))
2400        truncate('test_parent')
2401        r = query(q).getresult()[0]
2402        self.assertEqual(r, (0, 0))
2403        for n in range(3):
2404            query("insert into test_parent (n) values (1)")
2405            query("insert into test_child (n, m) values (2, 3)")
2406        r = query(q).getresult()[0]
2407        self.assertEqual(r, (6, 3))
2408        truncate('test_parent*')
2409        r = query(q).getresult()[0]
2410        self.assertEqual(r, (0, 0))
2411        for n in range(3):
2412            query("insert into test_parent (n) values (1)")
2413            query("insert into test_child (n, m) values (2, 3)")
2414        r = query(q).getresult()[0]
2415        self.assertEqual(r, (6, 3))
2416        truncate('test_parent', only=True)
2417        r = query(q).getresult()[0]
2418        self.assertEqual(r, (3, 3))
2419        truncate('test_parent', only=False)
2420        r = query(q).getresult()[0]
2421        self.assertEqual(r, (0, 0))
2422        self.assertRaises(ValueError, truncate, 'test_parent*', only=True)
2423        truncate('test_parent*', only=False)
2424        self.createTable('test_parent_2', 'n smallint')
2425        self.createTable('test_child_2', 'm smallint) inherits (test_parent_2')
2426        for t in '', '_2':
2427            for n in range(3):
2428                query("insert into test_parent%s (n) values (1)" % t)
2429                query("insert into test_child%s (n, m) values (2, 3)" % t)
2430        q = ("select (select count(*) from test_parent),"
2431             " (select count(*) from test_child),"
2432             " (select count(*) from test_parent_2),"
2433             " (select count(*) from test_child_2)")
2434        r = query(q).getresult()[0]
2435        self.assertEqual(r, (6, 3, 6, 3))
2436        truncate(['test_parent', 'test_parent_2'], only=[False, True])
2437        r = query(q).getresult()[0]
2438        self.assertEqual(r, (0, 0, 3, 3))
2439        truncate(['test_parent', 'test_parent_2'], only=False)
2440        r = query(q).getresult()[0]
2441        self.assertEqual(r, (0, 0, 0, 0))
2442        self.assertRaises(ValueError, truncate,
2443                          ['test_parent*', 'test_child'], only=[True, False])
2444        truncate(['test_parent*', 'test_child'], only=[False, True])
2445
2446    def testTruncateQuoted(self):
2447        truncate = self.db.truncate
2448        query = self.db.query
2449        table = "test table for truncate()"
2450        self.createTable(table, 'n smallint', temporary=False, values=[1] * 3)
2451        q = 'select count(*) from "%s"' % table
2452        r = query(q).getresult()[0][0]
2453        self.assertEqual(r, 3)
2454        truncate(table)
2455        r = query(q).getresult()[0][0]
2456        self.assertEqual(r, 0)
2457        for i in range(3):
2458            query('insert into "%s" values (1)' % table)
2459        r = query(q).getresult()[0][0]
2460        self.assertEqual(r, 3)
2461        truncate('public."%s"' % table)
2462        r = query(q).getresult()[0][0]
2463        self.assertEqual(r, 0)
2464
2465    def testGetAsList(self):
2466        get_as_list = self.db.get_as_list
2467        self.assertRaises(TypeError, get_as_list)
2468        self.assertRaises(TypeError, get_as_list, None)
2469        query = self.db.query
2470        table = 'test_aslist'
2471        r = query('select 1 as colname').namedresult()[0]
2472        self.assertIsInstance(r, tuple)
2473        named = hasattr(r, 'colname')
2474        names = [(1, 'Homer'), (2, 'Marge'),
2475                 (3, 'Bart'), (4, 'Lisa'), (5, 'Maggie')]
2476        self.createTable(table,
2477                         'id smallint primary key, name varchar', values=names)
2478        r = get_as_list(table)
2479        self.assertIsInstance(r, list)
2480        self.assertEqual(r, names)
2481        for t, n in zip(r, names):
2482            self.assertIsInstance(t, tuple)
2483            self.assertEqual(t, n)
2484            if named:
2485                self.assertEqual(t.id, n[0])
2486                self.assertEqual(t.name, n[1])
2487                self.assertEqual(t._asdict(), dict(id=n[0], name=n[1]))
2488        r = get_as_list(table, what='name')
2489        self.assertIsInstance(r, list)
2490        expected = sorted((row[1],) for row in names)
2491        self.assertEqual(r, expected)
2492        r = get_as_list(table, what='name, id')
2493        self.assertIsInstance(r, list)
2494        expected = sorted(tuple(reversed(row)) for row in names)
2495        self.assertEqual(r, expected)
2496        r = get_as_list(table, what=['name', 'id'])
2497        self.assertIsInstance(r, list)
2498        self.assertEqual(r, expected)
2499        r = get_as_list(table, where="name like 'Ba%'")
2500        self.assertIsInstance(r, list)
2501        self.assertEqual(r, names[2:3])
2502        r = get_as_list(table, what='name', where="name like 'Ma%'")
2503        self.assertIsInstance(r, list)
2504        self.assertEqual(r, [('Maggie',), ('Marge',)])
2505        r = get_as_list(table, what='name',
2506                        where=["name like 'Ma%'", "name like '%r%'"])
2507        self.assertIsInstance(r, list)
2508        self.assertEqual(r, [('Marge',)])
2509        r = get_as_list(table, what='name', order='id')
2510        self.assertIsInstance(r, list)
2511        expected = [(row[1],) for row in names]
2512        self.assertEqual(r, expected)
2513        r = get_as_list(table, what=['name'], order=['id'])
2514        self.assertIsInstance(r, list)
2515        self.assertEqual(r, expected)
2516        r = get_as_list(table, what=['id', 'name'], order=['id', 'name'])
2517        self.assertIsInstance(r, list)
2518        self.assertEqual(r, names)
2519        r = get_as_list(table, what='id * 2 as num', order='id desc')
2520        self.assertIsInstance(r, list)
2521        expected = [(n,) for n in range(10, 0, -2)]
2522        self.assertEqual(r, expected)
2523        r = get_as_list(table, limit=2)
2524        self.assertIsInstance(r, list)
2525        self.assertEqual(r, names[:2])
2526        r = get_as_list(table, offset=3)
2527        self.assertIsInstance(r, list)
2528        self.assertEqual(r, names[3:])
2529        r = get_as_list(table, limit=1, offset=2)
2530        self.assertIsInstance(r, list)
2531        self.assertEqual(r, names[2:3])
2532        r = get_as_list(table, scalar=True)
2533        self.assertIsInstance(r, list)
2534        self.assertEqual(r, list(range(1, 6)))
2535        r = get_as_list(table, what='name', scalar=True)
2536        self.assertIsInstance(r, list)
2537        expected = sorted(row[1] for row in names)
2538        self.assertEqual(r, expected)
2539        r = get_as_list(table, what='name', limit=1, scalar=True)
2540        self.assertIsInstance(r, list)
2541        self.assertEqual(r, expected[:1])
2542        query('alter table "%s" drop constraint "%s_pkey"' % (table, table))
2543        self.assertRaises(KeyError, self.db.pkey, table, flush=True)
2544        names.insert(1, (1, 'Snowball'))
2545        query('insert into "%s" values ($1, $2)' % table, (1, 'Snowball'))
2546        r = get_as_list(table)
2547        self.assertIsInstance(r, list)
2548        self.assertEqual(r, names)
2549        r = get_as_list(table, what='name', where='id=1', scalar=True)
2550        self.assertIsInstance(r, list)
2551        self.assertEqual(r, ['Homer', 'Snowball'])
2552        # test with unordered query
2553        r = get_as_list(table, order=False)
2554        self.assertIsInstance(r, list)
2555        self.assertEqual(set(r), set(names))
2556        # test with arbitrary from clause
2557        from_table = '(select lower(name) as n2 from "%s") as t2' % table
2558        r = get_as_list(from_table)
2559        self.assertIsInstance(r, list)
2560        r = set(row[0] for row in r)
2561        expected = set(row[1].lower() for row in names)
2562        self.assertEqual(r, expected)
2563        r = get_as_list(from_table, order='n2', scalar=True)
2564        self.assertIsInstance(r, list)
2565        self.assertEqual(r, sorted(expected))
2566        r = get_as_list(from_table, order='n2', limit=1)
2567        self.assertIsInstance(r, list)
2568        self.assertEqual(len(r), 1)
2569        t = r[0]
2570        self.assertIsInstance(t, tuple)
2571        if named:
2572            self.assertEqual(t.n2, 'bart')
2573            self.assertEqual(t._asdict(), dict(n2='bart'))
2574        else:
2575            self.assertEqual(t, ('bart',))
2576
2577    def testGetAsDict(self):
2578        get_as_dict = self.db.get_as_dict
2579        self.assertRaises(TypeError, get_as_dict)
2580        self.assertRaises(TypeError, get_as_dict, None)
2581        # the test table has no primary key
2582        self.assertRaises(pg.ProgrammingError, get_as_dict, 'test')
2583        query = self.db.query
2584        table = 'test_asdict'
2585        r = query('select 1 as colname').namedresult()[0]
2586        self.assertIsInstance(r, tuple)
2587        named = hasattr(r, 'colname')
2588        colors = [(1, '#7cb9e8', 'Aero'), (2, '#b5a642', 'Brass'),
2589                  (3, '#b2ffff', 'Celeste'), (4, '#c19a6b', 'Desert')]
2590        self.createTable(table,
2591                         'id smallint primary key, rgb char(7), name varchar',
2592                         values=colors)
2593        # keyname must be string, list or tuple
2594        self.assertRaises(KeyError, get_as_dict, table, 3)
2595        self.assertRaises(KeyError, get_as_dict, table, dict(id=None))
2596        # missing keyname in row
2597        self.assertRaises(KeyError, get_as_dict, table,
2598                          keyname='rgb', what='name')
2599        r = get_as_dict(table)
2600        self.assertIsInstance(r, OrderedDict)
2601        expected = OrderedDict((row[0], row[1:]) for row in colors)
2602        self.assertEqual(r, expected)
2603        for key in r:
2604            self.assertIsInstance(key, int)
2605            self.assertIn(key, expected)
2606            row = r[key]
2607            self.assertIsInstance(row, tuple)
2608            t = expected[key]
2609            self.assertEqual(row, t)
2610            if named:
2611                self.assertEqual(row.rgb, t[0])
2612                self.assertEqual(row.name, t[1])
2613                self.assertEqual(row._asdict(), dict(rgb=t[0], name=t[1]))
2614        if OrderedDict is not dict:  # Python > 2.6
2615            self.assertEqual(r.keys(), expected.keys())
2616        r = get_as_dict(table, keyname='rgb')
2617        self.assertIsInstance(r, OrderedDict)
2618        expected = OrderedDict((row[1], (row[0], row[2]))
2619                               for row in sorted(colors, key=itemgetter(1)))
2620        self.assertEqual(r, expected)
2621        for key in r:
2622            self.assertIsInstance(key, str)
2623            self.assertIn(key, expected)
2624            row = r[key]
2625            self.assertIsInstance(row, tuple)
2626            t = expected[key]
2627            self.assertEqual(row, t)
2628            if named:
2629                self.assertEqual(row.id, t[0])
2630                self.assertEqual(row.name, t[1])
2631                self.assertEqual(row._asdict(), dict(id=t[0], name=t[1]))
2632        if OrderedDict is not dict:  # Python > 2.6
2633            self.assertEqual(r.keys(), expected.keys())
2634        r = get_as_dict(table, keyname=['id', 'rgb'])
2635        self.assertIsInstance(r, OrderedDict)
2636        expected = OrderedDict((row[:2], row[2:]) for row in colors)
2637        self.assertEqual(r, expected)
2638        for key in r:
2639            self.assertIsInstance(key, tuple)
2640            self.assertIsInstance(key[0], int)
2641            self.assertIsInstance(key[1], str)
2642            if named:
2643                self.assertEqual(key, (key.id, key.rgb))
2644                self.assertEqual(key._fields, ('id', 'rgb'))
2645            row = r[key]
2646            self.assertIsInstance(row, tuple)
2647            self.assertIsInstance(row[0], str)
2648            t = expected[key]
2649            self.assertEqual(row, t)
2650            if named:
2651                self.assertEqual(row.name, t[0])
2652                self.assertEqual(row._asdict(), dict(name=t[0]))
2653        if OrderedDict is not dict:  # Python > 2.6
2654            self.assertEqual(r.keys(), expected.keys())
2655        r = get_as_dict(table, keyname=['id', 'rgb'], scalar=True)
2656        self.assertIsInstance(r, OrderedDict)
2657        expected = OrderedDict((row[:2], row[2]) for row in colors)
2658        self.assertEqual(r, expected)
2659        for key in r:
2660            self.assertIsInstance(key, tuple)
2661            row = r[key]
2662            self.assertIsInstance(row, str)
2663            t = expected[key]
2664            self.assertEqual(row, t)
2665        if OrderedDict is not dict:  # Python > 2.6
2666            self.assertEqual(r.keys(), expected.keys())
2667        r = get_as_dict(table, keyname='rgb', what=['rgb', 'name'], scalar=True)
2668        self.assertIsInstance(r, OrderedDict)
2669        expected = OrderedDict((row[1], row[2])
2670                               for row in sorted(colors, key=itemgetter(1)))
2671        self.assertEqual(r, expected)
2672        for key in r:
2673            self.assertIsInstance(key, str)
2674            row = r[key]
2675            self.assertIsInstance(row, str)
2676            t = expected[key]
2677            self.assertEqual(row, t)
2678        if OrderedDict is not dict:  # Python > 2.6
2679            self.assertEqual(r.keys(), expected.keys())
2680        r = get_as_dict(table, what='id, name',
2681                        where="rgb like '#b%'", scalar=True)
2682        self.assertIsInstance(r, OrderedDict)
2683        expected = OrderedDict((row[0], row[2]) for row in colors[1:3])
2684        self.assertEqual(r, expected)
2685        for key in r:
2686            self.assertIsInstance(key, int)
2687            row = r[key]
2688            self.assertIsInstance(row, str)
2689            t = expected[key]
2690            self.assertEqual(row, t)
2691        if OrderedDict is not dict:  # Python > 2.6
2692            self.assertEqual(r.keys(), expected.keys())
2693        expected = r
2694        r = get_as_dict(table, what=['name', 'id'],
2695                        where=['id > 1', 'id < 4', "rgb like '#b%'",
2696                               "name not like 'A%'", "name not like '%t'"], scalar=True)
2697        self.assertEqual(r, expected)
2698        r = get_as_dict(table, what='name, id', limit=2, offset=1, scalar=True)
2699        self.assertEqual(r, expected)
2700        r = get_as_dict(table, keyname=('id',), what=('name', 'id'),
2701                        where=('id > 1', 'id < 4'), order=('id',), scalar=True)
2702        self.assertEqual(r, expected)
2703        r = get_as_dict(table, limit=1)
2704        self.assertEqual(len(r), 1)
2705        self.assertEqual(r[1][1], 'Aero')
2706        r = get_as_dict(table, offset=3)
2707        self.assertEqual(len(r), 1)
2708        self.assertEqual(r[4][1], 'Desert')
2709        r = get_as_dict(table, order='id desc')
2710        expected = OrderedDict((row[0], row[1:]) for row in reversed(colors))
2711        self.assertEqual(r, expected)
2712        r = get_as_dict(table, where='id > 5')
2713        self.assertIsInstance(r, OrderedDict)
2714        self.assertEqual(len(r), 0)
2715        # test with unordered query
2716        expected = dict((row[0], row[1:]) for row in colors)
2717        r = get_as_dict(table, order=False)
2718        self.assertIsInstance(r, dict)
2719        self.assertEqual(r, expected)
2720        if dict is not OrderedDict:  # Python > 2.6
2721            self.assertNotIsInstance(self, OrderedDict)
2722        # test with arbitrary from clause
2723        from_table = '(select id, lower(name) as n2 from "%s") as t2' % table
2724        # primary key must be passed explicitly in this case
2725        self.assertRaises(pg.ProgrammingError, get_as_dict, from_table)
2726        r = get_as_dict(from_table, 'id')
2727        self.assertIsInstance(r, OrderedDict)
2728        expected = OrderedDict((row[0], (row[2].lower(),)) for row in colors)
2729        self.assertEqual(r, expected)
2730        # test without a primary key
2731        query('alter table "%s" drop constraint "%s_pkey"' % (table, table))
2732        self.assertRaises(KeyError, self.db.pkey, table, flush=True)
2733        self.assertRaises(pg.ProgrammingError, get_as_dict, table)
2734        r = get_as_dict(table, keyname='id')
2735        expected = OrderedDict((row[0], row[1:]) for row in colors)
2736        self.assertIsInstance(r, dict)
2737        self.assertEqual(r, expected)
2738        r = (1, '#007fff', 'Azure')
2739        query('insert into "%s" values ($1, $2, $3)' % table, r)
2740        # the last entry will win
2741        expected[1] = r[1:]
2742        r = get_as_dict(table, keyname='id')
2743        self.assertEqual(r, expected)
2744
2745    def testTransaction(self):
2746        query = self.db.query
2747        self.createTable('test_table', 'n integer', temporary=False)
2748        self.db.begin()
2749        query("insert into test_table values (1)")
2750        query("insert into test_table values (2)")
2751        self.db.commit()
2752        self.db.begin()
2753        query("insert into test_table values (3)")
2754        query("insert into test_table values (4)")
2755        self.db.rollback()
2756        self.db.begin()
2757        query("insert into test_table values (5)")
2758        self.db.savepoint('before6')
2759        query("insert into test_table values (6)")
2760        self.db.rollback('before6')
2761        query("insert into test_table values (7)")
2762        self.db.commit()
2763        self.db.begin()
2764        self.db.savepoint('before8')
2765        query("insert into test_table values (8)")
2766        self.db.release('before8')
2767        self.assertRaises(pg.ProgrammingError, self.db.rollback, 'before8')
2768        self.db.commit()
2769        self.db.start()
2770        query("insert into test_table values (9)")
2771        self.db.end()
2772        r = [r[0] for r in query(
2773            "select * from test_table order by 1").getresult()]
2774        self.assertEqual(r, [1, 2, 5, 7, 9])
2775        self.db.begin(mode='read only')
2776        self.assertRaises(pg.ProgrammingError,
2777                          query, "insert into test_table values (0)")
2778        self.db.rollback()
2779        self.db.start(mode='Read Only')
2780        self.assertRaises(pg.ProgrammingError,
2781                          query, "insert into test_table values (0)")
2782        self.db.abort()
2783
2784    def testTransactionAliases(self):
2785        self.assertEqual(self.db.begin, self.db.start)
2786        self.assertEqual(self.db.commit, self.db.end)
2787        self.assertEqual(self.db.rollback, self.db.abort)
2788
2789    def testContextManager(self):
2790        query = self.db.query
2791        self.createTable('test_table', 'n integer check(n>0)')
2792        with self.db:
2793            query("insert into test_table values (1)")
2794            query("insert into test_table values (2)")
2795        try:
2796            with self.db:
2797                query("insert into test_table values (3)")
2798                query("insert into test_table values (4)")
2799                raise ValueError('test transaction should rollback')
2800        except ValueError as error:
2801            self.assertEqual(str(error), 'test transaction should rollback')
2802        with self.db:
2803            query("insert into test_table values (5)")
2804        try:
2805            with self.db:
2806                query("insert into test_table values (6)")
2807                query("insert into test_table values (-1)")
2808        except pg.ProgrammingError as error:
2809            self.assertTrue('check' in str(error))
2810        with self.db:
2811            query("insert into test_table values (7)")
2812        r = [r[0] for r in query(
2813            "select * from test_table order by 1").getresult()]
2814        self.assertEqual(r, [1, 2, 5, 7])
2815
2816    def testBytea(self):
2817        query = self.db.query
2818        self.createTable('bytea_test', 'n smallint primary key, data bytea')
2819        s = b"It's all \\ kinds \x00 of\r nasty \xff stuff!\n"
2820        r = self.db.escape_bytea(s)
2821        query('insert into bytea_test values(3, $1)', (r,))
2822        r = query('select * from bytea_test where n=3').getresult()
2823        self.assertEqual(len(r), 1)
2824        r = r[0]
2825        self.assertEqual(len(r), 2)
2826        self.assertEqual(r[0], 3)
2827        r = r[1]
2828        self.assertIsInstance(r, bytes)
2829        self.assertEqual(r, s)
2830
2831    def testInsertUpdateGetBytea(self):
2832        query = self.db.query
2833        self.createTable('bytea_test', 'n smallint primary key, data bytea')
2834        # insert null value
2835        r = self.db.insert('bytea_test', n=0, data=None)
2836        self.assertIsInstance(r, dict)
2837        self.assertIn('n', r)
2838        self.assertEqual(r['n'], 0)
2839        self.assertIn('data', r)
2840        self.assertIsNone(r['data'])
2841        s = b'None'
2842        r = self.db.update('bytea_test', n=0, data=s)
2843        self.assertIsInstance(r, dict)
2844        self.assertIn('n', r)
2845        self.assertEqual(r['n'], 0)
2846        self.assertIn('data', r)
2847        r = r['data']
2848        self.assertIsInstance(r, bytes)
2849        self.assertEqual(r, s)
2850        r = self.db.update('bytea_test', n=0, data=None)
2851        self.assertIsNone(r['data'])
2852        # insert as bytes
2853        s = b"It's all \\ kinds \x00 of\r nasty \xff stuff!\n"
2854        r = self.db.insert('bytea_test', n=5, data=s)
2855        self.assertIsInstance(r, dict)
2856        self.assertIn('n', r)
2857        self.assertEqual(r['n'], 5)
2858        self.assertIn('data', r)
2859        r = r['data']
2860        self.assertIsInstance(r, bytes)
2861        self.assertEqual(r, s)
2862        # update as bytes
2863        s += b"and now even more \x00 nasty \t stuff!\f"
2864        r = self.db.update('bytea_test', n=5, data=s)
2865        self.assertIsInstance(r, dict)
2866        self.assertIn('n', r)
2867        self.assertEqual(r['n'], 5)
2868        self.assertIn('data', r)
2869        r = r['data']
2870        self.assertIsInstance(r, bytes)
2871        self.assertEqual(r, s)
2872        r = query('select * from bytea_test where n=5').getresult()
2873        self.assertEqual(len(r), 1)
2874        r = r[0]
2875        self.assertEqual(len(r), 2)
2876        self.assertEqual(r[0], 5)
2877        r = r[1]
2878        self.assertIsInstance(r, bytes)
2879        self.assertEqual(r, s)
2880        r = self.db.get('bytea_test', dict(n=5))
2881        self.assertIsInstance(r, dict)
2882        self.assertIn('n', r)
2883        self.assertEqual(r['n'], 5)
2884        self.assertIn('data', r)
2885        r = r['data']
2886        self.assertIsInstance(r, bytes)
2887        self.assertEqual(r, s)
2888
2889    def testUpsertBytea(self):
2890        self.createTable('bytea_test', 'n smallint primary key, data bytea')
2891        s = b"It's all \\ kinds \x00 of\r nasty \xff stuff!\n"
2892        r = dict(n=7, data=s)
2893        try:
2894            r = self.db.upsert('bytea_test', r)
2895        except pg.ProgrammingError as error:
2896            if self.db.server_version < 90500:
2897                self.skipTest('database does not support upsert')
2898            self.fail(str(error))
2899        self.assertIsInstance(r, dict)
2900        self.assertIn('n', r)
2901        self.assertEqual(r['n'], 7)
2902        self.assertIn('data', r)
2903        self.assertIsInstance(r['data'], bytes)
2904        self.assertEqual(r['data'], s)
2905        r['data'] = None
2906        r = self.db.upsert('bytea_test', r)
2907        self.assertIsInstance(r, dict)
2908        self.assertIn('n', r)
2909        self.assertEqual(r['n'], 7)
2910        self.assertIn('data', r)
2911        self.assertIsNone(r['data'], bytes)
2912
2913    def testInsertGetJson(self):
2914        try:
2915            self.createTable('json_test', 'n smallint primary key, data json')
2916        except pg.ProgrammingError as error:
2917            if self.db.server_version < 90200:
2918                self.skipTest('database does not support json')
2919            self.fail(str(error))
2920        jsondecode = pg.get_jsondecode()
2921        # insert null value
2922        r = self.db.insert('json_test', n=0, data=None)
2923        self.assertIsInstance(r, dict)
2924        self.assertIn('n', r)
2925        self.assertEqual(r['n'], 0)
2926        self.assertIn('data', r)
2927        self.assertIsNone(r['data'])
2928        r = self.db.get('json_test', 0)
2929        self.assertIsInstance(r, dict)
2930        self.assertIn('n', r)
2931        self.assertEqual(r['n'], 0)
2932        self.assertIn('data', r)
2933        self.assertIsNone(r['data'])
2934        # insert JSON object
2935        data = {
2936            "id": 1, "name": "Foo", "price": 1234.5,
2937            "new": True, "note": None,
2938            "tags": ["Bar", "Eek"],
2939            "stock": {"warehouse": 300, "retail": 20}}
2940        r = self.db.insert('json_test', n=1, data=data)
2941        self.assertIsInstance(r, dict)
2942        self.assertIn('n', r)
2943        self.assertEqual(r['n'], 1)
2944        self.assertIn('data', r)
2945        r = r['data']
2946        if jsondecode is None:
2947            self.assertIsInstance(r, str)
2948            r = json.loads(r)
2949        self.assertIsInstance(r, dict)
2950        self.assertEqual(r, data)
2951        self.assertIsInstance(r['id'], int)
2952        self.assertIsInstance(r['name'], unicode)
2953        self.assertIsInstance(r['price'], float)
2954        self.assertIsInstance(r['new'], bool)
2955        self.assertIsInstance(r['tags'], list)
2956        self.assertIsInstance(r['stock'], dict)
2957        r = self.db.get('json_test', 1)
2958        self.assertIsInstance(r, dict)
2959        self.assertIn('n', r)
2960        self.assertEqual(r['n'], 1)
2961        self.assertIn('data', r)
2962        r = r['data']
2963        if jsondecode is None:
2964            self.assertIsInstance(r, str)
2965            r = json.loads(r)
2966        self.assertIsInstance(r, dict)
2967        self.assertEqual(r, data)
2968        self.assertIsInstance(r['id'], int)
2969        self.assertIsInstance(r['name'], unicode)
2970        self.assertIsInstance(r['price'], float)
2971        self.assertIsInstance(r['new'], bool)
2972        self.assertIsInstance(r['tags'], list)
2973        self.assertIsInstance(r['stock'], dict)
2974        # insert JSON object as text
2975        self.db.insert('json_test', n=2, data=json.dumps(data))
2976        q = "select data from json_test where n in (1, 2) order by n"
2977        r = self.db.query(q).getresult()
2978        self.assertEqual(len(r), 2)
2979        self.assertIsInstance(r[0][0], str if jsondecode is None else dict)
2980        self.assertEqual(r[0][0], r[1][0])
2981
2982    def testInsertGetJsonb(self):
2983        try:
2984            self.createTable('jsonb_test',
2985                             'n smallint primary key, data jsonb')
2986        except pg.ProgrammingError as error:
2987            if self.db.server_version < 90400:
2988                self.skipTest('database does not support jsonb')
2989            self.fail(str(error))
2990        jsondecode = pg.get_jsondecode()
2991        # insert null value
2992        r = self.db.insert('jsonb_test', n=0, data=None)
2993        self.assertIsInstance(r, dict)
2994        self.assertIn('n', r)
2995        self.assertEqual(r['n'], 0)
2996        self.assertIn('data', r)
2997        self.assertIsNone(r['data'])
2998        r = self.db.get('jsonb_test', 0)
2999        self.assertIsInstance(r, dict)
3000        self.assertIn('n', r)
3001        self.assertEqual(r['n'], 0)
3002        self.assertIn('data', r)
3003        self.assertIsNone(r['data'])
3004        # insert JSON object
3005        data = {
3006            "id": 1, "name": "Foo", "price": 1234.5,
3007            "new": True, "note": None,
3008            "tags": ["Bar", "Eek"],
3009            "stock": {"warehouse": 300, "retail": 20}}
3010        r = self.db.insert('jsonb_test', n=1, data=data)
3011        self.assertIsInstance(r, dict)
3012        self.assertIn('n', r)
3013        self.assertEqual(r['n'], 1)
3014        self.assertIn('data', r)
3015        r = r['data']
3016        if jsondecode is None:
3017            self.assertIsInstance(r, str)
3018            r = json.loads(r)
3019        self.assertIsInstance(r, dict)
3020        self.assertEqual(r, data)
3021        self.assertIsInstance(r['id'], int)
3022        self.assertIsInstance(r['name'], unicode)
3023        self.assertIsInstance(r['price'], float)
3024        self.assertIsInstance(r['new'], bool)
3025        self.assertIsInstance(r['tags'], list)
3026        self.assertIsInstance(r['stock'], dict)
3027        r = self.db.get('jsonb_test', 1)
3028        self.assertIsInstance(r, dict)
3029        self.assertIn('n', r)
3030        self.assertEqual(r['n'], 1)
3031        self.assertIn('data', r)
3032        r = r['data']
3033        if jsondecode is None:
3034            self.assertIsInstance(r, str)
3035            r = json.loads(r)
3036        self.assertIsInstance(r, dict)
3037        self.assertEqual(r, data)
3038        self.assertIsInstance(r['id'], int)
3039        self.assertIsInstance(r['name'], unicode)
3040        self.assertIsInstance(r['price'], float)
3041        self.assertIsInstance(r['new'], bool)
3042        self.assertIsInstance(r['tags'], list)
3043        self.assertIsInstance(r['stock'], dict)
3044
3045    def testArray(self):
3046        self.createTable('arraytest',
3047                         'id smallint, i2 smallint[], i4 integer[], i8 bigint[],'
3048                         ' d numeric[], f4 real[], f8 double precision[], m money[],'
3049                         ' b bool[], v4 varchar(4)[], c4 char(4)[], t text[]')
3050        r = self.db.get_attnames('arraytest')
3051        if self.regtypes:
3052            self.assertEqual(r, dict(
3053                id='smallint', i2='smallint[]', i4='integer[]', i8='bigint[]',
3054                d='numeric[]', f4='real[]', f8='double precision[]',
3055                m='money[]', b='boolean[]',
3056                v4='character varying[]', c4='character[]', t='text[]'))
3057        else:
3058            self.assertEqual(r, dict(
3059                id='int', i2='int[]', i4='int[]', i8='int[]',
3060                d='num[]', f4='float[]', f8='float[]', m='money[]',
3061                b='bool[]', v4='text[]', c4='text[]', t='text[]'))
3062        decimal = pg.get_decimal()
3063        if decimal is Decimal:
3064            long_decimal = decimal('123456789.123456789')
3065            odd_money = decimal('1234567891234567.89')
3066        else:
3067            long_decimal = decimal('12345671234.5')
3068            odd_money = decimal('1234567123.25')
3069        t, f = (True, False) if pg.get_bool() else ('t', 'f')
3070        data = dict(id=42, i2=[42, 1234, None, 0, -1],
3071                    i4=[42, 123456789, None, 0, 1, -1],
3072                    i8=[long(42), long(123456789123456789), None,
3073                        long(0), long(1), long(-1)],
3074                    d=[decimal(42), long_decimal, None,
3075                       decimal(0), decimal(1), decimal(-1), -long_decimal],
3076                    f4=[42.0, 1234.5, None, 0.0, 1.0, -1.0,
3077                        float('inf'), float('-inf')],
3078                    f8=[42.0, 12345671234.5, None, 0.0, 1.0, -1.0,
3079                        float('inf'), float('-inf')],
3080                    m=[decimal('42.00'), odd_money, None,
3081                       decimal('0.00'), decimal('1.00'), decimal('-1.00'), -odd_money],
3082                    b=[t, f, t, None, f, t, None, None, t],
3083                    v4=['abc', '"Hi"', '', None], c4=['abc ', '"Hi"', '    ', None],
3084                    t=['abc', 'Hello, World!', '"Hello, World!"', '', None])
3085        r = data.copy()
3086        self.db.insert('arraytest', r)
3087        self.assertEqual(r, data)
3088        self.db.insert('arraytest', r)
3089        r = self.db.get('arraytest', 42, 'id')
3090        self.assertEqual(r, data)
3091        r = self.db.query('select * from arraytest limit 1').dictresult()[0]
3092        self.assertEqual(r, data)
3093
3094    def testArrayLiteral(self):
3095        insert = self.db.insert
3096        self.createTable('arraytest', 'i int[], t text[]', oids=True)
3097        r = dict(i=[1, 2, 3], t=['a', 'b', 'c'])
3098        insert('arraytest', r)
3099        self.assertEqual(r['i'], [1, 2, 3])
3100        self.assertEqual(r['t'], ['a', 'b', 'c'])
3101        r = dict(i='{1,2,3}', t='{a,b,c}')
3102        self.db.insert('arraytest', r)
3103        self.assertEqual(r['i'], [1, 2, 3])
3104        self.assertEqual(r['t'], ['a', 'b', 'c'])
3105        L = pg._Literal
3106        r = dict(i=L("ARRAY[1, 2, 3]"), t=L("ARRAY['a', 'b', 'c']"))
3107        self.db.insert('arraytest', r)
3108        self.assertEqual(r['i'], [1, 2, 3])
3109        self.assertEqual(r['t'], ['a', 'b', 'c'])
3110        r = dict(i="1, 2, 3", t="'a', 'b', 'c'")
3111        self.assertRaises(pg.ProgrammingError, self.db.insert, 'arraytest', r)
3112
3113    def testArrayOfIds(self):
3114        self.createTable('arraytest', 'c cid[], o oid[], x xid[]', oids=True)
3115        r = self.db.get_attnames('arraytest')
3116        if self.regtypes:
3117            self.assertEqual(r, dict(
3118                oid='oid', c='cid[]', o='oid[]', x='xid[]'))
3119        else:
3120            self.assertEqual(r, dict(
3121                oid='int', c='int[]', o='int[]', x='int[]'))
3122        data = dict(c=[11, 12, 13], o=[21, 22, 23], x=[31, 32, 33])
3123        r = data.copy()
3124        self.db.insert('arraytest', r)
3125        qoid = 'oid(arraytest)'
3126        oid = r.pop(qoid)
3127        self.assertEqual(r, data)
3128        r = {qoid: oid}
3129        self.db.get('arraytest', r)
3130        self.assertEqual(oid, r.pop(qoid))
3131        self.assertEqual(r, data)
3132
3133    def testArrayOfText(self):
3134        self.createTable('arraytest', 'data text[]', oids=True)
3135        r = self.db.get_attnames('arraytest')
3136        self.assertEqual(r['data'], 'text[]')
3137        data = ['Hello, World!', '', None, '{a,b,c}', '"Hi!"',
3138                'null', 'NULL', 'Null', 'nulL',
3139                "It's all \\ kinds of\r nasty stuff!\n"]
3140        r = dict(data=data)
3141        self.db.insert('arraytest', r)
3142        self.assertEqual(r['data'], data)
3143        self.assertIsInstance(r['data'][1], str)
3144        self.assertIsNone(r['data'][2])
3145        r['data'] = None
3146        self.db.get('arraytest', r)
3147        self.assertEqual(r['data'], data)
3148        self.assertIsInstance(r['data'][1], str)
3149        self.assertIsNone(r['data'][2])
3150
3151    def testArrayOfBytea(self):
3152        self.createTable('arraytest', 'data bytea[]', oids=True)
3153        r = self.db.get_attnames('arraytest')
3154        self.assertEqual(r['data'], 'bytea[]')
3155        data = [b'Hello, World!', b'', None, b'{a,b,c}', b'"Hi!"',
3156                b"It's all \\ kinds \x00 of\r nasty \xff stuff!\n"]
3157        r = dict(data=data)
3158        self.db.insert('arraytest', r)
3159        self.assertEqual(r['data'], data)
3160        self.assertIsInstance(r['data'][1], bytes)
3161        self.assertIsNone(r['data'][2])
3162        r['data'] = None
3163        self.db.get('arraytest', r)
3164        self.assertEqual(r['data'], data)
3165        self.assertIsInstance(r['data'][1], bytes)
3166        self.assertIsNone(r['data'][2])
3167
3168    def testArrayOfJson(self):
3169        try:
3170            self.createTable('arraytest', 'data json[]', oids=True)
3171        except pg.ProgrammingError as error:
3172            if self.db.server_version < 90200:
3173                self.skipTest('database does not support json')
3174            self.fail(str(error))
3175        r = self.db.get_attnames('arraytest')
3176        self.assertEqual(r['data'], 'json[]')
3177        data = [dict(id=815, name='John Doe'), dict(id=816, name='Jane Roe')]
3178        jsondecode = pg.get_jsondecode()
3179        r = dict(data=data)
3180        self.db.insert('arraytest', r)
3181        if jsondecode is None:
3182            r['data'] = [json.loads(d) for d in r['data']]
3183        self.assertEqual(r['data'], data)
3184        r['data'] = None
3185        self.db.get('arraytest', r)
3186        if jsondecode is None:
3187            r['data'] = [json.loads(d) for d in r['data']]
3188        self.assertEqual(r['data'], data)
3189        r = dict(data=[json.dumps(d) for d in data])
3190        self.db.insert('arraytest', r)
3191        if jsondecode is None:
3192            r['data'] = [json.loads(d) for d in r['data']]
3193        self.assertEqual(r['data'], data)
3194        r['data'] = None
3195        self.db.get('arraytest', r)
3196        # insert empty json values
3197        r = dict(data=['', None])
3198        self.db.insert('arraytest', r)
3199        r = r['data']
3200        self.assertIsInstance(r, list)
3201        self.assertEqual(len(r), 2)
3202        self.assertIsNone(r[0])
3203        self.assertIsNone(r[1])
3204
3205    def testArrayOfJsonb(self):
3206        try:
3207            self.createTable('arraytest', 'data jsonb[]', oids=True)
3208        except pg.ProgrammingError as error:
3209            if self.db.server_version < 90400:
3210                self.skipTest('database does not support jsonb')
3211            self.fail(str(error))
3212        r = self.db.get_attnames('arraytest')
3213        self.assertEqual(r['data'], 'jsonb[]' if self.regtypes else 'json[]')
3214        data = [dict(id=815, name='John Doe'), dict(id=816, name='Jane Roe')]
3215        jsondecode = pg.get_jsondecode()
3216        r = dict(data=data)
3217        self.db.insert('arraytest', r)
3218        if jsondecode is None:
3219            r['data'] = [json.loads(d) for d in r['data']]
3220        self.assertEqual(r['data'], data)
3221        r['data'] = None
3222        self.db.get('arraytest', r)
3223        if jsondecode is None:
3224            r['data'] = [json.loads(d) for d in r['data']]
3225        self.assertEqual(r['data'], data)
3226        r = dict(data=[json.dumps(d) for d in data])
3227        self.db.insert('arraytest', r)
3228        if jsondecode is None:
3229            r['data'] = [json.loads(d) for d in r['data']]
3230        self.assertEqual(r['data'], data)
3231        r['data'] = None
3232        self.db.get('arraytest', r)
3233        # insert empty json values
3234        r = dict(data=['', None])
3235        self.db.insert('arraytest', r)
3236        r = r['data']
3237        self.assertIsInstance(r, list)
3238        self.assertEqual(len(r), 2)
3239        self.assertIsNone(r[0])
3240        self.assertIsNone(r[1])
3241
3242    def testDeepArray(self):
3243        self.createTable('arraytest', 'data text[][][]', oids=True)
3244        r = self.db.get_attnames('arraytest')
3245        self.assertEqual(r['data'], 'text[]')
3246        data = [[['Hello, World!', '{a,b,c}', 'back\\slash']]]
3247        r = dict(data=data)
3248        self.db.insert('arraytest', r)
3249        self.assertEqual(r['data'], data)
3250        r['data'] = None
3251        self.db.get('arraytest', r)
3252        self.assertEqual(r['data'], data)
3253
3254    def testInsertUpdateGetRecord(self):
3255        query = self.db.query
3256        query('create type test_person_type as'
3257              ' (name varchar, age smallint, married bool,'
3258              ' weight real, salary money)')
3259        self.addCleanup(query, 'drop type test_person_type')
3260        self.createTable('test_person', 'person test_person_type',
3261                         temporary=False, oids=True)
3262        attnames = self.db.get_attnames('test_person')
3263        self.assertEqual(len(attnames), 2)
3264        self.assertIn('oid', attnames)
3265        self.assertIn('person', attnames)
3266        person_typ = attnames['person']
3267        if self.regtypes:
3268            self.assertEqual(person_typ, 'test_person_type')
3269        else:
3270            self.assertEqual(person_typ, 'record')
3271        if self.regtypes:
3272            self.assertEqual(person_typ.attnames,
3273                             dict(name='character varying', age='smallint',
3274                                  married='boolean', weight='real', salary='money'))
3275        else:
3276            self.assertEqual(person_typ.attnames,
3277                             dict(name='text', age='int', married='bool',
3278                                  weight='float', salary='money'))
3279        decimal = pg.get_decimal()
3280        if pg.get_bool():
3281            bool_class = bool
3282            t, f = True, False
3283        else:
3284            bool_class = str
3285            t, f = 't', 'f'
3286        person = ('John Doe', 61, t, 99.5, decimal('93456.75'))
3287        r = self.db.insert('test_person', None, person=person)
3288        p = r['person']
3289        self.assertIsInstance(p, tuple)
3290        self.assertEqual(p, person)
3291        self.assertEqual(p.name, 'John Doe')
3292        self.assertIsInstance(p.name, str)
3293        self.assertIsInstance(p.age, int)
3294        self.assertIsInstance(p.married, bool_class)
3295        self.assertIsInstance(p.weight, float)
3296        self.assertIsInstance(p.salary, decimal)
3297        person = ('Jane Roe', 59, f, 64.5, decimal('96543.25'))
3298        r['person'] = person
3299        self.db.update('test_person', r)
3300        p = r['person']
3301        self.assertIsInstance(p, tuple)
3302        self.assertEqual(p, person)
3303        self.assertEqual(p.name, 'Jane Roe')
3304        self.assertIsInstance(p.name, str)
3305        self.assertIsInstance(p.age, int)
3306        self.assertIsInstance(p.married, bool_class)
3307        self.assertIsInstance(p.weight, float)
3308        self.assertIsInstance(p.salary, decimal)
3309        r['person'] = None
3310        self.db.get('test_person', r)
3311        p = r['person']
3312        self.assertIsInstance(p, tuple)
3313        self.assertEqual(p, person)
3314        self.assertEqual(p.name, 'Jane Roe')
3315        self.assertIsInstance(p.name, str)
3316        self.assertIsInstance(p.age, int)
3317        self.assertIsInstance(p.married, bool_class)
3318        self.assertIsInstance(p.weight, float)
3319        self.assertIsInstance(p.salary, decimal)
3320        person = (None,) * 5
3321        r = self.db.insert('test_person', None, person=person)
3322        p = r['person']
3323        self.assertIsInstance(p, tuple)
3324        self.assertIsNone(p.name)
3325        self.assertIsNone(p.age)
3326        self.assertIsNone(p.married)
3327        self.assertIsNone(p.weight)
3328        self.assertIsNone(p.salary)
3329        r['person'] = None
3330        self.db.get('test_person', r)
3331        p = r['person']
3332        self.assertIsInstance(p, tuple)
3333        self.assertIsNone(p.name)
3334        self.assertIsNone(p.age)
3335        self.assertIsNone(p.married)
3336        self.assertIsNone(p.weight)
3337        self.assertIsNone(p.salary)
3338        r = self.db.insert('test_person', None, person=None)
3339        self.assertIsNone(r['person'])
3340        r['person'] = None
3341        self.db.get('test_person', r)
3342        self.assertIsNone(r['person'])
3343
3344    def testRecordInsertBytea(self):
3345        query = self.db.query
3346        query('create type test_person_type as'
3347              ' (name text, picture bytea)')
3348        self.addCleanup(query, 'drop type test_person_type')
3349        self.createTable('test_person', 'person test_person_type',
3350                         temporary=False, oids=True)
3351        person_typ = self.db.get_attnames('test_person')['person']
3352        self.assertEqual(person_typ.attnames,
3353                         dict(name='text', picture='bytea'))
3354        person = ('John Doe', b'O\x00ps\xff!')
3355        r = self.db.insert('test_person', None, person=person)
3356        p = r['person']
3357        self.assertIsInstance(p, tuple)
3358        self.assertEqual(p, person)
3359        self.assertEqual(p.name, 'John Doe')
3360        self.assertIsInstance(p.name, str)
3361        self.assertEqual(p.picture, person[1])
3362        self.assertIsInstance(p.picture, bytes)
3363
3364    def testRecordInsertJson(self):
3365        query = self.db.query
3366        try:
3367            query('create type test_person_type as'
3368                  ' (name text, data json)')
3369        except pg.ProgrammingError as error:
3370            if self.db.server_version < 90200:
3371                self.skipTest('database does not support json')
3372            self.fail(str(error))
3373        self.addCleanup(query, 'drop type test_person_type')
3374        self.createTable('test_person', 'person test_person_type',
3375                         temporary=False, oids=True)
3376        person_typ = self.db.get_attnames('test_person')['person']
3377        self.assertEqual(person_typ.attnames,
3378                         dict(name='text', data='json'))
3379        person = ('John Doe', dict(age=61, married=True, weight=99.5))
3380        r = self.db.insert('test_person', None, person=person)
3381        p = r['person']
3382        self.assertIsInstance(p, tuple)
3383        if pg.get_jsondecode() is None:
3384            p = p._replace(data=json.loads(p.data))
3385        self.assertEqual(p, person)
3386        self.assertEqual(p.name, 'John Doe')
3387        self.assertIsInstance(p.name, str)
3388        self.assertEqual(p.data, person[1])
3389        self.assertIsInstance(p.data, dict)
3390
3391    def testRecordLiteral(self):
3392        query = self.db.query
3393        query('create type test_person_type as'
3394              ' (name varchar, age smallint)')
3395        self.addCleanup(query, 'drop type test_person_type')
3396        self.createTable('test_person', 'person test_person_type',
3397                         temporary=False, oids=True)
3398        person_typ = self.db.get_attnames('test_person')['person']
3399        if self.regtypes:
3400            self.assertEqual(person_typ, 'test_person_type')
3401        else:
3402            self.assertEqual(person_typ, 'record')
3403        if self.regtypes:
3404            self.assertEqual(person_typ.attnames,
3405                             dict(name='character varying', age='smallint'))
3406        else:
3407            self.assertEqual(person_typ.attnames,
3408                             dict(name='text', age='int'))
3409        person = pg._Literal("('John Doe', 61)")
3410        r = self.db.insert('test_person', None, person=person)
3411        p = r['person']
3412        self.assertIsInstance(p, tuple)
3413        self.assertEqual(p.name, 'John Doe')
3414        self.assertIsInstance(p.name, str)
3415        self.assertEqual(p.age, 61)
3416        self.assertIsInstance(p.age, int)
3417
3418    def testDbTypesInfo(self):
3419        dbtypes = self.db.dbtypes
3420        self.assertIsInstance(dbtypes, dict)
3421        self.assertNotIn('numeric', dbtypes)
3422        typ = dbtypes['numeric']
3423        self.assertIn('numeric', dbtypes)
3424        self.assertEqual(typ, 'numeric' if self.regtypes else 'num')
3425        self.assertEqual(typ.oid, 1700)
3426        self.assertEqual(typ.pgtype, 'numeric')
3427        self.assertEqual(typ.regtype, 'numeric')
3428        self.assertEqual(typ.simple, 'num')
3429        self.assertEqual(typ.typtype, 'b')
3430        self.assertEqual(typ.category, 'N')
3431        self.assertEqual(typ.delim, ',')
3432        self.assertEqual(typ.relid, 0)
3433        self.assertIs(dbtypes[1700], typ)
3434        self.assertNotIn('pg_type', dbtypes)
3435        typ = dbtypes['pg_type']
3436        self.assertIn('pg_type', dbtypes)
3437        self.assertEqual(typ, 'pg_type' if self.regtypes else 'record')
3438        self.assertIsInstance(typ.oid, int)
3439        self.assertEqual(typ.pgtype, 'pg_type')
3440        self.assertEqual(typ.regtype, 'pg_type')
3441        self.assertEqual(typ.simple, 'record')
3442        self.assertEqual(typ.typtype, 'c')
3443        self.assertEqual(typ.category, 'C')
3444        self.assertEqual(typ.delim, ',')
3445        self.assertNotEqual(typ.relid, 0)
3446        attnames = typ.attnames
3447        self.assertIsInstance(attnames, dict)
3448        self.assertIs(attnames, dbtypes.get_attnames('pg_type'))
3449        self.assertEqual(list(attnames)[0], 'typname')
3450        typname = attnames['typname']
3451        self.assertEqual(typname, 'name' if self.regtypes else 'text')
3452        self.assertEqual(typname.typtype, 'b')  # base
3453        self.assertEqual(typname.category, 'S')  # string
3454        self.assertEqual(list(attnames)[3], 'typlen')
3455        typlen = attnames['typlen']
3456        self.assertEqual(typlen, 'smallint' if self.regtypes else 'int')
3457        self.assertEqual(typlen.typtype, 'b')  # base
3458        self.assertEqual(typlen.category, 'N')  # numeric
3459
3460    def testDbTypesTypecast(self):
3461        dbtypes = self.db.dbtypes
3462        self.assertIsInstance(dbtypes, dict)
3463        self.assertNotIn('circle', dbtypes)
3464        cast_circle = dbtypes.get_typecast('circle')
3465        squared_circle = lambda v: 'Squared Circle: %s' % v
3466        dbtypes.set_typecast('circle', squared_circle)
3467        self.assertIs(dbtypes.get_typecast('circle'), squared_circle)
3468        r = self.db.query("select '0,0,1'::circle").getresult()[0][0]
3469        self.assertIn('circle', dbtypes)
3470        self.assertEqual(r, 'Squared Circle: <(0,0),1>')
3471        self.assertEqual(dbtypes.typecast('Impossible', 'circle'),
3472            'Squared Circle: Impossible')
3473        dbtypes.reset_typecast('circle')
3474        self.assertIs(dbtypes.get_typecast('circle'), cast_circle)
3475
3476    def testGetSetTypeCast(self):
3477        get_typecast = pg.get_typecast
3478        set_typecast = pg.set_typecast
3479        dbtypes = self.db.dbtypes
3480        self.assertIsInstance(dbtypes, dict)
3481        self.assertNotIn('int4', dbtypes)
3482        self.assertNotIn('real', dbtypes)
3483        self.assertNotIn('bool', dbtypes)
3484        self.assertIs(get_typecast('int4'), int)
3485        self.assertIs(get_typecast('float4'), float)
3486        self.assertIs(get_typecast('bool'), pg.cast_bool)
3487        cast_circle = get_typecast('circle')
3488        squared_circle = lambda v: 'Squared Circle: %s' % v
3489        self.assertNotIn('circle', dbtypes)
3490        set_typecast('circle', squared_circle)
3491        self.assertNotIn('circle', dbtypes)
3492        self.assertIs(get_typecast('circle'), squared_circle)
3493        r = self.db.query("select '0,0,1'::circle").getresult()[0][0]
3494        self.assertIn('circle', dbtypes)
3495        self.assertEqual(r, 'Squared Circle: <(0,0),1>')
3496        set_typecast('circle', cast_circle)
3497        self.assertIs(get_typecast('circle'), cast_circle)
3498
3499    def testNotificationHandler(self):
3500        # the notification handler itself is tested separately
3501        f = self.db.notification_handler
3502        callback = lambda arg_dict: None
3503        handler = f('test', callback)
3504        self.assertIsInstance(handler, pg.NotificationHandler)
3505        self.assertIs(handler.db, self.db)
3506        self.assertEqual(handler.event, 'test')
3507        self.assertEqual(handler.stop_event, 'stop_test')
3508        self.assertIs(handler.callback, callback)
3509        self.assertIsInstance(handler.arg_dict, dict)
3510        self.assertEqual(handler.arg_dict, {})
3511        self.assertIsNone(handler.timeout)
3512        self.assertFalse(handler.listening)
3513        handler.close()
3514        self.assertIsNone(handler.db)
3515        self.db.reopen()
3516        self.assertIsNone(handler.db)
3517        handler = f('test2', callback, timeout=2)
3518        self.assertIsInstance(handler, pg.NotificationHandler)
3519        self.assertIs(handler.db, self.db)
3520        self.assertEqual(handler.event, 'test2')
3521        self.assertEqual(handler.stop_event, 'stop_test2')
3522        self.assertIs(handler.callback, callback)
3523        self.assertIsInstance(handler.arg_dict, dict)
3524        self.assertEqual(handler.arg_dict, {})
3525        self.assertEqual(handler.timeout, 2)
3526        self.assertFalse(handler.listening)
3527        handler.close()
3528        self.assertIsNone(handler.db)
3529        self.db.reopen()
3530        self.assertIsNone(handler.db)
3531        arg_dict = {'testing': 3}
3532        handler = f('test3', callback, arg_dict=arg_dict)
3533        self.assertIsInstance(handler, pg.NotificationHandler)
3534        self.assertIs(handler.db, self.db)
3535        self.assertEqual(handler.event, 'test3')
3536        self.assertEqual(handler.stop_event, 'stop_test3')
3537        self.assertIs(handler.callback, callback)
3538        self.assertIs(handler.arg_dict, arg_dict)
3539        self.assertEqual(arg_dict['testing'], 3)
3540        self.assertIsNone(handler.timeout)
3541        self.assertFalse(handler.listening)
3542        handler.close()
3543        self.assertIsNone(handler.db)
3544        self.db.reopen()
3545        self.assertIsNone(handler.db)
3546        handler = f('test4', callback, stop_event='stop4')
3547        self.assertIsInstance(handler, pg.NotificationHandler)
3548        self.assertIs(handler.db, self.db)
3549        self.assertEqual(handler.event, 'test4')
3550        self.assertEqual(handler.stop_event, 'stop4')
3551        self.assertIs(handler.callback, callback)
3552        self.assertIsInstance(handler.arg_dict, dict)
3553        self.assertEqual(handler.arg_dict, {})
3554        self.assertIsNone(handler.timeout)
3555        self.assertFalse(handler.listening)
3556        handler.close()
3557        self.assertIsNone(handler.db)
3558        self.db.reopen()
3559        self.assertIsNone(handler.db)
3560        arg_dict = {'testing': 5}
3561        handler = f('test5', callback, arg_dict, 1.5, 'stop5')
3562        self.assertIsInstance(handler, pg.NotificationHandler)
3563        self.assertIs(handler.db, self.db)
3564        self.assertEqual(handler.event, 'test5')
3565        self.assertEqual(handler.stop_event, 'stop5')
3566        self.assertIs(handler.callback, callback)
3567        self.assertIs(handler.arg_dict, arg_dict)
3568        self.assertEqual(arg_dict['testing'], 5)
3569        self.assertEqual(handler.timeout, 1.5)
3570        self.assertFalse(handler.listening)
3571        handler.close()
3572        self.assertIsNone(handler.db)
3573        self.db.reopen()
3574        self.assertIsNone(handler.db)
3575
3576
3577class TestDBClassNonStdOpts(TestDBClass):
3578    """Test the methods of the DB class with non-standard global options."""
3579
3580    @classmethod
3581    def setUpClass(cls):
3582        cls.saved_options = {}
3583        cls.set_option('decimal', float)
3584        not_bool = not pg.get_bool()
3585        cls.set_option('bool', not_bool)
3586        cls.set_option('namedresult', None)
3587        cls.set_option('jsondecode', None)
3588        cls.regtypes = not DB().use_regtypes()
3589        super(TestDBClassNonStdOpts, cls).setUpClass()
3590
3591    @classmethod
3592    def tearDownClass(cls):
3593        super(TestDBClassNonStdOpts, cls).tearDownClass()
3594        cls.reset_option('jsondecode')
3595        cls.reset_option('namedresult')
3596        cls.reset_option('bool')
3597        cls.reset_option('decimal')
3598
3599    @classmethod
3600    def set_option(cls, option, value):
3601        cls.saved_options[option] = getattr(pg, 'get_' + option)()
3602        return getattr(pg, 'set_' + option)(value)
3603
3604    @classmethod
3605    def reset_option(cls, option):
3606        return getattr(pg, 'set_' + option)(cls.saved_options[option])
3607
3608
3609class TestSchemas(unittest.TestCase):
3610    """Test correct handling of schemas (namespaces)."""
3611
3612    cls_set_up = False
3613
3614    @classmethod
3615    def setUpClass(cls):
3616        db = DB()
3617        query = db.query
3618        for num_schema in range(5):
3619            if num_schema:
3620                schema = "s%d" % num_schema
3621                query("drop schema if exists %s cascade" % (schema,))
3622                try:
3623                    query("create schema %s" % (schema,))
3624                except pg.ProgrammingError:
3625                    raise RuntimeError("The test user cannot create schemas.\n"
3626                        "Grant create on database %s to the user"
3627                        " for running these tests." % dbname)
3628            else:
3629                schema = "public"
3630                query("drop table if exists %s.t" % (schema,))
3631                query("drop table if exists %s.t%d" % (schema, num_schema))
3632            query("create table %s.t with oids as select 1 as n, %d as d"
3633                  % (schema, num_schema))
3634            query("create table %s.t%d with oids as select 1 as n, %d as d"
3635                  % (schema, num_schema, num_schema))
3636        db.close()
3637        cls.cls_set_up = True
3638
3639    @classmethod
3640    def tearDownClass(cls):
3641        db = DB()
3642        query = db.query
3643        for num_schema in range(5):
3644            if num_schema:
3645                schema = "s%d" % num_schema
3646                query("drop schema %s cascade" % (schema,))
3647            else:
3648                schema = "public"
3649                query("drop table %s.t" % (schema,))
3650                query("drop table %s.t%d" % (schema, num_schema))
3651        db.close()
3652
3653    def setUp(self):
3654        self.assertTrue(self.cls_set_up)
3655        self.db = DB()
3656
3657    def tearDown(self):
3658        self.doCleanups()
3659        self.db.close()
3660
3661    def testGetTables(self):
3662        tables = self.db.get_tables()
3663        for num_schema in range(5):
3664            if num_schema:
3665                schema = "s" + str(num_schema)
3666            else:
3667                schema = "public"
3668            for t in (schema + ".t",
3669                    schema + ".t" + str(num_schema)):
3670                self.assertIn(t, tables)
3671
3672    def testGetAttnames(self):
3673        get_attnames = self.db.get_attnames
3674        query = self.db.query
3675        result = {'oid': 'int', 'd': 'int', 'n': 'int'}
3676        r = get_attnames("t")
3677        self.assertEqual(r, result)
3678        r = get_attnames("s4.t4")
3679        self.assertEqual(r, result)
3680        query("drop table if exists s3.t3m")
3681        self.addCleanup(query, "drop table s3.t3m")
3682        query("create table s3.t3m with oids as select 1 as m")
3683        result_m = {'oid': 'int', 'm': 'int'}
3684        r = get_attnames("s3.t3m")
3685        self.assertEqual(r, result_m)
3686        query("set search_path to s1,s3")
3687        r = get_attnames("t3")
3688        self.assertEqual(r, result)
3689        r = get_attnames("t3m")
3690        self.assertEqual(r, result_m)
3691
3692    def testGet(self):
3693        get = self.db.get
3694        query = self.db.query
3695        PrgError = pg.ProgrammingError
3696        self.assertEqual(get("t", 1, 'n')['d'], 0)
3697        self.assertEqual(get("t0", 1, 'n')['d'], 0)
3698        self.assertEqual(get("public.t", 1, 'n')['d'], 0)
3699        self.assertEqual(get("public.t0", 1, 'n')['d'], 0)
3700        self.assertRaises(PrgError, get, "public.t1", 1, 'n')
3701        self.assertEqual(get("s1.t1", 1, 'n')['d'], 1)
3702        self.assertEqual(get("s3.t", 1, 'n')['d'], 3)
3703        query("set search_path to s2,s4")
3704        self.assertRaises(PrgError, get, "t1", 1, 'n')
3705        self.assertEqual(get("t4", 1, 'n')['d'], 4)
3706        self.assertRaises(PrgError, get, "t3", 1, 'n')
3707        self.assertEqual(get("t", 1, 'n')['d'], 2)
3708        self.assertEqual(get("s3.t3", 1, 'n')['d'], 3)
3709        query("set search_path to s1,s3")
3710        self.assertRaises(PrgError, get, "t2", 1, 'n')
3711        self.assertEqual(get("t3", 1, 'n')['d'], 3)
3712        self.assertRaises(PrgError, get, "t4", 1, 'n')
3713        self.assertEqual(get("t", 1, 'n')['d'], 1)
3714        self.assertEqual(get("s4.t4", 1, 'n')['d'], 4)
3715
3716    def testMunging(self):
3717        get = self.db.get
3718        query = self.db.query
3719        r = get("t", 1, 'n')
3720        self.assertIn('oid(t)', r)
3721        query("set search_path to s2")
3722        r = get("t2", 1, 'n')
3723        self.assertIn('oid(t2)', r)
3724        query("set search_path to s3")
3725        r = get("t", 1, 'n')
3726        self.assertIn('oid(t)', r)
3727
3728
3729class TestDebug(unittest.TestCase):
3730    """Test the debug attribute of the DB class."""
3731
3732    def setUp(self):
3733        self.db = DB()
3734        self.query = self.db.query
3735        self.debug = self.db.debug
3736        self.output = StringIO()
3737        self.stdout, sys.stdout = sys.stdout, self.output
3738
3739    def tearDown(self):
3740        sys.stdout = self.stdout
3741        self.output.close()
3742        self.db.debug = debug
3743        self.db.close()
3744
3745    def get_output(self):
3746        return self.output.getvalue()
3747
3748    def send_queries(self):
3749        self.db.query("select 1")
3750        self.db.query("select 2")
3751
3752    def testDebugDefault(self):
3753        if debug:
3754            self.assertEqual(self.db.debug, debug)
3755        else:
3756            self.assertIsNone(self.db.debug)
3757
3758    def testDebugIsFalse(self):
3759        self.db.debug = False
3760        self.send_queries()
3761        self.assertEqual(self.get_output(), "")
3762
3763    def testDebugIsTrue(self):
3764        self.db.debug = True
3765        self.send_queries()
3766        self.assertEqual(self.get_output(), "select 1\nselect 2\n")
3767
3768    def testDebugIsString(self):
3769        self.db.debug = "Test with string: %s."
3770        self.send_queries()
3771        self.assertEqual(self.get_output(),
3772            "Test with string: select 1.\nTest with string: select 2.\n")
3773
3774    def testDebugIsFileLike(self):
3775        with tempfile.TemporaryFile('w+') as debug_file:
3776            self.db.debug = debug_file
3777            self.send_queries()
3778            debug_file.seek(0)
3779            output = debug_file.read()
3780            self.assertEqual(output, "select 1\nselect 2\n")
3781            self.assertEqual(self.get_output(), "")
3782
3783    def testDebugIsCallable(self):
3784        output = []
3785        self.db.debug = output.append
3786        self.db.query("select 1")
3787        self.db.query("select 2")
3788        self.assertEqual(output, ["select 1", "select 2"])
3789        self.assertEqual(self.get_output(), "")
3790
3791    def testDebugMultipleArgs(self):
3792        output = []
3793        self.db.debug = output.append
3794        args = ['Error', 42, {1: 'a', 2: 'b'}, [3, 5, 7]]
3795        self.db._do_debug(*args)
3796        self.assertEqual(output, ['\n'.join(str(arg) for arg in args)])
3797        self.assertEqual(self.get_output(), "")
3798
3799
3800if __name__ == '__main__':
3801    unittest.main()
Note: See TracBrowser for help on using the repository browser.