source: trunk/tests/test_classic_dbwrapper.py @ 774

Last change on this file since 774 was 774, checked in by cito, 4 years ago

Add support for JSON and JSONB to pg and pgdb

This adds all necessary functions to make PyGreSQL automatically
convert between JSON columns and Python objects representing them.

The documentation has also been updated, see there for the details.

Also, tuples automatically bind to ROW expressions in pgdb now.

  • Property svn:executable set to *
  • Property svn:keywords set to Id
File size: 127.1 KB
Line 
1#! /usr/bin/python
2# -*- coding: utf-8 -*-
3
4"""Test the classic PyGreSQL interface.
5
6Sub-tests for the DB wrapper object.
7
8Contributed by Christoph Zwerschke.
9
10These tests need a database to test against.
11"""
12
13try:
14    import unittest2 as unittest  # for Python < 2.7
15except ImportError:
16    import unittest
17
18import os
19import sys
20import tempfile
21import json
22
23import pg  # the module under test
24
25from decimal import Decimal
26from operator import itemgetter
27
28# We need a database to test against.  If LOCAL_PyGreSQL.py exists we will
29# get our information from that.  Otherwise we use the defaults.
30# The current user must have create schema privilege on the database.
31dbname = 'unittest'
32dbhost = None
33dbport = 5432
34
35debug = False  # let DB wrapper print debugging output
36
37try:
38    from .LOCAL_PyGreSQL import *
39except (ImportError, ValueError):
40    try:
41        from LOCAL_PyGreSQL import *
42    except ImportError:
43        pass
44
45try:
46    long
47except NameError:  # Python >= 3.0
48    long = int
49
50try:
51    unicode
52except NameError:  # Python >= 3.0
53    unicode = str
54
55try:
56    from collections import OrderedDict
57except ImportError:  # Python 2.6 or 3.0
58    OrderedDict = dict
59
60if str is bytes:
61    from StringIO import StringIO
62else:
63    from io import StringIO
64
65windows = os.name == 'nt'
66
67# There is a known a bug in libpq under Windows which can cause
68# the interface to crash when calling PQhost():
69do_not_ask_for_host = windows
70do_not_ask_for_host_reason = 'libpq issue on Windows'
71
72
73def DB():
74    """Create a DB wrapper object connecting to the test database."""
75    db = pg.DB(dbname, dbhost, dbport)
76    if debug:
77        db.debug = debug
78    db.query("set client_min_messages=warning")
79    return db
80
81
82class TestAttrDict(unittest.TestCase):
83    """Test the simple ordered dictionary for attribute names."""
84
85    cls = pg.AttrDict
86    base = OrderedDict
87
88    def testInit(self):
89        a = self.cls()
90        self.assertIsInstance(a, self.base)
91        self.assertEqual(a, self.base())
92        items = [('id', 'int'), ('name', 'text')]
93        a = self.cls(items)
94        self.assertIsInstance(a, self.base)
95        self.assertEqual(a, self.base(items))
96        iteritems = iter(items)
97        a = self.cls(iteritems)
98        self.assertIsInstance(a, self.base)
99        self.assertEqual(a, self.base(items))
100
101    def testIter(self):
102        a = self.cls()
103        self.assertEqual(list(a), [])
104        keys = ['id', 'name', 'age']
105        items = [(key, None) for key in keys]
106        a = self.cls(items)
107        self.assertEqual(list(a), keys)
108
109    def testKeys(self):
110        a = self.cls()
111        self.assertEqual(list(a.keys()), [])
112        keys = ['id', 'name', 'age']
113        items = [(key, None) for key in keys]
114        a = self.cls(items)
115        self.assertEqual(list(a.keys()), keys)
116
117    def testValues(self):
118        a = self.cls()
119        self.assertEqual(list(a.values()), [])
120        items = [('id', 'int'), ('name', 'text')]
121        values = [item[1] for item in items]
122        a = self.cls(items)
123        self.assertEqual(list(a.values()), values)
124
125    def testItems(self):
126        a = self.cls()
127        self.assertEqual(list(a.items()), [])
128        items = [('id', 'int'), ('name', 'text')]
129        a = self.cls(items)
130        self.assertEqual(list(a.items()), items)
131
132    def testGet(self):
133        a = self.cls([('id', 1)])
134        try:
135            self.assertEqual(a['id'], 1)
136        except KeyError:
137            self.fail('AttrDict should be readable')
138
139    def testSet(self):
140        a = self.cls()
141        try:
142            a['id'] = 1
143        except TypeError:
144            pass
145        else:
146            self.fail('AttrDict should be read-only')
147
148    def testDel(self):
149        a = self.cls([('id', 1)])
150        try:
151            del a['id']
152        except TypeError:
153            pass
154        else:
155            self.fail('AttrDict should be read-only')
156
157    def testWriteMethods(self):
158        a = self.cls([('id', 1)])
159        self.assertEqual(a['id'], 1)
160        for method in 'clear', 'update', 'pop', 'setdefault', 'popitem':
161            method = getattr(a, method)
162            self.assertRaises(TypeError, method, a)
163
164
165class TestDBClassBasic(unittest.TestCase):
166    """Test existence of the DB class wrapped pg connection methods."""
167
168    def setUp(self):
169        self.db = DB()
170
171    def tearDown(self):
172        try:
173            self.db.close()
174        except pg.InternalError:
175            pass
176
177    def testAllDBAttributes(self):
178        attributes = [
179            'abort',
180            'begin',
181            'cancel', 'clear', 'close', 'commit',
182            'db', 'dbname', 'debug', 'decode_json', 'delete',
183            'encode_json', 'end', 'endcopy', 'error',
184            'escape_bytea', 'escape_identifier',
185            'escape_literal', 'escape_string',
186            'fileno',
187            'get', 'get_as_dict', 'get_as_list',
188            'get_attnames', 'get_databases',
189            'get_notice_receiver', 'get_parameter',
190            'get_relations', 'get_tables',
191            'getline', 'getlo', 'getnotify',
192            'has_table_privilege', 'host',
193            'insert', 'inserttable',
194            'locreate', 'loimport',
195            'notification_handler',
196            'options',
197            'parameter', 'pkey', 'port',
198            'protocol_version', 'putline',
199            'query',
200            'release', 'reopen', 'reset', 'rollback',
201            'savepoint', 'server_version',
202            'set_notice_receiver', 'set_parameter',
203            'source', 'start', 'status',
204            'transaction', 'truncate',
205            'unescape_bytea', 'update', 'upsert',
206            'use_regtypes', 'user',
207        ]
208        db_attributes = [a for a in dir(self.db)
209            if not a.startswith('_')]
210        self.assertEqual(attributes, db_attributes)
211
212    def testAttributeDb(self):
213        self.assertEqual(self.db.db.db, dbname)
214
215    def testAttributeDbname(self):
216        self.assertEqual(self.db.dbname, dbname)
217
218    def testAttributeError(self):
219        error = self.db.error
220        self.assertTrue(not error or 'krb5_' in error)
221        self.assertEqual(self.db.error, self.db.db.error)
222
223    @unittest.skipIf(do_not_ask_for_host, do_not_ask_for_host_reason)
224    def testAttributeHost(self):
225        def_host = 'localhost'
226        host = self.db.host
227        self.assertIsInstance(host, str)
228        self.assertEqual(host, dbhost or def_host)
229        self.assertEqual(host, self.db.db.host)
230
231    def testAttributeOptions(self):
232        no_options = ''
233        options = self.db.options
234        self.assertEqual(options, no_options)
235        self.assertEqual(options, self.db.db.options)
236
237    def testAttributePort(self):
238        def_port = 5432
239        port = self.db.port
240        self.assertIsInstance(port, int)
241        self.assertEqual(port, dbport or def_port)
242        self.assertEqual(port, self.db.db.port)
243
244    def testAttributeProtocolVersion(self):
245        protocol_version = self.db.protocol_version
246        self.assertIsInstance(protocol_version, int)
247        self.assertTrue(2 <= protocol_version < 4)
248        self.assertEqual(protocol_version, self.db.db.protocol_version)
249
250    def testAttributeServerVersion(self):
251        server_version = self.db.server_version
252        self.assertIsInstance(server_version, int)
253        self.assertTrue(70400 <= server_version < 100000)
254        self.assertEqual(server_version, self.db.db.server_version)
255
256    def testAttributeStatus(self):
257        status_ok = 1
258        status = self.db.status
259        self.assertIsInstance(status, int)
260        self.assertEqual(status, status_ok)
261        self.assertEqual(status, self.db.db.status)
262
263    def testAttributeUser(self):
264        no_user = 'Deprecated facility'
265        user = self.db.user
266        self.assertTrue(user)
267        self.assertIsInstance(user, str)
268        self.assertNotEqual(user, no_user)
269        self.assertEqual(user, self.db.db.user)
270
271    def testMethodEscapeLiteral(self):
272        self.assertEqual(self.db.escape_literal(''), "''")
273
274    def testMethodEscapeIdentifier(self):
275        self.assertEqual(self.db.escape_identifier(''), '""')
276
277    def testMethodEscapeString(self):
278        self.assertEqual(self.db.escape_string(''), '')
279
280    def testMethodEscapeBytea(self):
281        self.assertEqual(self.db.escape_bytea('').replace(
282            '\\x', '').replace('\\', ''), '')
283
284    def testMethodUnescapeBytea(self):
285        self.assertEqual(self.db.unescape_bytea(''), b'')
286
287    def testMethodDecodeJson(self):
288        self.assertEqual(self.db.decode_json('{}'), {})
289
290    def testMethodEncodeJson(self):
291        self.assertEqual(self.db.encode_json({}), '{}')
292
293    def testMethodQuery(self):
294        query = self.db.query
295        query("select 1+1")
296        query("select 1+$1+$2", 2, 3)
297        query("select 1+$1+$2", (2, 3))
298        query("select 1+$1+$2", [2, 3])
299        query("select 1+$1", 1)
300
301    def testMethodQueryEmpty(self):
302        self.assertRaises(ValueError, self.db.query, '')
303
304    def testMethodQueryProgrammingError(self):
305        try:
306            self.db.query("select 1/0")
307        except pg.ProgrammingError as error:
308            self.assertEqual(error.sqlstate, '22012')
309
310    def testMethodEndcopy(self):
311        try:
312            self.db.endcopy()
313        except IOError:
314            pass
315
316    def testMethodClose(self):
317        self.db.close()
318        try:
319            self.db.reset()
320        except pg.Error:
321            pass
322        else:
323            self.fail('Reset should give an error for a closed connection')
324        self.assertIsNone(self.db.db)
325        self.assertRaises(pg.InternalError, self.db.close)
326        self.assertRaises(pg.InternalError, self.db.query, 'select 1')
327        self.assertRaises(pg.InternalError, getattr, self.db, 'status')
328        self.assertRaises(pg.InternalError, getattr, self.db, 'error')
329        self.assertRaises(pg.InternalError, getattr, self.db, 'absent')
330
331    def testMethodReset(self):
332        con = self.db.db
333        self.db.reset()
334        self.assertIs(self.db.db, con)
335        self.db.query("select 1+1")
336        self.db.close()
337        self.assertRaises(pg.InternalError, self.db.reset)
338
339    def testMethodReopen(self):
340        con = self.db.db
341        self.db.reopen()
342        self.assertIsNot(self.db.db, con)
343        con = self.db.db
344        self.db.query("select 1+1")
345        self.db.close()
346        self.db.reopen()
347        self.assertIsNot(self.db.db, con)
348        self.db.query("select 1+1")
349        self.db.close()
350
351    def testExistingConnection(self):
352        db = pg.DB(self.db.db)
353        self.assertEqual(self.db.db, db.db)
354        self.assertTrue(db.db)
355        db.close()
356        self.assertTrue(db.db)
357        db.reopen()
358        self.assertTrue(db.db)
359        db.close()
360        self.assertTrue(db.db)
361        db = pg.DB(self.db)
362        self.assertEqual(self.db.db, db.db)
363        db = pg.DB(db=self.db.db)
364        self.assertEqual(self.db.db, db.db)
365
366        class DB2:
367            pass
368
369        db2 = DB2()
370        db2._cnx = self.db.db
371        db = pg.DB(db2)
372        self.assertEqual(self.db.db, db.db)
373
374
375class TestDBClass(unittest.TestCase):
376    """Test the methods of the DB class wrapped pg connection."""
377
378    cls_set_up = False
379
380    @classmethod
381    def setUpClass(cls):
382        db = DB()
383        db.query("drop table if exists test cascade")
384        db.query("create table test ("
385            "i2 smallint, i4 integer, i8 bigint,"
386            " d numeric, f4 real, f8 double precision, m money,"
387            " v4 varchar(4), c4 char(4), t text)")
388        db.query("create or replace view test_view as"
389            " select i4, v4 from test")
390        db.close()
391        cls.cls_set_up = True
392
393    @classmethod
394    def tearDownClass(cls):
395        db = DB()
396        db.query("drop table test cascade")
397        db.close()
398
399    def setUp(self):
400        self.assertTrue(self.cls_set_up)
401        self.db = DB()
402        query = self.db.query
403        query('set client_encoding=utf8')
404        query('set standard_conforming_strings=on')
405        query("set lc_monetary='C'")
406        query("set datestyle='ISO,YMD'")
407        query('set bytea_output=hex')
408
409    def tearDown(self):
410        self.doCleanups()
411        self.db.close()
412
413    def createTable(self, table, definition,
414            temporary=True, oids=None, values=None):
415        query = self.db.query
416        if not '"' in table or '.' in table:
417            table = '"%s"' % table
418        if not temporary:
419            q = 'drop table if exists %s cascade' % table
420            query(q)
421            self.addCleanup(query, q)
422        temporary = 'temporary table' if temporary else 'table'
423        as_query = definition.startswith(('as ', 'AS '))
424        if not as_query and not definition.startswith('('):
425            definition = '(%s)' % definition
426        with_oids = 'with oids' if oids else 'without oids'
427        q = ['create', temporary, table]
428        if as_query:
429            q.extend([with_oids, definition])
430        else:
431            q.extend([definition, with_oids])
432        q = ' '.join(q)
433        query(q)
434        if values:
435            for params in values:
436                if not isinstance(params, (list, tuple)):
437                    params = [params]
438                values = ', '.join('$%d' % (n + 1) for n in range(len(params)))
439                q = "insert into %s values (%s)" % (table, values)
440                query(q, params)
441
442    def testClassName(self):
443        self.assertEqual(self.db.__class__.__name__, 'DB')
444
445    def testModuleName(self):
446        self.assertEqual(self.db.__module__, 'pg')
447        self.assertEqual(self.db.__class__.__module__, 'pg')
448
449    def testEscapeLiteral(self):
450        f = self.db.escape_literal
451        r = f(b"plain")
452        self.assertIsInstance(r, bytes)
453        self.assertEqual(r, b"'plain'")
454        r = f(u"plain")
455        self.assertIsInstance(r, unicode)
456        self.assertEqual(r, u"'plain'")
457        r = f(u"that's kÀse".encode('utf-8'))
458        self.assertIsInstance(r, bytes)
459        self.assertEqual(r, u"'that''s kÀse'".encode('utf-8'))
460        r = f(u"that's kÀse")
461        self.assertIsInstance(r, unicode)
462        self.assertEqual(r, u"'that''s kÀse'")
463        self.assertEqual(f(r"It's fine to have a \ inside."),
464            r" E'It''s fine to have a \\ inside.'")
465        self.assertEqual(f('No "quotes" must be escaped.'),
466            "'No \"quotes\" must be escaped.'")
467
468    def testEscapeIdentifier(self):
469        f = self.db.escape_identifier
470        r = f(b"plain")
471        self.assertIsInstance(r, bytes)
472        self.assertEqual(r, b'"plain"')
473        r = f(u"plain")
474        self.assertIsInstance(r, unicode)
475        self.assertEqual(r, u'"plain"')
476        r = f(u"that's kÀse".encode('utf-8'))
477        self.assertIsInstance(r, bytes)
478        self.assertEqual(r, u'"that\'s kÀse"'.encode('utf-8'))
479        r = f(u"that's kÀse")
480        self.assertIsInstance(r, unicode)
481        self.assertEqual(r, u'"that\'s kÀse"')
482        self.assertEqual(f(r"It's fine to have a \ inside."),
483            '"It\'s fine to have a \\ inside."')
484        self.assertEqual(f('All "quotes" must be escaped.'),
485            '"All ""quotes"" must be escaped."')
486
487    def testEscapeString(self):
488        f = self.db.escape_string
489        r = f(b"plain")
490        self.assertIsInstance(r, bytes)
491        self.assertEqual(r, b"plain")
492        r = f(u"plain")
493        self.assertIsInstance(r, unicode)
494        self.assertEqual(r, u"plain")
495        r = f(u"that's kÀse".encode('utf-8'))
496        self.assertIsInstance(r, bytes)
497        self.assertEqual(r, u"that''s kÀse".encode('utf-8'))
498        r = f(u"that's kÀse")
499        self.assertIsInstance(r, unicode)
500        self.assertEqual(r, u"that''s kÀse")
501        self.assertEqual(f(r"It's fine to have a \ inside."),
502            r"It''s fine to have a \ inside.")
503
504    def testEscapeBytea(self):
505        f = self.db.escape_bytea
506        # note that escape_byte always returns hex output since Pg 9.0,
507        # regardless of the bytea_output setting
508        r = f(b'plain')
509        self.assertIsInstance(r, bytes)
510        self.assertEqual(r, b'\\x706c61696e')
511        r = f(u'plain')
512        self.assertIsInstance(r, unicode)
513        self.assertEqual(r, u'\\x706c61696e')
514        r = f(u"das is' kÀse".encode('utf-8'))
515        self.assertIsInstance(r, bytes)
516        self.assertEqual(r, b'\\x64617320697327206bc3a47365')
517        r = f(u"das is' kÀse")
518        self.assertIsInstance(r, unicode)
519        self.assertEqual(r, u'\\x64617320697327206bc3a47365')
520        self.assertEqual(f(b'O\x00ps\xff!'), b'\\x4f007073ff21')
521
522    def testUnescapeBytea(self):
523        f = self.db.unescape_bytea
524        r = f(b'plain')
525        self.assertIsInstance(r, bytes)
526        self.assertEqual(r, b'plain')
527        r = f(u'plain')
528        self.assertIsInstance(r, bytes)
529        self.assertEqual(r, b'plain')
530        r = f(b"das is' k\\303\\244se")
531        self.assertIsInstance(r, bytes)
532        self.assertEqual(r, u"das is' kÀse".encode('utf8'))
533        r = f(u"das is' k\\303\\244se")
534        self.assertIsInstance(r, bytes)
535        self.assertEqual(r, u"das is' kÀse".encode('utf8'))
536        self.assertEqual(f(r'O\\000ps\\377!'), b'O\\000ps\\377!')
537        self.assertEqual(f(r'\\x706c61696e'), b'\\x706c61696e')
538        self.assertEqual(f(r'\\x746861742773206be47365'),
539            b'\\x746861742773206be47365')
540        self.assertEqual(f(r'\\x4f007073ff21'), b'\\x4f007073ff21')
541
542    def testDecodeJson(self):
543        f = self.db.decode_json
544        self.assertIsNone(f('null'))
545        data = {
546          "id": 1, "name": "Foo", "price": 1234.5,
547          "new": True, "note": None,
548          "tags": ["Bar", "Eek"],
549          "stock": {"warehouse": 300, "retail": 20}}
550        text = json.dumps(data)
551        r = f(text)
552        self.assertIsInstance(r, dict)
553        self.assertEqual(r, data)
554        self.assertIsInstance(r['id'], int)
555        self.assertIsInstance(r['name'], unicode)
556        self.assertIsInstance(r['price'], float)
557        self.assertIsInstance(r['new'], bool)
558        self.assertIsInstance(r['tags'], list)
559        self.assertIsInstance(r['stock'], dict)
560
561    def testEncodeJson(self):
562        f = self.db.encode_json
563        self.assertEqual(f(None), 'null')
564        data = {
565          "id": 1, "name": "Foo", "price": 1234.5,
566          "new": True, "note": None,
567          "tags": ["Bar", "Eek"],
568          "stock": {"warehouse": 300, "retail": 20}}
569        text = json.dumps(data)
570        r = f(data)
571        self.assertIsInstance(r, str)
572        self.assertEqual(r, text)
573
574    def testGetParameter(self):
575        f = self.db.get_parameter
576        self.assertRaises(TypeError, f)
577        self.assertRaises(TypeError, f, None)
578        self.assertRaises(TypeError, f, 42)
579        self.assertRaises(TypeError, f, '')
580        self.assertRaises(TypeError, f, [])
581        self.assertRaises(TypeError, f, [''])
582        self.assertRaises(pg.ProgrammingError, f, 'this_does_not_exist')
583        r = f('standard_conforming_strings')
584        self.assertEqual(r, 'on')
585        r = f('lc_monetary')
586        self.assertEqual(r, 'C')
587        r = f('datestyle')
588        self.assertEqual(r, 'ISO, YMD')
589        r = f('bytea_output')
590        self.assertEqual(r, 'hex')
591        r = f(['bytea_output', 'lc_monetary'])
592        self.assertIsInstance(r, list)
593        self.assertEqual(r, ['hex', 'C'])
594        r = f(('standard_conforming_strings', 'datestyle', 'bytea_output'))
595        self.assertEqual(r, ['on', 'ISO, YMD', 'hex'])
596        r = f(set(['bytea_output', 'lc_monetary']))
597        self.assertIsInstance(r, dict)
598        self.assertEqual(r, {'bytea_output': 'hex', 'lc_monetary': 'C'})
599        r = f(set(['Bytea_Output', ' LC_Monetary ']))
600        self.assertIsInstance(r, dict)
601        self.assertEqual(r, {'Bytea_Output': 'hex', ' LC_Monetary ': 'C'})
602        s = dict.fromkeys(('bytea_output', 'lc_monetary'))
603        r = f(s)
604        self.assertIs(r, s)
605        self.assertEqual(r, {'bytea_output': 'hex', 'lc_monetary': 'C'})
606        s = dict.fromkeys(('Bytea_Output', ' LC_Monetary '))
607        r = f(s)
608        self.assertIs(r, s)
609        self.assertEqual(r, {'Bytea_Output': 'hex', ' LC_Monetary ': 'C'})
610
611    def testGetParameterServerVersion(self):
612        r = self.db.get_parameter('server_version_num')
613        self.assertIsInstance(r, str)
614        s = self.db.server_version
615        self.assertIsInstance(s, int)
616        self.assertEqual(r, str(s))
617
618    def testGetParameterAll(self):
619        f = self.db.get_parameter
620        r = f('all')
621        self.assertIsInstance(r, dict)
622        self.assertEqual(r['standard_conforming_strings'], 'on')
623        self.assertEqual(r['lc_monetary'], 'C')
624        self.assertEqual(r['DateStyle'], 'ISO, YMD')
625        self.assertEqual(r['bytea_output'], 'hex')
626
627    def testSetParameter(self):
628        f = self.db.set_parameter
629        g = self.db.get_parameter
630        self.assertRaises(TypeError, f)
631        self.assertRaises(TypeError, f, None)
632        self.assertRaises(TypeError, f, 42)
633        self.assertRaises(TypeError, f, '')
634        self.assertRaises(TypeError, f, [])
635        self.assertRaises(TypeError, f, [''])
636        self.assertRaises(ValueError, f, 'all', 'invalid')
637        self.assertRaises(ValueError, f, {
638            'invalid1': 'value1', 'invalid2': 'value2'}, 'value')
639        self.assertRaises(pg.ProgrammingError, f, 'this_does_not_exist')
640        f('standard_conforming_strings', 'off')
641        self.assertEqual(g('standard_conforming_strings'), 'off')
642        f('datestyle', 'ISO, DMY')
643        self.assertEqual(g('datestyle'), 'ISO, DMY')
644        f(['standard_conforming_strings', 'datestyle'], ['on', 'ISO, DMY'])
645        self.assertEqual(g('standard_conforming_strings'), 'on')
646        self.assertEqual(g('datestyle'), 'ISO, DMY')
647        f(['default_with_oids', 'standard_conforming_strings'], 'off')
648        self.assertEqual(g('default_with_oids'), 'off')
649        self.assertEqual(g('standard_conforming_strings'), 'off')
650        f(('standard_conforming_strings', 'datestyle'), ('on', 'ISO, YMD'))
651        self.assertEqual(g('standard_conforming_strings'), 'on')
652        self.assertEqual(g('datestyle'), 'ISO, YMD')
653        f(('default_with_oids', 'standard_conforming_strings'), 'off')
654        self.assertEqual(g('default_with_oids'), 'off')
655        self.assertEqual(g('standard_conforming_strings'), 'off')
656        f(set(['default_with_oids', 'standard_conforming_strings']), 'on')
657        self.assertEqual(g('default_with_oids'), 'on')
658        self.assertEqual(g('standard_conforming_strings'), 'on')
659        self.assertRaises(ValueError, f, set([ 'default_with_oids',
660            'standard_conforming_strings']), ['off', 'on'])
661        f(set(['default_with_oids', 'standard_conforming_strings']),
662            ['off', 'off'])
663        self.assertEqual(g('default_with_oids'), 'off')
664        self.assertEqual(g('standard_conforming_strings'), 'off')
665        f({'standard_conforming_strings': 'on', 'datestyle': 'ISO, YMD'})
666        self.assertEqual(g('standard_conforming_strings'), 'on')
667        self.assertEqual(g('datestyle'), 'ISO, YMD')
668
669    def testResetParameter(self):
670        db = DB()
671        f = db.set_parameter
672        g = db.get_parameter
673        r = g('default_with_oids')
674        self.assertIn(r, ('on', 'off'))
675        dwi, not_dwi = r, 'off' if r == 'on' else 'on'
676        r = g('standard_conforming_strings')
677        self.assertIn(r, ('on', 'off'))
678        scs, not_scs = r, 'off' if r == 'on' else 'on'
679        f('default_with_oids', not_dwi)
680        f('standard_conforming_strings', not_scs)
681        self.assertEqual(g('default_with_oids'), not_dwi)
682        self.assertEqual(g('standard_conforming_strings'), not_scs)
683        f('default_with_oids')
684        f('standard_conforming_strings', None)
685        self.assertEqual(g('default_with_oids'), dwi)
686        self.assertEqual(g('standard_conforming_strings'), scs)
687        f('default_with_oids', not_dwi)
688        f('standard_conforming_strings', not_scs)
689        self.assertEqual(g('default_with_oids'), not_dwi)
690        self.assertEqual(g('standard_conforming_strings'), not_scs)
691        f(['default_with_oids', 'standard_conforming_strings'], None)
692        self.assertEqual(g('default_with_oids'), dwi)
693        self.assertEqual(g('standard_conforming_strings'), scs)
694        f('default_with_oids', not_dwi)
695        f('standard_conforming_strings', not_scs)
696        self.assertEqual(g('default_with_oids'), not_dwi)
697        self.assertEqual(g('standard_conforming_strings'), not_scs)
698        f(('default_with_oids', 'standard_conforming_strings'))
699        self.assertEqual(g('default_with_oids'), dwi)
700        self.assertEqual(g('standard_conforming_strings'), scs)
701        f('default_with_oids', not_dwi)
702        f('standard_conforming_strings', not_scs)
703        self.assertEqual(g('default_with_oids'), not_dwi)
704        self.assertEqual(g('standard_conforming_strings'), not_scs)
705        f(set(['default_with_oids', 'standard_conforming_strings']))
706        self.assertEqual(g('default_with_oids'), dwi)
707        self.assertEqual(g('standard_conforming_strings'), scs)
708
709    def testResetParameterAll(self):
710        db = DB()
711        f = db.set_parameter
712        self.assertRaises(ValueError, f, 'all', 0)
713        self.assertRaises(ValueError, f, 'all', 'off')
714        g = db.get_parameter
715        r = g('default_with_oids')
716        self.assertIn(r, ('on', 'off'))
717        dwi, not_dwi = r, 'off' if r == 'on' else 'on'
718        r = g('standard_conforming_strings')
719        self.assertIn(r, ('on', 'off'))
720        scs, not_scs = r, 'off' if r == 'on' else 'on'
721        f('default_with_oids', not_dwi)
722        f('standard_conforming_strings', not_scs)
723        self.assertEqual(g('default_with_oids'), not_dwi)
724        self.assertEqual(g('standard_conforming_strings'), not_scs)
725        f('all')
726        self.assertEqual(g('default_with_oids'), dwi)
727        self.assertEqual(g('standard_conforming_strings'), scs)
728
729    def testSetParameterLocal(self):
730        f = self.db.set_parameter
731        g = self.db.get_parameter
732        self.assertEqual(g('standard_conforming_strings'), 'on')
733        self.db.begin()
734        f('standard_conforming_strings', 'off', local=True)
735        self.assertEqual(g('standard_conforming_strings'), 'off')
736        self.db.end()
737        self.assertEqual(g('standard_conforming_strings'), 'on')
738
739    def testSetParameterSession(self):
740        f = self.db.set_parameter
741        g = self.db.get_parameter
742        self.assertEqual(g('standard_conforming_strings'), 'on')
743        self.db.begin()
744        f('standard_conforming_strings', 'off', local=False)
745        self.assertEqual(g('standard_conforming_strings'), 'off')
746        self.db.end()
747        self.assertEqual(g('standard_conforming_strings'), 'off')
748
749    def testReset(self):
750        db = DB()
751        default_datestyle = db.get_parameter('datestyle')
752        changed_datestyle = 'ISO, DMY'
753        if changed_datestyle == default_datestyle:
754            changed_datestyle == 'ISO, YMD'
755        self.db.set_parameter('datestyle', changed_datestyle)
756        r = self.db.get_parameter('datestyle')
757        self.assertEqual(r, changed_datestyle)
758        con = self.db.db
759        q = con.query("show datestyle")
760        self.db.reset()
761        r = q.getresult()[0][0]
762        self.assertEqual(r, changed_datestyle)
763        q = con.query("show datestyle")
764        r = q.getresult()[0][0]
765        self.assertEqual(r, default_datestyle)
766        r = self.db.get_parameter('datestyle')
767        self.assertEqual(r, default_datestyle)
768
769    def testReopen(self):
770        db = DB()
771        default_datestyle = db.get_parameter('datestyle')
772        changed_datestyle = 'ISO, DMY'
773        if changed_datestyle == default_datestyle:
774            changed_datestyle == 'ISO, YMD'
775        self.db.set_parameter('datestyle', changed_datestyle)
776        r = self.db.get_parameter('datestyle')
777        self.assertEqual(r, changed_datestyle)
778        con = self.db.db
779        q = con.query("show datestyle")
780        self.db.reopen()
781        r = q.getresult()[0][0]
782        self.assertEqual(r, changed_datestyle)
783        self.assertRaises(TypeError, getattr, con, 'query')
784        r = self.db.get_parameter('datestyle')
785        self.assertEqual(r, default_datestyle)
786
787    def testCreateTable(self):
788        table = 'test hello world'
789        values = [(2, "World!"), (1, "Hello")]
790        self.createTable(table, "n smallint, t varchar",
791            temporary=True, oids=True, values=values)
792        r = self.db.query('select t from "%s" order by n' % table).getresult()
793        r = ', '.join(row[0] for row in r)
794        self.assertEqual(r, "Hello, World!")
795        r = self.db.query('select oid from "%s" limit 1' % table).getresult()
796        self.assertIsInstance(r[0][0], int)
797
798    def testQuery(self):
799        query = self.db.query
800        table = 'test_table'
801        self.createTable(table, "n integer", oids=True)
802        q = "insert into test_table values (1)"
803        r = query(q)
804        self.assertIsInstance(r, int)
805        q = "insert into test_table select 2"
806        r = query(q)
807        self.assertIsInstance(r, int)
808        oid = r
809        q = "select oid from test_table where n=2"
810        r = query(q).getresult()
811        self.assertEqual(len(r), 1)
812        r = r[0]
813        self.assertEqual(len(r), 1)
814        r = r[0]
815        self.assertEqual(r, oid)
816        q = "insert into test_table select 3 union select 4 union select 5"
817        r = query(q)
818        self.assertIsInstance(r, str)
819        self.assertEqual(r, '3')
820        q = "update test_table set n=4 where n<5"
821        r = query(q)
822        self.assertIsInstance(r, str)
823        self.assertEqual(r, '4')
824        q = "delete from test_table"
825        r = query(q)
826        self.assertIsInstance(r, str)
827        self.assertEqual(r, '5')
828
829    def testMultipleQueries(self):
830        self.assertEqual(self.db.query(
831            "create temporary table test_multi (n integer);"
832            "insert into test_multi values (4711);"
833            "select n from test_multi").getresult()[0][0], 4711)
834
835    def testQueryWithParams(self):
836        query = self.db.query
837        self.createTable('test_table', 'n1 integer, n2 integer', oids=True)
838        q = "insert into test_table values ($1, $2)"
839        r = query(q, (1, 2))
840        self.assertIsInstance(r, int)
841        r = query(q, [3, 4])
842        self.assertIsInstance(r, int)
843        r = query(q, [5, 6])
844        self.assertIsInstance(r, int)
845        q = "select * from test_table order by 1, 2"
846        self.assertEqual(query(q).getresult(),
847            [(1, 2), (3, 4), (5, 6)])
848        q = "select * from test_table where n1=$1 and n2=$2"
849        self.assertEqual(query(q, 3, 4).getresult(), [(3, 4)])
850        q = "update test_table set n2=$2 where n1=$1"
851        r = query(q, 3, 7)
852        self.assertEqual(r, '1')
853        q = "select * from test_table order by 1, 2"
854        self.assertEqual(query(q).getresult(),
855            [(1, 2), (3, 7), (5, 6)])
856        q = "delete from test_table where n2!=$1"
857        r = query(q, 4)
858        self.assertEqual(r, '3')
859
860    def testEmptyQuery(self):
861        self.assertRaises(ValueError, self.db.query, '')
862
863    def testQueryProgrammingError(self):
864        try:
865            self.db.query("select 1/0")
866        except pg.ProgrammingError as error:
867            self.assertEqual(error.sqlstate, '22012')
868
869    def testPkey(self):
870        query = self.db.query
871        pkey = self.db.pkey
872        self.assertRaises(KeyError, pkey, 'test')
873        for t in ('pkeytest', 'primary key test'):
874            self.createTable('%s0' % t, 'a smallint')
875            self.createTable('%s1' % t, 'b smallint primary key')
876            self.createTable('%s2' % t,
877                'c smallint, d smallint primary key')
878            self.createTable('%s3' % t,
879                'e smallint, f smallint, g smallint, h smallint, i smallint,'
880                ' primary key (f, h)')
881            self.createTable('%s4' % t,
882                'e smallint, f smallint, g smallint, h smallint, i smallint,'
883                ' primary key (h, f)')
884            self.createTable('%s5' % t,
885                'more_than_one_letter varchar primary key')
886            self.createTable('%s6' % t,
887                '"with space" date primary key')
888            self.createTable('%s7' % t,
889                'a_very_long_column_name varchar, "with space" date, "42" int,'
890                ' primary key (a_very_long_column_name, "with space", "42")')
891            self.assertRaises(KeyError, pkey, '%s0' % t)
892            self.assertEqual(pkey('%s1' % t), 'b')
893            self.assertEqual(pkey('%s1' % t, True), ('b',))
894            self.assertEqual(pkey('%s1' % t, composite=False), 'b')
895            self.assertEqual(pkey('%s1' % t, composite=True), ('b',))
896            self.assertEqual(pkey('%s2' % t), 'd')
897            self.assertEqual(pkey('%s2' % t, composite=True), ('d',))
898            r = pkey('%s3' % t)
899            self.assertIsInstance(r, tuple)
900            self.assertEqual(r, ('f', 'h'))
901            r = pkey('%s3' % t, composite=False)
902            self.assertIsInstance(r, tuple)
903            self.assertEqual(r, ('f', 'h'))
904            r = pkey('%s4' % t)
905            self.assertIsInstance(r, tuple)
906            self.assertEqual(r, ('h', 'f'))
907            self.assertEqual(pkey('%s5' % t), 'more_than_one_letter')
908            self.assertEqual(pkey('%s6' % t), 'with space')
909            r = pkey('%s7' % t)
910            self.assertIsInstance(r, tuple)
911            self.assertEqual(r, (
912                'a_very_long_column_name', 'with space', '42'))
913            # a newly added primary key will be detected
914            query('alter table "%s0" add primary key (a)' % t)
915            self.assertEqual(pkey('%s0' % t), 'a')
916            # a changed primary key will not be detected,
917            # indicating that the internal cache is operating
918            query('alter table "%s1" rename column b to x' % t)
919            self.assertEqual(pkey('%s1' % t), 'b')
920            # we get the changed primary key when the cache is flushed
921            self.assertEqual(pkey('%s1' % t, flush=True), 'x')
922
923    def testGetDatabases(self):
924        databases = self.db.get_databases()
925        self.assertIn('template0', databases)
926        self.assertIn('template1', databases)
927        self.assertNotIn('not existing database', databases)
928        self.assertIn('postgres', databases)
929        self.assertIn(dbname, databases)
930
931    def testGetTables(self):
932        get_tables = self.db.get_tables
933        tables = ('A very Special Name', 'A_MiXeD_quoted_NaMe',
934            'Hello, Test World!', 'Zoro', 'a1', 'a2', 'a321',
935            'averyveryveryveryveryveryveryreallyreallylongtablename',
936            'b0', 'b3', 'x', 'xXx', 'xx', 'y', 'z')
937        for t in tables:
938            self.db.query('drop table if exists "%s" cascade' % t)
939        before_tables = get_tables()
940        self.assertIsInstance(before_tables, list)
941        for t in before_tables:
942            t = t.split('.', 1)
943            self.assertGreaterEqual(len(t), 2)
944            if len(t) > 2:
945                self.assertTrue(t[1].startswith('"'))
946            t = t[0]
947            self.assertNotEqual(t, 'information_schema')
948            self.assertFalse(t.startswith('pg_'))
949        for t in tables:
950            self.createTable(t, 'as select 0', temporary=False)
951        current_tables = get_tables()
952        new_tables = [t for t in current_tables if t not in before_tables]
953        expected_new_tables = ['public.%s' % (
954            '"%s"' % t if ' ' in t or t != t.lower() else t) for t in tables]
955        self.assertEqual(new_tables, expected_new_tables)
956        self.doCleanups()
957        after_tables = get_tables()
958        self.assertEqual(after_tables, before_tables)
959
960    def testGetRelations(self):
961        get_relations = self.db.get_relations
962        result = get_relations()
963        self.assertIn('public.test', result)
964        self.assertIn('public.test_view', result)
965        result = get_relations('rv')
966        self.assertIn('public.test', result)
967        self.assertIn('public.test_view', result)
968        result = get_relations('r')
969        self.assertIn('public.test', result)
970        self.assertNotIn('public.test_view', result)
971        result = get_relations('v')
972        self.assertNotIn('public.test', result)
973        self.assertIn('public.test_view', result)
974        result = get_relations('cisSt')
975        self.assertNotIn('public.test', result)
976        self.assertNotIn('public.test_view', result)
977
978    def testGetAttnames(self):
979        get_attnames = self.db.get_attnames
980        self.assertRaises(pg.ProgrammingError,
981            self.db.get_attnames, 'does_not_exist')
982        self.assertRaises(pg.ProgrammingError,
983            self.db.get_attnames, 'has.too.many.dots')
984        r = get_attnames('test')
985        self.assertIsInstance(r, dict)
986        self.assertEqual(r, dict(
987            i2='int', i4='int', i8='int', d='num',
988            f4='float', f8='float', m='money',
989            v4='text', c4='text', t='text'))
990        self.createTable('test_table',
991            'n int, alpha smallint, beta bool,'
992            ' gamma char(5), tau text, v varchar(3)')
993        r = get_attnames('test_table')
994        self.assertIsInstance(r, dict)
995        self.assertEqual(r, dict(
996            n='int', alpha='int', beta='bool',
997            gamma='text', tau='text', v='text'))
998
999    def testGetAttnamesWithQuotes(self):
1000        get_attnames = self.db.get_attnames
1001        table = 'test table for get_attnames()'
1002        self.createTable(table,
1003            '"Prime!" smallint, "much space" integer, "Questions?" text')
1004        r = get_attnames(table)
1005        self.assertIsInstance(r, dict)
1006        self.assertEqual(r, {
1007            'Prime!': 'int', 'much space': 'int', 'Questions?': 'text'})
1008        table = 'yet another test table for get_attnames()'
1009        self.createTable(table,
1010            'a smallint, b integer, c bigint,'
1011            ' e numeric, f float, f2 double precision, m money,'
1012            ' x smallint, y smallint, z smallint,'
1013            ' Normal_NaMe smallint, "Special Name" smallint,'
1014            ' t text, u char(2), v varchar(2),'
1015            ' primary key (y, u)', oids=True)
1016        r = get_attnames(table)
1017        self.assertIsInstance(r, dict)
1018        self.assertEqual(r, {'a': 'int', 'c': 'int', 'b': 'int',
1019            'e': 'num', 'f': 'float', 'f2': 'float', 'm': 'money',
1020            'normal_name': 'int', 'Special Name': 'int',
1021            'u': 'text', 't': 'text', 'v': 'text',
1022            'y': 'int', 'x': 'int', 'z': 'int', 'oid': 'int'})
1023
1024    def testGetAttnamesWithRegtypes(self):
1025        get_attnames = self.db.get_attnames
1026        self.createTable('test_table',
1027            ' n int, alpha smallint, beta bool,'
1028            ' gamma char(5), tau text, v varchar(3)')
1029        use_regtypes = self.db.use_regtypes
1030        regtypes = use_regtypes()
1031        self.assertFalse(regtypes)
1032        use_regtypes(True)
1033        try:
1034            r = get_attnames("test_table")
1035            self.assertIsInstance(r, dict)
1036        finally:
1037            use_regtypes(regtypes)
1038        self.assertEqual(r, dict(
1039            n='integer', alpha='smallint', beta='boolean',
1040            gamma='character', tau='text', v='character varying'))
1041
1042    def testGetAttnamesIsCached(self):
1043        get_attnames = self.db.get_attnames
1044        query = self.db.query
1045        self.createTable('test_table', 'col int')
1046        r = get_attnames("test_table")
1047        self.assertIsInstance(r, dict)
1048        self.assertEqual(r, dict(col='int'))
1049        query("alter table test_table alter column col type text")
1050        query("alter table test_table add column col2 int")
1051        r = get_attnames("test_table")
1052        self.assertEqual(r, dict(col='int'))
1053        r = get_attnames("test_table", flush=True)
1054        self.assertEqual(r, dict(col='text', col2='int'))
1055        query("alter table test_table drop column col2")
1056        r = get_attnames("test_table")
1057        self.assertEqual(r, dict(col='text', col2='int'))
1058        r = get_attnames("test_table", flush=True)
1059        self.assertEqual(r, dict(col='text'))
1060        query("alter table test_table drop column col")
1061        r = get_attnames("test_table")
1062        self.assertEqual(r, dict(col='text'))
1063        r = get_attnames("test_table", flush=True)
1064        self.assertEqual(r, dict())
1065
1066    def testGetAttnamesIsOrdered(self):
1067        get_attnames = self.db.get_attnames
1068        r = get_attnames('test', flush=True)
1069        self.assertIsInstance(r, OrderedDict)
1070        self.assertEqual(r, OrderedDict([
1071            ('i2', 'int'), ('i4', 'int'), ('i8', 'int'),
1072            ('d', 'num'), ('f4', 'float'), ('f8', 'float'), ('m', 'money'),
1073            ('v4', 'text'), ('c4', 'text'), ('t', 'text')]))
1074        if OrderedDict is not dict:
1075            r = ' '.join(list(r.keys()))
1076            self.assertEqual(r, 'i2 i4 i8 d f4 f8 m v4 c4 t')
1077        table = 'test table for get_attnames'
1078        self.createTable(table,
1079            ' n int, alpha smallint, v varchar(3),'
1080            ' gamma char(5), tau text, beta bool')
1081        r = get_attnames(table)
1082        self.assertIsInstance(r, OrderedDict)
1083        self.assertEqual(r, OrderedDict([
1084            ('n', 'int'), ('alpha', 'int'), ('v', 'text'),
1085            ('gamma', 'text'), ('tau', 'text'), ('beta', 'bool')]))
1086        if OrderedDict is not dict:
1087            r = ' '.join(list(r.keys()))
1088            self.assertEqual(r, 'n alpha v gamma tau beta')
1089        else:
1090            self.skipTest('OrderedDict is not supported')
1091
1092    def testGetAttnamesIsAttrDict(self):
1093        AttrDict = pg.AttrDict
1094        get_attnames = self.db.get_attnames
1095        r = get_attnames('test', flush=True)
1096        self.assertIsInstance(r, AttrDict)
1097        self.assertEqual(r, AttrDict([
1098            ('i2', 'int'), ('i4', 'int'), ('i8', 'int'),
1099            ('d', 'num'), ('f4', 'float'), ('f8', 'float'), ('m', 'money'),
1100            ('v4', 'text'), ('c4', 'text'), ('t', 'text')]))
1101        r = ' '.join(list(r.keys()))
1102        self.assertEqual(r, 'i2 i4 i8 d f4 f8 m v4 c4 t')
1103        table = 'test table for get_attnames'
1104        self.createTable(table,
1105            ' n int, alpha smallint, v varchar(3),'
1106            ' gamma char(5), tau text, beta bool')
1107        r = get_attnames(table)
1108        self.assertIsInstance(r, AttrDict)
1109        self.assertEqual(r, AttrDict([
1110            ('n', 'int'), ('alpha', 'int'), ('v', 'text'),
1111            ('gamma', 'text'), ('tau', 'text'), ('beta', 'bool')]))
1112        r = ' '.join(list(r.keys()))
1113        self.assertEqual(r, 'n alpha v gamma tau beta')
1114
1115    def testHasTablePrivilege(self):
1116        can = self.db.has_table_privilege
1117        self.assertEqual(can('test'), True)
1118        self.assertEqual(can('test', 'select'), True)
1119        self.assertEqual(can('test', 'SeLeCt'), True)
1120        self.assertEqual(can('test', 'SELECT'), True)
1121        self.assertEqual(can('test', 'insert'), True)
1122        self.assertEqual(can('test', 'update'), True)
1123        self.assertEqual(can('test', 'delete'), True)
1124        self.assertEqual(can('pg_views', 'select'), True)
1125        self.assertEqual(can('pg_views', 'delete'), False)
1126        self.assertRaises(pg.ProgrammingError, can, 'test', 'foobar')
1127        self.assertRaises(pg.ProgrammingError, can, 'table_does_not_exist')
1128
1129    def testGet(self):
1130        get = self.db.get
1131        query = self.db.query
1132        table = 'get_test_table'
1133        self.assertRaises(TypeError, get)
1134        self.assertRaises(TypeError, get, table)
1135        self.createTable(table, 'n integer, t text',
1136            values=enumerate('xyz', start=1))
1137        self.assertRaises(pg.ProgrammingError, get, table, 2)
1138        r = get(table, 2, 'n')
1139        self.assertIsInstance(r, dict)
1140        self.assertEqual(r, dict(n=2, t='y'))
1141        r = get(table, 1, 'n')
1142        self.assertEqual(r, dict(n=1, t='x'))
1143        r = get(table, (3,), ('n',))
1144        self.assertEqual(r, dict(n=3, t='z'))
1145        r = get(table, 'y', 't')
1146        self.assertEqual(r, dict(n=2, t='y'))
1147        self.assertRaises(pg.DatabaseError, get, table, 4)
1148        self.assertRaises(pg.DatabaseError, get, table, 4, 'n')
1149        self.assertRaises(pg.DatabaseError, get, table, 'y')
1150        self.assertRaises(pg.DatabaseError, get, table, 2, 't')
1151        s = dict(n=3)
1152        self.assertRaises(pg.ProgrammingError, get, table, s)
1153        r = get(table, s, 'n')
1154        self.assertIs(r, s)
1155        self.assertEqual(r, dict(n=3, t='z'))
1156        s.update(t='x')
1157        r = get(table, s, 't')
1158        self.assertIs(r, s)
1159        self.assertEqual(s, dict(n=1, t='x'))
1160        r = get(table, s, ('n', 't'))
1161        self.assertIs(r, s)
1162        self.assertEqual(r, dict(n=1, t='x'))
1163        query('alter table "%s" alter n set not null' % table)
1164        query('alter table "%s" add primary key (n)' % table)
1165        r = get(table, 2)
1166        self.assertIsInstance(r, dict)
1167        self.assertEqual(r, dict(n=2, t='y'))
1168        self.assertEqual(get(table, 1)['t'], 'x')
1169        self.assertEqual(get(table, 3)['t'], 'z')
1170        self.assertEqual(get(table + '*', 2)['t'], 'y')
1171        self.assertEqual(get(table + ' *', 2)['t'], 'y')
1172        self.assertRaises(KeyError, get, table, (2, 2))
1173        s = dict(n=3)
1174        r = get(table, s)
1175        self.assertIs(r, s)
1176        self.assertEqual(r, dict(n=3, t='z'))
1177        s.update(n=1)
1178        self.assertEqual(get(table, s)['t'], 'x')
1179        s.update(n=2)
1180        self.assertEqual(get(table, r)['t'], 'y')
1181        s.pop('n')
1182        self.assertRaises(KeyError, get, table, s)
1183
1184    def testGetWithOid(self):
1185        get = self.db.get
1186        query = self.db.query
1187        table = 'get_with_oid_test_table'
1188        self.createTable(table, 'n integer, t text', oids=True,
1189            values=enumerate('xyz', start=1))
1190        self.assertRaises(pg.ProgrammingError, get, table, 2)
1191        self.assertRaises(KeyError, get, table, {}, 'oid')
1192        r = get(table, 2, 'n')
1193        qoid = 'oid(%s)' % table
1194        self.assertIn(qoid, r)
1195        oid = r[qoid]
1196        self.assertIsInstance(oid, int)
1197        result = {'t': 'y', 'n': 2, qoid: oid}
1198        self.assertEqual(r, result)
1199        r = get(table, oid, 'oid')
1200        self.assertEqual(r, result)
1201        r = get(table, dict(oid=oid))
1202        self.assertEqual(r, result)
1203        r = get(table, dict(oid=oid), 'oid')
1204        self.assertEqual(r, result)
1205        r = get(table, {qoid: oid})
1206        self.assertEqual(r, result)
1207        r = get(table, {qoid: oid}, 'oid')
1208        self.assertEqual(r, result)
1209        self.assertEqual(get(table + '*', 2, 'n'), r)
1210        self.assertEqual(get(table + ' *', 2, 'n'), r)
1211        self.assertEqual(get(table, oid, 'oid')['t'], 'y')
1212        self.assertEqual(get(table, 1, 'n')['t'], 'x')
1213        self.assertEqual(get(table, 3, 'n')['t'], 'z')
1214        self.assertEqual(get(table, 2, 'n')['t'], 'y')
1215        self.assertRaises(pg.DatabaseError, get, table, 4, 'n')
1216        r['n'] = 3
1217        self.assertEqual(get(table, r, 'n')['t'], 'z')
1218        self.assertEqual(get(table, 1, 'n')['t'], 'x')
1219        self.assertEqual(get(table, r, 'oid')['t'], 'z')
1220        query('alter table "%s" alter n set not null' % table)
1221        query('alter table "%s" add primary key (n)' % table)
1222        self.assertEqual(get(table, 3)['t'], 'z')
1223        self.assertEqual(get(table, 1)['t'], 'x')
1224        self.assertEqual(get(table, 2)['t'], 'y')
1225        r['n'] = 1
1226        self.assertEqual(get(table, r)['t'], 'x')
1227        r['n'] = 3
1228        self.assertEqual(get(table, r)['t'], 'z')
1229        r['n'] = 2
1230        self.assertEqual(get(table, r)['t'], 'y')
1231        r = get(table, oid, 'oid')
1232        self.assertEqual(r, result)
1233        r = get(table, dict(oid=oid))
1234        self.assertEqual(r, result)
1235        r = get(table, dict(oid=oid), 'oid')
1236        self.assertEqual(r, result)
1237        r = get(table, {qoid: oid})
1238        self.assertEqual(r, result)
1239        r = get(table, {qoid: oid}, 'oid')
1240        self.assertEqual(r, result)
1241        r = get(table, dict(oid=oid, n=1))
1242        self.assertEqual(r['n'], 1)
1243        self.assertNotEqual(r[qoid], oid)
1244        r = get(table, dict(oid=oid, t='z'), 't')
1245        self.assertEqual(r['n'], 3)
1246        self.assertNotEqual(r[qoid], oid)
1247
1248    def testGetWithCompositeKey(self):
1249        get = self.db.get
1250        query = self.db.query
1251        table = 'get_test_table_1'
1252        self.createTable(table, 'n integer primary key, t text',
1253            values=enumerate('abc', start=1))
1254        self.assertEqual(get(table, 2)['t'], 'b')
1255        self.assertEqual(get(table, 1, 'n')['t'], 'a')
1256        self.assertEqual(get(table, 2, ('n',))['t'], 'b')
1257        self.assertEqual(get(table, 3, ['n'])['t'], 'c')
1258        self.assertEqual(get(table, (2,), ('n',))['t'], 'b')
1259        self.assertEqual(get(table, 'b', 't')['n'], 2)
1260        self.assertEqual(get(table, ('a',), ('t',))['n'], 1)
1261        self.assertEqual(get(table, ['c'], ['t'])['n'], 3)
1262        table = 'get_test_table_2'
1263        self.createTable(table,
1264            'n integer, m integer, t text, primary key (n, m)',
1265            values=[(n + 1, m + 1, chr(ord('a') + 2 * n + m))
1266                for n in range(3) for m in range(2)])
1267        self.assertRaises(KeyError, get, table, 2)
1268        self.assertEqual(get(table, (1, 1))['t'], 'a')
1269        self.assertEqual(get(table, (1, 2))['t'], 'b')
1270        self.assertEqual(get(table, (2, 1))['t'], 'c')
1271        self.assertEqual(get(table, (1, 2), ('n', 'm'))['t'], 'b')
1272        self.assertEqual(get(table, (1, 2), ('m', 'n'))['t'], 'c')
1273        self.assertEqual(get(table, (3, 1), ('n', 'm'))['t'], 'e')
1274        self.assertEqual(get(table, (1, 3), ('m', 'n'))['t'], 'e')
1275        self.assertEqual(get(table, dict(n=2, m=2))['t'], 'd')
1276        self.assertEqual(get(table, dict(n=1, m=2), ('n', 'm'))['t'], 'b')
1277        self.assertEqual(get(table, dict(n=2, m=1), ['n', 'm'])['t'], 'c')
1278        self.assertEqual(get(table, dict(n=3, m=2), ('m', 'n'))['t'], 'f')
1279
1280    def testGetWithQuotedNames(self):
1281        get = self.db.get
1282        query = self.db.query
1283        table = 'test table for get()'
1284        self.createTable(table, '"Prime!" smallint primary key,'
1285            ' "much space" integer, "Questions?" text',
1286            values=[(17, 1001, 'No!')])
1287        r = get(table, 17)
1288        self.assertIsInstance(r, dict)
1289        self.assertEqual(r['Prime!'], 17)
1290        self.assertEqual(r['much space'], 1001)
1291        self.assertEqual(r['Questions?'], 'No!')
1292
1293    def testGetFromView(self):
1294        self.db.query('delete from test where i4=14')
1295        self.db.query('insert into test (i4, v4) values('
1296            "14, 'abc4')")
1297        r = self.db.get('test_view', 14, 'i4')
1298        self.assertIn('v4', r)
1299        self.assertEqual(r['v4'], 'abc4')
1300
1301    def testGetLittleBobbyTables(self):
1302        get = self.db.get
1303        query = self.db.query
1304        self.createTable('test_students',
1305            'firstname varchar primary key, nickname varchar, grade char(2)',
1306            values=[("D'Arcy", 'Darcey', 'A+'), ('Sheldon', 'Moonpie', 'A+'),
1307                    ('Robert', 'Little Bobby Tables', 'D-')])
1308        r = get('test_students', 'Sheldon')
1309        self.assertEqual(r, dict(
1310            firstname="Sheldon", nickname='Moonpie', grade='A+'))
1311        r = get('test_students', 'Robert')
1312        self.assertEqual(r, dict(
1313            firstname="Robert", nickname='Little Bobby Tables', grade='D-'))
1314        r = get('test_students', "D'Arcy")
1315        self.assertEqual(r, dict(
1316            firstname="D'Arcy", nickname='Darcey', grade='A+'))
1317        try:
1318            get('test_students', "D' Arcy")
1319        except pg.DatabaseError as error:
1320            self.assertEqual(str(error),
1321                'No such record in test_students\nwhere "firstname" = $1\n'
1322                'with $1="D\' Arcy"')
1323        try:
1324            get('test_students', "Robert'); TRUNCATE TABLE test_students;--")
1325        except pg.DatabaseError as error:
1326            self.assertEqual(str(error),
1327                'No such record in test_students\nwhere "firstname" = $1\n'
1328                'with $1="Robert\'); TRUNCATE TABLE test_students;--"')
1329        q = "select * from test_students order by 1 limit 4"
1330        r = query(q).getresult()
1331        self.assertEqual(len(r), 3)
1332        self.assertEqual(r[1][2], 'D-')
1333
1334    def testInsert(self):
1335        insert = self.db.insert
1336        query = self.db.query
1337        bool_on = pg.get_bool()
1338        decimal = pg.get_decimal()
1339        table = 'insert_test_table'
1340        self.createTable(table,
1341            'i2 smallint, i4 integer, i8 bigint,'
1342            ' d numeric, f4 real, f8 double precision, m money,'
1343            ' v4 varchar(4), c4 char(4), t text,'
1344            ' b boolean, ts timestamp', oids=True)
1345        oid_table = 'oid(%s)' % table
1346        tests = [dict(i2=None, i4=None, i8=None),
1347            (dict(i2='', i4='', i8=''), dict(i2=None, i4=None, i8=None)),
1348            (dict(i2=0, i4=0, i8=0), dict(i2=0, i4=0, i8=0)),
1349            dict(i2=42, i4=123456, i8=9876543210),
1350            dict(i2=2 ** 15 - 1,
1351                i4=int(2 ** 31 - 1), i8=long(2 ** 63 - 1)),
1352            dict(d=None), (dict(d=''), dict(d=None)),
1353            dict(d=Decimal(0)), (dict(d=0), dict(d=Decimal(0))),
1354            dict(f4=None, f8=None), dict(f4=0, f8=0),
1355            (dict(f4='', f8=''), dict(f4=None, f8=None)),
1356            (dict(d=1234.5, f4=1234.5, f8=1234.5),
1357                  dict(d=Decimal('1234.5'))),
1358            dict(d=Decimal('123.456789'), f4=12.375, f8=123.4921875),
1359            dict(d=Decimal('123456789.9876543212345678987654321')),
1360            dict(m=None), (dict(m=''), dict(m=None)),
1361            dict(m=Decimal('-1234.56')),
1362            (dict(m=('-1234.56')), dict(m=Decimal('-1234.56'))),
1363            dict(m=Decimal('1234.56')), dict(m=Decimal('123456')),
1364            (dict(m='1234.56'), dict(m=Decimal('1234.56'))),
1365            (dict(m=1234.5), dict(m=Decimal('1234.5'))),
1366            (dict(m=-1234.5), dict(m=Decimal('-1234.5'))),
1367            (dict(m=123456), dict(m=Decimal('123456'))),
1368            (dict(m='1234567.89'), dict(m=Decimal('1234567.89'))),
1369            dict(b=None), (dict(b=''), dict(b=None)),
1370            dict(b='f'), dict(b='t'),
1371            (dict(b=0), dict(b='f')), (dict(b=1), dict(b='t')),
1372            (dict(b=False), dict(b='f')), (dict(b=True), dict(b='t')),
1373            (dict(b='0'), dict(b='f')), (dict(b='1'), dict(b='t')),
1374            (dict(b='n'), dict(b='f')), (dict(b='y'), dict(b='t')),
1375            (dict(b='no'), dict(b='f')), (dict(b='yes'), dict(b='t')),
1376            (dict(b='off'), dict(b='f')), (dict(b='on'), dict(b='t')),
1377            dict(v4=None, c4=None, t=None),
1378            (dict(v4='', c4='', t=''), dict(c4=' ' * 4)),
1379            dict(v4='1234', c4='1234', t='1234' * 10),
1380            dict(v4='abcd', c4='abcd', t='abcdefg'),
1381            (dict(v4='abc', c4='abc', t='abc'), dict(c4='abc ')),
1382            dict(ts=None), (dict(ts=''), dict(ts=None)),
1383            (dict(ts=0), dict(ts=None)), (dict(ts=False), dict(ts=None)),
1384            dict(ts='2012-12-21 00:00:00'),
1385            (dict(ts='2012-12-21'), dict(ts='2012-12-21 00:00:00')),
1386            dict(ts='2012-12-21 12:21:12'),
1387            dict(ts='2013-01-05 12:13:14'),
1388            dict(ts='current_timestamp')]
1389        for test in tests:
1390            if isinstance(test, dict):
1391                data = test
1392                change = {}
1393            else:
1394                data, change = test
1395            expect = data.copy()
1396            expect.update(change)
1397            if bool_on:
1398                b = expect.get('b')
1399                if b is not None:
1400                    expect['b'] = b == 't'
1401            if decimal is not Decimal:
1402                d = expect.get('d')
1403                if d is not None:
1404                    expect['d'] = decimal(d)
1405                m = expect.get('m')
1406                if m is not None:
1407                    expect['m'] = decimal(m)
1408            self.assertEqual(insert(table, data), data)
1409            self.assertIn(oid_table, data)
1410            oid = data[oid_table]
1411            self.assertIsInstance(oid, int)
1412            data = dict(item for item in data.items()
1413                if item[0] in expect)
1414            ts = expect.get('ts')
1415            if ts == 'current_timestamp':
1416                ts = expect['ts'] = data['ts']
1417                if len(ts) > 19:
1418                    self.assertEqual(ts[19], '.')
1419                    ts = ts[:19]
1420                else:
1421                    self.assertEqual(len(ts), 19)
1422                self.assertTrue(ts[:4].isdigit())
1423                self.assertEqual(ts[4], '-')
1424                self.assertEqual(ts[10], ' ')
1425                self.assertTrue(ts[11:13].isdigit())
1426                self.assertEqual(ts[13], ':')
1427            self.assertEqual(data, expect)
1428            data = query(
1429                'select oid,* from "%s"' % table).dictresult()[0]
1430            self.assertEqual(data['oid'], oid)
1431            data = dict(item for item in data.items()
1432                if item[0] in expect)
1433            self.assertEqual(data, expect)
1434            query('delete from "%s"' % table)
1435
1436    def testInsertWithOid(self):
1437        insert = self.db.insert
1438        query = self.db.query
1439        self.createTable('test_table', 'n int', oids=True)
1440        r = insert('test_table', n=1)
1441        self.assertIsInstance(r, dict)
1442        self.assertEqual(r['n'], 1)
1443        self.assertNotIn('oid', r)
1444        qoid = 'oid(test_table)'
1445        self.assertIn(qoid, r)
1446        oid = r[qoid]
1447        self.assertEqual(sorted(r.keys()), ['n', qoid])
1448        r = insert('test_table', n=2, oid=oid)
1449        self.assertIsInstance(r, dict)
1450        self.assertEqual(r['n'], 2)
1451        self.assertIn(qoid, r)
1452        self.assertNotEqual(r[qoid], oid)
1453        self.assertNotIn('oid', r)
1454        r = insert('test_table', None, n=3)
1455        self.assertIsInstance(r, dict)
1456        self.assertEqual(r['n'], 3)
1457        s = r
1458        r = insert('test_table', r)
1459        self.assertIs(r, s)
1460        self.assertEqual(r['n'], 3)
1461        r = insert('test_table *', r)
1462        self.assertIs(r, s)
1463        self.assertEqual(r['n'], 3)
1464        r = insert('test_table', r, n=4)
1465        self.assertIs(r, s)
1466        self.assertEqual(r['n'], 4)
1467        self.assertNotIn('oid', r)
1468        self.assertIn(qoid, r)
1469        oid = r[qoid]
1470        r = insert('test_table', r, n=5, oid=oid)
1471        self.assertIs(r, s)
1472        self.assertEqual(r['n'], 5)
1473        self.assertIn(qoid, r)
1474        self.assertNotEqual(r[qoid], oid)
1475        self.assertNotIn('oid', r)
1476        r['oid'] = oid = r[qoid]
1477        r = insert('test_table', r, n=6)
1478        self.assertIs(r, s)
1479        self.assertEqual(r['n'], 6)
1480        self.assertIn(qoid, r)
1481        self.assertNotEqual(r[qoid], oid)
1482        self.assertNotIn('oid', r)
1483        q = 'select n from test_table order by 1 limit 9'
1484        r = ' '.join(str(row[0]) for row in query(q).getresult())
1485        self.assertEqual(r, '1 2 3 3 3 4 5 6')
1486        query("truncate test_table")
1487        query("alter table test_table add unique (n)")
1488        r = insert('test_table', dict(n=7))
1489        self.assertIsInstance(r, dict)
1490        self.assertEqual(r['n'], 7)
1491        self.assertRaises(pg.ProgrammingError, insert, 'test_table', r)
1492        r['n'] = 6
1493        self.assertRaises(pg.ProgrammingError, insert, 'test_table', r, n=7)
1494        self.assertIsInstance(r, dict)
1495        self.assertEqual(r['n'], 7)
1496        r['n'] = 6
1497        r = insert('test_table', r)
1498        self.assertIsInstance(r, dict)
1499        self.assertEqual(r['n'], 6)
1500        r = query(q).getresult()
1501        r = ' '.join(str(row[0]) for row in query(q).getresult())
1502        self.assertEqual(r, '6 7')
1503
1504    def testInsertWithQuotedNames(self):
1505        insert = self.db.insert
1506        query = self.db.query
1507        table = 'test table for insert()'
1508        self.createTable(table, '"Prime!" smallint primary key,'
1509            ' "much space" integer, "Questions?" text')
1510        r = {'Prime!': 11, 'much space': 2002, 'Questions?': 'What?'}
1511        r = insert(table, r)
1512        self.assertIsInstance(r, dict)
1513        self.assertEqual(r['Prime!'], 11)
1514        self.assertEqual(r['much space'], 2002)
1515        self.assertEqual(r['Questions?'], 'What?')
1516        r = query('select * from "%s" limit 2' % table).dictresult()
1517        self.assertEqual(len(r), 1)
1518        r = r[0]
1519        self.assertEqual(r['Prime!'], 11)
1520        self.assertEqual(r['much space'], 2002)
1521        self.assertEqual(r['Questions?'], 'What?')
1522
1523    def testInsertIntoView(self):
1524        insert = self.db.insert
1525        query = self.db.query
1526        query("truncate test")
1527        q = 'select * from test_view order by i4 limit 3'
1528        r = query(q).getresult()
1529        self.assertEqual(r, [])
1530        r = dict(i4=1234, v4='abcd')
1531        insert('test', r)
1532        self.assertIsNone(r['i2'])
1533        self.assertEqual(r['i4'], 1234)
1534        self.assertIsNone(r['i8'])
1535        self.assertEqual(r['v4'], 'abcd')
1536        self.assertIsNone(r['c4'])
1537        r = query(q).getresult()
1538        self.assertEqual(r, [(1234, 'abcd')])
1539        r = dict(i4=5678, v4='efgh')
1540        try:
1541            insert('test_view', r)
1542        except pg.ProgrammingError as error:
1543            if self.db.server_version < 90300:
1544                # must setup rules in older PostgreSQL versions
1545                self.skipTest('database cannot insert into view')
1546            self.fail(str(error))
1547        self.assertNotIn('i2', r)
1548        self.assertEqual(r['i4'], 5678)
1549        self.assertNotIn('i8', r)
1550        self.assertEqual(r['v4'], 'efgh')
1551        self.assertNotIn('c4', r)
1552        r = query(q).getresult()
1553        self.assertEqual(r, [(1234, 'abcd'), (5678, 'efgh')])
1554
1555    def testUpdate(self):
1556        update = self.db.update
1557        query = self.db.query
1558        self.assertRaises(pg.ProgrammingError, update,
1559            'test', i2=2, i4=4, i8=8)
1560        table = 'update_test_table'
1561        self.createTable(table, 'n integer, t text', oids=True,
1562            values=enumerate('xyz', start=1))
1563        self.assertRaises(pg.ProgrammingError, self.db.get, table, 2)
1564        r = self.db.get(table, 2, 'n')
1565        r['t'] = 'u'
1566        s = update(table, r)
1567        self.assertEqual(s, r)
1568        q = 'select t from "%s" where n=2' % table
1569        r = query(q).getresult()[0][0]
1570        self.assertEqual(r, 'u')
1571
1572    def testUpdateWithOid(self):
1573        update = self.db.update
1574        get = self.db.get
1575        query = self.db.query
1576        self.createTable('test_table', 'n int', oids=True, values=[1])
1577        s = get('test_table', 1, 'n')
1578        self.assertIsInstance(s, dict)
1579        self.assertEqual(s['n'], 1)
1580        s['n'] = 2
1581        r = update('test_table', s)
1582        self.assertIs(r, s)
1583        self.assertEqual(r['n'], 2)
1584        qoid = 'oid(test_table)'
1585        self.assertIn(qoid, r)
1586        self.assertNotIn('oid', r)
1587        self.assertEqual(sorted(r.keys()), ['n', qoid])
1588        r['n'] = 3
1589        oid = r.pop(qoid)
1590        r = update('test_table', r, oid=oid)
1591        self.assertIs(r, s)
1592        self.assertEqual(r['n'], 3)
1593        r.pop(qoid)
1594        self.assertRaises(pg.ProgrammingError, update, 'test_table', r)
1595        s = get('test_table', 3, 'n')
1596        self.assertIsInstance(s, dict)
1597        self.assertEqual(s['n'], 3)
1598        s.pop('n')
1599        r = update('test_table', s)
1600        oid = r.pop(qoid)
1601        self.assertEqual(r, {})
1602        q = "select n from test_table limit 2"
1603        r = query(q).getresult()
1604        self.assertEqual(r, [(3,)])
1605        query("insert into test_table values (1)")
1606        self.assertRaises(pg.ProgrammingError,
1607            update, 'test_table', dict(oid=oid, n=4))
1608        r = update('test_table', dict(n=4), oid=oid)
1609        self.assertEqual(r['n'], 4)
1610        r = update('test_table *', dict(n=5), oid=oid)
1611        self.assertEqual(r['n'], 5)
1612        query("alter table test_table add column m int")
1613        query("alter table test_table add primary key (n)")
1614        self.assertIn('m', self.db.get_attnames('test_table', flush=True))
1615        self.assertEqual('n', self.db.pkey('test_table', flush=True))
1616        s = dict(n=1, m=4)
1617        r = update('test_table', s)
1618        self.assertIs(r, s)
1619        self.assertEqual(r['n'], 1)
1620        self.assertEqual(r['m'], 4)
1621        s = dict(m=7)
1622        r = update('test_table', s, n=5)
1623        self.assertIs(r, s)
1624        self.assertEqual(r['n'], 5)
1625        self.assertEqual(r['m'], 7)
1626        q = "select n, m from test_table order by 1 limit 3"
1627        r = query(q).getresult()
1628        self.assertEqual(r, [(1, 4), (5, 7)])
1629        s = dict(m=9, oid=oid)
1630        self.assertRaises(KeyError, update, 'test_table', s)
1631        r = update('test_table', s, oid=oid)
1632        self.assertIs(r, s)
1633        self.assertEqual(r['n'], 5)
1634        self.assertEqual(r['m'], 9)
1635        s = dict(n=1, m=3, oid=oid)
1636        r = update('test_table', s)
1637        self.assertIs(r, s)
1638        self.assertEqual(r['n'], 1)
1639        self.assertEqual(r['m'], 3)
1640        r = query(q).getresult()
1641        self.assertEqual(r, [(1, 3), (5, 9)])
1642
1643    def testUpdateWithoutOid(self):
1644        update = self.db.update
1645        query = self.db.query
1646        self.assertRaises(pg.ProgrammingError, update,
1647            'test', i2=2, i4=4, i8=8)
1648        table = 'update_test_table'
1649        self.createTable(table, 'n integer primary key, t text', oids=False,
1650            values=enumerate('xyz', start=1))
1651        r = self.db.get(table, 2)
1652        r['t'] = 'u'
1653        s = update(table, r)
1654        self.assertEqual(s, r)
1655        q = 'select t from "%s" where n=2' % table
1656        r = query(q).getresult()[0][0]
1657        self.assertEqual(r, 'u')
1658
1659    def testUpdateWithCompositeKey(self):
1660        update = self.db.update
1661        query = self.db.query
1662        table = 'update_test_table_1'
1663        self.createTable(table, 'n integer primary key, t text',
1664            values=enumerate('abc', start=1))
1665        self.assertRaises(KeyError, update, table, dict(t='b'))
1666        s = dict(n=2, t='d')
1667        r = update(table, s)
1668        self.assertIs(r, s)
1669        self.assertEqual(r['n'], 2)
1670        self.assertEqual(r['t'], 'd')
1671        q = 'select t from "%s" where n=2' % table
1672        r = query(q).getresult()[0][0]
1673        self.assertEqual(r, 'd')
1674        s.update(dict(n=4, t='e'))
1675        r = update(table, s)
1676        self.assertEqual(r['n'], 4)
1677        self.assertEqual(r['t'], 'e')
1678        q = 'select t from "%s" where n=2' % table
1679        r = query(q).getresult()[0][0]
1680        self.assertEqual(r, 'd')
1681        q = 'select t from "%s" where n=4' % table
1682        r = query(q).getresult()
1683        self.assertEqual(len(r), 0)
1684        query('drop table "%s"' % table)
1685        table = 'update_test_table_2'
1686        self.createTable(table,
1687            'n integer, m integer, t text, primary key (n, m)',
1688            values=[(n + 1, m + 1, chr(ord('a') + 2 * n + m))
1689                for n in range(3) for m in range(2)])
1690        self.assertRaises(KeyError, update, table, dict(n=2, t='b'))
1691        self.assertEqual(update(table,
1692            dict(n=2, m=2, t='x'))['t'], 'x')
1693        q = 'select t from "%s" where n=2 order by m' % table
1694        r = [r[0] for r in query(q).getresult()]
1695        self.assertEqual(r, ['c', 'x'])
1696
1697    def testUpdateWithQuotedNames(self):
1698        update = self.db.update
1699        query = self.db.query
1700        table = 'test table for update()'
1701        self.createTable(table, '"Prime!" smallint primary key,'
1702            ' "much space" integer, "Questions?" text',
1703            values=[(13, 3003, 'Why!')])
1704        r = {'Prime!': 13, 'much space': 7007, 'Questions?': 'When?'}
1705        r = update(table, r)
1706        self.assertIsInstance(r, dict)
1707        self.assertEqual(r['Prime!'], 13)
1708        self.assertEqual(r['much space'], 7007)
1709        self.assertEqual(r['Questions?'], 'When?')
1710        r = query('select * from "%s" limit 2' % table).dictresult()
1711        self.assertEqual(len(r), 1)
1712        r = r[0]
1713        self.assertEqual(r['Prime!'], 13)
1714        self.assertEqual(r['much space'], 7007)
1715        self.assertEqual(r['Questions?'], 'When?')
1716
1717    def testUpsert(self):
1718        upsert = self.db.upsert
1719        query = self.db.query
1720        self.assertRaises(pg.ProgrammingError, upsert,
1721            'test', i2=2, i4=4, i8=8)
1722        table = 'upsert_test_table'
1723        self.createTable(table, 'n integer primary key, t text', oids=True)
1724        s = dict(n=1, t='x')
1725        try:
1726            r = upsert(table, s)
1727        except pg.ProgrammingError as error:
1728            if self.db.server_version < 90500:
1729                self.skipTest('database does not support upsert')
1730            self.fail(str(error))
1731        self.assertIs(r, s)
1732        self.assertEqual(r['n'], 1)
1733        self.assertEqual(r['t'], 'x')
1734        s.update(n=2, t='y')
1735        r = upsert(table, s, **dict.fromkeys(s))
1736        self.assertIs(r, s)
1737        self.assertEqual(r['n'], 2)
1738        self.assertEqual(r['t'], 'y')
1739        q = 'select n, t from "%s" order by n limit 3' % table
1740        r = query(q).getresult()
1741        self.assertEqual(r, [(1, 'x'), (2, 'y')])
1742        s.update(t='z')
1743        r = upsert(table, s)
1744        self.assertIs(r, s)
1745        self.assertEqual(r['n'], 2)
1746        self.assertEqual(r['t'], 'z')
1747        r = query(q).getresult()
1748        self.assertEqual(r, [(1, 'x'), (2, 'z')])
1749        s.update(t='n')
1750        r = upsert(table, s, t=False)
1751        self.assertIs(r, s)
1752        self.assertEqual(r['n'], 2)
1753        self.assertEqual(r['t'], 'z')
1754        r = query(q).getresult()
1755        self.assertEqual(r, [(1, 'x'), (2, 'z')])
1756        s.update(t='y')
1757        r = upsert(table, s, t=True)
1758        self.assertIs(r, s)
1759        self.assertEqual(r['n'], 2)
1760        self.assertEqual(r['t'], 'y')
1761        r = query(q).getresult()
1762        self.assertEqual(r, [(1, 'x'), (2, 'y')])
1763        s.update(t='n')
1764        r = upsert(table, s, t="included.t || '2'")
1765        self.assertIs(r, s)
1766        self.assertEqual(r['n'], 2)
1767        self.assertEqual(r['t'], 'y2')
1768        r = query(q).getresult()
1769        self.assertEqual(r, [(1, 'x'), (2, 'y2')])
1770        s.update(t='y')
1771        r = upsert(table, s, t="excluded.t || '3'")
1772        self.assertIs(r, s)
1773        self.assertEqual(r['n'], 2)
1774        self.assertEqual(r['t'], 'y3')
1775        r = query(q).getresult()
1776        self.assertEqual(r, [(1, 'x'), (2, 'y3')])
1777        s.update(n=1, t='2')
1778        r = upsert(table, s, t="included.t || excluded.t")
1779        self.assertIs(r, s)
1780        self.assertEqual(r['n'], 1)
1781        self.assertEqual(r['t'], 'x2')
1782        r = query(q).getresult()
1783        self.assertEqual(r, [(1, 'x2'), (2, 'y3')])
1784        # not existing columns and oid parameter should be ignored
1785        s = dict(m=3, u='z')
1786        r = upsert(table, s, oid='invalid')
1787        self.assertIs(r, s)
1788
1789    def testUpsertWithOid(self):
1790        upsert = self.db.upsert
1791        get = self.db.get
1792        query = self.db.query
1793        self.createTable('test_table', 'n int', oids=True, values=[1])
1794        self.assertRaises(pg.ProgrammingError,
1795            upsert, 'test_table', dict(n=2))
1796        r = get('test_table', 1, 'n')
1797        self.assertIsInstance(r, dict)
1798        self.assertEqual(r['n'], 1)
1799        qoid = 'oid(test_table)'
1800        self.assertIn(qoid, r)
1801        self.assertNotIn('oid', r)
1802        oid = r[qoid]
1803        self.assertRaises(pg.ProgrammingError,
1804            upsert, 'test_table', dict(n=2, oid=oid))
1805        query("alter table test_table add column m int")
1806        query("alter table test_table add primary key (n)")
1807        self.assertIn('m', self.db.get_attnames('test_table', flush=True))
1808        self.assertEqual('n', self.db.pkey('test_table', flush=True))
1809        s = dict(n=2)
1810        try:
1811            r = upsert('test_table', s)
1812        except pg.ProgrammingError as error:
1813            if self.db.server_version < 90500:
1814                self.skipTest('database does not support upsert')
1815            self.fail(str(error))
1816        self.assertIs(r, s)
1817        self.assertEqual(r['n'], 2)
1818        self.assertIsNone(r['m'])
1819        q = query("select n, m from test_table order by n limit 3")
1820        self.assertEqual(q.getresult(), [(1, None), (2, None)])
1821        r['oid'] = oid
1822        r = upsert('test_table', r)
1823        self.assertIs(r, s)
1824        self.assertEqual(r['n'], 2)
1825        self.assertIsNone(r['m'])
1826        self.assertIn(qoid, r)
1827        self.assertNotIn('oid', r)
1828        self.assertNotEqual(r[qoid], oid)
1829        r['m'] = 7
1830        r = upsert('test_table', r)
1831        self.assertIs(r, s)
1832        self.assertEqual(r['n'], 2)
1833        self.assertEqual(r['m'], 7)
1834        r.update(n=1, m=3)
1835        r = upsert('test_table', r)
1836        self.assertIs(r, s)
1837        self.assertEqual(r['n'], 1)
1838        self.assertEqual(r['m'], 3)
1839        q = query("select n, m from test_table order by n limit 3")
1840        self.assertEqual(q.getresult(), [(1, 3), (2, 7)])
1841        r = upsert('test_table', r, oid='invalid')
1842        self.assertIs(r, s)
1843        self.assertEqual(r['n'], 1)
1844        self.assertEqual(r['m'], 3)
1845        r['m'] = 5
1846        r = upsert('test_table', r, m=False)
1847        self.assertIs(r, s)
1848        self.assertEqual(r['n'], 1)
1849        self.assertEqual(r['m'], 3)
1850        r['m'] = 5
1851        r = upsert('test_table', r, m=True)
1852        self.assertIs(r, s)
1853        self.assertEqual(r['n'], 1)
1854        self.assertEqual(r['m'], 5)
1855        r.update(n=2, m=1)
1856        r = upsert('test_table', r, m='included.m')
1857        self.assertIs(r, s)
1858        self.assertEqual(r['n'], 2)
1859        self.assertEqual(r['m'], 7)
1860        r['m'] = 9
1861        r = upsert('test_table', r, m='excluded.m')
1862        self.assertIs(r, s)
1863        self.assertEqual(r['n'], 2)
1864        self.assertEqual(r['m'], 9)
1865        r['m'] = 8
1866        r = upsert('test_table *', r, m='included.m + 1')
1867        self.assertIs(r, s)
1868        self.assertEqual(r['n'], 2)
1869        self.assertEqual(r['m'], 10)
1870        q = query("select n, m from test_table order by n limit 3")
1871        self.assertEqual(q.getresult(), [(1, 5), (2, 10)])
1872
1873    def testUpsertWithCompositeKey(self):
1874        upsert = self.db.upsert
1875        query = self.db.query
1876        table = 'upsert_test_table_2'
1877        self.createTable(table,
1878            'n integer, m integer, t text, primary key (n, m)')
1879        s = dict(n=1, m=2, t='x')
1880        try:
1881            r = upsert(table, s)
1882        except pg.ProgrammingError as error:
1883            if self.db.server_version < 90500:
1884                self.skipTest('database does not support upsert')
1885            self.fail(str(error))
1886        self.assertIs(r, s)
1887        self.assertEqual(r['n'], 1)
1888        self.assertEqual(r['m'], 2)
1889        self.assertEqual(r['t'], 'x')
1890        s.update(m=3, t='y')
1891        r = upsert(table, s, **dict.fromkeys(s))
1892        self.assertIs(r, s)
1893        self.assertEqual(r['n'], 1)
1894        self.assertEqual(r['m'], 3)
1895        self.assertEqual(r['t'], 'y')
1896        q = 'select n, m, t from "%s" order by n, m limit 3' % table
1897        r = query(q).getresult()
1898        self.assertEqual(r, [(1, 2, 'x'), (1, 3, 'y')])
1899        s.update(t='z')
1900        r = upsert(table, s)
1901        self.assertIs(r, s)
1902        self.assertEqual(r['n'], 1)
1903        self.assertEqual(r['m'], 3)
1904        self.assertEqual(r['t'], 'z')
1905        r = query(q).getresult()
1906        self.assertEqual(r, [(1, 2, 'x'), (1, 3, 'z')])
1907        s.update(t='n')
1908        r = upsert(table, s, t=False)
1909        self.assertIs(r, s)
1910        self.assertEqual(r['n'], 1)
1911        self.assertEqual(r['m'], 3)
1912        self.assertEqual(r['t'], 'z')
1913        r = query(q).getresult()
1914        self.assertEqual(r, [(1, 2, 'x'), (1, 3, 'z')])
1915        s.update(t='n')
1916        r = upsert(table, s, t=True)
1917        self.assertIs(r, s)
1918        self.assertEqual(r['n'], 1)
1919        self.assertEqual(r['m'], 3)
1920        self.assertEqual(r['t'], 'n')
1921        r = query(q).getresult()
1922        self.assertEqual(r, [(1, 2, 'x'), (1, 3, 'n')])
1923        s.update(n=2, t='y')
1924        r = upsert(table, s, t="'z'")
1925        self.assertIs(r, s)
1926        self.assertEqual(r['n'], 2)
1927        self.assertEqual(r['m'], 3)
1928        self.assertEqual(r['t'], 'y')
1929        r = query(q).getresult()
1930        self.assertEqual(r, [(1, 2, 'x'), (1, 3, 'n'), (2, 3, 'y')])
1931        s.update(n=1, t='m')
1932        r = upsert(table, s, t='included.t || excluded.t')
1933        self.assertIs(r, s)
1934        self.assertEqual(r['n'], 1)
1935        self.assertEqual(r['m'], 3)
1936        self.assertEqual(r['t'], 'nm')
1937        r = query(q).getresult()
1938        self.assertEqual(r, [(1, 2, 'x'), (1, 3, 'nm'), (2, 3, 'y')])
1939
1940    def testUpsertWithQuotedNames(self):
1941        upsert = self.db.upsert
1942        query = self.db.query
1943        table = 'test table for upsert()'
1944        self.createTable(table, '"Prime!" smallint primary key,'
1945            ' "much space" integer, "Questions?" text')
1946        s = {'Prime!': 31, 'much space': 9009, 'Questions?': 'Yes.'}
1947        try:
1948            r = upsert(table, s)
1949        except pg.ProgrammingError as error:
1950            if self.db.server_version < 90500:
1951                self.skipTest('database does not support upsert')
1952            self.fail(str(error))
1953        self.assertIs(r, s)
1954        self.assertEqual(r['Prime!'], 31)
1955        self.assertEqual(r['much space'], 9009)
1956        self.assertEqual(r['Questions?'], 'Yes.')
1957        q = 'select * from "%s" limit 2' % table
1958        r = query(q).getresult()
1959        self.assertEqual(r, [(31, 9009, 'Yes.')])
1960        s.update({'Questions?': 'No.'})
1961        r = upsert(table, s)
1962        self.assertIs(r, s)
1963        self.assertEqual(r['Prime!'], 31)
1964        self.assertEqual(r['much space'], 9009)
1965        self.assertEqual(r['Questions?'], 'No.')
1966        r = query(q).getresult()
1967        self.assertEqual(r, [(31, 9009, 'No.')])
1968
1969    def testClear(self):
1970        clear = self.db.clear
1971        f = False if pg.get_bool() else 'f'
1972        r = clear('test')
1973        result = dict(
1974            i2=0, i4=0, i8=0, d=0, f4=0, f8=0, m=0, v4='', c4='', t='')
1975        self.assertEqual(r, result)
1976        table = 'clear_test_table'
1977        self.createTable(table,
1978            'n integer, f float, b boolean, d date, t text', oids=True)
1979        r = clear(table)
1980        result = dict(n=0, f=0, b=f, d='', t='')
1981        self.assertEqual(r, result)
1982        r['a'] = r['f'] = r['n'] = 1
1983        r['d'] = r['t'] = 'x'
1984        r['b'] = 't'
1985        r['oid'] = long(1)
1986        r = clear(table, r)
1987        result = dict(a=1, n=0, f=0, b=f, d='', t='', oid=long(1))
1988        self.assertEqual(r, result)
1989
1990    def testClearWithQuotedNames(self):
1991        clear = self.db.clear
1992        table = 'test table for clear()'
1993        self.createTable(table, '"Prime!" smallint primary key,'
1994            ' "much space" integer, "Questions?" text')
1995        r = clear(table)
1996        self.assertIsInstance(r, dict)
1997        self.assertEqual(r['Prime!'], 0)
1998        self.assertEqual(r['much space'], 0)
1999        self.assertEqual(r['Questions?'], '')
2000
2001    def testDelete(self):
2002        delete = self.db.delete
2003        query = self.db.query
2004        self.assertRaises(pg.ProgrammingError, delete,
2005            'test', dict(i2=2, i4=4, i8=8))
2006        table = 'delete_test_table'
2007        self.createTable(table, 'n integer, t text', oids=True,
2008            values=enumerate('xyz', start=1))
2009        self.assertRaises(pg.ProgrammingError, self.db.get, table, 2)
2010        r = self.db.get(table, 1, 'n')
2011        s = delete(table, r)
2012        self.assertEqual(s, 1)
2013        r = self.db.get(table, 3, 'n')
2014        s = delete(table, r)
2015        self.assertEqual(s, 1)
2016        s = delete(table, r)
2017        self.assertEqual(s, 0)
2018        r = query('select * from "%s"' % table).dictresult()
2019        self.assertEqual(len(r), 1)
2020        r = r[0]
2021        result = {'n': 2, 't': 'y'}
2022        self.assertEqual(r, result)
2023        r = self.db.get(table, 2, 'n')
2024        s = delete(table, r)
2025        self.assertEqual(s, 1)
2026        s = delete(table, r)
2027        self.assertEqual(s, 0)
2028        self.assertRaises(pg.DatabaseError, self.db.get, table, 2, 'n')
2029        # not existing columns and oid parameter should be ignored
2030        r.update(m=3, u='z', oid='invalid')
2031        s = delete(table, r)
2032        self.assertEqual(s, 0)
2033
2034    def testDeleteWithOid(self):
2035        delete = self.db.delete
2036        get = self.db.get
2037        query = self.db.query
2038        self.createTable('test_table', 'n int', oids=True, values=range(1, 7))
2039        r = dict(n=3)
2040        self.assertRaises(pg.ProgrammingError, delete, 'test_table', r)
2041        s = get('test_table', 1, 'n')
2042        qoid = 'oid(test_table)'
2043        self.assertIn(qoid, s)
2044        r = delete('test_table', s)
2045        self.assertEqual(r, 1)
2046        r = delete('test_table', s)
2047        self.assertEqual(r, 0)
2048        q = "select min(n),count(n) from test_table"
2049        self.assertEqual(query(q).getresult()[0], (2, 5))
2050        oid = get('test_table', 2, 'n')[qoid]
2051        s = dict(oid=oid, n=2)
2052        self.assertRaises(pg.ProgrammingError, delete, 'test_table', s)
2053        r = delete('test_table', None, oid=oid)
2054        self.assertEqual(r, 1)
2055        r = delete('test_table', None, oid=oid)
2056        self.assertEqual(r, 0)
2057        self.assertEqual(query(q).getresult()[0], (3, 4))
2058        s = dict(oid=oid, n=2)
2059        oid = get('test_table', 3, 'n')[qoid]
2060        self.assertRaises(pg.ProgrammingError, delete, 'test_table', s)
2061        r = delete('test_table', s, oid=oid)
2062        self.assertEqual(r, 1)
2063        r = delete('test_table', s, oid=oid)
2064        self.assertEqual(r, 0)
2065        self.assertEqual(query(q).getresult()[0], (4, 3))
2066        s = get('test_table', 4, 'n')
2067        r = delete('test_table *', s)
2068        self.assertEqual(r, 1)
2069        r = delete('test_table *', s)
2070        self.assertEqual(r, 0)
2071        self.assertEqual(query(q).getresult()[0], (5, 2))
2072        oid = get('test_table', 5, 'n')[qoid]
2073        s = {qoid: oid, 'm': 4}
2074        r = delete('test_table', s, m=6)
2075        self.assertEqual(r, 1)
2076        r = delete('test_table *', s)
2077        self.assertEqual(r, 0)
2078        self.assertEqual(query(q).getresult()[0], (6, 1))
2079        query("alter table test_table add column m int")
2080        query("alter table test_table add primary key (n)")
2081        self.assertIn('m', self.db.get_attnames('test_table', flush=True))
2082        self.assertEqual('n', self.db.pkey('test_table', flush=True))
2083        for i in range(5):
2084            query("insert into test_table values (%d, %d)" % (i + 1, i + 2))
2085        s = dict(m=2)
2086        self.assertRaises(KeyError, delete, 'test_table', s)
2087        s = dict(m=2, oid=oid)
2088        self.assertRaises(KeyError, delete, 'test_table', s)
2089        r = delete('test_table', dict(m=2), oid=oid)
2090        self.assertEqual(r, 0)
2091        oid = get('test_table', 1, 'n')[qoid]
2092        s = dict(oid=oid)
2093        self.assertRaises(KeyError, delete, 'test_table', s)
2094        r = delete('test_table', s, oid=oid)
2095        self.assertEqual(r, 1)
2096        r = delete('test_table', s, oid=oid)
2097        self.assertEqual(r, 0)
2098        self.assertEqual(query(q).getresult()[0], (2, 5))
2099        s = get('test_table', 2, 'n')
2100        del s['n']
2101        r = delete('test_table', s)
2102        self.assertEqual(r, 1)
2103        r = delete('test_table', s)
2104        self.assertEqual(r, 0)
2105        self.assertEqual(query(q).getresult()[0], (3, 4))
2106        r = delete('test_table', n=3)
2107        self.assertEqual(r, 1)
2108        r = delete('test_table', n=3)
2109        self.assertEqual(r, 0)
2110        self.assertEqual(query(q).getresult()[0], (4, 3))
2111        r = delete('test_table', None, n=4)
2112        self.assertEqual(r, 1)
2113        r = delete('test_table', None, n=4)
2114        self.assertEqual(r, 0)
2115        self.assertEqual(query(q).getresult()[0], (5, 2))
2116        s = dict(n=6)
2117        r = delete('test_table', s, n=5)
2118        self.assertEqual(r, 1)
2119        r = delete('test_table', s, n=5)
2120        self.assertEqual(r, 0)
2121        self.assertEqual(query(q).getresult()[0], (6, 1))
2122
2123    def testDeleteWithCompositeKey(self):
2124        query = self.db.query
2125        table = 'delete_test_table_1'
2126        self.createTable(table, 'n integer primary key, t text',
2127            values=enumerate('abc', start=1))
2128        self.assertRaises(KeyError, self.db.delete, table, dict(t='b'))
2129        self.assertEqual(self.db.delete(table, dict(n=2)), 1)
2130        r = query('select t from "%s" where n=2' % table).getresult()
2131        self.assertEqual(r, [])
2132        self.assertEqual(self.db.delete(table, dict(n=2)), 0)
2133        r = query('select t from "%s" where n=3' % table).getresult()[0][0]
2134        self.assertEqual(r, 'c')
2135        table = 'delete_test_table_2'
2136        self.createTable(table,
2137            'n integer, m integer, t text, primary key (n, m)',
2138            values=[(n + 1, m + 1, chr(ord('a') + 2 * n + m))
2139                for n in range(3) for m in range(2)])
2140        self.assertRaises(KeyError, self.db.delete, table, dict(n=2, t='b'))
2141        self.assertEqual(self.db.delete(table, dict(n=2, m=2)), 1)
2142        r = [r[0] for r in query('select t from "%s" where n=2'
2143            ' order by m' % table).getresult()]
2144        self.assertEqual(r, ['c'])
2145        self.assertEqual(self.db.delete(table, dict(n=2, m=2)), 0)
2146        r = [r[0] for r in query('select t from "%s" where n=3'
2147            ' order by m' % table).getresult()]
2148        self.assertEqual(r, ['e', 'f'])
2149        self.assertEqual(self.db.delete(table, dict(n=3, m=1)), 1)
2150        r = [r[0] for r in query('select t from "%s" where n=3'
2151            ' order by m' % table).getresult()]
2152        self.assertEqual(r, ['f'])
2153
2154    def testDeleteWithQuotedNames(self):
2155        delete = self.db.delete
2156        query = self.db.query
2157        table = 'test table for delete()'
2158        self.createTable(table, '"Prime!" smallint primary key,'
2159            ' "much space" integer, "Questions?" text',
2160            values=[(19, 5005, 'Yes!')])
2161        r = {'Prime!': 17}
2162        r = delete(table, r)
2163        self.assertEqual(r, 0)
2164        r = query('select count(*) from "%s"' % table).getresult()
2165        self.assertEqual(r[0][0], 1)
2166        r = {'Prime!': 19}
2167        r = delete(table, r)
2168        self.assertEqual(r, 1)
2169        r = query('select count(*) from "%s"' % table).getresult()
2170        self.assertEqual(r[0][0], 0)
2171
2172    def testDeleteReferenced(self):
2173        delete = self.db.delete
2174        query = self.db.query
2175        self.createTable('test_parent',
2176            'n smallint primary key', values=range(3))
2177        self.createTable('test_child',
2178            'n smallint primary key references test_parent', values=range(3))
2179        q = ("select (select count(*) from test_parent),"
2180            " (select count(*) from test_child)")
2181        self.assertEqual(query(q).getresult()[0], (3, 3))
2182        self.assertRaises(pg.ProgrammingError,
2183            delete, 'test_parent', None, n=2)
2184        self.assertRaises(pg.ProgrammingError,
2185            delete, 'test_parent *', None, n=2)
2186        r = delete('test_child', None, n=2)
2187        self.assertEqual(r, 1)
2188        self.assertEqual(query(q).getresult()[0], (3, 2))
2189        r = delete('test_parent', None, n=2)
2190        self.assertEqual(r, 1)
2191        self.assertEqual(query(q).getresult()[0], (2, 2))
2192        self.assertRaises(pg.ProgrammingError,
2193            delete, 'test_parent', dict(n=0))
2194        self.assertRaises(pg.ProgrammingError,
2195            delete, 'test_parent *', dict(n=0))
2196        r = delete('test_child', dict(n=0))
2197        self.assertEqual(r, 1)
2198        self.assertEqual(query(q).getresult()[0], (2, 1))
2199        r = delete('test_child', dict(n=0))
2200        self.assertEqual(r, 0)
2201        r = delete('test_parent', dict(n=0))
2202        self.assertEqual(r, 1)
2203        self.assertEqual(query(q).getresult()[0], (1, 1))
2204        r = delete('test_parent', None, n=0)
2205        self.assertEqual(r, 0)
2206        q = "select n from test_parent natural join test_child limit 2"
2207        self.assertEqual(query(q).getresult(), [(1,)])
2208
2209    def testTruncate(self):
2210        truncate = self.db.truncate
2211        self.assertRaises(TypeError, truncate, None)
2212        self.assertRaises(TypeError, truncate, 42)
2213        self.assertRaises(TypeError, truncate, dict(test_table=None))
2214        query = self.db.query
2215        self.createTable('test_table', 'n smallint',
2216            temporary=False, values=[1] * 3)
2217        q = "select count(*) from test_table"
2218        r = query(q).getresult()[0][0]
2219        self.assertEqual(r, 3)
2220        truncate('test_table')
2221        r = query(q).getresult()[0][0]
2222        self.assertEqual(r, 0)
2223        for i in range(3):
2224            query("insert into test_table values (1)")
2225        r = query(q).getresult()[0][0]
2226        self.assertEqual(r, 3)
2227        truncate('public.test_table')
2228        r = query(q).getresult()[0][0]
2229        self.assertEqual(r, 0)
2230        self.createTable('test_table_2', 'n smallint', temporary=True)
2231        for t in (list, tuple, set):
2232            for i in range(3):
2233                query("insert into test_table values (1)")
2234                query("insert into test_table_2 values (2)")
2235            q = ("select (select count(*) from test_table),"
2236                " (select count(*) from test_table_2)")
2237            r = query(q).getresult()[0]
2238            self.assertEqual(r, (3, 3))
2239            truncate(t(['test_table', 'test_table_2']))
2240            r = query(q).getresult()[0]
2241            self.assertEqual(r, (0, 0))
2242
2243    def testTruncateRestart(self):
2244        truncate = self.db.truncate
2245        self.assertRaises(TypeError, truncate, 'test_table', restart='invalid')
2246        query = self.db.query
2247        self.createTable('test_table', 'n serial, t text')
2248        for n in range(3):
2249            query("insert into test_table (t) values ('test')")
2250        q = "select count(n), min(n), max(n) from test_table"
2251        r = query(q).getresult()[0]
2252        self.assertEqual(r, (3, 1, 3))
2253        truncate('test_table')
2254        r = query(q).getresult()[0]
2255        self.assertEqual(r, (0, None, None))
2256        for n in range(3):
2257            query("insert into test_table (t) values ('test')")
2258        r = query(q).getresult()[0]
2259        self.assertEqual(r, (3, 4, 6))
2260        truncate('test_table', restart=True)
2261        r = query(q).getresult()[0]
2262        self.assertEqual(r, (0, None, None))
2263        for n in range(3):
2264            query("insert into test_table (t) values ('test')")
2265        r = query(q).getresult()[0]
2266        self.assertEqual(r, (3, 1, 3))
2267
2268    def testTruncateCascade(self):
2269        truncate = self.db.truncate
2270        self.assertRaises(TypeError, truncate, 'test_table', cascade='invalid')
2271        query = self.db.query
2272        self.createTable('test_parent', 'n smallint primary key',
2273            values=range(3))
2274        self.createTable('test_child',
2275            'n smallint primary key references test_parent (n)',
2276            values=range(3))
2277        q = ("select (select count(*) from test_parent),"
2278            " (select count(*) from test_child)")
2279        r = query(q).getresult()[0]
2280        self.assertEqual(r, (3, 3))
2281        self.assertRaises(pg.ProgrammingError, truncate, 'test_parent')
2282        truncate(['test_parent', 'test_child'])
2283        r = query(q).getresult()[0]
2284        self.assertEqual(r, (0, 0))
2285        for n in range(3):
2286            query("insert into test_parent (n) values (%d)" % n)
2287            query("insert into test_child (n) values (%d)" % n)
2288        r = query(q).getresult()[0]
2289        self.assertEqual(r, (3, 3))
2290        truncate('test_parent', cascade=True)
2291        r = query(q).getresult()[0]
2292        self.assertEqual(r, (0, 0))
2293        for n in range(3):
2294            query("insert into test_parent (n) values (%d)" % n)
2295            query("insert into test_child (n) values (%d)" % n)
2296        r = query(q).getresult()[0]
2297        self.assertEqual(r, (3, 3))
2298        truncate('test_child')
2299        r = query(q).getresult()[0]
2300        self.assertEqual(r, (3, 0))
2301        self.assertRaises(pg.ProgrammingError, truncate, 'test_parent')
2302        truncate('test_parent', cascade=True)
2303        r = query(q).getresult()[0]
2304        self.assertEqual(r, (0, 0))
2305
2306    def testTruncateOnly(self):
2307        truncate = self.db.truncate
2308        self.assertRaises(TypeError, truncate, 'test_table', only='invalid')
2309        query = self.db.query
2310        self.createTable('test_parent', 'n smallint')
2311        self.createTable('test_child', 'm smallint) inherits (test_parent')
2312        for n in range(3):
2313            query("insert into test_parent (n) values (1)")
2314            query("insert into test_child (n, m) values (2, 3)")
2315        q = ("select (select count(*) from test_parent),"
2316            " (select count(*) from test_child)")
2317        r = query(q).getresult()[0]
2318        self.assertEqual(r, (6, 3))
2319        truncate('test_parent')
2320        r = query(q).getresult()[0]
2321        self.assertEqual(r, (0, 0))
2322        for n in range(3):
2323            query("insert into test_parent (n) values (1)")
2324            query("insert into test_child (n, m) values (2, 3)")
2325        r = query(q).getresult()[0]
2326        self.assertEqual(r, (6, 3))
2327        truncate('test_parent*')
2328        r = query(q).getresult()[0]
2329        self.assertEqual(r, (0, 0))
2330        for n in range(3):
2331            query("insert into test_parent (n) values (1)")
2332            query("insert into test_child (n, m) values (2, 3)")
2333        r = query(q).getresult()[0]
2334        self.assertEqual(r, (6, 3))
2335        truncate('test_parent', only=True)
2336        r = query(q).getresult()[0]
2337        self.assertEqual(r, (3, 3))
2338        truncate('test_parent', only=False)
2339        r = query(q).getresult()[0]
2340        self.assertEqual(r, (0, 0))
2341        self.assertRaises(ValueError, truncate, 'test_parent*', only=True)
2342        truncate('test_parent*', only=False)
2343        self.createTable('test_parent_2', 'n smallint')
2344        self.createTable('test_child_2', 'm smallint) inherits (test_parent_2')
2345        for t in '', '_2':
2346            for n in range(3):
2347                query("insert into test_parent%s (n) values (1)" % t)
2348                query("insert into test_child%s (n, m) values (2, 3)" % t)
2349        q = ("select (select count(*) from test_parent),"
2350            " (select count(*) from test_child),"
2351            " (select count(*) from test_parent_2),"
2352            " (select count(*) from test_child_2)")
2353        r = query(q).getresult()[0]
2354        self.assertEqual(r, (6, 3, 6, 3))
2355        truncate(['test_parent', 'test_parent_2'], only=[False, True])
2356        r = query(q).getresult()[0]
2357        self.assertEqual(r, (0, 0, 3, 3))
2358        truncate(['test_parent', 'test_parent_2'], only=False)
2359        r = query(q).getresult()[0]
2360        self.assertEqual(r, (0, 0, 0, 0))
2361        self.assertRaises(ValueError, truncate,
2362            ['test_parent*', 'test_child'], only=[True, False])
2363        truncate(['test_parent*', 'test_child'], only=[False, True])
2364
2365    def testTruncateQuoted(self):
2366        truncate = self.db.truncate
2367        query = self.db.query
2368        table = "test table for truncate()"
2369        self.createTable(table, 'n smallint', temporary=False, values=[1] * 3)
2370        q = 'select count(*) from "%s"' % table
2371        r = query(q).getresult()[0][0]
2372        self.assertEqual(r, 3)
2373        truncate(table)
2374        r = query(q).getresult()[0][0]
2375        self.assertEqual(r, 0)
2376        for i in range(3):
2377            query('insert into "%s" values (1)' % table)
2378        r = query(q).getresult()[0][0]
2379        self.assertEqual(r, 3)
2380        truncate('public."%s"' % table)
2381        r = query(q).getresult()[0][0]
2382        self.assertEqual(r, 0)
2383
2384    def testGetAsList(self):
2385        get_as_list = self.db.get_as_list
2386        self.assertRaises(TypeError, get_as_list)
2387        self.assertRaises(TypeError, get_as_list, None)
2388        query = self.db.query
2389        table = 'test_aslist'
2390        r = query('select 1 as colname').namedresult()[0]
2391        self.assertIsInstance(r, tuple)
2392        named = hasattr(r, 'colname')
2393        names = [(1, 'Homer'), (2, 'Marge'),
2394                (3, 'Bart'), (4, 'Lisa'), (5, 'Maggie')]
2395        self.createTable(table,
2396            'id smallint primary key, name varchar', values=names)
2397        r = get_as_list(table)
2398        self.assertIsInstance(r, list)
2399        self.assertEqual(r, names)
2400        for t, n in zip(r, names):
2401            self.assertIsInstance(t, tuple)
2402            self.assertEqual(t, n)
2403            if named:
2404                self.assertEqual(t.id, n[0])
2405                self.assertEqual(t.name, n[1])
2406                self.assertEqual(t._asdict(), dict(id=n[0], name=n[1]))
2407        r = get_as_list(table, what='name')
2408        self.assertIsInstance(r, list)
2409        expected = sorted((row[1],) for row in names)
2410        self.assertEqual(r, expected)
2411        r = get_as_list(table, what='name, id')
2412        self.assertIsInstance(r, list)
2413        expected = sorted(tuple(reversed(row)) for row in names)
2414        self.assertEqual(r, expected)
2415        r = get_as_list(table, what=['name', 'id'])
2416        self.assertIsInstance(r, list)
2417        self.assertEqual(r, expected)
2418        r = get_as_list(table, where="name like 'Ba%'")
2419        self.assertIsInstance(r, list)
2420        self.assertEqual(r, names[2:3])
2421        r = get_as_list(table, what='name', where="name like 'Ma%'")
2422        self.assertIsInstance(r, list)
2423        self.assertEqual(r, [('Maggie',), ('Marge',)])
2424        r = get_as_list(table, what='name',
2425                        where=["name like 'Ma%'", "name like '%r%'"])
2426        self.assertIsInstance(r, list)
2427        self.assertEqual(r, [('Marge',)])
2428        r = get_as_list(table, what='name', order='id')
2429        self.assertIsInstance(r, list)
2430        expected = [(row[1],) for row in names]
2431        self.assertEqual(r, expected)
2432        r = get_as_list(table, what=['name'], order=['id'])
2433        self.assertIsInstance(r, list)
2434        self.assertEqual(r, expected)
2435        r = get_as_list(table, what=['id', 'name'], order=['id', 'name'])
2436        self.assertIsInstance(r, list)
2437        self.assertEqual(r, names)
2438        r = get_as_list(table, what='id * 2 as num', order='id desc')
2439        self.assertIsInstance(r, list)
2440        expected = [(n,) for n in range(10, 0, -2)]
2441        self.assertEqual(r, expected)
2442        r = get_as_list(table, limit=2)
2443        self.assertIsInstance(r, list)
2444        self.assertEqual(r, names[:2])
2445        r = get_as_list(table, offset=3)
2446        self.assertIsInstance(r, list)
2447        self.assertEqual(r, names[3:])
2448        r = get_as_list(table, limit=1, offset=2)
2449        self.assertIsInstance(r, list)
2450        self.assertEqual(r, names[2:3])
2451        r = get_as_list(table, scalar=True)
2452        self.assertIsInstance(r, list)
2453        self.assertEqual(r, list(range(1, 6)))
2454        r = get_as_list(table, what='name', scalar=True)
2455        self.assertIsInstance(r, list)
2456        expected = sorted(row[1] for row in names)
2457        self.assertEqual(r, expected)
2458        r = get_as_list(table, what='name', limit=1, scalar=True)
2459        self.assertIsInstance(r, list)
2460        self.assertEqual(r, expected[:1])
2461        query('alter table "%s" drop constraint "%s_pkey"' % (table, table))
2462        self.assertRaises(KeyError, self.db.pkey, table, flush=True)
2463        names.insert(1, (1, 'Snowball'))
2464        query('insert into "%s" values ($1, $2)' % table, (1, 'Snowball'))
2465        r = get_as_list(table)
2466        self.assertIsInstance(r, list)
2467        self.assertEqual(r, names)
2468        r = get_as_list(table, what='name', where='id=1', scalar=True)
2469        self.assertIsInstance(r, list)
2470        self.assertEqual(r, ['Homer', 'Snowball'])
2471        # test with unordered query
2472        r = get_as_list(table, order=False)
2473        self.assertIsInstance(r, list)
2474        self.assertEqual(set(r), set(names))
2475        # test with arbitrary from clause
2476        from_table = '(select lower(name) as n2 from "%s") as t2' % table
2477        r = get_as_list(from_table)
2478        self.assertIsInstance(r, list)
2479        r = set(row[0] for row in r)
2480        expected = set(row[1].lower() for row in names)
2481        self.assertEqual(r, expected)
2482        r = get_as_list(from_table, order='n2', scalar=True)
2483        self.assertIsInstance(r, list)
2484        self.assertEqual(r, sorted(expected))
2485        r = get_as_list(from_table, order='n2', limit=1)
2486        self.assertIsInstance(r, list)
2487        self.assertEqual(len(r), 1)
2488        t = r[0]
2489        self.assertIsInstance(t, tuple)
2490        if named:
2491            self.assertEqual(t.n2, 'bart')
2492            self.assertEqual(t._asdict(), dict(n2='bart'))
2493        else:
2494            self.assertEqual(t, ('bart',))
2495
2496    def testGetAsDict(self):
2497        get_as_dict = self.db.get_as_dict
2498        self.assertRaises(TypeError, get_as_dict)
2499        self.assertRaises(TypeError, get_as_dict, None)
2500        # the test table has no primary key
2501        self.assertRaises(pg.ProgrammingError, get_as_dict, 'test')
2502        query = self.db.query
2503        table = 'test_asdict'
2504        r = query('select 1 as colname').namedresult()[0]
2505        self.assertIsInstance(r, tuple)
2506        named = hasattr(r, 'colname')
2507        colors = [(1, '#7cb9e8', 'Aero'), (2, '#b5a642', 'Brass'),
2508                  (3, '#b2ffff', 'Celeste'), (4, '#c19a6b', 'Desert')]
2509        self.createTable(table,
2510            'id smallint primary key, rgb char(7), name varchar',
2511            values=colors)
2512        # keyname must be string, list or tuple
2513        self.assertRaises(KeyError, get_as_dict, table, 3)
2514        self.assertRaises(KeyError, get_as_dict, table, dict(id=None))
2515        # missing keyname in row
2516        self.assertRaises(KeyError, get_as_dict, table,
2517                          keyname='rgb', what='name')
2518        r = get_as_dict(table)
2519        self.assertIsInstance(r, OrderedDict)
2520        expected = OrderedDict((row[0], row[1:]) for row in colors)
2521        self.assertEqual(r, expected)
2522        for key in r:
2523            self.assertIsInstance(key, int)
2524            self.assertIn(key, expected)
2525            row = r[key]
2526            self.assertIsInstance(row, tuple)
2527            t = expected[key]
2528            self.assertEqual(row, t)
2529            if named:
2530                self.assertEqual(row.rgb, t[0])
2531                self.assertEqual(row.name, t[1])
2532                self.assertEqual(row._asdict(), dict(rgb=t[0], name=t[1]))
2533        if OrderedDict is not dict:  # Python > 2.6
2534            self.assertEqual(r.keys(), expected.keys())
2535        r = get_as_dict(table, keyname='rgb')
2536        self.assertIsInstance(r, OrderedDict)
2537        expected = OrderedDict((row[1], (row[0], row[2]))
2538            for row in sorted(colors, key=itemgetter(1)))
2539        self.assertEqual(r, expected)
2540        for key in r:
2541            self.assertIsInstance(key, str)
2542            self.assertIn(key, expected)
2543            row = r[key]
2544            self.assertIsInstance(row, tuple)
2545            t = expected[key]
2546            self.assertEqual(row, t)
2547            if named:
2548                self.assertEqual(row.id, t[0])
2549                self.assertEqual(row.name, t[1])
2550                self.assertEqual(row._asdict(), dict(id=t[0], name=t[1]))
2551        if OrderedDict is not dict:  # Python > 2.6
2552            self.assertEqual(r.keys(), expected.keys())
2553        r = get_as_dict(table, keyname=['id', 'rgb'])
2554        self.assertIsInstance(r, OrderedDict)
2555        expected = OrderedDict((row[:2], row[2:]) for row in colors)
2556        self.assertEqual(r, expected)
2557        for key in r:
2558            self.assertIsInstance(key, tuple)
2559            self.assertIsInstance(key[0], int)
2560            self.assertIsInstance(key[1], str)
2561            if named:
2562                self.assertEqual(key, (key.id, key.rgb))
2563                self.assertEqual(key._fields, ('id', 'rgb'))
2564            row = r[key]
2565            self.assertIsInstance(row, tuple)
2566            self.assertIsInstance(row[0], str)
2567            t = expected[key]
2568            self.assertEqual(row, t)
2569            if named:
2570                self.assertEqual(row.name, t[0])
2571                self.assertEqual(row._asdict(), dict(name=t[0]))
2572        if OrderedDict is not dict:  # Python > 2.6
2573            self.assertEqual(r.keys(), expected.keys())
2574        r = get_as_dict(table, keyname=['id', 'rgb'], scalar=True)
2575        self.assertIsInstance(r, OrderedDict)
2576        expected = OrderedDict((row[:2], row[2]) for row in colors)
2577        self.assertEqual(r, expected)
2578        for key in r:
2579            self.assertIsInstance(key, tuple)
2580            row = r[key]
2581            self.assertIsInstance(row, str)
2582            t = expected[key]
2583            self.assertEqual(row, t)
2584        if OrderedDict is not dict:  # Python > 2.6
2585            self.assertEqual(r.keys(), expected.keys())
2586        r = get_as_dict(table, keyname='rgb', what=['rgb', 'name'], scalar=True)
2587        self.assertIsInstance(r, OrderedDict)
2588        expected = OrderedDict((row[1], row[2])
2589            for row in sorted(colors, key=itemgetter(1)))
2590        self.assertEqual(r, expected)
2591        for key in r:
2592            self.assertIsInstance(key, str)
2593            row = r[key]
2594            self.assertIsInstance(row, str)
2595            t = expected[key]
2596            self.assertEqual(row, t)
2597        if OrderedDict is not dict:  # Python > 2.6
2598            self.assertEqual(r.keys(), expected.keys())
2599        r = get_as_dict(table, what='id, name',
2600                        where="rgb like '#b%'", scalar=True)
2601        self.assertIsInstance(r, OrderedDict)
2602        expected = OrderedDict((row[0], row[2]) for row in colors[1:3])
2603        self.assertEqual(r, expected)
2604        for key in r:
2605            self.assertIsInstance(key, int)
2606            row = r[key]
2607            self.assertIsInstance(row, str)
2608            t = expected[key]
2609            self.assertEqual(row, t)
2610        if OrderedDict is not dict:  # Python > 2.6
2611            self.assertEqual(r.keys(), expected.keys())
2612        expected = r
2613        r = get_as_dict(table, what=['name', 'id'],
2614                        where=['id > 1', 'id < 4', "rgb like '#b%'",
2615                   "name not like 'A%'", "name not like '%t'"], scalar=True)
2616        self.assertEqual(r, expected)
2617        r = get_as_dict(table, what='name, id', limit=2, offset=1, scalar=True)
2618        self.assertEqual(r, expected)
2619        r = get_as_dict(table, keyname=('id',), what=('name', 'id'),
2620                        where=('id > 1', 'id < 4'), order=('id',), scalar=True)
2621        self.assertEqual(r, expected)
2622        r = get_as_dict(table, limit=1)
2623        self.assertEqual(len(r), 1)
2624        self.assertEqual(r[1][1], 'Aero')
2625        r = get_as_dict(table, offset=3)
2626        self.assertEqual(len(r), 1)
2627        self.assertEqual(r[4][1], 'Desert')
2628        r = get_as_dict(table, order='id desc')
2629        expected = OrderedDict((row[0], row[1:]) for row in reversed(colors))
2630        self.assertEqual(r, expected)
2631        r = get_as_dict(table, where='id > 5')
2632        self.assertIsInstance(r, OrderedDict)
2633        self.assertEqual(len(r), 0)
2634        # test with unordered query
2635        expected = dict((row[0], row[1:]) for row in colors)
2636        r = get_as_dict(table, order=False)
2637        self.assertIsInstance(r, dict)
2638        self.assertEqual(r, expected)
2639        if dict is not OrderedDict:  # Python > 2.6
2640            self.assertNotIsInstance(self, OrderedDict)
2641        # test with arbitrary from clause
2642        from_table = '(select id, lower(name) as n2 from "%s") as t2' % table
2643        # primary key must be passed explicitly in this case
2644        self.assertRaises(pg.ProgrammingError, get_as_dict, from_table)
2645        r = get_as_dict(from_table, 'id')
2646        self.assertIsInstance(r, OrderedDict)
2647        expected = OrderedDict((row[0], (row[2].lower(),)) for row in colors)
2648        self.assertEqual(r, expected)
2649        # test without a primary key
2650        query('alter table "%s" drop constraint "%s_pkey"' % (table, table))
2651        self.assertRaises(KeyError, self.db.pkey, table, flush=True)
2652        self.assertRaises(pg.ProgrammingError, get_as_dict, table)
2653        r = get_as_dict(table, keyname='id')
2654        expected = OrderedDict((row[0], row[1:]) for row in colors)
2655        self.assertIsInstance(r, dict)
2656        self.assertEqual(r, expected)
2657        r = (1, '#007fff', 'Azure')
2658        query('insert into "%s" values ($1, $2, $3)' % table, r)
2659        # the last entry will win
2660        expected[1] = r[1:]
2661        r = get_as_dict(table, keyname='id')
2662        self.assertEqual(r, expected)
2663
2664    def testTransaction(self):
2665        query = self.db.query
2666        self.createTable('test_table', 'n integer', temporary=False)
2667        self.db.begin()
2668        query("insert into test_table values (1)")
2669        query("insert into test_table values (2)")
2670        self.db.commit()
2671        self.db.begin()
2672        query("insert into test_table values (3)")
2673        query("insert into test_table values (4)")
2674        self.db.rollback()
2675        self.db.begin()
2676        query("insert into test_table values (5)")
2677        self.db.savepoint('before6')
2678        query("insert into test_table values (6)")
2679        self.db.rollback('before6')
2680        query("insert into test_table values (7)")
2681        self.db.commit()
2682        self.db.begin()
2683        self.db.savepoint('before8')
2684        query("insert into test_table values (8)")
2685        self.db.release('before8')
2686        self.assertRaises(pg.ProgrammingError, self.db.rollback, 'before8')
2687        self.db.commit()
2688        self.db.start()
2689        query("insert into test_table values (9)")
2690        self.db.end()
2691        r = [r[0] for r in query(
2692            "select * from test_table order by 1").getresult()]
2693        self.assertEqual(r, [1, 2, 5, 7, 9])
2694        self.db.begin(mode='read only')
2695        self.assertRaises(pg.ProgrammingError,
2696            query, "insert into test_table values (0)")
2697        self.db.rollback()
2698        self.db.start(mode='Read Only')
2699        self.assertRaises(pg.ProgrammingError,
2700            query, "insert into test_table values (0)")
2701        self.db.abort()
2702
2703    def testTransactionAliases(self):
2704        self.assertEqual(self.db.begin, self.db.start)
2705        self.assertEqual(self.db.commit, self.db.end)
2706        self.assertEqual(self.db.rollback, self.db.abort)
2707
2708    def testContextManager(self):
2709        query = self.db.query
2710        self.createTable('test_table', 'n integer check(n>0)')
2711        with self.db:
2712            query("insert into test_table values (1)")
2713            query("insert into test_table values (2)")
2714        try:
2715            with self.db:
2716                query("insert into test_table values (3)")
2717                query("insert into test_table values (4)")
2718                raise ValueError('test transaction should rollback')
2719        except ValueError as error:
2720            self.assertEqual(str(error), 'test transaction should rollback')
2721        with self.db:
2722            query("insert into test_table values (5)")
2723        try:
2724            with self.db:
2725                query("insert into test_table values (6)")
2726                query("insert into test_table values (-1)")
2727        except pg.ProgrammingError as error:
2728            self.assertTrue('check' in str(error))
2729        with self.db:
2730            query("insert into test_table values (7)")
2731        r = [r[0] for r in query(
2732            "select * from test_table order by 1").getresult()]
2733        self.assertEqual(r, [1, 2, 5, 7])
2734
2735    def testBytea(self):
2736        query = self.db.query
2737        self.createTable('bytea_test', 'n smallint primary key, data bytea')
2738        s = b"It's all \\ kinds \x00 of\r nasty \xff stuff!\n"
2739        r = self.db.escape_bytea(s)
2740        query('insert into bytea_test values(3,$1)', (r,))
2741        r = query('select * from bytea_test where n=3').getresult()
2742        self.assertEqual(len(r), 1)
2743        r = r[0]
2744        self.assertEqual(len(r), 2)
2745        self.assertEqual(r[0], 3)
2746        r = r[1]
2747        self.assertIsInstance(r, str)
2748        r = self.db.unescape_bytea(r)
2749        self.assertIsInstance(r, bytes)
2750        self.assertEqual(r, s)
2751
2752    def testInsertUpdateGetBytea(self):
2753        query = self.db.query
2754        self.createTable('bytea_test', 'n smallint primary key, data bytea')
2755        # insert null value
2756        r = self.db.insert('bytea_test', n=0, data=None)
2757        self.assertIsInstance(r, dict)
2758        self.assertIn('n', r)
2759        self.assertEqual(r['n'], 0)
2760        self.assertIn('data', r)
2761        self.assertIsNone(r['data'])
2762        s = b'None'
2763        r = self.db.update('bytea_test', n=0, data=s)
2764        self.assertIsInstance(r, dict)
2765        self.assertIn('n', r)
2766        self.assertEqual(r['n'], 0)
2767        self.assertIn('data', r)
2768        r = r['data']
2769        self.assertIsInstance(r, bytes)
2770        self.assertEqual(r, s)
2771        r = self.db.update('bytea_test', n=0, data=None)
2772        self.assertIsNone(r['data'])
2773        # insert as bytes
2774        s = b"It's all \\ kinds \x00 of\r nasty \xff stuff!\n"
2775        r = self.db.insert('bytea_test', n=5, data=s)
2776        self.assertIsInstance(r, dict)
2777        self.assertIn('n', r)
2778        self.assertEqual(r['n'], 5)
2779        self.assertIn('data', r)
2780        r = r['data']
2781        self.assertIsInstance(r, bytes)
2782        self.assertEqual(r, s)
2783        # update as bytes
2784        s += b"and now even more \x00 nasty \t stuff!\f"
2785        r = self.db.update('bytea_test', n=5, data=s)
2786        self.assertIsInstance(r, dict)
2787        self.assertIn('n', r)
2788        self.assertEqual(r['n'], 5)
2789        self.assertIn('data', r)
2790        r = r['data']
2791        self.assertIsInstance(r, bytes)
2792        self.assertEqual(r, s)
2793        r = query('select * from bytea_test where n=5').getresult()
2794        self.assertEqual(len(r), 1)
2795        r = r[0]
2796        self.assertEqual(len(r), 2)
2797        self.assertEqual(r[0], 5)
2798        r = r[1]
2799        self.assertIsInstance(r, str)
2800        r = self.db.unescape_bytea(r)
2801        self.assertIsInstance(r, bytes)
2802        self.assertEqual(r, s)
2803        r = self.db.get('bytea_test', dict(n=5))
2804        self.assertIsInstance(r, dict)
2805        self.assertIn('n', r)
2806        self.assertEqual(r['n'], 5)
2807        self.assertIn('data', r)
2808        r = r['data']
2809        self.assertIsInstance(r, bytes)
2810        self.assertEqual(r, s)
2811
2812    def testUpsertBytea(self):
2813        self.createTable('bytea_test', 'n smallint primary key, data bytea')
2814        s = b"It's all \\ kinds \x00 of\r nasty \xff stuff!\n"
2815        r = dict(n=7, data=s)
2816        try:
2817            r = self.db.upsert('bytea_test', r)
2818        except pg.ProgrammingError as error:
2819            if self.db.server_version < 90500:
2820                self.skipTest('database does not support upsert')
2821            self.fail(str(error))
2822        self.assertIsInstance(r, dict)
2823        self.assertIn('n', r)
2824        self.assertEqual(r['n'], 7)
2825        self.assertIn('data', r)
2826        self.assertIsInstance(r['data'], bytes)
2827        self.assertEqual(r['data'], s)
2828        r['data'] = None
2829        r = self.db.upsert('bytea_test', r)
2830        self.assertIsInstance(r, dict)
2831        self.assertIn('n', r)
2832        self.assertEqual(r['n'], 7)
2833        self.assertIn('data', r)
2834        self.assertIsNone(r['data'], bytes)
2835
2836    def testInsertGetJson(self):
2837        try:
2838            self.createTable('json_test', 'n smallint primary key, data json')
2839        except pg.ProgrammingError as error:
2840            if self.db.server_version < 90200:
2841                self.skipTest('database does not support json')
2842            self.fail(str(error))
2843        jsondecode = pg.get_jsondecode()
2844        # insert null value
2845        r = self.db.insert('json_test', n=0, data=None)
2846        self.assertIsInstance(r, dict)
2847        self.assertIn('n', r)
2848        self.assertEqual(r['n'], 0)
2849        self.assertIn('data', r)
2850        self.assertIsNone(r['data'])
2851        r = self.db.get('json_test', 0)
2852        self.assertIsInstance(r, dict)
2853        self.assertIn('n', r)
2854        self.assertEqual(r['n'], 0)
2855        self.assertIn('data', r)
2856        self.assertIsNone(r['data'])
2857        # insert JSON object
2858        data = {
2859          "id": 1, "name": "Foo", "price": 1234.5,
2860          "new": True, "note": None,
2861          "tags": ["Bar", "Eek"],
2862          "stock": {"warehouse": 300, "retail": 20}}
2863        r = self.db.insert('json_test', n=1, data=data)
2864        self.assertIsInstance(r, dict)
2865        self.assertIn('n', r)
2866        self.assertEqual(r['n'], 1)
2867        self.assertIn('data', r)
2868        r = r['data']
2869        if jsondecode is None:
2870            self.assertIsInstance(r, str)
2871            r = json.loads(r)
2872        self.assertIsInstance(r, dict)
2873        self.assertEqual(r, data)
2874        self.assertIsInstance(r['id'], int)
2875        self.assertIsInstance(r['name'], unicode)
2876        self.assertIsInstance(r['price'], float)
2877        self.assertIsInstance(r['new'], bool)
2878        self.assertIsInstance(r['tags'], list)
2879        self.assertIsInstance(r['stock'], dict)
2880        r = self.db.get('json_test', 1)
2881        self.assertIsInstance(r, dict)
2882        self.assertIn('n', r)
2883        self.assertEqual(r['n'], 1)
2884        self.assertIn('data', r)
2885        r = r['data']
2886        if jsondecode is None:
2887            self.assertIsInstance(r, str)
2888            r = json.loads(r)
2889        self.assertIsInstance(r, dict)
2890        self.assertEqual(r, data)
2891        self.assertIsInstance(r['id'], int)
2892        self.assertIsInstance(r['name'], unicode)
2893        self.assertIsInstance(r['price'], float)
2894        self.assertIsInstance(r['new'], bool)
2895        self.assertIsInstance(r['tags'], list)
2896        self.assertIsInstance(r['stock'], dict)
2897
2898    def testInsertGetJsonb(self):
2899        try:
2900            self.createTable('jsonb_test',
2901                'n smallint primary key, data jsonb')
2902        except pg.ProgrammingError as error:
2903            if self.db.server_version < 90400:
2904                self.skipTest('database does not support jsonb')
2905            self.fail(str(error))
2906        jsondecode = pg.get_jsondecode()
2907        # insert null value
2908        r = self.db.insert('jsonb_test', n=0, data=None)
2909        self.assertIsInstance(r, dict)
2910        self.assertIn('n', r)
2911        self.assertEqual(r['n'], 0)
2912        self.assertIn('data', r)
2913        self.assertIsNone(r['data'])
2914        r = self.db.get('jsonb_test', 0)
2915        self.assertIsInstance(r, dict)
2916        self.assertIn('n', r)
2917        self.assertEqual(r['n'], 0)
2918        self.assertIn('data', r)
2919        self.assertIsNone(r['data'])
2920        # insert JSON object
2921        data = {
2922          "id": 1, "name": "Foo", "price": 1234.5,
2923          "new": True, "note": None,
2924          "tags": ["Bar", "Eek"],
2925          "stock": {"warehouse": 300, "retail": 20}}
2926        r = self.db.insert('jsonb_test', n=1, data=data)
2927        self.assertIsInstance(r, dict)
2928        self.assertIn('n', r)
2929        self.assertEqual(r['n'], 1)
2930        self.assertIn('data', r)
2931        r = r['data']
2932        if jsondecode is None:
2933            self.assertIsInstance(r, str)
2934            r = json.loads(r)
2935        self.assertIsInstance(r, dict)
2936        self.assertEqual(r, data)
2937        self.assertIsInstance(r['id'], int)
2938        self.assertIsInstance(r['name'], unicode)
2939        self.assertIsInstance(r['price'], float)
2940        self.assertIsInstance(r['new'], bool)
2941        self.assertIsInstance(r['tags'], list)
2942        self.assertIsInstance(r['stock'], dict)
2943        r = self.db.get('jsonb_test', 1)
2944        self.assertIsInstance(r, dict)
2945        self.assertIn('n', r)
2946        self.assertEqual(r['n'], 1)
2947        self.assertIn('data', r)
2948        r = r['data']
2949        if jsondecode is None:
2950            self.assertIsInstance(r, str)
2951            r = json.loads(r)
2952        self.assertIsInstance(r, dict)
2953        self.assertEqual(r, data)
2954        self.assertIsInstance(r['id'], int)
2955        self.assertIsInstance(r['name'], unicode)
2956        self.assertIsInstance(r['price'], float)
2957        self.assertIsInstance(r['new'], bool)
2958        self.assertIsInstance(r['tags'], list)
2959        self.assertIsInstance(r['stock'], dict)
2960
2961    def testNotificationHandler(self):
2962        # the notification handler itself is tested separately
2963        f = self.db.notification_handler
2964        callback = lambda arg_dict: None
2965        handler = f('test', callback)
2966        self.assertIsInstance(handler, pg.NotificationHandler)
2967        self.assertIs(handler.db, self.db)
2968        self.assertEqual(handler.event, 'test')
2969        self.assertEqual(handler.stop_event, 'stop_test')
2970        self.assertIs(handler.callback, callback)
2971        self.assertIsInstance(handler.arg_dict, dict)
2972        self.assertEqual(handler.arg_dict, {})
2973        self.assertIsNone(handler.timeout)
2974        self.assertFalse(handler.listening)
2975        handler.close()
2976        self.assertIsNone(handler.db)
2977        self.db.reopen()
2978        self.assertIsNone(handler.db)
2979        handler = f('test2', callback, timeout=2)
2980        self.assertIsInstance(handler, pg.NotificationHandler)
2981        self.assertIs(handler.db, self.db)
2982        self.assertEqual(handler.event, 'test2')
2983        self.assertEqual(handler.stop_event, 'stop_test2')
2984        self.assertIs(handler.callback, callback)
2985        self.assertIsInstance(handler.arg_dict, dict)
2986        self.assertEqual(handler.arg_dict, {})
2987        self.assertEqual(handler.timeout, 2)
2988        self.assertFalse(handler.listening)
2989        handler.close()
2990        self.assertIsNone(handler.db)
2991        self.db.reopen()
2992        self.assertIsNone(handler.db)
2993        arg_dict = {'testing': 3}
2994        handler = f('test3', callback, arg_dict=arg_dict)
2995        self.assertIsInstance(handler, pg.NotificationHandler)
2996        self.assertIs(handler.db, self.db)
2997        self.assertEqual(handler.event, 'test3')
2998        self.assertEqual(handler.stop_event, 'stop_test3')
2999        self.assertIs(handler.callback, callback)
3000        self.assertIs(handler.arg_dict, arg_dict)
3001        self.assertEqual(arg_dict['testing'], 3)
3002        self.assertIsNone(handler.timeout)
3003        self.assertFalse(handler.listening)
3004        handler.close()
3005        self.assertIsNone(handler.db)
3006        self.db.reopen()
3007        self.assertIsNone(handler.db)
3008        handler = f('test4', callback, stop_event='stop4')
3009        self.assertIsInstance(handler, pg.NotificationHandler)
3010        self.assertIs(handler.db, self.db)
3011        self.assertEqual(handler.event, 'test4')
3012        self.assertEqual(handler.stop_event, 'stop4')
3013        self.assertIs(handler.callback, callback)
3014        self.assertIsInstance(handler.arg_dict, dict)
3015        self.assertEqual(handler.arg_dict, {})
3016        self.assertIsNone(handler.timeout)
3017        self.assertFalse(handler.listening)
3018        handler.close()
3019        self.assertIsNone(handler.db)
3020        self.db.reopen()
3021        self.assertIsNone(handler.db)
3022        arg_dict = {'testing': 5}
3023        handler = f('test5', callback, arg_dict, 1.5, 'stop5')
3024        self.assertIsInstance(handler, pg.NotificationHandler)
3025        self.assertIs(handler.db, self.db)
3026        self.assertEqual(handler.event, 'test5')
3027        self.assertEqual(handler.stop_event, 'stop5')
3028        self.assertIs(handler.callback, callback)
3029        self.assertIs(handler.arg_dict, arg_dict)
3030        self.assertEqual(arg_dict['testing'], 5)
3031        self.assertEqual(handler.timeout, 1.5)
3032        self.assertFalse(handler.listening)
3033        handler.close()
3034        self.assertIsNone(handler.db)
3035        self.db.reopen()
3036        self.assertIsNone(handler.db)
3037
3038
3039class TestDBClassNonStdOpts(TestDBClass):
3040    """Test the methods of the DB class with non-standard global options."""
3041
3042    @classmethod
3043    def setUpClass(cls):
3044        cls.saved_options = {}
3045        cls.set_option('decimal', float)
3046        not_bool = not pg.get_bool()
3047        cls.set_option('bool', not_bool)
3048        cls.set_option('namedresult', None)
3049        cls.set_option('jsondecode', None)
3050        super(TestDBClassNonStdOpts, cls).setUpClass()
3051
3052    @classmethod
3053    def tearDownClass(cls):
3054        super(TestDBClassNonStdOpts, cls).tearDownClass()
3055        cls.reset_option('jsondecode')
3056        cls.reset_option('namedresult')
3057        cls.reset_option('bool')
3058        cls.reset_option('decimal')
3059
3060    @classmethod
3061    def set_option(cls, option, value):
3062        cls.saved_options[option] = getattr(pg, 'get_' + option)()
3063        return getattr(pg, 'set_' + option)(value)
3064
3065    @classmethod
3066    def reset_option(cls, option):
3067        return getattr(pg, 'set_' + option)(cls.saved_options[option])
3068
3069
3070class TestSchemas(unittest.TestCase):
3071    """Test correct handling of schemas (namespaces)."""
3072
3073    cls_set_up = False
3074
3075    @classmethod
3076    def setUpClass(cls):
3077        db = DB()
3078        query = db.query
3079        for num_schema in range(5):
3080            if num_schema:
3081                schema = "s%d" % num_schema
3082                query("drop schema if exists %s cascade" % (schema,))
3083                try:
3084                    query("create schema %s" % (schema,))
3085                except pg.ProgrammingError:
3086                    raise RuntimeError("The test user cannot create schemas.\n"
3087                        "Grant create on database %s to the user"
3088                        " for running these tests." % dbname)
3089            else:
3090                schema = "public"
3091                query("drop table if exists %s.t" % (schema,))
3092                query("drop table if exists %s.t%d" % (schema, num_schema))
3093            query("create table %s.t with oids as select 1 as n, %d as d"
3094                  % (schema, num_schema))
3095            query("create table %s.t%d with oids as select 1 as n, %d as d"
3096                  % (schema, num_schema, num_schema))
3097        db.close()
3098        cls.cls_set_up = True
3099
3100    @classmethod
3101    def tearDownClass(cls):
3102        db = DB()
3103        query = db.query
3104        for num_schema in range(5):
3105            if num_schema:
3106                schema = "s%d" % num_schema
3107                query("drop schema %s cascade" % (schema,))
3108            else:
3109                schema = "public"
3110                query("drop table %s.t" % (schema,))
3111                query("drop table %s.t%d" % (schema, num_schema))
3112        db.close()
3113
3114    def setUp(self):
3115        self.assertTrue(self.cls_set_up)
3116        self.db = DB()
3117
3118    def tearDown(self):
3119        self.doCleanups()
3120        self.db.close()
3121
3122    def testGetTables(self):
3123        tables = self.db.get_tables()
3124        for num_schema in range(5):
3125            if num_schema:
3126                schema = "s" + str(num_schema)
3127            else:
3128                schema = "public"
3129            for t in (schema + ".t",
3130                    schema + ".t" + str(num_schema)):
3131                self.assertIn(t, tables)
3132
3133    def testGetAttnames(self):
3134        get_attnames = self.db.get_attnames
3135        query = self.db.query
3136        result = {'oid': 'int', 'd': 'int', 'n': 'int'}
3137        r = get_attnames("t")
3138        self.assertEqual(r, result)
3139        r = get_attnames("s4.t4")
3140        self.assertEqual(r, result)
3141        query("drop table if exists s3.t3m")
3142        self.addCleanup(query, "drop table s3.t3m")
3143        query("create table s3.t3m with oids as select 1 as m")
3144        result_m = {'oid': 'int', 'm': 'int'}
3145        r = get_attnames("s3.t3m")
3146        self.assertEqual(r, result_m)
3147        query("set search_path to s1,s3")
3148        r = get_attnames("t3")
3149        self.assertEqual(r, result)
3150        r = get_attnames("t3m")
3151        self.assertEqual(r, result_m)
3152
3153    def testGet(self):
3154        get = self.db.get
3155        query = self.db.query
3156        PrgError = pg.ProgrammingError
3157        self.assertEqual(get("t", 1, 'n')['d'], 0)
3158        self.assertEqual(get("t0", 1, 'n')['d'], 0)
3159        self.assertEqual(get("public.t", 1, 'n')['d'], 0)
3160        self.assertEqual(get("public.t0", 1, 'n')['d'], 0)
3161        self.assertRaises(PrgError, get, "public.t1", 1, 'n')
3162        self.assertEqual(get("s1.t1", 1, 'n')['d'], 1)
3163        self.assertEqual(get("s3.t", 1, 'n')['d'], 3)
3164        query("set search_path to s2,s4")
3165        self.assertRaises(PrgError, get, "t1", 1, 'n')
3166        self.assertEqual(get("t4", 1, 'n')['d'], 4)
3167        self.assertRaises(PrgError, get, "t3", 1, 'n')
3168        self.assertEqual(get("t", 1, 'n')['d'], 2)
3169        self.assertEqual(get("s3.t3", 1, 'n')['d'], 3)
3170        query("set search_path to s1,s3")
3171        self.assertRaises(PrgError, get, "t2", 1, 'n')
3172        self.assertEqual(get("t3", 1, 'n')['d'], 3)
3173        self.assertRaises(PrgError, get, "t4", 1, 'n')
3174        self.assertEqual(get("t", 1, 'n')['d'], 1)
3175        self.assertEqual(get("s4.t4", 1, 'n')['d'], 4)
3176
3177    def testMunging(self):
3178        get = self.db.get
3179        query = self.db.query
3180        r = get("t", 1, 'n')
3181        self.assertIn('oid(t)', r)
3182        query("set search_path to s2")
3183        r = get("t2", 1, 'n')
3184        self.assertIn('oid(t2)', r)
3185        query("set search_path to s3")
3186        r = get("t", 1, 'n')
3187        self.assertIn('oid(t)', r)
3188
3189
3190class TestDebug(unittest.TestCase):
3191    """Test the debug attribute of the DB class."""
3192
3193    def setUp(self):
3194        self.db = DB()
3195        self.query = self.db.query
3196        self.debug = self.db.debug
3197        self.output = StringIO()
3198        self.stdout, sys.stdout = sys.stdout, self.output
3199
3200    def tearDown(self):
3201        sys.stdout = self.stdout
3202        self.output.close()
3203        self.db.debug = debug
3204        self.db.close()
3205
3206    def get_output(self):
3207        return self.output.getvalue()
3208
3209    def send_queries(self):
3210        self.db.query("select 1")
3211        self.db.query("select 2")
3212
3213    def testDebugDefault(self):
3214        if debug:
3215            self.assertEqual(self.db.debug, debug)
3216        else:
3217            self.assertIsNone(self.db.debug)
3218
3219    def testDebugIsFalse(self):
3220        self.db.debug = False
3221        self.send_queries()
3222        self.assertEqual(self.get_output(), "")
3223
3224    def testDebugIsTrue(self):
3225        self.db.debug = True
3226        self.send_queries()
3227        self.assertEqual(self.get_output(), "select 1\nselect 2\n")
3228
3229    def testDebugIsString(self):
3230        self.db.debug = "Test with string: %s."
3231        self.send_queries()
3232        self.assertEqual(self.get_output(),
3233            "Test with string: select 1.\nTest with string: select 2.\n")
3234
3235    def testDebugIsFileLike(self):
3236        with tempfile.TemporaryFile('w+') as debug_file:
3237            self.db.debug = debug_file
3238            self.send_queries()
3239            debug_file.seek(0)
3240            output = debug_file.read()
3241            self.assertEqual(output, "select 1\nselect 2\n")
3242            self.assertEqual(self.get_output(), "")
3243
3244    def testDebugIsCallable(self):
3245        output = []
3246        self.db.debug = output.append
3247        self.db.query("select 1")
3248        self.db.query("select 2")
3249        self.assertEqual(output, ["select 1", "select 2"])
3250        self.assertEqual(self.get_output(), "")
3251
3252
3253if __name__ == '__main__':
3254    unittest.main()
Note: See TracBrowser for help on using the repository browser.