source: trunk/tests/test_classic_dbwrapper.py @ 759

Last change on this file since 759 was 759, checked in by cito, 4 years ago

Add tests for debug modes of DB

  • Property svn:executable set to *
  • Property svn:keywords set to Id
File size: 80.9 KB
Line 
1#! /usr/bin/python
2# -*- coding: utf-8 -*-
3
4"""Test the classic PyGreSQL interface.
5
6Sub-tests for the DB wrapper object.
7
8Contributed by Christoph Zwerschke.
9
10These tests need a database to test against.
11
12"""
13
14try:
15    import unittest2 as unittest  # for Python < 2.7
16except ImportError:
17    import unittest
18
19import os
20import sys
21import tempfile
22
23import pg  # the module under test
24
25from decimal import Decimal
26
27# We need a database to test against.  If LOCAL_PyGreSQL.py exists we will
28# get our information from that.  Otherwise we use the defaults.
29# The current user must have create schema privilege on the database.
30dbname = 'unittest'
31dbhost = None
32dbport = 5432
33
34debug = False  # let DB wrapper print debugging output
35
36try:
37    from .LOCAL_PyGreSQL import *
38except (ImportError, ValueError):
39    try:
40        from LOCAL_PyGreSQL import *
41    except ImportError:
42        pass
43
44try:
45    long
46except NameError:  # Python >= 3.0
47    long = int
48
49try:
50    unicode
51except NameError:  # Python >= 3.0
52    unicode = str
53
54try:
55    from collections import OrderedDict
56except ImportError:  # Python 2.6 or 3.0
57    OrderedDict = dict
58
59if str is bytes:
60    from StringIO import StringIO
61else:
62    from io import StringIO
63
64windows = os.name == 'nt'
65
66# There is a known a bug in libpq under Windows which can cause
67# the interface to crash when calling PQhost():
68do_not_ask_for_host = windows
69do_not_ask_for_host_reason = 'libpq issue on Windows'
70
71
72def DB():
73    """Create a DB wrapper object connecting to the test database."""
74    db = pg.DB(dbname, dbhost, dbport)
75    if debug:
76        db.debug = debug
77    db.query("set client_min_messages=warning")
78    return db
79
80
81class TestDBClassBasic(unittest.TestCase):
82    """Test existence of the DB class wrapped pg connection methods."""
83
84    def setUp(self):
85        self.db = DB()
86
87    def tearDown(self):
88        try:
89            self.db.close()
90        except pg.InternalError:
91            pass
92
93    def testAllDBAttributes(self):
94        attributes = [
95            'begin',
96            'cancel', 'clear', 'close', 'commit',
97            'db', 'dbname', 'debug', 'delete',
98            'end', 'endcopy', 'error',
99            'escape_bytea', 'escape_identifier',
100            'escape_literal', 'escape_string',
101            'fileno',
102            'get', 'get_attnames', 'get_databases',
103            'get_notice_receiver', 'get_parameter',
104            'get_relations', 'get_tables',
105            'getline', 'getlo', 'getnotify',
106            'has_table_privilege', 'host',
107            'insert', 'inserttable',
108            'locreate', 'loimport',
109            'notification_handler',
110            'options',
111            'parameter', 'pkey', 'port',
112            'protocol_version', 'putline',
113            'query',
114            'release', 'reopen', 'reset', 'rollback',
115            'savepoint', 'server_version',
116            'set_notice_receiver', 'set_parameter',
117            'source', 'start', 'status',
118            'transaction', 'truncate',
119            'unescape_bytea', 'update', 'upsert',
120            'use_regtypes', 'user',
121        ]
122        db_attributes = [a for a in dir(self.db)
123            if not a.startswith('_')]
124        self.assertEqual(attributes, db_attributes)
125
126    def testAttributeDb(self):
127        self.assertEqual(self.db.db.db, dbname)
128
129    def testAttributeDbname(self):
130        self.assertEqual(self.db.dbname, dbname)
131
132    def testAttributeError(self):
133        error = self.db.error
134        self.assertTrue(not error or 'krb5_' in error)
135        self.assertEqual(self.db.error, self.db.db.error)
136
137    @unittest.skipIf(do_not_ask_for_host, do_not_ask_for_host_reason)
138    def testAttributeHost(self):
139        def_host = 'localhost'
140        host = self.db.host
141        self.assertIsInstance(host, str)
142        self.assertEqual(host, dbhost or def_host)
143        self.assertEqual(host, self.db.db.host)
144
145    def testAttributeOptions(self):
146        no_options = ''
147        options = self.db.options
148        self.assertEqual(options, no_options)
149        self.assertEqual(options, self.db.db.options)
150
151    def testAttributePort(self):
152        def_port = 5432
153        port = self.db.port
154        self.assertIsInstance(port, int)
155        self.assertEqual(port, dbport or def_port)
156        self.assertEqual(port, self.db.db.port)
157
158    def testAttributeProtocolVersion(self):
159        protocol_version = self.db.protocol_version
160        self.assertIsInstance(protocol_version, int)
161        self.assertTrue(2 <= protocol_version < 4)
162        self.assertEqual(protocol_version, self.db.db.protocol_version)
163
164    def testAttributeServerVersion(self):
165        server_version = self.db.server_version
166        self.assertIsInstance(server_version, int)
167        self.assertTrue(70400 <= server_version < 100000)
168        self.assertEqual(server_version, self.db.db.server_version)
169
170    def testAttributeStatus(self):
171        status_ok = 1
172        status = self.db.status
173        self.assertIsInstance(status, int)
174        self.assertEqual(status, status_ok)
175        self.assertEqual(status, self.db.db.status)
176
177    def testAttributeUser(self):
178        no_user = 'Deprecated facility'
179        user = self.db.user
180        self.assertTrue(user)
181        self.assertIsInstance(user, str)
182        self.assertNotEqual(user, no_user)
183        self.assertEqual(user, self.db.db.user)
184
185    def testMethodEscapeLiteral(self):
186        self.assertEqual(self.db.escape_literal(''), "''")
187
188    def testMethodEscapeIdentifier(self):
189        self.assertEqual(self.db.escape_identifier(''), '""')
190
191    def testMethodEscapeString(self):
192        self.assertEqual(self.db.escape_string(''), '')
193
194    def testMethodEscapeBytea(self):
195        self.assertEqual(self.db.escape_bytea('').replace(
196            '\\x', '').replace('\\', ''), '')
197
198    def testMethodUnescapeBytea(self):
199        self.assertEqual(self.db.unescape_bytea(''), b'')
200
201    def testMethodQuery(self):
202        query = self.db.query
203        query("select 1+1")
204        query("select 1+$1+$2", 2, 3)
205        query("select 1+$1+$2", (2, 3))
206        query("select 1+$1+$2", [2, 3])
207        query("select 1+$1", 1)
208
209    def testMethodQueryEmpty(self):
210        self.assertRaises(ValueError, self.db.query, '')
211
212    def testMethodQueryProgrammingError(self):
213        try:
214            self.db.query("select 1/0")
215        except pg.ProgrammingError as error:
216            self.assertEqual(error.sqlstate, '22012')
217
218    def testMethodEndcopy(self):
219        try:
220            self.db.endcopy()
221        except IOError:
222            pass
223
224    def testMethodClose(self):
225        self.db.close()
226        try:
227            self.db.reset()
228        except pg.Error:
229            pass
230        else:
231            self.fail('Reset should give an error for a closed connection')
232        self.assertRaises(pg.InternalError, self.db.close)
233        self.assertRaises(pg.InternalError, self.db.query, 'select 1')
234
235    def testExistingConnection(self):
236        db = pg.DB(self.db.db)
237        self.assertEqual(self.db.db, db.db)
238        self.assertTrue(db.db)
239        db.close()
240        self.assertTrue(db.db)
241        db.reopen()
242        self.assertTrue(db.db)
243        db.close()
244        self.assertTrue(db.db)
245        db = pg.DB(self.db)
246        self.assertEqual(self.db.db, db.db)
247        db = pg.DB(db=self.db.db)
248        self.assertEqual(self.db.db, db.db)
249
250        class DB2:
251            pass
252
253        db2 = DB2()
254        db2._cnx = self.db.db
255        db = pg.DB(db2)
256        self.assertEqual(self.db.db, db.db)
257
258
259class TestDBClass(unittest.TestCase):
260    """Test the methods of the DB class wrapped pg connection."""
261
262    @classmethod
263    def setUpClass(cls):
264        db = DB()
265        db.query("drop table if exists test cascade")
266        db.query("create table test ("
267            "i2 smallint, i4 integer, i8 bigint,"
268            " d numeric, f4 real, f8 double precision, m money,"
269            " v4 varchar(4), c4 char(4), t text)")
270        db.query("create or replace view test_view as"
271            " select i4, v4 from test")
272        db.close()
273
274    @classmethod
275    def tearDownClass(cls):
276        db = DB()
277        db.query("drop table test cascade")
278        db.close()
279
280    def setUp(self):
281        self.db = DB()
282        query = self.db.query
283        query('set client_encoding=utf8')
284        query('set standard_conforming_strings=on')
285        query("set lc_monetary='C'")
286        query("set datestyle='ISO,YMD'")
287        query('set bytea_output=hex')
288
289    def tearDown(self):
290        self.doCleanups()
291        self.db.close()
292
293    def testClassName(self):
294        self.assertEqual(self.db.__class__.__name__, 'DB')
295
296    def testModuleName(self):
297        self.assertEqual(self.db.__module__, 'pg')
298        self.assertEqual(self.db.__class__.__module__, 'pg')
299
300    def testEscapeLiteral(self):
301        f = self.db.escape_literal
302        r = f(b"plain")
303        self.assertIsInstance(r, bytes)
304        self.assertEqual(r, b"'plain'")
305        r = f(u"plain")
306        self.assertIsInstance(r, unicode)
307        self.assertEqual(r, u"'plain'")
308        r = f(u"that's kÀse".encode('utf-8'))
309        self.assertIsInstance(r, bytes)
310        self.assertEqual(r, u"'that''s kÀse'".encode('utf-8'))
311        r = f(u"that's kÀse")
312        self.assertIsInstance(r, unicode)
313        self.assertEqual(r, u"'that''s kÀse'")
314        self.assertEqual(f(r"It's fine to have a \ inside."),
315            r" E'It''s fine to have a \\ inside.'")
316        self.assertEqual(f('No "quotes" must be escaped.'),
317            "'No \"quotes\" must be escaped.'")
318
319    def testEscapeIdentifier(self):
320        f = self.db.escape_identifier
321        r = f(b"plain")
322        self.assertIsInstance(r, bytes)
323        self.assertEqual(r, b'"plain"')
324        r = f(u"plain")
325        self.assertIsInstance(r, unicode)
326        self.assertEqual(r, u'"plain"')
327        r = f(u"that's kÀse".encode('utf-8'))
328        self.assertIsInstance(r, bytes)
329        self.assertEqual(r, u'"that\'s kÀse"'.encode('utf-8'))
330        r = f(u"that's kÀse")
331        self.assertIsInstance(r, unicode)
332        self.assertEqual(r, u'"that\'s kÀse"')
333        self.assertEqual(f(r"It's fine to have a \ inside."),
334            '"It\'s fine to have a \\ inside."')
335        self.assertEqual(f('All "quotes" must be escaped.'),
336            '"All ""quotes"" must be escaped."')
337
338    def testEscapeString(self):
339        f = self.db.escape_string
340        r = f(b"plain")
341        self.assertIsInstance(r, bytes)
342        self.assertEqual(r, b"plain")
343        r = f(u"plain")
344        self.assertIsInstance(r, unicode)
345        self.assertEqual(r, u"plain")
346        r = f(u"that's kÀse".encode('utf-8'))
347        self.assertIsInstance(r, bytes)
348        self.assertEqual(r, u"that''s kÀse".encode('utf-8'))
349        r = f(u"that's kÀse")
350        self.assertIsInstance(r, unicode)
351        self.assertEqual(r, u"that''s kÀse")
352        self.assertEqual(f(r"It's fine to have a \ inside."),
353            r"It''s fine to have a \ inside.")
354
355    def testEscapeBytea(self):
356        f = self.db.escape_bytea
357        # note that escape_byte always returns hex output since Pg 9.0,
358        # regardless of the bytea_output setting
359        r = f(b'plain')
360        self.assertIsInstance(r, bytes)
361        self.assertEqual(r, b'\\x706c61696e')
362        r = f(u'plain')
363        self.assertIsInstance(r, unicode)
364        self.assertEqual(r, u'\\x706c61696e')
365        r = f(u"das is' kÀse".encode('utf-8'))
366        self.assertIsInstance(r, bytes)
367        self.assertEqual(r, b'\\x64617320697327206bc3a47365')
368        r = f(u"das is' kÀse")
369        self.assertIsInstance(r, unicode)
370        self.assertEqual(r, u'\\x64617320697327206bc3a47365')
371        self.assertEqual(f(b'O\x00ps\xff!'), b'\\x4f007073ff21')
372
373    def testUnescapeBytea(self):
374        f = self.db.unescape_bytea
375        r = f(b'plain')
376        self.assertIsInstance(r, bytes)
377        self.assertEqual(r, b'plain')
378        r = f(u'plain')
379        self.assertIsInstance(r, bytes)
380        self.assertEqual(r, b'plain')
381        r = f(b"das is' k\\303\\244se")
382        self.assertIsInstance(r, bytes)
383        self.assertEqual(r, u"das is' kÀse".encode('utf8'))
384        r = f(u"das is' k\\303\\244se")
385        self.assertIsInstance(r, bytes)
386        self.assertEqual(r, u"das is' kÀse".encode('utf8'))
387        self.assertEqual(f(r'O\\000ps\\377!'), b'O\\000ps\\377!')
388        self.assertEqual(f(r'\\x706c61696e'), b'\\x706c61696e')
389        self.assertEqual(f(r'\\x746861742773206be47365'),
390            b'\\x746861742773206be47365')
391        self.assertEqual(f(r'\\x4f007073ff21'), b'\\x4f007073ff21')
392
393    def testGetParameter(self):
394        f = self.db.get_parameter
395        self.assertRaises(TypeError, f)
396        self.assertRaises(TypeError, f, None)
397        self.assertRaises(TypeError, f, 42)
398        self.assertRaises(pg.ProgrammingError, f, 'this_does_not_exist')
399        r = f('standard_conforming_strings')
400        self.assertEqual(r, 'on')
401        r = f('lc_monetary')
402        self.assertEqual(r, 'C')
403        r = f('datestyle')
404        self.assertEqual(r, 'ISO, YMD')
405        r = f('bytea_output')
406        self.assertEqual(r, 'hex')
407        r = f(['bytea_output', 'lc_monetary'])
408        self.assertIsInstance(r, list)
409        self.assertEqual(r, ['hex', 'C'])
410        r = f(('standard_conforming_strings', 'datestyle', 'bytea_output'))
411        self.assertEqual(r, ['on', 'ISO, YMD', 'hex'])
412        r = f(set(['bytea_output', 'lc_monetary']))
413        self.assertIsInstance(r, dict)
414        self.assertEqual(r, {'bytea_output': 'hex', 'lc_monetary': 'C'})
415        r = f(set(['Bytea_Output', ' LC_Monetary ']))
416        self.assertIsInstance(r, dict)
417        self.assertEqual(r, {'Bytea_Output': 'hex', ' LC_Monetary ': 'C'})
418        s = dict.fromkeys(('bytea_output', 'lc_monetary'))
419        r = f(s)
420        self.assertIs(r, s)
421        self.assertEqual(r, {'bytea_output': 'hex', 'lc_monetary': 'C'})
422        s = dict.fromkeys(('Bytea_Output', ' LC_Monetary '))
423        r = f(s)
424        self.assertIs(r, s)
425        self.assertEqual(r, {'Bytea_Output': 'hex', ' LC_Monetary ': 'C'})
426
427    def testGetParameterServerVersion(self):
428        r = self.db.get_parameter('server_version_num')
429        self.assertIsInstance(r, str)
430        s = self.db.server_version
431        self.assertIsInstance(s, int)
432        self.assertEqual(r, str(s))
433
434    def testGetParameterAll(self):
435        f = self.db.get_parameter
436        r = f('all')
437        self.assertIsInstance(r, dict)
438        self.assertEqual(r['standard_conforming_strings'], 'on')
439        self.assertEqual(r['lc_monetary'], 'C')
440        self.assertEqual(r['DateStyle'], 'ISO, YMD')
441        self.assertEqual(r['bytea_output'], 'hex')
442
443    def testSetParameter(self):
444        f = self.db.set_parameter
445        g = self.db.get_parameter
446        self.assertRaises(TypeError, f)
447        self.assertRaises(TypeError, f, None)
448        self.assertRaises(TypeError, f, 42)
449        self.assertRaises(pg.ProgrammingError, f, 'this_does_not_exist')
450        f('standard_conforming_strings', 'off')
451        self.assertEqual(g('standard_conforming_strings'), 'off')
452        f('datestyle', 'ISO, DMY')
453        self.assertEqual(g('datestyle'), 'ISO, DMY')
454        f(['standard_conforming_strings', 'datestyle'], ['on', 'ISO, DMY'])
455        self.assertEqual(g('standard_conforming_strings'), 'on')
456        self.assertEqual(g('datestyle'), 'ISO, DMY')
457        f(['default_with_oids', 'standard_conforming_strings'], 'off')
458        self.assertEqual(g('default_with_oids'), 'off')
459        self.assertEqual(g('standard_conforming_strings'), 'off')
460        f(('standard_conforming_strings', 'datestyle'), ('on', 'ISO, YMD'))
461        self.assertEqual(g('standard_conforming_strings'), 'on')
462        self.assertEqual(g('datestyle'), 'ISO, YMD')
463        f(('default_with_oids', 'standard_conforming_strings'), 'off')
464        self.assertEqual(g('default_with_oids'), 'off')
465        self.assertEqual(g('standard_conforming_strings'), 'off')
466        f(set(['default_with_oids', 'standard_conforming_strings']), 'on')
467        self.assertEqual(g('default_with_oids'), 'on')
468        self.assertEqual(g('standard_conforming_strings'), 'on')
469        self.assertRaises(ValueError, f, set([ 'default_with_oids',
470            'standard_conforming_strings']), ['off', 'on'])
471        f(set(['default_with_oids', 'standard_conforming_strings']),
472            ['off', 'off'])
473        self.assertEqual(g('default_with_oids'), 'off')
474        self.assertEqual(g('standard_conforming_strings'), 'off')
475        f({'standard_conforming_strings': 'on', 'datestyle': 'ISO, YMD'})
476        self.assertEqual(g('standard_conforming_strings'), 'on')
477        self.assertEqual(g('datestyle'), 'ISO, YMD')
478
479    def testResetParameter(self):
480        db = DB()
481        f = db.set_parameter
482        g = db.get_parameter
483        r = g('default_with_oids')
484        self.assertIn(r, ('on', 'off'))
485        dwi, not_dwi = r, 'off' if r == 'on' else 'on'
486        r = g('standard_conforming_strings')
487        self.assertIn(r, ('on', 'off'))
488        scs, not_scs = r, 'off' if r == 'on' else 'on'
489        f('default_with_oids', not_dwi)
490        f('standard_conforming_strings', not_scs)
491        self.assertEqual(g('default_with_oids'), not_dwi)
492        self.assertEqual(g('standard_conforming_strings'), not_scs)
493        f('default_with_oids')
494        f('standard_conforming_strings', None)
495        self.assertEqual(g('default_with_oids'), dwi)
496        self.assertEqual(g('standard_conforming_strings'), scs)
497        f('default_with_oids', not_dwi)
498        f('standard_conforming_strings', not_scs)
499        self.assertEqual(g('default_with_oids'), not_dwi)
500        self.assertEqual(g('standard_conforming_strings'), not_scs)
501        f(['default_with_oids', 'standard_conforming_strings'], None)
502        self.assertEqual(g('default_with_oids'), dwi)
503        self.assertEqual(g('standard_conforming_strings'), scs)
504        f('default_with_oids', not_dwi)
505        f('standard_conforming_strings', not_scs)
506        self.assertEqual(g('default_with_oids'), not_dwi)
507        self.assertEqual(g('standard_conforming_strings'), not_scs)
508        f(('default_with_oids', 'standard_conforming_strings'))
509        self.assertEqual(g('default_with_oids'), dwi)
510        self.assertEqual(g('standard_conforming_strings'), scs)
511        f('default_with_oids', not_dwi)
512        f('standard_conforming_strings', not_scs)
513        self.assertEqual(g('default_with_oids'), not_dwi)
514        self.assertEqual(g('standard_conforming_strings'), not_scs)
515        f(set(['default_with_oids', 'standard_conforming_strings']))
516        self.assertEqual(g('default_with_oids'), dwi)
517        self.assertEqual(g('standard_conforming_strings'), scs)
518
519    def testResetParameterAll(self):
520        db = DB()
521        f = db.set_parameter
522        self.assertRaises(ValueError, f, 'all', 0)
523        self.assertRaises(ValueError, f, 'all', 'off')
524        g = db.get_parameter
525        r = g('default_with_oids')
526        self.assertIn(r, ('on', 'off'))
527        dwi, not_dwi = r, 'off' if r == 'on' else 'on'
528        r = g('standard_conforming_strings')
529        self.assertIn(r, ('on', 'off'))
530        scs, not_scs = r, 'off' if r == 'on' else 'on'
531        f('default_with_oids', not_dwi)
532        f('standard_conforming_strings', not_scs)
533        self.assertEqual(g('default_with_oids'), not_dwi)
534        self.assertEqual(g('standard_conforming_strings'), not_scs)
535        f('all')
536        self.assertEqual(g('default_with_oids'), dwi)
537        self.assertEqual(g('standard_conforming_strings'), scs)
538
539    def testSetParameterLocal(self):
540        f = self.db.set_parameter
541        g = self.db.get_parameter
542        self.assertEqual(g('standard_conforming_strings'), 'on')
543        self.db.begin()
544        f('standard_conforming_strings', 'off', local=True)
545        self.assertEqual(g('standard_conforming_strings'), 'off')
546        self.db.end()
547        self.assertEqual(g('standard_conforming_strings'), 'on')
548
549    def testSetParameterSession(self):
550        f = self.db.set_parameter
551        g = self.db.get_parameter
552        self.assertEqual(g('standard_conforming_strings'), 'on')
553        self.db.begin()
554        f('standard_conforming_strings', 'off', local=False)
555        self.assertEqual(g('standard_conforming_strings'), 'off')
556        self.db.end()
557        self.assertEqual(g('standard_conforming_strings'), 'off')
558
559    def testQuery(self):
560        query = self.db.query
561        query("drop table if exists test_table")
562        self.addCleanup(query, "drop table test_table")
563        q = "create table test_table (n integer) with oids"
564        r = query(q)
565        self.assertIsNone(r)
566        q = "insert into test_table values (1)"
567        r = query(q)
568        self.assertIsInstance(r, int)
569        q = "insert into test_table select 2"
570        r = query(q)
571        self.assertIsInstance(r, int)
572        oid = r
573        q = "select oid from test_table where n=2"
574        r = query(q).getresult()
575        self.assertEqual(len(r), 1)
576        r = r[0]
577        self.assertEqual(len(r), 1)
578        r = r[0]
579        self.assertEqual(r, oid)
580        q = "insert into test_table select 3 union select 4 union select 5"
581        r = query(q)
582        self.assertIsInstance(r, str)
583        self.assertEqual(r, '3')
584        q = "update test_table set n=4 where n<5"
585        r = query(q)
586        self.assertIsInstance(r, str)
587        self.assertEqual(r, '4')
588        q = "delete from test_table"
589        r = query(q)
590        self.assertIsInstance(r, str)
591        self.assertEqual(r, '5')
592
593    def testMultipleQueries(self):
594        self.assertEqual(self.db.query(
595            "create temporary table test_multi (n integer);"
596            "insert into test_multi values (4711);"
597            "select n from test_multi").getresult()[0][0], 4711)
598
599    def testQueryWithParams(self):
600        query = self.db.query
601        query("drop table if exists test_table")
602        self.addCleanup(query, "drop table test_table")
603        q = "create table test_table (n1 integer, n2 integer) with oids"
604        query(q)
605        q = "insert into test_table values ($1, $2)"
606        r = query(q, (1, 2))
607        self.assertIsInstance(r, int)
608        r = query(q, [3, 4])
609        self.assertIsInstance(r, int)
610        r = query(q, [5, 6])
611        self.assertIsInstance(r, int)
612        q = "select * from test_table order by 1, 2"
613        self.assertEqual(query(q).getresult(),
614            [(1, 2), (3, 4), (5, 6)])
615        q = "select * from test_table where n1=$1 and n2=$2"
616        self.assertEqual(query(q, 3, 4).getresult(), [(3, 4)])
617        q = "update test_table set n2=$2 where n1=$1"
618        r = query(q, 3, 7)
619        self.assertEqual(r, '1')
620        q = "select * from test_table order by 1, 2"
621        self.assertEqual(query(q).getresult(),
622            [(1, 2), (3, 7), (5, 6)])
623        q = "delete from test_table where n2!=$1"
624        r = query(q, 4)
625        self.assertEqual(r, '3')
626
627    def testEmptyQuery(self):
628        self.assertRaises(ValueError, self.db.query, '')
629
630    def testQueryProgrammingError(self):
631        try:
632            self.db.query("select 1/0")
633        except pg.ProgrammingError as error:
634            self.assertEqual(error.sqlstate, '22012')
635
636    def testPkey(self):
637        query = self.db.query
638        pkey = self.db.pkey
639        for t in ('pkeytest', 'primary key test'):
640            for n in range(7):
641                query('drop table if exists "%s%d"' % (t, n))
642                self.addCleanup(query, 'drop table "%s%d"' % (t, n))
643            query('create table "%s0" ('
644                "a smallint)" % t)
645            query('create table "%s1" ('
646                "b smallint primary key)" % t)
647            query('create table "%s2" ('
648                "c smallint, d smallint primary key)" % t)
649            query('create table "%s3" ('
650                "e smallint, f smallint, g smallint,"
651                " h smallint, i smallint,"
652                " primary key (f, h))" % t)
653            query('create table "%s4" ('
654                "more_than_one_letter varchar primary key)" % t)
655            query('create table "%s5" ('
656                '"with space" date primary key)' % t)
657            query('create table "%s6" ('
658                'a_very_long_column_name varchar,'
659                ' "with space" date,'
660                ' "42" int,'
661                " primary key (a_very_long_column_name,"
662                ' "with space", "42"))' % t)
663            self.assertRaises(KeyError, pkey, '%s0' % t)
664            self.assertEqual(pkey('%s1' % t), 'b')
665            self.assertEqual(pkey('%s2' % t), 'd')
666            r = pkey('%s3' % t)
667            self.assertIsInstance(r, frozenset)
668            self.assertEqual(r, frozenset('fh'))
669            self.assertEqual(pkey('%s4' % t), 'more_than_one_letter')
670            self.assertEqual(pkey('%s5' % t), 'with space')
671            r = pkey('%s6' % t)
672            self.assertIsInstance(r, frozenset)
673            self.assertEqual(r, frozenset([
674                'a_very_long_column_name', 'with space', '42']))
675            # a newly added primary key will be detected
676            query('alter table "%s0" add primary key (a)' % t)
677            self.assertEqual(pkey('%s0' % t), 'a')
678            # a changed primary key will not be detected,
679            # indicating that the internal cache is operating
680            query('alter table "%s1" rename column b to x' % t)
681            self.assertEqual(pkey('%s1' % t), 'b')
682            # we get the changed primary key when the cache is flushed
683            self.assertEqual(pkey('%s1' % t, flush=True), 'x')
684
685    def testGetDatabases(self):
686        databases = self.db.get_databases()
687        self.assertIn('template0', databases)
688        self.assertIn('template1', databases)
689        self.assertNotIn('not existing database', databases)
690        self.assertIn('postgres', databases)
691        self.assertIn(dbname, databases)
692
693    def testGetTables(self):
694        get_tables = self.db.get_tables
695        result1 = get_tables()
696        self.assertIsInstance(result1, list)
697        for t in result1:
698            t = t.split('.', 1)
699            self.assertGreaterEqual(len(t), 2)
700            if len(t) > 2:
701                self.assertTrue(t[1].startswith('"'))
702            t = t[0]
703            self.assertNotEqual(t, 'information_schema')
704            self.assertFalse(t.startswith('pg_'))
705        tables = ('"A very Special Name"',
706            '"A_MiXeD_quoted_NaMe"', 'a1', 'a2',
707            'A_MiXeD_NaMe', '"another special name"',
708            'averyveryveryveryveryveryverylongtablename',
709            'b0', 'b3', 'x', 'xx', 'xXx', 'y', 'z')
710        for t in tables:
711            self.db.query('drop table if exists %s' % t)
712            self.db.query("create table %s"
713                " as select 0" % t)
714        result3 = get_tables()
715        result2 = []
716        for t in result3:
717            if t not in result1:
718                result2.append(t)
719        result3 = []
720        for t in tables:
721            if not t.startswith('"'):
722                t = t.lower()
723            result3.append('public.' + t)
724        self.assertEqual(result2, result3)
725        for t in result2:
726            self.db.query('drop table %s' % t)
727        result2 = get_tables()
728        self.assertEqual(result2, result1)
729
730    def testGetRelations(self):
731        get_relations = self.db.get_relations
732        result = get_relations()
733        self.assertIn('public.test', result)
734        self.assertIn('public.test_view', result)
735        result = get_relations('rv')
736        self.assertIn('public.test', result)
737        self.assertIn('public.test_view', result)
738        result = get_relations('r')
739        self.assertIn('public.test', result)
740        self.assertNotIn('public.test_view', result)
741        result = get_relations('v')
742        self.assertNotIn('public.test', result)
743        self.assertIn('public.test_view', result)
744        result = get_relations('cisSt')
745        self.assertNotIn('public.test', result)
746        self.assertNotIn('public.test_view', result)
747
748    def testGetAttnames(self):
749        get_attnames = self.db.get_attnames
750        self.assertRaises(pg.ProgrammingError,
751            self.db.get_attnames, 'does_not_exist')
752        self.assertRaises(pg.ProgrammingError,
753            self.db.get_attnames, 'has.too.many.dots')
754        query = self.db.query
755        query("drop table if exists test_table")
756        self.addCleanup(query, "drop table test_table")
757        query("create table test_table("
758            " n int, alpha smallint, beta bool,"
759            " gamma char(5), tau text, v varchar(3))")
760        r = get_attnames('test_table')
761        self.assertIsInstance(r, dict)
762        self.assertEqual(r, dict(
763            n='int', alpha='int', beta='bool',
764            gamma='text', tau='text', v='text'))
765
766    def testGetAttnamesWithQuotes(self):
767        get_attnames = self.db.get_attnames
768        query = self.db.query
769        table = 'test table for get_attnames()'
770        query('drop table if exists "%s"' % table)
771        self.addCleanup(query, 'drop table "%s"' % table)
772        query('create table "%s"('
773            '"Prime!" smallint,'
774            ' "much space" integer, "Questions?" text)' % table)
775        r = get_attnames(table)
776        self.assertIsInstance(r, dict)
777        self.assertEqual(r, {
778            'Prime!': 'int', 'much space': 'int', 'Questions?': 'text'})
779        table = 'yet another test table for get_attnames()'
780        query('drop table if exists "%s"' % table)
781        self.addCleanup(query, 'drop table "%s"' % table)
782        self.db.query('create table "%s" ('
783            'a smallint, b integer, c bigint,'
784            ' e numeric, f float, f2 double precision, m money,'
785            ' x smallint, y smallint, z smallint,'
786            ' Normal_NaMe smallint, "Special Name" smallint,'
787            ' t text, u char(2), v varchar(2),'
788            ' primary key (y, u)) with oids' % table)
789        r = get_attnames(table)
790        self.assertIsInstance(r, dict)
791        self.assertEqual(r, {'a': 'int', 'c': 'int', 'b': 'int',
792            'e': 'num', 'f': 'float', 'f2': 'float', 'm': 'money',
793            'normal_name': 'int', 'Special Name': 'int',
794            'u': 'text', 't': 'text', 'v': 'text',
795            'y': 'int', 'x': 'int', 'z': 'int', 'oid': 'int'})
796
797    def testGetAttnamesWithRegtypes(self):
798        get_attnames = self.db.get_attnames
799        query = self.db.query
800        query("drop table if exists test_table")
801        self.addCleanup(query, "drop table test_table")
802        query("create table test_table("
803            " n int, alpha smallint, beta bool,"
804            " gamma char(5), tau text, v varchar(3))")
805        self.db.use_regtypes(True)
806        try:
807            r = get_attnames("test_table")
808            self.assertIsInstance(r, dict)
809        finally:
810            self.db.use_regtypes(False)
811        self.assertEqual(r, dict(
812            n='integer', alpha='smallint', beta='boolean',
813            gamma='character', tau='text', v='character varying'))
814
815    def testGetAttnamesIsCached(self):
816        get_attnames = self.db.get_attnames
817        query = self.db.query
818        query("drop table if exists test_table")
819        self.addCleanup(query, "drop table if exists test_table")
820        query("create table test_table(col int)")
821        r = get_attnames("test_table")
822        self.assertIsInstance(r, dict)
823        self.assertEqual(r, dict(col='int'))
824        query("drop table test_table")
825        query("create table test_table(col text)")
826        r = get_attnames("test_table")
827        self.assertEqual(r, dict(col='int'))
828        r = get_attnames("test_table", flush=True)
829        self.assertEqual(r, dict(col='text'))
830        query("drop table test_table")
831        r = get_attnames("test_table")
832        self.assertEqual(r, dict(col='text'))
833        self.assertRaises(pg.ProgrammingError,
834            get_attnames, "test_table", flush=True)
835
836    def testGetAttnamesIsOrdered(self):
837        get_attnames = self.db.get_attnames
838        query = self.db.query
839        query("drop table if exists test_table")
840        self.addCleanup(query, "drop table test_table")
841        query("create table test_table("
842            " n int, alpha smallint, v varchar(3),"
843            " gamma char(5), tau text, beta bool)")
844        r = get_attnames("test_table")
845        self.assertIsInstance(r, OrderedDict)
846        self.assertEqual(r, OrderedDict([
847            ('n', 'int'), ('alpha', 'int'), ('v', 'text'),
848            ('gamma', 'text'), ('tau', 'text'), ('beta', 'bool')]))
849        if OrderedDict is dict:
850            self.skipTest('OrderedDict is not supported')
851        r = ' '.join(list(r.keys()))
852        self.assertEqual(r, 'n alpha v gamma tau beta')
853
854    def testHasTablePrivilege(self):
855        can = self.db.has_table_privilege
856        self.assertEqual(can('test'), True)
857        self.assertEqual(can('test', 'select'), True)
858        self.assertEqual(can('test', 'SeLeCt'), True)
859        self.assertEqual(can('test', 'SELECT'), True)
860        self.assertEqual(can('test', 'insert'), True)
861        self.assertEqual(can('test', 'update'), True)
862        self.assertEqual(can('test', 'delete'), True)
863        self.assertEqual(can('pg_views', 'select'), True)
864        self.assertEqual(can('pg_views', 'delete'), False)
865        self.assertRaises(pg.ProgrammingError, can, 'test', 'foobar')
866        self.assertRaises(pg.ProgrammingError, can, 'table_does_not_exist')
867
868    def testGet(self):
869        get = self.db.get
870        query = self.db.query
871        table = 'get_test_table'
872        query('drop table if exists "%s"' % table)
873        self.addCleanup(query, 'drop table "%s"' % table)
874        query('create table "%s" ('
875            "n integer, t text) with oids" % table)
876        for n, t in enumerate('xyz'):
877            query('insert into "%s" values('"%d, '%s')"
878                % (table, n + 1, t))
879        self.assertRaises(pg.ProgrammingError, get, table, 2)
880        r = get(table, 2, 'n')
881        oid_table = 'oid(%s)' % table
882        self.assertIn(oid_table, r)
883        oid = r[oid_table]
884        self.assertIsInstance(oid, int)
885        result = {'t': 'y', 'n': 2, oid_table: oid}
886        self.assertEqual(r, result)
887        self.assertEqual(get(table + ' *', 2, 'n'), r)
888        self.assertEqual(get(table, oid, 'oid')['t'], 'y')
889        self.assertEqual(get(table, 1, 'n')['t'], 'x')
890        self.assertEqual(get(table, 3, 'n')['t'], 'z')
891        self.assertEqual(get(table, 2, 'n')['t'], 'y')
892        self.assertRaises(pg.DatabaseError, get, table, 4, 'n')
893        r['n'] = 3
894        self.assertEqual(get(table, r, 'n')['t'], 'z')
895        self.assertEqual(get(table, 1, 'n')['t'], 'x')
896        query('alter table "%s" alter n set not null' % table)
897        query('alter table "%s" add primary key (n)' % table)
898        self.assertEqual(get(table, 3)['t'], 'z')
899        self.assertEqual(get(table, 1)['t'], 'x')
900        self.assertEqual(get(table, 2)['t'], 'y')
901        r['n'] = 1
902        self.assertEqual(get(table, r)['t'], 'x')
903        r['n'] = 3
904        self.assertEqual(get(table, r)['t'], 'z')
905        r['n'] = 2
906        self.assertEqual(get(table, r)['t'], 'y')
907
908    def testGetWithCompositeKey(self):
909        get = self.db.get
910        query = self.db.query
911        table = 'get_test_table_1'
912        query('drop table if exists "%s"' % table)
913        self.addCleanup(query, 'drop table "%s"' % table)
914        query('create table "%s" ('
915            "n integer, t text, primary key (n))" % table)
916        for n, t in enumerate('abc'):
917            query('insert into "%s" values('
918                "%d, '%s')" % (table, n + 1, t))
919        self.assertEqual(get(table, 2)['t'], 'b')
920        table = 'get_test_table_2'
921        query('drop table if exists "%s"' % table)
922        self.addCleanup(query, 'drop table "%s"' % table)
923        query('create table "%s" ('
924            "n integer, m integer, t text, primary key (n, m))" % table)
925        for n in range(3):
926            for m in range(2):
927                t = chr(ord('a') + 2 * n + m)
928                query('insert into "%s" values('
929                    "%d, %d, '%s')" % (table, n + 1, m + 1, t))
930        self.assertRaises(pg.ProgrammingError, get, table, 2)
931        self.assertEqual(get(table, dict(n=2, m=2))['t'], 'd')
932        r = get(table, dict(n=1, m=2), ('n', 'm'))
933        self.assertEqual(r['t'], 'b')
934        r = get(table, dict(n=3, m=2), frozenset(['n', 'm']))
935        self.assertEqual(r['t'], 'f')
936
937    def testGetWithQuotedNames(self):
938        get = self.db.get
939        query = self.db.query
940        table = 'test table for get()'
941        query('drop table if exists "%s"' % table)
942        self.addCleanup(query, 'drop table "%s"' % table)
943        query('create table "%s" ('
944            '"Prime!" smallint primary key,'
945            ' "much space" integer, "Questions?" text)' % table)
946        query('insert into "%s"'
947              " values(17, 1001, 'No!')" % table)
948        r = get(table, 17)
949        self.assertIsInstance(r, dict)
950        self.assertEqual(r['Prime!'], 17)
951        self.assertEqual(r['much space'], 1001)
952        self.assertEqual(r['Questions?'], 'No!')
953
954    def testGetFromView(self):
955        self.db.query('delete from test where i4=14')
956        self.db.query('insert into test (i4, v4) values('
957            "14, 'abc4')")
958        r = self.db.get('test_view', 14, 'i4')
959        self.assertIn('v4', r)
960        self.assertEqual(r['v4'], 'abc4')
961
962    def testGetLittleBobbyTables(self):
963        get = self.db.get
964        query = self.db.query
965        query("drop table if exists test_students")
966        self.addCleanup(query, "drop table test_students")
967        query("create table test_students (firstname varchar primary key,"
968            " nickname varchar, grade char(2))")
969        query("insert into test_students values ("
970              "'D''Arcy', 'Darcey', 'A+')")
971        query("insert into test_students values ("
972              "'Sheldon', 'Moonpie', 'A+')")
973        query("insert into test_students values ("
974              "'Robert', 'Little Bobby Tables', 'D-')")
975        r = get('test_students', 'Sheldon')
976        self.assertEqual(r, dict(
977            firstname="Sheldon", nickname='Moonpie', grade='A+'))
978        r = get('test_students', 'Robert')
979        self.assertEqual(r, dict(
980            firstname="Robert", nickname='Little Bobby Tables', grade='D-'))
981        r = get('test_students', "D'Arcy")
982        self.assertEqual(r, dict(
983            firstname="D'Arcy", nickname='Darcey', grade='A+'))
984        try:
985            get('test_students', "D' Arcy")
986        except pg.DatabaseError as error:
987            self.assertEqual(str(error),
988                'No such record in test_students\nwhere "firstname" = $1\n'
989                'with $1="D\' Arcy"')
990        try:
991            get('test_students', "Robert'); TRUNCATE TABLE test_students;--")
992        except pg.DatabaseError as error:
993            self.assertEqual(str(error),
994                'No such record in test_students\nwhere "firstname" = $1\n'
995                'with $1="Robert\'); TRUNCATE TABLE test_students;--"')
996        q = "select * from test_students order by 1 limit 4"
997        r = query(q).getresult()
998        self.assertEqual(len(r), 3)
999        self.assertEqual(r[1][2], 'D-')
1000
1001    def testInsert(self):
1002        insert = self.db.insert
1003        query = self.db.query
1004        bool_on = pg.get_bool()
1005        decimal = pg.get_decimal()
1006        table = 'insert_test_table'
1007        query('drop table if exists "%s"' % table)
1008        self.addCleanup(query, 'drop table "%s"' % table)
1009        query('create table "%s" ('
1010            "i2 smallint, i4 integer, i8 bigint,"
1011            " d numeric, f4 real, f8 double precision, m money,"
1012            " v4 varchar(4), c4 char(4), t text,"
1013            " b boolean, ts timestamp) with oids" % table)
1014        oid_table = 'oid(%s)' % table
1015        tests = [dict(i2=None, i4=None, i8=None),
1016            (dict(i2='', i4='', i8=''), dict(i2=None, i4=None, i8=None)),
1017            (dict(i2=0, i4=0, i8=0), dict(i2=0, i4=0, i8=0)),
1018            dict(i2=42, i4=123456, i8=9876543210),
1019            dict(i2=2 ** 15 - 1,
1020                i4=int(2 ** 31 - 1), i8=long(2 ** 63 - 1)),
1021            dict(d=None), (dict(d=''), dict(d=None)),
1022            dict(d=Decimal(0)), (dict(d=0), dict(d=Decimal(0))),
1023            dict(f4=None, f8=None), dict(f4=0, f8=0),
1024            (dict(f4='', f8=''), dict(f4=None, f8=None)),
1025            (dict(d=1234.5, f4=1234.5, f8=1234.5),
1026                  dict(d=Decimal('1234.5'))),
1027            dict(d=Decimal('123.456789'), f4=12.375, f8=123.4921875),
1028            dict(d=Decimal('123456789.9876543212345678987654321')),
1029            dict(m=None), (dict(m=''), dict(m=None)),
1030            dict(m=Decimal('-1234.56')),
1031            (dict(m=('-1234.56')), dict(m=Decimal('-1234.56'))),
1032            dict(m=Decimal('1234.56')), dict(m=Decimal('123456')),
1033            (dict(m='1234.56'), dict(m=Decimal('1234.56'))),
1034            (dict(m=1234.5), dict(m=Decimal('1234.5'))),
1035            (dict(m=-1234.5), dict(m=Decimal('-1234.5'))),
1036            (dict(m=123456), dict(m=Decimal('123456'))),
1037            (dict(m='1234567.89'), dict(m=Decimal('1234567.89'))),
1038            dict(b=None), (dict(b=''), dict(b=None)),
1039            dict(b='f'), dict(b='t'),
1040            (dict(b=0), dict(b='f')), (dict(b=1), dict(b='t')),
1041            (dict(b=False), dict(b='f')), (dict(b=True), dict(b='t')),
1042            (dict(b='0'), dict(b='f')), (dict(b='1'), dict(b='t')),
1043            (dict(b='n'), dict(b='f')), (dict(b='y'), dict(b='t')),
1044            (dict(b='no'), dict(b='f')), (dict(b='yes'), dict(b='t')),
1045            (dict(b='off'), dict(b='f')), (dict(b='on'), dict(b='t')),
1046            dict(v4=None, c4=None, t=None),
1047            (dict(v4='', c4='', t=''), dict(c4=' ' * 4)),
1048            dict(v4='1234', c4='1234', t='1234' * 10),
1049            dict(v4='abcd', c4='abcd', t='abcdefg'),
1050            (dict(v4='abc', c4='abc', t='abc'), dict(c4='abc ')),
1051            dict(ts=None), (dict(ts=''), dict(ts=None)),
1052            (dict(ts=0), dict(ts=None)), (dict(ts=False), dict(ts=None)),
1053            dict(ts='2012-12-21 00:00:00'),
1054            (dict(ts='2012-12-21'), dict(ts='2012-12-21 00:00:00')),
1055            dict(ts='2012-12-21 12:21:12'),
1056            dict(ts='2013-01-05 12:13:14'),
1057            dict(ts='current_timestamp')]
1058        for test in tests:
1059            if isinstance(test, dict):
1060                data = test
1061                change = {}
1062            else:
1063                data, change = test
1064            expect = data.copy()
1065            expect.update(change)
1066            if bool_on:
1067                b = expect.get('b')
1068                if b is not None:
1069                    expect['b'] = b == 't'
1070            if decimal is not Decimal:
1071                d = expect.get('d')
1072                if d is not None:
1073                    expect['d'] = decimal(d)
1074                m = expect.get('m')
1075                if m is not None:
1076                    expect['m'] = decimal(m)
1077            self.assertEqual(insert(table, data), data)
1078            self.assertIn(oid_table, data)
1079            oid = data[oid_table]
1080            self.assertIsInstance(oid, int)
1081            data = dict(item for item in data.items()
1082                if item[0] in expect)
1083            ts = expect.get('ts')
1084            if ts == 'current_timestamp':
1085                ts = expect['ts'] = data['ts']
1086                if len(ts) > 19:
1087                    self.assertEqual(ts[19], '.')
1088                    ts = ts[:19]
1089                else:
1090                    self.assertEqual(len(ts), 19)
1091                self.assertTrue(ts[:4].isdigit())
1092                self.assertEqual(ts[4], '-')
1093                self.assertEqual(ts[10], ' ')
1094                self.assertTrue(ts[11:13].isdigit())
1095                self.assertEqual(ts[13], ':')
1096            self.assertEqual(data, expect)
1097            data = query(
1098                'select oid,* from "%s"' % table).dictresult()[0]
1099            self.assertEqual(data['oid'], oid)
1100            data = dict(item for item in data.items()
1101                if item[0] in expect)
1102            self.assertEqual(data, expect)
1103            query('delete from "%s"' % table)
1104
1105    def testInsertWithQuotedNames(self):
1106        insert = self.db.insert
1107        query = self.db.query
1108        table = 'test table for insert()'
1109        query('drop table if exists "%s"' % table)
1110        self.addCleanup(query, 'drop table "%s"' % table)
1111        query('create table "%s" ('
1112            '"Prime!" smallint primary key,'
1113            ' "much space" integer, "Questions?" text)' % table)
1114        r = {'Prime!': 11, 'much space': 2002, 'Questions?': 'What?'}
1115        r = insert(table, r)
1116        self.assertIsInstance(r, dict)
1117        self.assertEqual(r['Prime!'], 11)
1118        self.assertEqual(r['much space'], 2002)
1119        self.assertEqual(r['Questions?'], 'What?')
1120        r = query('select * from "%s" limit 2' % table).dictresult()
1121        self.assertEqual(len(r), 1)
1122        r = r[0]
1123        self.assertEqual(r['Prime!'], 11)
1124        self.assertEqual(r['much space'], 2002)
1125        self.assertEqual(r['Questions?'], 'What?')
1126
1127    def testUpdate(self):
1128        update = self.db.update
1129        query = self.db.query
1130        table = 'update_test_table'
1131        query('drop table if exists "%s"' % table)
1132        self.addCleanup(query, 'drop table "%s"' % table)
1133        query('create table "%s" ('
1134            "n integer, t text) with oids" % table)
1135        for n, t in enumerate('xyz'):
1136            query('insert into "%s" values('
1137                "%d, '%s')" % (table, n + 1, t))
1138        self.assertRaises(pg.ProgrammingError, self.db.get, table, 2)
1139        r = self.db.get(table, 2, 'n')
1140        r['t'] = 'u'
1141        s = update(table, r)
1142        self.assertEqual(s, r)
1143        q = 'select t from "%s" where n=2' % table
1144        r = query(q).getresult()[0][0]
1145        self.assertEqual(r, 'u')
1146
1147    def testUpdateWithCompositeKey(self):
1148        update = self.db.update
1149        query = self.db.query
1150        table = 'update_test_table_1'
1151        query('drop table if exists "%s"' % table)
1152        self.addCleanup(query, 'drop table if exists "%s"' % table)
1153        query('create table "%s" ('
1154            "n integer, t text, primary key (n))" % table)
1155        for n, t in enumerate('abc'):
1156            query('insert into "%s" values('
1157                "%d, '%s')" % (table, n + 1, t))
1158        self.assertRaises(pg.ProgrammingError, update,
1159                          table, dict(t='b'))
1160        s = dict(n=2, t='d')
1161        r = update(table, s)
1162        self.assertIs(r, s)
1163        self.assertEqual(r['n'], 2)
1164        self.assertEqual(r['t'], 'd')
1165        q = 'select t from "%s" where n=2' % table
1166        r = query(q).getresult()[0][0]
1167        self.assertEqual(r, 'd')
1168        s.update(dict(n=4, t='e'))
1169        r = update(table, s)
1170        self.assertEqual(r['n'], 4)
1171        self.assertEqual(r['t'], 'e')
1172        q = 'select t from "%s" where n=2' % table
1173        r = query(q).getresult()[0][0]
1174        self.assertEqual(r, 'd')
1175        q = 'select t from "%s" where n=4' % table
1176        r = query(q).getresult()
1177        self.assertEqual(len(r), 0)
1178        query('drop table "%s"' % table)
1179        table = 'update_test_table_2'
1180        query('drop table if exists "%s"' % table)
1181        query('create table "%s" ('
1182            "n integer, m integer, t text, primary key (n, m))" % table)
1183        for n in range(3):
1184            for m in range(2):
1185                t = chr(ord('a') + 2 * n + m)
1186                query('insert into "%s" values('
1187                    "%d, %d, '%s')" % (table, n + 1, m + 1, t))
1188        self.assertRaises(pg.ProgrammingError, update,
1189                          table, dict(n=2, t='b'))
1190        self.assertEqual(update(table,
1191                                dict(n=2, m=2, t='x'))['t'], 'x')
1192        q = 'select t from "%s" where n=2 order by m' % table
1193        r = [r[0] for r in query(q).getresult()]
1194        self.assertEqual(r, ['c', 'x'])
1195
1196    def testUpdateWithQuotedNames(self):
1197        update = self.db.update
1198        query = self.db.query
1199        table = 'test table for update()'
1200        query('drop table if exists "%s"' % table)
1201        self.addCleanup(query, 'drop table "%s"' % table)
1202        query('create table "%s" ('
1203            '"Prime!" smallint primary key,'
1204            ' "much space" integer, "Questions?" text)' % table)
1205        query('insert into "%s"'
1206              " values(13, 3003, 'Why!')" % table)
1207        r = {'Prime!': 13, 'much space': 7007, 'Questions?': 'When?'}
1208        r = update(table, r)
1209        self.assertIsInstance(r, dict)
1210        self.assertEqual(r['Prime!'], 13)
1211        self.assertEqual(r['much space'], 7007)
1212        self.assertEqual(r['Questions?'], 'When?')
1213        r = query('select * from "%s" limit 2' % table).dictresult()
1214        self.assertEqual(len(r), 1)
1215        r = r[0]
1216        self.assertEqual(r['Prime!'], 13)
1217        self.assertEqual(r['much space'], 7007)
1218        self.assertEqual(r['Questions?'], 'When?')
1219
1220    def testUpsert(self):
1221        upsert = self.db.upsert
1222        query = self.db.query
1223        table = 'upsert_test_table'
1224        query('drop table if exists "%s"' % table)
1225        self.addCleanup(query, 'drop table "%s"' % table)
1226        query('create table "%s" ('
1227            "n integer primary key, t text) with oids" % table)
1228        s = dict(n=1, t='x')
1229        try:
1230            r = upsert(table, s)
1231        except pg.ProgrammingError as error:
1232            if self.db.server_version < 90500:
1233                self.skipTest('database does not support upsert')
1234            self.fail(str(error))
1235        self.assertIs(r, s)
1236        self.assertEqual(r['n'], 1)
1237        self.assertEqual(r['t'], 'x')
1238        s.update(n=2, t='y')
1239        r = upsert(table, s, **dict.fromkeys(s))
1240        self.assertIs(r, s)
1241        self.assertEqual(r['n'], 2)
1242        self.assertEqual(r['t'], 'y')
1243        q = 'select n, t from "%s" order by n limit 3' % table
1244        r = query(q).getresult()
1245        self.assertEqual(r, [(1, 'x'), (2, 'y')])
1246        s.update(t='z')
1247        r = upsert(table, s)
1248        self.assertIs(r, s)
1249        self.assertEqual(r['n'], 2)
1250        self.assertEqual(r['t'], 'z')
1251        r = query(q).getresult()
1252        self.assertEqual(r, [(1, 'x'), (2, 'z')])
1253        s.update(t='n')
1254        r = upsert(table, s, t=False)
1255        self.assertIs(r, s)
1256        self.assertEqual(r['n'], 2)
1257        self.assertEqual(r['t'], 'z')
1258        r = query(q).getresult()
1259        self.assertEqual(r, [(1, 'x'), (2, 'z')])
1260        s.update(t='y')
1261        r = upsert(table, s, t=True)
1262        self.assertIs(r, s)
1263        self.assertEqual(r['n'], 2)
1264        self.assertEqual(r['t'], 'y')
1265        r = query(q).getresult()
1266        self.assertEqual(r, [(1, 'x'), (2, 'y')])
1267        s.update(t='n')
1268        r = upsert(table, s, t="included.t || '2'")
1269        self.assertIs(r, s)
1270        self.assertEqual(r['n'], 2)
1271        self.assertEqual(r['t'], 'y2')
1272        r = query(q).getresult()
1273        self.assertEqual(r, [(1, 'x'), (2, 'y2')])
1274        s.update(t='y')
1275        r = upsert(table, s, t="excluded.t || '3'")
1276        self.assertIs(r, s)
1277        self.assertEqual(r['n'], 2)
1278        self.assertEqual(r['t'], 'y3')
1279        r = query(q).getresult()
1280        self.assertEqual(r, [(1, 'x'), (2, 'y3')])
1281        s.update(n=1, t='2')
1282        r = upsert(table, s, t="included.t || excluded.t")
1283        self.assertIs(r, s)
1284        self.assertEqual(r['n'], 1)
1285        self.assertEqual(r['t'], 'x2')
1286        r = query(q).getresult()
1287        self.assertEqual(r, [(1, 'x2'), (2, 'y3')])
1288
1289    def testUpsertWithCompositeKey(self):
1290        upsert = self.db.upsert
1291        query = self.db.query
1292        table = 'upsert_test_table_2'
1293        query('drop table if exists "%s"' % table)
1294        self.addCleanup(query, 'drop table "%s"' % table)
1295        query('create table "%s" ('
1296            "n integer, m integer, t text, primary key (n, m))" % table)
1297        s = dict(n=1, m=2, t='x')
1298        try:
1299            r = upsert(table, s)
1300        except pg.ProgrammingError as error:
1301            if self.db.server_version < 90500:
1302                self.skipTest('database does not support upsert')
1303            self.fail(str(error))
1304        self.assertIs(r, s)
1305        self.assertEqual(r['n'], 1)
1306        self.assertEqual(r['m'], 2)
1307        self.assertEqual(r['t'], 'x')
1308        s.update(m=3, t='y')
1309        r = upsert(table, s, **dict.fromkeys(s))
1310        self.assertIs(r, s)
1311        self.assertEqual(r['n'], 1)
1312        self.assertEqual(r['m'], 3)
1313        self.assertEqual(r['t'], 'y')
1314        q = 'select n, m, t from "%s" order by n, m limit 3' % table
1315        r = query(q).getresult()
1316        self.assertEqual(r, [(1, 2, 'x'), (1, 3, 'y')])
1317        s.update(t='z')
1318        r = upsert(table, s)
1319        self.assertIs(r, s)
1320        self.assertEqual(r['n'], 1)
1321        self.assertEqual(r['m'], 3)
1322        self.assertEqual(r['t'], 'z')
1323        r = query(q).getresult()
1324        self.assertEqual(r, [(1, 2, 'x'), (1, 3, 'z')])
1325        s.update(t='n')
1326        r = upsert(table, s, t=False)
1327        self.assertIs(r, s)
1328        self.assertEqual(r['n'], 1)
1329        self.assertEqual(r['m'], 3)
1330        self.assertEqual(r['t'], 'z')
1331        r = query(q).getresult()
1332        self.assertEqual(r, [(1, 2, 'x'), (1, 3, 'z')])
1333        s.update(t='n')
1334        r = upsert(table, s, t=True)
1335        self.assertIs(r, s)
1336        self.assertEqual(r['n'], 1)
1337        self.assertEqual(r['m'], 3)
1338        self.assertEqual(r['t'], 'n')
1339        r = query(q).getresult()
1340        self.assertEqual(r, [(1, 2, 'x'), (1, 3, 'n')])
1341        s.update(n=2, t='y')
1342        r = upsert(table, s, t="'z'")
1343        self.assertIs(r, s)
1344        self.assertEqual(r['n'], 2)
1345        self.assertEqual(r['m'], 3)
1346        self.assertEqual(r['t'], 'y')
1347        r = query(q).getresult()
1348        self.assertEqual(r, [(1, 2, 'x'), (1, 3, 'n'), (2, 3, 'y')])
1349        s.update(n=1, t='m')
1350        r = upsert(table, s, t='included.t || excluded.t')
1351        self.assertIs(r, s)
1352        self.assertEqual(r['n'], 1)
1353        self.assertEqual(r['m'], 3)
1354        self.assertEqual(r['t'], 'nm')
1355        r = query(q).getresult()
1356        self.assertEqual(r, [(1, 2, 'x'), (1, 3, 'nm'), (2, 3, 'y')])
1357
1358    def testUpsertWithQuotedNames(self):
1359        upsert = self.db.upsert
1360        query = self.db.query
1361        table = 'test table for upsert()'
1362        query('drop table if exists "%s"' % table)
1363        self.addCleanup(query, 'drop table "%s"' % table)
1364        query('create table "%s" ('
1365            '"Prime!" smallint primary key,'
1366            ' "much space" integer, "Questions?" text)' % table)
1367        s = {'Prime!': 31, 'much space': 9009, 'Questions?': 'Yes.'}
1368        try:
1369            r = upsert(table, s)
1370        except pg.ProgrammingError as error:
1371            if self.db.server_version < 90500:
1372                self.skipTest('database does not support upsert')
1373            self.fail(str(error))
1374        self.assertIs(r, s)
1375        self.assertEqual(r['Prime!'], 31)
1376        self.assertEqual(r['much space'], 9009)
1377        self.assertEqual(r['Questions?'], 'Yes.')
1378        q = 'select * from "%s" limit 2' % table
1379        r = query(q).getresult()
1380        self.assertEqual(r, [(31, 9009, 'Yes.')])
1381        s.update({'Questions?': 'No.'})
1382        r = upsert(table, s)
1383        self.assertIs(r, s)
1384        self.assertEqual(r['Prime!'], 31)
1385        self.assertEqual(r['much space'], 9009)
1386        self.assertEqual(r['Questions?'], 'No.')
1387        r = query(q).getresult()
1388        self.assertEqual(r, [(31, 9009, 'No.')])
1389
1390    def testClear(self):
1391        clear = self.db.clear
1392        query = self.db.query
1393        f = False if pg.get_bool() else 'f'
1394        table = 'clear_test_table'
1395        query('drop table if exists "%s"' % table)
1396        self.addCleanup(query, 'drop table "%s"' % table)
1397        query('create table "%s" ('
1398            "n integer, b boolean, d date, t text)" % table)
1399        r = clear(table)
1400        result = {'n': 0, 'b': f, 'd': '', 't': ''}
1401        self.assertEqual(r, result)
1402        r['a'] = r['n'] = 1
1403        r['d'] = r['t'] = 'x'
1404        r['b'] = 't'
1405        r['oid'] = long(1)
1406        r = clear(table, r)
1407        result = {'a': 1, 'n': 0, 'b': f, 'd': '', 't': '',
1408            'oid': long(1)}
1409        self.assertEqual(r, result)
1410
1411    def testClearWithQuotedNames(self):
1412        clear = self.db.clear
1413        query = self.db.query
1414        table = 'test table for clear()'
1415        query('drop table if exists "%s"' % table)
1416        self.addCleanup(query, 'drop table "%s"' % table)
1417        query('create table "%s" ('
1418            '"Prime!" smallint primary key,'
1419            ' "much space" integer, "Questions?" text)' % table)
1420        r = clear(table)
1421        self.assertIsInstance(r, dict)
1422        self.assertEqual(r['Prime!'], 0)
1423        self.assertEqual(r['much space'], 0)
1424        self.assertEqual(r['Questions?'], '')
1425
1426    def testDelete(self):
1427        delete = self.db.delete
1428        query = self.db.query
1429        table = 'delete_test_table'
1430        query('drop table if exists "%s"' % table)
1431        self.addCleanup(query, 'drop table "%s"' % table)
1432        query('create table "%s" ('
1433            "n integer, t text) with oids" % table)
1434        for n, t in enumerate('xyz'):
1435            query('insert into "%s" values('
1436                "%d, '%s')" % (table, n + 1, t))
1437        self.assertRaises(pg.ProgrammingError, self.db.get, table, 2)
1438        r = self.db.get(table, 1, 'n')
1439        s = delete(table, r)
1440        self.assertEqual(s, 1)
1441        r = self.db.get(table, 3, 'n')
1442        s = delete(table, r)
1443        self.assertEqual(s, 1)
1444        s = delete(table, r)
1445        self.assertEqual(s, 0)
1446        r = query('select * from "%s"' % table).dictresult()
1447        self.assertEqual(len(r), 1)
1448        r = r[0]
1449        result = {'n': 2, 't': 'y'}
1450        self.assertEqual(r, result)
1451        r = self.db.get(table, 2, 'n')
1452        s = delete(table, r)
1453        self.assertEqual(s, 1)
1454        s = delete(table, r)
1455        self.assertEqual(s, 0)
1456        self.assertRaises(pg.DatabaseError, self.db.get, table, 2, 'n')
1457
1458    def testDeleteWithCompositeKey(self):
1459        query = self.db.query
1460        table = 'delete_test_table_1'
1461        query('drop table if exists "%s"' % table)
1462        self.addCleanup(query, 'drop table "%s"' % table)
1463        query('create table "%s" ('
1464            "n integer, t text, primary key (n))" % table)
1465        for n, t in enumerate('abc'):
1466            query("insert into %s values("
1467                "%d, '%s')" % (table, n + 1, t))
1468        self.assertRaises(pg.ProgrammingError, self.db.delete,
1469            table, dict(t='b'))
1470        self.assertEqual(self.db.delete(table, dict(n=2)), 1)
1471        r = query('select t from "%s" where n=2' % table
1472                  ).getresult()
1473        self.assertEqual(r, [])
1474        self.assertEqual(self.db.delete(table, dict(n=2)), 0)
1475        r = query('select t from "%s" where n=3' % table
1476                  ).getresult()[0][0]
1477        self.assertEqual(r, 'c')
1478        table = 'delete_test_table_2'
1479        query('drop table if exists "%s"' % table)
1480        self.addCleanup(query, 'drop table "%s"' % table)
1481        query('create table "%s" ('
1482            "n integer, m integer, t text, primary key (n, m))" % table)
1483        for n in range(3):
1484            for m in range(2):
1485                t = chr(ord('a') + 2 * n + m)
1486                query('insert into "%s" values('
1487                    "%d, %d, '%s')" % (table, n + 1, m + 1, t))
1488        self.assertRaises(pg.ProgrammingError, self.db.delete,
1489            table, dict(n=2, t='b'))
1490        self.assertEqual(self.db.delete(table, dict(n=2, m=2)), 1)
1491        r = [r[0] for r in query('select t from "%s" where n=2'
1492            ' order by m' % table).getresult()]
1493        self.assertEqual(r, ['c'])
1494        self.assertEqual(self.db.delete(table, dict(n=2, m=2)), 0)
1495        r = [r[0] for r in query('select t from "%s" where n=3'
1496            ' order by m' % table).getresult()]
1497        self.assertEqual(r, ['e', 'f'])
1498        self.assertEqual(self.db.delete(table, dict(n=3, m=1)), 1)
1499        r = [r[0] for r in query('select t from "%s" where n=3'
1500            ' order by m' % table).getresult()]
1501        self.assertEqual(r, ['f'])
1502
1503    def testDeleteWithQuotedNames(self):
1504        delete = self.db.delete
1505        query = self.db.query
1506        table = 'test table for delete()'
1507        query('drop table if exists "%s"' % table)
1508        self.addCleanup(query, 'drop table "%s"' % table)
1509        query('create table "%s" ('
1510            '"Prime!" smallint primary key,'
1511            ' "much space" integer, "Questions?" text)' % table)
1512        query('insert into "%s"'
1513              " values(19, 5005, 'Yes!')" % table)
1514        r = {'Prime!': 17}
1515        r = delete(table, r)
1516        self.assertEqual(r, 0)
1517        r = query('select count(*) from "%s"' % table).getresult()
1518        self.assertEqual(r[0][0], 1)
1519        r = {'Prime!': 19}
1520        r = delete(table, r)
1521        self.assertEqual(r, 1)
1522        r = query('select count(*) from "%s"' % table).getresult()
1523        self.assertEqual(r[0][0], 0)
1524
1525    def testTruncate(self):
1526        truncate = self.db.truncate
1527        self.assertRaises(TypeError, truncate, None)
1528        self.assertRaises(TypeError, truncate, 42)
1529        self.assertRaises(TypeError, truncate, dict(test_table=None))
1530        query = self.db.query
1531        query("drop table if exists test_table")
1532        self.addCleanup(query, "drop table test_table")
1533        query("create table test_table (n smallint)")
1534        for i in range(3):
1535            query("insert into test_table values (1)")
1536        q = "select count(*) from test_table"
1537        r = query(q).getresult()[0][0]
1538        self.assertEqual(r, 3)
1539        truncate('test_table')
1540        r = query(q).getresult()[0][0]
1541        self.assertEqual(r, 0)
1542        for i in range(3):
1543            query("insert into test_table values (1)")
1544        r = query(q).getresult()[0][0]
1545        self.assertEqual(r, 3)
1546        truncate('public.test_table')
1547        r = query(q).getresult()[0][0]
1548        self.assertEqual(r, 0)
1549        query("drop table if exists test_table_2")
1550        self.addCleanup(query, "drop table test_table_2")
1551        query('create table test_table_2 (n smallint)')
1552        for t in (list, tuple, set):
1553            for i in range(3):
1554                query("insert into test_table values (1)")
1555                query("insert into test_table_2 values (2)")
1556            q = ("select (select count(*) from test_table),"
1557                " (select count(*) from test_table_2)")
1558            r = query(q).getresult()[0]
1559            self.assertEqual(r, (3, 3))
1560            truncate(t(['test_table', 'test_table_2']))
1561            r = query(q).getresult()[0]
1562            self.assertEqual(r, (0, 0))
1563
1564    def testTruncateRestart(self):
1565        truncate = self.db.truncate
1566        self.assertRaises(TypeError, truncate, 'test_table', restart='invalid')
1567        query = self.db.query
1568        query("drop table if exists test_table")
1569        self.addCleanup(query, "drop table test_table")
1570        query("create table test_table (n serial, t text)")
1571        for n in range(3):
1572            query("insert into test_table (t) values ('test')")
1573        q = "select count(n), min(n), max(n) from test_table"
1574        r = query(q).getresult()[0]
1575        self.assertEqual(r, (3, 1, 3))
1576        truncate('test_table')
1577        r = query(q).getresult()[0]
1578        self.assertEqual(r, (0, None, None))
1579        for n in range(3):
1580            query("insert into test_table (t) values ('test')")
1581        r = query(q).getresult()[0]
1582        self.assertEqual(r, (3, 4, 6))
1583        truncate('test_table', restart=True)
1584        r = query(q).getresult()[0]
1585        self.assertEqual(r, (0, None, None))
1586        for n in range(3):
1587            query("insert into test_table (t) values ('test')")
1588        r = query(q).getresult()[0]
1589        self.assertEqual(r, (3, 1, 3))
1590
1591    def testTruncateCascade(self):
1592        truncate = self.db.truncate
1593        self.assertRaises(TypeError, truncate, 'test_table', cascade='invalid')
1594        query = self.db.query
1595        query("drop table if exists test_child")
1596        query("drop table if exists test_parent")
1597        self.addCleanup(query, "drop table test_parent")
1598        query("create table test_parent (n smallint primary key)")
1599        self.addCleanup(query, "drop table test_child")
1600        query("create table test_child ("
1601            " n smallint primary key references test_parent (n))")
1602        for n in range(3):
1603            query("insert into test_parent (n) values (%d)" % n)
1604            query("insert into test_child (n) values (%d)" % n)
1605        q = ("select (select count(*) from test_parent),"
1606            " (select count(*) from test_child)")
1607        r = query(q).getresult()[0]
1608        self.assertEqual(r, (3, 3))
1609        self.assertRaises(pg.ProgrammingError, truncate, 'test_parent')
1610        truncate(['test_parent', 'test_child'])
1611        r = query(q).getresult()[0]
1612        self.assertEqual(r, (0, 0))
1613        for n in range(3):
1614            query("insert into test_parent (n) values (%d)" % n)
1615            query("insert into test_child (n) values (%d)" % n)
1616        r = query(q).getresult()[0]
1617        self.assertEqual(r, (3, 3))
1618        truncate('test_parent', cascade=True)
1619        r = query(q).getresult()[0]
1620        self.assertEqual(r, (0, 0))
1621        for n in range(3):
1622            query("insert into test_parent (n) values (%d)" % n)
1623            query("insert into test_child (n) values (%d)" % n)
1624        r = query(q).getresult()[0]
1625        self.assertEqual(r, (3, 3))
1626        truncate('test_child')
1627        r = query(q).getresult()[0]
1628        self.assertEqual(r, (3, 0))
1629        self.assertRaises(pg.ProgrammingError, truncate, 'test_parent')
1630        truncate('test_parent', cascade=True)
1631        r = query(q).getresult()[0]
1632        self.assertEqual(r, (0, 0))
1633
1634    def testTruncateOnly(self):
1635        truncate = self.db.truncate
1636        self.assertRaises(TypeError, truncate, 'test_table', only='invalid')
1637        query = self.db.query
1638        query("drop table if exists test_child")
1639        query("drop table if exists test_parent")
1640        self.addCleanup(query, "drop table test_parent")
1641        query("create table test_parent (n smallint)")
1642        self.addCleanup(query, "drop table test_child")
1643        query("create table test_child ("
1644            " m smallint) inherits (test_parent)")
1645        for n in range(3):
1646            query("insert into test_parent (n) values (1)")
1647            query("insert into test_child (n, m) values (2, 3)")
1648        q = ("select (select count(*) from test_parent),"
1649            " (select count(*) from test_child)")
1650        r = query(q).getresult()[0]
1651        self.assertEqual(r, (6, 3))
1652        truncate('test_parent')
1653        r = query(q).getresult()[0]
1654        self.assertEqual(r, (0, 0))
1655        for n in range(3):
1656            query("insert into test_parent (n) values (1)")
1657            query("insert into test_child (n, m) values (2, 3)")
1658        r = query(q).getresult()[0]
1659        self.assertEqual(r, (6, 3))
1660        truncate('test_parent*')
1661        r = query(q).getresult()[0]
1662        self.assertEqual(r, (0, 0))
1663        for n in range(3):
1664            query("insert into test_parent (n) values (1)")
1665            query("insert into test_child (n, m) values (2, 3)")
1666        r = query(q).getresult()[0]
1667        self.assertEqual(r, (6, 3))
1668        truncate('test_parent', only=True)
1669        r = query(q).getresult()[0]
1670        self.assertEqual(r, (3, 3))
1671        truncate('test_parent', only=False)
1672        r = query(q).getresult()[0]
1673        self.assertEqual(r, (0, 0))
1674        self.assertRaises(ValueError, truncate, 'test_parent*', only=True)
1675        truncate('test_parent*', only=False)
1676        query("drop table if exists test_parent_2")
1677        self.addCleanup(query, "drop table test_parent_2")
1678        query("create table test_parent_2 (n smallint)")
1679        query("drop table if exists test_child_2")
1680        self.addCleanup(query, "drop table test_child_2")
1681        query("create table test_child_2 ("
1682            " m smallint) inherits (test_parent_2)")
1683        for n in range(3):
1684            query("insert into test_parent (n) values (1)")
1685            query("insert into test_child (n, m) values (2, 3)")
1686            query("insert into test_parent_2 (n) values (1)")
1687            query("insert into test_child_2 (n, m) values (2, 3)")
1688        q = ("select (select count(*) from test_parent),"
1689            " (select count(*) from test_child),"
1690            " (select count(*) from test_parent_2),"
1691            " (select count(*) from test_child_2)")
1692        r = query(q).getresult()[0]
1693        self.assertEqual(r, (6, 3, 6, 3))
1694        truncate(['test_parent', 'test_parent_2'], only=[False, True])
1695        r = query(q).getresult()[0]
1696        self.assertEqual(r, (0, 0, 3, 3))
1697        truncate(['test_parent', 'test_parent_2'], only=False)
1698        r = query(q).getresult()[0]
1699        self.assertEqual(r, (0, 0, 0, 0))
1700        self.assertRaises(ValueError, truncate,
1701            ['test_parent*', 'test_child'], only=[True, False])
1702        truncate(['test_parent*', 'test_child'], only=[False, True])
1703
1704    def testTruncateQuoted(self):
1705        truncate = self.db.truncate
1706        query = self.db.query
1707        table = "test table for truncate()"
1708        query('drop table if exists "%s"' % table)
1709        self.addCleanup(query, 'drop table "%s"' % table)
1710        query('create table "%s" (n smallint)' % table)
1711        for i in range(3):
1712            query('insert into "%s" values (1)' % table)
1713        q = 'select count(*) from "%s"' % table
1714        r = query(q).getresult()[0][0]
1715        self.assertEqual(r, 3)
1716        truncate(table)
1717        r = query(q).getresult()[0][0]
1718        self.assertEqual(r, 0)
1719        for i in range(3):
1720            query('insert into "%s" values (1)' % table)
1721        r = query(q).getresult()[0][0]
1722        self.assertEqual(r, 3)
1723        truncate('public."%s"' % table)
1724        r = query(q).getresult()[0][0]
1725        self.assertEqual(r, 0)
1726
1727    def testTransaction(self):
1728        query = self.db.query
1729        query("drop table if exists test_table")
1730        self.addCleanup(query, "drop table test_table")
1731        query("create table test_table (n integer)")
1732        self.db.begin()
1733        query("insert into test_table values (1)")
1734        query("insert into test_table values (2)")
1735        self.db.commit()
1736        self.db.begin()
1737        query("insert into test_table values (3)")
1738        query("insert into test_table values (4)")
1739        self.db.rollback()
1740        self.db.begin()
1741        query("insert into test_table values (5)")
1742        self.db.savepoint('before6')
1743        query("insert into test_table values (6)")
1744        self.db.rollback('before6')
1745        query("insert into test_table values (7)")
1746        self.db.commit()
1747        self.db.begin()
1748        self.db.savepoint('before8')
1749        query("insert into test_table values (8)")
1750        self.db.release('before8')
1751        self.assertRaises(pg.ProgrammingError, self.db.rollback, 'before8')
1752        self.db.commit()
1753        self.db.start()
1754        query("insert into test_table values (9)")
1755        self.db.end()
1756        r = [r[0] for r in query(
1757            "select * from test_table order by 1").getresult()]
1758        self.assertEqual(r, [1, 2, 5, 7, 9])
1759
1760    def testContextManager(self):
1761        query = self.db.query
1762        query("drop table if exists test_table")
1763        self.addCleanup(query, "drop table test_table")
1764        query("create table test_table (n integer check(n>0))")
1765        with self.db:
1766            query("insert into test_table values (1)")
1767            query("insert into test_table values (2)")
1768        try:
1769            with self.db:
1770                query("insert into test_table values (3)")
1771                query("insert into test_table values (4)")
1772                raise ValueError('test transaction should rollback')
1773        except ValueError as error:
1774            self.assertEqual(str(error), 'test transaction should rollback')
1775        with self.db:
1776            query("insert into test_table values (5)")
1777        try:
1778            with self.db:
1779                query("insert into test_table values (6)")
1780                query("insert into test_table values (-1)")
1781        except pg.ProgrammingError as error:
1782            self.assertTrue('check' in str(error))
1783        with self.db:
1784            query("insert into test_table values (7)")
1785        r = [r[0] for r in query(
1786            "select * from test_table order by 1").getresult()]
1787        self.assertEqual(r, [1, 2, 5, 7])
1788
1789    def testBytea(self):
1790        query = self.db.query
1791        query('drop table if exists bytea_test')
1792        self.addCleanup(query, 'drop table bytea_test')
1793        query('create table bytea_test (n smallint primary key, data bytea)')
1794        s = b"It's all \\ kinds \x00 of\r nasty \xff stuff!\n"
1795        r = self.db.escape_bytea(s)
1796        query('insert into bytea_test values(3,$1)', (r,))
1797        r = query('select * from bytea_test where n=3').getresult()
1798        self.assertEqual(len(r), 1)
1799        r = r[0]
1800        self.assertEqual(len(r), 2)
1801        self.assertEqual(r[0], 3)
1802        r = r[1]
1803        self.assertIsInstance(r, str)
1804        r = self.db.unescape_bytea(r)
1805        self.assertIsInstance(r, bytes)
1806        self.assertEqual(r, s)
1807
1808    def testInsertUpdateGetBytea(self):
1809        query = self.db.query
1810        query('drop table if exists bytea_test')
1811        self.addCleanup(query, 'drop table bytea_test')
1812        query('create table bytea_test (n smallint primary key, data bytea)')
1813        # insert null value
1814        r = self.db.insert('bytea_test', n=0, data=None)
1815        self.assertIsInstance(r, dict)
1816        self.assertIn('n', r)
1817        self.assertEqual(r['n'], 0)
1818        self.assertIn('data', r)
1819        self.assertIsNone(r['data'])
1820        s = b'None'
1821        r = self.db.update('bytea_test', n=0, data=s)
1822        self.assertIsInstance(r, dict)
1823        self.assertIn('n', r)
1824        self.assertEqual(r['n'], 0)
1825        self.assertIn('data', r)
1826        r = r['data']
1827        self.assertIsInstance(r, bytes)
1828        self.assertEqual(r, s)
1829        r = self.db.update('bytea_test', n=0, data=None)
1830        self.assertIsNone(r['data'])
1831        # insert as bytes
1832        s = b"It's all \\ kinds \x00 of\r nasty \xff stuff!\n"
1833        r = self.db.insert('bytea_test', n=5, data=s)
1834        self.assertIsInstance(r, dict)
1835        self.assertIn('n', r)
1836        self.assertEqual(r['n'], 5)
1837        self.assertIn('data', r)
1838        r = r['data']
1839        self.assertIsInstance(r, bytes)
1840        self.assertEqual(r, s)
1841        # update as bytes
1842        s += b"and now even more \x00 nasty \t stuff!\f"
1843        r = self.db.update('bytea_test', n=5, data=s)
1844        self.assertIsInstance(r, dict)
1845        self.assertIn('n', r)
1846        self.assertEqual(r['n'], 5)
1847        self.assertIn('data', r)
1848        r = r['data']
1849        self.assertIsInstance(r, bytes)
1850        self.assertEqual(r, s)
1851        r = query('select * from bytea_test where n=5').getresult()
1852        self.assertEqual(len(r), 1)
1853        r = r[0]
1854        self.assertEqual(len(r), 2)
1855        self.assertEqual(r[0], 5)
1856        r = r[1]
1857        self.assertIsInstance(r, str)
1858        r = self.db.unescape_bytea(r)
1859        self.assertIsInstance(r, bytes)
1860        self.assertEqual(r, s)
1861        r = self.db.get('bytea_test', dict(n=5))
1862        self.assertIsInstance(r, dict)
1863        self.assertIn('n', r)
1864        self.assertEqual(r['n'], 5)
1865        self.assertIn('data', r)
1866        r = r['data']
1867        self.assertIsInstance(r, bytes)
1868        self.assertEqual(r, s)
1869
1870
1871class TestDBClassNonStdOpts(TestDBClass):
1872    """Test the methods of the DB class with non-standard global options."""
1873
1874    @classmethod
1875    def setUpClass(cls):
1876        cls.saved_options = {}
1877        cls.set_option('decimal', float)
1878        not_bool = not pg.get_bool()
1879        cls.set_option('bool', not_bool)
1880        unnamed_result = lambda q: q.getresult()
1881        cls.set_option('namedresult', unnamed_result)
1882        super(TestDBClassNonStdOpts, cls).setUpClass()
1883
1884    @classmethod
1885    def tearDownClass(cls):
1886        super(TestDBClassNonStdOpts, cls).tearDownClass()
1887        cls.reset_option('namedresult')
1888        cls.reset_option('bool')
1889        cls.reset_option('decimal')
1890
1891    @classmethod
1892    def set_option(cls, option, value):
1893        cls.saved_options[option] = getattr(pg, 'get_' + option)()
1894        return getattr(pg, 'set_' + option)(value)
1895
1896    @classmethod
1897    def reset_option(cls, option):
1898        return getattr(pg, 'set_' + option)(cls.saved_options[option])
1899
1900
1901class TestSchemas(unittest.TestCase):
1902    """Test correct handling of schemas (namespaces)."""
1903
1904    @classmethod
1905    def setUpClass(cls):
1906        db = DB()
1907        query = db.query
1908        query("set client_min_messages=warning")
1909        for num_schema in range(5):
1910            if num_schema:
1911                schema = "s%d" % num_schema
1912                query("drop schema if exists %s cascade" % (schema,))
1913                try:
1914                    query("create schema %s" % (schema,))
1915                except pg.ProgrammingError:
1916                    raise RuntimeError("The test user cannot create schemas.\n"
1917                        "Grant create on database %s to the user"
1918                        " for running these tests." % dbname)
1919            else:
1920                schema = "public"
1921                query("drop table if exists %s.t" % (schema,))
1922                query("drop table if exists %s.t%d" % (schema, num_schema))
1923            query("create table %s.t with oids as select 1 as n, %d as d"
1924                  % (schema, num_schema))
1925            query("create table %s.t%d with oids as select 1 as n, %d as d"
1926                  % (schema, num_schema, num_schema))
1927        db.close()
1928
1929    @classmethod
1930    def tearDownClass(cls):
1931        db = DB()
1932        query = db.query
1933        query("set client_min_messages=warning")
1934        for num_schema in range(5):
1935            if num_schema:
1936                schema = "s%d" % num_schema
1937                query("drop schema %s cascade" % (schema,))
1938            else:
1939                schema = "public"
1940                query("drop table %s.t" % (schema,))
1941                query("drop table %s.t%d" % (schema, num_schema))
1942        db.close()
1943
1944    def setUp(self):
1945        self.db = DB()
1946        self.db.query("set client_min_messages=warning")
1947
1948    def tearDown(self):
1949        self.doCleanups()
1950        self.db.close()
1951
1952    def testGetTables(self):
1953        tables = self.db.get_tables()
1954        for num_schema in range(5):
1955            if num_schema:
1956                schema = "s" + str(num_schema)
1957            else:
1958                schema = "public"
1959            for t in (schema + ".t",
1960                    schema + ".t" + str(num_schema)):
1961                self.assertIn(t, tables)
1962
1963    def testGetAttnames(self):
1964        get_attnames = self.db.get_attnames
1965        query = self.db.query
1966        result = {'oid': 'int', 'd': 'int', 'n': 'int'}
1967        r = get_attnames("t")
1968        self.assertEqual(r, result)
1969        r = get_attnames("s4.t4")
1970        self.assertEqual(r, result)
1971        query("drop table if exists s3.t3m")
1972        self.addCleanup(query, "drop table s3.t3m")
1973        query("create table s3.t3m with oids as select 1 as m")
1974        result_m = {'oid': 'int', 'm': 'int'}
1975        r = get_attnames("s3.t3m")
1976        self.assertEqual(r, result_m)
1977        query("set search_path to s1,s3")
1978        r = get_attnames("t3")
1979        self.assertEqual(r, result)
1980        r = get_attnames("t3m")
1981        self.assertEqual(r, result_m)
1982
1983    def testGet(self):
1984        get = self.db.get
1985        query = self.db.query
1986        PrgError = pg.ProgrammingError
1987        self.assertEqual(get("t", 1, 'n')['d'], 0)
1988        self.assertEqual(get("t0", 1, 'n')['d'], 0)
1989        self.assertEqual(get("public.t", 1, 'n')['d'], 0)
1990        self.assertEqual(get("public.t0", 1, 'n')['d'], 0)
1991        self.assertRaises(PrgError, get, "public.t1", 1, 'n')
1992        self.assertEqual(get("s1.t1", 1, 'n')['d'], 1)
1993        self.assertEqual(get("s3.t", 1, 'n')['d'], 3)
1994        query("set search_path to s2,s4")
1995        self.assertRaises(PrgError, get, "t1", 1, 'n')
1996        self.assertEqual(get("t4", 1, 'n')['d'], 4)
1997        self.assertRaises(PrgError, get, "t3", 1, 'n')
1998        self.assertEqual(get("t", 1, 'n')['d'], 2)
1999        self.assertEqual(get("s3.t3", 1, 'n')['d'], 3)
2000        query("set search_path to s1,s3")
2001        self.assertRaises(PrgError, get, "t2", 1, 'n')
2002        self.assertEqual(get("t3", 1, 'n')['d'], 3)
2003        self.assertRaises(PrgError, get, "t4", 1, 'n')
2004        self.assertEqual(get("t", 1, 'n')['d'], 1)
2005        self.assertEqual(get("s4.t4", 1, 'n')['d'], 4)
2006
2007    def testMunging(self):
2008        get = self.db.get
2009        query = self.db.query
2010        r = get("t", 1, 'n')
2011        self.assertIn('oid(t)', r)
2012        query("set search_path to s2")
2013        r = get("t2", 1, 'n')
2014        self.assertIn('oid(t2)', r)
2015        query("set search_path to s3")
2016        r = get("t", 1, 'n')
2017        self.assertIn('oid(t)', r)
2018
2019
2020class TestDebug(unittest.TestCase):
2021    """Test the debug attribute of the DB class."""
2022
2023    def setUp(self):
2024        self.db = DB()
2025        self.query = self.db.query
2026        self.debug = self.db.debug
2027        self.output = StringIO()
2028        self.stdout, sys.stdout = sys.stdout, self.output
2029
2030    def tearDown(self):
2031        sys.stdout = self.stdout
2032        self.output.close()
2033        self.db.debug = debug
2034        self.db.close()
2035
2036    def get_output(self):
2037        return self.output.getvalue()
2038
2039    def send_queries(self):
2040        self.db.query("select 1")
2041        self.db.query("select 2")
2042
2043    def testDebugDefault(self):
2044        if debug:
2045            self.assertEqual(self.db.debug, debug)
2046        else:
2047            self.assertIsNone(self.db.debug)
2048
2049    def testDebugIsFalse(self):
2050        self.db.debug = False
2051        self.send_queries()
2052        self.assertEqual(self.get_output(), "")
2053
2054    def testDebugIsTrue(self):
2055        self.db.debug = True
2056        self.send_queries()
2057        self.assertEqual(self.get_output(), "select 1\nselect 2\n")
2058
2059    def testDebugIsString(self):
2060        self.db.debug = "Test with string: %s."
2061        self.send_queries()
2062        self.assertEqual(self.get_output(),
2063            "Test with string: select 1.\nTest with string: select 2.\n")
2064
2065    def testDebugIsFileLike(self):
2066        with tempfile.TemporaryFile('w+') as debug_file:
2067            self.db.debug = debug_file
2068            self.send_queries()
2069            debug_file.seek(0)
2070            output = debug_file.read()
2071            self.assertEqual(output, "select 1\nselect 2\n")
2072            self.assertEqual(self.get_output(), "")
2073
2074    def testDebugIsCallable(self):
2075        output = []
2076        self.db.debug = output.append
2077        self.db.query("select 1")
2078        self.db.query("select 2")
2079        self.assertEqual(output, ["select 1", "select 2"])
2080        self.assertEqual(self.get_output(), "")
2081
2082
2083if __name__ == '__main__':
2084    unittest.main()
Note: See TracBrowser for help on using the repository browser.