source: trunk/tests/test_classic_dbwrapper.py @ 793

Last change on this file since 793 was 793, checked in by cito, 4 years ago

Improve quoting and typecasting in the pg module

Larger refactoring of the code for adapting and typecasting in the pg module.
Things are now a lot cleaner and clearer.

The _Adapt class is responsible for all adapting of Python objects to their
PostgreSQL equivalents when sending data to the database. The typecasting
from PostgreSQL on output happens in the C module, except for the typecasting
of records which is new and provided by the _CastRecord class.

The classic module also did not work properly when regular type names were
switched on with use_regtypes(True), since the adapting of types relied on
the PyGreSQL type names. This has been solved by adding a new _PgType class
that is essentially the old type name, but augmented with all the necessary
information necessary to adapt types, particularly record types.

All tests in test_classic_dbwrapper now run twice, using opposite settings
for the various configuration settings like use_bool() or use_regtypes(),
in order to make sure that no internal functions rely on default settings.

  • Property svn:executable set to *
  • Property svn:keywords set to Id
File size: 146.5 KB
Line 
1#! /usr/bin/python
2# -*- coding: utf-8 -*-
3
4"""Test the classic PyGreSQL interface.
5
6Sub-tests for the DB wrapper object.
7
8Contributed by Christoph Zwerschke.
9
10These tests need a database to test against.
11"""
12
13try:
14    import unittest2 as unittest  # for Python < 2.7
15except ImportError:
16    import unittest
17
18import os
19import sys
20import tempfile
21import json
22
23import pg  # the module under test
24
25from decimal import Decimal
26from operator import itemgetter
27
28# We need a database to test against.  If LOCAL_PyGreSQL.py exists we will
29# get our information from that.  Otherwise we use the defaults.
30# The current user must have create schema privilege on the database.
31dbname = 'unittest'
32dbhost = None
33dbport = 5432
34
35debug = False  # let DB wrapper print debugging output
36
37try:
38    from .LOCAL_PyGreSQL import *
39except (ImportError, ValueError):
40    try:
41        from LOCAL_PyGreSQL import *
42    except ImportError:
43        pass
44
45try:
46    long
47except NameError:  # Python >= 3.0
48    long = int
49
50try:
51    unicode
52except NameError:  # Python >= 3.0
53    unicode = str
54
55try:
56    from collections import OrderedDict
57except ImportError:  # Python 2.6 or 3.0
58    OrderedDict = dict
59
60if str is bytes:
61    from StringIO import StringIO
62else:
63    from io import StringIO
64
65windows = os.name == 'nt'
66
67# There is a known a bug in libpq under Windows which can cause
68# the interface to crash when calling PQhost():
69do_not_ask_for_host = windows
70do_not_ask_for_host_reason = 'libpq issue on Windows'
71
72
73def DB():
74    """Create a DB wrapper object connecting to the test database."""
75    db = pg.DB(dbname, dbhost, dbport)
76    if debug:
77        db.debug = debug
78    db.query("set client_min_messages=warning")
79    return db
80
81
82class TestAttrDict(unittest.TestCase):
83    """Test the simple ordered dictionary for attribute names."""
84
85    cls = pg.AttrDict
86    base = OrderedDict
87
88    def testInit(self):
89        a = self.cls()
90        self.assertIsInstance(a, self.base)
91        self.assertEqual(a, self.base())
92        items = [('id', 'int'), ('name', 'text')]
93        a = self.cls(items)
94        self.assertIsInstance(a, self.base)
95        self.assertEqual(a, self.base(items))
96        iteritems = iter(items)
97        a = self.cls(iteritems)
98        self.assertIsInstance(a, self.base)
99        self.assertEqual(a, self.base(items))
100
101    def testIter(self):
102        a = self.cls()
103        self.assertEqual(list(a), [])
104        keys = ['id', 'name', 'age']
105        items = [(key, None) for key in keys]
106        a = self.cls(items)
107        self.assertEqual(list(a), keys)
108
109    def testKeys(self):
110        a = self.cls()
111        self.assertEqual(list(a.keys()), [])
112        keys = ['id', 'name', 'age']
113        items = [(key, None) for key in keys]
114        a = self.cls(items)
115        self.assertEqual(list(a.keys()), keys)
116
117    def testValues(self):
118        a = self.cls()
119        self.assertEqual(list(a.values()), [])
120        items = [('id', 'int'), ('name', 'text')]
121        values = [item[1] for item in items]
122        a = self.cls(items)
123        self.assertEqual(list(a.values()), values)
124
125    def testItems(self):
126        a = self.cls()
127        self.assertEqual(list(a.items()), [])
128        items = [('id', 'int'), ('name', 'text')]
129        a = self.cls(items)
130        self.assertEqual(list(a.items()), items)
131
132    def testGet(self):
133        a = self.cls([('id', 1)])
134        try:
135            self.assertEqual(a['id'], 1)
136        except KeyError:
137            self.fail('AttrDict should be readable')
138
139    def testSet(self):
140        a = self.cls()
141        try:
142            a['id'] = 1
143        except TypeError:
144            pass
145        else:
146            self.fail('AttrDict should be read-only')
147
148    def testDel(self):
149        a = self.cls([('id', 1)])
150        try:
151            del a['id']
152        except TypeError:
153            pass
154        else:
155            self.fail('AttrDict should be read-only')
156
157    def testWriteMethods(self):
158        a = self.cls([('id', 1)])
159        self.assertEqual(a['id'], 1)
160        for method in 'clear', 'update', 'pop', 'setdefault', 'popitem':
161            method = getattr(a, method)
162            self.assertRaises(TypeError, method, a)
163
164
165class TestDBClassBasic(unittest.TestCase):
166    """Test existence of the DB class wrapped pg connection methods."""
167
168    def setUp(self):
169        self.db = DB()
170
171    def tearDown(self):
172        try:
173            self.db.close()
174        except pg.InternalError:
175            pass
176
177    def testAllDBAttributes(self):
178        attributes = [
179            'abort',
180            'begin',
181            'cancel', 'clear', 'close', 'commit',
182            'db', 'dbname', 'debug', 'decode_json', 'delete',
183            'encode_json', 'end', 'endcopy', 'error',
184            'escape_bytea', 'escape_identifier',
185            'escape_literal', 'escape_string',
186            'fileno',
187            'get', 'get_as_dict', 'get_as_list',
188            'get_attnames', 'get_databases',
189            'get_notice_receiver', 'get_parameter',
190            'get_relations', 'get_tables',
191            'getline', 'getlo', 'getnotify',
192            'has_table_privilege', 'host',
193            'insert', 'inserttable',
194            'locreate', 'loimport',
195            'notification_handler',
196            'options',
197            'parameter', 'pkey', 'port',
198            'protocol_version', 'putline',
199            'query',
200            'release', 'reopen', 'reset', 'rollback',
201            'savepoint', 'server_version',
202            'set_notice_receiver', 'set_parameter',
203            'source', 'start', 'status',
204            'transaction', 'truncate',
205            'unescape_bytea', 'update', 'upsert',
206            'use_regtypes', 'user',
207        ]
208        db_attributes = [a for a in dir(self.db)
209            if not a.startswith('_')]
210        self.assertEqual(attributes, db_attributes)
211
212    def testAttributeDb(self):
213        self.assertEqual(self.db.db.db, dbname)
214
215    def testAttributeDbname(self):
216        self.assertEqual(self.db.dbname, dbname)
217
218    def testAttributeError(self):
219        error = self.db.error
220        self.assertTrue(not error or 'krb5_' in error)
221        self.assertEqual(self.db.error, self.db.db.error)
222
223    @unittest.skipIf(do_not_ask_for_host, do_not_ask_for_host_reason)
224    def testAttributeHost(self):
225        def_host = 'localhost'
226        host = self.db.host
227        self.assertIsInstance(host, str)
228        self.assertEqual(host, dbhost or def_host)
229        self.assertEqual(host, self.db.db.host)
230
231    def testAttributeOptions(self):
232        no_options = ''
233        options = self.db.options
234        self.assertEqual(options, no_options)
235        self.assertEqual(options, self.db.db.options)
236
237    def testAttributePort(self):
238        def_port = 5432
239        port = self.db.port
240        self.assertIsInstance(port, int)
241        self.assertEqual(port, dbport or def_port)
242        self.assertEqual(port, self.db.db.port)
243
244    def testAttributeProtocolVersion(self):
245        protocol_version = self.db.protocol_version
246        self.assertIsInstance(protocol_version, int)
247        self.assertTrue(2 <= protocol_version < 4)
248        self.assertEqual(protocol_version, self.db.db.protocol_version)
249
250    def testAttributeServerVersion(self):
251        server_version = self.db.server_version
252        self.assertIsInstance(server_version, int)
253        self.assertTrue(70400 <= server_version < 100000)
254        self.assertEqual(server_version, self.db.db.server_version)
255
256    def testAttributeStatus(self):
257        status_ok = 1
258        status = self.db.status
259        self.assertIsInstance(status, int)
260        self.assertEqual(status, status_ok)
261        self.assertEqual(status, self.db.db.status)
262
263    def testAttributeUser(self):
264        no_user = 'Deprecated facility'
265        user = self.db.user
266        self.assertTrue(user)
267        self.assertIsInstance(user, str)
268        self.assertNotEqual(user, no_user)
269        self.assertEqual(user, self.db.db.user)
270
271    def testMethodEscapeLiteral(self):
272        self.assertEqual(self.db.escape_literal(''), "''")
273
274    def testMethodEscapeIdentifier(self):
275        self.assertEqual(self.db.escape_identifier(''), '""')
276
277    def testMethodEscapeString(self):
278        self.assertEqual(self.db.escape_string(''), '')
279
280    def testMethodEscapeBytea(self):
281        self.assertEqual(self.db.escape_bytea('').replace(
282            '\\x', '').replace('\\', ''), '')
283
284    def testMethodUnescapeBytea(self):
285        self.assertEqual(self.db.unescape_bytea(''), b'')
286
287    def testMethodDecodeJson(self):
288        self.assertEqual(self.db.decode_json('{}'), {})
289
290    def testMethodEncodeJson(self):
291        self.assertEqual(self.db.encode_json({}), '{}')
292
293    def testMethodQuery(self):
294        query = self.db.query
295        query("select 1+1")
296        query("select 1+$1+$2", 2, 3)
297        query("select 1+$1+$2", (2, 3))
298        query("select 1+$1+$2", [2, 3])
299        query("select 1+$1", 1)
300
301    def testMethodQueryEmpty(self):
302        self.assertRaises(ValueError, self.db.query, '')
303
304    def testMethodQueryProgrammingError(self):
305        try:
306            self.db.query("select 1/0")
307        except pg.ProgrammingError as error:
308            self.assertEqual(error.sqlstate, '22012')
309
310    def testMethodEndcopy(self):
311        try:
312            self.db.endcopy()
313        except IOError:
314            pass
315
316    def testMethodClose(self):
317        self.db.close()
318        try:
319            self.db.reset()
320        except pg.Error:
321            pass
322        else:
323            self.fail('Reset should give an error for a closed connection')
324        self.assertIsNone(self.db.db)
325        self.assertRaises(pg.InternalError, self.db.close)
326        self.assertRaises(pg.InternalError, self.db.query, 'select 1')
327        self.assertRaises(pg.InternalError, getattr, self.db, 'status')
328        self.assertRaises(pg.InternalError, getattr, self.db, 'error')
329        self.assertRaises(pg.InternalError, getattr, self.db, 'absent')
330
331    def testMethodReset(self):
332        con = self.db.db
333        self.db.reset()
334        self.assertIs(self.db.db, con)
335        self.db.query("select 1+1")
336        self.db.close()
337        self.assertRaises(pg.InternalError, self.db.reset)
338
339    def testMethodReopen(self):
340        con = self.db.db
341        self.db.reopen()
342        self.assertIsNot(self.db.db, con)
343        con = self.db.db
344        self.db.query("select 1+1")
345        self.db.close()
346        self.db.reopen()
347        self.assertIsNot(self.db.db, con)
348        self.db.query("select 1+1")
349        self.db.close()
350
351    def testExistingConnection(self):
352        db = pg.DB(self.db.db)
353        self.assertEqual(self.db.db, db.db)
354        self.assertTrue(db.db)
355        db.close()
356        self.assertTrue(db.db)
357        db.reopen()
358        self.assertTrue(db.db)
359        db.close()
360        self.assertTrue(db.db)
361        db = pg.DB(self.db)
362        self.assertEqual(self.db.db, db.db)
363        db = pg.DB(db=self.db.db)
364        self.assertEqual(self.db.db, db.db)
365
366        class DB2:
367            pass
368
369        db2 = DB2()
370        db2._cnx = self.db.db
371        db = pg.DB(db2)
372        self.assertEqual(self.db.db, db.db)
373
374
375class TestDBClass(unittest.TestCase):
376    """Test the methods of the DB class wrapped pg connection."""
377
378    maxDiff = 80 * 20
379
380    cls_set_up = False
381
382    regtypes = None
383
384    @classmethod
385    def setUpClass(cls):
386        db = DB()
387        db.query("drop table if exists test cascade")
388        db.query("create table test ("
389            "i2 smallint, i4 integer, i8 bigint,"
390            " d numeric, f4 real, f8 double precision, m money,"
391            " v4 varchar(4), c4 char(4), t text)")
392        db.query("create or replace view test_view as"
393            " select i4, v4 from test")
394        db.close()
395        cls.cls_set_up = True
396
397    @classmethod
398    def tearDownClass(cls):
399        db = DB()
400        db.query("drop table test cascade")
401        db.close()
402
403    def setUp(self):
404        self.assertTrue(self.cls_set_up)
405        self.db = DB()
406        if self.regtypes is None:
407            self.regtypes = self.db.use_regtypes()
408        else:
409            self.db.use_regtypes(self.regtypes)
410        query = self.db.query
411        query('set client_encoding=utf8')
412        query('set standard_conforming_strings=on')
413        query("set lc_monetary='C'")
414        query("set datestyle='ISO,YMD'")
415        query('set bytea_output=hex')
416
417    def tearDown(self):
418        self.doCleanups()
419        self.db.close()
420
421    def createTable(self, table, definition,
422            temporary=True, oids=None, values=None):
423        query = self.db.query
424        if not '"' in table or '.' in table:
425            table = '"%s"' % table
426        if not temporary:
427            q = 'drop table if exists %s cascade' % table
428            query(q)
429            self.addCleanup(query, q)
430        temporary = 'temporary table' if temporary else 'table'
431        as_query = definition.startswith(('as ', 'AS '))
432        if not as_query and not definition.startswith('('):
433            definition = '(%s)' % definition
434        with_oids = 'with oids' if oids else 'without oids'
435        q = ['create', temporary, table]
436        if as_query:
437            q.extend([with_oids, definition])
438        else:
439            q.extend([definition, with_oids])
440        q = ' '.join(q)
441        query(q)
442        if values:
443            for params in values:
444                if not isinstance(params, (list, tuple)):
445                    params = [params]
446                values = ', '.join('$%d' % (n + 1) for n in range(len(params)))
447                q = "insert into %s values (%s)" % (table, values)
448                query(q, params)
449
450    def testClassName(self):
451        self.assertEqual(self.db.__class__.__name__, 'DB')
452
453    def testModuleName(self):
454        self.assertEqual(self.db.__module__, 'pg')
455        self.assertEqual(self.db.__class__.__module__, 'pg')
456
457    def testEscapeLiteral(self):
458        f = self.db.escape_literal
459        r = f(b"plain")
460        self.assertIsInstance(r, bytes)
461        self.assertEqual(r, b"'plain'")
462        r = f(u"plain")
463        self.assertIsInstance(r, unicode)
464        self.assertEqual(r, u"'plain'")
465        r = f(u"that's kÀse".encode('utf-8'))
466        self.assertIsInstance(r, bytes)
467        self.assertEqual(r, u"'that''s kÀse'".encode('utf-8'))
468        r = f(u"that's kÀse")
469        self.assertIsInstance(r, unicode)
470        self.assertEqual(r, u"'that''s kÀse'")
471        self.assertEqual(f(r"It's fine to have a \ inside."),
472            r" E'It''s fine to have a \\ inside.'")
473        self.assertEqual(f('No "quotes" must be escaped.'),
474            "'No \"quotes\" must be escaped.'")
475
476    def testEscapeIdentifier(self):
477        f = self.db.escape_identifier
478        r = f(b"plain")
479        self.assertIsInstance(r, bytes)
480        self.assertEqual(r, b'"plain"')
481        r = f(u"plain")
482        self.assertIsInstance(r, unicode)
483        self.assertEqual(r, u'"plain"')
484        r = f(u"that's kÀse".encode('utf-8'))
485        self.assertIsInstance(r, bytes)
486        self.assertEqual(r, u'"that\'s kÀse"'.encode('utf-8'))
487        r = f(u"that's kÀse")
488        self.assertIsInstance(r, unicode)
489        self.assertEqual(r, u'"that\'s kÀse"')
490        self.assertEqual(f(r"It's fine to have a \ inside."),
491            '"It\'s fine to have a \\ inside."')
492        self.assertEqual(f('All "quotes" must be escaped.'),
493            '"All ""quotes"" must be escaped."')
494
495    def testEscapeString(self):
496        f = self.db.escape_string
497        r = f(b"plain")
498        self.assertIsInstance(r, bytes)
499        self.assertEqual(r, b"plain")
500        r = f(u"plain")
501        self.assertIsInstance(r, unicode)
502        self.assertEqual(r, u"plain")
503        r = f(u"that's kÀse".encode('utf-8'))
504        self.assertIsInstance(r, bytes)
505        self.assertEqual(r, u"that''s kÀse".encode('utf-8'))
506        r = f(u"that's kÀse")
507        self.assertIsInstance(r, unicode)
508        self.assertEqual(r, u"that''s kÀse")
509        self.assertEqual(f(r"It's fine to have a \ inside."),
510            r"It''s fine to have a \ inside.")
511
512    def testEscapeBytea(self):
513        f = self.db.escape_bytea
514        # note that escape_byte always returns hex output since Pg 9.0,
515        # regardless of the bytea_output setting
516        r = f(b'plain')
517        self.assertIsInstance(r, bytes)
518        self.assertEqual(r, b'\\x706c61696e')
519        r = f(u'plain')
520        self.assertIsInstance(r, unicode)
521        self.assertEqual(r, u'\\x706c61696e')
522        r = f(u"das is' kÀse".encode('utf-8'))
523        self.assertIsInstance(r, bytes)
524        self.assertEqual(r, b'\\x64617320697327206bc3a47365')
525        r = f(u"das is' kÀse")
526        self.assertIsInstance(r, unicode)
527        self.assertEqual(r, u'\\x64617320697327206bc3a47365')
528        self.assertEqual(f(b'O\x00ps\xff!'), b'\\x4f007073ff21')
529
530    def testUnescapeBytea(self):
531        f = self.db.unescape_bytea
532        r = f(b'plain')
533        self.assertIsInstance(r, bytes)
534        self.assertEqual(r, b'plain')
535        r = f(u'plain')
536        self.assertIsInstance(r, bytes)
537        self.assertEqual(r, b'plain')
538        r = f(b"das is' k\\303\\244se")
539        self.assertIsInstance(r, bytes)
540        self.assertEqual(r, u"das is' kÀse".encode('utf8'))
541        r = f(u"das is' k\\303\\244se")
542        self.assertIsInstance(r, bytes)
543        self.assertEqual(r, u"das is' kÀse".encode('utf8'))
544        self.assertEqual(f(r'O\\000ps\\377!'), b'O\\000ps\\377!')
545        self.assertEqual(f(r'\\x706c61696e'), b'\\x706c61696e')
546        self.assertEqual(f(r'\\x746861742773206be47365'),
547            b'\\x746861742773206be47365')
548        self.assertEqual(f(r'\\x4f007073ff21'), b'\\x4f007073ff21')
549
550    def testDecodeJson(self):
551        f = self.db.decode_json
552        self.assertIsNone(f('null'))
553        data = {
554          "id": 1, "name": "Foo", "price": 1234.5,
555          "new": True, "note": None,
556          "tags": ["Bar", "Eek"],
557          "stock": {"warehouse": 300, "retail": 20}}
558        text = json.dumps(data)
559        r = f(text)
560        self.assertIsInstance(r, dict)
561        self.assertEqual(r, data)
562        self.assertIsInstance(r['id'], int)
563        self.assertIsInstance(r['name'], unicode)
564        self.assertIsInstance(r['price'], float)
565        self.assertIsInstance(r['new'], bool)
566        self.assertIsInstance(r['tags'], list)
567        self.assertIsInstance(r['stock'], dict)
568
569    def testEncodeJson(self):
570        f = self.db.encode_json
571        self.assertEqual(f(None), 'null')
572        data = {
573          "id": 1, "name": "Foo", "price": 1234.5,
574          "new": True, "note": None,
575          "tags": ["Bar", "Eek"],
576          "stock": {"warehouse": 300, "retail": 20}}
577        text = json.dumps(data)
578        r = f(data)
579        self.assertIsInstance(r, str)
580        self.assertEqual(r, text)
581
582    def testGetParameter(self):
583        f = self.db.get_parameter
584        self.assertRaises(TypeError, f)
585        self.assertRaises(TypeError, f, None)
586        self.assertRaises(TypeError, f, 42)
587        self.assertRaises(TypeError, f, '')
588        self.assertRaises(TypeError, f, [])
589        self.assertRaises(TypeError, f, [''])
590        self.assertRaises(pg.ProgrammingError, f, 'this_does_not_exist')
591        r = f('standard_conforming_strings')
592        self.assertEqual(r, 'on')
593        r = f('lc_monetary')
594        self.assertEqual(r, 'C')
595        r = f('datestyle')
596        self.assertEqual(r, 'ISO, YMD')
597        r = f('bytea_output')
598        self.assertEqual(r, 'hex')
599        r = f(['bytea_output', 'lc_monetary'])
600        self.assertIsInstance(r, list)
601        self.assertEqual(r, ['hex', 'C'])
602        r = f(('standard_conforming_strings', 'datestyle', 'bytea_output'))
603        self.assertEqual(r, ['on', 'ISO, YMD', 'hex'])
604        r = f(set(['bytea_output', 'lc_monetary']))
605        self.assertIsInstance(r, dict)
606        self.assertEqual(r, {'bytea_output': 'hex', 'lc_monetary': 'C'})
607        r = f(set(['Bytea_Output', ' LC_Monetary ']))
608        self.assertIsInstance(r, dict)
609        self.assertEqual(r, {'Bytea_Output': 'hex', ' LC_Monetary ': 'C'})
610        s = dict.fromkeys(('bytea_output', 'lc_monetary'))
611        r = f(s)
612        self.assertIs(r, s)
613        self.assertEqual(r, {'bytea_output': 'hex', 'lc_monetary': 'C'})
614        s = dict.fromkeys(('Bytea_Output', ' LC_Monetary '))
615        r = f(s)
616        self.assertIs(r, s)
617        self.assertEqual(r, {'Bytea_Output': 'hex', ' LC_Monetary ': 'C'})
618
619    def testGetParameterServerVersion(self):
620        r = self.db.get_parameter('server_version_num')
621        self.assertIsInstance(r, str)
622        s = self.db.server_version
623        self.assertIsInstance(s, int)
624        self.assertEqual(r, str(s))
625
626    def testGetParameterAll(self):
627        f = self.db.get_parameter
628        r = f('all')
629        self.assertIsInstance(r, dict)
630        self.assertEqual(r['standard_conforming_strings'], 'on')
631        self.assertEqual(r['lc_monetary'], 'C')
632        self.assertEqual(r['DateStyle'], 'ISO, YMD')
633        self.assertEqual(r['bytea_output'], 'hex')
634
635    def testSetParameter(self):
636        f = self.db.set_parameter
637        g = self.db.get_parameter
638        self.assertRaises(TypeError, f)
639        self.assertRaises(TypeError, f, None)
640        self.assertRaises(TypeError, f, 42)
641        self.assertRaises(TypeError, f, '')
642        self.assertRaises(TypeError, f, [])
643        self.assertRaises(TypeError, f, [''])
644        self.assertRaises(ValueError, f, 'all', 'invalid')
645        self.assertRaises(ValueError, f, {
646            'invalid1': 'value1', 'invalid2': 'value2'}, 'value')
647        self.assertRaises(pg.ProgrammingError, f, 'this_does_not_exist')
648        f('standard_conforming_strings', 'off')
649        self.assertEqual(g('standard_conforming_strings'), 'off')
650        f('datestyle', 'ISO, DMY')
651        self.assertEqual(g('datestyle'), 'ISO, DMY')
652        f(['standard_conforming_strings', 'datestyle'], ['on', 'ISO, DMY'])
653        self.assertEqual(g('standard_conforming_strings'), 'on')
654        self.assertEqual(g('datestyle'), 'ISO, DMY')
655        f(['default_with_oids', 'standard_conforming_strings'], 'off')
656        self.assertEqual(g('default_with_oids'), 'off')
657        self.assertEqual(g('standard_conforming_strings'), 'off')
658        f(('standard_conforming_strings', 'datestyle'), ('on', 'ISO, YMD'))
659        self.assertEqual(g('standard_conforming_strings'), 'on')
660        self.assertEqual(g('datestyle'), 'ISO, YMD')
661        f(('default_with_oids', 'standard_conforming_strings'), 'off')
662        self.assertEqual(g('default_with_oids'), 'off')
663        self.assertEqual(g('standard_conforming_strings'), 'off')
664        f(set(['default_with_oids', 'standard_conforming_strings']), 'on')
665        self.assertEqual(g('default_with_oids'), 'on')
666        self.assertEqual(g('standard_conforming_strings'), 'on')
667        self.assertRaises(ValueError, f, set([ 'default_with_oids',
668            'standard_conforming_strings']), ['off', 'on'])
669        f(set(['default_with_oids', 'standard_conforming_strings']),
670            ['off', 'off'])
671        self.assertEqual(g('default_with_oids'), 'off')
672        self.assertEqual(g('standard_conforming_strings'), 'off')
673        f({'standard_conforming_strings': 'on', 'datestyle': 'ISO, YMD'})
674        self.assertEqual(g('standard_conforming_strings'), 'on')
675        self.assertEqual(g('datestyle'), 'ISO, YMD')
676
677    def testResetParameter(self):
678        db = DB()
679        f = db.set_parameter
680        g = db.get_parameter
681        r = g('default_with_oids')
682        self.assertIn(r, ('on', 'off'))
683        dwi, not_dwi = r, 'off' if r == 'on' else 'on'
684        r = g('standard_conforming_strings')
685        self.assertIn(r, ('on', 'off'))
686        scs, not_scs = r, 'off' if r == 'on' else 'on'
687        f('default_with_oids', not_dwi)
688        f('standard_conforming_strings', not_scs)
689        self.assertEqual(g('default_with_oids'), not_dwi)
690        self.assertEqual(g('standard_conforming_strings'), not_scs)
691        f('default_with_oids')
692        f('standard_conforming_strings', None)
693        self.assertEqual(g('default_with_oids'), dwi)
694        self.assertEqual(g('standard_conforming_strings'), scs)
695        f('default_with_oids', not_dwi)
696        f('standard_conforming_strings', not_scs)
697        self.assertEqual(g('default_with_oids'), not_dwi)
698        self.assertEqual(g('standard_conforming_strings'), not_scs)
699        f(['default_with_oids', 'standard_conforming_strings'], None)
700        self.assertEqual(g('default_with_oids'), dwi)
701        self.assertEqual(g('standard_conforming_strings'), scs)
702        f('default_with_oids', not_dwi)
703        f('standard_conforming_strings', not_scs)
704        self.assertEqual(g('default_with_oids'), not_dwi)
705        self.assertEqual(g('standard_conforming_strings'), not_scs)
706        f(('default_with_oids', 'standard_conforming_strings'))
707        self.assertEqual(g('default_with_oids'), dwi)
708        self.assertEqual(g('standard_conforming_strings'), scs)
709        f('default_with_oids', not_dwi)
710        f('standard_conforming_strings', not_scs)
711        self.assertEqual(g('default_with_oids'), not_dwi)
712        self.assertEqual(g('standard_conforming_strings'), not_scs)
713        f(set(['default_with_oids', 'standard_conforming_strings']))
714        self.assertEqual(g('default_with_oids'), dwi)
715        self.assertEqual(g('standard_conforming_strings'), scs)
716
717    def testResetParameterAll(self):
718        db = DB()
719        f = db.set_parameter
720        self.assertRaises(ValueError, f, 'all', 0)
721        self.assertRaises(ValueError, f, 'all', 'off')
722        g = db.get_parameter
723        r = g('default_with_oids')
724        self.assertIn(r, ('on', 'off'))
725        dwi, not_dwi = r, 'off' if r == 'on' else 'on'
726        r = g('standard_conforming_strings')
727        self.assertIn(r, ('on', 'off'))
728        scs, not_scs = r, 'off' if r == 'on' else 'on'
729        f('default_with_oids', not_dwi)
730        f('standard_conforming_strings', not_scs)
731        self.assertEqual(g('default_with_oids'), not_dwi)
732        self.assertEqual(g('standard_conforming_strings'), not_scs)
733        f('all')
734        self.assertEqual(g('default_with_oids'), dwi)
735        self.assertEqual(g('standard_conforming_strings'), scs)
736
737    def testSetParameterLocal(self):
738        f = self.db.set_parameter
739        g = self.db.get_parameter
740        self.assertEqual(g('standard_conforming_strings'), 'on')
741        self.db.begin()
742        f('standard_conforming_strings', 'off', local=True)
743        self.assertEqual(g('standard_conforming_strings'), 'off')
744        self.db.end()
745        self.assertEqual(g('standard_conforming_strings'), 'on')
746
747    def testSetParameterSession(self):
748        f = self.db.set_parameter
749        g = self.db.get_parameter
750        self.assertEqual(g('standard_conforming_strings'), 'on')
751        self.db.begin()
752        f('standard_conforming_strings', 'off', local=False)
753        self.assertEqual(g('standard_conforming_strings'), 'off')
754        self.db.end()
755        self.assertEqual(g('standard_conforming_strings'), 'off')
756
757    def testReset(self):
758        db = DB()
759        default_datestyle = db.get_parameter('datestyle')
760        changed_datestyle = 'ISO, DMY'
761        if changed_datestyle == default_datestyle:
762            changed_datestyle == 'ISO, YMD'
763        self.db.set_parameter('datestyle', changed_datestyle)
764        r = self.db.get_parameter('datestyle')
765        self.assertEqual(r, changed_datestyle)
766        con = self.db.db
767        q = con.query("show datestyle")
768        self.db.reset()
769        r = q.getresult()[0][0]
770        self.assertEqual(r, changed_datestyle)
771        q = con.query("show datestyle")
772        r = q.getresult()[0][0]
773        self.assertEqual(r, default_datestyle)
774        r = self.db.get_parameter('datestyle')
775        self.assertEqual(r, default_datestyle)
776
777    def testReopen(self):
778        db = DB()
779        default_datestyle = db.get_parameter('datestyle')
780        changed_datestyle = 'ISO, DMY'
781        if changed_datestyle == default_datestyle:
782            changed_datestyle == 'ISO, YMD'
783        self.db.set_parameter('datestyle', changed_datestyle)
784        r = self.db.get_parameter('datestyle')
785        self.assertEqual(r, changed_datestyle)
786        con = self.db.db
787        q = con.query("show datestyle")
788        self.db.reopen()
789        r = q.getresult()[0][0]
790        self.assertEqual(r, changed_datestyle)
791        self.assertRaises(TypeError, getattr, con, 'query')
792        r = self.db.get_parameter('datestyle')
793        self.assertEqual(r, default_datestyle)
794
795    def testCreateTable(self):
796        table = 'test hello world'
797        values = [(2, "World!"), (1, "Hello")]
798        self.createTable(table, "n smallint, t varchar",
799            temporary=True, oids=True, values=values)
800        r = self.db.query('select t from "%s" order by n' % table).getresult()
801        r = ', '.join(row[0] for row in r)
802        self.assertEqual(r, "Hello, World!")
803        r = self.db.query('select oid from "%s" limit 1' % table).getresult()
804        self.assertIsInstance(r[0][0], int)
805
806    def testQuery(self):
807        query = self.db.query
808        table = 'test_table'
809        self.createTable(table, "n integer", oids=True)
810        q = "insert into test_table values (1)"
811        r = query(q)
812        self.assertIsInstance(r, int)
813        q = "insert into test_table select 2"
814        r = query(q)
815        self.assertIsInstance(r, int)
816        oid = r
817        q = "select oid from test_table where n=2"
818        r = query(q).getresult()
819        self.assertEqual(len(r), 1)
820        r = r[0]
821        self.assertEqual(len(r), 1)
822        r = r[0]
823        self.assertEqual(r, oid)
824        q = "insert into test_table select 3 union select 4 union select 5"
825        r = query(q)
826        self.assertIsInstance(r, str)
827        self.assertEqual(r, '3')
828        q = "update test_table set n=4 where n<5"
829        r = query(q)
830        self.assertIsInstance(r, str)
831        self.assertEqual(r, '4')
832        q = "delete from test_table"
833        r = query(q)
834        self.assertIsInstance(r, str)
835        self.assertEqual(r, '5')
836
837    def testMultipleQueries(self):
838        self.assertEqual(self.db.query(
839            "create temporary table test_multi (n integer);"
840            "insert into test_multi values (4711);"
841            "select n from test_multi").getresult()[0][0], 4711)
842
843    def testQueryWithParams(self):
844        query = self.db.query
845        self.createTable('test_table', 'n1 integer, n2 integer', oids=True)
846        q = "insert into test_table values ($1, $2)"
847        r = query(q, (1, 2))
848        self.assertIsInstance(r, int)
849        r = query(q, [3, 4])
850        self.assertIsInstance(r, int)
851        r = query(q, [5, 6])
852        self.assertIsInstance(r, int)
853        q = "select * from test_table order by 1, 2"
854        self.assertEqual(query(q).getresult(),
855            [(1, 2), (3, 4), (5, 6)])
856        q = "select * from test_table where n1=$1 and n2=$2"
857        self.assertEqual(query(q, 3, 4).getresult(), [(3, 4)])
858        q = "update test_table set n2=$2 where n1=$1"
859        r = query(q, 3, 7)
860        self.assertEqual(r, '1')
861        q = "select * from test_table order by 1, 2"
862        self.assertEqual(query(q).getresult(),
863            [(1, 2), (3, 7), (5, 6)])
864        q = "delete from test_table where n2!=$1"
865        r = query(q, 4)
866        self.assertEqual(r, '3')
867
868    def testEmptyQuery(self):
869        self.assertRaises(ValueError, self.db.query, '')
870
871    def testQueryProgrammingError(self):
872        try:
873            self.db.query("select 1/0")
874        except pg.ProgrammingError as error:
875            self.assertEqual(error.sqlstate, '22012')
876
877    def testPkey(self):
878        query = self.db.query
879        pkey = self.db.pkey
880        self.assertRaises(KeyError, pkey, 'test')
881        for t in ('pkeytest', 'primary key test'):
882            self.createTable('%s0' % t, 'a smallint')
883            self.createTable('%s1' % t, 'b smallint primary key')
884            self.createTable('%s2' % t,
885                'c smallint, d smallint primary key')
886            self.createTable('%s3' % t,
887                'e smallint, f smallint, g smallint, h smallint, i smallint,'
888                ' primary key (f, h)')
889            self.createTable('%s4' % t,
890                'e smallint, f smallint, g smallint, h smallint, i smallint,'
891                ' primary key (h, f)')
892            self.createTable('%s5' % t,
893                'more_than_one_letter varchar primary key')
894            self.createTable('%s6' % t,
895                '"with space" date primary key')
896            self.createTable('%s7' % t,
897                'a_very_long_column_name varchar, "with space" date, "42" int,'
898                ' primary key (a_very_long_column_name, "with space", "42")')
899            self.assertRaises(KeyError, pkey, '%s0' % t)
900            self.assertEqual(pkey('%s1' % t), 'b')
901            self.assertEqual(pkey('%s1' % t, True), ('b',))
902            self.assertEqual(pkey('%s1' % t, composite=False), 'b')
903            self.assertEqual(pkey('%s1' % t, composite=True), ('b',))
904            self.assertEqual(pkey('%s2' % t), 'd')
905            self.assertEqual(pkey('%s2' % t, composite=True), ('d',))
906            r = pkey('%s3' % t)
907            self.assertIsInstance(r, tuple)
908            self.assertEqual(r, ('f', 'h'))
909            r = pkey('%s3' % t, composite=False)
910            self.assertIsInstance(r, tuple)
911            self.assertEqual(r, ('f', 'h'))
912            r = pkey('%s4' % t)
913            self.assertIsInstance(r, tuple)
914            self.assertEqual(r, ('h', 'f'))
915            self.assertEqual(pkey('%s5' % t), 'more_than_one_letter')
916            self.assertEqual(pkey('%s6' % t), 'with space')
917            r = pkey('%s7' % t)
918            self.assertIsInstance(r, tuple)
919            self.assertEqual(r, (
920                'a_very_long_column_name', 'with space', '42'))
921            # a newly added primary key will be detected
922            query('alter table "%s0" add primary key (a)' % t)
923            self.assertEqual(pkey('%s0' % t), 'a')
924            # a changed primary key will not be detected,
925            # indicating that the internal cache is operating
926            query('alter table "%s1" rename column b to x' % t)
927            self.assertEqual(pkey('%s1' % t), 'b')
928            # we get the changed primary key when the cache is flushed
929            self.assertEqual(pkey('%s1' % t, flush=True), 'x')
930
931    def testGetDatabases(self):
932        databases = self.db.get_databases()
933        self.assertIn('template0', databases)
934        self.assertIn('template1', databases)
935        self.assertNotIn('not existing database', databases)
936        self.assertIn('postgres', databases)
937        self.assertIn(dbname, databases)
938
939    def testGetTables(self):
940        get_tables = self.db.get_tables
941        tables = ('A very Special Name', 'A_MiXeD_quoted_NaMe',
942            'Hello, Test World!', 'Zoro', 'a1', 'a2', 'a321',
943            'averyveryveryveryveryveryveryreallyreallylongtablename',
944            'b0', 'b3', 'x', 'xXx', 'xx', 'y', 'z')
945        for t in tables:
946            self.db.query('drop table if exists "%s" cascade' % t)
947        before_tables = get_tables()
948        self.assertIsInstance(before_tables, list)
949        for t in before_tables:
950            t = t.split('.', 1)
951            self.assertGreaterEqual(len(t), 2)
952            if len(t) > 2:
953                self.assertTrue(t[1].startswith('"'))
954            t = t[0]
955            self.assertNotEqual(t, 'information_schema')
956            self.assertFalse(t.startswith('pg_'))
957        for t in tables:
958            self.createTable(t, 'as select 0', temporary=False)
959        current_tables = get_tables()
960        new_tables = [t for t in current_tables if t not in before_tables]
961        expected_new_tables = ['public.%s' % (
962            '"%s"' % t if ' ' in t or t != t.lower() else t) for t in tables]
963        self.assertEqual(new_tables, expected_new_tables)
964        self.doCleanups()
965        after_tables = get_tables()
966        self.assertEqual(after_tables, before_tables)
967
968    def testGetRelations(self):
969        get_relations = self.db.get_relations
970        result = get_relations()
971        self.assertIn('public.test', result)
972        self.assertIn('public.test_view', result)
973        result = get_relations('rv')
974        self.assertIn('public.test', result)
975        self.assertIn('public.test_view', result)
976        result = get_relations('r')
977        self.assertIn('public.test', result)
978        self.assertNotIn('public.test_view', result)
979        result = get_relations('v')
980        self.assertNotIn('public.test', result)
981        self.assertIn('public.test_view', result)
982        result = get_relations('cisSt')
983        self.assertNotIn('public.test', result)
984        self.assertNotIn('public.test_view', result)
985
986    def testGetAttnames(self):
987        get_attnames = self.db.get_attnames
988        self.assertRaises(pg.ProgrammingError,
989            self.db.get_attnames, 'does_not_exist')
990        self.assertRaises(pg.ProgrammingError,
991            self.db.get_attnames, 'has.too.many.dots')
992        r = get_attnames('test')
993        self.assertIsInstance(r, dict)
994        if self.regtypes:
995            self.assertEqual(r, dict(
996                i2='smallint', i4='integer', i8='bigint', d='numeric',
997                f4='real', f8='double precision', m='money',
998                v4='character varying', c4='character', t='text'))
999        else:
1000            self.assertEqual(r, dict(
1001                i2='int', i4='int', i8='int', d='num',
1002                f4='float', f8='float', m='money',
1003                v4='text', c4='text', t='text'))
1004        self.createTable('test_table',
1005            'n int, alpha smallint, beta bool,'
1006            ' gamma char(5), tau text, v varchar(3)')
1007        r = get_attnames('test_table')
1008        self.assertIsInstance(r, dict)
1009        if self.regtypes:
1010            self.assertEqual(r, dict(
1011                n='integer', alpha='smallint', beta='boolean',
1012                gamma='character', tau='text', v='character varying'))
1013        else:
1014            self.assertEqual(r, dict(
1015                n='int', alpha='int', beta='bool',
1016                gamma='text', tau='text', v='text'))
1017
1018    def testGetAttnamesWithQuotes(self):
1019        get_attnames = self.db.get_attnames
1020        table = 'test table for get_attnames()'
1021        self.createTable(table,
1022            '"Prime!" smallint, "much space" integer, "Questions?" text')
1023        r = get_attnames(table)
1024        self.assertIsInstance(r, dict)
1025        if self.regtypes:
1026            self.assertEqual(r, {
1027                'Prime!': 'smallint', 'much space': 'integer',
1028                'Questions?': 'text'})
1029        else:
1030            self.assertEqual(r, {
1031                'Prime!': 'int', 'much space': 'int', 'Questions?': 'text'})
1032        table = 'yet another test table for get_attnames()'
1033        self.createTable(table,
1034            'a smallint, b integer, c bigint,'
1035            ' e numeric, f real, f2 double precision, m money,'
1036            ' x smallint, y smallint, z smallint,'
1037            ' Normal_NaMe smallint, "Special Name" smallint,'
1038            ' t text, u char(2), v varchar(2),'
1039            ' primary key (y, u)', oids=True)
1040        r = get_attnames(table)
1041        self.assertIsInstance(r, dict)
1042        if self.regtypes:
1043            self.assertEqual(r, {
1044                'a': 'smallint', 'b': 'integer', 'c': 'bigint',
1045                'e': 'numeric', 'f': 'real', 'f2': 'double precision',
1046                'm': 'money', 'normal_name': 'smallint',
1047                'Special Name': 'smallint', 'u': 'character',
1048                't': 'text', 'v': 'character varying', 'y': 'smallint',
1049                'x': 'smallint', 'z': 'smallint', 'oid': 'oid'})
1050        else:
1051            self.assertEqual(r, {'a': 'int', 'b': 'int', 'c': 'int',
1052                'e': 'num', 'f': 'float', 'f2': 'float', 'm': 'money',
1053                'normal_name': 'int', 'Special Name': 'int',
1054                'u': 'text', 't': 'text', 'v': 'text',
1055                'y': 'int', 'x': 'int', 'z': 'int', 'oid': 'int'})
1056
1057    def testGetAttnamesWithRegtypes(self):
1058        get_attnames = self.db.get_attnames
1059        self.createTable('test_table',
1060            ' n int, alpha smallint, beta bool,'
1061            ' gamma char(5), tau text, v varchar(3)')
1062        use_regtypes = self.db.use_regtypes
1063        regtypes = use_regtypes()
1064        self.assertEqual(regtypes, self.regtypes)
1065        use_regtypes(True)
1066        try:
1067            r = get_attnames("test_table")
1068            self.assertIsInstance(r, dict)
1069        finally:
1070            use_regtypes(regtypes)
1071        self.assertEqual(r, dict(
1072            n='integer', alpha='smallint', beta='boolean',
1073            gamma='character', tau='text', v='character varying'))
1074
1075    def testGetAttnamesWithoutRegtypes(self):
1076        get_attnames = self.db.get_attnames
1077        self.createTable('test_table',
1078            ' n int, alpha smallint, beta bool,'
1079            ' gamma char(5), tau text, v varchar(3)')
1080        use_regtypes = self.db.use_regtypes
1081        regtypes = use_regtypes()
1082        self.assertEqual(regtypes, self.regtypes)
1083        use_regtypes(False)
1084        try:
1085            r = get_attnames("test_table")
1086            self.assertIsInstance(r, dict)
1087        finally:
1088            use_regtypes(regtypes)
1089        self.assertEqual(r, dict(
1090            n='int', alpha='int', beta='bool',
1091            gamma='text', tau='text', v='text'))
1092
1093    def testGetAttnamesIsCached(self):
1094        get_attnames = self.db.get_attnames
1095        int_type = 'integer' if self.regtypes else 'int'
1096        text_type = 'text'
1097        query = self.db.query
1098        self.createTable('test_table', 'col int')
1099        r = get_attnames("test_table")
1100        self.assertIsInstance(r, dict)
1101        self.assertEqual(r, dict(col=int_type))
1102        query("alter table test_table alter column col type text")
1103        query("alter table test_table add column col2 int")
1104        r = get_attnames("test_table")
1105        self.assertEqual(r, dict(col=int_type))
1106        r = get_attnames("test_table", flush=True)
1107        self.assertEqual(r, dict(col=text_type, col2=int_type))
1108        query("alter table test_table drop column col2")
1109        r = get_attnames("test_table")
1110        self.assertEqual(r, dict(col=text_type, col2=int_type))
1111        r = get_attnames("test_table", flush=True)
1112        self.assertEqual(r, dict(col=text_type))
1113        query("alter table test_table drop column col")
1114        r = get_attnames("test_table")
1115        self.assertEqual(r, dict(col=text_type))
1116        r = get_attnames("test_table", flush=True)
1117        self.assertEqual(r, dict())
1118
1119    def testGetAttnamesIsOrdered(self):
1120        get_attnames = self.db.get_attnames
1121        r = get_attnames('test', flush=True)
1122        self.assertIsInstance(r, OrderedDict)
1123        if self.regtypes:
1124            self.assertEqual(r, OrderedDict([
1125                ('i2', 'smallint'), ('i4', 'integer'), ('i8', 'bigint'),
1126                ('d', 'numeric'), ('f4', 'real'), ('f8', 'double precision'),
1127                ('m', 'money'), ('v4', 'character varying'),
1128                ('c4', 'character'), ('t', 'text')]))
1129        else:
1130            self.assertEqual(r, OrderedDict([
1131                ('i2', 'int'), ('i4', 'int'), ('i8', 'int'),
1132                ('d', 'num'), ('f4', 'float'), ('f8', 'float'), ('m', 'money'),
1133                ('v4', 'text'), ('c4', 'text'), ('t', 'text')]))
1134        if OrderedDict is not dict:
1135            r = ' '.join(list(r.keys()))
1136            self.assertEqual(r, 'i2 i4 i8 d f4 f8 m v4 c4 t')
1137        table = 'test table for get_attnames'
1138        self.createTable(table,
1139            ' n int, alpha smallint, v varchar(3),'
1140            ' gamma char(5), tau text, beta bool')
1141        r = get_attnames(table)
1142        self.assertIsInstance(r, OrderedDict)
1143        if self.regtypes:
1144            self.assertEqual(r, OrderedDict([
1145                ('n', 'integer'), ('alpha', 'smallint'),
1146                ('v', 'character varying'), ('gamma', 'character'),
1147                ('tau', 'text'), ('beta', 'boolean')]))
1148        else:
1149            self.assertEqual(r, OrderedDict([
1150                ('n', 'int'), ('alpha', 'int'), ('v', 'text'),
1151                ('gamma', 'text'), ('tau', 'text'), ('beta', 'bool')]))
1152        if OrderedDict is not dict:
1153            r = ' '.join(list(r.keys()))
1154            self.assertEqual(r, 'n alpha v gamma tau beta')
1155        else:
1156            self.skipTest('OrderedDict is not supported')
1157
1158    def testGetAttnamesIsAttrDict(self):
1159        AttrDict = pg.AttrDict
1160        get_attnames = self.db.get_attnames
1161        r = get_attnames('test', flush=True)
1162        self.assertIsInstance(r, AttrDict)
1163        if self.regtypes:
1164            self.assertEqual(r, AttrDict([
1165                ('i2', 'smallint'), ('i4', 'integer'), ('i8', 'bigint'),
1166                ('d', 'numeric'), ('f4', 'real'), ('f8', 'double precision'),
1167                ('m', 'money'), ('v4', 'character varying'),
1168                ('c4', 'character'), ('t', 'text')]))
1169        else:
1170            self.assertEqual(r, AttrDict([
1171                ('i2', 'int'), ('i4', 'int'), ('i8', 'int'),
1172                ('d', 'num'), ('f4', 'float'), ('f8', 'float'), ('m', 'money'),
1173                ('v4', 'text'), ('c4', 'text'), ('t', 'text')]))
1174        r = ' '.join(list(r.keys()))
1175        self.assertEqual(r, 'i2 i4 i8 d f4 f8 m v4 c4 t')
1176        table = 'test table for get_attnames'
1177        self.createTable(table,
1178            ' n int, alpha smallint, v varchar(3),'
1179            ' gamma char(5), tau text, beta bool')
1180        r = get_attnames(table)
1181        self.assertIsInstance(r, AttrDict)
1182        if self.regtypes:
1183            self.assertEqual(r, AttrDict([
1184                ('n', 'integer'), ('alpha', 'smallint'),
1185                ('v', 'character varying'), ('gamma', 'character'),
1186                ('tau', 'text'), ('beta', 'boolean')]))
1187        else:
1188            self.assertEqual(r, AttrDict([
1189                ('n', 'int'), ('alpha', 'int'), ('v', 'text'),
1190                ('gamma', 'text'), ('tau', 'text'), ('beta', 'bool')]))
1191        r = ' '.join(list(r.keys()))
1192        self.assertEqual(r, 'n alpha v gamma tau beta')
1193
1194    def testHasTablePrivilege(self):
1195        can = self.db.has_table_privilege
1196        self.assertEqual(can('test'), True)
1197        self.assertEqual(can('test', 'select'), True)
1198        self.assertEqual(can('test', 'SeLeCt'), True)
1199        self.assertEqual(can('test', 'SELECT'), True)
1200        self.assertEqual(can('test', 'insert'), True)
1201        self.assertEqual(can('test', 'update'), True)
1202        self.assertEqual(can('test', 'delete'), True)
1203        self.assertEqual(can('pg_views', 'select'), True)
1204        self.assertEqual(can('pg_views', 'delete'), False)
1205        self.assertRaises(pg.ProgrammingError, can, 'test', 'foobar')
1206        self.assertRaises(pg.ProgrammingError, can, 'table_does_not_exist')
1207
1208    def testGet(self):
1209        get = self.db.get
1210        query = self.db.query
1211        table = 'get_test_table'
1212        self.assertRaises(TypeError, get)
1213        self.assertRaises(TypeError, get, table)
1214        self.createTable(table, 'n integer, t text',
1215            values=enumerate('xyz', start=1))
1216        self.assertRaises(pg.ProgrammingError, get, table, 2)
1217        r = get(table, 2, 'n')
1218        self.assertIsInstance(r, dict)
1219        self.assertEqual(r, dict(n=2, t='y'))
1220        r = get(table, 1, 'n')
1221        self.assertEqual(r, dict(n=1, t='x'))
1222        r = get(table, (3,), ('n',))
1223        self.assertEqual(r, dict(n=3, t='z'))
1224        r = get(table, 'y', 't')
1225        self.assertEqual(r, dict(n=2, t='y'))
1226        self.assertRaises(pg.DatabaseError, get, table, 4)
1227        self.assertRaises(pg.DatabaseError, get, table, 4, 'n')
1228        self.assertRaises(pg.DatabaseError, get, table, 'y')
1229        self.assertRaises(pg.DatabaseError, get, table, 2, 't')
1230        s = dict(n=3)
1231        self.assertRaises(pg.ProgrammingError, get, table, s)
1232        r = get(table, s, 'n')
1233        self.assertIs(r, s)
1234        self.assertEqual(r, dict(n=3, t='z'))
1235        s.update(t='x')
1236        r = get(table, s, 't')
1237        self.assertIs(r, s)
1238        self.assertEqual(s, dict(n=1, t='x'))
1239        r = get(table, s, ('n', 't'))
1240        self.assertIs(r, s)
1241        self.assertEqual(r, dict(n=1, t='x'))
1242        query('alter table "%s" alter n set not null' % table)
1243        query('alter table "%s" add primary key (n)' % table)
1244        r = get(table, 2)
1245        self.assertIsInstance(r, dict)
1246        self.assertEqual(r, dict(n=2, t='y'))
1247        self.assertEqual(get(table, 1)['t'], 'x')
1248        self.assertEqual(get(table, 3)['t'], 'z')
1249        self.assertEqual(get(table + '*', 2)['t'], 'y')
1250        self.assertEqual(get(table + ' *', 2)['t'], 'y')
1251        self.assertRaises(KeyError, get, table, (2, 2))
1252        s = dict(n=3)
1253        r = get(table, s)
1254        self.assertIs(r, s)
1255        self.assertEqual(r, dict(n=3, t='z'))
1256        s.update(n=1)
1257        self.assertEqual(get(table, s)['t'], 'x')
1258        s.update(n=2)
1259        self.assertEqual(get(table, r)['t'], 'y')
1260        s.pop('n')
1261        self.assertRaises(KeyError, get, table, s)
1262
1263    def testGetWithOid(self):
1264        get = self.db.get
1265        query = self.db.query
1266        table = 'get_with_oid_test_table'
1267        self.createTable(table, 'n integer, t text', oids=True,
1268            values=enumerate('xyz', start=1))
1269        self.assertRaises(pg.ProgrammingError, get, table, 2)
1270        self.assertRaises(KeyError, get, table, {}, 'oid')
1271        r = get(table, 2, 'n')
1272        qoid = 'oid(%s)' % table
1273        self.assertIn(qoid, r)
1274        oid = r[qoid]
1275        self.assertIsInstance(oid, int)
1276        result = {'t': 'y', 'n': 2, qoid: oid}
1277        self.assertEqual(r, result)
1278        r = get(table, oid, 'oid')
1279        self.assertEqual(r, result)
1280        r = get(table, dict(oid=oid))
1281        self.assertEqual(r, result)
1282        r = get(table, dict(oid=oid), 'oid')
1283        self.assertEqual(r, result)
1284        r = get(table, {qoid: oid})
1285        self.assertEqual(r, result)
1286        r = get(table, {qoid: oid}, 'oid')
1287        self.assertEqual(r, result)
1288        self.assertEqual(get(table + '*', 2, 'n'), r)
1289        self.assertEqual(get(table + ' *', 2, 'n'), r)
1290        self.assertEqual(get(table, oid, 'oid')['t'], 'y')
1291        self.assertEqual(get(table, 1, 'n')['t'], 'x')
1292        self.assertEqual(get(table, 3, 'n')['t'], 'z')
1293        self.assertEqual(get(table, 2, 'n')['t'], 'y')
1294        self.assertRaises(pg.DatabaseError, get, table, 4, 'n')
1295        r['n'] = 3
1296        self.assertEqual(get(table, r, 'n')['t'], 'z')
1297        self.assertEqual(get(table, 1, 'n')['t'], 'x')
1298        self.assertEqual(get(table, r, 'oid')['t'], 'z')
1299        query('alter table "%s" alter n set not null' % table)
1300        query('alter table "%s" add primary key (n)' % table)
1301        self.assertEqual(get(table, 3)['t'], 'z')
1302        self.assertEqual(get(table, 1)['t'], 'x')
1303        self.assertEqual(get(table, 2)['t'], 'y')
1304        r['n'] = 1
1305        self.assertEqual(get(table, r)['t'], 'x')
1306        r['n'] = 3
1307        self.assertEqual(get(table, r)['t'], 'z')
1308        r['n'] = 2
1309        self.assertEqual(get(table, r)['t'], 'y')
1310        r = get(table, oid, 'oid')
1311        self.assertEqual(r, result)
1312        r = get(table, dict(oid=oid))
1313        self.assertEqual(r, result)
1314        r = get(table, dict(oid=oid), 'oid')
1315        self.assertEqual(r, result)
1316        r = get(table, {qoid: oid})
1317        self.assertEqual(r, result)
1318        r = get(table, {qoid: oid}, 'oid')
1319        self.assertEqual(r, result)
1320        r = get(table, dict(oid=oid, n=1))
1321        self.assertEqual(r['n'], 1)
1322        self.assertNotEqual(r[qoid], oid)
1323        r = get(table, dict(oid=oid, t='z'), 't')
1324        self.assertEqual(r['n'], 3)
1325        self.assertNotEqual(r[qoid], oid)
1326
1327    def testGetWithCompositeKey(self):
1328        get = self.db.get
1329        query = self.db.query
1330        table = 'get_test_table_1'
1331        self.createTable(table, 'n integer primary key, t text',
1332            values=enumerate('abc', start=1))
1333        self.assertEqual(get(table, 2)['t'], 'b')
1334        self.assertEqual(get(table, 1, 'n')['t'], 'a')
1335        self.assertEqual(get(table, 2, ('n',))['t'], 'b')
1336        self.assertEqual(get(table, 3, ['n'])['t'], 'c')
1337        self.assertEqual(get(table, (2,), ('n',))['t'], 'b')
1338        self.assertEqual(get(table, 'b', 't')['n'], 2)
1339        self.assertEqual(get(table, ('a',), ('t',))['n'], 1)
1340        self.assertEqual(get(table, ['c'], ['t'])['n'], 3)
1341        table = 'get_test_table_2'
1342        self.createTable(table,
1343            'n integer, m integer, t text, primary key (n, m)',
1344            values=[(n + 1, m + 1, chr(ord('a') + 2 * n + m))
1345                for n in range(3) for m in range(2)])
1346        self.assertRaises(KeyError, get, table, 2)
1347        self.assertEqual(get(table, (1, 1))['t'], 'a')
1348        self.assertEqual(get(table, (1, 2))['t'], 'b')
1349        self.assertEqual(get(table, (2, 1))['t'], 'c')
1350        self.assertEqual(get(table, (1, 2), ('n', 'm'))['t'], 'b')
1351        self.assertEqual(get(table, (1, 2), ('m', 'n'))['t'], 'c')
1352        self.assertEqual(get(table, (3, 1), ('n', 'm'))['t'], 'e')
1353        self.assertEqual(get(table, (1, 3), ('m', 'n'))['t'], 'e')
1354        self.assertEqual(get(table, dict(n=2, m=2))['t'], 'd')
1355        self.assertEqual(get(table, dict(n=1, m=2), ('n', 'm'))['t'], 'b')
1356        self.assertEqual(get(table, dict(n=2, m=1), ['n', 'm'])['t'], 'c')
1357        self.assertEqual(get(table, dict(n=3, m=2), ('m', 'n'))['t'], 'f')
1358
1359    def testGetWithQuotedNames(self):
1360        get = self.db.get
1361        query = self.db.query
1362        table = 'test table for get()'
1363        self.createTable(table, '"Prime!" smallint primary key,'
1364            ' "much space" integer, "Questions?" text',
1365            values=[(17, 1001, 'No!')])
1366        r = get(table, 17)
1367        self.assertIsInstance(r, dict)
1368        self.assertEqual(r['Prime!'], 17)
1369        self.assertEqual(r['much space'], 1001)
1370        self.assertEqual(r['Questions?'], 'No!')
1371
1372    def testGetFromView(self):
1373        self.db.query('delete from test where i4=14')
1374        self.db.query('insert into test (i4, v4) values('
1375            "14, 'abc4')")
1376        r = self.db.get('test_view', 14, 'i4')
1377        self.assertIn('v4', r)
1378        self.assertEqual(r['v4'], 'abc4')
1379
1380    def testGetLittleBobbyTables(self):
1381        get = self.db.get
1382        query = self.db.query
1383        self.createTable('test_students',
1384            'firstname varchar primary key, nickname varchar, grade char(2)',
1385            values=[("D'Arcy", 'Darcey', 'A+'), ('Sheldon', 'Moonpie', 'A+'),
1386                    ('Robert', 'Little Bobby Tables', 'D-')])
1387        r = get('test_students', 'Sheldon')
1388        self.assertEqual(r, dict(
1389            firstname="Sheldon", nickname='Moonpie', grade='A+'))
1390        r = get('test_students', 'Robert')
1391        self.assertEqual(r, dict(
1392            firstname="Robert", nickname='Little Bobby Tables', grade='D-'))
1393        r = get('test_students', "D'Arcy")
1394        self.assertEqual(r, dict(
1395            firstname="D'Arcy", nickname='Darcey', grade='A+'))
1396        try:
1397            get('test_students', "D' Arcy")
1398        except pg.DatabaseError as error:
1399            self.assertEqual(str(error),
1400                'No such record in test_students\nwhere "firstname" = $1\n'
1401                'with $1="D\' Arcy"')
1402        try:
1403            get('test_students', "Robert'); TRUNCATE TABLE test_students;--")
1404        except pg.DatabaseError as error:
1405            self.assertEqual(str(error),
1406                'No such record in test_students\nwhere "firstname" = $1\n'
1407                'with $1="Robert\'); TRUNCATE TABLE test_students;--"')
1408        q = "select * from test_students order by 1 limit 4"
1409        r = query(q).getresult()
1410        self.assertEqual(len(r), 3)
1411        self.assertEqual(r[1][2], 'D-')
1412
1413    def testInsert(self):
1414        insert = self.db.insert
1415        query = self.db.query
1416        bool_on = pg.get_bool()
1417        decimal = pg.get_decimal()
1418        table = 'insert_test_table'
1419        self.createTable(table,
1420            'i2 smallint, i4 integer, i8 bigint,'
1421            ' d numeric, f4 real, f8 double precision, m money,'
1422            ' v4 varchar(4), c4 char(4), t text,'
1423            ' b boolean, ts timestamp', oids=True)
1424        oid_table = 'oid(%s)' % table
1425        tests = [dict(i2=None, i4=None, i8=None),
1426            (dict(i2='', i4='', i8=''), dict(i2=None, i4=None, i8=None)),
1427            (dict(i2=0, i4=0, i8=0), dict(i2=0, i4=0, i8=0)),
1428            dict(i2=42, i4=123456, i8=9876543210),
1429            dict(i2=2 ** 15 - 1,
1430                i4=int(2 ** 31 - 1), i8=long(2 ** 63 - 1)),
1431            dict(d=None), (dict(d=''), dict(d=None)),
1432            dict(d=Decimal(0)), (dict(d=0), dict(d=Decimal(0))),
1433            dict(f4=None, f8=None), dict(f4=0, f8=0),
1434            (dict(f4='', f8=''), dict(f4=None, f8=None)),
1435            (dict(d=1234.5, f4=1234.5, f8=1234.5),
1436                  dict(d=Decimal('1234.5'))),
1437            dict(d=Decimal('123.456789'), f4=12.375, f8=123.4921875),
1438            dict(d=Decimal('123456789.9876543212345678987654321')),
1439            dict(m=None), (dict(m=''), dict(m=None)),
1440            dict(m=Decimal('-1234.56')),
1441            (dict(m=('-1234.56')), dict(m=Decimal('-1234.56'))),
1442            dict(m=Decimal('1234.56')), dict(m=Decimal('123456')),
1443            (dict(m='1234.56'), dict(m=Decimal('1234.56'))),
1444            (dict(m=1234.5), dict(m=Decimal('1234.5'))),
1445            (dict(m=-1234.5), dict(m=Decimal('-1234.5'))),
1446            (dict(m=123456), dict(m=Decimal('123456'))),
1447            (dict(m='1234567.89'), dict(m=Decimal('1234567.89'))),
1448            dict(b=None), (dict(b=''), dict(b=None)),
1449            dict(b='f'), dict(b='t'),
1450            (dict(b=0), dict(b='f')), (dict(b=1), dict(b='t')),
1451            (dict(b=False), dict(b='f')), (dict(b=True), dict(b='t')),
1452            (dict(b='0'), dict(b='f')), (dict(b='1'), dict(b='t')),
1453            (dict(b='n'), dict(b='f')), (dict(b='y'), dict(b='t')),
1454            (dict(b='no'), dict(b='f')), (dict(b='yes'), dict(b='t')),
1455            (dict(b='off'), dict(b='f')), (dict(b='on'), dict(b='t')),
1456            dict(v4=None, c4=None, t=None),
1457            (dict(v4='', c4='', t=''), dict(c4=' ' * 4)),
1458            dict(v4='1234', c4='1234', t='1234' * 10),
1459            dict(v4='abcd', c4='abcd', t='abcdefg'),
1460            (dict(v4='abc', c4='abc', t='abc'), dict(c4='abc ')),
1461            dict(ts=None), (dict(ts=''), dict(ts=None)),
1462            (dict(ts=0), dict(ts=None)), (dict(ts=False), dict(ts=None)),
1463            dict(ts='2012-12-21 00:00:00'),
1464            (dict(ts='2012-12-21'), dict(ts='2012-12-21 00:00:00')),
1465            dict(ts='2012-12-21 12:21:12'),
1466            dict(ts='2013-01-05 12:13:14'),
1467            dict(ts='current_timestamp')]
1468        for test in tests:
1469            if isinstance(test, dict):
1470                data = test
1471                change = {}
1472            else:
1473                data, change = test
1474            expect = data.copy()
1475            expect.update(change)
1476            if bool_on:
1477                b = expect.get('b')
1478                if b is not None:
1479                    expect['b'] = b == 't'
1480            if decimal is not Decimal:
1481                d = expect.get('d')
1482                if d is not None:
1483                    expect['d'] = decimal(d)
1484                m = expect.get('m')
1485                if m is not None:
1486                    expect['m'] = decimal(m)
1487            self.assertEqual(insert(table, data), data)
1488            self.assertIn(oid_table, data)
1489            oid = data[oid_table]
1490            self.assertIsInstance(oid, int)
1491            data = dict(item for item in data.items()
1492                if item[0] in expect)
1493            ts = expect.get('ts')
1494            if ts == 'current_timestamp':
1495                ts = expect['ts'] = data['ts']
1496                if len(ts) > 19:
1497                    self.assertEqual(ts[19], '.')
1498                    ts = ts[:19]
1499                else:
1500                    self.assertEqual(len(ts), 19)
1501                self.assertTrue(ts[:4].isdigit())
1502                self.assertEqual(ts[4], '-')
1503                self.assertEqual(ts[10], ' ')
1504                self.assertTrue(ts[11:13].isdigit())
1505                self.assertEqual(ts[13], ':')
1506            self.assertEqual(data, expect)
1507            data = query(
1508                'select oid,* from "%s"' % table).dictresult()[0]
1509            self.assertEqual(data['oid'], oid)
1510            data = dict(item for item in data.items()
1511                if item[0] in expect)
1512            self.assertEqual(data, expect)
1513            query('delete from "%s"' % table)
1514
1515    def testInsertWithOid(self):
1516        insert = self.db.insert
1517        query = self.db.query
1518        self.createTable('test_table', 'n int', oids=True)
1519        self.assertRaises(pg.ProgrammingError, insert, 'test_table', m=1)
1520        r = insert('test_table', n=1)
1521        self.assertIsInstance(r, dict)
1522        self.assertEqual(r['n'], 1)
1523        self.assertNotIn('oid', r)
1524        qoid = 'oid(test_table)'
1525        self.assertIn(qoid, r)
1526        oid = r[qoid]
1527        self.assertEqual(sorted(r.keys()), ['n', qoid])
1528        r = insert('test_table', n=2, oid=oid)
1529        self.assertIsInstance(r, dict)
1530        self.assertEqual(r['n'], 2)
1531        self.assertIn(qoid, r)
1532        self.assertNotEqual(r[qoid], oid)
1533        self.assertNotIn('oid', r)
1534        r = insert('test_table', None, n=3)
1535        self.assertIsInstance(r, dict)
1536        self.assertEqual(r['n'], 3)
1537        s = r
1538        r = insert('test_table', r)
1539        self.assertIs(r, s)
1540        self.assertEqual(r['n'], 3)
1541        r = insert('test_table *', r)
1542        self.assertIs(r, s)
1543        self.assertEqual(r['n'], 3)
1544        r = insert('test_table', r, n=4)
1545        self.assertIs(r, s)
1546        self.assertEqual(r['n'], 4)
1547        self.assertNotIn('oid', r)
1548        self.assertIn(qoid, r)
1549        oid = r[qoid]
1550        r = insert('test_table', r, n=5, oid=oid)
1551        self.assertIs(r, s)
1552        self.assertEqual(r['n'], 5)
1553        self.assertIn(qoid, r)
1554        self.assertNotEqual(r[qoid], oid)
1555        self.assertNotIn('oid', r)
1556        r['oid'] = oid = r[qoid]
1557        r = insert('test_table', r, n=6)
1558        self.assertIs(r, s)
1559        self.assertEqual(r['n'], 6)
1560        self.assertIn(qoid, r)
1561        self.assertNotEqual(r[qoid], oid)
1562        self.assertNotIn('oid', r)
1563        q = 'select n from test_table order by 1 limit 9'
1564        r = ' '.join(str(row[0]) for row in query(q).getresult())
1565        self.assertEqual(r, '1 2 3 3 3 4 5 6')
1566        query("truncate test_table")
1567        query("alter table test_table add unique (n)")
1568        r = insert('test_table', dict(n=7))
1569        self.assertIsInstance(r, dict)
1570        self.assertEqual(r['n'], 7)
1571        self.assertRaises(pg.ProgrammingError, insert, 'test_table', r)
1572        r['n'] = 6
1573        self.assertRaises(pg.ProgrammingError, insert, 'test_table', r, n=7)
1574        self.assertIsInstance(r, dict)
1575        self.assertEqual(r['n'], 7)
1576        r['n'] = 6
1577        r = insert('test_table', r)
1578        self.assertIsInstance(r, dict)
1579        self.assertEqual(r['n'], 6)
1580        r = ' '.join(str(row[0]) for row in query(q).getresult())
1581        self.assertEqual(r, '6 7')
1582
1583    def testInsertWithQuotedNames(self):
1584        insert = self.db.insert
1585        query = self.db.query
1586        table = 'test table for insert()'
1587        self.createTable(table, '"Prime!" smallint primary key,'
1588            ' "much space" integer, "Questions?" text')
1589        r = {'Prime!': 11, 'much space': 2002, 'Questions?': 'What?'}
1590        r = insert(table, r)
1591        self.assertIsInstance(r, dict)
1592        self.assertEqual(r['Prime!'], 11)
1593        self.assertEqual(r['much space'], 2002)
1594        self.assertEqual(r['Questions?'], 'What?')
1595        r = query('select * from "%s" limit 2' % table).dictresult()
1596        self.assertEqual(len(r), 1)
1597        r = r[0]
1598        self.assertEqual(r['Prime!'], 11)
1599        self.assertEqual(r['much space'], 2002)
1600        self.assertEqual(r['Questions?'], 'What?')
1601
1602    def testInsertIntoView(self):
1603        insert = self.db.insert
1604        query = self.db.query
1605        query("truncate test")
1606        q = 'select * from test_view order by i4 limit 3'
1607        r = query(q).getresult()
1608        self.assertEqual(r, [])
1609        r = dict(i4=1234, v4='abcd')
1610        insert('test', r)
1611        self.assertIsNone(r['i2'])
1612        self.assertEqual(r['i4'], 1234)
1613        self.assertIsNone(r['i8'])
1614        self.assertEqual(r['v4'], 'abcd')
1615        self.assertIsNone(r['c4'])
1616        r = query(q).getresult()
1617        self.assertEqual(r, [(1234, 'abcd')])
1618        r = dict(i4=5678, v4='efgh')
1619        try:
1620            insert('test_view', r)
1621        except pg.ProgrammingError as error:
1622            if self.db.server_version < 90300:
1623                # must setup rules in older PostgreSQL versions
1624                self.skipTest('database cannot insert into view')
1625            self.fail(str(error))
1626        self.assertNotIn('i2', r)
1627        self.assertEqual(r['i4'], 5678)
1628        self.assertNotIn('i8', r)
1629        self.assertEqual(r['v4'], 'efgh')
1630        self.assertNotIn('c4', r)
1631        r = query(q).getresult()
1632        self.assertEqual(r, [(1234, 'abcd'), (5678, 'efgh')])
1633
1634    def testUpdate(self):
1635        update = self.db.update
1636        query = self.db.query
1637        self.assertRaises(pg.ProgrammingError, update,
1638            'test', i2=2, i4=4, i8=8)
1639        table = 'update_test_table'
1640        self.createTable(table, 'n integer, t text', oids=True,
1641            values=enumerate('xyz', start=1))
1642        self.assertRaises(pg.ProgrammingError, self.db.get, table, 2)
1643        r = self.db.get(table, 2, 'n')
1644        r['t'] = 'u'
1645        s = update(table, r)
1646        self.assertEqual(s, r)
1647        q = 'select t from "%s" where n=2' % table
1648        r = query(q).getresult()[0][0]
1649        self.assertEqual(r, 'u')
1650
1651    def testUpdateWithOid(self):
1652        update = self.db.update
1653        get = self.db.get
1654        query = self.db.query
1655        self.createTable('test_table', 'n int', oids=True, values=[1])
1656        s = get('test_table', 1, 'n')
1657        self.assertIsInstance(s, dict)
1658        self.assertEqual(s['n'], 1)
1659        s['n'] = 2
1660        r = update('test_table', s)
1661        self.assertIs(r, s)
1662        self.assertEqual(r['n'], 2)
1663        qoid = 'oid(test_table)'
1664        self.assertIn(qoid, r)
1665        self.assertNotIn('oid', r)
1666        self.assertEqual(sorted(r.keys()), ['n', qoid])
1667        r['n'] = 3
1668        oid = r.pop(qoid)
1669        r = update('test_table', r, oid=oid)
1670        self.assertIs(r, s)
1671        self.assertEqual(r['n'], 3)
1672        r.pop(qoid)
1673        self.assertRaises(pg.ProgrammingError, update, 'test_table', r)
1674        s = get('test_table', 3, 'n')
1675        self.assertIsInstance(s, dict)
1676        self.assertEqual(s['n'], 3)
1677        s.pop('n')
1678        r = update('test_table', s)
1679        oid = r.pop(qoid)
1680        self.assertEqual(r, {})
1681        q = "select n from test_table limit 2"
1682        r = query(q).getresult()
1683        self.assertEqual(r, [(3,)])
1684        query("insert into test_table values (1)")
1685        self.assertRaises(pg.ProgrammingError,
1686            update, 'test_table', dict(oid=oid, n=4))
1687        r = update('test_table', dict(n=4), oid=oid)
1688        self.assertEqual(r['n'], 4)
1689        r = update('test_table *', dict(n=5), oid=oid)
1690        self.assertEqual(r['n'], 5)
1691        query("alter table test_table add column m int")
1692        query("alter table test_table add primary key (n)")
1693        self.assertIn('m', self.db.get_attnames('test_table', flush=True))
1694        self.assertEqual('n', self.db.pkey('test_table', flush=True))
1695        s = dict(n=1, m=4)
1696        r = update('test_table', s)
1697        self.assertIs(r, s)
1698        self.assertEqual(r['n'], 1)
1699        self.assertEqual(r['m'], 4)
1700        s = dict(m=7)
1701        r = update('test_table', s, n=5)
1702        self.assertIs(r, s)
1703        self.assertEqual(r['n'], 5)
1704        self.assertEqual(r['m'], 7)
1705        q = "select n, m from test_table order by 1 limit 3"
1706        r = query(q).getresult()
1707        self.assertEqual(r, [(1, 4), (5, 7)])
1708        s = dict(m=9, oid=oid)
1709        self.assertRaises(KeyError, update, 'test_table', s)
1710        r = update('test_table', s, oid=oid)
1711        self.assertIs(r, s)
1712        self.assertEqual(r['n'], 5)
1713        self.assertEqual(r['m'], 9)
1714        s = dict(n=1, m=3, oid=oid)
1715        r = update('test_table', s)
1716        self.assertIs(r, s)
1717        self.assertEqual(r['n'], 1)
1718        self.assertEqual(r['m'], 3)
1719        r = query(q).getresult()
1720        self.assertEqual(r, [(1, 3), (5, 9)])
1721
1722    def testUpdateWithoutOid(self):
1723        update = self.db.update
1724        query = self.db.query
1725        self.assertRaises(pg.ProgrammingError, update,
1726            'test', i2=2, i4=4, i8=8)
1727        table = 'update_test_table'
1728        self.createTable(table, 'n integer primary key, t text', oids=False,
1729            values=enumerate('xyz', start=1))
1730        r = self.db.get(table, 2)
1731        r['t'] = 'u'
1732        s = update(table, r)
1733        self.assertEqual(s, r)
1734        q = 'select t from "%s" where n=2' % table
1735        r = query(q).getresult()[0][0]
1736        self.assertEqual(r, 'u')
1737
1738    def testUpdateWithCompositeKey(self):
1739        update = self.db.update
1740        query = self.db.query
1741        table = 'update_test_table_1'
1742        self.createTable(table, 'n integer primary key, t text',
1743            values=enumerate('abc', start=1))
1744        self.assertRaises(KeyError, update, table, dict(t='b'))
1745        s = dict(n=2, t='d')
1746        r = update(table, s)
1747        self.assertIs(r, s)
1748        self.assertEqual(r['n'], 2)
1749        self.assertEqual(r['t'], 'd')
1750        q = 'select t from "%s" where n=2' % table
1751        r = query(q).getresult()[0][0]
1752        self.assertEqual(r, 'd')
1753        s.update(dict(n=4, t='e'))
1754        r = update(table, s)
1755        self.assertEqual(r['n'], 4)
1756        self.assertEqual(r['t'], 'e')
1757        q = 'select t from "%s" where n=2' % table
1758        r = query(q).getresult()[0][0]
1759        self.assertEqual(r, 'd')
1760        q = 'select t from "%s" where n=4' % table
1761        r = query(q).getresult()
1762        self.assertEqual(len(r), 0)
1763        query('drop table "%s"' % table)
1764        table = 'update_test_table_2'
1765        self.createTable(table,
1766            'n integer, m integer, t text, primary key (n, m)',
1767            values=[(n + 1, m + 1, chr(ord('a') + 2 * n + m))
1768                for n in range(3) for m in range(2)])
1769        self.assertRaises(KeyError, update, table, dict(n=2, t='b'))
1770        self.assertEqual(update(table,
1771            dict(n=2, m=2, t='x'))['t'], 'x')
1772        q = 'select t from "%s" where n=2 order by m' % table
1773        r = [r[0] for r in query(q).getresult()]
1774        self.assertEqual(r, ['c', 'x'])
1775
1776    def testUpdateWithQuotedNames(self):
1777        update = self.db.update
1778        query = self.db.query
1779        table = 'test table for update()'
1780        self.createTable(table, '"Prime!" smallint primary key,'
1781            ' "much space" integer, "Questions?" text',
1782            values=[(13, 3003, 'Why!')])
1783        r = {'Prime!': 13, 'much space': 7007, 'Questions?': 'When?'}
1784        r = update(table, r)
1785        self.assertIsInstance(r, dict)
1786        self.assertEqual(r['Prime!'], 13)
1787        self.assertEqual(r['much space'], 7007)
1788        self.assertEqual(r['Questions?'], 'When?')
1789        r = query('select * from "%s" limit 2' % table).dictresult()
1790        self.assertEqual(len(r), 1)
1791        r = r[0]
1792        self.assertEqual(r['Prime!'], 13)
1793        self.assertEqual(r['much space'], 7007)
1794        self.assertEqual(r['Questions?'], 'When?')
1795
1796    def testUpsert(self):
1797        upsert = self.db.upsert
1798        query = self.db.query
1799        self.assertRaises(pg.ProgrammingError, upsert,
1800            'test', i2=2, i4=4, i8=8)
1801        table = 'upsert_test_table'
1802        self.createTable(table, 'n integer primary key, t text', oids=True)
1803        s = dict(n=1, t='x')
1804        try:
1805            r = upsert(table, s)
1806        except pg.ProgrammingError as error:
1807            if self.db.server_version < 90500:
1808                self.skipTest('database does not support upsert')
1809            self.fail(str(error))
1810        self.assertIs(r, s)
1811        self.assertEqual(r['n'], 1)
1812        self.assertEqual(r['t'], 'x')
1813        s.update(n=2, t='y')
1814        r = upsert(table, s, **dict.fromkeys(s))
1815        self.assertIs(r, s)
1816        self.assertEqual(r['n'], 2)
1817        self.assertEqual(r['t'], 'y')
1818        q = 'select n, t from "%s" order by n limit 3' % table
1819        r = query(q).getresult()
1820        self.assertEqual(r, [(1, 'x'), (2, 'y')])
1821        s.update(t='z')
1822        r = upsert(table, s)
1823        self.assertIs(r, s)
1824        self.assertEqual(r['n'], 2)
1825        self.assertEqual(r['t'], 'z')
1826        r = query(q).getresult()
1827        self.assertEqual(r, [(1, 'x'), (2, 'z')])
1828        s.update(t='n')
1829        r = upsert(table, s, t=False)
1830        self.assertIs(r, s)
1831        self.assertEqual(r['n'], 2)
1832        self.assertEqual(r['t'], 'z')
1833        r = query(q).getresult()
1834        self.assertEqual(r, [(1, 'x'), (2, 'z')])
1835        s.update(t='y')
1836        r = upsert(table, s, t=True)
1837        self.assertIs(r, s)
1838        self.assertEqual(r['n'], 2)
1839        self.assertEqual(r['t'], 'y')
1840        r = query(q).getresult()
1841        self.assertEqual(r, [(1, 'x'), (2, 'y')])
1842        s.update(t='n')
1843        r = upsert(table, s, t="included.t || '2'")
1844        self.assertIs(r, s)
1845        self.assertEqual(r['n'], 2)
1846        self.assertEqual(r['t'], 'y2')
1847        r = query(q).getresult()
1848        self.assertEqual(r, [(1, 'x'), (2, 'y2')])
1849        s.update(t='y')
1850        r = upsert(table, s, t="excluded.t || '3'")
1851        self.assertIs(r, s)
1852        self.assertEqual(r['n'], 2)
1853        self.assertEqual(r['t'], 'y3')
1854        r = query(q).getresult()
1855        self.assertEqual(r, [(1, 'x'), (2, 'y3')])
1856        s.update(n=1, t='2')
1857        r = upsert(table, s, t="included.t || excluded.t")
1858        self.assertIs(r, s)
1859        self.assertEqual(r['n'], 1)
1860        self.assertEqual(r['t'], 'x2')
1861        r = query(q).getresult()
1862        self.assertEqual(r, [(1, 'x2'), (2, 'y3')])
1863        # not existing columns and oid parameter should be ignored
1864        s = dict(m=3, u='z')
1865        r = upsert(table, s, oid='invalid')
1866        self.assertIs(r, s)
1867
1868    def testUpsertWithOid(self):
1869        upsert = self.db.upsert
1870        get = self.db.get
1871        query = self.db.query
1872        self.createTable('test_table', 'n int', oids=True, values=[1])
1873        self.assertRaises(pg.ProgrammingError,
1874            upsert, 'test_table', dict(n=2))
1875        r = get('test_table', 1, 'n')
1876        self.assertIsInstance(r, dict)
1877        self.assertEqual(r['n'], 1)
1878        qoid = 'oid(test_table)'
1879        self.assertIn(qoid, r)
1880        self.assertNotIn('oid', r)
1881        oid = r[qoid]
1882        self.assertRaises(pg.ProgrammingError,
1883            upsert, 'test_table', dict(n=2, oid=oid))
1884        query("alter table test_table add column m int")
1885        query("alter table test_table add primary key (n)")
1886        self.assertIn('m', self.db.get_attnames('test_table', flush=True))
1887        self.assertEqual('n', self.db.pkey('test_table', flush=True))
1888        s = dict(n=2)
1889        try:
1890            r = upsert('test_table', s)
1891        except pg.ProgrammingError as error:
1892            if self.db.server_version < 90500:
1893                self.skipTest('database does not support upsert')
1894            self.fail(str(error))
1895        self.assertIs(r, s)
1896        self.assertEqual(r['n'], 2)
1897        self.assertIsNone(r['m'])
1898        q = query("select n, m from test_table order by n limit 3")
1899        self.assertEqual(q.getresult(), [(1, None), (2, None)])
1900        r['oid'] = oid
1901        r = upsert('test_table', r)
1902        self.assertIs(r, s)
1903        self.assertEqual(r['n'], 2)
1904        self.assertIsNone(r['m'])
1905        self.assertIn(qoid, r)
1906        self.assertNotIn('oid', r)
1907        self.assertNotEqual(r[qoid], oid)
1908        r['m'] = 7
1909        r = upsert('test_table', r)
1910        self.assertIs(r, s)
1911        self.assertEqual(r['n'], 2)
1912        self.assertEqual(r['m'], 7)
1913        r.update(n=1, m=3)
1914        r = upsert('test_table', r)
1915        self.assertIs(r, s)
1916        self.assertEqual(r['n'], 1)
1917        self.assertEqual(r['m'], 3)
1918        q = query("select n, m from test_table order by n limit 3")
1919        self.assertEqual(q.getresult(), [(1, 3), (2, 7)])
1920        r = upsert('test_table', r, oid='invalid')
1921        self.assertIs(r, s)
1922        self.assertEqual(r['n'], 1)
1923        self.assertEqual(r['m'], 3)
1924        r['m'] = 5
1925        r = upsert('test_table', r, m=False)
1926        self.assertIs(r, s)
1927        self.assertEqual(r['n'], 1)
1928        self.assertEqual(r['m'], 3)
1929        r['m'] = 5
1930        r = upsert('test_table', r, m=True)
1931        self.assertIs(r, s)
1932        self.assertEqual(r['n'], 1)
1933        self.assertEqual(r['m'], 5)
1934        r.update(n=2, m=1)
1935        r = upsert('test_table', r, m='included.m')
1936        self.assertIs(r, s)
1937        self.assertEqual(r['n'], 2)
1938        self.assertEqual(r['m'], 7)
1939        r['m'] = 9
1940        r = upsert('test_table', r, m='excluded.m')
1941        self.assertIs(r, s)
1942        self.assertEqual(r['n'], 2)
1943        self.assertEqual(r['m'], 9)
1944        r['m'] = 8
1945        r = upsert('test_table *', r, m='included.m + 1')
1946        self.assertIs(r, s)
1947        self.assertEqual(r['n'], 2)
1948        self.assertEqual(r['m'], 10)
1949        q = query("select n, m from test_table order by n limit 3")
1950        self.assertEqual(q.getresult(), [(1, 5), (2, 10)])
1951
1952    def testUpsertWithCompositeKey(self):
1953        upsert = self.db.upsert
1954        query = self.db.query
1955        table = 'upsert_test_table_2'
1956        self.createTable(table,
1957            'n integer, m integer, t text, primary key (n, m)')
1958        s = dict(n=1, m=2, t='x')
1959        try:
1960            r = upsert(table, s)
1961        except pg.ProgrammingError as error:
1962            if self.db.server_version < 90500:
1963                self.skipTest('database does not support upsert')
1964            self.fail(str(error))
1965        self.assertIs(r, s)
1966        self.assertEqual(r['n'], 1)
1967        self.assertEqual(r['m'], 2)
1968        self.assertEqual(r['t'], 'x')
1969        s.update(m=3, t='y')
1970        r = upsert(table, s, **dict.fromkeys(s))
1971        self.assertIs(r, s)
1972        self.assertEqual(r['n'], 1)
1973        self.assertEqual(r['m'], 3)
1974        self.assertEqual(r['t'], 'y')
1975        q = 'select n, m, t from "%s" order by n, m limit 3' % table
1976        r = query(q).getresult()
1977        self.assertEqual(r, [(1, 2, 'x'), (1, 3, 'y')])
1978        s.update(t='z')
1979        r = upsert(table, s)
1980        self.assertIs(r, s)
1981        self.assertEqual(r['n'], 1)
1982        self.assertEqual(r['m'], 3)
1983        self.assertEqual(r['t'], 'z')
1984        r = query(q).getresult()
1985        self.assertEqual(r, [(1, 2, 'x'), (1, 3, 'z')])
1986        s.update(t='n')
1987        r = upsert(table, s, t=False)
1988        self.assertIs(r, s)
1989        self.assertEqual(r['n'], 1)
1990        self.assertEqual(r['m'], 3)
1991        self.assertEqual(r['t'], 'z')
1992        r = query(q).getresult()
1993        self.assertEqual(r, [(1, 2, 'x'), (1, 3, 'z')])
1994        s.update(t='n')
1995        r = upsert(table, s, t=True)
1996        self.assertIs(r, s)
1997        self.assertEqual(r['n'], 1)
1998        self.assertEqual(r['m'], 3)
1999        self.assertEqual(r['t'], 'n')
2000        r = query(q).getresult()
2001        self.assertEqual(r, [(1, 2, 'x'), (1, 3, 'n')])
2002        s.update(n=2, t='y')
2003        r = upsert(table, s, t="'z'")
2004        self.assertIs(r, s)
2005        self.assertEqual(r['n'], 2)
2006        self.assertEqual(r['m'], 3)
2007        self.assertEqual(r['t'], 'y')
2008        r = query(q).getresult()
2009        self.assertEqual(r, [(1, 2, 'x'), (1, 3, 'n'), (2, 3, 'y')])
2010        s.update(n=1, t='m')
2011        r = upsert(table, s, t='included.t || excluded.t')
2012        self.assertIs(r, s)
2013        self.assertEqual(r['n'], 1)
2014        self.assertEqual(r['m'], 3)
2015        self.assertEqual(r['t'], 'nm')
2016        r = query(q).getresult()
2017        self.assertEqual(r, [(1, 2, 'x'), (1, 3, 'nm'), (2, 3, 'y')])
2018
2019    def testUpsertWithQuotedNames(self):
2020        upsert = self.db.upsert
2021        query = self.db.query
2022        table = 'test table for upsert()'
2023        self.createTable(table, '"Prime!" smallint primary key,'
2024            ' "much space" integer, "Questions?" text')
2025        s = {'Prime!': 31, 'much space': 9009, 'Questions?': 'Yes.'}
2026        try:
2027            r = upsert(table, s)
2028        except pg.ProgrammingError as error:
2029            if self.db.server_version < 90500:
2030                self.skipTest('database does not support upsert')
2031            self.fail(str(error))
2032        self.assertIs(r, s)
2033        self.assertEqual(r['Prime!'], 31)
2034        self.assertEqual(r['much space'], 9009)
2035        self.assertEqual(r['Questions?'], 'Yes.')
2036        q = 'select * from "%s" limit 2' % table
2037        r = query(q).getresult()
2038        self.assertEqual(r, [(31, 9009, 'Yes.')])
2039        s.update({'Questions?': 'No.'})
2040        r = upsert(table, s)
2041        self.assertIs(r, s)
2042        self.assertEqual(r['Prime!'], 31)
2043        self.assertEqual(r['much space'], 9009)
2044        self.assertEqual(r['Questions?'], 'No.')
2045        r = query(q).getresult()
2046        self.assertEqual(r, [(31, 9009, 'No.')])
2047
2048    def testClear(self):
2049        clear = self.db.clear
2050        f = False if pg.get_bool() else 'f'
2051        r = clear('test')
2052        result = dict(
2053            i2=0, i4=0, i8=0, d=0, f4=0, f8=0, m=0, v4='', c4='', t='')
2054        self.assertEqual(r, result)
2055        table = 'clear_test_table'
2056        self.createTable(table,
2057            'n integer, f float, b boolean, d date, t text', oids=True)
2058        r = clear(table)
2059        result = dict(n=0, f=0, b=f, d='', t='')
2060        self.assertEqual(r, result)
2061        r['a'] = r['f'] = r['n'] = 1
2062        r['d'] = r['t'] = 'x'
2063        r['b'] = 't'
2064        r['oid'] = long(1)
2065        r = clear(table, r)
2066        result = dict(a=1, n=0, f=0, b=f, d='', t='', oid=long(1))
2067        self.assertEqual(r, result)
2068
2069    def testClearWithQuotedNames(self):
2070        clear = self.db.clear
2071        table = 'test table for clear()'
2072        self.createTable(table, '"Prime!" smallint primary key,'
2073            ' "much space" integer, "Questions?" text')
2074        r = clear(table)
2075        self.assertIsInstance(r, dict)
2076        self.assertEqual(r['Prime!'], 0)
2077        self.assertEqual(r['much space'], 0)
2078        self.assertEqual(r['Questions?'], '')
2079
2080    def testDelete(self):
2081        delete = self.db.delete
2082        query = self.db.query
2083        self.assertRaises(pg.ProgrammingError, delete,
2084            'test', dict(i2=2, i4=4, i8=8))
2085        table = 'delete_test_table'
2086        self.createTable(table, 'n integer, t text', oids=True,
2087            values=enumerate('xyz', start=1))
2088        self.assertRaises(pg.ProgrammingError, self.db.get, table, 2)
2089        r = self.db.get(table, 1, 'n')
2090        s = delete(table, r)
2091        self.assertEqual(s, 1)
2092        r = self.db.get(table, 3, 'n')
2093        s = delete(table, r)
2094        self.assertEqual(s, 1)
2095        s = delete(table, r)
2096        self.assertEqual(s, 0)
2097        r = query('select * from "%s"' % table).dictresult()
2098        self.assertEqual(len(r), 1)
2099        r = r[0]
2100        result = {'n': 2, 't': 'y'}
2101        self.assertEqual(r, result)
2102        r = self.db.get(table, 2, 'n')
2103        s = delete(table, r)
2104        self.assertEqual(s, 1)
2105        s = delete(table, r)
2106        self.assertEqual(s, 0)
2107        self.assertRaises(pg.DatabaseError, self.db.get, table, 2, 'n')
2108        # not existing columns and oid parameter should be ignored
2109        r.update(m=3, u='z', oid='invalid')
2110        s = delete(table, r)
2111        self.assertEqual(s, 0)
2112
2113    def testDeleteWithOid(self):
2114        delete = self.db.delete
2115        get = self.db.get
2116        query = self.db.query
2117        self.createTable('test_table', 'n int', oids=True, values=range(1, 7))
2118        r = dict(n=3)
2119        self.assertRaises(pg.ProgrammingError, delete, 'test_table', r)
2120        s = get('test_table', 1, 'n')
2121        qoid = 'oid(test_table)'
2122        self.assertIn(qoid, s)
2123        r = delete('test_table', s)
2124        self.assertEqual(r, 1)
2125        r = delete('test_table', s)
2126        self.assertEqual(r, 0)
2127        q = "select min(n),count(n) from test_table"
2128        self.assertEqual(query(q).getresult()[0], (2, 5))
2129        oid = get('test_table', 2, 'n')[qoid]
2130        s = dict(oid=oid, n=2)
2131        self.assertRaises(pg.ProgrammingError, delete, 'test_table', s)
2132        r = delete('test_table', None, oid=oid)
2133        self.assertEqual(r, 1)
2134        r = delete('test_table', None, oid=oid)
2135        self.assertEqual(r, 0)
2136        self.assertEqual(query(q).getresult()[0], (3, 4))
2137        s = dict(oid=oid, n=2)
2138        oid = get('test_table', 3, 'n')[qoid]
2139        self.assertRaises(pg.ProgrammingError, delete, 'test_table', s)
2140        r = delete('test_table', s, oid=oid)
2141        self.assertEqual(r, 1)
2142        r = delete('test_table', s, oid=oid)
2143        self.assertEqual(r, 0)
2144        self.assertEqual(query(q).getresult()[0], (4, 3))
2145        s = get('test_table', 4, 'n')
2146        r = delete('test_table *', s)
2147        self.assertEqual(r, 1)
2148        r = delete('test_table *', s)
2149        self.assertEqual(r, 0)
2150        self.assertEqual(query(q).getresult()[0], (5, 2))
2151        oid = get('test_table', 5, 'n')[qoid]
2152        s = {qoid: oid, 'm': 4}
2153        r = delete('test_table', s, m=6)
2154        self.assertEqual(r, 1)
2155        r = delete('test_table *', s)
2156        self.assertEqual(r, 0)
2157        self.assertEqual(query(q).getresult()[0], (6, 1))
2158        query("alter table test_table add column m int")
2159        query("alter table test_table add primary key (n)")
2160        self.assertIn('m', self.db.get_attnames('test_table', flush=True))
2161        self.assertEqual('n', self.db.pkey('test_table', flush=True))
2162        for i in range(5):
2163            query("insert into test_table values (%d, %d)" % (i + 1, i + 2))
2164        s = dict(m=2)
2165        self.assertRaises(KeyError, delete, 'test_table', s)
2166        s = dict(m=2, oid=oid)
2167        self.assertRaises(KeyError, delete, 'test_table', s)
2168        r = delete('test_table', dict(m=2), oid=oid)
2169        self.assertEqual(r, 0)
2170        oid = get('test_table', 1, 'n')[qoid]
2171        s = dict(oid=oid)
2172        self.assertRaises(KeyError, delete, 'test_table', s)
2173        r = delete('test_table', s, oid=oid)
2174        self.assertEqual(r, 1)
2175        r = delete('test_table', s, oid=oid)
2176        self.assertEqual(r, 0)
2177        self.assertEqual(query(q).getresult()[0], (2, 5))
2178        s = get('test_table', 2, 'n')
2179        del s['n']
2180        r = delete('test_table', s)
2181        self.assertEqual(r, 1)
2182        r = delete('test_table', s)
2183        self.assertEqual(r, 0)
2184        self.assertEqual(query(q).getresult()[0], (3, 4))
2185        r = delete('test_table', n=3)
2186        self.assertEqual(r, 1)
2187        r = delete('test_table', n=3)
2188        self.assertEqual(r, 0)
2189        self.assertEqual(query(q).getresult()[0], (4, 3))
2190        r = delete('test_table', None, n=4)
2191        self.assertEqual(r, 1)
2192        r = delete('test_table', None, n=4)
2193        self.assertEqual(r, 0)
2194        self.assertEqual(query(q).getresult()[0], (5, 2))
2195        s = dict(n=6)
2196        r = delete('test_table', s, n=5)
2197        self.assertEqual(r, 1)
2198        r = delete('test_table', s, n=5)
2199        self.assertEqual(r, 0)
2200        self.assertEqual(query(q).getresult()[0], (6, 1))
2201
2202    def testDeleteWithCompositeKey(self):
2203        query = self.db.query
2204        table = 'delete_test_table_1'
2205        self.createTable(table, 'n integer primary key, t text',
2206            values=enumerate('abc', start=1))
2207        self.assertRaises(KeyError, self.db.delete, table, dict(t='b'))
2208        self.assertEqual(self.db.delete(table, dict(n=2)), 1)
2209        r = query('select t from "%s" where n=2' % table).getresult()
2210        self.assertEqual(r, [])
2211        self.assertEqual(self.db.delete(table, dict(n=2)), 0)
2212        r = query('select t from "%s" where n=3' % table).getresult()[0][0]
2213        self.assertEqual(r, 'c')
2214        table = 'delete_test_table_2'
2215        self.createTable(table,
2216            'n integer, m integer, t text, primary key (n, m)',
2217            values=[(n + 1, m + 1, chr(ord('a') + 2 * n + m))
2218                for n in range(3) for m in range(2)])
2219        self.assertRaises(KeyError, self.db.delete, table, dict(n=2, t='b'))
2220        self.assertEqual(self.db.delete(table, dict(n=2, m=2)), 1)
2221        r = [r[0] for r in query('select t from "%s" where n=2'
2222            ' order by m' % table).getresult()]
2223        self.assertEqual(r, ['c'])
2224        self.assertEqual(self.db.delete(table, dict(n=2, m=2)), 0)
2225        r = [r[0] for r in query('select t from "%s" where n=3'
2226            ' order by m' % table).getresult()]
2227        self.assertEqual(r, ['e', 'f'])
2228        self.assertEqual(self.db.delete(table, dict(n=3, m=1)), 1)
2229        r = [r[0] for r in query('select t from "%s" where n=3'
2230            ' order by m' % table).getresult()]
2231        self.assertEqual(r, ['f'])
2232
2233    def testDeleteWithQuotedNames(self):
2234        delete = self.db.delete
2235        query = self.db.query
2236        table = 'test table for delete()'
2237        self.createTable(table, '"Prime!" smallint primary key,'
2238            ' "much space" integer, "Questions?" text',
2239            values=[(19, 5005, 'Yes!')])
2240        r = {'Prime!': 17}
2241        r = delete(table, r)
2242        self.assertEqual(r, 0)
2243        r = query('select count(*) from "%s"' % table).getresult()
2244        self.assertEqual(r[0][0], 1)
2245        r = {'Prime!': 19}
2246        r = delete(table, r)
2247        self.assertEqual(r, 1)
2248        r = query('select count(*) from "%s"' % table).getresult()
2249        self.assertEqual(r[0][0], 0)
2250
2251    def testDeleteReferenced(self):
2252        delete = self.db.delete
2253        query = self.db.query
2254        self.createTable('test_parent',
2255            'n smallint primary key', values=range(3))
2256        self.createTable('test_child',
2257            'n smallint primary key references test_parent', values=range(3))
2258        q = ("select (select count(*) from test_parent),"
2259            " (select count(*) from test_child)")
2260        self.assertEqual(query(q).getresult()[0], (3, 3))
2261        self.assertRaises(pg.ProgrammingError,
2262            delete, 'test_parent', None, n=2)
2263        self.assertRaises(pg.ProgrammingError,
2264            delete, 'test_parent *', None, n=2)
2265        r = delete('test_child', None, n=2)
2266        self.assertEqual(r, 1)
2267        self.assertEqual(query(q).getresult()[0], (3, 2))
2268        r = delete('test_parent', None, n=2)
2269        self.assertEqual(r, 1)
2270        self.assertEqual(query(q).getresult()[0], (2, 2))
2271        self.assertRaises(pg.ProgrammingError,
2272            delete, 'test_parent', dict(n=0))
2273        self.assertRaises(pg.ProgrammingError,
2274            delete, 'test_parent *', dict(n=0))
2275        r = delete('test_child', dict(n=0))
2276        self.assertEqual(r, 1)
2277        self.assertEqual(query(q).getresult()[0], (2, 1))
2278        r = delete('test_child', dict(n=0))
2279        self.assertEqual(r, 0)
2280        r = delete('test_parent', dict(n=0))
2281        self.assertEqual(r, 1)
2282        self.assertEqual(query(q).getresult()[0], (1, 1))
2283        r = delete('test_parent', None, n=0)
2284        self.assertEqual(r, 0)
2285        q = "select n from test_parent natural join test_child limit 2"
2286        self.assertEqual(query(q).getresult(), [(1,)])
2287
2288    def testTruncate(self):
2289        truncate = self.db.truncate
2290        self.assertRaises(TypeError, truncate, None)
2291        self.assertRaises(TypeError, truncate, 42)
2292        self.assertRaises(TypeError, truncate, dict(test_table=None))
2293        query = self.db.query
2294        self.createTable('test_table', 'n smallint',
2295            temporary=False, values=[1] * 3)
2296        q = "select count(*) from test_table"
2297        r = query(q).getresult()[0][0]
2298        self.assertEqual(r, 3)
2299        truncate('test_table')
2300        r = query(q).getresult()[0][0]
2301        self.assertEqual(r, 0)
2302        for i in range(3):
2303            query("insert into test_table values (1)")
2304        r = query(q).getresult()[0][0]
2305        self.assertEqual(r, 3)
2306        truncate('public.test_table')
2307        r = query(q).getresult()[0][0]
2308        self.assertEqual(r, 0)
2309        self.createTable('test_table_2', 'n smallint', temporary=True)
2310        for t in (list, tuple, set):
2311            for i in range(3):
2312                query("insert into test_table values (1)")
2313                query("insert into test_table_2 values (2)")
2314            q = ("select (select count(*) from test_table),"
2315                " (select count(*) from test_table_2)")
2316            r = query(q).getresult()[0]
2317            self.assertEqual(r, (3, 3))
2318            truncate(t(['test_table', 'test_table_2']))
2319            r = query(q).getresult()[0]
2320            self.assertEqual(r, (0, 0))
2321
2322    def testTruncateRestart(self):
2323        truncate = self.db.truncate
2324        self.assertRaises(TypeError, truncate, 'test_table', restart='invalid')
2325        query = self.db.query
2326        self.createTable('test_table', 'n serial, t text')
2327        for n in range(3):
2328            query("insert into test_table (t) values ('test')")
2329        q = "select count(n), min(n), max(n) from test_table"
2330        r = query(q).getresult()[0]
2331        self.assertEqual(r, (3, 1, 3))
2332        truncate('test_table')
2333        r = query(q).getresult()[0]
2334        self.assertEqual(r, (0, None, None))
2335        for n in range(3):
2336            query("insert into test_table (t) values ('test')")
2337        r = query(q).getresult()[0]
2338        self.assertEqual(r, (3, 4, 6))
2339        truncate('test_table', restart=True)
2340        r = query(q).getresult()[0]
2341        self.assertEqual(r, (0, None, None))
2342        for n in range(3):
2343            query("insert into test_table (t) values ('test')")
2344        r = query(q).getresult()[0]
2345        self.assertEqual(r, (3, 1, 3))
2346
2347    def testTruncateCascade(self):
2348        truncate = self.db.truncate
2349        self.assertRaises(TypeError, truncate, 'test_table', cascade='invalid')
2350        query = self.db.query
2351        self.createTable('test_parent', 'n smallint primary key',
2352            values=range(3))
2353        self.createTable('test_child',
2354            'n smallint primary key references test_parent (n)',
2355            values=range(3))
2356        q = ("select (select count(*) from test_parent),"
2357            " (select count(*) from test_child)")
2358        r = query(q).getresult()[0]
2359        self.assertEqual(r, (3, 3))
2360        self.assertRaises(pg.ProgrammingError, truncate, 'test_parent')
2361        truncate(['test_parent', 'test_child'])
2362        r = query(q).getresult()[0]
2363        self.assertEqual(r, (0, 0))
2364        for n in range(3):
2365            query("insert into test_parent (n) values (%d)" % n)
2366            query("insert into test_child (n) values (%d)" % n)
2367        r = query(q).getresult()[0]
2368        self.assertEqual(r, (3, 3))
2369        truncate('test_parent', cascade=True)
2370        r = query(q).getresult()[0]
2371        self.assertEqual(r, (0, 0))
2372        for n in range(3):
2373            query("insert into test_parent (n) values (%d)" % n)
2374            query("insert into test_child (n) values (%d)" % n)
2375        r = query(q).getresult()[0]
2376        self.assertEqual(r, (3, 3))
2377        truncate('test_child')
2378        r = query(q).getresult()[0]
2379        self.assertEqual(r, (3, 0))
2380        self.assertRaises(pg.ProgrammingError, truncate, 'test_parent')
2381        truncate('test_parent', cascade=True)
2382        r = query(q).getresult()[0]
2383        self.assertEqual(r, (0, 0))
2384
2385    def testTruncateOnly(self):
2386        truncate = self.db.truncate
2387        self.assertRaises(TypeError, truncate, 'test_table', only='invalid')
2388        query = self.db.query
2389        self.createTable('test_parent', 'n smallint')
2390        self.createTable('test_child', 'm smallint) inherits (test_parent')
2391        for n in range(3):
2392            query("insert into test_parent (n) values (1)")
2393            query("insert into test_child (n, m) values (2, 3)")
2394        q = ("select (select count(*) from test_parent),"
2395            " (select count(*) from test_child)")
2396        r = query(q).getresult()[0]
2397        self.assertEqual(r, (6, 3))
2398        truncate('test_parent')
2399        r = query(q).getresult()[0]
2400        self.assertEqual(r, (0, 0))
2401        for n in range(3):
2402            query("insert into test_parent (n) values (1)")
2403            query("insert into test_child (n, m) values (2, 3)")
2404        r = query(q).getresult()[0]
2405        self.assertEqual(r, (6, 3))
2406        truncate('test_parent*')
2407        r = query(q).getresult()[0]
2408        self.assertEqual(r, (0, 0))
2409        for n in range(3):
2410            query("insert into test_parent (n) values (1)")
2411            query("insert into test_child (n, m) values (2, 3)")
2412        r = query(q).getresult()[0]
2413        self.assertEqual(r, (6, 3))
2414        truncate('test_parent', only=True)
2415        r = query(q).getresult()[0]
2416        self.assertEqual(r, (3, 3))
2417        truncate('test_parent', only=False)
2418        r = query(q).getresult()[0]
2419        self.assertEqual(r, (0, 0))
2420        self.assertRaises(ValueError, truncate, 'test_parent*', only=True)
2421        truncate('test_parent*', only=False)
2422        self.createTable('test_parent_2', 'n smallint')
2423        self.createTable('test_child_2', 'm smallint) inherits (test_parent_2')
2424        for t in '', '_2':
2425            for n in range(3):
2426                query("insert into test_parent%s (n) values (1)" % t)
2427                query("insert into test_child%s (n, m) values (2, 3)" % t)
2428        q = ("select (select count(*) from test_parent),"
2429            " (select count(*) from test_child),"
2430            " (select count(*) from test_parent_2),"
2431            " (select count(*) from test_child_2)")
2432        r = query(q).getresult()[0]
2433        self.assertEqual(r, (6, 3, 6, 3))
2434        truncate(['test_parent', 'test_parent_2'], only=[False, True])
2435        r = query(q).getresult()[0]
2436        self.assertEqual(r, (0, 0, 3, 3))
2437        truncate(['test_parent', 'test_parent_2'], only=False)
2438        r = query(q).getresult()[0]
2439        self.assertEqual(r, (0, 0, 0, 0))
2440        self.assertRaises(ValueError, truncate,
2441            ['test_parent*', 'test_child'], only=[True, False])
2442        truncate(['test_parent*', 'test_child'], only=[False, True])
2443
2444    def testTruncateQuoted(self):
2445        truncate = self.db.truncate
2446        query = self.db.query
2447        table = "test table for truncate()"
2448        self.createTable(table, 'n smallint', temporary=False, values=[1] * 3)
2449        q = 'select count(*) from "%s"' % table
2450        r = query(q).getresult()[0][0]
2451        self.assertEqual(r, 3)
2452        truncate(table)
2453        r = query(q).getresult()[0][0]
2454        self.assertEqual(r, 0)
2455        for i in range(3):
2456            query('insert into "%s" values (1)' % table)
2457        r = query(q).getresult()[0][0]
2458        self.assertEqual(r, 3)
2459        truncate('public."%s"' % table)
2460        r = query(q).getresult()[0][0]
2461        self.assertEqual(r, 0)
2462
2463    def testGetAsList(self):
2464        get_as_list = self.db.get_as_list
2465        self.assertRaises(TypeError, get_as_list)
2466        self.assertRaises(TypeError, get_as_list, None)
2467        query = self.db.query
2468        table = 'test_aslist'
2469        r = query('select 1 as colname').namedresult()[0]
2470        self.assertIsInstance(r, tuple)
2471        named = hasattr(r, 'colname')
2472        names = [(1, 'Homer'), (2, 'Marge'),
2473                (3, 'Bart'), (4, 'Lisa'), (5, 'Maggie')]
2474        self.createTable(table,
2475            'id smallint primary key, name varchar', values=names)
2476        r = get_as_list(table)
2477        self.assertIsInstance(r, list)
2478        self.assertEqual(r, names)
2479        for t, n in zip(r, names):
2480            self.assertIsInstance(t, tuple)
2481            self.assertEqual(t, n)
2482            if named:
2483                self.assertEqual(t.id, n[0])
2484                self.assertEqual(t.name, n[1])
2485                self.assertEqual(t._asdict(), dict(id=n[0], name=n[1]))
2486        r = get_as_list(table, what='name')
2487        self.assertIsInstance(r, list)
2488        expected = sorted((row[1],) for row in names)
2489        self.assertEqual(r, expected)
2490        r = get_as_list(table, what='name, id')
2491        self.assertIsInstance(r, list)
2492        expected = sorted(tuple(reversed(row)) for row in names)
2493        self.assertEqual(r, expected)
2494        r = get_as_list(table, what=['name', 'id'])
2495        self.assertIsInstance(r, list)
2496        self.assertEqual(r, expected)
2497        r = get_as_list(table, where="name like 'Ba%'")
2498        self.assertIsInstance(r, list)
2499        self.assertEqual(r, names[2:3])
2500        r = get_as_list(table, what='name', where="name like 'Ma%'")
2501        self.assertIsInstance(r, list)
2502        self.assertEqual(r, [('Maggie',), ('Marge',)])
2503        r = get_as_list(table, what='name',
2504                        where=["name like 'Ma%'", "name like '%r%'"])
2505        self.assertIsInstance(r, list)
2506        self.assertEqual(r, [('Marge',)])
2507        r = get_as_list(table, what='name', order='id')
2508        self.assertIsInstance(r, list)
2509        expected = [(row[1],) for row in names]
2510        self.assertEqual(r, expected)
2511        r = get_as_list(table, what=['name'], order=['id'])
2512        self.assertIsInstance(r, list)
2513        self.assertEqual(r, expected)
2514        r = get_as_list(table, what=['id', 'name'], order=['id', 'name'])
2515        self.assertIsInstance(r, list)
2516        self.assertEqual(r, names)
2517        r = get_as_list(table, what='id * 2 as num', order='id desc')
2518        self.assertIsInstance(r, list)
2519        expected = [(n,) for n in range(10, 0, -2)]
2520        self.assertEqual(r, expected)
2521        r = get_as_list(table, limit=2)
2522        self.assertIsInstance(r, list)
2523        self.assertEqual(r, names[:2])
2524        r = get_as_list(table, offset=3)
2525        self.assertIsInstance(r, list)
2526        self.assertEqual(r, names[3:])
2527        r = get_as_list(table, limit=1, offset=2)
2528        self.assertIsInstance(r, list)
2529        self.assertEqual(r, names[2:3])
2530        r = get_as_list(table, scalar=True)
2531        self.assertIsInstance(r, list)
2532        self.assertEqual(r, list(range(1, 6)))
2533        r = get_as_list(table, what='name', scalar=True)
2534        self.assertIsInstance(r, list)
2535        expected = sorted(row[1] for row in names)
2536        self.assertEqual(r, expected)
2537        r = get_as_list(table, what='name', limit=1, scalar=True)
2538        self.assertIsInstance(r, list)
2539        self.assertEqual(r, expected[:1])
2540        query('alter table "%s" drop constraint "%s_pkey"' % (table, table))
2541        self.assertRaises(KeyError, self.db.pkey, table, flush=True)
2542        names.insert(1, (1, 'Snowball'))
2543        query('insert into "%s" values ($1, $2)' % table, (1, 'Snowball'))
2544        r = get_as_list(table)
2545        self.assertIsInstance(r, list)
2546        self.assertEqual(r, names)
2547        r = get_as_list(table, what='name', where='id=1', scalar=True)
2548        self.assertIsInstance(r, list)
2549        self.assertEqual(r, ['Homer', 'Snowball'])
2550        # test with unordered query
2551        r = get_as_list(table, order=False)
2552        self.assertIsInstance(r, list)
2553        self.assertEqual(set(r), set(names))
2554        # test with arbitrary from clause
2555        from_table = '(select lower(name) as n2 from "%s") as t2' % table
2556        r = get_as_list(from_table)
2557        self.assertIsInstance(r, list)
2558        r = set(row[0] for row in r)
2559        expected = set(row[1].lower() for row in names)
2560        self.assertEqual(r, expected)
2561        r = get_as_list(from_table, order='n2', scalar=True)
2562        self.assertIsInstance(r, list)
2563        self.assertEqual(r, sorted(expected))
2564        r = get_as_list(from_table, order='n2', limit=1)
2565        self.assertIsInstance(r, list)
2566        self.assertEqual(len(r), 1)
2567        t = r[0]
2568        self.assertIsInstance(t, tuple)
2569        if named:
2570            self.assertEqual(t.n2, 'bart')
2571            self.assertEqual(t._asdict(), dict(n2='bart'))
2572        else:
2573            self.assertEqual(t, ('bart',))
2574
2575    def testGetAsDict(self):
2576        get_as_dict = self.db.get_as_dict
2577        self.assertRaises(TypeError, get_as_dict)
2578        self.assertRaises(TypeError, get_as_dict, None)
2579        # the test table has no primary key
2580        self.assertRaises(pg.ProgrammingError, get_as_dict, 'test')
2581        query = self.db.query
2582        table = 'test_asdict'
2583        r = query('select 1 as colname').namedresult()[0]
2584        self.assertIsInstance(r, tuple)
2585        named = hasattr(r, 'colname')
2586        colors = [(1, '#7cb9e8', 'Aero'), (2, '#b5a642', 'Brass'),
2587                  (3, '#b2ffff', 'Celeste'), (4, '#c19a6b', 'Desert')]
2588        self.createTable(table,
2589            'id smallint primary key, rgb char(7), name varchar',
2590            values=colors)
2591        # keyname must be string, list or tuple
2592        self.assertRaises(KeyError, get_as_dict, table, 3)
2593        self.assertRaises(KeyError, get_as_dict, table, dict(id=None))
2594        # missing keyname in row
2595        self.assertRaises(KeyError, get_as_dict, table,
2596                          keyname='rgb', what='name')
2597        r = get_as_dict(table)
2598        self.assertIsInstance(r, OrderedDict)
2599        expected = OrderedDict((row[0], row[1:]) for row in colors)
2600        self.assertEqual(r, expected)
2601        for key in r:
2602            self.assertIsInstance(key, int)
2603            self.assertIn(key, expected)
2604            row = r[key]
2605            self.assertIsInstance(row, tuple)
2606            t = expected[key]
2607            self.assertEqual(row, t)
2608            if named:
2609                self.assertEqual(row.rgb, t[0])
2610                self.assertEqual(row.name, t[1])
2611                self.assertEqual(row._asdict(), dict(rgb=t[0], name=t[1]))
2612        if OrderedDict is not dict:  # Python > 2.6
2613            self.assertEqual(r.keys(), expected.keys())
2614        r = get_as_dict(table, keyname='rgb')
2615        self.assertIsInstance(r, OrderedDict)
2616        expected = OrderedDict((row[1], (row[0], row[2]))
2617            for row in sorted(colors, key=itemgetter(1)))
2618        self.assertEqual(r, expected)
2619        for key in r:
2620            self.assertIsInstance(key, str)
2621            self.assertIn(key, expected)
2622            row = r[key]
2623            self.assertIsInstance(row, tuple)
2624            t = expected[key]
2625            self.assertEqual(row, t)
2626            if named:
2627                self.assertEqual(row.id, t[0])
2628                self.assertEqual(row.name, t[1])
2629                self.assertEqual(row._asdict(), dict(id=t[0], name=t[1]))
2630        if OrderedDict is not dict:  # Python > 2.6
2631            self.assertEqual(r.keys(), expected.keys())
2632        r = get_as_dict(table, keyname=['id', 'rgb'])
2633        self.assertIsInstance(r, OrderedDict)
2634        expected = OrderedDict((row[:2], row[2:]) for row in colors)
2635        self.assertEqual(r, expected)
2636        for key in r:
2637            self.assertIsInstance(key, tuple)
2638            self.assertIsInstance(key[0], int)
2639            self.assertIsInstance(key[1], str)
2640            if named:
2641                self.assertEqual(key, (key.id, key.rgb))
2642                self.assertEqual(key._fields, ('id', 'rgb'))
2643            row = r[key]
2644            self.assertIsInstance(row, tuple)
2645            self.assertIsInstance(row[0], str)
2646            t = expected[key]
2647            self.assertEqual(row, t)
2648            if named:
2649                self.assertEqual(row.name, t[0])
2650                self.assertEqual(row._asdict(), dict(name=t[0]))
2651        if OrderedDict is not dict:  # Python > 2.6
2652            self.assertEqual(r.keys(), expected.keys())
2653        r = get_as_dict(table, keyname=['id', 'rgb'], scalar=True)
2654        self.assertIsInstance(r, OrderedDict)
2655        expected = OrderedDict((row[:2], row[2]) for row in colors)
2656        self.assertEqual(r, expected)
2657        for key in r:
2658            self.assertIsInstance(key, tuple)
2659            row = r[key]
2660            self.assertIsInstance(row, str)
2661            t = expected[key]
2662            self.assertEqual(row, t)
2663        if OrderedDict is not dict:  # Python > 2.6
2664            self.assertEqual(r.keys(), expected.keys())
2665        r = get_as_dict(table, keyname='rgb', what=['rgb', 'name'], scalar=True)
2666        self.assertIsInstance(r, OrderedDict)
2667        expected = OrderedDict((row[1], row[2])
2668            for row in sorted(colors, key=itemgetter(1)))
2669        self.assertEqual(r, expected)
2670        for key in r:
2671            self.assertIsInstance(key, str)
2672            row = r[key]
2673            self.assertIsInstance(row, str)
2674            t = expected[key]
2675            self.assertEqual(row, t)
2676        if OrderedDict is not dict:  # Python > 2.6
2677            self.assertEqual(r.keys(), expected.keys())
2678        r = get_as_dict(table, what='id, name',
2679                        where="rgb like '#b%'", scalar=True)
2680        self.assertIsInstance(r, OrderedDict)
2681        expected = OrderedDict((row[0], row[2]) for row in colors[1:3])
2682        self.assertEqual(r, expected)
2683        for key in r:
2684            self.assertIsInstance(key, int)
2685            row = r[key]
2686            self.assertIsInstance(row, str)
2687            t = expected[key]
2688            self.assertEqual(row, t)
2689        if OrderedDict is not dict:  # Python > 2.6
2690            self.assertEqual(r.keys(), expected.keys())
2691        expected = r
2692        r = get_as_dict(table, what=['name', 'id'],
2693                        where=['id > 1', 'id < 4', "rgb like '#b%'",
2694                   "name not like 'A%'", "name not like '%t'"], scalar=True)
2695        self.assertEqual(r, expected)
2696        r = get_as_dict(table, what='name, id', limit=2, offset=1, scalar=True)
2697        self.assertEqual(r, expected)
2698        r = get_as_dict(table, keyname=('id',), what=('name', 'id'),
2699                        where=('id > 1', 'id < 4'), order=('id',), scalar=True)
2700        self.assertEqual(r, expected)
2701        r = get_as_dict(table, limit=1)
2702        self.assertEqual(len(r), 1)
2703        self.assertEqual(r[1][1], 'Aero')
2704        r = get_as_dict(table, offset=3)
2705        self.assertEqual(len(r), 1)
2706        self.assertEqual(r[4][1], 'Desert')
2707        r = get_as_dict(table, order='id desc')
2708        expected = OrderedDict((row[0], row[1:]) for row in reversed(colors))
2709        self.assertEqual(r, expected)
2710        r = get_as_dict(table, where='id > 5')
2711        self.assertIsInstance(r, OrderedDict)
2712        self.assertEqual(len(r), 0)
2713        # test with unordered query
2714        expected = dict((row[0], row[1:]) for row in colors)
2715        r = get_as_dict(table, order=False)
2716        self.assertIsInstance(r, dict)
2717        self.assertEqual(r, expected)
2718        if dict is not OrderedDict:  # Python > 2.6
2719            self.assertNotIsInstance(self, OrderedDict)
2720        # test with arbitrary from clause
2721        from_table = '(select id, lower(name) as n2 from "%s") as t2' % table
2722        # primary key must be passed explicitly in this case
2723        self.assertRaises(pg.ProgrammingError, get_as_dict, from_table)
2724        r = get_as_dict(from_table, 'id')
2725        self.assertIsInstance(r, OrderedDict)
2726        expected = OrderedDict((row[0], (row[2].lower(),)) for row in colors)
2727        self.assertEqual(r, expected)
2728        # test without a primary key
2729        query('alter table "%s" drop constraint "%s_pkey"' % (table, table))
2730        self.assertRaises(KeyError, self.db.pkey, table, flush=True)
2731        self.assertRaises(pg.ProgrammingError, get_as_dict, table)
2732        r = get_as_dict(table, keyname='id')
2733        expected = OrderedDict((row[0], row[1:]) for row in colors)
2734        self.assertIsInstance(r, dict)
2735        self.assertEqual(r, expected)
2736        r = (1, '#007fff', 'Azure')
2737        query('insert into "%s" values ($1, $2, $3)' % table, r)
2738        # the last entry will win
2739        expected[1] = r[1:]
2740        r = get_as_dict(table, keyname='id')
2741        self.assertEqual(r, expected)
2742
2743    def testTransaction(self):
2744        query = self.db.query
2745        self.createTable('test_table', 'n integer', temporary=False)
2746        self.db.begin()
2747        query("insert into test_table values (1)")
2748        query("insert into test_table values (2)")
2749        self.db.commit()
2750        self.db.begin()
2751        query("insert into test_table values (3)")
2752        query("insert into test_table values (4)")
2753        self.db.rollback()
2754        self.db.begin()
2755        query("insert into test_table values (5)")
2756        self.db.savepoint('before6')
2757        query("insert into test_table values (6)")
2758        self.db.rollback('before6')
2759        query("insert into test_table values (7)")
2760        self.db.commit()
2761        self.db.begin()
2762        self.db.savepoint('before8')
2763        query("insert into test_table values (8)")
2764        self.db.release('before8')
2765        self.assertRaises(pg.ProgrammingError, self.db.rollback, 'before8')
2766        self.db.commit()
2767        self.db.start()
2768        query("insert into test_table values (9)")
2769        self.db.end()
2770        r = [r[0] for r in query(
2771            "select * from test_table order by 1").getresult()]
2772        self.assertEqual(r, [1, 2, 5, 7, 9])
2773        self.db.begin(mode='read only')
2774        self.assertRaises(pg.ProgrammingError,
2775            query, "insert into test_table values (0)")
2776        self.db.rollback()
2777        self.db.start(mode='Read Only')
2778        self.assertRaises(pg.ProgrammingError,
2779            query, "insert into test_table values (0)")
2780        self.db.abort()
2781
2782    def testTransactionAliases(self):
2783        self.assertEqual(self.db.begin, self.db.start)
2784        self.assertEqual(self.db.commit, self.db.end)
2785        self.assertEqual(self.db.rollback, self.db.abort)
2786
2787    def testContextManager(self):
2788        query = self.db.query
2789        self.createTable('test_table', 'n integer check(n>0)')
2790        with self.db:
2791            query("insert into test_table values (1)")
2792            query("insert into test_table values (2)")
2793        try:
2794            with self.db:
2795                query("insert into test_table values (3)")
2796                query("insert into test_table values (4)")
2797                raise ValueError('test transaction should rollback')
2798        except ValueError as error:
2799            self.assertEqual(str(error), 'test transaction should rollback')
2800        with self.db:
2801            query("insert into test_table values (5)")
2802        try:
2803            with self.db:
2804                query("insert into test_table values (6)")
2805                query("insert into test_table values (-1)")
2806        except pg.ProgrammingError as error:
2807            self.assertTrue('check' in str(error))
2808        with self.db:
2809            query("insert into test_table values (7)")
2810        r = [r[0] for r in query(
2811            "select * from test_table order by 1").getresult()]
2812        self.assertEqual(r, [1, 2, 5, 7])
2813
2814    def testBytea(self):
2815        query = self.db.query
2816        self.createTable('bytea_test', 'n smallint primary key, data bytea')
2817        s = b"It's all \\ kinds \x00 of\r nasty \xff stuff!\n"
2818        r = self.db.escape_bytea(s)
2819        query('insert into bytea_test values(3, $1)', (r,))
2820        r = query('select * from bytea_test where n=3').getresult()
2821        self.assertEqual(len(r), 1)
2822        r = r[0]
2823        self.assertEqual(len(r), 2)
2824        self.assertEqual(r[0], 3)
2825        r = r[1]
2826        self.assertIsInstance(r, bytes)
2827        self.assertEqual(r, s)
2828
2829    def testInsertUpdateGetBytea(self):
2830        query = self.db.query
2831        self.createTable('bytea_test', 'n smallint primary key, data bytea')
2832        # insert null value
2833        r = self.db.insert('bytea_test', n=0, data=None)
2834        self.assertIsInstance(r, dict)
2835        self.assertIn('n', r)
2836        self.assertEqual(r['n'], 0)
2837        self.assertIn('data', r)
2838        self.assertIsNone(r['data'])
2839        s = b'None'
2840        r = self.db.update('bytea_test', n=0, data=s)
2841        self.assertIsInstance(r, dict)
2842        self.assertIn('n', r)
2843        self.assertEqual(r['n'], 0)
2844        self.assertIn('data', r)
2845        r = r['data']
2846        self.assertIsInstance(r, bytes)
2847        self.assertEqual(r, s)
2848        r = self.db.update('bytea_test', n=0, data=None)
2849        self.assertIsNone(r['data'])
2850        # insert as bytes
2851        s = b"It's all \\ kinds \x00 of\r nasty \xff stuff!\n"
2852        r = self.db.insert('bytea_test', n=5, data=s)
2853        self.assertIsInstance(r, dict)
2854        self.assertIn('n', r)
2855        self.assertEqual(r['n'], 5)
2856        self.assertIn('data', r)
2857        r = r['data']
2858        self.assertIsInstance(r, bytes)
2859        self.assertEqual(r, s)
2860        # update as bytes
2861        s += b"and now even more \x00 nasty \t stuff!\f"
2862        r = self.db.update('bytea_test', n=5, data=s)
2863        self.assertIsInstance(r, dict)
2864        self.assertIn('n', r)
2865        self.assertEqual(r['n'], 5)
2866        self.assertIn('data', r)
2867        r = r['data']
2868        self.assertIsInstance(r, bytes)
2869        self.assertEqual(r, s)
2870        r = query('select * from bytea_test where n=5').getresult()
2871        self.assertEqual(len(r), 1)
2872        r = r[0]
2873        self.assertEqual(len(r), 2)
2874        self.assertEqual(r[0], 5)
2875        r = r[1]
2876        self.assertIsInstance(r, bytes)
2877        self.assertEqual(r, s)
2878        r = self.db.get('bytea_test', dict(n=5))
2879        self.assertIsInstance(r, dict)
2880        self.assertIn('n', r)
2881        self.assertEqual(r['n'], 5)
2882        self.assertIn('data', r)
2883        r = r['data']
2884        self.assertIsInstance(r, bytes)
2885        self.assertEqual(r, s)
2886
2887    def testUpsertBytea(self):
2888        self.createTable('bytea_test', 'n smallint primary key, data bytea')
2889        s = b"It's all \\ kinds \x00 of\r nasty \xff stuff!\n"
2890        r = dict(n=7, data=s)
2891        try:
2892            r = self.db.upsert('bytea_test', r)
2893        except pg.ProgrammingError as error:
2894            if self.db.server_version < 90500:
2895                self.skipTest('database does not support upsert')
2896            self.fail(str(error))
2897        self.assertIsInstance(r, dict)
2898        self.assertIn('n', r)
2899        self.assertEqual(r['n'], 7)
2900        self.assertIn('data', r)
2901        self.assertIsInstance(r['data'], bytes)
2902        self.assertEqual(r['data'], s)
2903        r['data'] = None
2904        r = self.db.upsert('bytea_test', r)
2905        self.assertIsInstance(r, dict)
2906        self.assertIn('n', r)
2907        self.assertEqual(r['n'], 7)
2908        self.assertIn('data', r)
2909        self.assertIsNone(r['data'], bytes)
2910
2911    def testInsertGetJson(self):
2912        try:
2913            self.createTable('json_test', 'n smallint primary key, data json')
2914        except pg.ProgrammingError as error:
2915            if self.db.server_version < 90200:
2916                self.skipTest('database does not support json')
2917            self.fail(str(error))
2918        jsondecode = pg.get_jsondecode()
2919        # insert null value
2920        r = self.db.insert('json_test', n=0, data=None)
2921        self.assertIsInstance(r, dict)
2922        self.assertIn('n', r)
2923        self.assertEqual(r['n'], 0)
2924        self.assertIn('data', r)
2925        self.assertIsNone(r['data'])
2926        r = self.db.get('json_test', 0)
2927        self.assertIsInstance(r, dict)
2928        self.assertIn('n', r)
2929        self.assertEqual(r['n'], 0)
2930        self.assertIn('data', r)
2931        self.assertIsNone(r['data'])
2932        # insert JSON object
2933        data = {
2934          "id": 1, "name": "Foo", "price": 1234.5,
2935          "new": True, "note": None,
2936          "tags": ["Bar", "Eek"],
2937          "stock": {"warehouse": 300, "retail": 20}}
2938        r = self.db.insert('json_test', n=1, data=data)
2939        self.assertIsInstance(r, dict)
2940        self.assertIn('n', r)
2941        self.assertEqual(r['n'], 1)
2942        self.assertIn('data', r)
2943        r = r['data']
2944        if jsondecode is None:
2945            self.assertIsInstance(r, str)
2946            r = json.loads(r)
2947        self.assertIsInstance(r, dict)
2948        self.assertEqual(r, data)
2949        self.assertIsInstance(r['id'], int)
2950        self.assertIsInstance(r['name'], unicode)
2951        self.assertIsInstance(r['price'], float)
2952        self.assertIsInstance(r['new'], bool)
2953        self.assertIsInstance(r['tags'], list)
2954        self.assertIsInstance(r['stock'], dict)
2955        r = self.db.get('json_test', 1)
2956        self.assertIsInstance(r, dict)
2957        self.assertIn('n', r)
2958        self.assertEqual(r['n'], 1)
2959        self.assertIn('data', r)
2960        r = r['data']
2961        if jsondecode is None:
2962            self.assertIsInstance(r, str)
2963            r = json.loads(r)
2964        self.assertIsInstance(r, dict)
2965        self.assertEqual(r, data)
2966        self.assertIsInstance(r['id'], int)
2967        self.assertIsInstance(r['name'], unicode)
2968        self.assertIsInstance(r['price'], float)
2969        self.assertIsInstance(r['new'], bool)
2970        self.assertIsInstance(r['tags'], list)
2971        self.assertIsInstance(r['stock'], dict)
2972        # insert JSON object as text
2973        self.db.insert('json_test', n=2, data=json.dumps(data))
2974        q = "select data from json_test where n in (1, 2) order by n"
2975        r = self.db.query(q).getresult()
2976        self.assertEqual(len(r), 2)
2977        self.assertIsInstance(r[0][0], str if jsondecode is None else dict)
2978        self.assertEqual(r[0][0], r[1][0])
2979
2980    def testInsertGetJsonb(self):
2981        try:
2982            self.createTable('jsonb_test',
2983                'n smallint primary key, data jsonb')
2984        except pg.ProgrammingError as error:
2985            if self.db.server_version < 90400:
2986                self.skipTest('database does not support jsonb')
2987            self.fail(str(error))
2988        jsondecode = pg.get_jsondecode()
2989        # insert null value
2990        r = self.db.insert('jsonb_test', n=0, data=None)
2991        self.assertIsInstance(r, dict)
2992        self.assertIn('n', r)
2993        self.assertEqual(r['n'], 0)
2994        self.assertIn('data', r)
2995        self.assertIsNone(r['data'])
2996        r = self.db.get('jsonb_test', 0)
2997        self.assertIsInstance(r, dict)
2998        self.assertIn('n', r)
2999        self.assertEqual(r['n'], 0)
3000        self.assertIn('data', r)
3001        self.assertIsNone(r['data'])
3002        # insert JSON object
3003        data = {
3004          "id": 1, "name": "Foo", "price": 1234.5,
3005          "new": True, "note": None,
3006          "tags": ["Bar", "Eek"],
3007          "stock": {"warehouse": 300, "retail": 20}}
3008        r = self.db.insert('jsonb_test', n=1, data=data)
3009        self.assertIsInstance(r, dict)
3010        self.assertIn('n', r)
3011        self.assertEqual(r['n'], 1)
3012        self.assertIn('data', r)
3013        r = r['data']
3014        if jsondecode is None:
3015            self.assertIsInstance(r, str)
3016            r = json.loads(r)
3017        self.assertIsInstance(r, dict)
3018        self.assertEqual(r, data)
3019        self.assertIsInstance(r['id'], int)
3020        self.assertIsInstance(r['name'], unicode)
3021        self.assertIsInstance(r['price'], float)
3022        self.assertIsInstance(r['new'], bool)
3023        self.assertIsInstance(r['tags'], list)
3024        self.assertIsInstance(r['stock'], dict)
3025        r = self.db.get('jsonb_test', 1)
3026        self.assertIsInstance(r, dict)
3027        self.assertIn('n', r)
3028        self.assertEqual(r['n'], 1)
3029        self.assertIn('data', r)
3030        r = r['data']
3031        if jsondecode is None:
3032            self.assertIsInstance(r, str)
3033            r = json.loads(r)
3034        self.assertIsInstance(r, dict)
3035        self.assertEqual(r, data)
3036        self.assertIsInstance(r['id'], int)
3037        self.assertIsInstance(r['name'], unicode)
3038        self.assertIsInstance(r['price'], float)
3039        self.assertIsInstance(r['new'], bool)
3040        self.assertIsInstance(r['tags'], list)
3041        self.assertIsInstance(r['stock'], dict)
3042
3043    def testArray(self):
3044        self.createTable('arraytest',
3045            'id smallint, i2 smallint[], i4 integer[], i8 bigint[],'
3046            ' d numeric[], f4 real[], f8 double precision[], m money[],'
3047            ' b bool[], v4 varchar(4)[], c4 char(4)[], t text[]')
3048        r = self.db.get_attnames('arraytest')
3049        if self.regtypes:
3050            self.assertEqual(r, dict(
3051                id='smallint', i2='smallint[]', i4='integer[]', i8='bigint[]',
3052                d='numeric[]', f4='real[]', f8='double precision[]',
3053                m='money[]', b='boolean[]',
3054                v4='character varying[]', c4='character[]', t='text[]'))
3055        else:
3056            self.assertEqual(r, dict(
3057                id='int', i2='int[]', i4='int[]', i8='int[]',
3058                d='num[]', f4='float[]', f8='float[]', m='money[]',
3059                b='bool[]', v4='text[]', c4='text[]', t='text[]'))
3060        decimal = pg.get_decimal()
3061        if decimal is Decimal:
3062            long_decimal = decimal('123456789.123456789')
3063            odd_money = decimal('1234567891234567.89')
3064        else:
3065            long_decimal = decimal('12345671234.5')
3066            odd_money = decimal('1234567123.25')
3067        t, f = (True, False) if pg.get_bool() else ('t', 'f')
3068        data = dict(id=42, i2=[42, 1234, None, 0, -1],
3069            i4=[42, 123456789, None, 0, 1, -1],
3070            i8=[long(42), long(123456789123456789), None,
3071                long(0), long(1), long(-1)],
3072            d=[decimal(42), long_decimal, None,
3073               decimal(0), decimal(1), decimal(-1), -long_decimal],
3074            f4=[42.0, 1234.5, None, 0.0, 1.0, -1.0,
3075                float('inf'), float('-inf')],
3076            f8=[42.0, 12345671234.5, None, 0.0, 1.0, -1.0,
3077                float('inf'), float('-inf')],
3078            m=[decimal('42.00'), odd_money, None,
3079               decimal('0.00'), decimal('1.00'), decimal('-1.00'), -odd_money],
3080            b=[t, f, t, None, f, t, None, None, t],
3081            v4=['abc', '"Hi"', '', None], c4=['abc ', '"Hi"', '    ', None],
3082            t=['abc', 'Hello, World!', '"Hello, World!"', '', None])
3083        r = data.copy()
3084        self.db.insert('arraytest', r)
3085        self.assertEqual(r, data)
3086        self.db.insert('arraytest', r)
3087        r = self.db.get('arraytest', 42, 'id')
3088        self.assertEqual(r, data)
3089        r = self.db.query('select * from arraytest limit 1').dictresult()[0]
3090        self.assertEqual(r, data)
3091
3092    def testArrayLiteral(self):
3093        insert = self.db.insert
3094        self.createTable('arraytest', 'i int[], t text[]', oids=True)
3095        r = dict(i=[1, 2, 3], t=['a', 'b', 'c'])
3096        insert('arraytest', r)
3097        self.assertEqual(r['i'], [1, 2, 3])
3098        self.assertEqual(r['t'], ['a', 'b', 'c'])
3099        r = dict(i='{1,2,3}', t='{a,b,c}')
3100        self.db.insert('arraytest', r)
3101        self.assertEqual(r['i'], [1, 2, 3])
3102        self.assertEqual(r['t'], ['a', 'b', 'c'])
3103        L = pg._Literal
3104        r = dict(i=L("ARRAY[1, 2, 3]"), t=L("ARRAY['a', 'b', 'c']"))
3105        self.db.insert('arraytest', r)
3106        self.assertEqual(r['i'], [1, 2, 3])
3107        self.assertEqual(r['t'], ['a', 'b', 'c'])
3108        r = dict(i="1, 2, 3", t="'a', 'b', 'c'")
3109        self.assertRaises(pg.ProgrammingError, self.db.insert, 'arraytest', r)
3110
3111    def testArrayOfIds(self):
3112        self.createTable('arraytest', 'c cid[], o oid[], x xid[]', oids=True)
3113        r = self.db.get_attnames('arraytest')
3114        if self.regtypes:
3115            self.assertEqual(r, dict(
3116                oid='oid', c='cid[]', o='oid[]', x='xid[]'))
3117        else:
3118            self.assertEqual(r, dict(
3119                oid='int', c='int[]', o='int[]', x='int[]'))
3120        data = dict(c=[11, 12, 13], o=[21, 22, 23], x=[31, 32, 33])
3121        r = data.copy()
3122        self.db.insert('arraytest', r)
3123        qoid = 'oid(arraytest)'
3124        oid = r.pop(qoid)
3125        self.assertEqual(r, data)
3126        r = {qoid: oid}
3127        self.db.get('arraytest', r)
3128        self.assertEqual(oid, r.pop(qoid))
3129        self.assertEqual(r, data)
3130
3131    def testArrayOfText(self):
3132        self.createTable('arraytest', 'data text[]', oids=True)
3133        r = self.db.get_attnames('arraytest')
3134        self.assertEqual(r['data'], 'text[]')
3135        data = ['Hello, World!', '', None, '{a,b,c}', '"Hi!"',
3136                'null', 'NULL', 'Null', 'nulL',
3137                "It's all \\ kinds of\r nasty stuff!\n"]
3138        r = dict(data=data)
3139        self.db.insert('arraytest', r)
3140        self.assertEqual(r['data'], data)
3141        self.assertIsInstance(r['data'][1], str)
3142        self.assertIsNone(r['data'][2])
3143        r['data'] = None
3144        self.db.get('arraytest', r)
3145        self.assertEqual(r['data'], data)
3146        self.assertIsInstance(r['data'][1], str)
3147        self.assertIsNone(r['data'][2])
3148
3149    def testArrayOfBytea(self):
3150        self.createTable('arraytest', 'data bytea[]', oids=True)
3151        r = self.db.get_attnames('arraytest')
3152        self.assertEqual(r['data'], 'bytea[]')
3153        data = [b'Hello, World!', b'', None, b'{a,b,c}', b'"Hi!"',
3154                b"It's all \\ kinds \x00 of\r nasty \xff stuff!\n"]
3155        r = dict(data=data)
3156        self.db.insert('arraytest', r)
3157        self.assertEqual(r['data'], data)
3158        self.assertIsInstance(r['data'][1], bytes)
3159        self.assertIsNone(r['data'][2])
3160        r['data'] = None
3161        self.db.get('arraytest', r)
3162        self.assertEqual(r['data'], data)
3163        self.assertIsInstance(r['data'][1], bytes)
3164        self.assertIsNone(r['data'][2])
3165
3166    def testArrayOfJson(self):
3167        try:
3168            self.createTable('arraytest', 'data json[]', oids=True)
3169        except pg.ProgrammingError as error:
3170            if self.db.server_version < 90200:
3171                self.skipTest('database does not support json')
3172            self.fail(str(error))
3173        r = self.db.get_attnames('arraytest')
3174        self.assertEqual(r['data'], 'json[]')
3175        data = [dict(id=815, name='John Doe'), dict(id=816, name='Jane Roe')]
3176        jsondecode = pg.get_jsondecode()
3177        r = dict(data=data)
3178        self.db.insert('arraytest', r)
3179        if jsondecode is None:
3180            r['data'] = [json.loads(d) for d in r['data']]
3181        self.assertEqual(r['data'], data)
3182        r['data'] = None
3183        self.db.get('arraytest', r)
3184        if jsondecode is None:
3185            r['data'] = [json.loads(d) for d in r['data']]
3186        self.assertEqual(r['data'], data)
3187        r = dict(data=[json.dumps(d) for d in data])
3188        self.db.insert('arraytest', r)
3189        if jsondecode is None:
3190            r['data'] = [json.loads(d) for d in r['data']]
3191        self.assertEqual(r['data'], data)
3192        r['data'] = None
3193        self.db.get('arraytest', r)
3194        # insert empty json values
3195        r = dict(data=['', None])
3196        self.db.insert('arraytest', r)
3197        r = r['data']
3198        self.assertIsInstance(r, list)
3199        self.assertEqual(len(r), 2)
3200        self.assertIsNone(r[0])
3201        self.assertIsNone(r[1])
3202
3203    def testArrayOfJsonb(self):
3204        try:
3205            self.createTable('arraytest', 'data jsonb[]', oids=True)
3206        except pg.ProgrammingError as error:
3207            if self.db.server_version < 90400:
3208                self.skipTest('database does not support jsonb')
3209            self.fail(str(error))
3210        r = self.db.get_attnames('arraytest')
3211        self.assertEqual(r['data'], 'jsonb[]' if self.regtypes else 'json[]')
3212        data = [dict(id=815, name='John Doe'), dict(id=816, name='Jane Roe')]
3213        jsondecode = pg.get_jsondecode()
3214        r = dict(data=data)
3215        self.db.insert('arraytest', r)
3216        if jsondecode is None:
3217            r['data'] = [json.loads(d) for d in r['data']]
3218        self.assertEqual(r['data'], data)
3219        r['data'] = None
3220        self.db.get('arraytest', r)
3221        if jsondecode is None:
3222            r['data'] = [json.loads(d) for d in r['data']]
3223        self.assertEqual(r['data'], data)
3224        r = dict(data=[json.dumps(d) for d in data])
3225        self.db.insert('arraytest', r)
3226        if jsondecode is None:
3227            r['data'] = [json.loads(d) for d in r['data']]
3228        self.assertEqual(r['data'], data)
3229        r['data'] = None
3230        self.db.get('arraytest', r)
3231        # insert empty json values
3232        r = dict(data=['', None])
3233        self.db.insert('arraytest', r)
3234        r = r['data']
3235        self.assertIsInstance(r, list)
3236        self.assertEqual(len(r), 2)
3237        self.assertIsNone(r[0])
3238        self.assertIsNone(r[1])
3239
3240    def testDeepArray(self):
3241        self.createTable('arraytest', 'data text[][][]', oids=True)
3242        r = self.db.get_attnames('arraytest')
3243        self.assertEqual(r['data'], 'text[]')
3244        data = [[['Hello, World!', '{a,b,c}', 'back\\slash']]]
3245        r = dict(data=data)
3246        self.db.insert('arraytest', r)
3247        self.assertEqual(r['data'], data)
3248        r['data'] = None
3249        self.db.get('arraytest', r)
3250        self.assertEqual(r['data'], data)
3251
3252    def testInsertUpdateGetRecord(self):
3253        query = self.db.query
3254        query('create type test_person_type as'
3255            ' (name varchar, age smallint, married bool,'
3256              ' weight real, salary money)')
3257        self.addCleanup(query, 'drop type test_person_type')
3258        self.createTable('test_person', 'person test_person_type',
3259            temporary=False, oids=True)
3260        attnames = self.db.get_attnames('test_person')
3261        self.assertEqual(len(attnames), 2)
3262        self.assertIn('oid', attnames)
3263        self.assertIn('person', attnames)
3264        person_typ = attnames['person']
3265        if self.regtypes:
3266            self.assertEqual(person_typ, 'test_person_type')
3267        else:
3268            self.assertEqual(person_typ, 'record')
3269        if self.regtypes:
3270            self.assertEqual(person_typ.attnames,
3271                dict(name='character varying', age='smallint',
3272                    married='boolean', weight='real', salary='money'))
3273        else:
3274            self.assertEqual(person_typ.attnames,
3275                dict(name='text', age='int', married='bool',
3276                    weight='float', salary='money'))
3277        decimal = pg.get_decimal()
3278        if pg.get_bool():
3279            bool_class = bool
3280            t, f = True, False
3281        else:
3282            bool_class = str
3283            t, f = 't', 'f'
3284        person = ('John Doe', 61, t, 99.5, decimal('93456.75'))
3285        r = self.db.insert('test_person', None, person=person)
3286        p = r['person']
3287        self.assertIsInstance(p, tuple)
3288        self.assertEqual(p, person)
3289        self.assertEqual(p.name, 'John Doe')
3290        self.assertIsInstance(p.name, str)
3291        self.assertIsInstance(p.age, int)
3292        self.assertIsInstance(p.married, bool_class)
3293        self.assertIsInstance(p.weight, float)
3294        self.assertIsInstance(p.salary, decimal)
3295        person = ('Jane Roe', 59, f, 64.5, decimal('96543.25'))
3296        r['person'] = person
3297        self.db.update('test_person', r)
3298        p = r['person']
3299        self.assertIsInstance(p, tuple)
3300        self.assertEqual(p, person)
3301        self.assertEqual(p.name, 'Jane Roe')
3302        self.assertIsInstance(p.name, str)
3303        self.assertIsInstance(p.age, int)
3304        self.assertIsInstance(p.married, bool_class)
3305        self.assertIsInstance(p.weight, float)
3306        self.assertIsInstance(p.salary, decimal)
3307        r['person'] = None
3308        self.db.get('test_person', r)
3309        p = r['person']
3310        self.assertIsInstance(p, tuple)
3311        self.assertEqual(p, person)
3312        self.assertEqual(p.name, 'Jane Roe')
3313        self.assertIsInstance(p.name, str)
3314        self.assertIsInstance(p.age, int)
3315        self.assertIsInstance(p.married, bool_class)
3316        self.assertIsInstance(p.weight, float)
3317        self.assertIsInstance(p.salary, decimal)
3318        person = (None,) * 5
3319        r = self.db.insert('test_person', None, person=person)
3320        p = r['person']
3321        self.assertIsInstance(p, tuple)
3322        self.assertIsNone(p.name)
3323        self.assertIsNone(p.age)
3324        self.assertIsNone(p.married)
3325        self.assertIsNone(p.weight)
3326        self.assertIsNone(p.salary)
3327        r['person'] = None
3328        self.db.get('test_person', r)
3329        p = r['person']
3330        self.assertIsInstance(p, tuple)
3331        self.assertIsNone(p.name)
3332        self.assertIsNone(p.age)
3333        self.assertIsNone(p.married)
3334        self.assertIsNone(p.weight)
3335        self.assertIsNone(p.salary)
3336        r = self.db.insert('test_person', None, person=None)
3337        self.assertIsNone(r['person'])
3338        r['person'] = None
3339        self.db.get('test_person', r)
3340        self.assertIsNone(r['person'])
3341
3342    def testRecordInsertBytea(self):
3343        query = self.db.query
3344        query('create type test_person_type as'
3345            ' (name text, picture bytea)')
3346        self.addCleanup(query, 'drop type test_person_type')
3347        self.createTable('test_person', 'person test_person_type',
3348            temporary=False, oids=True)
3349        person_typ = self.db.get_attnames('test_person')['person']
3350        self.assertEqual(person_typ.attnames,
3351            dict(name='text', picture='bytea'))
3352        person = ('John Doe', b'O\x00ps\xff!')
3353        r = self.db.insert('test_person', None, person=person)
3354        p = r['person']
3355        self.assertIsInstance(p, tuple)
3356        self.assertEqual(p, person)
3357        self.assertEqual(p.name, 'John Doe')
3358        self.assertIsInstance(p.name, str)
3359        self.assertEqual(p.picture, person[1])
3360        self.assertIsInstance(p.picture, bytes)
3361
3362    def testRecordInsertJson(self):
3363        query = self.db.query
3364        try:
3365            query('create type test_person_type as'
3366                ' (name text, data json)')
3367        except pg.ProgrammingError as error:
3368            if self.db.server_version < 90200:
3369                self.skipTest('database does not support json')
3370            self.fail(str(error))
3371        self.addCleanup(query, 'drop type test_person_type')
3372        self.createTable('test_person', 'person test_person_type',
3373            temporary=False, oids=True)
3374        person_typ = self.db.get_attnames('test_person')['person']
3375        self.assertEqual(person_typ.attnames,
3376            dict(name='text', data='json'))
3377        person = ('John Doe', dict(age=61, married=True, weight=99.5))
3378        r = self.db.insert('test_person', None, person=person)
3379        p = r['person']
3380        self.assertIsInstance(p, tuple)
3381        if pg.get_jsondecode() is None:
3382            p = p._replace(data=json.loads(p.data))
3383        self.assertEqual(p, person)
3384        self.assertEqual(p.name, 'John Doe')
3385        self.assertIsInstance(p.name, str)
3386        self.assertEqual(p.data, person[1])
3387        self.assertIsInstance(p.data, dict)
3388
3389    def testRecordLiteral(self):
3390        query = self.db.query
3391        query('create type test_person_type as'
3392            ' (name varchar, age smallint)')
3393        self.addCleanup(query, 'drop type test_person_type')
3394        self.createTable('test_person', 'person test_person_type',
3395            temporary=False, oids=True)
3396        person_typ = self.db.get_attnames('test_person')['person']
3397        if self.regtypes:
3398            self.assertEqual(person_typ, 'test_person_type')
3399        else:
3400            self.assertEqual(person_typ, 'record')
3401        if self.regtypes:
3402            self.assertEqual(person_typ.attnames,
3403                dict(name='character varying', age='smallint'))
3404        else:
3405            self.assertEqual(person_typ.attnames,
3406                dict(name='text', age='int'))
3407        person = pg._Literal("('John Doe', 61)")
3408        r = self.db.insert('test_person', None, person=person)
3409        p = r['person']
3410        self.assertIsInstance(p, tuple)
3411        self.assertEqual(p.name, 'John Doe')
3412        self.assertIsInstance(p.name, str)
3413        self.assertEqual(p.age, 61)
3414        self.assertIsInstance(p.age, int)
3415
3416    def testNotificationHandler(self):
3417        # the notification handler itself is tested separately
3418        f = self.db.notification_handler
3419        callback = lambda arg_dict: None
3420        handler = f('test', callback)
3421        self.assertIsInstance(handler, pg.NotificationHandler)
3422        self.assertIs(handler.db, self.db)
3423        self.assertEqual(handler.event, 'test')
3424        self.assertEqual(handler.stop_event, 'stop_test')
3425        self.assertIs(handler.callback, callback)
3426        self.assertIsInstance(handler.arg_dict, dict)
3427        self.assertEqual(handler.arg_dict, {})
3428        self.assertIsNone(handler.timeout)
3429        self.assertFalse(handler.listening)
3430        handler.close()
3431        self.assertIsNone(handler.db)
3432        self.db.reopen()
3433        self.assertIsNone(handler.db)
3434        handler = f('test2', callback, timeout=2)
3435        self.assertIsInstance(handler, pg.NotificationHandler)
3436        self.assertIs(handler.db, self.db)
3437        self.assertEqual(handler.event, 'test2')
3438        self.assertEqual(handler.stop_event, 'stop_test2')
3439        self.assertIs(handler.callback, callback)
3440        self.assertIsInstance(handler.arg_dict, dict)
3441        self.assertEqual(handler.arg_dict, {})
3442        self.assertEqual(handler.timeout, 2)
3443        self.assertFalse(handler.listening)
3444        handler.close()
3445        self.assertIsNone(handler.db)
3446        self.db.reopen()
3447        self.assertIsNone(handler.db)
3448        arg_dict = {'testing': 3}
3449        handler = f('test3', callback, arg_dict=arg_dict)
3450        self.assertIsInstance(handler, pg.NotificationHandler)
3451        self.assertIs(handler.db, self.db)
3452        self.assertEqual(handler.event, 'test3')
3453        self.assertEqual(handler.stop_event, 'stop_test3')
3454        self.assertIs(handler.callback, callback)
3455        self.assertIs(handler.arg_dict, arg_dict)
3456        self.assertEqual(arg_dict['testing'], 3)
3457        self.assertIsNone(handler.timeout)
3458        self.assertFalse(handler.listening)
3459        handler.close()
3460        self.assertIsNone(handler.db)
3461        self.db.reopen()
3462        self.assertIsNone(handler.db)
3463        handler = f('test4', callback, stop_event='stop4')
3464        self.assertIsInstance(handler, pg.NotificationHandler)
3465        self.assertIs(handler.db, self.db)
3466        self.assertEqual(handler.event, 'test4')
3467        self.assertEqual(handler.stop_event, 'stop4')
3468        self.assertIs(handler.callback, callback)
3469        self.assertIsInstance(handler.arg_dict, dict)
3470        self.assertEqual(handler.arg_dict, {})
3471        self.assertIsNone(handler.timeout)
3472        self.assertFalse(handler.listening)
3473        handler.close()
3474        self.assertIsNone(handler.db)
3475        self.db.reopen()
3476        self.assertIsNone(handler.db)
3477        arg_dict = {'testing': 5}
3478        handler = f('test5', callback, arg_dict, 1.5, 'stop5')
3479        self.assertIsInstance(handler, pg.NotificationHandler)
3480        self.assertIs(handler.db, self.db)
3481        self.assertEqual(handler.event, 'test5')
3482        self.assertEqual(handler.stop_event, 'stop5')
3483        self.assertIs(handler.callback, callback)
3484        self.assertIs(handler.arg_dict, arg_dict)
3485        self.assertEqual(arg_dict['testing'], 5)
3486        self.assertEqual(handler.timeout, 1.5)
3487        self.assertFalse(handler.listening)
3488        handler.close()
3489        self.assertIsNone(handler.db)
3490        self.db.reopen()
3491        self.assertIsNone(handler.db)
3492
3493
3494class TestDBClassNonStdOpts(TestDBClass):
3495    """Test the methods of the DB class with non-standard global options."""
3496
3497    @classmethod
3498    def setUpClass(cls):
3499        cls.saved_options = {}
3500        cls.set_option('decimal', float)
3501        not_bool = not pg.get_bool()
3502        cls.set_option('bool', not_bool)
3503        cls.set_option('namedresult', None)
3504        cls.set_option('jsondecode', None)
3505        cls.regtypes = not DB().use_regtypes()
3506        super(TestDBClassNonStdOpts, cls).setUpClass()
3507
3508    @classmethod
3509    def tearDownClass(cls):
3510        super(TestDBClassNonStdOpts, cls).tearDownClass()
3511        cls.reset_option('jsondecode')
3512        cls.reset_option('namedresult')
3513        cls.reset_option('bool')
3514        cls.reset_option('decimal')
3515
3516    @classmethod
3517    def set_option(cls, option, value):
3518        cls.saved_options[option] = getattr(pg, 'get_' + option)()
3519        return getattr(pg, 'set_' + option)(value)
3520
3521    @classmethod
3522    def reset_option(cls, option):
3523        return getattr(pg, 'set_' + option)(cls.saved_options[option])
3524
3525
3526class TestSchemas(unittest.TestCase):
3527    """Test correct handling of schemas (namespaces)."""
3528
3529    cls_set_up = False
3530
3531    @classmethod
3532    def setUpClass(cls):
3533        db = DB()
3534        query = db.query
3535        for num_schema in range(5):
3536            if num_schema:
3537                schema = "s%d" % num_schema
3538                query("drop schema if exists %s cascade" % (schema,))
3539                try:
3540                    query("create schema %s" % (schema,))
3541                except pg.ProgrammingError:
3542                    raise RuntimeError("The test user cannot create schemas.\n"
3543                        "Grant create on database %s to the user"
3544                        " for running these tests." % dbname)
3545            else:
3546                schema = "public"
3547                query("drop table if exists %s.t" % (schema,))
3548                query("drop table if exists %s.t%d" % (schema, num_schema))
3549            query("create table %s.t with oids as select 1 as n, %d as d"
3550                  % (schema, num_schema))
3551            query("create table %s.t%d with oids as select 1 as n, %d as d"
3552                  % (schema, num_schema, num_schema))
3553        db.close()
3554        cls.cls_set_up = True
3555
3556    @classmethod
3557    def tearDownClass(cls):
3558        db = DB()
3559        query = db.query
3560        for num_schema in range(5):
3561            if num_schema:
3562                schema = "s%d" % num_schema
3563                query("drop schema %s cascade" % (schema,))
3564            else:
3565                schema = "public"
3566                query("drop table %s.t" % (schema,))
3567                query("drop table %s.t%d" % (schema, num_schema))
3568        db.close()
3569
3570    def setUp(self):
3571        self.assertTrue(self.cls_set_up)
3572        self.db = DB()
3573
3574    def tearDown(self):
3575        self.doCleanups()
3576        self.db.close()
3577
3578    def testGetTables(self):
3579        tables = self.db.get_tables()
3580        for num_schema in range(5):
3581            if num_schema:
3582                schema = "s" + str(num_schema)
3583            else:
3584                schema = "public"
3585            for t in (schema + ".t",
3586                    schema + ".t" + str(num_schema)):
3587                self.assertIn(t, tables)
3588
3589    def testGetAttnames(self):
3590        get_attnames = self.db.get_attnames
3591        query = self.db.query
3592        result = {'oid': 'int', 'd': 'int', 'n': 'int'}
3593        r = get_attnames("t")
3594        self.assertEqual(r, result)
3595        r = get_attnames("s4.t4")
3596        self.assertEqual(r, result)
3597        query("drop table if exists s3.t3m")
3598        self.addCleanup(query, "drop table s3.t3m")
3599        query("create table s3.t3m with oids as select 1 as m")
3600        result_m = {'oid': 'int', 'm': 'int'}
3601        r = get_attnames("s3.t3m")
3602        self.assertEqual(r, result_m)
3603        query("set search_path to s1,s3")
3604        r = get_attnames("t3")
3605        self.assertEqual(r, result)
3606        r = get_attnames("t3m")
3607        self.assertEqual(r, result_m)
3608
3609    def testGet(self):
3610        get = self.db.get
3611        query = self.db.query
3612        PrgError = pg.ProgrammingError
3613        self.assertEqual(get("t", 1, 'n')['d'], 0)
3614        self.assertEqual(get("t0", 1, 'n')['d'], 0)
3615        self.assertEqual(get("public.t", 1, 'n')['d'], 0)
3616        self.assertEqual(get("public.t0", 1, 'n')['d'], 0)
3617        self.assertRaises(PrgError, get, "public.t1", 1, 'n')
3618        self.assertEqual(get("s1.t1", 1, 'n')['d'], 1)
3619        self.assertEqual(get("s3.t", 1, 'n')['d'], 3)
3620        query("set search_path to s2,s4")
3621        self.assertRaises(PrgError, get, "t1", 1, 'n')
3622        self.assertEqual(get("t4", 1, 'n')['d'], 4)
3623        self.assertRaises(PrgError, get, "t3", 1, 'n')
3624        self.assertEqual(get("t", 1, 'n')['d'], 2)
3625        self.assertEqual(get("s3.t3", 1, 'n')['d'], 3)
3626        query("set search_path to s1,s3")
3627        self.assertRaises(PrgError, get, "t2", 1, 'n')
3628        self.assertEqual(get("t3", 1, 'n')['d'], 3)
3629        self.assertRaises(PrgError, get, "t4", 1, 'n')
3630        self.assertEqual(get("t", 1, 'n')['d'], 1)
3631        self.assertEqual(get("s4.t4", 1, 'n')['d'], 4)
3632
3633    def testMunging(self):
3634        get = self.db.get
3635        query = self.db.query
3636        r = get("t", 1, 'n')
3637        self.assertIn('oid(t)', r)
3638        query("set search_path to s2")
3639        r = get("t2", 1, 'n')
3640        self.assertIn('oid(t2)', r)
3641        query("set search_path to s3")
3642        r = get("t", 1, 'n')
3643        self.assertIn('oid(t)', r)
3644
3645
3646class TestDebug(unittest.TestCase):
3647    """Test the debug attribute of the DB class."""
3648
3649    def setUp(self):
3650        self.db = DB()
3651        self.query = self.db.query
3652        self.debug = self.db.debug
3653        self.output = StringIO()
3654        self.stdout, sys.stdout = sys.stdout, self.output
3655
3656    def tearDown(self):
3657        sys.stdout = self.stdout
3658        self.output.close()
3659        self.db.debug = debug
3660        self.db.close()
3661
3662    def get_output(self):
3663        return self.output.getvalue()
3664
3665    def send_queries(self):
3666        self.db.query("select 1")
3667        self.db.query("select 2")
3668
3669    def testDebugDefault(self):
3670        if debug:
3671            self.assertEqual(self.db.debug, debug)
3672        else:
3673            self.assertIsNone(self.db.debug)
3674
3675    def testDebugIsFalse(self):
3676        self.db.debug = False
3677        self.send_queries()
3678        self.assertEqual(self.get_output(), "")
3679
3680    def testDebugIsTrue(self):
3681        self.db.debug = True
3682        self.send_queries()
3683        self.assertEqual(self.get_output(), "select 1\nselect 2\n")
3684
3685    def testDebugIsString(self):
3686        self.db.debug = "Test with string: %s."
3687        self.send_queries()
3688        self.assertEqual(self.get_output(),
3689            "Test with string: select 1.\nTest with string: select 2.\n")
3690
3691    def testDebugIsFileLike(self):
3692        with tempfile.TemporaryFile('w+') as debug_file:
3693            self.db.debug = debug_file
3694            self.send_queries()
3695            debug_file.seek(0)
3696            output = debug_file.read()
3697            self.assertEqual(output, "select 1\nselect 2\n")
3698            self.assertEqual(self.get_output(), "")
3699
3700    def testDebugIsCallable(self):
3701        output = []
3702        self.db.debug = output.append
3703        self.db.query("select 1")
3704        self.db.query("select 2")
3705        self.assertEqual(output, ["select 1", "select 2"])
3706        self.assertEqual(self.get_output(), "")
3707
3708    def testDebugMultipleArgs(self):
3709        output = []
3710        self.db.debug = output.append
3711        args = ['Error', 42, {1: 'a', 2: 'b'}, [3, 5, 7]]
3712        self.db._do_debug(*args)
3713        self.assertEqual(output, ['\n'.join(str(arg) for arg in args)])
3714        self.assertEqual(self.get_output(), "")
3715
3716
3717if __name__ == '__main__':
3718    unittest.main()
Note: See TracBrowser for help on using the repository browser.