source: trunk/tests/test_classic_dbwrapper.py @ 760

Last change on this file since 760 was 760, checked in by cito, 4 years ago

Test reset() and reopen() methods of DB

  • Property svn:executable set to *
  • Property svn:keywords set to Id
File size: 82.7 KB
Line 
1#! /usr/bin/python
2# -*- coding: utf-8 -*-
3
4"""Test the classic PyGreSQL interface.
5
6Sub-tests for the DB wrapper object.
7
8Contributed by Christoph Zwerschke.
9
10These tests need a database to test against.
11
12"""
13
14try:
15    import unittest2 as unittest  # for Python < 2.7
16except ImportError:
17    import unittest
18
19import os
20import sys
21import tempfile
22
23import pg  # the module under test
24
25from decimal import Decimal
26
27# We need a database to test against.  If LOCAL_PyGreSQL.py exists we will
28# get our information from that.  Otherwise we use the defaults.
29# The current user must have create schema privilege on the database.
30dbname = 'unittest'
31dbhost = None
32dbport = 5432
33
34debug = False  # let DB wrapper print debugging output
35
36try:
37    from .LOCAL_PyGreSQL import *
38except (ImportError, ValueError):
39    try:
40        from LOCAL_PyGreSQL import *
41    except ImportError:
42        pass
43
44try:
45    long
46except NameError:  # Python >= 3.0
47    long = int
48
49try:
50    unicode
51except NameError:  # Python >= 3.0
52    unicode = str
53
54try:
55    from collections import OrderedDict
56except ImportError:  # Python 2.6 or 3.0
57    OrderedDict = dict
58
59if str is bytes:
60    from StringIO import StringIO
61else:
62    from io import StringIO
63
64windows = os.name == 'nt'
65
66# There is a known a bug in libpq under Windows which can cause
67# the interface to crash when calling PQhost():
68do_not_ask_for_host = windows
69do_not_ask_for_host_reason = 'libpq issue on Windows'
70
71
72def DB():
73    """Create a DB wrapper object connecting to the test database."""
74    db = pg.DB(dbname, dbhost, dbport)
75    if debug:
76        db.debug = debug
77    db.query("set client_min_messages=warning")
78    return db
79
80
81class TestDBClassBasic(unittest.TestCase):
82    """Test existence of the DB class wrapped pg connection methods."""
83
84    def setUp(self):
85        self.db = DB()
86
87    def tearDown(self):
88        try:
89            self.db.close()
90        except pg.InternalError:
91            pass
92
93    def testAllDBAttributes(self):
94        attributes = [
95            'begin',
96            'cancel', 'clear', 'close', 'commit',
97            'db', 'dbname', 'debug', 'delete',
98            'end', 'endcopy', 'error',
99            'escape_bytea', 'escape_identifier',
100            'escape_literal', 'escape_string',
101            'fileno',
102            'get', 'get_attnames', 'get_databases',
103            'get_notice_receiver', 'get_parameter',
104            'get_relations', 'get_tables',
105            'getline', 'getlo', 'getnotify',
106            'has_table_privilege', 'host',
107            'insert', 'inserttable',
108            'locreate', 'loimport',
109            'notification_handler',
110            'options',
111            'parameter', 'pkey', 'port',
112            'protocol_version', 'putline',
113            'query',
114            'release', 'reopen', 'reset', 'rollback',
115            'savepoint', 'server_version',
116            'set_notice_receiver', 'set_parameter',
117            'source', 'start', 'status',
118            'transaction', 'truncate',
119            'unescape_bytea', 'update', 'upsert',
120            'use_regtypes', 'user',
121        ]
122        db_attributes = [a for a in dir(self.db)
123            if not a.startswith('_')]
124        self.assertEqual(attributes, db_attributes)
125
126    def testAttributeDb(self):
127        self.assertEqual(self.db.db.db, dbname)
128
129    def testAttributeDbname(self):
130        self.assertEqual(self.db.dbname, dbname)
131
132    def testAttributeError(self):
133        error = self.db.error
134        self.assertTrue(not error or 'krb5_' in error)
135        self.assertEqual(self.db.error, self.db.db.error)
136
137    @unittest.skipIf(do_not_ask_for_host, do_not_ask_for_host_reason)
138    def testAttributeHost(self):
139        def_host = 'localhost'
140        host = self.db.host
141        self.assertIsInstance(host, str)
142        self.assertEqual(host, dbhost or def_host)
143        self.assertEqual(host, self.db.db.host)
144
145    def testAttributeOptions(self):
146        no_options = ''
147        options = self.db.options
148        self.assertEqual(options, no_options)
149        self.assertEqual(options, self.db.db.options)
150
151    def testAttributePort(self):
152        def_port = 5432
153        port = self.db.port
154        self.assertIsInstance(port, int)
155        self.assertEqual(port, dbport or def_port)
156        self.assertEqual(port, self.db.db.port)
157
158    def testAttributeProtocolVersion(self):
159        protocol_version = self.db.protocol_version
160        self.assertIsInstance(protocol_version, int)
161        self.assertTrue(2 <= protocol_version < 4)
162        self.assertEqual(protocol_version, self.db.db.protocol_version)
163
164    def testAttributeServerVersion(self):
165        server_version = self.db.server_version
166        self.assertIsInstance(server_version, int)
167        self.assertTrue(70400 <= server_version < 100000)
168        self.assertEqual(server_version, self.db.db.server_version)
169
170    def testAttributeStatus(self):
171        status_ok = 1
172        status = self.db.status
173        self.assertIsInstance(status, int)
174        self.assertEqual(status, status_ok)
175        self.assertEqual(status, self.db.db.status)
176
177    def testAttributeUser(self):
178        no_user = 'Deprecated facility'
179        user = self.db.user
180        self.assertTrue(user)
181        self.assertIsInstance(user, str)
182        self.assertNotEqual(user, no_user)
183        self.assertEqual(user, self.db.db.user)
184
185    def testMethodEscapeLiteral(self):
186        self.assertEqual(self.db.escape_literal(''), "''")
187
188    def testMethodEscapeIdentifier(self):
189        self.assertEqual(self.db.escape_identifier(''), '""')
190
191    def testMethodEscapeString(self):
192        self.assertEqual(self.db.escape_string(''), '')
193
194    def testMethodEscapeBytea(self):
195        self.assertEqual(self.db.escape_bytea('').replace(
196            '\\x', '').replace('\\', ''), '')
197
198    def testMethodUnescapeBytea(self):
199        self.assertEqual(self.db.unescape_bytea(''), b'')
200
201    def testMethodQuery(self):
202        query = self.db.query
203        query("select 1+1")
204        query("select 1+$1+$2", 2, 3)
205        query("select 1+$1+$2", (2, 3))
206        query("select 1+$1+$2", [2, 3])
207        query("select 1+$1", 1)
208
209    def testMethodQueryEmpty(self):
210        self.assertRaises(ValueError, self.db.query, '')
211
212    def testMethodQueryProgrammingError(self):
213        try:
214            self.db.query("select 1/0")
215        except pg.ProgrammingError as error:
216            self.assertEqual(error.sqlstate, '22012')
217
218    def testMethodEndcopy(self):
219        try:
220            self.db.endcopy()
221        except IOError:
222            pass
223
224    def testMethodClose(self):
225        self.db.close()
226        try:
227            self.db.reset()
228        except pg.Error:
229            pass
230        else:
231            self.fail('Reset should give an error for a closed connection')
232        self.assertRaises(pg.InternalError, self.db.close)
233        self.assertRaises(pg.InternalError, self.db.query, 'select 1')
234
235    def testMethodReset(self):
236        con = self.db.db
237        self.db.reset()
238        self.assertIs(self.db.db, con)
239        self.db.query("select 1+1")
240        self.db.close()
241
242    def testMethodReopen(self):
243        con = self.db.db
244        self.db.reopen()
245        self.assertIsNot(self.db.db, con)
246        self.db.query("select 1+1")
247        self.db.close()
248
249    def testExistingConnection(self):
250        db = pg.DB(self.db.db)
251        self.assertEqual(self.db.db, db.db)
252        self.assertTrue(db.db)
253        db.close()
254        self.assertTrue(db.db)
255        db.reopen()
256        self.assertTrue(db.db)
257        db.close()
258        self.assertTrue(db.db)
259        db = pg.DB(self.db)
260        self.assertEqual(self.db.db, db.db)
261        db = pg.DB(db=self.db.db)
262        self.assertEqual(self.db.db, db.db)
263
264        class DB2:
265            pass
266
267        db2 = DB2()
268        db2._cnx = self.db.db
269        db = pg.DB(db2)
270        self.assertEqual(self.db.db, db.db)
271
272
273class TestDBClass(unittest.TestCase):
274    """Test the methods of the DB class wrapped pg connection."""
275
276    @classmethod
277    def setUpClass(cls):
278        db = DB()
279        db.query("drop table if exists test cascade")
280        db.query("create table test ("
281            "i2 smallint, i4 integer, i8 bigint,"
282            " d numeric, f4 real, f8 double precision, m money,"
283            " v4 varchar(4), c4 char(4), t text)")
284        db.query("create or replace view test_view as"
285            " select i4, v4 from test")
286        db.close()
287
288    @classmethod
289    def tearDownClass(cls):
290        db = DB()
291        db.query("drop table test cascade")
292        db.close()
293
294    def setUp(self):
295        self.db = DB()
296        query = self.db.query
297        query('set client_encoding=utf8')
298        query('set standard_conforming_strings=on')
299        query("set lc_monetary='C'")
300        query("set datestyle='ISO,YMD'")
301        query('set bytea_output=hex')
302
303    def tearDown(self):
304        self.doCleanups()
305        self.db.close()
306
307    def testClassName(self):
308        self.assertEqual(self.db.__class__.__name__, 'DB')
309
310    def testModuleName(self):
311        self.assertEqual(self.db.__module__, 'pg')
312        self.assertEqual(self.db.__class__.__module__, 'pg')
313
314    def testEscapeLiteral(self):
315        f = self.db.escape_literal
316        r = f(b"plain")
317        self.assertIsInstance(r, bytes)
318        self.assertEqual(r, b"'plain'")
319        r = f(u"plain")
320        self.assertIsInstance(r, unicode)
321        self.assertEqual(r, u"'plain'")
322        r = f(u"that's kÀse".encode('utf-8'))
323        self.assertIsInstance(r, bytes)
324        self.assertEqual(r, u"'that''s kÀse'".encode('utf-8'))
325        r = f(u"that's kÀse")
326        self.assertIsInstance(r, unicode)
327        self.assertEqual(r, u"'that''s kÀse'")
328        self.assertEqual(f(r"It's fine to have a \ inside."),
329            r" E'It''s fine to have a \\ inside.'")
330        self.assertEqual(f('No "quotes" must be escaped.'),
331            "'No \"quotes\" must be escaped.'")
332
333    def testEscapeIdentifier(self):
334        f = self.db.escape_identifier
335        r = f(b"plain")
336        self.assertIsInstance(r, bytes)
337        self.assertEqual(r, b'"plain"')
338        r = f(u"plain")
339        self.assertIsInstance(r, unicode)
340        self.assertEqual(r, u'"plain"')
341        r = f(u"that's kÀse".encode('utf-8'))
342        self.assertIsInstance(r, bytes)
343        self.assertEqual(r, u'"that\'s kÀse"'.encode('utf-8'))
344        r = f(u"that's kÀse")
345        self.assertIsInstance(r, unicode)
346        self.assertEqual(r, u'"that\'s kÀse"')
347        self.assertEqual(f(r"It's fine to have a \ inside."),
348            '"It\'s fine to have a \\ inside."')
349        self.assertEqual(f('All "quotes" must be escaped.'),
350            '"All ""quotes"" must be escaped."')
351
352    def testEscapeString(self):
353        f = self.db.escape_string
354        r = f(b"plain")
355        self.assertIsInstance(r, bytes)
356        self.assertEqual(r, b"plain")
357        r = f(u"plain")
358        self.assertIsInstance(r, unicode)
359        self.assertEqual(r, u"plain")
360        r = f(u"that's kÀse".encode('utf-8'))
361        self.assertIsInstance(r, bytes)
362        self.assertEqual(r, u"that''s kÀse".encode('utf-8'))
363        r = f(u"that's kÀse")
364        self.assertIsInstance(r, unicode)
365        self.assertEqual(r, u"that''s kÀse")
366        self.assertEqual(f(r"It's fine to have a \ inside."),
367            r"It''s fine to have a \ inside.")
368
369    def testEscapeBytea(self):
370        f = self.db.escape_bytea
371        # note that escape_byte always returns hex output since Pg 9.0,
372        # regardless of the bytea_output setting
373        r = f(b'plain')
374        self.assertIsInstance(r, bytes)
375        self.assertEqual(r, b'\\x706c61696e')
376        r = f(u'plain')
377        self.assertIsInstance(r, unicode)
378        self.assertEqual(r, u'\\x706c61696e')
379        r = f(u"das is' kÀse".encode('utf-8'))
380        self.assertIsInstance(r, bytes)
381        self.assertEqual(r, b'\\x64617320697327206bc3a47365')
382        r = f(u"das is' kÀse")
383        self.assertIsInstance(r, unicode)
384        self.assertEqual(r, u'\\x64617320697327206bc3a47365')
385        self.assertEqual(f(b'O\x00ps\xff!'), b'\\x4f007073ff21')
386
387    def testUnescapeBytea(self):
388        f = self.db.unescape_bytea
389        r = f(b'plain')
390        self.assertIsInstance(r, bytes)
391        self.assertEqual(r, b'plain')
392        r = f(u'plain')
393        self.assertIsInstance(r, bytes)
394        self.assertEqual(r, b'plain')
395        r = f(b"das is' k\\303\\244se")
396        self.assertIsInstance(r, bytes)
397        self.assertEqual(r, u"das is' kÀse".encode('utf8'))
398        r = f(u"das is' k\\303\\244se")
399        self.assertIsInstance(r, bytes)
400        self.assertEqual(r, u"das is' kÀse".encode('utf8'))
401        self.assertEqual(f(r'O\\000ps\\377!'), b'O\\000ps\\377!')
402        self.assertEqual(f(r'\\x706c61696e'), b'\\x706c61696e')
403        self.assertEqual(f(r'\\x746861742773206be47365'),
404            b'\\x746861742773206be47365')
405        self.assertEqual(f(r'\\x4f007073ff21'), b'\\x4f007073ff21')
406
407    def testGetParameter(self):
408        f = self.db.get_parameter
409        self.assertRaises(TypeError, f)
410        self.assertRaises(TypeError, f, None)
411        self.assertRaises(TypeError, f, 42)
412        self.assertRaises(pg.ProgrammingError, f, 'this_does_not_exist')
413        r = f('standard_conforming_strings')
414        self.assertEqual(r, 'on')
415        r = f('lc_monetary')
416        self.assertEqual(r, 'C')
417        r = f('datestyle')
418        self.assertEqual(r, 'ISO, YMD')
419        r = f('bytea_output')
420        self.assertEqual(r, 'hex')
421        r = f(['bytea_output', 'lc_monetary'])
422        self.assertIsInstance(r, list)
423        self.assertEqual(r, ['hex', 'C'])
424        r = f(('standard_conforming_strings', 'datestyle', 'bytea_output'))
425        self.assertEqual(r, ['on', 'ISO, YMD', 'hex'])
426        r = f(set(['bytea_output', 'lc_monetary']))
427        self.assertIsInstance(r, dict)
428        self.assertEqual(r, {'bytea_output': 'hex', 'lc_monetary': 'C'})
429        r = f(set(['Bytea_Output', ' LC_Monetary ']))
430        self.assertIsInstance(r, dict)
431        self.assertEqual(r, {'Bytea_Output': 'hex', ' LC_Monetary ': 'C'})
432        s = dict.fromkeys(('bytea_output', 'lc_monetary'))
433        r = f(s)
434        self.assertIs(r, s)
435        self.assertEqual(r, {'bytea_output': 'hex', 'lc_monetary': 'C'})
436        s = dict.fromkeys(('Bytea_Output', ' LC_Monetary '))
437        r = f(s)
438        self.assertIs(r, s)
439        self.assertEqual(r, {'Bytea_Output': 'hex', ' LC_Monetary ': 'C'})
440
441    def testGetParameterServerVersion(self):
442        r = self.db.get_parameter('server_version_num')
443        self.assertIsInstance(r, str)
444        s = self.db.server_version
445        self.assertIsInstance(s, int)
446        self.assertEqual(r, str(s))
447
448    def testGetParameterAll(self):
449        f = self.db.get_parameter
450        r = f('all')
451        self.assertIsInstance(r, dict)
452        self.assertEqual(r['standard_conforming_strings'], 'on')
453        self.assertEqual(r['lc_monetary'], 'C')
454        self.assertEqual(r['DateStyle'], 'ISO, YMD')
455        self.assertEqual(r['bytea_output'], 'hex')
456
457    def testSetParameter(self):
458        f = self.db.set_parameter
459        g = self.db.get_parameter
460        self.assertRaises(TypeError, f)
461        self.assertRaises(TypeError, f, None)
462        self.assertRaises(TypeError, f, 42)
463        self.assertRaises(pg.ProgrammingError, f, 'this_does_not_exist')
464        f('standard_conforming_strings', 'off')
465        self.assertEqual(g('standard_conforming_strings'), 'off')
466        f('datestyle', 'ISO, DMY')
467        self.assertEqual(g('datestyle'), 'ISO, DMY')
468        f(['standard_conforming_strings', 'datestyle'], ['on', 'ISO, DMY'])
469        self.assertEqual(g('standard_conforming_strings'), 'on')
470        self.assertEqual(g('datestyle'), 'ISO, DMY')
471        f(['default_with_oids', 'standard_conforming_strings'], 'off')
472        self.assertEqual(g('default_with_oids'), 'off')
473        self.assertEqual(g('standard_conforming_strings'), 'off')
474        f(('standard_conforming_strings', 'datestyle'), ('on', 'ISO, YMD'))
475        self.assertEqual(g('standard_conforming_strings'), 'on')
476        self.assertEqual(g('datestyle'), 'ISO, YMD')
477        f(('default_with_oids', 'standard_conforming_strings'), 'off')
478        self.assertEqual(g('default_with_oids'), 'off')
479        self.assertEqual(g('standard_conforming_strings'), 'off')
480        f(set(['default_with_oids', 'standard_conforming_strings']), 'on')
481        self.assertEqual(g('default_with_oids'), 'on')
482        self.assertEqual(g('standard_conforming_strings'), 'on')
483        self.assertRaises(ValueError, f, set([ 'default_with_oids',
484            'standard_conforming_strings']), ['off', 'on'])
485        f(set(['default_with_oids', 'standard_conforming_strings']),
486            ['off', 'off'])
487        self.assertEqual(g('default_with_oids'), 'off')
488        self.assertEqual(g('standard_conforming_strings'), 'off')
489        f({'standard_conforming_strings': 'on', 'datestyle': 'ISO, YMD'})
490        self.assertEqual(g('standard_conforming_strings'), 'on')
491        self.assertEqual(g('datestyle'), 'ISO, YMD')
492
493    def testResetParameter(self):
494        db = DB()
495        f = db.set_parameter
496        g = db.get_parameter
497        r = g('default_with_oids')
498        self.assertIn(r, ('on', 'off'))
499        dwi, not_dwi = r, 'off' if r == 'on' else 'on'
500        r = g('standard_conforming_strings')
501        self.assertIn(r, ('on', 'off'))
502        scs, not_scs = r, 'off' if r == 'on' else 'on'
503        f('default_with_oids', not_dwi)
504        f('standard_conforming_strings', not_scs)
505        self.assertEqual(g('default_with_oids'), not_dwi)
506        self.assertEqual(g('standard_conforming_strings'), not_scs)
507        f('default_with_oids')
508        f('standard_conforming_strings', None)
509        self.assertEqual(g('default_with_oids'), dwi)
510        self.assertEqual(g('standard_conforming_strings'), scs)
511        f('default_with_oids', not_dwi)
512        f('standard_conforming_strings', not_scs)
513        self.assertEqual(g('default_with_oids'), not_dwi)
514        self.assertEqual(g('standard_conforming_strings'), not_scs)
515        f(['default_with_oids', 'standard_conforming_strings'], None)
516        self.assertEqual(g('default_with_oids'), dwi)
517        self.assertEqual(g('standard_conforming_strings'), scs)
518        f('default_with_oids', not_dwi)
519        f('standard_conforming_strings', not_scs)
520        self.assertEqual(g('default_with_oids'), not_dwi)
521        self.assertEqual(g('standard_conforming_strings'), not_scs)
522        f(('default_with_oids', 'standard_conforming_strings'))
523        self.assertEqual(g('default_with_oids'), dwi)
524        self.assertEqual(g('standard_conforming_strings'), scs)
525        f('default_with_oids', not_dwi)
526        f('standard_conforming_strings', not_scs)
527        self.assertEqual(g('default_with_oids'), not_dwi)
528        self.assertEqual(g('standard_conforming_strings'), not_scs)
529        f(set(['default_with_oids', 'standard_conforming_strings']))
530        self.assertEqual(g('default_with_oids'), dwi)
531        self.assertEqual(g('standard_conforming_strings'), scs)
532
533    def testResetParameterAll(self):
534        db = DB()
535        f = db.set_parameter
536        self.assertRaises(ValueError, f, 'all', 0)
537        self.assertRaises(ValueError, f, 'all', 'off')
538        g = db.get_parameter
539        r = g('default_with_oids')
540        self.assertIn(r, ('on', 'off'))
541        dwi, not_dwi = r, 'off' if r == 'on' else 'on'
542        r = g('standard_conforming_strings')
543        self.assertIn(r, ('on', 'off'))
544        scs, not_scs = r, 'off' if r == 'on' else 'on'
545        f('default_with_oids', not_dwi)
546        f('standard_conforming_strings', not_scs)
547        self.assertEqual(g('default_with_oids'), not_dwi)
548        self.assertEqual(g('standard_conforming_strings'), not_scs)
549        f('all')
550        self.assertEqual(g('default_with_oids'), dwi)
551        self.assertEqual(g('standard_conforming_strings'), scs)
552
553    def testSetParameterLocal(self):
554        f = self.db.set_parameter
555        g = self.db.get_parameter
556        self.assertEqual(g('standard_conforming_strings'), 'on')
557        self.db.begin()
558        f('standard_conforming_strings', 'off', local=True)
559        self.assertEqual(g('standard_conforming_strings'), 'off')
560        self.db.end()
561        self.assertEqual(g('standard_conforming_strings'), 'on')
562
563    def testSetParameterSession(self):
564        f = self.db.set_parameter
565        g = self.db.get_parameter
566        self.assertEqual(g('standard_conforming_strings'), 'on')
567        self.db.begin()
568        f('standard_conforming_strings', 'off', local=False)
569        self.assertEqual(g('standard_conforming_strings'), 'off')
570        self.db.end()
571        self.assertEqual(g('standard_conforming_strings'), 'off')
572
573    def testReset(self):
574        db = DB()
575        default_datestyle = db.get_parameter('datestyle')
576        changed_datestyle = 'ISO, DMY'
577        if changed_datestyle == default_datestyle:
578            changed_datestyle == 'ISO, YMD'
579        self.db.set_parameter('datestyle', changed_datestyle)
580        r = self.db.get_parameter('datestyle')
581        self.assertEqual(r, changed_datestyle)
582        con = self.db.db
583        q = con.query("show datestyle")
584        self.db.reset()
585        r = q.getresult()[0][0]
586        self.assertEqual(r, changed_datestyle)
587        q = con.query("show datestyle")
588        r = q.getresult()[0][0]
589        self.assertEqual(r, default_datestyle)
590        r = self.db.get_parameter('datestyle')
591        self.assertEqual(r, default_datestyle)
592
593    def testReopen(self):
594        db = DB()
595        default_datestyle = db.get_parameter('datestyle')
596        changed_datestyle = 'ISO, DMY'
597        if changed_datestyle == default_datestyle:
598            changed_datestyle == 'ISO, YMD'
599        self.db.set_parameter('datestyle', changed_datestyle)
600        r = self.db.get_parameter('datestyle')
601        self.assertEqual(r, changed_datestyle)
602        con = self.db.db
603        q = con.query("show datestyle")
604        self.db.reopen()
605        r = q.getresult()[0][0]
606        self.assertEqual(r, changed_datestyle)
607        self.assertRaises(TypeError, getattr, con, 'query')
608        r = self.db.get_parameter('datestyle')
609        self.assertEqual(r, default_datestyle)
610
611    def testQuery(self):
612        query = self.db.query
613        query("drop table if exists test_table")
614        self.addCleanup(query, "drop table test_table")
615        q = "create table test_table (n integer) with oids"
616        r = query(q)
617        self.assertIsNone(r)
618        q = "insert into test_table values (1)"
619        r = query(q)
620        self.assertIsInstance(r, int)
621        q = "insert into test_table select 2"
622        r = query(q)
623        self.assertIsInstance(r, int)
624        oid = r
625        q = "select oid from test_table where n=2"
626        r = query(q).getresult()
627        self.assertEqual(len(r), 1)
628        r = r[0]
629        self.assertEqual(len(r), 1)
630        r = r[0]
631        self.assertEqual(r, oid)
632        q = "insert into test_table select 3 union select 4 union select 5"
633        r = query(q)
634        self.assertIsInstance(r, str)
635        self.assertEqual(r, '3')
636        q = "update test_table set n=4 where n<5"
637        r = query(q)
638        self.assertIsInstance(r, str)
639        self.assertEqual(r, '4')
640        q = "delete from test_table"
641        r = query(q)
642        self.assertIsInstance(r, str)
643        self.assertEqual(r, '5')
644
645    def testMultipleQueries(self):
646        self.assertEqual(self.db.query(
647            "create temporary table test_multi (n integer);"
648            "insert into test_multi values (4711);"
649            "select n from test_multi").getresult()[0][0], 4711)
650
651    def testQueryWithParams(self):
652        query = self.db.query
653        query("drop table if exists test_table")
654        self.addCleanup(query, "drop table test_table")
655        q = "create table test_table (n1 integer, n2 integer) with oids"
656        query(q)
657        q = "insert into test_table values ($1, $2)"
658        r = query(q, (1, 2))
659        self.assertIsInstance(r, int)
660        r = query(q, [3, 4])
661        self.assertIsInstance(r, int)
662        r = query(q, [5, 6])
663        self.assertIsInstance(r, int)
664        q = "select * from test_table order by 1, 2"
665        self.assertEqual(query(q).getresult(),
666            [(1, 2), (3, 4), (5, 6)])
667        q = "select * from test_table where n1=$1 and n2=$2"
668        self.assertEqual(query(q, 3, 4).getresult(), [(3, 4)])
669        q = "update test_table set n2=$2 where n1=$1"
670        r = query(q, 3, 7)
671        self.assertEqual(r, '1')
672        q = "select * from test_table order by 1, 2"
673        self.assertEqual(query(q).getresult(),
674            [(1, 2), (3, 7), (5, 6)])
675        q = "delete from test_table where n2!=$1"
676        r = query(q, 4)
677        self.assertEqual(r, '3')
678
679    def testEmptyQuery(self):
680        self.assertRaises(ValueError, self.db.query, '')
681
682    def testQueryProgrammingError(self):
683        try:
684            self.db.query("select 1/0")
685        except pg.ProgrammingError as error:
686            self.assertEqual(error.sqlstate, '22012')
687
688    def testPkey(self):
689        query = self.db.query
690        pkey = self.db.pkey
691        for t in ('pkeytest', 'primary key test'):
692            for n in range(7):
693                query('drop table if exists "%s%d"' % (t, n))
694                self.addCleanup(query, 'drop table "%s%d"' % (t, n))
695            query('create table "%s0" ('
696                "a smallint)" % t)
697            query('create table "%s1" ('
698                "b smallint primary key)" % t)
699            query('create table "%s2" ('
700                "c smallint, d smallint primary key)" % t)
701            query('create table "%s3" ('
702                "e smallint, f smallint, g smallint,"
703                " h smallint, i smallint,"
704                " primary key (f, h))" % t)
705            query('create table "%s4" ('
706                "more_than_one_letter varchar primary key)" % t)
707            query('create table "%s5" ('
708                '"with space" date primary key)' % t)
709            query('create table "%s6" ('
710                'a_very_long_column_name varchar,'
711                ' "with space" date,'
712                ' "42" int,'
713                " primary key (a_very_long_column_name,"
714                ' "with space", "42"))' % t)
715            self.assertRaises(KeyError, pkey, '%s0' % t)
716            self.assertEqual(pkey('%s1' % t), 'b')
717            self.assertEqual(pkey('%s2' % t), 'd')
718            r = pkey('%s3' % t)
719            self.assertIsInstance(r, frozenset)
720            self.assertEqual(r, frozenset('fh'))
721            self.assertEqual(pkey('%s4' % t), 'more_than_one_letter')
722            self.assertEqual(pkey('%s5' % t), 'with space')
723            r = pkey('%s6' % t)
724            self.assertIsInstance(r, frozenset)
725            self.assertEqual(r, frozenset([
726                'a_very_long_column_name', 'with space', '42']))
727            # a newly added primary key will be detected
728            query('alter table "%s0" add primary key (a)' % t)
729            self.assertEqual(pkey('%s0' % t), 'a')
730            # a changed primary key will not be detected,
731            # indicating that the internal cache is operating
732            query('alter table "%s1" rename column b to x' % t)
733            self.assertEqual(pkey('%s1' % t), 'b')
734            # we get the changed primary key when the cache is flushed
735            self.assertEqual(pkey('%s1' % t, flush=True), 'x')
736
737    def testGetDatabases(self):
738        databases = self.db.get_databases()
739        self.assertIn('template0', databases)
740        self.assertIn('template1', databases)
741        self.assertNotIn('not existing database', databases)
742        self.assertIn('postgres', databases)
743        self.assertIn(dbname, databases)
744
745    def testGetTables(self):
746        get_tables = self.db.get_tables
747        result1 = get_tables()
748        self.assertIsInstance(result1, list)
749        for t in result1:
750            t = t.split('.', 1)
751            self.assertGreaterEqual(len(t), 2)
752            if len(t) > 2:
753                self.assertTrue(t[1].startswith('"'))
754            t = t[0]
755            self.assertNotEqual(t, 'information_schema')
756            self.assertFalse(t.startswith('pg_'))
757        tables = ('"A very Special Name"',
758            '"A_MiXeD_quoted_NaMe"', 'a1', 'a2',
759            'A_MiXeD_NaMe', '"another special name"',
760            'averyveryveryveryveryveryverylongtablename',
761            'b0', 'b3', 'x', 'xx', 'xXx', 'y', 'z')
762        for t in tables:
763            self.db.query('drop table if exists %s' % t)
764            self.db.query("create table %s"
765                " as select 0" % t)
766        result3 = get_tables()
767        result2 = []
768        for t in result3:
769            if t not in result1:
770                result2.append(t)
771        result3 = []
772        for t in tables:
773            if not t.startswith('"'):
774                t = t.lower()
775            result3.append('public.' + t)
776        self.assertEqual(result2, result3)
777        for t in result2:
778            self.db.query('drop table %s' % t)
779        result2 = get_tables()
780        self.assertEqual(result2, result1)
781
782    def testGetRelations(self):
783        get_relations = self.db.get_relations
784        result = get_relations()
785        self.assertIn('public.test', result)
786        self.assertIn('public.test_view', result)
787        result = get_relations('rv')
788        self.assertIn('public.test', result)
789        self.assertIn('public.test_view', result)
790        result = get_relations('r')
791        self.assertIn('public.test', result)
792        self.assertNotIn('public.test_view', result)
793        result = get_relations('v')
794        self.assertNotIn('public.test', result)
795        self.assertIn('public.test_view', result)
796        result = get_relations('cisSt')
797        self.assertNotIn('public.test', result)
798        self.assertNotIn('public.test_view', result)
799
800    def testGetAttnames(self):
801        get_attnames = self.db.get_attnames
802        self.assertRaises(pg.ProgrammingError,
803            self.db.get_attnames, 'does_not_exist')
804        self.assertRaises(pg.ProgrammingError,
805            self.db.get_attnames, 'has.too.many.dots')
806        query = self.db.query
807        query("drop table if exists test_table")
808        self.addCleanup(query, "drop table test_table")
809        query("create table test_table("
810            " n int, alpha smallint, beta bool,"
811            " gamma char(5), tau text, v varchar(3))")
812        r = get_attnames('test_table')
813        self.assertIsInstance(r, dict)
814        self.assertEqual(r, dict(
815            n='int', alpha='int', beta='bool',
816            gamma='text', tau='text', v='text'))
817
818    def testGetAttnamesWithQuotes(self):
819        get_attnames = self.db.get_attnames
820        query = self.db.query
821        table = 'test table for get_attnames()'
822        query('drop table if exists "%s"' % table)
823        self.addCleanup(query, 'drop table "%s"' % table)
824        query('create table "%s"('
825            '"Prime!" smallint,'
826            ' "much space" integer, "Questions?" text)' % table)
827        r = get_attnames(table)
828        self.assertIsInstance(r, dict)
829        self.assertEqual(r, {
830            'Prime!': 'int', 'much space': 'int', 'Questions?': 'text'})
831        table = 'yet another test table for get_attnames()'
832        query('drop table if exists "%s"' % table)
833        self.addCleanup(query, 'drop table "%s"' % table)
834        self.db.query('create table "%s" ('
835            'a smallint, b integer, c bigint,'
836            ' e numeric, f float, f2 double precision, m money,'
837            ' x smallint, y smallint, z smallint,'
838            ' Normal_NaMe smallint, "Special Name" smallint,'
839            ' t text, u char(2), v varchar(2),'
840            ' primary key (y, u)) with oids' % table)
841        r = get_attnames(table)
842        self.assertIsInstance(r, dict)
843        self.assertEqual(r, {'a': 'int', 'c': 'int', 'b': 'int',
844            'e': 'num', 'f': 'float', 'f2': 'float', 'm': 'money',
845            'normal_name': 'int', 'Special Name': 'int',
846            'u': 'text', 't': 'text', 'v': 'text',
847            'y': 'int', 'x': 'int', 'z': 'int', 'oid': 'int'})
848
849    def testGetAttnamesWithRegtypes(self):
850        get_attnames = self.db.get_attnames
851        query = self.db.query
852        query("drop table if exists test_table")
853        self.addCleanup(query, "drop table test_table")
854        query("create table test_table("
855            " n int, alpha smallint, beta bool,"
856            " gamma char(5), tau text, v varchar(3))")
857        self.db.use_regtypes(True)
858        try:
859            r = get_attnames("test_table")
860            self.assertIsInstance(r, dict)
861        finally:
862            self.db.use_regtypes(False)
863        self.assertEqual(r, dict(
864            n='integer', alpha='smallint', beta='boolean',
865            gamma='character', tau='text', v='character varying'))
866
867    def testGetAttnamesIsCached(self):
868        get_attnames = self.db.get_attnames
869        query = self.db.query
870        query("drop table if exists test_table")
871        self.addCleanup(query, "drop table if exists test_table")
872        query("create table test_table(col int)")
873        r = get_attnames("test_table")
874        self.assertIsInstance(r, dict)
875        self.assertEqual(r, dict(col='int'))
876        query("drop table test_table")
877        query("create table test_table(col text)")
878        r = get_attnames("test_table")
879        self.assertEqual(r, dict(col='int'))
880        r = get_attnames("test_table", flush=True)
881        self.assertEqual(r, dict(col='text'))
882        query("drop table test_table")
883        r = get_attnames("test_table")
884        self.assertEqual(r, dict(col='text'))
885        self.assertRaises(pg.ProgrammingError,
886            get_attnames, "test_table", flush=True)
887
888    def testGetAttnamesIsOrdered(self):
889        get_attnames = self.db.get_attnames
890        query = self.db.query
891        query("drop table if exists test_table")
892        self.addCleanup(query, "drop table test_table")
893        query("create table test_table("
894            " n int, alpha smallint, v varchar(3),"
895            " gamma char(5), tau text, beta bool)")
896        r = get_attnames("test_table")
897        self.assertIsInstance(r, OrderedDict)
898        self.assertEqual(r, OrderedDict([
899            ('n', 'int'), ('alpha', 'int'), ('v', 'text'),
900            ('gamma', 'text'), ('tau', 'text'), ('beta', 'bool')]))
901        if OrderedDict is dict:
902            self.skipTest('OrderedDict is not supported')
903        r = ' '.join(list(r.keys()))
904        self.assertEqual(r, 'n alpha v gamma tau beta')
905
906    def testHasTablePrivilege(self):
907        can = self.db.has_table_privilege
908        self.assertEqual(can('test'), True)
909        self.assertEqual(can('test', 'select'), True)
910        self.assertEqual(can('test', 'SeLeCt'), True)
911        self.assertEqual(can('test', 'SELECT'), True)
912        self.assertEqual(can('test', 'insert'), True)
913        self.assertEqual(can('test', 'update'), True)
914        self.assertEqual(can('test', 'delete'), True)
915        self.assertEqual(can('pg_views', 'select'), True)
916        self.assertEqual(can('pg_views', 'delete'), False)
917        self.assertRaises(pg.ProgrammingError, can, 'test', 'foobar')
918        self.assertRaises(pg.ProgrammingError, can, 'table_does_not_exist')
919
920    def testGet(self):
921        get = self.db.get
922        query = self.db.query
923        table = 'get_test_table'
924        query('drop table if exists "%s"' % table)
925        self.addCleanup(query, 'drop table "%s"' % table)
926        query('create table "%s" ('
927            "n integer, t text) with oids" % table)
928        for n, t in enumerate('xyz'):
929            query('insert into "%s" values('"%d, '%s')"
930                % (table, n + 1, t))
931        self.assertRaises(pg.ProgrammingError, get, table, 2)
932        r = get(table, 2, 'n')
933        oid_table = 'oid(%s)' % table
934        self.assertIn(oid_table, r)
935        oid = r[oid_table]
936        self.assertIsInstance(oid, int)
937        result = {'t': 'y', 'n': 2, oid_table: oid}
938        self.assertEqual(r, result)
939        self.assertEqual(get(table + ' *', 2, 'n'), r)
940        self.assertEqual(get(table, oid, 'oid')['t'], 'y')
941        self.assertEqual(get(table, 1, 'n')['t'], 'x')
942        self.assertEqual(get(table, 3, 'n')['t'], 'z')
943        self.assertEqual(get(table, 2, 'n')['t'], 'y')
944        self.assertRaises(pg.DatabaseError, get, table, 4, 'n')
945        r['n'] = 3
946        self.assertEqual(get(table, r, 'n')['t'], 'z')
947        self.assertEqual(get(table, 1, 'n')['t'], 'x')
948        query('alter table "%s" alter n set not null' % table)
949        query('alter table "%s" add primary key (n)' % table)
950        self.assertEqual(get(table, 3)['t'], 'z')
951        self.assertEqual(get(table, 1)['t'], 'x')
952        self.assertEqual(get(table, 2)['t'], 'y')
953        r['n'] = 1
954        self.assertEqual(get(table, r)['t'], 'x')
955        r['n'] = 3
956        self.assertEqual(get(table, r)['t'], 'z')
957        r['n'] = 2
958        self.assertEqual(get(table, r)['t'], 'y')
959
960    def testGetWithCompositeKey(self):
961        get = self.db.get
962        query = self.db.query
963        table = 'get_test_table_1'
964        query('drop table if exists "%s"' % table)
965        self.addCleanup(query, 'drop table "%s"' % table)
966        query('create table "%s" ('
967            "n integer, t text, primary key (n))" % table)
968        for n, t in enumerate('abc'):
969            query('insert into "%s" values('
970                "%d, '%s')" % (table, n + 1, t))
971        self.assertEqual(get(table, 2)['t'], 'b')
972        table = 'get_test_table_2'
973        query('drop table if exists "%s"' % table)
974        self.addCleanup(query, 'drop table "%s"' % table)
975        query('create table "%s" ('
976            "n integer, m integer, t text, primary key (n, m))" % table)
977        for n in range(3):
978            for m in range(2):
979                t = chr(ord('a') + 2 * n + m)
980                query('insert into "%s" values('
981                    "%d, %d, '%s')" % (table, n + 1, m + 1, t))
982        self.assertRaises(pg.ProgrammingError, get, table, 2)
983        self.assertEqual(get(table, dict(n=2, m=2))['t'], 'd')
984        r = get(table, dict(n=1, m=2), ('n', 'm'))
985        self.assertEqual(r['t'], 'b')
986        r = get(table, dict(n=3, m=2), frozenset(['n', 'm']))
987        self.assertEqual(r['t'], 'f')
988
989    def testGetWithQuotedNames(self):
990        get = self.db.get
991        query = self.db.query
992        table = 'test table for get()'
993        query('drop table if exists "%s"' % table)
994        self.addCleanup(query, 'drop table "%s"' % table)
995        query('create table "%s" ('
996            '"Prime!" smallint primary key,'
997            ' "much space" integer, "Questions?" text)' % table)
998        query('insert into "%s"'
999              " values(17, 1001, 'No!')" % table)
1000        r = get(table, 17)
1001        self.assertIsInstance(r, dict)
1002        self.assertEqual(r['Prime!'], 17)
1003        self.assertEqual(r['much space'], 1001)
1004        self.assertEqual(r['Questions?'], 'No!')
1005
1006    def testGetFromView(self):
1007        self.db.query('delete from test where i4=14')
1008        self.db.query('insert into test (i4, v4) values('
1009            "14, 'abc4')")
1010        r = self.db.get('test_view', 14, 'i4')
1011        self.assertIn('v4', r)
1012        self.assertEqual(r['v4'], 'abc4')
1013
1014    def testGetLittleBobbyTables(self):
1015        get = self.db.get
1016        query = self.db.query
1017        query("drop table if exists test_students")
1018        self.addCleanup(query, "drop table test_students")
1019        query("create table test_students (firstname varchar primary key,"
1020            " nickname varchar, grade char(2))")
1021        query("insert into test_students values ("
1022              "'D''Arcy', 'Darcey', 'A+')")
1023        query("insert into test_students values ("
1024              "'Sheldon', 'Moonpie', 'A+')")
1025        query("insert into test_students values ("
1026              "'Robert', 'Little Bobby Tables', 'D-')")
1027        r = get('test_students', 'Sheldon')
1028        self.assertEqual(r, dict(
1029            firstname="Sheldon", nickname='Moonpie', grade='A+'))
1030        r = get('test_students', 'Robert')
1031        self.assertEqual(r, dict(
1032            firstname="Robert", nickname='Little Bobby Tables', grade='D-'))
1033        r = get('test_students', "D'Arcy")
1034        self.assertEqual(r, dict(
1035            firstname="D'Arcy", nickname='Darcey', grade='A+'))
1036        try:
1037            get('test_students', "D' Arcy")
1038        except pg.DatabaseError as error:
1039            self.assertEqual(str(error),
1040                'No such record in test_students\nwhere "firstname" = $1\n'
1041                'with $1="D\' Arcy"')
1042        try:
1043            get('test_students', "Robert'); TRUNCATE TABLE test_students;--")
1044        except pg.DatabaseError as error:
1045            self.assertEqual(str(error),
1046                'No such record in test_students\nwhere "firstname" = $1\n'
1047                'with $1="Robert\'); TRUNCATE TABLE test_students;--"')
1048        q = "select * from test_students order by 1 limit 4"
1049        r = query(q).getresult()
1050        self.assertEqual(len(r), 3)
1051        self.assertEqual(r[1][2], 'D-')
1052
1053    def testInsert(self):
1054        insert = self.db.insert
1055        query = self.db.query
1056        bool_on = pg.get_bool()
1057        decimal = pg.get_decimal()
1058        table = 'insert_test_table'
1059        query('drop table if exists "%s"' % table)
1060        self.addCleanup(query, 'drop table "%s"' % table)
1061        query('create table "%s" ('
1062            "i2 smallint, i4 integer, i8 bigint,"
1063            " d numeric, f4 real, f8 double precision, m money,"
1064            " v4 varchar(4), c4 char(4), t text,"
1065            " b boolean, ts timestamp) with oids" % table)
1066        oid_table = 'oid(%s)' % table
1067        tests = [dict(i2=None, i4=None, i8=None),
1068            (dict(i2='', i4='', i8=''), dict(i2=None, i4=None, i8=None)),
1069            (dict(i2=0, i4=0, i8=0), dict(i2=0, i4=0, i8=0)),
1070            dict(i2=42, i4=123456, i8=9876543210),
1071            dict(i2=2 ** 15 - 1,
1072                i4=int(2 ** 31 - 1), i8=long(2 ** 63 - 1)),
1073            dict(d=None), (dict(d=''), dict(d=None)),
1074            dict(d=Decimal(0)), (dict(d=0), dict(d=Decimal(0))),
1075            dict(f4=None, f8=None), dict(f4=0, f8=0),
1076            (dict(f4='', f8=''), dict(f4=None, f8=None)),
1077            (dict(d=1234.5, f4=1234.5, f8=1234.5),
1078                  dict(d=Decimal('1234.5'))),
1079            dict(d=Decimal('123.456789'), f4=12.375, f8=123.4921875),
1080            dict(d=Decimal('123456789.9876543212345678987654321')),
1081            dict(m=None), (dict(m=''), dict(m=None)),
1082            dict(m=Decimal('-1234.56')),
1083            (dict(m=('-1234.56')), dict(m=Decimal('-1234.56'))),
1084            dict(m=Decimal('1234.56')), dict(m=Decimal('123456')),
1085            (dict(m='1234.56'), dict(m=Decimal('1234.56'))),
1086            (dict(m=1234.5), dict(m=Decimal('1234.5'))),
1087            (dict(m=-1234.5), dict(m=Decimal('-1234.5'))),
1088            (dict(m=123456), dict(m=Decimal('123456'))),
1089            (dict(m='1234567.89'), dict(m=Decimal('1234567.89'))),
1090            dict(b=None), (dict(b=''), dict(b=None)),
1091            dict(b='f'), dict(b='t'),
1092            (dict(b=0), dict(b='f')), (dict(b=1), dict(b='t')),
1093            (dict(b=False), dict(b='f')), (dict(b=True), dict(b='t')),
1094            (dict(b='0'), dict(b='f')), (dict(b='1'), dict(b='t')),
1095            (dict(b='n'), dict(b='f')), (dict(b='y'), dict(b='t')),
1096            (dict(b='no'), dict(b='f')), (dict(b='yes'), dict(b='t')),
1097            (dict(b='off'), dict(b='f')), (dict(b='on'), dict(b='t')),
1098            dict(v4=None, c4=None, t=None),
1099            (dict(v4='', c4='', t=''), dict(c4=' ' * 4)),
1100            dict(v4='1234', c4='1234', t='1234' * 10),
1101            dict(v4='abcd', c4='abcd', t='abcdefg'),
1102            (dict(v4='abc', c4='abc', t='abc'), dict(c4='abc ')),
1103            dict(ts=None), (dict(ts=''), dict(ts=None)),
1104            (dict(ts=0), dict(ts=None)), (dict(ts=False), dict(ts=None)),
1105            dict(ts='2012-12-21 00:00:00'),
1106            (dict(ts='2012-12-21'), dict(ts='2012-12-21 00:00:00')),
1107            dict(ts='2012-12-21 12:21:12'),
1108            dict(ts='2013-01-05 12:13:14'),
1109            dict(ts='current_timestamp')]
1110        for test in tests:
1111            if isinstance(test, dict):
1112                data = test
1113                change = {}
1114            else:
1115                data, change = test
1116            expect = data.copy()
1117            expect.update(change)
1118            if bool_on:
1119                b = expect.get('b')
1120                if b is not None:
1121                    expect['b'] = b == 't'
1122            if decimal is not Decimal:
1123                d = expect.get('d')
1124                if d is not None:
1125                    expect['d'] = decimal(d)
1126                m = expect.get('m')
1127                if m is not None:
1128                    expect['m'] = decimal(m)
1129            self.assertEqual(insert(table, data), data)
1130            self.assertIn(oid_table, data)
1131            oid = data[oid_table]
1132            self.assertIsInstance(oid, int)
1133            data = dict(item for item in data.items()
1134                if item[0] in expect)
1135            ts = expect.get('ts')
1136            if ts == 'current_timestamp':
1137                ts = expect['ts'] = data['ts']
1138                if len(ts) > 19:
1139                    self.assertEqual(ts[19], '.')
1140                    ts = ts[:19]
1141                else:
1142                    self.assertEqual(len(ts), 19)
1143                self.assertTrue(ts[:4].isdigit())
1144                self.assertEqual(ts[4], '-')
1145                self.assertEqual(ts[10], ' ')
1146                self.assertTrue(ts[11:13].isdigit())
1147                self.assertEqual(ts[13], ':')
1148            self.assertEqual(data, expect)
1149            data = query(
1150                'select oid,* from "%s"' % table).dictresult()[0]
1151            self.assertEqual(data['oid'], oid)
1152            data = dict(item for item in data.items()
1153                if item[0] in expect)
1154            self.assertEqual(data, expect)
1155            query('delete from "%s"' % table)
1156
1157    def testInsertWithQuotedNames(self):
1158        insert = self.db.insert
1159        query = self.db.query
1160        table = 'test table for insert()'
1161        query('drop table if exists "%s"' % table)
1162        self.addCleanup(query, 'drop table "%s"' % table)
1163        query('create table "%s" ('
1164            '"Prime!" smallint primary key,'
1165            ' "much space" integer, "Questions?" text)' % table)
1166        r = {'Prime!': 11, 'much space': 2002, 'Questions?': 'What?'}
1167        r = insert(table, r)
1168        self.assertIsInstance(r, dict)
1169        self.assertEqual(r['Prime!'], 11)
1170        self.assertEqual(r['much space'], 2002)
1171        self.assertEqual(r['Questions?'], 'What?')
1172        r = query('select * from "%s" limit 2' % table).dictresult()
1173        self.assertEqual(len(r), 1)
1174        r = r[0]
1175        self.assertEqual(r['Prime!'], 11)
1176        self.assertEqual(r['much space'], 2002)
1177        self.assertEqual(r['Questions?'], 'What?')
1178
1179    def testUpdate(self):
1180        update = self.db.update
1181        query = self.db.query
1182        table = 'update_test_table'
1183        query('drop table if exists "%s"' % table)
1184        self.addCleanup(query, 'drop table "%s"' % table)
1185        query('create table "%s" ('
1186            "n integer, t text) with oids" % table)
1187        for n, t in enumerate('xyz'):
1188            query('insert into "%s" values('
1189                "%d, '%s')" % (table, n + 1, t))
1190        self.assertRaises(pg.ProgrammingError, self.db.get, table, 2)
1191        r = self.db.get(table, 2, 'n')
1192        r['t'] = 'u'
1193        s = update(table, r)
1194        self.assertEqual(s, r)
1195        q = 'select t from "%s" where n=2' % table
1196        r = query(q).getresult()[0][0]
1197        self.assertEqual(r, 'u')
1198
1199    def testUpdateWithCompositeKey(self):
1200        update = self.db.update
1201        query = self.db.query
1202        table = 'update_test_table_1'
1203        query('drop table if exists "%s"' % table)
1204        self.addCleanup(query, 'drop table if exists "%s"' % table)
1205        query('create table "%s" ('
1206            "n integer, t text, primary key (n))" % table)
1207        for n, t in enumerate('abc'):
1208            query('insert into "%s" values('
1209                "%d, '%s')" % (table, n + 1, t))
1210        self.assertRaises(pg.ProgrammingError, update,
1211                          table, dict(t='b'))
1212        s = dict(n=2, t='d')
1213        r = update(table, s)
1214        self.assertIs(r, s)
1215        self.assertEqual(r['n'], 2)
1216        self.assertEqual(r['t'], 'd')
1217        q = 'select t from "%s" where n=2' % table
1218        r = query(q).getresult()[0][0]
1219        self.assertEqual(r, 'd')
1220        s.update(dict(n=4, t='e'))
1221        r = update(table, s)
1222        self.assertEqual(r['n'], 4)
1223        self.assertEqual(r['t'], 'e')
1224        q = 'select t from "%s" where n=2' % table
1225        r = query(q).getresult()[0][0]
1226        self.assertEqual(r, 'd')
1227        q = 'select t from "%s" where n=4' % table
1228        r = query(q).getresult()
1229        self.assertEqual(len(r), 0)
1230        query('drop table "%s"' % table)
1231        table = 'update_test_table_2'
1232        query('drop table if exists "%s"' % table)
1233        query('create table "%s" ('
1234            "n integer, m integer, t text, primary key (n, m))" % table)
1235        for n in range(3):
1236            for m in range(2):
1237                t = chr(ord('a') + 2 * n + m)
1238                query('insert into "%s" values('
1239                    "%d, %d, '%s')" % (table, n + 1, m + 1, t))
1240        self.assertRaises(pg.ProgrammingError, update,
1241                          table, dict(n=2, t='b'))
1242        self.assertEqual(update(table,
1243                                dict(n=2, m=2, t='x'))['t'], 'x')
1244        q = 'select t from "%s" where n=2 order by m' % table
1245        r = [r[0] for r in query(q).getresult()]
1246        self.assertEqual(r, ['c', 'x'])
1247
1248    def testUpdateWithQuotedNames(self):
1249        update = self.db.update
1250        query = self.db.query
1251        table = 'test table for update()'
1252        query('drop table if exists "%s"' % table)
1253        self.addCleanup(query, 'drop table "%s"' % table)
1254        query('create table "%s" ('
1255            '"Prime!" smallint primary key,'
1256            ' "much space" integer, "Questions?" text)' % table)
1257        query('insert into "%s"'
1258              " values(13, 3003, 'Why!')" % table)
1259        r = {'Prime!': 13, 'much space': 7007, 'Questions?': 'When?'}
1260        r = update(table, r)
1261        self.assertIsInstance(r, dict)
1262        self.assertEqual(r['Prime!'], 13)
1263        self.assertEqual(r['much space'], 7007)
1264        self.assertEqual(r['Questions?'], 'When?')
1265        r = query('select * from "%s" limit 2' % table).dictresult()
1266        self.assertEqual(len(r), 1)
1267        r = r[0]
1268        self.assertEqual(r['Prime!'], 13)
1269        self.assertEqual(r['much space'], 7007)
1270        self.assertEqual(r['Questions?'], 'When?')
1271
1272    def testUpsert(self):
1273        upsert = self.db.upsert
1274        query = self.db.query
1275        table = 'upsert_test_table'
1276        query('drop table if exists "%s"' % table)
1277        self.addCleanup(query, 'drop table "%s"' % table)
1278        query('create table "%s" ('
1279            "n integer primary key, t text) with oids" % table)
1280        s = dict(n=1, t='x')
1281        try:
1282            r = upsert(table, s)
1283        except pg.ProgrammingError as error:
1284            if self.db.server_version < 90500:
1285                self.skipTest('database does not support upsert')
1286            self.fail(str(error))
1287        self.assertIs(r, s)
1288        self.assertEqual(r['n'], 1)
1289        self.assertEqual(r['t'], 'x')
1290        s.update(n=2, t='y')
1291        r = upsert(table, s, **dict.fromkeys(s))
1292        self.assertIs(r, s)
1293        self.assertEqual(r['n'], 2)
1294        self.assertEqual(r['t'], 'y')
1295        q = 'select n, t from "%s" order by n limit 3' % table
1296        r = query(q).getresult()
1297        self.assertEqual(r, [(1, 'x'), (2, 'y')])
1298        s.update(t='z')
1299        r = upsert(table, s)
1300        self.assertIs(r, s)
1301        self.assertEqual(r['n'], 2)
1302        self.assertEqual(r['t'], 'z')
1303        r = query(q).getresult()
1304        self.assertEqual(r, [(1, 'x'), (2, 'z')])
1305        s.update(t='n')
1306        r = upsert(table, s, t=False)
1307        self.assertIs(r, s)
1308        self.assertEqual(r['n'], 2)
1309        self.assertEqual(r['t'], 'z')
1310        r = query(q).getresult()
1311        self.assertEqual(r, [(1, 'x'), (2, 'z')])
1312        s.update(t='y')
1313        r = upsert(table, s, t=True)
1314        self.assertIs(r, s)
1315        self.assertEqual(r['n'], 2)
1316        self.assertEqual(r['t'], 'y')
1317        r = query(q).getresult()
1318        self.assertEqual(r, [(1, 'x'), (2, 'y')])
1319        s.update(t='n')
1320        r = upsert(table, s, t="included.t || '2'")
1321        self.assertIs(r, s)
1322        self.assertEqual(r['n'], 2)
1323        self.assertEqual(r['t'], 'y2')
1324        r = query(q).getresult()
1325        self.assertEqual(r, [(1, 'x'), (2, 'y2')])
1326        s.update(t='y')
1327        r = upsert(table, s, t="excluded.t || '3'")
1328        self.assertIs(r, s)
1329        self.assertEqual(r['n'], 2)
1330        self.assertEqual(r['t'], 'y3')
1331        r = query(q).getresult()
1332        self.assertEqual(r, [(1, 'x'), (2, 'y3')])
1333        s.update(n=1, t='2')
1334        r = upsert(table, s, t="included.t || excluded.t")
1335        self.assertIs(r, s)
1336        self.assertEqual(r['n'], 1)
1337        self.assertEqual(r['t'], 'x2')
1338        r = query(q).getresult()
1339        self.assertEqual(r, [(1, 'x2'), (2, 'y3')])
1340
1341    def testUpsertWithCompositeKey(self):
1342        upsert = self.db.upsert
1343        query = self.db.query
1344        table = 'upsert_test_table_2'
1345        query('drop table if exists "%s"' % table)
1346        self.addCleanup(query, 'drop table "%s"' % table)
1347        query('create table "%s" ('
1348            "n integer, m integer, t text, primary key (n, m))" % table)
1349        s = dict(n=1, m=2, t='x')
1350        try:
1351            r = upsert(table, s)
1352        except pg.ProgrammingError as error:
1353            if self.db.server_version < 90500:
1354                self.skipTest('database does not support upsert')
1355            self.fail(str(error))
1356        self.assertIs(r, s)
1357        self.assertEqual(r['n'], 1)
1358        self.assertEqual(r['m'], 2)
1359        self.assertEqual(r['t'], 'x')
1360        s.update(m=3, t='y')
1361        r = upsert(table, s, **dict.fromkeys(s))
1362        self.assertIs(r, s)
1363        self.assertEqual(r['n'], 1)
1364        self.assertEqual(r['m'], 3)
1365        self.assertEqual(r['t'], 'y')
1366        q = 'select n, m, t from "%s" order by n, m limit 3' % table
1367        r = query(q).getresult()
1368        self.assertEqual(r, [(1, 2, 'x'), (1, 3, 'y')])
1369        s.update(t='z')
1370        r = upsert(table, s)
1371        self.assertIs(r, s)
1372        self.assertEqual(r['n'], 1)
1373        self.assertEqual(r['m'], 3)
1374        self.assertEqual(r['t'], 'z')
1375        r = query(q).getresult()
1376        self.assertEqual(r, [(1, 2, 'x'), (1, 3, 'z')])
1377        s.update(t='n')
1378        r = upsert(table, s, t=False)
1379        self.assertIs(r, s)
1380        self.assertEqual(r['n'], 1)
1381        self.assertEqual(r['m'], 3)
1382        self.assertEqual(r['t'], 'z')
1383        r = query(q).getresult()
1384        self.assertEqual(r, [(1, 2, 'x'), (1, 3, 'z')])
1385        s.update(t='n')
1386        r = upsert(table, s, t=True)
1387        self.assertIs(r, s)
1388        self.assertEqual(r['n'], 1)
1389        self.assertEqual(r['m'], 3)
1390        self.assertEqual(r['t'], 'n')
1391        r = query(q).getresult()
1392        self.assertEqual(r, [(1, 2, 'x'), (1, 3, 'n')])
1393        s.update(n=2, t='y')
1394        r = upsert(table, s, t="'z'")
1395        self.assertIs(r, s)
1396        self.assertEqual(r['n'], 2)
1397        self.assertEqual(r['m'], 3)
1398        self.assertEqual(r['t'], 'y')
1399        r = query(q).getresult()
1400        self.assertEqual(r, [(1, 2, 'x'), (1, 3, 'n'), (2, 3, 'y')])
1401        s.update(n=1, t='m')
1402        r = upsert(table, s, t='included.t || excluded.t')
1403        self.assertIs(r, s)
1404        self.assertEqual(r['n'], 1)
1405        self.assertEqual(r['m'], 3)
1406        self.assertEqual(r['t'], 'nm')
1407        r = query(q).getresult()
1408        self.assertEqual(r, [(1, 2, 'x'), (1, 3, 'nm'), (2, 3, 'y')])
1409
1410    def testUpsertWithQuotedNames(self):
1411        upsert = self.db.upsert
1412        query = self.db.query
1413        table = 'test table for upsert()'
1414        query('drop table if exists "%s"' % table)
1415        self.addCleanup(query, 'drop table "%s"' % table)
1416        query('create table "%s" ('
1417            '"Prime!" smallint primary key,'
1418            ' "much space" integer, "Questions?" text)' % table)
1419        s = {'Prime!': 31, 'much space': 9009, 'Questions?': 'Yes.'}
1420        try:
1421            r = upsert(table, s)
1422        except pg.ProgrammingError as error:
1423            if self.db.server_version < 90500:
1424                self.skipTest('database does not support upsert')
1425            self.fail(str(error))
1426        self.assertIs(r, s)
1427        self.assertEqual(r['Prime!'], 31)
1428        self.assertEqual(r['much space'], 9009)
1429        self.assertEqual(r['Questions?'], 'Yes.')
1430        q = 'select * from "%s" limit 2' % table
1431        r = query(q).getresult()
1432        self.assertEqual(r, [(31, 9009, 'Yes.')])
1433        s.update({'Questions?': 'No.'})
1434        r = upsert(table, s)
1435        self.assertIs(r, s)
1436        self.assertEqual(r['Prime!'], 31)
1437        self.assertEqual(r['much space'], 9009)
1438        self.assertEqual(r['Questions?'], 'No.')
1439        r = query(q).getresult()
1440        self.assertEqual(r, [(31, 9009, 'No.')])
1441
1442    def testClear(self):
1443        clear = self.db.clear
1444        query = self.db.query
1445        f = False if pg.get_bool() else 'f'
1446        table = 'clear_test_table'
1447        query('drop table if exists "%s"' % table)
1448        self.addCleanup(query, 'drop table "%s"' % table)
1449        query('create table "%s" ('
1450            "n integer, b boolean, d date, t text)" % table)
1451        r = clear(table)
1452        result = {'n': 0, 'b': f, 'd': '', 't': ''}
1453        self.assertEqual(r, result)
1454        r['a'] = r['n'] = 1
1455        r['d'] = r['t'] = 'x'
1456        r['b'] = 't'
1457        r['oid'] = long(1)
1458        r = clear(table, r)
1459        result = {'a': 1, 'n': 0, 'b': f, 'd': '', 't': '',
1460            'oid': long(1)}
1461        self.assertEqual(r, result)
1462
1463    def testClearWithQuotedNames(self):
1464        clear = self.db.clear
1465        query = self.db.query
1466        table = 'test table for clear()'
1467        query('drop table if exists "%s"' % table)
1468        self.addCleanup(query, 'drop table "%s"' % table)
1469        query('create table "%s" ('
1470            '"Prime!" smallint primary key,'
1471            ' "much space" integer, "Questions?" text)' % table)
1472        r = clear(table)
1473        self.assertIsInstance(r, dict)
1474        self.assertEqual(r['Prime!'], 0)
1475        self.assertEqual(r['much space'], 0)
1476        self.assertEqual(r['Questions?'], '')
1477
1478    def testDelete(self):
1479        delete = self.db.delete
1480        query = self.db.query
1481        table = 'delete_test_table'
1482        query('drop table if exists "%s"' % table)
1483        self.addCleanup(query, 'drop table "%s"' % table)
1484        query('create table "%s" ('
1485            "n integer, t text) with oids" % table)
1486        for n, t in enumerate('xyz'):
1487            query('insert into "%s" values('
1488                "%d, '%s')" % (table, n + 1, t))
1489        self.assertRaises(pg.ProgrammingError, self.db.get, table, 2)
1490        r = self.db.get(table, 1, 'n')
1491        s = delete(table, r)
1492        self.assertEqual(s, 1)
1493        r = self.db.get(table, 3, 'n')
1494        s = delete(table, r)
1495        self.assertEqual(s, 1)
1496        s = delete(table, r)
1497        self.assertEqual(s, 0)
1498        r = query('select * from "%s"' % table).dictresult()
1499        self.assertEqual(len(r), 1)
1500        r = r[0]
1501        result = {'n': 2, 't': 'y'}
1502        self.assertEqual(r, result)
1503        r = self.db.get(table, 2, 'n')
1504        s = delete(table, r)
1505        self.assertEqual(s, 1)
1506        s = delete(table, r)
1507        self.assertEqual(s, 0)
1508        self.assertRaises(pg.DatabaseError, self.db.get, table, 2, 'n')
1509
1510    def testDeleteWithCompositeKey(self):
1511        query = self.db.query
1512        table = 'delete_test_table_1'
1513        query('drop table if exists "%s"' % table)
1514        self.addCleanup(query, 'drop table "%s"' % table)
1515        query('create table "%s" ('
1516            "n integer, t text, primary key (n))" % table)
1517        for n, t in enumerate('abc'):
1518            query("insert into %s values("
1519                "%d, '%s')" % (table, n + 1, t))
1520        self.assertRaises(pg.ProgrammingError, self.db.delete,
1521            table, dict(t='b'))
1522        self.assertEqual(self.db.delete(table, dict(n=2)), 1)
1523        r = query('select t from "%s" where n=2' % table
1524                  ).getresult()
1525        self.assertEqual(r, [])
1526        self.assertEqual(self.db.delete(table, dict(n=2)), 0)
1527        r = query('select t from "%s" where n=3' % table
1528                  ).getresult()[0][0]
1529        self.assertEqual(r, 'c')
1530        table = 'delete_test_table_2'
1531        query('drop table if exists "%s"' % table)
1532        self.addCleanup(query, 'drop table "%s"' % table)
1533        query('create table "%s" ('
1534            "n integer, m integer, t text, primary key (n, m))" % table)
1535        for n in range(3):
1536            for m in range(2):
1537                t = chr(ord('a') + 2 * n + m)
1538                query('insert into "%s" values('
1539                    "%d, %d, '%s')" % (table, n + 1, m + 1, t))
1540        self.assertRaises(pg.ProgrammingError, self.db.delete,
1541            table, dict(n=2, t='b'))
1542        self.assertEqual(self.db.delete(table, dict(n=2, m=2)), 1)
1543        r = [r[0] for r in query('select t from "%s" where n=2'
1544            ' order by m' % table).getresult()]
1545        self.assertEqual(r, ['c'])
1546        self.assertEqual(self.db.delete(table, dict(n=2, m=2)), 0)
1547        r = [r[0] for r in query('select t from "%s" where n=3'
1548            ' order by m' % table).getresult()]
1549        self.assertEqual(r, ['e', 'f'])
1550        self.assertEqual(self.db.delete(table, dict(n=3, m=1)), 1)
1551        r = [r[0] for r in query('select t from "%s" where n=3'
1552            ' order by m' % table).getresult()]
1553        self.assertEqual(r, ['f'])
1554
1555    def testDeleteWithQuotedNames(self):
1556        delete = self.db.delete
1557        query = self.db.query
1558        table = 'test table for delete()'
1559        query('drop table if exists "%s"' % table)
1560        self.addCleanup(query, 'drop table "%s"' % table)
1561        query('create table "%s" ('
1562            '"Prime!" smallint primary key,'
1563            ' "much space" integer, "Questions?" text)' % table)
1564        query('insert into "%s"'
1565              " values(19, 5005, 'Yes!')" % table)
1566        r = {'Prime!': 17}
1567        r = delete(table, r)
1568        self.assertEqual(r, 0)
1569        r = query('select count(*) from "%s"' % table).getresult()
1570        self.assertEqual(r[0][0], 1)
1571        r = {'Prime!': 19}
1572        r = delete(table, r)
1573        self.assertEqual(r, 1)
1574        r = query('select count(*) from "%s"' % table).getresult()
1575        self.assertEqual(r[0][0], 0)
1576
1577    def testTruncate(self):
1578        truncate = self.db.truncate
1579        self.assertRaises(TypeError, truncate, None)
1580        self.assertRaises(TypeError, truncate, 42)
1581        self.assertRaises(TypeError, truncate, dict(test_table=None))
1582        query = self.db.query
1583        query("drop table if exists test_table")
1584        self.addCleanup(query, "drop table test_table")
1585        query("create table test_table (n smallint)")
1586        for i in range(3):
1587            query("insert into test_table values (1)")
1588        q = "select count(*) from test_table"
1589        r = query(q).getresult()[0][0]
1590        self.assertEqual(r, 3)
1591        truncate('test_table')
1592        r = query(q).getresult()[0][0]
1593        self.assertEqual(r, 0)
1594        for i in range(3):
1595            query("insert into test_table values (1)")
1596        r = query(q).getresult()[0][0]
1597        self.assertEqual(r, 3)
1598        truncate('public.test_table')
1599        r = query(q).getresult()[0][0]
1600        self.assertEqual(r, 0)
1601        query("drop table if exists test_table_2")
1602        self.addCleanup(query, "drop table test_table_2")
1603        query('create table test_table_2 (n smallint)')
1604        for t in (list, tuple, set):
1605            for i in range(3):
1606                query("insert into test_table values (1)")
1607                query("insert into test_table_2 values (2)")
1608            q = ("select (select count(*) from test_table),"
1609                " (select count(*) from test_table_2)")
1610            r = query(q).getresult()[0]
1611            self.assertEqual(r, (3, 3))
1612            truncate(t(['test_table', 'test_table_2']))
1613            r = query(q).getresult()[0]
1614            self.assertEqual(r, (0, 0))
1615
1616    def testTruncateRestart(self):
1617        truncate = self.db.truncate
1618        self.assertRaises(TypeError, truncate, 'test_table', restart='invalid')
1619        query = self.db.query
1620        query("drop table if exists test_table")
1621        self.addCleanup(query, "drop table test_table")
1622        query("create table test_table (n serial, t text)")
1623        for n in range(3):
1624            query("insert into test_table (t) values ('test')")
1625        q = "select count(n), min(n), max(n) from test_table"
1626        r = query(q).getresult()[0]
1627        self.assertEqual(r, (3, 1, 3))
1628        truncate('test_table')
1629        r = query(q).getresult()[0]
1630        self.assertEqual(r, (0, None, None))
1631        for n in range(3):
1632            query("insert into test_table (t) values ('test')")
1633        r = query(q).getresult()[0]
1634        self.assertEqual(r, (3, 4, 6))
1635        truncate('test_table', restart=True)
1636        r = query(q).getresult()[0]
1637        self.assertEqual(r, (0, None, None))
1638        for n in range(3):
1639            query("insert into test_table (t) values ('test')")
1640        r = query(q).getresult()[0]
1641        self.assertEqual(r, (3, 1, 3))
1642
1643    def testTruncateCascade(self):
1644        truncate = self.db.truncate
1645        self.assertRaises(TypeError, truncate, 'test_table', cascade='invalid')
1646        query = self.db.query
1647        query("drop table if exists test_child")
1648        query("drop table if exists test_parent")
1649        self.addCleanup(query, "drop table test_parent")
1650        query("create table test_parent (n smallint primary key)")
1651        self.addCleanup(query, "drop table test_child")
1652        query("create table test_child ("
1653            " n smallint primary key references test_parent (n))")
1654        for n in range(3):
1655            query("insert into test_parent (n) values (%d)" % n)
1656            query("insert into test_child (n) values (%d)" % n)
1657        q = ("select (select count(*) from test_parent),"
1658            " (select count(*) from test_child)")
1659        r = query(q).getresult()[0]
1660        self.assertEqual(r, (3, 3))
1661        self.assertRaises(pg.ProgrammingError, truncate, 'test_parent')
1662        truncate(['test_parent', 'test_child'])
1663        r = query(q).getresult()[0]
1664        self.assertEqual(r, (0, 0))
1665        for n in range(3):
1666            query("insert into test_parent (n) values (%d)" % n)
1667            query("insert into test_child (n) values (%d)" % n)
1668        r = query(q).getresult()[0]
1669        self.assertEqual(r, (3, 3))
1670        truncate('test_parent', cascade=True)
1671        r = query(q).getresult()[0]
1672        self.assertEqual(r, (0, 0))
1673        for n in range(3):
1674            query("insert into test_parent (n) values (%d)" % n)
1675            query("insert into test_child (n) values (%d)" % n)
1676        r = query(q).getresult()[0]
1677        self.assertEqual(r, (3, 3))
1678        truncate('test_child')
1679        r = query(q).getresult()[0]
1680        self.assertEqual(r, (3, 0))
1681        self.assertRaises(pg.ProgrammingError, truncate, 'test_parent')
1682        truncate('test_parent', cascade=True)
1683        r = query(q).getresult()[0]
1684        self.assertEqual(r, (0, 0))
1685
1686    def testTruncateOnly(self):
1687        truncate = self.db.truncate
1688        self.assertRaises(TypeError, truncate, 'test_table', only='invalid')
1689        query = self.db.query
1690        query("drop table if exists test_child")
1691        query("drop table if exists test_parent")
1692        self.addCleanup(query, "drop table test_parent")
1693        query("create table test_parent (n smallint)")
1694        self.addCleanup(query, "drop table test_child")
1695        query("create table test_child ("
1696            " m smallint) inherits (test_parent)")
1697        for n in range(3):
1698            query("insert into test_parent (n) values (1)")
1699            query("insert into test_child (n, m) values (2, 3)")
1700        q = ("select (select count(*) from test_parent),"
1701            " (select count(*) from test_child)")
1702        r = query(q).getresult()[0]
1703        self.assertEqual(r, (6, 3))
1704        truncate('test_parent')
1705        r = query(q).getresult()[0]
1706        self.assertEqual(r, (0, 0))
1707        for n in range(3):
1708            query("insert into test_parent (n) values (1)")
1709            query("insert into test_child (n, m) values (2, 3)")
1710        r = query(q).getresult()[0]
1711        self.assertEqual(r, (6, 3))
1712        truncate('test_parent*')
1713        r = query(q).getresult()[0]
1714        self.assertEqual(r, (0, 0))
1715        for n in range(3):
1716            query("insert into test_parent (n) values (1)")
1717            query("insert into test_child (n, m) values (2, 3)")
1718        r = query(q).getresult()[0]
1719        self.assertEqual(r, (6, 3))
1720        truncate('test_parent', only=True)
1721        r = query(q).getresult()[0]
1722        self.assertEqual(r, (3, 3))
1723        truncate('test_parent', only=False)
1724        r = query(q).getresult()[0]
1725        self.assertEqual(r, (0, 0))
1726        self.assertRaises(ValueError, truncate, 'test_parent*', only=True)
1727        truncate('test_parent*', only=False)
1728        query("drop table if exists test_parent_2")
1729        self.addCleanup(query, "drop table test_parent_2")
1730        query("create table test_parent_2 (n smallint)")
1731        query("drop table if exists test_child_2")
1732        self.addCleanup(query, "drop table test_child_2")
1733        query("create table test_child_2 ("
1734            " m smallint) inherits (test_parent_2)")
1735        for n in range(3):
1736            query("insert into test_parent (n) values (1)")
1737            query("insert into test_child (n, m) values (2, 3)")
1738            query("insert into test_parent_2 (n) values (1)")
1739            query("insert into test_child_2 (n, m) values (2, 3)")
1740        q = ("select (select count(*) from test_parent),"
1741            " (select count(*) from test_child),"
1742            " (select count(*) from test_parent_2),"
1743            " (select count(*) from test_child_2)")
1744        r = query(q).getresult()[0]
1745        self.assertEqual(r, (6, 3, 6, 3))
1746        truncate(['test_parent', 'test_parent_2'], only=[False, True])
1747        r = query(q).getresult()[0]
1748        self.assertEqual(r, (0, 0, 3, 3))
1749        truncate(['test_parent', 'test_parent_2'], only=False)
1750        r = query(q).getresult()[0]
1751        self.assertEqual(r, (0, 0, 0, 0))
1752        self.assertRaises(ValueError, truncate,
1753            ['test_parent*', 'test_child'], only=[True, False])
1754        truncate(['test_parent*', 'test_child'], only=[False, True])
1755
1756    def testTruncateQuoted(self):
1757        truncate = self.db.truncate
1758        query = self.db.query
1759        table = "test table for truncate()"
1760        query('drop table if exists "%s"' % table)
1761        self.addCleanup(query, 'drop table "%s"' % table)
1762        query('create table "%s" (n smallint)' % table)
1763        for i in range(3):
1764            query('insert into "%s" values (1)' % table)
1765        q = 'select count(*) from "%s"' % table
1766        r = query(q).getresult()[0][0]
1767        self.assertEqual(r, 3)
1768        truncate(table)
1769        r = query(q).getresult()[0][0]
1770        self.assertEqual(r, 0)
1771        for i in range(3):
1772            query('insert into "%s" values (1)' % table)
1773        r = query(q).getresult()[0][0]
1774        self.assertEqual(r, 3)
1775        truncate('public."%s"' % table)
1776        r = query(q).getresult()[0][0]
1777        self.assertEqual(r, 0)
1778
1779    def testTransaction(self):
1780        query = self.db.query
1781        query("drop table if exists test_table")
1782        self.addCleanup(query, "drop table test_table")
1783        query("create table test_table (n integer)")
1784        self.db.begin()
1785        query("insert into test_table values (1)")
1786        query("insert into test_table values (2)")
1787        self.db.commit()
1788        self.db.begin()
1789        query("insert into test_table values (3)")
1790        query("insert into test_table values (4)")
1791        self.db.rollback()
1792        self.db.begin()
1793        query("insert into test_table values (5)")
1794        self.db.savepoint('before6')
1795        query("insert into test_table values (6)")
1796        self.db.rollback('before6')
1797        query("insert into test_table values (7)")
1798        self.db.commit()
1799        self.db.begin()
1800        self.db.savepoint('before8')
1801        query("insert into test_table values (8)")
1802        self.db.release('before8')
1803        self.assertRaises(pg.ProgrammingError, self.db.rollback, 'before8')
1804        self.db.commit()
1805        self.db.start()
1806        query("insert into test_table values (9)")
1807        self.db.end()
1808        r = [r[0] for r in query(
1809            "select * from test_table order by 1").getresult()]
1810        self.assertEqual(r, [1, 2, 5, 7, 9])
1811
1812    def testContextManager(self):
1813        query = self.db.query
1814        query("drop table if exists test_table")
1815        self.addCleanup(query, "drop table test_table")
1816        query("create table test_table (n integer check(n>0))")
1817        with self.db:
1818            query("insert into test_table values (1)")
1819            query("insert into test_table values (2)")
1820        try:
1821            with self.db:
1822                query("insert into test_table values (3)")
1823                query("insert into test_table values (4)")
1824                raise ValueError('test transaction should rollback')
1825        except ValueError as error:
1826            self.assertEqual(str(error), 'test transaction should rollback')
1827        with self.db:
1828            query("insert into test_table values (5)")
1829        try:
1830            with self.db:
1831                query("insert into test_table values (6)")
1832                query("insert into test_table values (-1)")
1833        except pg.ProgrammingError as error:
1834            self.assertTrue('check' in str(error))
1835        with self.db:
1836            query("insert into test_table values (7)")
1837        r = [r[0] for r in query(
1838            "select * from test_table order by 1").getresult()]
1839        self.assertEqual(r, [1, 2, 5, 7])
1840
1841    def testBytea(self):
1842        query = self.db.query
1843        query('drop table if exists bytea_test')
1844        self.addCleanup(query, 'drop table bytea_test')
1845        query('create table bytea_test (n smallint primary key, data bytea)')
1846        s = b"It's all \\ kinds \x00 of\r nasty \xff stuff!\n"
1847        r = self.db.escape_bytea(s)
1848        query('insert into bytea_test values(3,$1)', (r,))
1849        r = query('select * from bytea_test where n=3').getresult()
1850        self.assertEqual(len(r), 1)
1851        r = r[0]
1852        self.assertEqual(len(r), 2)
1853        self.assertEqual(r[0], 3)
1854        r = r[1]
1855        self.assertIsInstance(r, str)
1856        r = self.db.unescape_bytea(r)
1857        self.assertIsInstance(r, bytes)
1858        self.assertEqual(r, s)
1859
1860    def testInsertUpdateGetBytea(self):
1861        query = self.db.query
1862        query('drop table if exists bytea_test')
1863        self.addCleanup(query, 'drop table bytea_test')
1864        query('create table bytea_test (n smallint primary key, data bytea)')
1865        # insert null value
1866        r = self.db.insert('bytea_test', n=0, data=None)
1867        self.assertIsInstance(r, dict)
1868        self.assertIn('n', r)
1869        self.assertEqual(r['n'], 0)
1870        self.assertIn('data', r)
1871        self.assertIsNone(r['data'])
1872        s = b'None'
1873        r = self.db.update('bytea_test', n=0, data=s)
1874        self.assertIsInstance(r, dict)
1875        self.assertIn('n', r)
1876        self.assertEqual(r['n'], 0)
1877        self.assertIn('data', r)
1878        r = r['data']
1879        self.assertIsInstance(r, bytes)
1880        self.assertEqual(r, s)
1881        r = self.db.update('bytea_test', n=0, data=None)
1882        self.assertIsNone(r['data'])
1883        # insert as bytes
1884        s = b"It's all \\ kinds \x00 of\r nasty \xff stuff!\n"
1885        r = self.db.insert('bytea_test', n=5, data=s)
1886        self.assertIsInstance(r, dict)
1887        self.assertIn('n', r)
1888        self.assertEqual(r['n'], 5)
1889        self.assertIn('data', r)
1890        r = r['data']
1891        self.assertIsInstance(r, bytes)
1892        self.assertEqual(r, s)
1893        # update as bytes
1894        s += b"and now even more \x00 nasty \t stuff!\f"
1895        r = self.db.update('bytea_test', n=5, data=s)
1896        self.assertIsInstance(r, dict)
1897        self.assertIn('n', r)
1898        self.assertEqual(r['n'], 5)
1899        self.assertIn('data', r)
1900        r = r['data']
1901        self.assertIsInstance(r, bytes)
1902        self.assertEqual(r, s)
1903        r = query('select * from bytea_test where n=5').getresult()
1904        self.assertEqual(len(r), 1)
1905        r = r[0]
1906        self.assertEqual(len(r), 2)
1907        self.assertEqual(r[0], 5)
1908        r = r[1]
1909        self.assertIsInstance(r, str)
1910        r = self.db.unescape_bytea(r)
1911        self.assertIsInstance(r, bytes)
1912        self.assertEqual(r, s)
1913        r = self.db.get('bytea_test', dict(n=5))
1914        self.assertIsInstance(r, dict)
1915        self.assertIn('n', r)
1916        self.assertEqual(r['n'], 5)
1917        self.assertIn('data', r)
1918        r = r['data']
1919        self.assertIsInstance(r, bytes)
1920        self.assertEqual(r, s)
1921
1922
1923class TestDBClassNonStdOpts(TestDBClass):
1924    """Test the methods of the DB class with non-standard global options."""
1925
1926    @classmethod
1927    def setUpClass(cls):
1928        cls.saved_options = {}
1929        cls.set_option('decimal', float)
1930        not_bool = not pg.get_bool()
1931        cls.set_option('bool', not_bool)
1932        unnamed_result = lambda q: q.getresult()
1933        cls.set_option('namedresult', unnamed_result)
1934        super(TestDBClassNonStdOpts, cls).setUpClass()
1935
1936    @classmethod
1937    def tearDownClass(cls):
1938        super(TestDBClassNonStdOpts, cls).tearDownClass()
1939        cls.reset_option('namedresult')
1940        cls.reset_option('bool')
1941        cls.reset_option('decimal')
1942
1943    @classmethod
1944    def set_option(cls, option, value):
1945        cls.saved_options[option] = getattr(pg, 'get_' + option)()
1946        return getattr(pg, 'set_' + option)(value)
1947
1948    @classmethod
1949    def reset_option(cls, option):
1950        return getattr(pg, 'set_' + option)(cls.saved_options[option])
1951
1952
1953class TestSchemas(unittest.TestCase):
1954    """Test correct handling of schemas (namespaces)."""
1955
1956    @classmethod
1957    def setUpClass(cls):
1958        db = DB()
1959        query = db.query
1960        query("set client_min_messages=warning")
1961        for num_schema in range(5):
1962            if num_schema:
1963                schema = "s%d" % num_schema
1964                query("drop schema if exists %s cascade" % (schema,))
1965                try:
1966                    query("create schema %s" % (schema,))
1967                except pg.ProgrammingError:
1968                    raise RuntimeError("The test user cannot create schemas.\n"
1969                        "Grant create on database %s to the user"
1970                        " for running these tests." % dbname)
1971            else:
1972                schema = "public"
1973                query("drop table if exists %s.t" % (schema,))
1974                query("drop table if exists %s.t%d" % (schema, num_schema))
1975            query("create table %s.t with oids as select 1 as n, %d as d"
1976                  % (schema, num_schema))
1977            query("create table %s.t%d with oids as select 1 as n, %d as d"
1978                  % (schema, num_schema, num_schema))
1979        db.close()
1980
1981    @classmethod
1982    def tearDownClass(cls):
1983        db = DB()
1984        query = db.query
1985        query("set client_min_messages=warning")
1986        for num_schema in range(5):
1987            if num_schema:
1988                schema = "s%d" % num_schema
1989                query("drop schema %s cascade" % (schema,))
1990            else:
1991                schema = "public"
1992                query("drop table %s.t" % (schema,))
1993                query("drop table %s.t%d" % (schema, num_schema))
1994        db.close()
1995
1996    def setUp(self):
1997        self.db = DB()
1998        self.db.query("set client_min_messages=warning")
1999
2000    def tearDown(self):
2001        self.doCleanups()
2002        self.db.close()
2003
2004    def testGetTables(self):
2005        tables = self.db.get_tables()
2006        for num_schema in range(5):
2007            if num_schema:
2008                schema = "s" + str(num_schema)
2009            else:
2010                schema = "public"
2011            for t in (schema + ".t",
2012                    schema + ".t" + str(num_schema)):
2013                self.assertIn(t, tables)
2014
2015    def testGetAttnames(self):
2016        get_attnames = self.db.get_attnames
2017        query = self.db.query
2018        result = {'oid': 'int', 'd': 'int', 'n': 'int'}
2019        r = get_attnames("t")
2020        self.assertEqual(r, result)
2021        r = get_attnames("s4.t4")
2022        self.assertEqual(r, result)
2023        query("drop table if exists s3.t3m")
2024        self.addCleanup(query, "drop table s3.t3m")
2025        query("create table s3.t3m with oids as select 1 as m")
2026        result_m = {'oid': 'int', 'm': 'int'}
2027        r = get_attnames("s3.t3m")
2028        self.assertEqual(r, result_m)
2029        query("set search_path to s1,s3")
2030        r = get_attnames("t3")
2031        self.assertEqual(r, result)
2032        r = get_attnames("t3m")
2033        self.assertEqual(r, result_m)
2034
2035    def testGet(self):
2036        get = self.db.get
2037        query = self.db.query
2038        PrgError = pg.ProgrammingError
2039        self.assertEqual(get("t", 1, 'n')['d'], 0)
2040        self.assertEqual(get("t0", 1, 'n')['d'], 0)
2041        self.assertEqual(get("public.t", 1, 'n')['d'], 0)
2042        self.assertEqual(get("public.t0", 1, 'n')['d'], 0)
2043        self.assertRaises(PrgError, get, "public.t1", 1, 'n')
2044        self.assertEqual(get("s1.t1", 1, 'n')['d'], 1)
2045        self.assertEqual(get("s3.t", 1, 'n')['d'], 3)
2046        query("set search_path to s2,s4")
2047        self.assertRaises(PrgError, get, "t1", 1, 'n')
2048        self.assertEqual(get("t4", 1, 'n')['d'], 4)
2049        self.assertRaises(PrgError, get, "t3", 1, 'n')
2050        self.assertEqual(get("t", 1, 'n')['d'], 2)
2051        self.assertEqual(get("s3.t3", 1, 'n')['d'], 3)
2052        query("set search_path to s1,s3")
2053        self.assertRaises(PrgError, get, "t2", 1, 'n')
2054        self.assertEqual(get("t3", 1, 'n')['d'], 3)
2055        self.assertRaises(PrgError, get, "t4", 1, 'n')
2056        self.assertEqual(get("t", 1, 'n')['d'], 1)
2057        self.assertEqual(get("s4.t4", 1, 'n')['d'], 4)
2058
2059    def testMunging(self):
2060        get = self.db.get
2061        query = self.db.query
2062        r = get("t", 1, 'n')
2063        self.assertIn('oid(t)', r)
2064        query("set search_path to s2")
2065        r = get("t2", 1, 'n')
2066        self.assertIn('oid(t2)', r)
2067        query("set search_path to s3")
2068        r = get("t", 1, 'n')
2069        self.assertIn('oid(t)', r)
2070
2071
2072class TestDebug(unittest.TestCase):
2073    """Test the debug attribute of the DB class."""
2074
2075    def setUp(self):
2076        self.db = DB()
2077        self.query = self.db.query
2078        self.debug = self.db.debug
2079        self.output = StringIO()
2080        self.stdout, sys.stdout = sys.stdout, self.output
2081
2082    def tearDown(self):
2083        sys.stdout = self.stdout
2084        self.output.close()
2085        self.db.debug = debug
2086        self.db.close()
2087
2088    def get_output(self):
2089        return self.output.getvalue()
2090
2091    def send_queries(self):
2092        self.db.query("select 1")
2093        self.db.query("select 2")
2094
2095    def testDebugDefault(self):
2096        if debug:
2097            self.assertEqual(self.db.debug, debug)
2098        else:
2099            self.assertIsNone(self.db.debug)
2100
2101    def testDebugIsFalse(self):
2102        self.db.debug = False
2103        self.send_queries()
2104        self.assertEqual(self.get_output(), "")
2105
2106    def testDebugIsTrue(self):
2107        self.db.debug = True
2108        self.send_queries()
2109        self.assertEqual(self.get_output(), "select 1\nselect 2\n")
2110
2111    def testDebugIsString(self):
2112        self.db.debug = "Test with string: %s."
2113        self.send_queries()
2114        self.assertEqual(self.get_output(),
2115            "Test with string: select 1.\nTest with string: select 2.\n")
2116
2117    def testDebugIsFileLike(self):
2118        with tempfile.TemporaryFile('w+') as debug_file:
2119            self.db.debug = debug_file
2120            self.send_queries()
2121            debug_file.seek(0)
2122            output = debug_file.read()
2123            self.assertEqual(output, "select 1\nselect 2\n")
2124            self.assertEqual(self.get_output(), "")
2125
2126    def testDebugIsCallable(self):
2127        output = []
2128        self.db.debug = output.append
2129        self.db.query("select 1")
2130        self.db.query("select 2")
2131        self.assertEqual(output, ["select 1", "select 2"])
2132        self.assertEqual(self.get_output(), "")
2133
2134
2135if __name__ == '__main__':
2136    unittest.main()
Note: See TracBrowser for help on using the repository browser.