source: trunk/tests/test_classic_dbwrapper.py @ 770

Last change on this file since 770 was 770, checked in by cito, 4 years ago

Add methods for getting a table as a list or dict

Also added documentation and 100% test coverage.

The get_attnames() method now always returns a read-only ordered dictionary,
even under Python 2.6 or 3.0. So you can sure the columns will be returned
in the right order if you iterate over it, and that you don't accidentally
modify the dictionary (since it is cached).

  • Property svn:executable set to *
  • Property svn:keywords set to Id
File size: 121.2 KB
Line 
1#! /usr/bin/python
2# -*- coding: utf-8 -*-
3
4"""Test the classic PyGreSQL interface.
5
6Sub-tests for the DB wrapper object.
7
8Contributed by Christoph Zwerschke.
9
10These tests need a database to test against.
11"""
12
13try:
14    import unittest2 as unittest  # for Python < 2.7
15except ImportError:
16    import unittest
17
18import os
19import sys
20import tempfile
21
22import pg  # the module under test
23
24from decimal import Decimal
25from operator import itemgetter
26
27# We need a database to test against.  If LOCAL_PyGreSQL.py exists we will
28# get our information from that.  Otherwise we use the defaults.
29# The current user must have create schema privilege on the database.
30dbname = 'unittest'
31dbhost = None
32dbport = 5432
33
34debug = False  # let DB wrapper print debugging output
35
36try:
37    from .LOCAL_PyGreSQL import *
38except (ImportError, ValueError):
39    try:
40        from LOCAL_PyGreSQL import *
41    except ImportError:
42        pass
43
44try:
45    long
46except NameError:  # Python >= 3.0
47    long = int
48
49try:
50    unicode
51except NameError:  # Python >= 3.0
52    unicode = str
53
54try:
55    from collections import OrderedDict
56except ImportError:  # Python 2.6 or 3.0
57    OrderedDict = dict
58
59if str is bytes:
60    from StringIO import StringIO
61else:
62    from io import StringIO
63
64windows = os.name == 'nt'
65
66# There is a known a bug in libpq under Windows which can cause
67# the interface to crash when calling PQhost():
68do_not_ask_for_host = windows
69do_not_ask_for_host_reason = 'libpq issue on Windows'
70
71
72def DB():
73    """Create a DB wrapper object connecting to the test database."""
74    db = pg.DB(dbname, dbhost, dbport)
75    if debug:
76        db.debug = debug
77    db.query("set client_min_messages=warning")
78    return db
79
80
81class TestAttrDict(unittest.TestCase):
82    """Test the simple ordered dictionary for attribute names."""
83
84    cls = pg.AttrDict
85    base = OrderedDict
86
87    def testInit(self):
88        a = self.cls()
89        self.assertIsInstance(a, self.base)
90        self.assertEqual(a, self.base())
91        items = [('id', 'int'), ('name', 'text')]
92        a = self.cls(items)
93        self.assertIsInstance(a, self.base)
94        self.assertEqual(a, self.base(items))
95        iteritems = iter(items)
96        a = self.cls(iteritems)
97        self.assertIsInstance(a, self.base)
98        self.assertEqual(a, self.base(items))
99
100    def testIter(self):
101        a = self.cls()
102        self.assertEqual(list(a), [])
103        keys = ['id', 'name', 'age']
104        items = [(key, None) for key in keys]
105        a = self.cls(items)
106        self.assertEqual(list(a), keys)
107
108    def testKeys(self):
109        a = self.cls()
110        self.assertEqual(list(a.keys()), [])
111        keys = ['id', 'name', 'age']
112        items = [(key, None) for key in keys]
113        a = self.cls(items)
114        self.assertEqual(list(a.keys()), keys)
115
116    def testValues(self):
117        a = self.cls()
118        self.assertEqual(list(a.values()), [])
119        items = [('id', 'int'), ('name', 'text')]
120        values = [item[1] for item in items]
121        a = self.cls(items)
122        self.assertEqual(list(a.values()), values)
123
124    def testItems(self):
125        a = self.cls()
126        self.assertEqual(list(a.items()), [])
127        items = [('id', 'int'), ('name', 'text')]
128        a = self.cls(items)
129        self.assertEqual(list(a.items()), items)
130
131    def testGet(self):
132        a = self.cls([('id', 1)])
133        try:
134            self.assertEqual(a['id'], 1)
135        except KeyError:
136            self.fail('AttrDict should be readable')
137
138    def testSet(self):
139        a = self.cls()
140        try:
141            a['id'] = 1
142        except TypeError:
143            pass
144        else:
145            self.fail('AttrDict should be read-only')
146
147    def testDel(self):
148        a = self.cls([('id', 1)])
149        try:
150            del a['id']
151        except TypeError:
152            pass
153        else:
154            self.fail('AttrDict should be read-only')
155
156    def testWriteMethods(self):
157        a = self.cls([('id', 1)])
158        self.assertEqual(a['id'], 1)
159        for method in 'clear', 'update', 'pop', 'setdefault', 'popitem':
160            method = getattr(a, method)
161            self.assertRaises(TypeError, method, a)
162
163
164class TestDBClassBasic(unittest.TestCase):
165    """Test existence of the DB class wrapped pg connection methods."""
166
167    def setUp(self):
168        self.db = DB()
169
170    def tearDown(self):
171        try:
172            self.db.close()
173        except pg.InternalError:
174            pass
175
176    def testAllDBAttributes(self):
177        attributes = [
178            'abort',
179            'begin',
180            'cancel', 'clear', 'close', 'commit',
181            'db', 'dbname', 'debug', 'delete',
182            'end', 'endcopy', 'error',
183            'escape_bytea', 'escape_identifier',
184            'escape_literal', 'escape_string',
185            'fileno',
186            'get', 'get_as_dict', 'get_as_list',
187            'get_attnames', 'get_databases',
188            'get_notice_receiver', 'get_parameter',
189            'get_relations', 'get_tables',
190            'getline', 'getlo', 'getnotify',
191            'has_table_privilege', 'host',
192            'insert', 'inserttable',
193            'locreate', 'loimport',
194            'notification_handler',
195            'options',
196            'parameter', 'pkey', 'port',
197            'protocol_version', 'putline',
198            'query',
199            'release', 'reopen', 'reset', 'rollback',
200            'savepoint', 'server_version',
201            'set_notice_receiver', 'set_parameter',
202            'source', 'start', 'status',
203            'transaction', 'truncate',
204            'unescape_bytea', 'update', 'upsert',
205            'use_regtypes', 'user',
206        ]
207        db_attributes = [a for a in dir(self.db)
208            if not a.startswith('_')]
209        self.assertEqual(attributes, db_attributes)
210
211    def testAttributeDb(self):
212        self.assertEqual(self.db.db.db, dbname)
213
214    def testAttributeDbname(self):
215        self.assertEqual(self.db.dbname, dbname)
216
217    def testAttributeError(self):
218        error = self.db.error
219        self.assertTrue(not error or 'krb5_' in error)
220        self.assertEqual(self.db.error, self.db.db.error)
221
222    @unittest.skipIf(do_not_ask_for_host, do_not_ask_for_host_reason)
223    def testAttributeHost(self):
224        def_host = 'localhost'
225        host = self.db.host
226        self.assertIsInstance(host, str)
227        self.assertEqual(host, dbhost or def_host)
228        self.assertEqual(host, self.db.db.host)
229
230    def testAttributeOptions(self):
231        no_options = ''
232        options = self.db.options
233        self.assertEqual(options, no_options)
234        self.assertEqual(options, self.db.db.options)
235
236    def testAttributePort(self):
237        def_port = 5432
238        port = self.db.port
239        self.assertIsInstance(port, int)
240        self.assertEqual(port, dbport or def_port)
241        self.assertEqual(port, self.db.db.port)
242
243    def testAttributeProtocolVersion(self):
244        protocol_version = self.db.protocol_version
245        self.assertIsInstance(protocol_version, int)
246        self.assertTrue(2 <= protocol_version < 4)
247        self.assertEqual(protocol_version, self.db.db.protocol_version)
248
249    def testAttributeServerVersion(self):
250        server_version = self.db.server_version
251        self.assertIsInstance(server_version, int)
252        self.assertTrue(70400 <= server_version < 100000)
253        self.assertEqual(server_version, self.db.db.server_version)
254
255    def testAttributeStatus(self):
256        status_ok = 1
257        status = self.db.status
258        self.assertIsInstance(status, int)
259        self.assertEqual(status, status_ok)
260        self.assertEqual(status, self.db.db.status)
261
262    def testAttributeUser(self):
263        no_user = 'Deprecated facility'
264        user = self.db.user
265        self.assertTrue(user)
266        self.assertIsInstance(user, str)
267        self.assertNotEqual(user, no_user)
268        self.assertEqual(user, self.db.db.user)
269
270    def testMethodEscapeLiteral(self):
271        self.assertEqual(self.db.escape_literal(''), "''")
272
273    def testMethodEscapeIdentifier(self):
274        self.assertEqual(self.db.escape_identifier(''), '""')
275
276    def testMethodEscapeString(self):
277        self.assertEqual(self.db.escape_string(''), '')
278
279    def testMethodEscapeBytea(self):
280        self.assertEqual(self.db.escape_bytea('').replace(
281            '\\x', '').replace('\\', ''), '')
282
283    def testMethodUnescapeBytea(self):
284        self.assertEqual(self.db.unescape_bytea(''), b'')
285
286    def testMethodQuery(self):
287        query = self.db.query
288        query("select 1+1")
289        query("select 1+$1+$2", 2, 3)
290        query("select 1+$1+$2", (2, 3))
291        query("select 1+$1+$2", [2, 3])
292        query("select 1+$1", 1)
293
294    def testMethodQueryEmpty(self):
295        self.assertRaises(ValueError, self.db.query, '')
296
297    def testMethodQueryProgrammingError(self):
298        try:
299            self.db.query("select 1/0")
300        except pg.ProgrammingError as error:
301            self.assertEqual(error.sqlstate, '22012')
302
303    def testMethodEndcopy(self):
304        try:
305            self.db.endcopy()
306        except IOError:
307            pass
308
309    def testMethodClose(self):
310        self.db.close()
311        try:
312            self.db.reset()
313        except pg.Error:
314            pass
315        else:
316            self.fail('Reset should give an error for a closed connection')
317        self.assertIsNone(self.db.db)
318        self.assertRaises(pg.InternalError, self.db.close)
319        self.assertRaises(pg.InternalError, self.db.query, 'select 1')
320        self.assertRaises(pg.InternalError, getattr, self.db, 'status')
321        self.assertRaises(pg.InternalError, getattr, self.db, 'error')
322        self.assertRaises(pg.InternalError, getattr, self.db, 'absent')
323
324    def testMethodReset(self):
325        con = self.db.db
326        self.db.reset()
327        self.assertIs(self.db.db, con)
328        self.db.query("select 1+1")
329        self.db.close()
330        self.assertRaises(pg.InternalError, self.db.reset)
331
332    def testMethodReopen(self):
333        con = self.db.db
334        self.db.reopen()
335        self.assertIsNot(self.db.db, con)
336        con = self.db.db
337        self.db.query("select 1+1")
338        self.db.close()
339        self.db.reopen()
340        self.assertIsNot(self.db.db, con)
341        self.db.query("select 1+1")
342        self.db.close()
343
344    def testExistingConnection(self):
345        db = pg.DB(self.db.db)
346        self.assertEqual(self.db.db, db.db)
347        self.assertTrue(db.db)
348        db.close()
349        self.assertTrue(db.db)
350        db.reopen()
351        self.assertTrue(db.db)
352        db.close()
353        self.assertTrue(db.db)
354        db = pg.DB(self.db)
355        self.assertEqual(self.db.db, db.db)
356        db = pg.DB(db=self.db.db)
357        self.assertEqual(self.db.db, db.db)
358
359        class DB2:
360            pass
361
362        db2 = DB2()
363        db2._cnx = self.db.db
364        db = pg.DB(db2)
365        self.assertEqual(self.db.db, db.db)
366
367
368class TestDBClass(unittest.TestCase):
369    """Test the methods of the DB class wrapped pg connection."""
370
371    cls_set_up = False
372
373    @classmethod
374    def setUpClass(cls):
375        db = DB()
376        db.query("drop table if exists test cascade")
377        db.query("create table test ("
378            "i2 smallint, i4 integer, i8 bigint,"
379            " d numeric, f4 real, f8 double precision, m money,"
380            " v4 varchar(4), c4 char(4), t text)")
381        db.query("create or replace view test_view as"
382            " select i4, v4 from test")
383        db.close()
384        cls.cls_set_up = True
385
386    @classmethod
387    def tearDownClass(cls):
388        db = DB()
389        db.query("drop table test cascade")
390        db.close()
391
392    def setUp(self):
393        self.assertTrue(self.cls_set_up)
394        self.db = DB()
395        query = self.db.query
396        query('set client_encoding=utf8')
397        query('set standard_conforming_strings=on')
398        query("set lc_monetary='C'")
399        query("set datestyle='ISO,YMD'")
400        query('set bytea_output=hex')
401
402    def tearDown(self):
403        self.doCleanups()
404        self.db.close()
405
406    def createTable(self, table, definition,
407            temporary=True, oids=None, values=None):
408        query = self.db.query
409        if not '"' in table or '.' in table:
410            table = '"%s"' % table
411        if not temporary:
412            q = 'drop table if exists %s cascade' % table
413            query(q)
414            self.addCleanup(query, q)
415        temporary = 'temporary table' if temporary else 'table'
416        as_query = definition.startswith(('as ', 'AS '))
417        if not as_query and not definition.startswith('('):
418            definition = '(%s)' % definition
419        with_oids = 'with oids' if oids else 'without oids'
420        q = ['create', temporary, table]
421        if as_query:
422            q.extend([with_oids, definition])
423        else:
424            q.extend([definition, with_oids])
425        q = ' '.join(q)
426        query(q)
427        if values:
428            for params in values:
429                if not isinstance(params, (list, tuple)):
430                    params = [params]
431                values = ', '.join('$%d' % (n + 1) for n in range(len(params)))
432                q = "insert into %s values (%s)" % (table, values)
433                query(q, params)
434
435    def testClassName(self):
436        self.assertEqual(self.db.__class__.__name__, 'DB')
437
438    def testModuleName(self):
439        self.assertEqual(self.db.__module__, 'pg')
440        self.assertEqual(self.db.__class__.__module__, 'pg')
441
442    def testEscapeLiteral(self):
443        f = self.db.escape_literal
444        r = f(b"plain")
445        self.assertIsInstance(r, bytes)
446        self.assertEqual(r, b"'plain'")
447        r = f(u"plain")
448        self.assertIsInstance(r, unicode)
449        self.assertEqual(r, u"'plain'")
450        r = f(u"that's kÀse".encode('utf-8'))
451        self.assertIsInstance(r, bytes)
452        self.assertEqual(r, u"'that''s kÀse'".encode('utf-8'))
453        r = f(u"that's kÀse")
454        self.assertIsInstance(r, unicode)
455        self.assertEqual(r, u"'that''s kÀse'")
456        self.assertEqual(f(r"It's fine to have a \ inside."),
457            r" E'It''s fine to have a \\ inside.'")
458        self.assertEqual(f('No "quotes" must be escaped.'),
459            "'No \"quotes\" must be escaped.'")
460
461    def testEscapeIdentifier(self):
462        f = self.db.escape_identifier
463        r = f(b"plain")
464        self.assertIsInstance(r, bytes)
465        self.assertEqual(r, b'"plain"')
466        r = f(u"plain")
467        self.assertIsInstance(r, unicode)
468        self.assertEqual(r, u'"plain"')
469        r = f(u"that's kÀse".encode('utf-8'))
470        self.assertIsInstance(r, bytes)
471        self.assertEqual(r, u'"that\'s kÀse"'.encode('utf-8'))
472        r = f(u"that's kÀse")
473        self.assertIsInstance(r, unicode)
474        self.assertEqual(r, u'"that\'s kÀse"')
475        self.assertEqual(f(r"It's fine to have a \ inside."),
476            '"It\'s fine to have a \\ inside."')
477        self.assertEqual(f('All "quotes" must be escaped.'),
478            '"All ""quotes"" must be escaped."')
479
480    def testEscapeString(self):
481        f = self.db.escape_string
482        r = f(b"plain")
483        self.assertIsInstance(r, bytes)
484        self.assertEqual(r, b"plain")
485        r = f(u"plain")
486        self.assertIsInstance(r, unicode)
487        self.assertEqual(r, u"plain")
488        r = f(u"that's kÀse".encode('utf-8'))
489        self.assertIsInstance(r, bytes)
490        self.assertEqual(r, u"that''s kÀse".encode('utf-8'))
491        r = f(u"that's kÀse")
492        self.assertIsInstance(r, unicode)
493        self.assertEqual(r, u"that''s kÀse")
494        self.assertEqual(f(r"It's fine to have a \ inside."),
495            r"It''s fine to have a \ inside.")
496
497    def testEscapeBytea(self):
498        f = self.db.escape_bytea
499        # note that escape_byte always returns hex output since Pg 9.0,
500        # regardless of the bytea_output setting
501        r = f(b'plain')
502        self.assertIsInstance(r, bytes)
503        self.assertEqual(r, b'\\x706c61696e')
504        r = f(u'plain')
505        self.assertIsInstance(r, unicode)
506        self.assertEqual(r, u'\\x706c61696e')
507        r = f(u"das is' kÀse".encode('utf-8'))
508        self.assertIsInstance(r, bytes)
509        self.assertEqual(r, b'\\x64617320697327206bc3a47365')
510        r = f(u"das is' kÀse")
511        self.assertIsInstance(r, unicode)
512        self.assertEqual(r, u'\\x64617320697327206bc3a47365')
513        self.assertEqual(f(b'O\x00ps\xff!'), b'\\x4f007073ff21')
514
515    def testUnescapeBytea(self):
516        f = self.db.unescape_bytea
517        r = f(b'plain')
518        self.assertIsInstance(r, bytes)
519        self.assertEqual(r, b'plain')
520        r = f(u'plain')
521        self.assertIsInstance(r, bytes)
522        self.assertEqual(r, b'plain')
523        r = f(b"das is' k\\303\\244se")
524        self.assertIsInstance(r, bytes)
525        self.assertEqual(r, u"das is' kÀse".encode('utf8'))
526        r = f(u"das is' k\\303\\244se")
527        self.assertIsInstance(r, bytes)
528        self.assertEqual(r, u"das is' kÀse".encode('utf8'))
529        self.assertEqual(f(r'O\\000ps\\377!'), b'O\\000ps\\377!')
530        self.assertEqual(f(r'\\x706c61696e'), b'\\x706c61696e')
531        self.assertEqual(f(r'\\x746861742773206be47365'),
532            b'\\x746861742773206be47365')
533        self.assertEqual(f(r'\\x4f007073ff21'), b'\\x4f007073ff21')
534
535    def testGetParameter(self):
536        f = self.db.get_parameter
537        self.assertRaises(TypeError, f)
538        self.assertRaises(TypeError, f, None)
539        self.assertRaises(TypeError, f, 42)
540        self.assertRaises(TypeError, f, '')
541        self.assertRaises(TypeError, f, [])
542        self.assertRaises(TypeError, f, [''])
543        self.assertRaises(pg.ProgrammingError, f, 'this_does_not_exist')
544        r = f('standard_conforming_strings')
545        self.assertEqual(r, 'on')
546        r = f('lc_monetary')
547        self.assertEqual(r, 'C')
548        r = f('datestyle')
549        self.assertEqual(r, 'ISO, YMD')
550        r = f('bytea_output')
551        self.assertEqual(r, 'hex')
552        r = f(['bytea_output', 'lc_monetary'])
553        self.assertIsInstance(r, list)
554        self.assertEqual(r, ['hex', 'C'])
555        r = f(('standard_conforming_strings', 'datestyle', 'bytea_output'))
556        self.assertEqual(r, ['on', 'ISO, YMD', 'hex'])
557        r = f(set(['bytea_output', 'lc_monetary']))
558        self.assertIsInstance(r, dict)
559        self.assertEqual(r, {'bytea_output': 'hex', 'lc_monetary': 'C'})
560        r = f(set(['Bytea_Output', ' LC_Monetary ']))
561        self.assertIsInstance(r, dict)
562        self.assertEqual(r, {'Bytea_Output': 'hex', ' LC_Monetary ': 'C'})
563        s = dict.fromkeys(('bytea_output', 'lc_monetary'))
564        r = f(s)
565        self.assertIs(r, s)
566        self.assertEqual(r, {'bytea_output': 'hex', 'lc_monetary': 'C'})
567        s = dict.fromkeys(('Bytea_Output', ' LC_Monetary '))
568        r = f(s)
569        self.assertIs(r, s)
570        self.assertEqual(r, {'Bytea_Output': 'hex', ' LC_Monetary ': 'C'})
571
572    def testGetParameterServerVersion(self):
573        r = self.db.get_parameter('server_version_num')
574        self.assertIsInstance(r, str)
575        s = self.db.server_version
576        self.assertIsInstance(s, int)
577        self.assertEqual(r, str(s))
578
579    def testGetParameterAll(self):
580        f = self.db.get_parameter
581        r = f('all')
582        self.assertIsInstance(r, dict)
583        self.assertEqual(r['standard_conforming_strings'], 'on')
584        self.assertEqual(r['lc_monetary'], 'C')
585        self.assertEqual(r['DateStyle'], 'ISO, YMD')
586        self.assertEqual(r['bytea_output'], 'hex')
587
588    def testSetParameter(self):
589        f = self.db.set_parameter
590        g = self.db.get_parameter
591        self.assertRaises(TypeError, f)
592        self.assertRaises(TypeError, f, None)
593        self.assertRaises(TypeError, f, 42)
594        self.assertRaises(TypeError, f, '')
595        self.assertRaises(TypeError, f, [])
596        self.assertRaises(TypeError, f, [''])
597        self.assertRaises(ValueError, f, 'all', 'invalid')
598        self.assertRaises(ValueError, f, {
599            'invalid1': 'value1', 'invalid2': 'value2'}, 'value')
600        self.assertRaises(pg.ProgrammingError, f, 'this_does_not_exist')
601        f('standard_conforming_strings', 'off')
602        self.assertEqual(g('standard_conforming_strings'), 'off')
603        f('datestyle', 'ISO, DMY')
604        self.assertEqual(g('datestyle'), 'ISO, DMY')
605        f(['standard_conforming_strings', 'datestyle'], ['on', 'ISO, DMY'])
606        self.assertEqual(g('standard_conforming_strings'), 'on')
607        self.assertEqual(g('datestyle'), 'ISO, DMY')
608        f(['default_with_oids', 'standard_conforming_strings'], 'off')
609        self.assertEqual(g('default_with_oids'), 'off')
610        self.assertEqual(g('standard_conforming_strings'), 'off')
611        f(('standard_conforming_strings', 'datestyle'), ('on', 'ISO, YMD'))
612        self.assertEqual(g('standard_conforming_strings'), 'on')
613        self.assertEqual(g('datestyle'), 'ISO, YMD')
614        f(('default_with_oids', 'standard_conforming_strings'), 'off')
615        self.assertEqual(g('default_with_oids'), 'off')
616        self.assertEqual(g('standard_conforming_strings'), 'off')
617        f(set(['default_with_oids', 'standard_conforming_strings']), 'on')
618        self.assertEqual(g('default_with_oids'), 'on')
619        self.assertEqual(g('standard_conforming_strings'), 'on')
620        self.assertRaises(ValueError, f, set([ 'default_with_oids',
621            'standard_conforming_strings']), ['off', 'on'])
622        f(set(['default_with_oids', 'standard_conforming_strings']),
623            ['off', 'off'])
624        self.assertEqual(g('default_with_oids'), 'off')
625        self.assertEqual(g('standard_conforming_strings'), 'off')
626        f({'standard_conforming_strings': 'on', 'datestyle': 'ISO, YMD'})
627        self.assertEqual(g('standard_conforming_strings'), 'on')
628        self.assertEqual(g('datestyle'), 'ISO, YMD')
629
630    def testResetParameter(self):
631        db = DB()
632        f = db.set_parameter
633        g = db.get_parameter
634        r = g('default_with_oids')
635        self.assertIn(r, ('on', 'off'))
636        dwi, not_dwi = r, 'off' if r == 'on' else 'on'
637        r = g('standard_conforming_strings')
638        self.assertIn(r, ('on', 'off'))
639        scs, not_scs = r, 'off' if r == 'on' else 'on'
640        f('default_with_oids', not_dwi)
641        f('standard_conforming_strings', not_scs)
642        self.assertEqual(g('default_with_oids'), not_dwi)
643        self.assertEqual(g('standard_conforming_strings'), not_scs)
644        f('default_with_oids')
645        f('standard_conforming_strings', None)
646        self.assertEqual(g('default_with_oids'), dwi)
647        self.assertEqual(g('standard_conforming_strings'), scs)
648        f('default_with_oids', not_dwi)
649        f('standard_conforming_strings', not_scs)
650        self.assertEqual(g('default_with_oids'), not_dwi)
651        self.assertEqual(g('standard_conforming_strings'), not_scs)
652        f(['default_with_oids', 'standard_conforming_strings'], None)
653        self.assertEqual(g('default_with_oids'), dwi)
654        self.assertEqual(g('standard_conforming_strings'), scs)
655        f('default_with_oids', not_dwi)
656        f('standard_conforming_strings', not_scs)
657        self.assertEqual(g('default_with_oids'), not_dwi)
658        self.assertEqual(g('standard_conforming_strings'), not_scs)
659        f(('default_with_oids', 'standard_conforming_strings'))
660        self.assertEqual(g('default_with_oids'), dwi)
661        self.assertEqual(g('standard_conforming_strings'), scs)
662        f('default_with_oids', not_dwi)
663        f('standard_conforming_strings', not_scs)
664        self.assertEqual(g('default_with_oids'), not_dwi)
665        self.assertEqual(g('standard_conforming_strings'), not_scs)
666        f(set(['default_with_oids', 'standard_conforming_strings']))
667        self.assertEqual(g('default_with_oids'), dwi)
668        self.assertEqual(g('standard_conforming_strings'), scs)
669
670    def testResetParameterAll(self):
671        db = DB()
672        f = db.set_parameter
673        self.assertRaises(ValueError, f, 'all', 0)
674        self.assertRaises(ValueError, f, 'all', 'off')
675        g = db.get_parameter
676        r = g('default_with_oids')
677        self.assertIn(r, ('on', 'off'))
678        dwi, not_dwi = r, 'off' if r == 'on' else 'on'
679        r = g('standard_conforming_strings')
680        self.assertIn(r, ('on', 'off'))
681        scs, not_scs = r, 'off' if r == 'on' else 'on'
682        f('default_with_oids', not_dwi)
683        f('standard_conforming_strings', not_scs)
684        self.assertEqual(g('default_with_oids'), not_dwi)
685        self.assertEqual(g('standard_conforming_strings'), not_scs)
686        f('all')
687        self.assertEqual(g('default_with_oids'), dwi)
688        self.assertEqual(g('standard_conforming_strings'), scs)
689
690    def testSetParameterLocal(self):
691        f = self.db.set_parameter
692        g = self.db.get_parameter
693        self.assertEqual(g('standard_conforming_strings'), 'on')
694        self.db.begin()
695        f('standard_conforming_strings', 'off', local=True)
696        self.assertEqual(g('standard_conforming_strings'), 'off')
697        self.db.end()
698        self.assertEqual(g('standard_conforming_strings'), 'on')
699
700    def testSetParameterSession(self):
701        f = self.db.set_parameter
702        g = self.db.get_parameter
703        self.assertEqual(g('standard_conforming_strings'), 'on')
704        self.db.begin()
705        f('standard_conforming_strings', 'off', local=False)
706        self.assertEqual(g('standard_conforming_strings'), 'off')
707        self.db.end()
708        self.assertEqual(g('standard_conforming_strings'), 'off')
709
710    def testReset(self):
711        db = DB()
712        default_datestyle = db.get_parameter('datestyle')
713        changed_datestyle = 'ISO, DMY'
714        if changed_datestyle == default_datestyle:
715            changed_datestyle == 'ISO, YMD'
716        self.db.set_parameter('datestyle', changed_datestyle)
717        r = self.db.get_parameter('datestyle')
718        self.assertEqual(r, changed_datestyle)
719        con = self.db.db
720        q = con.query("show datestyle")
721        self.db.reset()
722        r = q.getresult()[0][0]
723        self.assertEqual(r, changed_datestyle)
724        q = con.query("show datestyle")
725        r = q.getresult()[0][0]
726        self.assertEqual(r, default_datestyle)
727        r = self.db.get_parameter('datestyle')
728        self.assertEqual(r, default_datestyle)
729
730    def testReopen(self):
731        db = DB()
732        default_datestyle = db.get_parameter('datestyle')
733        changed_datestyle = 'ISO, DMY'
734        if changed_datestyle == default_datestyle:
735            changed_datestyle == 'ISO, YMD'
736        self.db.set_parameter('datestyle', changed_datestyle)
737        r = self.db.get_parameter('datestyle')
738        self.assertEqual(r, changed_datestyle)
739        con = self.db.db
740        q = con.query("show datestyle")
741        self.db.reopen()
742        r = q.getresult()[0][0]
743        self.assertEqual(r, changed_datestyle)
744        self.assertRaises(TypeError, getattr, con, 'query')
745        r = self.db.get_parameter('datestyle')
746        self.assertEqual(r, default_datestyle)
747
748    def testCreateTable(self):
749        table = 'test hello world'
750        values = [(2, "World!"), (1, "Hello")]
751        self.createTable(table, "n smallint, t varchar",
752            temporary=True, oids=True, values=values)
753        r = self.db.query('select t from "%s" order by n' % table).getresult()
754        r = ', '.join(row[0] for row in r)
755        self.assertEqual(r, "Hello, World!")
756        r = self.db.query('select oid from "%s" limit 1' % table).getresult()
757        self.assertIsInstance(r[0][0], int)
758
759    def testQuery(self):
760        query = self.db.query
761        table = 'test_table'
762        self.createTable(table, "n integer", oids=True)
763        q = "insert into test_table values (1)"
764        r = query(q)
765        self.assertIsInstance(r, int)
766        q = "insert into test_table select 2"
767        r = query(q)
768        self.assertIsInstance(r, int)
769        oid = r
770        q = "select oid from test_table where n=2"
771        r = query(q).getresult()
772        self.assertEqual(len(r), 1)
773        r = r[0]
774        self.assertEqual(len(r), 1)
775        r = r[0]
776        self.assertEqual(r, oid)
777        q = "insert into test_table select 3 union select 4 union select 5"
778        r = query(q)
779        self.assertIsInstance(r, str)
780        self.assertEqual(r, '3')
781        q = "update test_table set n=4 where n<5"
782        r = query(q)
783        self.assertIsInstance(r, str)
784        self.assertEqual(r, '4')
785        q = "delete from test_table"
786        r = query(q)
787        self.assertIsInstance(r, str)
788        self.assertEqual(r, '5')
789
790    def testMultipleQueries(self):
791        self.assertEqual(self.db.query(
792            "create temporary table test_multi (n integer);"
793            "insert into test_multi values (4711);"
794            "select n from test_multi").getresult()[0][0], 4711)
795
796    def testQueryWithParams(self):
797        query = self.db.query
798        self.createTable('test_table', 'n1 integer, n2 integer', oids=True)
799        q = "insert into test_table values ($1, $2)"
800        r = query(q, (1, 2))
801        self.assertIsInstance(r, int)
802        r = query(q, [3, 4])
803        self.assertIsInstance(r, int)
804        r = query(q, [5, 6])
805        self.assertIsInstance(r, int)
806        q = "select * from test_table order by 1, 2"
807        self.assertEqual(query(q).getresult(),
808            [(1, 2), (3, 4), (5, 6)])
809        q = "select * from test_table where n1=$1 and n2=$2"
810        self.assertEqual(query(q, 3, 4).getresult(), [(3, 4)])
811        q = "update test_table set n2=$2 where n1=$1"
812        r = query(q, 3, 7)
813        self.assertEqual(r, '1')
814        q = "select * from test_table order by 1, 2"
815        self.assertEqual(query(q).getresult(),
816            [(1, 2), (3, 7), (5, 6)])
817        q = "delete from test_table where n2!=$1"
818        r = query(q, 4)
819        self.assertEqual(r, '3')
820
821    def testEmptyQuery(self):
822        self.assertRaises(ValueError, self.db.query, '')
823
824    def testQueryProgrammingError(self):
825        try:
826            self.db.query("select 1/0")
827        except pg.ProgrammingError as error:
828            self.assertEqual(error.sqlstate, '22012')
829
830    def testPkey(self):
831        query = self.db.query
832        pkey = self.db.pkey
833        self.assertRaises(KeyError, pkey, 'test')
834        for t in ('pkeytest', 'primary key test'):
835            self.createTable('%s0' % t, 'a smallint')
836            self.createTable('%s1' % t, 'b smallint primary key')
837            self.createTable('%s2' % t,
838                'c smallint, d smallint primary key')
839            self.createTable('%s3' % t,
840                'e smallint, f smallint, g smallint, h smallint, i smallint,'
841                ' primary key (f, h)')
842            self.createTable('%s4' % t,
843                'e smallint, f smallint, g smallint, h smallint, i smallint,'
844                ' primary key (h, f)')
845            self.createTable('%s5' % t,
846                'more_than_one_letter varchar primary key')
847            self.createTable('%s6' % t,
848                '"with space" date primary key')
849            self.createTable('%s7' % t,
850                'a_very_long_column_name varchar, "with space" date, "42" int,'
851                ' primary key (a_very_long_column_name, "with space", "42")')
852            self.assertRaises(KeyError, pkey, '%s0' % t)
853            self.assertEqual(pkey('%s1' % t), 'b')
854            self.assertEqual(pkey('%s1' % t, True), ('b',))
855            self.assertEqual(pkey('%s1' % t, composite=False), 'b')
856            self.assertEqual(pkey('%s1' % t, composite=True), ('b',))
857            self.assertEqual(pkey('%s2' % t), 'd')
858            self.assertEqual(pkey('%s2' % t, composite=True), ('d',))
859            r = pkey('%s3' % t)
860            self.assertIsInstance(r, tuple)
861            self.assertEqual(r, ('f', 'h'))
862            r = pkey('%s3' % t, composite=False)
863            self.assertIsInstance(r, tuple)
864            self.assertEqual(r, ('f', 'h'))
865            r = pkey('%s4' % t)
866            self.assertIsInstance(r, tuple)
867            self.assertEqual(r, ('h', 'f'))
868            self.assertEqual(pkey('%s5' % t), 'more_than_one_letter')
869            self.assertEqual(pkey('%s6' % t), 'with space')
870            r = pkey('%s7' % t)
871            self.assertIsInstance(r, tuple)
872            self.assertEqual(r, (
873                'a_very_long_column_name', 'with space', '42'))
874            # a newly added primary key will be detected
875            query('alter table "%s0" add primary key (a)' % t)
876            self.assertEqual(pkey('%s0' % t), 'a')
877            # a changed primary key will not be detected,
878            # indicating that the internal cache is operating
879            query('alter table "%s1" rename column b to x' % t)
880            self.assertEqual(pkey('%s1' % t), 'b')
881            # we get the changed primary key when the cache is flushed
882            self.assertEqual(pkey('%s1' % t, flush=True), 'x')
883
884    def testGetDatabases(self):
885        databases = self.db.get_databases()
886        self.assertIn('template0', databases)
887        self.assertIn('template1', databases)
888        self.assertNotIn('not existing database', databases)
889        self.assertIn('postgres', databases)
890        self.assertIn(dbname, databases)
891
892    def testGetTables(self):
893        get_tables = self.db.get_tables
894        tables = ('A very Special Name', 'A_MiXeD_quoted_NaMe',
895            'Hello, Test World!', 'Zoro', 'a1', 'a2', 'a321',
896            'averyveryveryveryveryveryveryreallyreallylongtablename',
897            'b0', 'b3', 'x', 'xXx', 'xx', 'y', 'z')
898        for t in tables:
899            self.db.query('drop table if exists "%s" cascade' % t)
900        before_tables = get_tables()
901        self.assertIsInstance(before_tables, list)
902        for t in before_tables:
903            t = t.split('.', 1)
904            self.assertGreaterEqual(len(t), 2)
905            if len(t) > 2:
906                self.assertTrue(t[1].startswith('"'))
907            t = t[0]
908            self.assertNotEqual(t, 'information_schema')
909            self.assertFalse(t.startswith('pg_'))
910        for t in tables:
911            self.createTable(t, 'as select 0', temporary=False)
912        current_tables = get_tables()
913        new_tables = [t for t in current_tables if t not in before_tables]
914        expected_new_tables = ['public.%s' % (
915            '"%s"' % t if ' ' in t or t != t.lower() else t) for t in tables]
916        self.assertEqual(new_tables, expected_new_tables)
917        self.doCleanups()
918        after_tables = get_tables()
919        self.assertEqual(after_tables, before_tables)
920
921    def testGetRelations(self):
922        get_relations = self.db.get_relations
923        result = get_relations()
924        self.assertIn('public.test', result)
925        self.assertIn('public.test_view', result)
926        result = get_relations('rv')
927        self.assertIn('public.test', result)
928        self.assertIn('public.test_view', result)
929        result = get_relations('r')
930        self.assertIn('public.test', result)
931        self.assertNotIn('public.test_view', result)
932        result = get_relations('v')
933        self.assertNotIn('public.test', result)
934        self.assertIn('public.test_view', result)
935        result = get_relations('cisSt')
936        self.assertNotIn('public.test', result)
937        self.assertNotIn('public.test_view', result)
938
939    def testGetAttnames(self):
940        get_attnames = self.db.get_attnames
941        self.assertRaises(pg.ProgrammingError,
942            self.db.get_attnames, 'does_not_exist')
943        self.assertRaises(pg.ProgrammingError,
944            self.db.get_attnames, 'has.too.many.dots')
945        r = get_attnames('test')
946        self.assertIsInstance(r, dict)
947        self.assertEqual(r, dict(
948            i2='int', i4='int', i8='int', d='num',
949            f4='float', f8='float', m='money',
950            v4='text', c4='text', t='text'))
951        self.createTable('test_table',
952            'n int, alpha smallint, beta bool,'
953            ' gamma char(5), tau text, v varchar(3)')
954        r = get_attnames('test_table')
955        self.assertIsInstance(r, dict)
956        self.assertEqual(r, dict(
957            n='int', alpha='int', beta='bool',
958            gamma='text', tau='text', v='text'))
959
960    def testGetAttnamesWithQuotes(self):
961        get_attnames = self.db.get_attnames
962        table = 'test table for get_attnames()'
963        self.createTable(table,
964            '"Prime!" smallint, "much space" integer, "Questions?" text')
965        r = get_attnames(table)
966        self.assertIsInstance(r, dict)
967        self.assertEqual(r, {
968            'Prime!': 'int', 'much space': 'int', 'Questions?': 'text'})
969        table = 'yet another test table for get_attnames()'
970        self.createTable(table,
971            'a smallint, b integer, c bigint,'
972            ' e numeric, f float, f2 double precision, m money,'
973            ' x smallint, y smallint, z smallint,'
974            ' Normal_NaMe smallint, "Special Name" smallint,'
975            ' t text, u char(2), v varchar(2),'
976            ' primary key (y, u)', oids=True)
977        r = get_attnames(table)
978        self.assertIsInstance(r, dict)
979        self.assertEqual(r, {'a': 'int', 'c': 'int', 'b': 'int',
980            'e': 'num', 'f': 'float', 'f2': 'float', 'm': 'money',
981            'normal_name': 'int', 'Special Name': 'int',
982            'u': 'text', 't': 'text', 'v': 'text',
983            'y': 'int', 'x': 'int', 'z': 'int', 'oid': 'int'})
984
985    def testGetAttnamesWithRegtypes(self):
986        get_attnames = self.db.get_attnames
987        self.createTable('test_table',
988            ' n int, alpha smallint, beta bool,'
989            ' gamma char(5), tau text, v varchar(3)')
990        use_regtypes = self.db.use_regtypes
991        regtypes = use_regtypes()
992        self.assertFalse(regtypes)
993        use_regtypes(True)
994        try:
995            r = get_attnames("test_table")
996            self.assertIsInstance(r, dict)
997        finally:
998            use_regtypes(regtypes)
999        self.assertEqual(r, dict(
1000            n='integer', alpha='smallint', beta='boolean',
1001            gamma='character', tau='text', v='character varying'))
1002
1003    def testGetAttnamesIsCached(self):
1004        get_attnames = self.db.get_attnames
1005        query = self.db.query
1006        self.createTable('test_table', 'col int')
1007        r = get_attnames("test_table")
1008        self.assertIsInstance(r, dict)
1009        self.assertEqual(r, dict(col='int'))
1010        query("alter table test_table alter column col type text")
1011        query("alter table test_table add column col2 int")
1012        r = get_attnames("test_table")
1013        self.assertEqual(r, dict(col='int'))
1014        r = get_attnames("test_table", flush=True)
1015        self.assertEqual(r, dict(col='text', col2='int'))
1016        query("alter table test_table drop column col2")
1017        r = get_attnames("test_table")
1018        self.assertEqual(r, dict(col='text', col2='int'))
1019        r = get_attnames("test_table", flush=True)
1020        self.assertEqual(r, dict(col='text'))
1021        query("alter table test_table drop column col")
1022        r = get_attnames("test_table")
1023        self.assertEqual(r, dict(col='text'))
1024        r = get_attnames("test_table", flush=True)
1025        self.assertEqual(r, dict())
1026
1027    def testGetAttnamesIsOrdered(self):
1028        get_attnames = self.db.get_attnames
1029        r = get_attnames('test', flush=True)
1030        self.assertIsInstance(r, OrderedDict)
1031        self.assertEqual(r, OrderedDict([
1032            ('i2', 'int'), ('i4', 'int'), ('i8', 'int'),
1033            ('d', 'num'), ('f4', 'float'), ('f8', 'float'), ('m', 'money'),
1034            ('v4', 'text'), ('c4', 'text'), ('t', 'text')]))
1035        if OrderedDict is not dict:
1036            r = ' '.join(list(r.keys()))
1037            self.assertEqual(r, 'i2 i4 i8 d f4 f8 m v4 c4 t')
1038        table = 'test table for get_attnames'
1039        self.createTable(table,
1040            ' n int, alpha smallint, v varchar(3),'
1041            ' gamma char(5), tau text, beta bool')
1042        r = get_attnames(table)
1043        self.assertIsInstance(r, OrderedDict)
1044        self.assertEqual(r, OrderedDict([
1045            ('n', 'int'), ('alpha', 'int'), ('v', 'text'),
1046            ('gamma', 'text'), ('tau', 'text'), ('beta', 'bool')]))
1047        if OrderedDict is not dict:
1048            r = ' '.join(list(r.keys()))
1049            self.assertEqual(r, 'n alpha v gamma tau beta')
1050        else:
1051            self.skipTest('OrderedDict is not supported')
1052
1053    def testGetAttnamesIsAttrDict(self):
1054        AttrDict = pg.AttrDict
1055        get_attnames = self.db.get_attnames
1056        r = get_attnames('test', flush=True)
1057        self.assertIsInstance(r, AttrDict)
1058        self.assertEqual(r, AttrDict([
1059            ('i2', 'int'), ('i4', 'int'), ('i8', 'int'),
1060            ('d', 'num'), ('f4', 'float'), ('f8', 'float'), ('m', 'money'),
1061            ('v4', 'text'), ('c4', 'text'), ('t', 'text')]))
1062        r = ' '.join(list(r.keys()))
1063        self.assertEqual(r, 'i2 i4 i8 d f4 f8 m v4 c4 t')
1064        table = 'test table for get_attnames'
1065        self.createTable(table,
1066            ' n int, alpha smallint, v varchar(3),'
1067            ' gamma char(5), tau text, beta bool')
1068        r = get_attnames(table)
1069        self.assertIsInstance(r, AttrDict)
1070        self.assertEqual(r, AttrDict([
1071            ('n', 'int'), ('alpha', 'int'), ('v', 'text'),
1072            ('gamma', 'text'), ('tau', 'text'), ('beta', 'bool')]))
1073        r = ' '.join(list(r.keys()))
1074        self.assertEqual(r, 'n alpha v gamma tau beta')
1075
1076    def testHasTablePrivilege(self):
1077        can = self.db.has_table_privilege
1078        self.assertEqual(can('test'), True)
1079        self.assertEqual(can('test', 'select'), True)
1080        self.assertEqual(can('test', 'SeLeCt'), True)
1081        self.assertEqual(can('test', 'SELECT'), True)
1082        self.assertEqual(can('test', 'insert'), True)
1083        self.assertEqual(can('test', 'update'), True)
1084        self.assertEqual(can('test', 'delete'), True)
1085        self.assertEqual(can('pg_views', 'select'), True)
1086        self.assertEqual(can('pg_views', 'delete'), False)
1087        self.assertRaises(pg.ProgrammingError, can, 'test', 'foobar')
1088        self.assertRaises(pg.ProgrammingError, can, 'table_does_not_exist')
1089
1090    def testGet(self):
1091        get = self.db.get
1092        query = self.db.query
1093        table = 'get_test_table'
1094        self.assertRaises(TypeError, get)
1095        self.assertRaises(TypeError, get, table)
1096        self.createTable(table, 'n integer, t text',
1097            values=enumerate('xyz', start=1))
1098        self.assertRaises(pg.ProgrammingError, get, table, 2)
1099        r = get(table, 2, 'n')
1100        self.assertIsInstance(r, dict)
1101        self.assertEqual(r, dict(n=2, t='y'))
1102        r = get(table, 1, 'n')
1103        self.assertEqual(r, dict(n=1, t='x'))
1104        r = get(table, (3,), ('n',))
1105        self.assertEqual(r, dict(n=3, t='z'))
1106        r = get(table, 'y', 't')
1107        self.assertEqual(r, dict(n=2, t='y'))
1108        self.assertRaises(pg.DatabaseError, get, table, 4)
1109        self.assertRaises(pg.DatabaseError, get, table, 4, 'n')
1110        self.assertRaises(pg.DatabaseError, get, table, 'y')
1111        self.assertRaises(pg.DatabaseError, get, table, 2, 't')
1112        s = dict(n=3)
1113        self.assertRaises(pg.ProgrammingError, get, table, s)
1114        r = get(table, s, 'n')
1115        self.assertIs(r, s)
1116        self.assertEqual(r, dict(n=3, t='z'))
1117        s.update(t='x')
1118        r = get(table, s, 't')
1119        self.assertIs(r, s)
1120        self.assertEqual(s, dict(n=1, t='x'))
1121        r = get(table, s, ('n', 't'))
1122        self.assertIs(r, s)
1123        self.assertEqual(r, dict(n=1, t='x'))
1124        query('alter table "%s" alter n set not null' % table)
1125        query('alter table "%s" add primary key (n)' % table)
1126        r = get(table, 2)
1127        self.assertIsInstance(r, dict)
1128        self.assertEqual(r, dict(n=2, t='y'))
1129        self.assertEqual(get(table, 1)['t'], 'x')
1130        self.assertEqual(get(table, 3)['t'], 'z')
1131        self.assertEqual(get(table + '*', 2)['t'], 'y')
1132        self.assertEqual(get(table + ' *', 2)['t'], 'y')
1133        self.assertRaises(KeyError, get, table, (2, 2))
1134        s = dict(n=3)
1135        r = get(table, s)
1136        self.assertIs(r, s)
1137        self.assertEqual(r, dict(n=3, t='z'))
1138        s.update(n=1)
1139        self.assertEqual(get(table, s)['t'], 'x')
1140        s.update(n=2)
1141        self.assertEqual(get(table, r)['t'], 'y')
1142        s.pop('n')
1143        self.assertRaises(KeyError, get, table, s)
1144
1145    def testGetWithOid(self):
1146        get = self.db.get
1147        query = self.db.query
1148        table = 'get_with_oid_test_table'
1149        self.createTable(table, 'n integer, t text', oids=True,
1150            values=enumerate('xyz', start=1))
1151        self.assertRaises(pg.ProgrammingError, get, table, 2)
1152        self.assertRaises(KeyError, get, table, {}, 'oid')
1153        r = get(table, 2, 'n')
1154        qoid = 'oid(%s)' % table
1155        self.assertIn(qoid, r)
1156        oid = r[qoid]
1157        self.assertIsInstance(oid, int)
1158        result = {'t': 'y', 'n': 2, qoid: oid}
1159        self.assertEqual(r, result)
1160        r = get(table, oid, 'oid')
1161        self.assertEqual(r, result)
1162        r = get(table, dict(oid=oid))
1163        self.assertEqual(r, result)
1164        r = get(table, dict(oid=oid), 'oid')
1165        self.assertEqual(r, result)
1166        r = get(table, {qoid: oid})
1167        self.assertEqual(r, result)
1168        r = get(table, {qoid: oid}, 'oid')
1169        self.assertEqual(r, result)
1170        self.assertEqual(get(table + '*', 2, 'n'), r)
1171        self.assertEqual(get(table + ' *', 2, 'n'), r)
1172        self.assertEqual(get(table, oid, 'oid')['t'], 'y')
1173        self.assertEqual(get(table, 1, 'n')['t'], 'x')
1174        self.assertEqual(get(table, 3, 'n')['t'], 'z')
1175        self.assertEqual(get(table, 2, 'n')['t'], 'y')
1176        self.assertRaises(pg.DatabaseError, get, table, 4, 'n')
1177        r['n'] = 3
1178        self.assertEqual(get(table, r, 'n')['t'], 'z')
1179        self.assertEqual(get(table, 1, 'n')['t'], 'x')
1180        self.assertEqual(get(table, r, 'oid')['t'], 'z')
1181        query('alter table "%s" alter n set not null' % table)
1182        query('alter table "%s" add primary key (n)' % table)
1183        self.assertEqual(get(table, 3)['t'], 'z')
1184        self.assertEqual(get(table, 1)['t'], 'x')
1185        self.assertEqual(get(table, 2)['t'], 'y')
1186        r['n'] = 1
1187        self.assertEqual(get(table, r)['t'], 'x')
1188        r['n'] = 3
1189        self.assertEqual(get(table, r)['t'], 'z')
1190        r['n'] = 2
1191        self.assertEqual(get(table, r)['t'], 'y')
1192        r = get(table, oid, 'oid')
1193        self.assertEqual(r, result)
1194        r = get(table, dict(oid=oid))
1195        self.assertEqual(r, result)
1196        r = get(table, dict(oid=oid), 'oid')
1197        self.assertEqual(r, result)
1198        r = get(table, {qoid: oid})
1199        self.assertEqual(r, result)
1200        r = get(table, {qoid: oid}, 'oid')
1201        self.assertEqual(r, result)
1202        r = get(table, dict(oid=oid, n=1))
1203        self.assertEqual(r['n'], 1)
1204        self.assertNotEqual(r[qoid], oid)
1205        r = get(table, dict(oid=oid, t='z'), 't')
1206        self.assertEqual(r['n'], 3)
1207        self.assertNotEqual(r[qoid], oid)
1208
1209    def testGetWithCompositeKey(self):
1210        get = self.db.get
1211        query = self.db.query
1212        table = 'get_test_table_1'
1213        self.createTable(table, 'n integer primary key, t text',
1214            values=enumerate('abc', start=1))
1215        self.assertEqual(get(table, 2)['t'], 'b')
1216        self.assertEqual(get(table, 1, 'n')['t'], 'a')
1217        self.assertEqual(get(table, 2, ('n',))['t'], 'b')
1218        self.assertEqual(get(table, 3, ['n'])['t'], 'c')
1219        self.assertEqual(get(table, (2,), ('n',))['t'], 'b')
1220        self.assertEqual(get(table, 'b', 't')['n'], 2)
1221        self.assertEqual(get(table, ('a',), ('t',))['n'], 1)
1222        self.assertEqual(get(table, ['c'], ['t'])['n'], 3)
1223        table = 'get_test_table_2'
1224        self.createTable(table,
1225            'n integer, m integer, t text, primary key (n, m)',
1226            values=[(n + 1, m + 1, chr(ord('a') + 2 * n + m))
1227                for n in range(3) for m in range(2)])
1228        self.assertRaises(KeyError, get, table, 2)
1229        self.assertEqual(get(table, (1, 1))['t'], 'a')
1230        self.assertEqual(get(table, (1, 2))['t'], 'b')
1231        self.assertEqual(get(table, (2, 1))['t'], 'c')
1232        self.assertEqual(get(table, (1, 2), ('n', 'm'))['t'], 'b')
1233        self.assertEqual(get(table, (1, 2), ('m', 'n'))['t'], 'c')
1234        self.assertEqual(get(table, (3, 1), ('n', 'm'))['t'], 'e')
1235        self.assertEqual(get(table, (1, 3), ('m', 'n'))['t'], 'e')
1236        self.assertEqual(get(table, dict(n=2, m=2))['t'], 'd')
1237        self.assertEqual(get(table, dict(n=1, m=2), ('n', 'm'))['t'], 'b')
1238        self.assertEqual(get(table, dict(n=2, m=1), ['n', 'm'])['t'], 'c')
1239        self.assertEqual(get(table, dict(n=3, m=2), ('m', 'n'))['t'], 'f')
1240
1241    def testGetWithQuotedNames(self):
1242        get = self.db.get
1243        query = self.db.query
1244        table = 'test table for get()'
1245        self.createTable(table, '"Prime!" smallint primary key,'
1246            ' "much space" integer, "Questions?" text',
1247            values=[(17, 1001, 'No!')])
1248        r = get(table, 17)
1249        self.assertIsInstance(r, dict)
1250        self.assertEqual(r['Prime!'], 17)
1251        self.assertEqual(r['much space'], 1001)
1252        self.assertEqual(r['Questions?'], 'No!')
1253
1254    def testGetFromView(self):
1255        self.db.query('delete from test where i4=14')
1256        self.db.query('insert into test (i4, v4) values('
1257            "14, 'abc4')")
1258        r = self.db.get('test_view', 14, 'i4')
1259        self.assertIn('v4', r)
1260        self.assertEqual(r['v4'], 'abc4')
1261
1262    def testGetLittleBobbyTables(self):
1263        get = self.db.get
1264        query = self.db.query
1265        self.createTable('test_students',
1266            'firstname varchar primary key, nickname varchar, grade char(2)',
1267            values=[("D'Arcy", 'Darcey', 'A+'), ('Sheldon', 'Moonpie', 'A+'),
1268                    ('Robert', 'Little Bobby Tables', 'D-')])
1269        r = get('test_students', 'Sheldon')
1270        self.assertEqual(r, dict(
1271            firstname="Sheldon", nickname='Moonpie', grade='A+'))
1272        r = get('test_students', 'Robert')
1273        self.assertEqual(r, dict(
1274            firstname="Robert", nickname='Little Bobby Tables', grade='D-'))
1275        r = get('test_students', "D'Arcy")
1276        self.assertEqual(r, dict(
1277            firstname="D'Arcy", nickname='Darcey', grade='A+'))
1278        try:
1279            get('test_students', "D' Arcy")
1280        except pg.DatabaseError as error:
1281            self.assertEqual(str(error),
1282                'No such record in test_students\nwhere "firstname" = $1\n'
1283                'with $1="D\' Arcy"')
1284        try:
1285            get('test_students', "Robert'); TRUNCATE TABLE test_students;--")
1286        except pg.DatabaseError as error:
1287            self.assertEqual(str(error),
1288                'No such record in test_students\nwhere "firstname" = $1\n'
1289                'with $1="Robert\'); TRUNCATE TABLE test_students;--"')
1290        q = "select * from test_students order by 1 limit 4"
1291        r = query(q).getresult()
1292        self.assertEqual(len(r), 3)
1293        self.assertEqual(r[1][2], 'D-')
1294
1295    def testInsert(self):
1296        insert = self.db.insert
1297        query = self.db.query
1298        bool_on = pg.get_bool()
1299        decimal = pg.get_decimal()
1300        table = 'insert_test_table'
1301        self.createTable(table,
1302            'i2 smallint, i4 integer, i8 bigint,'
1303            ' d numeric, f4 real, f8 double precision, m money,'
1304            ' v4 varchar(4), c4 char(4), t text,'
1305            ' b boolean, ts timestamp', oids=True)
1306        oid_table = 'oid(%s)' % table
1307        tests = [dict(i2=None, i4=None, i8=None),
1308            (dict(i2='', i4='', i8=''), dict(i2=None, i4=None, i8=None)),
1309            (dict(i2=0, i4=0, i8=0), dict(i2=0, i4=0, i8=0)),
1310            dict(i2=42, i4=123456, i8=9876543210),
1311            dict(i2=2 ** 15 - 1,
1312                i4=int(2 ** 31 - 1), i8=long(2 ** 63 - 1)),
1313            dict(d=None), (dict(d=''), dict(d=None)),
1314            dict(d=Decimal(0)), (dict(d=0), dict(d=Decimal(0))),
1315            dict(f4=None, f8=None), dict(f4=0, f8=0),
1316            (dict(f4='', f8=''), dict(f4=None, f8=None)),
1317            (dict(d=1234.5, f4=1234.5, f8=1234.5),
1318                  dict(d=Decimal('1234.5'))),
1319            dict(d=Decimal('123.456789'), f4=12.375, f8=123.4921875),
1320            dict(d=Decimal('123456789.9876543212345678987654321')),
1321            dict(m=None), (dict(m=''), dict(m=None)),
1322            dict(m=Decimal('-1234.56')),
1323            (dict(m=('-1234.56')), dict(m=Decimal('-1234.56'))),
1324            dict(m=Decimal('1234.56')), dict(m=Decimal('123456')),
1325            (dict(m='1234.56'), dict(m=Decimal('1234.56'))),
1326            (dict(m=1234.5), dict(m=Decimal('1234.5'))),
1327            (dict(m=-1234.5), dict(m=Decimal('-1234.5'))),
1328            (dict(m=123456), dict(m=Decimal('123456'))),
1329            (dict(m='1234567.89'), dict(m=Decimal('1234567.89'))),
1330            dict(b=None), (dict(b=''), dict(b=None)),
1331            dict(b='f'), dict(b='t'),
1332            (dict(b=0), dict(b='f')), (dict(b=1), dict(b='t')),
1333            (dict(b=False), dict(b='f')), (dict(b=True), dict(b='t')),
1334            (dict(b='0'), dict(b='f')), (dict(b='1'), dict(b='t')),
1335            (dict(b='n'), dict(b='f')), (dict(b='y'), dict(b='t')),
1336            (dict(b='no'), dict(b='f')), (dict(b='yes'), dict(b='t')),
1337            (dict(b='off'), dict(b='f')), (dict(b='on'), dict(b='t')),
1338            dict(v4=None, c4=None, t=None),
1339            (dict(v4='', c4='', t=''), dict(c4=' ' * 4)),
1340            dict(v4='1234', c4='1234', t='1234' * 10),
1341            dict(v4='abcd', c4='abcd', t='abcdefg'),
1342            (dict(v4='abc', c4='abc', t='abc'), dict(c4='abc ')),
1343            dict(ts=None), (dict(ts=''), dict(ts=None)),
1344            (dict(ts=0), dict(ts=None)), (dict(ts=False), dict(ts=None)),
1345            dict(ts='2012-12-21 00:00:00'),
1346            (dict(ts='2012-12-21'), dict(ts='2012-12-21 00:00:00')),
1347            dict(ts='2012-12-21 12:21:12'),
1348            dict(ts='2013-01-05 12:13:14'),
1349            dict(ts='current_timestamp')]
1350        for test in tests:
1351            if isinstance(test, dict):
1352                data = test
1353                change = {}
1354            else:
1355                data, change = test
1356            expect = data.copy()
1357            expect.update(change)
1358            if bool_on:
1359                b = expect.get('b')
1360                if b is not None:
1361                    expect['b'] = b == 't'
1362            if decimal is not Decimal:
1363                d = expect.get('d')
1364                if d is not None:
1365                    expect['d'] = decimal(d)
1366                m = expect.get('m')
1367                if m is not None:
1368                    expect['m'] = decimal(m)
1369            self.assertEqual(insert(table, data), data)
1370            self.assertIn(oid_table, data)
1371            oid = data[oid_table]
1372            self.assertIsInstance(oid, int)
1373            data = dict(item for item in data.items()
1374                if item[0] in expect)
1375            ts = expect.get('ts')
1376            if ts == 'current_timestamp':
1377                ts = expect['ts'] = data['ts']
1378                if len(ts) > 19:
1379                    self.assertEqual(ts[19], '.')
1380                    ts = ts[:19]
1381                else:
1382                    self.assertEqual(len(ts), 19)
1383                self.assertTrue(ts[:4].isdigit())
1384                self.assertEqual(ts[4], '-')
1385                self.assertEqual(ts[10], ' ')
1386                self.assertTrue(ts[11:13].isdigit())
1387                self.assertEqual(ts[13], ':')
1388            self.assertEqual(data, expect)
1389            data = query(
1390                'select oid,* from "%s"' % table).dictresult()[0]
1391            self.assertEqual(data['oid'], oid)
1392            data = dict(item for item in data.items()
1393                if item[0] in expect)
1394            self.assertEqual(data, expect)
1395            query('delete from "%s"' % table)
1396
1397    def testInsertWithOid(self):
1398        insert = self.db.insert
1399        query = self.db.query
1400        self.createTable('test_table', 'n int', oids=True)
1401        r = insert('test_table', n=1)
1402        self.assertIsInstance(r, dict)
1403        self.assertEqual(r['n'], 1)
1404        self.assertNotIn('oid', r)
1405        qoid = 'oid(test_table)'
1406        self.assertIn(qoid, r)
1407        oid = r[qoid]
1408        self.assertEqual(sorted(r.keys()), ['n', qoid])
1409        r = insert('test_table', n=2, oid=oid)
1410        self.assertIsInstance(r, dict)
1411        self.assertEqual(r['n'], 2)
1412        self.assertIn(qoid, r)
1413        self.assertNotEqual(r[qoid], oid)
1414        self.assertNotIn('oid', r)
1415        r = insert('test_table', None, n=3)
1416        self.assertIsInstance(r, dict)
1417        self.assertEqual(r['n'], 3)
1418        s = r
1419        r = insert('test_table', r)
1420        self.assertIs(r, s)
1421        self.assertEqual(r['n'], 3)
1422        r = insert('test_table *', r)
1423        self.assertIs(r, s)
1424        self.assertEqual(r['n'], 3)
1425        r = insert('test_table', r, n=4)
1426        self.assertIs(r, s)
1427        self.assertEqual(r['n'], 4)
1428        self.assertNotIn('oid', r)
1429        self.assertIn(qoid, r)
1430        oid = r[qoid]
1431        r = insert('test_table', r, n=5, oid=oid)
1432        self.assertIs(r, s)
1433        self.assertEqual(r['n'], 5)
1434        self.assertIn(qoid, r)
1435        self.assertNotEqual(r[qoid], oid)
1436        self.assertNotIn('oid', r)
1437        r['oid'] = oid = r[qoid]
1438        r = insert('test_table', r, n=6)
1439        self.assertIs(r, s)
1440        self.assertEqual(r['n'], 6)
1441        self.assertIn(qoid, r)
1442        self.assertNotEqual(r[qoid], oid)
1443        self.assertNotIn('oid', r)
1444        q = 'select n from test_table order by 1 limit 9'
1445        r = ' '.join(str(row[0]) for row in query(q).getresult())
1446        self.assertEqual(r, '1 2 3 3 3 4 5 6')
1447        query("truncate test_table")
1448        query("alter table test_table add unique (n)")
1449        r = insert('test_table', dict(n=7))
1450        self.assertIsInstance(r, dict)
1451        self.assertEqual(r['n'], 7)
1452        self.assertRaises(pg.ProgrammingError, insert, 'test_table', r)
1453        r['n'] = 6
1454        self.assertRaises(pg.ProgrammingError, insert, 'test_table', r, n=7)
1455        self.assertIsInstance(r, dict)
1456        self.assertEqual(r['n'], 7)
1457        r['n'] = 6
1458        r = insert('test_table', r)
1459        self.assertIsInstance(r, dict)
1460        self.assertEqual(r['n'], 6)
1461        r = query(q).getresult()
1462        r = ' '.join(str(row[0]) for row in query(q).getresult())
1463        self.assertEqual(r, '6 7')
1464
1465    def testInsertWithQuotedNames(self):
1466        insert = self.db.insert
1467        query = self.db.query
1468        table = 'test table for insert()'
1469        self.createTable(table, '"Prime!" smallint primary key,'
1470            ' "much space" integer, "Questions?" text')
1471        r = {'Prime!': 11, 'much space': 2002, 'Questions?': 'What?'}
1472        r = insert(table, r)
1473        self.assertIsInstance(r, dict)
1474        self.assertEqual(r['Prime!'], 11)
1475        self.assertEqual(r['much space'], 2002)
1476        self.assertEqual(r['Questions?'], 'What?')
1477        r = query('select * from "%s" limit 2' % table).dictresult()
1478        self.assertEqual(len(r), 1)
1479        r = r[0]
1480        self.assertEqual(r['Prime!'], 11)
1481        self.assertEqual(r['much space'], 2002)
1482        self.assertEqual(r['Questions?'], 'What?')
1483
1484    def testInsertIntoView(self):
1485        insert = self.db.insert
1486        query = self.db.query
1487        query("truncate test")
1488        q = 'select * from test_view order by i4 limit 3'
1489        r = query(q).getresult()
1490        self.assertEqual(r, [])
1491        r = dict(i4=1234, v4='abcd')
1492        insert('test', r)
1493        self.assertIsNone(r['i2'])
1494        self.assertEqual(r['i4'], 1234)
1495        self.assertIsNone(r['i8'])
1496        self.assertEqual(r['v4'], 'abcd')
1497        self.assertIsNone(r['c4'])
1498        r = query(q).getresult()
1499        self.assertEqual(r, [(1234, 'abcd')])
1500        r = dict(i4=5678, v4='efgh')
1501        try:
1502            insert('test_view', r)
1503        except pg.ProgrammingError as error:
1504            if self.db.server_version < 90300:
1505                # must setup rules in older PostgreSQL versions
1506                self.skipTest('database cannot insert into view')
1507            self.fail(str(error))
1508        self.assertNotIn('i2', r)
1509        self.assertEqual(r['i4'], 5678)
1510        self.assertNotIn('i8', r)
1511        self.assertEqual(r['v4'], 'efgh')
1512        self.assertNotIn('c4', r)
1513        r = query(q).getresult()
1514        self.assertEqual(r, [(1234, 'abcd'), (5678, 'efgh')])
1515
1516    def testUpdate(self):
1517        update = self.db.update
1518        query = self.db.query
1519        self.assertRaises(pg.ProgrammingError, update,
1520            'test', i2=2, i4=4, i8=8)
1521        table = 'update_test_table'
1522        self.createTable(table, 'n integer, t text', oids=True,
1523            values=enumerate('xyz', start=1))
1524        self.assertRaises(pg.ProgrammingError, self.db.get, table, 2)
1525        r = self.db.get(table, 2, 'n')
1526        r['t'] = 'u'
1527        s = update(table, r)
1528        self.assertEqual(s, r)
1529        q = 'select t from "%s" where n=2' % table
1530        r = query(q).getresult()[0][0]
1531        self.assertEqual(r, 'u')
1532
1533    def testUpdateWithOid(self):
1534        update = self.db.update
1535        get = self.db.get
1536        query = self.db.query
1537        self.createTable('test_table', 'n int', oids=True, values=[1])
1538        s = get('test_table', 1, 'n')
1539        self.assertIsInstance(s, dict)
1540        self.assertEqual(s['n'], 1)
1541        s['n'] = 2
1542        r = update('test_table', s)
1543        self.assertIs(r, s)
1544        self.assertEqual(r['n'], 2)
1545        qoid = 'oid(test_table)'
1546        self.assertIn(qoid, r)
1547        self.assertNotIn('oid', r)
1548        self.assertEqual(sorted(r.keys()), ['n', qoid])
1549        r['n'] = 3
1550        oid = r.pop(qoid)
1551        r = update('test_table', r, oid=oid)
1552        self.assertIs(r, s)
1553        self.assertEqual(r['n'], 3)
1554        r.pop(qoid)
1555        self.assertRaises(pg.ProgrammingError, update, 'test_table', r)
1556        s = get('test_table', 3, 'n')
1557        self.assertIsInstance(s, dict)
1558        self.assertEqual(s['n'], 3)
1559        s.pop('n')
1560        r = update('test_table', s)
1561        oid = r.pop(qoid)
1562        self.assertEqual(r, {})
1563        q = "select n from test_table limit 2"
1564        r = query(q).getresult()
1565        self.assertEqual(r, [(3,)])
1566        query("insert into test_table values (1)")
1567        self.assertRaises(pg.ProgrammingError,
1568            update, 'test_table', dict(oid=oid, n=4))
1569        r = update('test_table', dict(n=4), oid=oid)
1570        self.assertEqual(r['n'], 4)
1571        r = update('test_table *', dict(n=5), oid=oid)
1572        self.assertEqual(r['n'], 5)
1573        query("alter table test_table add column m int")
1574        query("alter table test_table add primary key (n)")
1575        self.assertIn('m', self.db.get_attnames('test_table', flush=True))
1576        self.assertEqual('n', self.db.pkey('test_table', flush=True))
1577        s = dict(n=1, m=4)
1578        r = update('test_table', s)
1579        self.assertIs(r, s)
1580        self.assertEqual(r['n'], 1)
1581        self.assertEqual(r['m'], 4)
1582        s = dict(m=7)
1583        r = update('test_table', s, n=5)
1584        self.assertIs(r, s)
1585        self.assertEqual(r['n'], 5)
1586        self.assertEqual(r['m'], 7)
1587        q = "select n, m from test_table order by 1 limit 3"
1588        r = query(q).getresult()
1589        self.assertEqual(r, [(1, 4), (5, 7)])
1590        s = dict(m=9, oid=oid)
1591        self.assertRaises(KeyError, update, 'test_table', s)
1592        r = update('test_table', s, oid=oid)
1593        self.assertIs(r, s)
1594        self.assertEqual(r['n'], 5)
1595        self.assertEqual(r['m'], 9)
1596        s = dict(n=1, m=3, oid=oid)
1597        r = update('test_table', s)
1598        self.assertIs(r, s)
1599        self.assertEqual(r['n'], 1)
1600        self.assertEqual(r['m'], 3)
1601        r = query(q).getresult()
1602        self.assertEqual(r, [(1, 3), (5, 9)])
1603
1604    def testUpdateWithoutOid(self):
1605        update = self.db.update
1606        query = self.db.query
1607        self.assertRaises(pg.ProgrammingError, update,
1608            'test', i2=2, i4=4, i8=8)
1609        table = 'update_test_table'
1610        self.createTable(table, 'n integer primary key, t text', oids=False,
1611            values=enumerate('xyz', start=1))
1612        r = self.db.get(table, 2)
1613        r['t'] = 'u'
1614        s = update(table, r)
1615        self.assertEqual(s, r)
1616        q = 'select t from "%s" where n=2' % table
1617        r = query(q).getresult()[0][0]
1618        self.assertEqual(r, 'u')
1619
1620    def testUpdateWithCompositeKey(self):
1621        update = self.db.update
1622        query = self.db.query
1623        table = 'update_test_table_1'
1624        self.createTable(table, 'n integer primary key, t text',
1625            values=enumerate('abc', start=1))
1626        self.assertRaises(KeyError, update, table, dict(t='b'))
1627        s = dict(n=2, t='d')
1628        r = update(table, s)
1629        self.assertIs(r, s)
1630        self.assertEqual(r['n'], 2)
1631        self.assertEqual(r['t'], 'd')
1632        q = 'select t from "%s" where n=2' % table
1633        r = query(q).getresult()[0][0]
1634        self.assertEqual(r, 'd')
1635        s.update(dict(n=4, t='e'))
1636        r = update(table, s)
1637        self.assertEqual(r['n'], 4)
1638        self.assertEqual(r['t'], 'e')
1639        q = 'select t from "%s" where n=2' % table
1640        r = query(q).getresult()[0][0]
1641        self.assertEqual(r, 'd')
1642        q = 'select t from "%s" where n=4' % table
1643        r = query(q).getresult()
1644        self.assertEqual(len(r), 0)
1645        query('drop table "%s"' % table)
1646        table = 'update_test_table_2'
1647        self.createTable(table,
1648            'n integer, m integer, t text, primary key (n, m)',
1649            values=[(n + 1, m + 1, chr(ord('a') + 2 * n + m))
1650                for n in range(3) for m in range(2)])
1651        self.assertRaises(KeyError, update, table, dict(n=2, t='b'))
1652        self.assertEqual(update(table,
1653            dict(n=2, m=2, t='x'))['t'], 'x')
1654        q = 'select t from "%s" where n=2 order by m' % table
1655        r = [r[0] for r in query(q).getresult()]
1656        self.assertEqual(r, ['c', 'x'])
1657
1658    def testUpdateWithQuotedNames(self):
1659        update = self.db.update
1660        query = self.db.query
1661        table = 'test table for update()'
1662        self.createTable(table, '"Prime!" smallint primary key,'
1663            ' "much space" integer, "Questions?" text',
1664            values=[(13, 3003, 'Why!')])
1665        r = {'Prime!': 13, 'much space': 7007, 'Questions?': 'When?'}
1666        r = update(table, r)
1667        self.assertIsInstance(r, dict)
1668        self.assertEqual(r['Prime!'], 13)
1669        self.assertEqual(r['much space'], 7007)
1670        self.assertEqual(r['Questions?'], 'When?')
1671        r = query('select * from "%s" limit 2' % table).dictresult()
1672        self.assertEqual(len(r), 1)
1673        r = r[0]
1674        self.assertEqual(r['Prime!'], 13)
1675        self.assertEqual(r['much space'], 7007)
1676        self.assertEqual(r['Questions?'], 'When?')
1677
1678    def testUpsert(self):
1679        upsert = self.db.upsert
1680        query = self.db.query
1681        self.assertRaises(pg.ProgrammingError, upsert,
1682            'test', i2=2, i4=4, i8=8)
1683        table = 'upsert_test_table'
1684        self.createTable(table, 'n integer primary key, t text', oids=True)
1685        s = dict(n=1, t='x')
1686        try:
1687            r = upsert(table, s)
1688        except pg.ProgrammingError as error:
1689            if self.db.server_version < 90500:
1690                self.skipTest('database does not support upsert')
1691            self.fail(str(error))
1692        self.assertIs(r, s)
1693        self.assertEqual(r['n'], 1)
1694        self.assertEqual(r['t'], 'x')
1695        s.update(n=2, t='y')
1696        r = upsert(table, s, **dict.fromkeys(s))
1697        self.assertIs(r, s)
1698        self.assertEqual(r['n'], 2)
1699        self.assertEqual(r['t'], 'y')
1700        q = 'select n, t from "%s" order by n limit 3' % table
1701        r = query(q).getresult()
1702        self.assertEqual(r, [(1, 'x'), (2, 'y')])
1703        s.update(t='z')
1704        r = upsert(table, s)
1705        self.assertIs(r, s)
1706        self.assertEqual(r['n'], 2)
1707        self.assertEqual(r['t'], 'z')
1708        r = query(q).getresult()
1709        self.assertEqual(r, [(1, 'x'), (2, 'z')])
1710        s.update(t='n')
1711        r = upsert(table, s, t=False)
1712        self.assertIs(r, s)
1713        self.assertEqual(r['n'], 2)
1714        self.assertEqual(r['t'], 'z')
1715        r = query(q).getresult()
1716        self.assertEqual(r, [(1, 'x'), (2, 'z')])
1717        s.update(t='y')
1718        r = upsert(table, s, t=True)
1719        self.assertIs(r, s)
1720        self.assertEqual(r['n'], 2)
1721        self.assertEqual(r['t'], 'y')
1722        r = query(q).getresult()
1723        self.assertEqual(r, [(1, 'x'), (2, 'y')])
1724        s.update(t='n')
1725        r = upsert(table, s, t="included.t || '2'")
1726        self.assertIs(r, s)
1727        self.assertEqual(r['n'], 2)
1728        self.assertEqual(r['t'], 'y2')
1729        r = query(q).getresult()
1730        self.assertEqual(r, [(1, 'x'), (2, 'y2')])
1731        s.update(t='y')
1732        r = upsert(table, s, t="excluded.t || '3'")
1733        self.assertIs(r, s)
1734        self.assertEqual(r['n'], 2)
1735        self.assertEqual(r['t'], 'y3')
1736        r = query(q).getresult()
1737        self.assertEqual(r, [(1, 'x'), (2, 'y3')])
1738        s.update(n=1, t='2')
1739        r = upsert(table, s, t="included.t || excluded.t")
1740        self.assertIs(r, s)
1741        self.assertEqual(r['n'], 1)
1742        self.assertEqual(r['t'], 'x2')
1743        r = query(q).getresult()
1744        self.assertEqual(r, [(1, 'x2'), (2, 'y3')])
1745        # not existing columns and oid parameter should be ignored
1746        s = dict(m=3, u='z')
1747        r = upsert(table, s, oid='invalid')
1748        self.assertIs(r, s)
1749
1750    def testUpsertWithOid(self):
1751        upsert = self.db.upsert
1752        get = self.db.get
1753        query = self.db.query
1754        self.createTable('test_table', 'n int', oids=True, values=[1])
1755        self.assertRaises(pg.ProgrammingError,
1756            upsert, 'test_table', dict(n=2))
1757        r = get('test_table', 1, 'n')
1758        self.assertIsInstance(r, dict)
1759        self.assertEqual(r['n'], 1)
1760        qoid = 'oid(test_table)'
1761        self.assertIn(qoid, r)
1762        self.assertNotIn('oid', r)
1763        oid = r[qoid]
1764        self.assertRaises(pg.ProgrammingError,
1765            upsert, 'test_table', dict(n=2, oid=oid))
1766        query("alter table test_table add column m int")
1767        query("alter table test_table add primary key (n)")
1768        self.assertIn('m', self.db.get_attnames('test_table', flush=True))
1769        self.assertEqual('n', self.db.pkey('test_table', flush=True))
1770        s = dict(n=2)
1771        try:
1772            r = upsert('test_table', s)
1773        except pg.ProgrammingError as error:
1774            if self.db.server_version < 90500:
1775                self.skipTest('database does not support upsert')
1776            self.fail(str(error))
1777        self.assertIs(r, s)
1778        self.assertEqual(r['n'], 2)
1779        self.assertIsNone(r['m'])
1780        q = query("select n, m from test_table order by n limit 3")
1781        self.assertEqual(q.getresult(), [(1, None), (2, None)])
1782        r['oid'] = oid
1783        r = upsert('test_table', r)
1784        self.assertIs(r, s)
1785        self.assertEqual(r['n'], 2)
1786        self.assertIsNone(r['m'])
1787        self.assertIn(qoid, r)
1788        self.assertNotIn('oid', r)
1789        self.assertNotEqual(r[qoid], oid)
1790        r['m'] = 7
1791        r = upsert('test_table', r)
1792        self.assertIs(r, s)
1793        self.assertEqual(r['n'], 2)
1794        self.assertEqual(r['m'], 7)
1795        r.update(n=1, m=3)
1796        r = upsert('test_table', r)
1797        self.assertIs(r, s)
1798        self.assertEqual(r['n'], 1)
1799        self.assertEqual(r['m'], 3)
1800        q = query("select n, m from test_table order by n limit 3")
1801        self.assertEqual(q.getresult(), [(1, 3), (2, 7)])
1802        r = upsert('test_table', r, oid='invalid')
1803        self.assertIs(r, s)
1804        self.assertEqual(r['n'], 1)
1805        self.assertEqual(r['m'], 3)
1806        r['m'] = 5
1807        r = upsert('test_table', r, m=False)
1808        self.assertIs(r, s)
1809        self.assertEqual(r['n'], 1)
1810        self.assertEqual(r['m'], 3)
1811        r['m'] = 5
1812        r = upsert('test_table', r, m=True)
1813        self.assertIs(r, s)
1814        self.assertEqual(r['n'], 1)
1815        self.assertEqual(r['m'], 5)
1816        r.update(n=2, m=1)
1817        r = upsert('test_table', r, m='included.m')
1818        self.assertIs(r, s)
1819        self.assertEqual(r['n'], 2)
1820        self.assertEqual(r['m'], 7)
1821        r['m'] = 9
1822        r = upsert('test_table', r, m='excluded.m')
1823        self.assertIs(r, s)
1824        self.assertEqual(r['n'], 2)
1825        self.assertEqual(r['m'], 9)
1826        r['m'] = 8
1827        r = upsert('test_table *', r, m='included.m + 1')
1828        self.assertIs(r, s)
1829        self.assertEqual(r['n'], 2)
1830        self.assertEqual(r['m'], 10)
1831        q = query("select n, m from test_table order by n limit 3")
1832        self.assertEqual(q.getresult(), [(1, 5), (2, 10)])
1833
1834    def testUpsertWithCompositeKey(self):
1835        upsert = self.db.upsert
1836        query = self.db.query
1837        table = 'upsert_test_table_2'
1838        self.createTable(table,
1839            'n integer, m integer, t text, primary key (n, m)')
1840        s = dict(n=1, m=2, t='x')
1841        try:
1842            r = upsert(table, s)
1843        except pg.ProgrammingError as error:
1844            if self.db.server_version < 90500:
1845                self.skipTest('database does not support upsert')
1846            self.fail(str(error))
1847        self.assertIs(r, s)
1848        self.assertEqual(r['n'], 1)
1849        self.assertEqual(r['m'], 2)
1850        self.assertEqual(r['t'], 'x')
1851        s.update(m=3, t='y')
1852        r = upsert(table, s, **dict.fromkeys(s))
1853        self.assertIs(r, s)
1854        self.assertEqual(r['n'], 1)
1855        self.assertEqual(r['m'], 3)
1856        self.assertEqual(r['t'], 'y')
1857        q = 'select n, m, t from "%s" order by n, m limit 3' % table
1858        r = query(q).getresult()
1859        self.assertEqual(r, [(1, 2, 'x'), (1, 3, 'y')])
1860        s.update(t='z')
1861        r = upsert(table, s)
1862        self.assertIs(r, s)
1863        self.assertEqual(r['n'], 1)
1864        self.assertEqual(r['m'], 3)
1865        self.assertEqual(r['t'], 'z')
1866        r = query(q).getresult()
1867        self.assertEqual(r, [(1, 2, 'x'), (1, 3, 'z')])
1868        s.update(t='n')
1869        r = upsert(table, s, t=False)
1870        self.assertIs(r, s)
1871        self.assertEqual(r['n'], 1)
1872        self.assertEqual(r['m'], 3)
1873        self.assertEqual(r['t'], 'z')
1874        r = query(q).getresult()
1875        self.assertEqual(r, [(1, 2, 'x'), (1, 3, 'z')])
1876        s.update(t='n')
1877        r = upsert(table, s, t=True)
1878        self.assertIs(r, s)
1879        self.assertEqual(r['n'], 1)
1880        self.assertEqual(r['m'], 3)
1881        self.assertEqual(r['t'], 'n')
1882        r = query(q).getresult()
1883        self.assertEqual(r, [(1, 2, 'x'), (1, 3, 'n')])
1884        s.update(n=2, t='y')
1885        r = upsert(table, s, t="'z'")
1886        self.assertIs(r, s)
1887        self.assertEqual(r['n'], 2)
1888        self.assertEqual(r['m'], 3)
1889        self.assertEqual(r['t'], 'y')
1890        r = query(q).getresult()
1891        self.assertEqual(r, [(1, 2, 'x'), (1, 3, 'n'), (2, 3, 'y')])
1892        s.update(n=1, t='m')
1893        r = upsert(table, s, t='included.t || excluded.t')
1894        self.assertIs(r, s)
1895        self.assertEqual(r['n'], 1)
1896        self.assertEqual(r['m'], 3)
1897        self.assertEqual(r['t'], 'nm')
1898        r = query(q).getresult()
1899        self.assertEqual(r, [(1, 2, 'x'), (1, 3, 'nm'), (2, 3, 'y')])
1900
1901    def testUpsertWithQuotedNames(self):
1902        upsert = self.db.upsert
1903        query = self.db.query
1904        table = 'test table for upsert()'
1905        self.createTable(table, '"Prime!" smallint primary key,'
1906            ' "much space" integer, "Questions?" text')
1907        s = {'Prime!': 31, 'much space': 9009, 'Questions?': 'Yes.'}
1908        try:
1909            r = upsert(table, s)
1910        except pg.ProgrammingError as error:
1911            if self.db.server_version < 90500:
1912                self.skipTest('database does not support upsert')
1913            self.fail(str(error))
1914        self.assertIs(r, s)
1915        self.assertEqual(r['Prime!'], 31)
1916        self.assertEqual(r['much space'], 9009)
1917        self.assertEqual(r['Questions?'], 'Yes.')
1918        q = 'select * from "%s" limit 2' % table
1919        r = query(q).getresult()
1920        self.assertEqual(r, [(31, 9009, 'Yes.')])
1921        s.update({'Questions?': 'No.'})
1922        r = upsert(table, s)
1923        self.assertIs(r, s)
1924        self.assertEqual(r['Prime!'], 31)
1925        self.assertEqual(r['much space'], 9009)
1926        self.assertEqual(r['Questions?'], 'No.')
1927        r = query(q).getresult()
1928        self.assertEqual(r, [(31, 9009, 'No.')])
1929
1930    def testClear(self):
1931        clear = self.db.clear
1932        f = False if pg.get_bool() else 'f'
1933        r = clear('test')
1934        result = dict(
1935            i2=0, i4=0, i8=0, d=0, f4=0, f8=0, m=0, v4='', c4='', t='')
1936        self.assertEqual(r, result)
1937        table = 'clear_test_table'
1938        self.createTable(table,
1939            'n integer, f float, b boolean, d date, t text', oids=True)
1940        r = clear(table)
1941        result = dict(n=0, f=0, b=f, d='', t='')
1942        self.assertEqual(r, result)
1943        r['a'] = r['f'] = r['n'] = 1
1944        r['d'] = r['t'] = 'x'
1945        r['b'] = 't'
1946        r['oid'] = long(1)
1947        r = clear(table, r)
1948        result = dict(a=1, n=0, f=0, b=f, d='', t='', oid=long(1))
1949        self.assertEqual(r, result)
1950
1951    def testClearWithQuotedNames(self):
1952        clear = self.db.clear
1953        table = 'test table for clear()'
1954        self.createTable(table, '"Prime!" smallint primary key,'
1955            ' "much space" integer, "Questions?" text')
1956        r = clear(table)
1957        self.assertIsInstance(r, dict)
1958        self.assertEqual(r['Prime!'], 0)
1959        self.assertEqual(r['much space'], 0)
1960        self.assertEqual(r['Questions?'], '')
1961
1962    def testDelete(self):
1963        delete = self.db.delete
1964        query = self.db.query
1965        self.assertRaises(pg.ProgrammingError, delete,
1966            'test', dict(i2=2, i4=4, i8=8))
1967        table = 'delete_test_table'
1968        self.createTable(table, 'n integer, t text', oids=True,
1969            values=enumerate('xyz', start=1))
1970        self.assertRaises(pg.ProgrammingError, self.db.get, table, 2)
1971        r = self.db.get(table, 1, 'n')
1972        s = delete(table, r)
1973        self.assertEqual(s, 1)
1974        r = self.db.get(table, 3, 'n')
1975        s = delete(table, r)
1976        self.assertEqual(s, 1)
1977        s = delete(table, r)
1978        self.assertEqual(s, 0)
1979        r = query('select * from "%s"' % table).dictresult()
1980        self.assertEqual(len(r), 1)
1981        r = r[0]
1982        result = {'n': 2, 't': 'y'}
1983        self.assertEqual(r, result)
1984        r = self.db.get(table, 2, 'n')
1985        s = delete(table, r)
1986        self.assertEqual(s, 1)
1987        s = delete(table, r)
1988        self.assertEqual(s, 0)
1989        self.assertRaises(pg.DatabaseError, self.db.get, table, 2, 'n')
1990        # not existing columns and oid parameter should be ignored
1991        r.update(m=3, u='z', oid='invalid')
1992        s = delete(table, r)
1993        self.assertEqual(s, 0)
1994
1995    def testDeleteWithOid(self):
1996        delete = self.db.delete
1997        get = self.db.get
1998        query = self.db.query
1999        self.createTable('test_table', 'n int', oids=True, values=range(1, 7))
2000        r = dict(n=3)
2001        self.assertRaises(pg.ProgrammingError, delete, 'test_table', r)
2002        s = get('test_table', 1, 'n')
2003        qoid = 'oid(test_table)'
2004        self.assertIn(qoid, s)
2005        r = delete('test_table', s)
2006        self.assertEqual(r, 1)
2007        r = delete('test_table', s)
2008        self.assertEqual(r, 0)
2009        q = "select min(n),count(n) from test_table"
2010        self.assertEqual(query(q).getresult()[0], (2, 5))
2011        oid = get('test_table', 2, 'n')[qoid]
2012        s = dict(oid=oid, n=2)
2013        self.assertRaises(pg.ProgrammingError, delete, 'test_table', s)
2014        r = delete('test_table', None, oid=oid)
2015        self.assertEqual(r, 1)
2016        r = delete('test_table', None, oid=oid)
2017        self.assertEqual(r, 0)
2018        self.assertEqual(query(q).getresult()[0], (3, 4))
2019        s = dict(oid=oid, n=2)
2020        oid = get('test_table', 3, 'n')[qoid]
2021        self.assertRaises(pg.ProgrammingError, delete, 'test_table', s)
2022        r = delete('test_table', s, oid=oid)
2023        self.assertEqual(r, 1)
2024        r = delete('test_table', s, oid=oid)
2025        self.assertEqual(r, 0)
2026        self.assertEqual(query(q).getresult()[0], (4, 3))
2027        s = get('test_table', 4, 'n')
2028        r = delete('test_table *', s)
2029        self.assertEqual(r, 1)
2030        r = delete('test_table *', s)
2031        self.assertEqual(r, 0)
2032        self.assertEqual(query(q).getresult()[0], (5, 2))
2033        oid = get('test_table', 5, 'n')[qoid]
2034        s = {qoid: oid, 'm': 4}
2035        r = delete('test_table', s, m=6)
2036        self.assertEqual(r, 1)
2037        r = delete('test_table *', s)
2038        self.assertEqual(r, 0)
2039        self.assertEqual(query(q).getresult()[0], (6, 1))
2040        query("alter table test_table add column m int")
2041        query("alter table test_table add primary key (n)")
2042        self.assertIn('m', self.db.get_attnames('test_table', flush=True))
2043        self.assertEqual('n', self.db.pkey('test_table', flush=True))
2044        for i in range(5):
2045            query("insert into test_table values (%d, %d)" % (i + 1, i + 2))
2046        s = dict(m=2)
2047        self.assertRaises(KeyError, delete, 'test_table', s)
2048        s = dict(m=2, oid=oid)
2049        self.assertRaises(KeyError, delete, 'test_table', s)
2050        r = delete('test_table', dict(m=2), oid=oid)
2051        self.assertEqual(r, 0)
2052        oid = get('test_table', 1, 'n')[qoid]
2053        s = dict(oid=oid)
2054        self.assertRaises(KeyError, delete, 'test_table', s)
2055        r = delete('test_table', s, oid=oid)
2056        self.assertEqual(r, 1)
2057        r = delete('test_table', s, oid=oid)
2058        self.assertEqual(r, 0)
2059        self.assertEqual(query(q).getresult()[0], (2, 5))
2060        s = get('test_table', 2, 'n')
2061        del s['n']
2062        r = delete('test_table', s)
2063        self.assertEqual(r, 1)
2064        r = delete('test_table', s)
2065        self.assertEqual(r, 0)
2066        self.assertEqual(query(q).getresult()[0], (3, 4))
2067        r = delete('test_table', n=3)
2068        self.assertEqual(r, 1)
2069        r = delete('test_table', n=3)
2070        self.assertEqual(r, 0)
2071        self.assertEqual(query(q).getresult()[0], (4, 3))
2072        r = delete('test_table', None, n=4)
2073        self.assertEqual(r, 1)
2074        r = delete('test_table', None, n=4)
2075        self.assertEqual(r, 0)
2076        self.assertEqual(query(q).getresult()[0], (5, 2))
2077        s = dict(n=6)
2078        r = delete('test_table', s, n=5)
2079        self.assertEqual(r, 1)
2080        r = delete('test_table', s, n=5)
2081        self.assertEqual(r, 0)
2082        self.assertEqual(query(q).getresult()[0], (6, 1))
2083
2084    def testDeleteWithCompositeKey(self):
2085        query = self.db.query
2086        table = 'delete_test_table_1'
2087        self.createTable(table, 'n integer primary key, t text',
2088            values=enumerate('abc', start=1))
2089        self.assertRaises(KeyError, self.db.delete, table, dict(t='b'))
2090        self.assertEqual(self.db.delete(table, dict(n=2)), 1)
2091        r = query('select t from "%s" where n=2' % table).getresult()
2092        self.assertEqual(r, [])
2093        self.assertEqual(self.db.delete(table, dict(n=2)), 0)
2094        r = query('select t from "%s" where n=3' % table).getresult()[0][0]
2095        self.assertEqual(r, 'c')
2096        table = 'delete_test_table_2'
2097        self.createTable(table,
2098            'n integer, m integer, t text, primary key (n, m)',
2099            values=[(n + 1, m + 1, chr(ord('a') + 2 * n + m))
2100                for n in range(3) for m in range(2)])
2101        self.assertRaises(KeyError, self.db.delete, table, dict(n=2, t='b'))
2102        self.assertEqual(self.db.delete(table, dict(n=2, m=2)), 1)
2103        r = [r[0] for r in query('select t from "%s" where n=2'
2104            ' order by m' % table).getresult()]
2105        self.assertEqual(r, ['c'])
2106        self.assertEqual(self.db.delete(table, dict(n=2, m=2)), 0)
2107        r = [r[0] for r in query('select t from "%s" where n=3'
2108            ' order by m' % table).getresult()]
2109        self.assertEqual(r, ['e', 'f'])
2110        self.assertEqual(self.db.delete(table, dict(n=3, m=1)), 1)
2111        r = [r[0] for r in query('select t from "%s" where n=3'
2112            ' order by m' % table).getresult()]
2113        self.assertEqual(r, ['f'])
2114
2115    def testDeleteWithQuotedNames(self):
2116        delete = self.db.delete
2117        query = self.db.query
2118        table = 'test table for delete()'
2119        self.createTable(table, '"Prime!" smallint primary key,'
2120            ' "much space" integer, "Questions?" text',
2121            values=[(19, 5005, 'Yes!')])
2122        r = {'Prime!': 17}
2123        r = delete(table, r)
2124        self.assertEqual(r, 0)
2125        r = query('select count(*) from "%s"' % table).getresult()
2126        self.assertEqual(r[0][0], 1)
2127        r = {'Prime!': 19}
2128        r = delete(table, r)
2129        self.assertEqual(r, 1)
2130        r = query('select count(*) from "%s"' % table).getresult()
2131        self.assertEqual(r[0][0], 0)
2132
2133    def testDeleteReferenced(self):
2134        delete = self.db.delete
2135        query = self.db.query
2136        self.createTable('test_parent',
2137            'n smallint primary key', values=range(3))
2138        self.createTable('test_child',
2139            'n smallint primary key references test_parent', values=range(3))
2140        q = ("select (select count(*) from test_parent),"
2141            " (select count(*) from test_child)")
2142        self.assertEqual(query(q).getresult()[0], (3, 3))
2143        self.assertRaises(pg.ProgrammingError,
2144            delete, 'test_parent', None, n=2)
2145        self.assertRaises(pg.ProgrammingError,
2146            delete, 'test_parent *', None, n=2)
2147        r = delete('test_child', None, n=2)
2148        self.assertEqual(r, 1)
2149        self.assertEqual(query(q).getresult()[0], (3, 2))
2150        r = delete('test_parent', None, n=2)
2151        self.assertEqual(r, 1)
2152        self.assertEqual(query(q).getresult()[0], (2, 2))
2153        self.assertRaises(pg.ProgrammingError,
2154            delete, 'test_parent', dict(n=0))
2155        self.assertRaises(pg.ProgrammingError,
2156            delete, 'test_parent *', dict(n=0))
2157        r = delete('test_child', dict(n=0))
2158        self.assertEqual(r, 1)
2159        self.assertEqual(query(q).getresult()[0], (2, 1))
2160        r = delete('test_child', dict(n=0))
2161        self.assertEqual(r, 0)
2162        r = delete('test_parent', dict(n=0))
2163        self.assertEqual(r, 1)
2164        self.assertEqual(query(q).getresult()[0], (1, 1))
2165        r = delete('test_parent', None, n=0)
2166        self.assertEqual(r, 0)
2167        q = "select n from test_parent natural join test_child limit 2"
2168        self.assertEqual(query(q).getresult(), [(1,)])
2169
2170    def testTruncate(self):
2171        truncate = self.db.truncate
2172        self.assertRaises(TypeError, truncate, None)
2173        self.assertRaises(TypeError, truncate, 42)
2174        self.assertRaises(TypeError, truncate, dict(test_table=None))
2175        query = self.db.query
2176        self.createTable('test_table', 'n smallint',
2177            temporary=False, values=[1] * 3)
2178        q = "select count(*) from test_table"
2179        r = query(q).getresult()[0][0]
2180        self.assertEqual(r, 3)
2181        truncate('test_table')
2182        r = query(q).getresult()[0][0]
2183        self.assertEqual(r, 0)
2184        for i in range(3):
2185            query("insert into test_table values (1)")
2186        r = query(q).getresult()[0][0]
2187        self.assertEqual(r, 3)
2188        truncate('public.test_table')
2189        r = query(q).getresult()[0][0]
2190        self.assertEqual(r, 0)
2191        self.createTable('test_table_2', 'n smallint', temporary=True)
2192        for t in (list, tuple, set):
2193            for i in range(3):
2194                query("insert into test_table values (1)")
2195                query("insert into test_table_2 values (2)")
2196            q = ("select (select count(*) from test_table),"
2197                " (select count(*) from test_table_2)")
2198            r = query(q).getresult()[0]
2199            self.assertEqual(r, (3, 3))
2200            truncate(t(['test_table', 'test_table_2']))
2201            r = query(q).getresult()[0]
2202            self.assertEqual(r, (0, 0))
2203
2204    def testTruncateRestart(self):
2205        truncate = self.db.truncate
2206        self.assertRaises(TypeError, truncate, 'test_table', restart='invalid')
2207        query = self.db.query
2208        self.createTable('test_table', 'n serial, t text')
2209        for n in range(3):
2210            query("insert into test_table (t) values ('test')")
2211        q = "select count(n), min(n), max(n) from test_table"
2212        r = query(q).getresult()[0]
2213        self.assertEqual(r, (3, 1, 3))
2214        truncate('test_table')
2215        r = query(q).getresult()[0]
2216        self.assertEqual(r, (0, None, None))
2217        for n in range(3):
2218            query("insert into test_table (t) values ('test')")
2219        r = query(q).getresult()[0]
2220        self.assertEqual(r, (3, 4, 6))
2221        truncate('test_table', restart=True)
2222        r = query(q).getresult()[0]
2223        self.assertEqual(r, (0, None, None))
2224        for n in range(3):
2225            query("insert into test_table (t) values ('test')")
2226        r = query(q).getresult()[0]
2227        self.assertEqual(r, (3, 1, 3))
2228
2229    def testTruncateCascade(self):
2230        truncate = self.db.truncate
2231        self.assertRaises(TypeError, truncate, 'test_table', cascade='invalid')
2232        query = self.db.query
2233        self.createTable('test_parent', 'n smallint primary key',
2234            values=range(3))
2235        self.createTable('test_child',
2236            'n smallint primary key references test_parent (n)',
2237            values=range(3))
2238        q = ("select (select count(*) from test_parent),"
2239            " (select count(*) from test_child)")
2240        r = query(q).getresult()[0]
2241        self.assertEqual(r, (3, 3))
2242        self.assertRaises(pg.ProgrammingError, truncate, 'test_parent')
2243        truncate(['test_parent', 'test_child'])
2244        r = query(q).getresult()[0]
2245        self.assertEqual(r, (0, 0))
2246        for n in range(3):
2247            query("insert into test_parent (n) values (%d)" % n)
2248            query("insert into test_child (n) values (%d)" % n)
2249        r = query(q).getresult()[0]
2250        self.assertEqual(r, (3, 3))
2251        truncate('test_parent', cascade=True)
2252        r = query(q).getresult()[0]
2253        self.assertEqual(r, (0, 0))
2254        for n in range(3):
2255            query("insert into test_parent (n) values (%d)" % n)
2256            query("insert into test_child (n) values (%d)" % n)
2257        r = query(q).getresult()[0]
2258        self.assertEqual(r, (3, 3))
2259        truncate('test_child')
2260        r = query(q).getresult()[0]
2261        self.assertEqual(r, (3, 0))
2262        self.assertRaises(pg.ProgrammingError, truncate, 'test_parent')
2263        truncate('test_parent', cascade=True)
2264        r = query(q).getresult()[0]
2265        self.assertEqual(r, (0, 0))
2266
2267    def testTruncateOnly(self):
2268        truncate = self.db.truncate
2269        self.assertRaises(TypeError, truncate, 'test_table', only='invalid')
2270        query = self.db.query
2271        self.createTable('test_parent', 'n smallint')
2272        self.createTable('test_child', 'm smallint) inherits (test_parent')
2273        for n in range(3):
2274            query("insert into test_parent (n) values (1)")
2275            query("insert into test_child (n, m) values (2, 3)")
2276        q = ("select (select count(*) from test_parent),"
2277            " (select count(*) from test_child)")
2278        r = query(q).getresult()[0]
2279        self.assertEqual(r, (6, 3))
2280        truncate('test_parent')
2281        r = query(q).getresult()[0]
2282        self.assertEqual(r, (0, 0))
2283        for n in range(3):
2284            query("insert into test_parent (n) values (1)")
2285            query("insert into test_child (n, m) values (2, 3)")
2286        r = query(q).getresult()[0]
2287        self.assertEqual(r, (6, 3))
2288        truncate('test_parent*')
2289        r = query(q).getresult()[0]
2290        self.assertEqual(r, (0, 0))
2291        for n in range(3):
2292            query("insert into test_parent (n) values (1)")
2293            query("insert into test_child (n, m) values (2, 3)")
2294        r = query(q).getresult()[0]
2295        self.assertEqual(r, (6, 3))
2296        truncate('test_parent', only=True)
2297        r = query(q).getresult()[0]
2298        self.assertEqual(r, (3, 3))
2299        truncate('test_parent', only=False)
2300        r = query(q).getresult()[0]
2301        self.assertEqual(r, (0, 0))
2302        self.assertRaises(ValueError, truncate, 'test_parent*', only=True)
2303        truncate('test_parent*', only=False)
2304        self.createTable('test_parent_2', 'n smallint')
2305        self.createTable('test_child_2', 'm smallint) inherits (test_parent_2')
2306        for t in '', '_2':
2307            for n in range(3):
2308                query("insert into test_parent%s (n) values (1)" % t)
2309                query("insert into test_child%s (n, m) values (2, 3)" % t)
2310        q = ("select (select count(*) from test_parent),"
2311            " (select count(*) from test_child),"
2312            " (select count(*) from test_parent_2),"
2313            " (select count(*) from test_child_2)")
2314        r = query(q).getresult()[0]
2315        self.assertEqual(r, (6, 3, 6, 3))
2316        truncate(['test_parent', 'test_parent_2'], only=[False, True])
2317        r = query(q).getresult()[0]
2318        self.assertEqual(r, (0, 0, 3, 3))
2319        truncate(['test_parent', 'test_parent_2'], only=False)
2320        r = query(q).getresult()[0]
2321        self.assertEqual(r, (0, 0, 0, 0))
2322        self.assertRaises(ValueError, truncate,
2323            ['test_parent*', 'test_child'], only=[True, False])
2324        truncate(['test_parent*', 'test_child'], only=[False, True])
2325
2326    def testTruncateQuoted(self):
2327        truncate = self.db.truncate
2328        query = self.db.query
2329        table = "test table for truncate()"
2330        self.createTable(table, 'n smallint', temporary=False, values=[1] * 3)
2331        q = 'select count(*) from "%s"' % table
2332        r = query(q).getresult()[0][0]
2333        self.assertEqual(r, 3)
2334        truncate(table)
2335        r = query(q).getresult()[0][0]
2336        self.assertEqual(r, 0)
2337        for i in range(3):
2338            query('insert into "%s" values (1)' % table)
2339        r = query(q).getresult()[0][0]
2340        self.assertEqual(r, 3)
2341        truncate('public."%s"' % table)
2342        r = query(q).getresult()[0][0]
2343        self.assertEqual(r, 0)
2344
2345    def testGetAsList(self):
2346        get_as_list = self.db.get_as_list
2347        self.assertRaises(TypeError, get_as_list)
2348        self.assertRaises(TypeError, get_as_list, None)
2349        query = self.db.query
2350        table = 'test_aslist'
2351        r = query('select 1 as colname').namedresult()[0]
2352        self.assertIsInstance(r, tuple)
2353        named = hasattr(r, 'colname')
2354        names = [(1, 'Homer'), (2, 'Marge'),
2355                (3, 'Bart'), (4, 'Lisa'), (5, 'Maggie')]
2356        self.createTable(table,
2357            'id smallint primary key, name varchar', values=names)
2358        r = get_as_list(table)
2359        self.assertIsInstance(r, list)
2360        self.assertEqual(r, names)
2361        for t, n in zip(r, names):
2362            self.assertIsInstance(t, tuple)
2363            self.assertEqual(t, n)
2364            if named:
2365                self.assertEqual(t.id, n[0])
2366                self.assertEqual(t.name, n[1])
2367                self.assertEqual(t._asdict(), dict(id=n[0], name=n[1]))
2368        r = get_as_list(table, what='name')
2369        self.assertIsInstance(r, list)
2370        expected = sorted((row[1],) for row in names)
2371        self.assertEqual(r, expected)
2372        r = get_as_list(table, what='name, id')
2373        self.assertIsInstance(r, list)
2374        expected = sorted(tuple(reversed(row)) for row in names)
2375        self.assertEqual(r, expected)
2376        r = get_as_list(table, what=['name', 'id'])
2377        self.assertIsInstance(r, list)
2378        self.assertEqual(r, expected)
2379        r = get_as_list(table, where="name like 'Ba%'")
2380        self.assertIsInstance(r, list)
2381        self.assertEqual(r, names[2:3])
2382        r = get_as_list(table, what='name', where="name like 'Ma%'")
2383        self.assertIsInstance(r, list)
2384        self.assertEqual(r, [('Maggie',), ('Marge',)])
2385        r = get_as_list(table, what='name',
2386                        where=["name like 'Ma%'", "name like '%r%'"])
2387        self.assertIsInstance(r, list)
2388        self.assertEqual(r, [('Marge',)])
2389        r = get_as_list(table, what='name', order='id')
2390        self.assertIsInstance(r, list)
2391        expected = [(row[1],) for row in names]
2392        self.assertEqual(r, expected)
2393        r = get_as_list(table, what=['name'], order=['id'])
2394        self.assertIsInstance(r, list)
2395        self.assertEqual(r, expected)
2396        r = get_as_list(table, what=['id', 'name'], order=['id', 'name'])
2397        self.assertIsInstance(r, list)
2398        self.assertEqual(r, names)
2399        r = get_as_list(table, what='id * 2 as num', order='id desc')
2400        self.assertIsInstance(r, list)
2401        expected = [(n,) for n in range(10, 0, -2)]
2402        self.assertEqual(r, expected)
2403        r = get_as_list(table, limit=2)
2404        self.assertIsInstance(r, list)
2405        self.assertEqual(r, names[:2])
2406        r = get_as_list(table, offset=3)
2407        self.assertIsInstance(r, list)
2408        self.assertEqual(r, names[3:])
2409        r = get_as_list(table, limit=1, offset=2)
2410        self.assertIsInstance(r, list)
2411        self.assertEqual(r, names[2:3])
2412        r = get_as_list(table, scalar=True)
2413        self.assertIsInstance(r, list)
2414        self.assertEqual(r, list(range(1, 6)))
2415        r = get_as_list(table, what='name', scalar=True)
2416        self.assertIsInstance(r, list)
2417        expected = sorted(row[1] for row in names)
2418        self.assertEqual(r, expected)
2419        r = get_as_list(table, what='name', limit=1, scalar=True)
2420        self.assertIsInstance(r, list)
2421        self.assertEqual(r, expected[:1])
2422        query('alter table "%s" drop constraint "%s_pkey"' % (table, table))
2423        self.assertRaises(KeyError, self.db.pkey, table, flush=True)
2424        names.insert(1, (1, 'Snowball'))
2425        query('insert into "%s" values ($1, $2)' % table, (1, 'Snowball'))
2426        r = get_as_list(table)
2427        self.assertIsInstance(r, list)
2428        self.assertEqual(r, names)
2429        r = get_as_list(table, what='name', where='id=1', scalar=True)
2430        self.assertIsInstance(r, list)
2431        self.assertEqual(r, ['Homer', 'Snowball'])
2432        # test with unordered query
2433        r = get_as_list(table, order=False)
2434        self.assertIsInstance(r, list)
2435        self.assertEqual(set(r), set(names))
2436        # test with arbitrary from clause
2437        from_table = '(select lower(name) as n2 from "%s") as t2' % table
2438        r = get_as_list(from_table)
2439        self.assertIsInstance(r, list)
2440        r = set(row[0] for row in r)
2441        expected = set(row[1].lower() for row in names)
2442        self.assertEqual(r, expected)
2443        r = get_as_list(from_table, order='n2', scalar=True)
2444        self.assertIsInstance(r, list)
2445        self.assertEqual(r, sorted(expected))
2446        r = get_as_list(from_table, order='n2', limit=1)
2447        self.assertIsInstance(r, list)
2448        self.assertEqual(len(r), 1)
2449        t = r[0]
2450        self.assertIsInstance(t, tuple)
2451        if named:
2452            self.assertEqual(t.n2, 'bart')
2453            self.assertEqual(t._asdict(), dict(n2='bart'))
2454        else:
2455            self.assertEqual(t, ('bart',))
2456
2457    def testGetAsDict(self):
2458        get_as_dict = self.db.get_as_dict
2459        self.assertRaises(TypeError, get_as_dict)
2460        self.assertRaises(TypeError, get_as_dict, None)
2461        # the test table has no primary key
2462        self.assertRaises(pg.ProgrammingError, get_as_dict, 'test')
2463        query = self.db.query
2464        table = 'test_asdict'
2465        r = query('select 1 as colname').namedresult()[0]
2466        self.assertIsInstance(r, tuple)
2467        named = hasattr(r, 'colname')
2468        colors = [(1, '#7cb9e8', 'Aero'), (2, '#b5a642', 'Brass'),
2469                  (3, '#b2ffff', 'Celeste'), (4, '#c19a6b', 'Desert')]
2470        self.createTable(table,
2471            'id smallint primary key, rgb char(7), name varchar',
2472            values=colors)
2473        # keyname must be string, list or tuple
2474        self.assertRaises(KeyError, get_as_dict, table, 3)
2475        self.assertRaises(KeyError, get_as_dict, table, dict(id=None))
2476        # missing keyname in row
2477        self.assertRaises(KeyError, get_as_dict, table,
2478                          keyname='rgb', what='name')
2479        r = get_as_dict(table)
2480        self.assertIsInstance(r, OrderedDict)
2481        expected = OrderedDict((row[0], row[1:]) for row in colors)
2482        self.assertEqual(r, expected)
2483        for key in r:
2484            self.assertIsInstance(key, int)
2485            self.assertIn(key, expected)
2486            row = r[key]
2487            self.assertIsInstance(row, tuple)
2488            t = expected[key]
2489            self.assertEqual(row, t)
2490            if named:
2491                self.assertEqual(row.rgb, t[0])
2492                self.assertEqual(row.name, t[1])
2493                self.assertEqual(row._asdict(), dict(rgb=t[0], name=t[1]))
2494        if OrderedDict is not dict:  # Python > 2.6
2495            self.assertEqual(r.keys(), expected.keys())
2496        r = get_as_dict(table, keyname='rgb')
2497        self.assertIsInstance(r, OrderedDict)
2498        expected = OrderedDict((row[1], (row[0], row[2]))
2499            for row in sorted(colors, key=itemgetter(1)))
2500        self.assertEqual(r, expected)
2501        for key in r:
2502            self.assertIsInstance(key, str)
2503            self.assertIn(key, expected)
2504            row = r[key]
2505            self.assertIsInstance(row, tuple)
2506            t = expected[key]
2507            self.assertEqual(row, t)
2508            if named:
2509                self.assertEqual(row.id, t[0])
2510                self.assertEqual(row.name, t[1])
2511                self.assertEqual(row._asdict(), dict(id=t[0], name=t[1]))
2512        if OrderedDict is not dict:  # Python > 2.6
2513            self.assertEqual(r.keys(), expected.keys())
2514        r = get_as_dict(table, keyname=['id', 'rgb'])
2515        self.assertIsInstance(r, OrderedDict)
2516        expected = OrderedDict((row[:2], row[2:]) for row in colors)
2517        self.assertEqual(r, expected)
2518        for key in r:
2519            self.assertIsInstance(key, tuple)
2520            self.assertIsInstance(key[0], int)
2521            self.assertIsInstance(key[1], str)
2522            if named:
2523                self.assertEqual(key, (key.id, key.rgb))
2524                self.assertEqual(key._fields, ('id', 'rgb'))
2525            row = r[key]
2526            self.assertIsInstance(row, tuple)
2527            self.assertIsInstance(row[0], str)
2528            t = expected[key]
2529            self.assertEqual(row, t)
2530            if named:
2531                self.assertEqual(row.name, t[0])
2532                self.assertEqual(row._asdict(), dict(name=t[0]))
2533        if OrderedDict is not dict:  # Python > 2.6
2534            self.assertEqual(r.keys(), expected.keys())
2535        r = get_as_dict(table, keyname=['id', 'rgb'], scalar=True)
2536        self.assertIsInstance(r, OrderedDict)
2537        expected = OrderedDict((row[:2], row[2]) for row in colors)
2538        self.assertEqual(r, expected)
2539        for key in r:
2540            self.assertIsInstance(key, tuple)
2541            row = r[key]
2542            self.assertIsInstance(row, str)
2543            t = expected[key]
2544            self.assertEqual(row, t)
2545        if OrderedDict is not dict:  # Python > 2.6
2546            self.assertEqual(r.keys(), expected.keys())
2547        r = get_as_dict(table, keyname='rgb', what=['rgb', 'name'], scalar=True)
2548        self.assertIsInstance(r, OrderedDict)
2549        expected = OrderedDict((row[1], row[2])
2550            for row in sorted(colors, key=itemgetter(1)))
2551        self.assertEqual(r, expected)
2552        for key in r:
2553            self.assertIsInstance(key, str)
2554            row = r[key]
2555            self.assertIsInstance(row, str)
2556            t = expected[key]
2557            self.assertEqual(row, t)
2558        if OrderedDict is not dict:  # Python > 2.6
2559            self.assertEqual(r.keys(), expected.keys())
2560        r = get_as_dict(table, what='id, name',
2561                        where="rgb like '#b%'", scalar=True)
2562        self.assertIsInstance(r, OrderedDict)
2563        expected = OrderedDict((row[0], row[2]) for row in colors[1:3])
2564        self.assertEqual(r, expected)
2565        for key in r:
2566            self.assertIsInstance(key, int)
2567            row = r[key]
2568            self.assertIsInstance(row, str)
2569            t = expected[key]
2570            self.assertEqual(row, t)
2571        if OrderedDict is not dict:  # Python > 2.6
2572            self.assertEqual(r.keys(), expected.keys())
2573        expected = r
2574        r = get_as_dict(table, what=['name', 'id'],
2575                        where=['id > 1', 'id < 4', "rgb like '#b%'",
2576                   "name not like 'A%'", "name not like '%t'"], scalar=True)
2577        self.assertEqual(r, expected)
2578        r = get_as_dict(table, what='name, id', limit=2, offset=1, scalar=True)
2579        self.assertEqual(r, expected)
2580        r = get_as_dict(table, keyname=('id',), what=('name', 'id'),
2581                        where=('id > 1', 'id < 4'), order=('id',), scalar=True)
2582        self.assertEqual(r, expected)
2583        r = get_as_dict(table, limit=1)
2584        self.assertEqual(len(r), 1)
2585        self.assertEqual(r[1][1], 'Aero')
2586        r = get_as_dict(table, offset=3)
2587        self.assertEqual(len(r), 1)
2588        self.assertEqual(r[4][1], 'Desert')
2589        r = get_as_dict(table, order='id desc')
2590        expected = OrderedDict((row[0], row[1:]) for row in reversed(colors))
2591        self.assertEqual(r, expected)
2592        r = get_as_dict(table, where='id > 5')
2593        self.assertIsInstance(r, OrderedDict)
2594        self.assertEqual(len(r), 0)
2595        # test with unordered query
2596        expected = dict((row[0], row[1:]) for row in colors)
2597        r = get_as_dict(table, order=False)
2598        self.assertIsInstance(r, dict)
2599        self.assertEqual(r, expected)
2600        if dict is not OrderedDict:  # Python > 2.6
2601            self.assertNotIsInstance(self, OrderedDict)
2602        # test with arbitrary from clause
2603        from_table = '(select id, lower(name) as n2 from "%s") as t2' % table
2604        # primary key must be passed explicitly in this case
2605        self.assertRaises(pg.ProgrammingError, get_as_dict, from_table)
2606        r = get_as_dict(from_table, 'id')
2607        self.assertIsInstance(r, OrderedDict)
2608        expected = OrderedDict((row[0], (row[2].lower(),)) for row in colors)
2609        self.assertEqual(r, expected)
2610        # test without a primary key
2611        query('alter table "%s" drop constraint "%s_pkey"' % (table, table))
2612        self.assertRaises(KeyError, self.db.pkey, table, flush=True)
2613        self.assertRaises(pg.ProgrammingError, get_as_dict, table)
2614        r = get_as_dict(table, keyname='id')
2615        expected = OrderedDict((row[0], row[1:]) for row in colors)
2616        self.assertIsInstance(r, dict)
2617        self.assertEqual(r, expected)
2618        r = (1, '#007fff', 'Azure')
2619        query('insert into "%s" values ($1, $2, $3)' % table, r)
2620        # the last entry will win
2621        expected[1] = r[1:]
2622        r = get_as_dict(table, keyname='id')
2623        self.assertEqual(r, expected)
2624
2625    def testTransaction(self):
2626        query = self.db.query
2627        self.createTable('test_table', 'n integer', temporary=False)
2628        self.db.begin()
2629        query("insert into test_table values (1)")
2630        query("insert into test_table values (2)")
2631        self.db.commit()
2632        self.db.begin()
2633        query("insert into test_table values (3)")
2634        query("insert into test_table values (4)")
2635        self.db.rollback()
2636        self.db.begin()
2637        query("insert into test_table values (5)")
2638        self.db.savepoint('before6')
2639        query("insert into test_table values (6)")
2640        self.db.rollback('before6')
2641        query("insert into test_table values (7)")
2642        self.db.commit()
2643        self.db.begin()
2644        self.db.savepoint('before8')
2645        query("insert into test_table values (8)")
2646        self.db.release('before8')
2647        self.assertRaises(pg.ProgrammingError, self.db.rollback, 'before8')
2648        self.db.commit()
2649        self.db.start()
2650        query("insert into test_table values (9)")
2651        self.db.end()
2652        r = [r[0] for r in query(
2653            "select * from test_table order by 1").getresult()]
2654        self.assertEqual(r, [1, 2, 5, 7, 9])
2655        self.db.begin(mode='read only')
2656        self.assertRaises(pg.ProgrammingError,
2657            query, "insert into test_table values (0)")
2658        self.db.rollback()
2659        self.db.start(mode='Read Only')
2660        self.assertRaises(pg.ProgrammingError,
2661            query, "insert into test_table values (0)")
2662        self.db.abort()
2663
2664    def testTransactionAliases(self):
2665        self.assertEqual(self.db.begin, self.db.start)
2666        self.assertEqual(self.db.commit, self.db.end)
2667        self.assertEqual(self.db.rollback, self.db.abort)
2668
2669    def testContextManager(self):
2670        query = self.db.query
2671        self.createTable('test_table', 'n integer check(n>0)')
2672        with self.db:
2673            query("insert into test_table values (1)")
2674            query("insert into test_table values (2)")
2675        try:
2676            with self.db:
2677                query("insert into test_table values (3)")
2678                query("insert into test_table values (4)")
2679                raise ValueError('test transaction should rollback')
2680        except ValueError as error:
2681            self.assertEqual(str(error), 'test transaction should rollback')
2682        with self.db:
2683            query("insert into test_table values (5)")
2684        try:
2685            with self.db:
2686                query("insert into test_table values (6)")
2687                query("insert into test_table values (-1)")
2688        except pg.ProgrammingError as error:
2689            self.assertTrue('check' in str(error))
2690        with self.db:
2691            query("insert into test_table values (7)")
2692        r = [r[0] for r in query(
2693            "select * from test_table order by 1").getresult()]
2694        self.assertEqual(r, [1, 2, 5, 7])
2695
2696    def testBytea(self):
2697        query = self.db.query
2698        self.createTable('bytea_test', 'n smallint primary key, data bytea')
2699        s = b"It's all \\ kinds \x00 of\r nasty \xff stuff!\n"
2700        r = self.db.escape_bytea(s)
2701        query('insert into bytea_test values(3,$1)', (r,))
2702        r = query('select * from bytea_test where n=3').getresult()
2703        self.assertEqual(len(r), 1)
2704        r = r[0]
2705        self.assertEqual(len(r), 2)
2706        self.assertEqual(r[0], 3)
2707        r = r[1]
2708        self.assertIsInstance(r, str)
2709        r = self.db.unescape_bytea(r)
2710        self.assertIsInstance(r, bytes)
2711        self.assertEqual(r, s)
2712
2713    def testInsertUpdateGetBytea(self):
2714        query = self.db.query
2715        self.createTable('bytea_test', 'n smallint primary key, data bytea')
2716        # insert null value
2717        r = self.db.insert('bytea_test', n=0, data=None)
2718        self.assertIsInstance(r, dict)
2719        self.assertIn('n', r)
2720        self.assertEqual(r['n'], 0)
2721        self.assertIn('data', r)
2722        self.assertIsNone(r['data'])
2723        s = b'None'
2724        r = self.db.update('bytea_test', n=0, data=s)
2725        self.assertIsInstance(r, dict)
2726        self.assertIn('n', r)
2727        self.assertEqual(r['n'], 0)
2728        self.assertIn('data', r)
2729        r = r['data']
2730        self.assertIsInstance(r, bytes)
2731        self.assertEqual(r, s)
2732        r = self.db.update('bytea_test', n=0, data=None)
2733        self.assertIsNone(r['data'])
2734        # insert as bytes
2735        s = b"It's all \\ kinds \x00 of\r nasty \xff stuff!\n"
2736        r = self.db.insert('bytea_test', n=5, data=s)
2737        self.assertIsInstance(r, dict)
2738        self.assertIn('n', r)
2739        self.assertEqual(r['n'], 5)
2740        self.assertIn('data', r)
2741        r = r['data']
2742        self.assertIsInstance(r, bytes)
2743        self.assertEqual(r, s)
2744        # update as bytes
2745        s += b"and now even more \x00 nasty \t stuff!\f"
2746        r = self.db.update('bytea_test', n=5, data=s)
2747        self.assertIsInstance(r, dict)
2748        self.assertIn('n', r)
2749        self.assertEqual(r['n'], 5)
2750        self.assertIn('data', r)
2751        r = r['data']
2752        self.assertIsInstance(r, bytes)
2753        self.assertEqual(r, s)
2754        r = query('select * from bytea_test where n=5').getresult()
2755        self.assertEqual(len(r), 1)
2756        r = r[0]
2757        self.assertEqual(len(r), 2)
2758        self.assertEqual(r[0], 5)
2759        r = r[1]
2760        self.assertIsInstance(r, str)
2761        r = self.db.unescape_bytea(r)
2762        self.assertIsInstance(r, bytes)
2763        self.assertEqual(r, s)
2764        r = self.db.get('bytea_test', dict(n=5))
2765        self.assertIsInstance(r, dict)
2766        self.assertIn('n', r)
2767        self.assertEqual(r['n'], 5)
2768        self.assertIn('data', r)
2769        r = r['data']
2770        self.assertIsInstance(r, bytes)
2771        self.assertEqual(r, s)
2772
2773    def testUpsertBytea(self):
2774        query = self.db.query
2775        self.createTable('bytea_test', 'n smallint primary key, data bytea')
2776        s = b"It's all \\ kinds \x00 of\r nasty \xff stuff!\n"
2777        r = dict(n=7, data=s)
2778        try:
2779            r = self.db.upsert('bytea_test', r)
2780        except pg.ProgrammingError as error:
2781            if self.db.server_version < 90500:
2782                self.skipTest('database does not support upsert')
2783            self.fail(str(error))
2784        self.assertIsInstance(r, dict)
2785        self.assertIn('n', r)
2786        self.assertEqual(r['n'], 7)
2787        self.assertIn('data', r)
2788        self.assertIsInstance(r['data'], bytes)
2789        self.assertEqual(r['data'], s)
2790        r['data'] = None
2791        r = self.db.upsert('bytea_test', r)
2792        self.assertIsInstance(r, dict)
2793        self.assertIn('n', r)
2794        self.assertEqual(r['n'], 7)
2795        self.assertIn('data', r)
2796        self.assertIsNone(r['data'], bytes)
2797
2798    def testNotificationHandler(self):
2799        # the notification handler itself is tested separately
2800        f = self.db.notification_handler
2801        callback = lambda arg_dict: None
2802        handler = f('test', callback)
2803        self.assertIsInstance(handler, pg.NotificationHandler)
2804        self.assertIs(handler.db, self.db)
2805        self.assertEqual(handler.event, 'test')
2806        self.assertEqual(handler.stop_event, 'stop_test')
2807        self.assertIs(handler.callback, callback)
2808        self.assertIsInstance(handler.arg_dict, dict)
2809        self.assertEqual(handler.arg_dict, {})
2810        self.assertIsNone(handler.timeout)
2811        self.assertFalse(handler.listening)
2812        handler.close()
2813        self.assertIsNone(handler.db)
2814        self.db.reopen()
2815        self.assertIsNone(handler.db)
2816        handler = f('test2', callback, timeout=2)
2817        self.assertIsInstance(handler, pg.NotificationHandler)
2818        self.assertIs(handler.db, self.db)
2819        self.assertEqual(handler.event, 'test2')
2820        self.assertEqual(handler.stop_event, 'stop_test2')
2821        self.assertIs(handler.callback, callback)
2822        self.assertIsInstance(handler.arg_dict, dict)
2823        self.assertEqual(handler.arg_dict, {})
2824        self.assertEqual(handler.timeout, 2)
2825        self.assertFalse(handler.listening)
2826        handler.close()
2827        self.assertIsNone(handler.db)
2828        self.db.reopen()
2829        self.assertIsNone(handler.db)
2830        arg_dict = {'testing': 3}
2831        handler = f('test3', callback, arg_dict=arg_dict)
2832        self.assertIsInstance(handler, pg.NotificationHandler)
2833        self.assertIs(handler.db, self.db)
2834        self.assertEqual(handler.event, 'test3')
2835        self.assertEqual(handler.stop_event, 'stop_test3')
2836        self.assertIs(handler.callback, callback)
2837        self.assertIs(handler.arg_dict, arg_dict)
2838        self.assertEqual(arg_dict['testing'], 3)
2839        self.assertIsNone(handler.timeout)
2840        self.assertFalse(handler.listening)
2841        handler.close()
2842        self.assertIsNone(handler.db)
2843        self.db.reopen()
2844        self.assertIsNone(handler.db)
2845        handler = f('test4', callback, stop_event='stop4')
2846        self.assertIsInstance(handler, pg.NotificationHandler)
2847        self.assertIs(handler.db, self.db)
2848        self.assertEqual(handler.event, 'test4')
2849        self.assertEqual(handler.stop_event, 'stop4')
2850        self.assertIs(handler.callback, callback)
2851        self.assertIsInstance(handler.arg_dict, dict)
2852        self.assertEqual(handler.arg_dict, {})
2853        self.assertIsNone(handler.timeout)
2854        self.assertFalse(handler.listening)
2855        handler.close()
2856        self.assertIsNone(handler.db)
2857        self.db.reopen()
2858        self.assertIsNone(handler.db)
2859        arg_dict = {'testing': 5}
2860        handler = f('test5', callback, arg_dict, 1.5, 'stop5')
2861        self.assertIsInstance(handler, pg.NotificationHandler)
2862        self.assertIs(handler.db, self.db)
2863        self.assertEqual(handler.event, 'test5')
2864        self.assertEqual(handler.stop_event, 'stop5')
2865        self.assertIs(handler.callback, callback)
2866        self.assertIs(handler.arg_dict, arg_dict)
2867        self.assertEqual(arg_dict['testing'], 5)
2868        self.assertEqual(handler.timeout, 1.5)
2869        self.assertFalse(handler.listening)
2870        handler.close()
2871        self.assertIsNone(handler.db)
2872        self.db.reopen()
2873        self.assertIsNone(handler.db)
2874
2875
2876class TestDBClassNonStdOpts(TestDBClass):
2877    """Test the methods of the DB class with non-standard global options."""
2878
2879    @classmethod
2880    def setUpClass(cls):
2881        cls.saved_options = {}
2882        cls.set_option('decimal', float)
2883        not_bool = not pg.get_bool()
2884        cls.set_option('bool', not_bool)
2885        unnamed_result = lambda q: q.getresult()
2886        cls.set_option('namedresult', unnamed_result)
2887        super(TestDBClassNonStdOpts, cls).setUpClass()
2888
2889    @classmethod
2890    def tearDownClass(cls):
2891        super(TestDBClassNonStdOpts, cls).tearDownClass()
2892        cls.reset_option('namedresult')
2893        cls.reset_option('bool')
2894        cls.reset_option('decimal')
2895
2896    @classmethod
2897    def set_option(cls, option, value):
2898        cls.saved_options[option] = getattr(pg, 'get_' + option)()
2899        return getattr(pg, 'set_' + option)(value)
2900
2901    @classmethod
2902    def reset_option(cls, option):
2903        return getattr(pg, 'set_' + option)(cls.saved_options[option])
2904
2905
2906class TestSchemas(unittest.TestCase):
2907    """Test correct handling of schemas (namespaces)."""
2908
2909    cls_set_up = False
2910
2911    @classmethod
2912    def setUpClass(cls):
2913        db = DB()
2914        query = db.query
2915        for num_schema in range(5):
2916            if num_schema:
2917                schema = "s%d" % num_schema
2918                query("drop schema if exists %s cascade" % (schema,))
2919                try:
2920                    query("create schema %s" % (schema,))
2921                except pg.ProgrammingError:
2922                    raise RuntimeError("The test user cannot create schemas.\n"
2923                        "Grant create on database %s to the user"
2924                        " for running these tests." % dbname)
2925            else:
2926                schema = "public"
2927                query("drop table if exists %s.t" % (schema,))
2928                query("drop table if exists %s.t%d" % (schema, num_schema))
2929            query("create table %s.t with oids as select 1 as n, %d as d"
2930                  % (schema, num_schema))
2931            query("create table %s.t%d with oids as select 1 as n, %d as d"
2932                  % (schema, num_schema, num_schema))
2933        db.close()
2934        cls.cls_set_up = True
2935
2936    @classmethod
2937    def tearDownClass(cls):
2938        db = DB()
2939        query = db.query
2940        for num_schema in range(5):
2941            if num_schema:
2942                schema = "s%d" % num_schema
2943                query("drop schema %s cascade" % (schema,))
2944            else:
2945                schema = "public"
2946                query("drop table %s.t" % (schema,))
2947                query("drop table %s.t%d" % (schema, num_schema))
2948        db.close()
2949
2950    def setUp(self):
2951        self.assertTrue(self.cls_set_up)
2952        self.db = DB()
2953
2954    def tearDown(self):
2955        self.doCleanups()
2956        self.db.close()
2957
2958    def testGetTables(self):
2959        tables = self.db.get_tables()
2960        for num_schema in range(5):
2961            if num_schema:
2962                schema = "s" + str(num_schema)
2963            else:
2964                schema = "public"
2965            for t in (schema + ".t",
2966                    schema + ".t" + str(num_schema)):
2967                self.assertIn(t, tables)
2968
2969    def testGetAttnames(self):
2970        get_attnames = self.db.get_attnames
2971        query = self.db.query
2972        result = {'oid': 'int', 'd': 'int', 'n': 'int'}
2973        r = get_attnames("t")
2974        self.assertEqual(r, result)
2975        r = get_attnames("s4.t4")
2976        self.assertEqual(r, result)
2977        query("drop table if exists s3.t3m")
2978        self.addCleanup(query, "drop table s3.t3m")
2979        query("create table s3.t3m with oids as select 1 as m")
2980        result_m = {'oid': 'int', 'm': 'int'}
2981        r = get_attnames("s3.t3m")
2982        self.assertEqual(r, result_m)
2983        query("set search_path to s1,s3")
2984        r = get_attnames("t3")
2985        self.assertEqual(r, result)
2986        r = get_attnames("t3m")
2987        self.assertEqual(r, result_m)
2988
2989    def testGet(self):
2990        get = self.db.get
2991        query = self.db.query
2992        PrgError = pg.ProgrammingError
2993        self.assertEqual(get("t", 1, 'n')['d'], 0)
2994        self.assertEqual(get("t0", 1, 'n')['d'], 0)
2995        self.assertEqual(get("public.t", 1, 'n')['d'], 0)
2996        self.assertEqual(get("public.t0", 1, 'n')['d'], 0)
2997        self.assertRaises(PrgError, get, "public.t1", 1, 'n')
2998        self.assertEqual(get("s1.t1", 1, 'n')['d'], 1)
2999        self.assertEqual(get("s3.t", 1, 'n')['d'], 3)
3000        query("set search_path to s2,s4")
3001        self.assertRaises(PrgError, get, "t1", 1, 'n')
3002        self.assertEqual(get("t4", 1, 'n')['d'], 4)
3003        self.assertRaises(PrgError, get, "t3", 1, 'n')
3004        self.assertEqual(get("t", 1, 'n')['d'], 2)
3005        self.assertEqual(get("s3.t3", 1, 'n')['d'], 3)
3006        query("set search_path to s1,s3")
3007        self.assertRaises(PrgError, get, "t2", 1, 'n')
3008        self.assertEqual(get("t3", 1, 'n')['d'], 3)
3009        self.assertRaises(PrgError, get, "t4", 1, 'n')
3010        self.assertEqual(get("t", 1, 'n')['d'], 1)
3011        self.assertEqual(get("s4.t4", 1, 'n')['d'], 4)
3012
3013    def testMunging(self):
3014        get = self.db.get
3015        query = self.db.query
3016        r = get("t", 1, 'n')
3017        self.assertIn('oid(t)', r)
3018        query("set search_path to s2")
3019        r = get("t2", 1, 'n')
3020        self.assertIn('oid(t2)', r)
3021        query("set search_path to s3")
3022        r = get("t", 1, 'n')
3023        self.assertIn('oid(t)', r)
3024
3025
3026class TestDebug(unittest.TestCase):
3027    """Test the debug attribute of the DB class."""
3028
3029    def setUp(self):
3030        self.db = DB()
3031        self.query = self.db.query
3032        self.debug = self.db.debug
3033        self.output = StringIO()
3034        self.stdout, sys.stdout = sys.stdout, self.output
3035
3036    def tearDown(self):
3037        sys.stdout = self.stdout
3038        self.output.close()
3039        self.db.debug = debug
3040        self.db.close()
3041
3042    def get_output(self):
3043        return self.output.getvalue()
3044
3045    def send_queries(self):
3046        self.db.query("select 1")
3047        self.db.query("select 2")
3048
3049    def testDebugDefault(self):
3050        if debug:
3051            self.assertEqual(self.db.debug, debug)
3052        else:
3053            self.assertIsNone(self.db.debug)
3054
3055    def testDebugIsFalse(self):
3056        self.db.debug = False
3057        self.send_queries()
3058        self.assertEqual(self.get_output(), "")
3059
3060    def testDebugIsTrue(self):
3061        self.db.debug = True
3062        self.send_queries()
3063        self.assertEqual(self.get_output(), "select 1\nselect 2\n")
3064
3065    def testDebugIsString(self):
3066        self.db.debug = "Test with string: %s."
3067        self.send_queries()
3068        self.assertEqual(self.get_output(),
3069            "Test with string: select 1.\nTest with string: select 2.\n")
3070
3071    def testDebugIsFileLike(self):
3072        with tempfile.TemporaryFile('w+') as debug_file:
3073            self.db.debug = debug_file
3074            self.send_queries()
3075            debug_file.seek(0)
3076            output = debug_file.read()
3077            self.assertEqual(output, "select 1\nselect 2\n")
3078            self.assertEqual(self.get_output(), "")
3079
3080    def testDebugIsCallable(self):
3081        output = []
3082        self.db.debug = output.append
3083        self.db.query("select 1")
3084        self.db.query("select 2")
3085        self.assertEqual(output, ["select 1", "select 2"])
3086        self.assertEqual(self.get_output(), "")
3087
3088
3089if __name__ == '__main__':
3090    unittest.main()
Note: See TracBrowser for help on using the repository browser.