source: trunk/tests/test_classic_dbwrapper.py @ 763

Last change on this file since 763 was 763, checked in by cito, 4 years ago

Achieve 100% test coverage for pg module on the trunk

Note that some lines are only covered in certain Pg or Py versions,
so you need to run tests with different versions to be sure.

Also added another synonym for transaction methods,
you can now pick your favorite for all three of them.

  • Property svn:executable set to *
  • Property svn:keywords set to Id
File size: 93.4 KB
Line 
1#! /usr/bin/python
2# -*- coding: utf-8 -*-
3
4"""Test the classic PyGreSQL interface.
5
6Sub-tests for the DB wrapper object.
7
8Contributed by Christoph Zwerschke.
9
10These tests need a database to test against.
11
12"""
13
14try:
15    import unittest2 as unittest  # for Python < 2.7
16except ImportError:
17    import unittest
18
19import os
20import sys
21import tempfile
22
23import pg  # the module under test
24
25from decimal import Decimal
26
27# We need a database to test against.  If LOCAL_PyGreSQL.py exists we will
28# get our information from that.  Otherwise we use the defaults.
29# The current user must have create schema privilege on the database.
30dbname = 'unittest'
31dbhost = None
32dbport = 5432
33
34debug = False  # let DB wrapper print debugging output
35
36try:
37    from .LOCAL_PyGreSQL import *
38except (ImportError, ValueError):
39    try:
40        from LOCAL_PyGreSQL import *
41    except ImportError:
42        pass
43
44try:
45    long
46except NameError:  # Python >= 3.0
47    long = int
48
49try:
50    unicode
51except NameError:  # Python >= 3.0
52    unicode = str
53
54try:
55    from collections import OrderedDict
56except ImportError:  # Python 2.6 or 3.0
57    OrderedDict = dict
58
59if str is bytes:
60    from StringIO import StringIO
61else:
62    from io import StringIO
63
64windows = os.name == 'nt'
65
66# There is a known a bug in libpq under Windows which can cause
67# the interface to crash when calling PQhost():
68do_not_ask_for_host = windows
69do_not_ask_for_host_reason = 'libpq issue on Windows'
70
71
72def DB():
73    """Create a DB wrapper object connecting to the test database."""
74    db = pg.DB(dbname, dbhost, dbport)
75    if debug:
76        db.debug = debug
77    db.query("set client_min_messages=warning")
78    return db
79
80
81class TestDBClassBasic(unittest.TestCase):
82    """Test existence of the DB class wrapped pg connection methods."""
83
84    def setUp(self):
85        self.db = DB()
86
87    def tearDown(self):
88        try:
89            self.db.close()
90        except pg.InternalError:
91            pass
92
93    def testAllDBAttributes(self):
94        attributes = [
95            'abort',
96            'begin',
97            'cancel', 'clear', 'close', 'commit',
98            'db', 'dbname', 'debug', 'delete',
99            'end', 'endcopy', 'error',
100            'escape_bytea', 'escape_identifier',
101            'escape_literal', 'escape_string',
102            'fileno',
103            'get', 'get_attnames', 'get_databases',
104            'get_notice_receiver', 'get_parameter',
105            'get_relations', 'get_tables',
106            'getline', 'getlo', 'getnotify',
107            'has_table_privilege', 'host',
108            'insert', 'inserttable',
109            'locreate', 'loimport',
110            'notification_handler',
111            'options',
112            'parameter', 'pkey', 'port',
113            'protocol_version', 'putline',
114            'query',
115            'release', 'reopen', 'reset', 'rollback',
116            'savepoint', 'server_version',
117            'set_notice_receiver', 'set_parameter',
118            'source', 'start', 'status',
119            'transaction', 'truncate',
120            'unescape_bytea', 'update', 'upsert',
121            'use_regtypes', 'user',
122        ]
123        db_attributes = [a for a in dir(self.db)
124            if not a.startswith('_')]
125        self.assertEqual(attributes, db_attributes)
126
127    def testAttributeDb(self):
128        self.assertEqual(self.db.db.db, dbname)
129
130    def testAttributeDbname(self):
131        self.assertEqual(self.db.dbname, dbname)
132
133    def testAttributeError(self):
134        error = self.db.error
135        self.assertTrue(not error or 'krb5_' in error)
136        self.assertEqual(self.db.error, self.db.db.error)
137
138    @unittest.skipIf(do_not_ask_for_host, do_not_ask_for_host_reason)
139    def testAttributeHost(self):
140        def_host = 'localhost'
141        host = self.db.host
142        self.assertIsInstance(host, str)
143        self.assertEqual(host, dbhost or def_host)
144        self.assertEqual(host, self.db.db.host)
145
146    def testAttributeOptions(self):
147        no_options = ''
148        options = self.db.options
149        self.assertEqual(options, no_options)
150        self.assertEqual(options, self.db.db.options)
151
152    def testAttributePort(self):
153        def_port = 5432
154        port = self.db.port
155        self.assertIsInstance(port, int)
156        self.assertEqual(port, dbport or def_port)
157        self.assertEqual(port, self.db.db.port)
158
159    def testAttributeProtocolVersion(self):
160        protocol_version = self.db.protocol_version
161        self.assertIsInstance(protocol_version, int)
162        self.assertTrue(2 <= protocol_version < 4)
163        self.assertEqual(protocol_version, self.db.db.protocol_version)
164
165    def testAttributeServerVersion(self):
166        server_version = self.db.server_version
167        self.assertIsInstance(server_version, int)
168        self.assertTrue(70400 <= server_version < 100000)
169        self.assertEqual(server_version, self.db.db.server_version)
170
171    def testAttributeStatus(self):
172        status_ok = 1
173        status = self.db.status
174        self.assertIsInstance(status, int)
175        self.assertEqual(status, status_ok)
176        self.assertEqual(status, self.db.db.status)
177
178    def testAttributeUser(self):
179        no_user = 'Deprecated facility'
180        user = self.db.user
181        self.assertTrue(user)
182        self.assertIsInstance(user, str)
183        self.assertNotEqual(user, no_user)
184        self.assertEqual(user, self.db.db.user)
185
186    def testMethodEscapeLiteral(self):
187        self.assertEqual(self.db.escape_literal(''), "''")
188
189    def testMethodEscapeIdentifier(self):
190        self.assertEqual(self.db.escape_identifier(''), '""')
191
192    def testMethodEscapeString(self):
193        self.assertEqual(self.db.escape_string(''), '')
194
195    def testMethodEscapeBytea(self):
196        self.assertEqual(self.db.escape_bytea('').replace(
197            '\\x', '').replace('\\', ''), '')
198
199    def testMethodUnescapeBytea(self):
200        self.assertEqual(self.db.unescape_bytea(''), b'')
201
202    def testMethodQuery(self):
203        query = self.db.query
204        query("select 1+1")
205        query("select 1+$1+$2", 2, 3)
206        query("select 1+$1+$2", (2, 3))
207        query("select 1+$1+$2", [2, 3])
208        query("select 1+$1", 1)
209
210    def testMethodQueryEmpty(self):
211        self.assertRaises(ValueError, self.db.query, '')
212
213    def testMethodQueryProgrammingError(self):
214        try:
215            self.db.query("select 1/0")
216        except pg.ProgrammingError as error:
217            self.assertEqual(error.sqlstate, '22012')
218
219    def testMethodEndcopy(self):
220        try:
221            self.db.endcopy()
222        except IOError:
223            pass
224
225    def testMethodClose(self):
226        self.db.close()
227        try:
228            self.db.reset()
229        except pg.Error:
230            pass
231        else:
232            self.fail('Reset should give an error for a closed connection')
233        self.assertIsNone(self.db.db)
234        self.assertRaises(pg.InternalError, self.db.close)
235        self.assertRaises(pg.InternalError, self.db.query, 'select 1')
236        self.assertRaises(pg.InternalError, getattr, self.db, 'status')
237        self.assertRaises(pg.InternalError, getattr, self.db, 'error')
238        self.assertRaises(pg.InternalError, getattr, self.db, 'absent')
239
240    def testMethodReset(self):
241        con = self.db.db
242        self.db.reset()
243        self.assertIs(self.db.db, con)
244        self.db.query("select 1+1")
245        self.db.close()
246        self.assertRaises(pg.InternalError, self.db.reset)
247
248    def testMethodReopen(self):
249        con = self.db.db
250        self.db.reopen()
251        self.assertIsNot(self.db.db, con)
252        con = self.db.db
253        self.db.query("select 1+1")
254        self.db.close()
255        self.db.reopen()
256        self.assertIsNot(self.db.db, con)
257        self.db.query("select 1+1")
258        self.db.close()
259
260    def testExistingConnection(self):
261        db = pg.DB(self.db.db)
262        self.assertEqual(self.db.db, db.db)
263        self.assertTrue(db.db)
264        db.close()
265        self.assertTrue(db.db)
266        db.reopen()
267        self.assertTrue(db.db)
268        db.close()
269        self.assertTrue(db.db)
270        db = pg.DB(self.db)
271        self.assertEqual(self.db.db, db.db)
272        db = pg.DB(db=self.db.db)
273        self.assertEqual(self.db.db, db.db)
274
275        class DB2:
276            pass
277
278        db2 = DB2()
279        db2._cnx = self.db.db
280        db = pg.DB(db2)
281        self.assertEqual(self.db.db, db.db)
282
283
284class TestDBClass(unittest.TestCase):
285    """Test the methods of the DB class wrapped pg connection."""
286
287    @classmethod
288    def setUpClass(cls):
289        db = DB()
290        db.query("drop table if exists test cascade")
291        db.query("create table test ("
292            "i2 smallint, i4 integer, i8 bigint,"
293            " d numeric, f4 real, f8 double precision, m money,"
294            " v4 varchar(4), c4 char(4), t text)")
295        db.query("create or replace view test_view as"
296            " select i4, v4 from test")
297        db.close()
298
299    @classmethod
300    def tearDownClass(cls):
301        db = DB()
302        db.query("drop table test cascade")
303        db.close()
304
305    def setUp(self):
306        self.db = DB()
307        query = self.db.query
308        query('set client_encoding=utf8')
309        query('set standard_conforming_strings=on')
310        query("set lc_monetary='C'")
311        query("set datestyle='ISO,YMD'")
312        query('set bytea_output=hex')
313
314    def tearDown(self):
315        self.doCleanups()
316        self.db.close()
317
318    def testClassName(self):
319        self.assertEqual(self.db.__class__.__name__, 'DB')
320
321    def testModuleName(self):
322        self.assertEqual(self.db.__module__, 'pg')
323        self.assertEqual(self.db.__class__.__module__, 'pg')
324
325    def testEscapeLiteral(self):
326        f = self.db.escape_literal
327        r = f(b"plain")
328        self.assertIsInstance(r, bytes)
329        self.assertEqual(r, b"'plain'")
330        r = f(u"plain")
331        self.assertIsInstance(r, unicode)
332        self.assertEqual(r, u"'plain'")
333        r = f(u"that's kÀse".encode('utf-8'))
334        self.assertIsInstance(r, bytes)
335        self.assertEqual(r, u"'that''s kÀse'".encode('utf-8'))
336        r = f(u"that's kÀse")
337        self.assertIsInstance(r, unicode)
338        self.assertEqual(r, u"'that''s kÀse'")
339        self.assertEqual(f(r"It's fine to have a \ inside."),
340            r" E'It''s fine to have a \\ inside.'")
341        self.assertEqual(f('No "quotes" must be escaped.'),
342            "'No \"quotes\" must be escaped.'")
343
344    def testEscapeIdentifier(self):
345        f = self.db.escape_identifier
346        r = f(b"plain")
347        self.assertIsInstance(r, bytes)
348        self.assertEqual(r, b'"plain"')
349        r = f(u"plain")
350        self.assertIsInstance(r, unicode)
351        self.assertEqual(r, u'"plain"')
352        r = f(u"that's kÀse".encode('utf-8'))
353        self.assertIsInstance(r, bytes)
354        self.assertEqual(r, u'"that\'s kÀse"'.encode('utf-8'))
355        r = f(u"that's kÀse")
356        self.assertIsInstance(r, unicode)
357        self.assertEqual(r, u'"that\'s kÀse"')
358        self.assertEqual(f(r"It's fine to have a \ inside."),
359            '"It\'s fine to have a \\ inside."')
360        self.assertEqual(f('All "quotes" must be escaped.'),
361            '"All ""quotes"" must be escaped."')
362
363    def testEscapeString(self):
364        f = self.db.escape_string
365        r = f(b"plain")
366        self.assertIsInstance(r, bytes)
367        self.assertEqual(r, b"plain")
368        r = f(u"plain")
369        self.assertIsInstance(r, unicode)
370        self.assertEqual(r, u"plain")
371        r = f(u"that's kÀse".encode('utf-8'))
372        self.assertIsInstance(r, bytes)
373        self.assertEqual(r, u"that''s kÀse".encode('utf-8'))
374        r = f(u"that's kÀse")
375        self.assertIsInstance(r, unicode)
376        self.assertEqual(r, u"that''s kÀse")
377        self.assertEqual(f(r"It's fine to have a \ inside."),
378            r"It''s fine to have a \ inside.")
379
380    def testEscapeBytea(self):
381        f = self.db.escape_bytea
382        # note that escape_byte always returns hex output since Pg 9.0,
383        # regardless of the bytea_output setting
384        r = f(b'plain')
385        self.assertIsInstance(r, bytes)
386        self.assertEqual(r, b'\\x706c61696e')
387        r = f(u'plain')
388        self.assertIsInstance(r, unicode)
389        self.assertEqual(r, u'\\x706c61696e')
390        r = f(u"das is' kÀse".encode('utf-8'))
391        self.assertIsInstance(r, bytes)
392        self.assertEqual(r, b'\\x64617320697327206bc3a47365')
393        r = f(u"das is' kÀse")
394        self.assertIsInstance(r, unicode)
395        self.assertEqual(r, u'\\x64617320697327206bc3a47365')
396        self.assertEqual(f(b'O\x00ps\xff!'), b'\\x4f007073ff21')
397
398    def testUnescapeBytea(self):
399        f = self.db.unescape_bytea
400        r = f(b'plain')
401        self.assertIsInstance(r, bytes)
402        self.assertEqual(r, b'plain')
403        r = f(u'plain')
404        self.assertIsInstance(r, bytes)
405        self.assertEqual(r, b'plain')
406        r = f(b"das is' k\\303\\244se")
407        self.assertIsInstance(r, bytes)
408        self.assertEqual(r, u"das is' kÀse".encode('utf8'))
409        r = f(u"das is' k\\303\\244se")
410        self.assertIsInstance(r, bytes)
411        self.assertEqual(r, u"das is' kÀse".encode('utf8'))
412        self.assertEqual(f(r'O\\000ps\\377!'), b'O\\000ps\\377!')
413        self.assertEqual(f(r'\\x706c61696e'), b'\\x706c61696e')
414        self.assertEqual(f(r'\\x746861742773206be47365'),
415            b'\\x746861742773206be47365')
416        self.assertEqual(f(r'\\x4f007073ff21'), b'\\x4f007073ff21')
417
418    def testGetParameter(self):
419        f = self.db.get_parameter
420        self.assertRaises(TypeError, f)
421        self.assertRaises(TypeError, f, None)
422        self.assertRaises(TypeError, f, 42)
423        self.assertRaises(TypeError, f, '')
424        self.assertRaises(TypeError, f, [])
425        self.assertRaises(TypeError, f, [''])
426        self.assertRaises(pg.ProgrammingError, f, 'this_does_not_exist')
427        r = f('standard_conforming_strings')
428        self.assertEqual(r, 'on')
429        r = f('lc_monetary')
430        self.assertEqual(r, 'C')
431        r = f('datestyle')
432        self.assertEqual(r, 'ISO, YMD')
433        r = f('bytea_output')
434        self.assertEqual(r, 'hex')
435        r = f(['bytea_output', 'lc_monetary'])
436        self.assertIsInstance(r, list)
437        self.assertEqual(r, ['hex', 'C'])
438        r = f(('standard_conforming_strings', 'datestyle', 'bytea_output'))
439        self.assertEqual(r, ['on', 'ISO, YMD', 'hex'])
440        r = f(set(['bytea_output', 'lc_monetary']))
441        self.assertIsInstance(r, dict)
442        self.assertEqual(r, {'bytea_output': 'hex', 'lc_monetary': 'C'})
443        r = f(set(['Bytea_Output', ' LC_Monetary ']))
444        self.assertIsInstance(r, dict)
445        self.assertEqual(r, {'Bytea_Output': 'hex', ' LC_Monetary ': 'C'})
446        s = dict.fromkeys(('bytea_output', 'lc_monetary'))
447        r = f(s)
448        self.assertIs(r, s)
449        self.assertEqual(r, {'bytea_output': 'hex', 'lc_monetary': 'C'})
450        s = dict.fromkeys(('Bytea_Output', ' LC_Monetary '))
451        r = f(s)
452        self.assertIs(r, s)
453        self.assertEqual(r, {'Bytea_Output': 'hex', ' LC_Monetary ': 'C'})
454
455    def testGetParameterServerVersion(self):
456        r = self.db.get_parameter('server_version_num')
457        self.assertIsInstance(r, str)
458        s = self.db.server_version
459        self.assertIsInstance(s, int)
460        self.assertEqual(r, str(s))
461
462    def testGetParameterAll(self):
463        f = self.db.get_parameter
464        r = f('all')
465        self.assertIsInstance(r, dict)
466        self.assertEqual(r['standard_conforming_strings'], 'on')
467        self.assertEqual(r['lc_monetary'], 'C')
468        self.assertEqual(r['DateStyle'], 'ISO, YMD')
469        self.assertEqual(r['bytea_output'], 'hex')
470
471    def testSetParameter(self):
472        f = self.db.set_parameter
473        g = self.db.get_parameter
474        self.assertRaises(TypeError, f)
475        self.assertRaises(TypeError, f, None)
476        self.assertRaises(TypeError, f, 42)
477        self.assertRaises(TypeError, f, '')
478        self.assertRaises(TypeError, f, [])
479        self.assertRaises(TypeError, f, [''])
480        self.assertRaises(ValueError, f, 'all', 'invalid')
481        self.assertRaises(ValueError, f, {
482            'invalid1': 'value1', 'invalid2': 'value2'}, 'value')
483        self.assertRaises(pg.ProgrammingError, f, 'this_does_not_exist')
484        f('standard_conforming_strings', 'off')
485        self.assertEqual(g('standard_conforming_strings'), 'off')
486        f('datestyle', 'ISO, DMY')
487        self.assertEqual(g('datestyle'), 'ISO, DMY')
488        f(['standard_conforming_strings', 'datestyle'], ['on', 'ISO, DMY'])
489        self.assertEqual(g('standard_conforming_strings'), 'on')
490        self.assertEqual(g('datestyle'), 'ISO, DMY')
491        f(['default_with_oids', 'standard_conforming_strings'], 'off')
492        self.assertEqual(g('default_with_oids'), 'off')
493        self.assertEqual(g('standard_conforming_strings'), 'off')
494        f(('standard_conforming_strings', 'datestyle'), ('on', 'ISO, YMD'))
495        self.assertEqual(g('standard_conforming_strings'), 'on')
496        self.assertEqual(g('datestyle'), 'ISO, YMD')
497        f(('default_with_oids', 'standard_conforming_strings'), 'off')
498        self.assertEqual(g('default_with_oids'), 'off')
499        self.assertEqual(g('standard_conforming_strings'), 'off')
500        f(set(['default_with_oids', 'standard_conforming_strings']), 'on')
501        self.assertEqual(g('default_with_oids'), 'on')
502        self.assertEqual(g('standard_conforming_strings'), 'on')
503        self.assertRaises(ValueError, f, set([ 'default_with_oids',
504            'standard_conforming_strings']), ['off', 'on'])
505        f(set(['default_with_oids', 'standard_conforming_strings']),
506            ['off', 'off'])
507        self.assertEqual(g('default_with_oids'), 'off')
508        self.assertEqual(g('standard_conforming_strings'), 'off')
509        f({'standard_conforming_strings': 'on', 'datestyle': 'ISO, YMD'})
510        self.assertEqual(g('standard_conforming_strings'), 'on')
511        self.assertEqual(g('datestyle'), 'ISO, YMD')
512
513    def testResetParameter(self):
514        db = DB()
515        f = db.set_parameter
516        g = db.get_parameter
517        r = g('default_with_oids')
518        self.assertIn(r, ('on', 'off'))
519        dwi, not_dwi = r, 'off' if r == 'on' else 'on'
520        r = g('standard_conforming_strings')
521        self.assertIn(r, ('on', 'off'))
522        scs, not_scs = r, 'off' if r == 'on' else 'on'
523        f('default_with_oids', not_dwi)
524        f('standard_conforming_strings', not_scs)
525        self.assertEqual(g('default_with_oids'), not_dwi)
526        self.assertEqual(g('standard_conforming_strings'), not_scs)
527        f('default_with_oids')
528        f('standard_conforming_strings', None)
529        self.assertEqual(g('default_with_oids'), dwi)
530        self.assertEqual(g('standard_conforming_strings'), scs)
531        f('default_with_oids', not_dwi)
532        f('standard_conforming_strings', not_scs)
533        self.assertEqual(g('default_with_oids'), not_dwi)
534        self.assertEqual(g('standard_conforming_strings'), not_scs)
535        f(['default_with_oids', 'standard_conforming_strings'], None)
536        self.assertEqual(g('default_with_oids'), dwi)
537        self.assertEqual(g('standard_conforming_strings'), scs)
538        f('default_with_oids', not_dwi)
539        f('standard_conforming_strings', not_scs)
540        self.assertEqual(g('default_with_oids'), not_dwi)
541        self.assertEqual(g('standard_conforming_strings'), not_scs)
542        f(('default_with_oids', 'standard_conforming_strings'))
543        self.assertEqual(g('default_with_oids'), dwi)
544        self.assertEqual(g('standard_conforming_strings'), scs)
545        f('default_with_oids', not_dwi)
546        f('standard_conforming_strings', not_scs)
547        self.assertEqual(g('default_with_oids'), not_dwi)
548        self.assertEqual(g('standard_conforming_strings'), not_scs)
549        f(set(['default_with_oids', 'standard_conforming_strings']))
550        self.assertEqual(g('default_with_oids'), dwi)
551        self.assertEqual(g('standard_conforming_strings'), scs)
552
553    def testResetParameterAll(self):
554        db = DB()
555        f = db.set_parameter
556        self.assertRaises(ValueError, f, 'all', 0)
557        self.assertRaises(ValueError, f, 'all', 'off')
558        g = db.get_parameter
559        r = g('default_with_oids')
560        self.assertIn(r, ('on', 'off'))
561        dwi, not_dwi = r, 'off' if r == 'on' else 'on'
562        r = g('standard_conforming_strings')
563        self.assertIn(r, ('on', 'off'))
564        scs, not_scs = r, 'off' if r == 'on' else 'on'
565        f('default_with_oids', not_dwi)
566        f('standard_conforming_strings', not_scs)
567        self.assertEqual(g('default_with_oids'), not_dwi)
568        self.assertEqual(g('standard_conforming_strings'), not_scs)
569        f('all')
570        self.assertEqual(g('default_with_oids'), dwi)
571        self.assertEqual(g('standard_conforming_strings'), scs)
572
573    def testSetParameterLocal(self):
574        f = self.db.set_parameter
575        g = self.db.get_parameter
576        self.assertEqual(g('standard_conforming_strings'), 'on')
577        self.db.begin()
578        f('standard_conforming_strings', 'off', local=True)
579        self.assertEqual(g('standard_conforming_strings'), 'off')
580        self.db.end()
581        self.assertEqual(g('standard_conforming_strings'), 'on')
582
583    def testSetParameterSession(self):
584        f = self.db.set_parameter
585        g = self.db.get_parameter
586        self.assertEqual(g('standard_conforming_strings'), 'on')
587        self.db.begin()
588        f('standard_conforming_strings', 'off', local=False)
589        self.assertEqual(g('standard_conforming_strings'), 'off')
590        self.db.end()
591        self.assertEqual(g('standard_conforming_strings'), 'off')
592
593    def testReset(self):
594        db = DB()
595        default_datestyle = db.get_parameter('datestyle')
596        changed_datestyle = 'ISO, DMY'
597        if changed_datestyle == default_datestyle:
598            changed_datestyle == 'ISO, YMD'
599        self.db.set_parameter('datestyle', changed_datestyle)
600        r = self.db.get_parameter('datestyle')
601        self.assertEqual(r, changed_datestyle)
602        con = self.db.db
603        q = con.query("show datestyle")
604        self.db.reset()
605        r = q.getresult()[0][0]
606        self.assertEqual(r, changed_datestyle)
607        q = con.query("show datestyle")
608        r = q.getresult()[0][0]
609        self.assertEqual(r, default_datestyle)
610        r = self.db.get_parameter('datestyle')
611        self.assertEqual(r, default_datestyle)
612
613    def testReopen(self):
614        db = DB()
615        default_datestyle = db.get_parameter('datestyle')
616        changed_datestyle = 'ISO, DMY'
617        if changed_datestyle == default_datestyle:
618            changed_datestyle == 'ISO, YMD'
619        self.db.set_parameter('datestyle', changed_datestyle)
620        r = self.db.get_parameter('datestyle')
621        self.assertEqual(r, changed_datestyle)
622        con = self.db.db
623        q = con.query("show datestyle")
624        self.db.reopen()
625        r = q.getresult()[0][0]
626        self.assertEqual(r, changed_datestyle)
627        self.assertRaises(TypeError, getattr, con, 'query')
628        r = self.db.get_parameter('datestyle')
629        self.assertEqual(r, default_datestyle)
630
631    def testQuery(self):
632        query = self.db.query
633        query("drop table if exists test_table")
634        self.addCleanup(query, "drop table test_table")
635        q = "create table test_table (n integer) with oids"
636        r = query(q)
637        self.assertIsNone(r)
638        q = "insert into test_table values (1)"
639        r = query(q)
640        self.assertIsInstance(r, int)
641        q = "insert into test_table select 2"
642        r = query(q)
643        self.assertIsInstance(r, int)
644        oid = r
645        q = "select oid from test_table where n=2"
646        r = query(q).getresult()
647        self.assertEqual(len(r), 1)
648        r = r[0]
649        self.assertEqual(len(r), 1)
650        r = r[0]
651        self.assertEqual(r, oid)
652        q = "insert into test_table select 3 union select 4 union select 5"
653        r = query(q)
654        self.assertIsInstance(r, str)
655        self.assertEqual(r, '3')
656        q = "update test_table set n=4 where n<5"
657        r = query(q)
658        self.assertIsInstance(r, str)
659        self.assertEqual(r, '4')
660        q = "delete from test_table"
661        r = query(q)
662        self.assertIsInstance(r, str)
663        self.assertEqual(r, '5')
664
665    def testMultipleQueries(self):
666        self.assertEqual(self.db.query(
667            "create temporary table test_multi (n integer);"
668            "insert into test_multi values (4711);"
669            "select n from test_multi").getresult()[0][0], 4711)
670
671    def testQueryWithParams(self):
672        query = self.db.query
673        query("drop table if exists test_table")
674        self.addCleanup(query, "drop table test_table")
675        q = "create table test_table (n1 integer, n2 integer) with oids"
676        query(q)
677        q = "insert into test_table values ($1, $2)"
678        r = query(q, (1, 2))
679        self.assertIsInstance(r, int)
680        r = query(q, [3, 4])
681        self.assertIsInstance(r, int)
682        r = query(q, [5, 6])
683        self.assertIsInstance(r, int)
684        q = "select * from test_table order by 1, 2"
685        self.assertEqual(query(q).getresult(),
686            [(1, 2), (3, 4), (5, 6)])
687        q = "select * from test_table where n1=$1 and n2=$2"
688        self.assertEqual(query(q, 3, 4).getresult(), [(3, 4)])
689        q = "update test_table set n2=$2 where n1=$1"
690        r = query(q, 3, 7)
691        self.assertEqual(r, '1')
692        q = "select * from test_table order by 1, 2"
693        self.assertEqual(query(q).getresult(),
694            [(1, 2), (3, 7), (5, 6)])
695        q = "delete from test_table where n2!=$1"
696        r = query(q, 4)
697        self.assertEqual(r, '3')
698
699    def testEmptyQuery(self):
700        self.assertRaises(ValueError, self.db.query, '')
701
702    def testQueryProgrammingError(self):
703        try:
704            self.db.query("select 1/0")
705        except pg.ProgrammingError as error:
706            self.assertEqual(error.sqlstate, '22012')
707
708    def testPkey(self):
709        query = self.db.query
710        pkey = self.db.pkey
711        for t in ('pkeytest', 'primary key test'):
712            for n in range(7):
713                query('drop table if exists "%s%d"' % (t, n))
714                self.addCleanup(query, 'drop table "%s%d"' % (t, n))
715            query('create table "%s0" ('
716                "a smallint)" % t)
717            query('create table "%s1" ('
718                "b smallint primary key)" % t)
719            query('create table "%s2" ('
720                "c smallint, d smallint primary key)" % t)
721            query('create table "%s3" ('
722                "e smallint, f smallint, g smallint,"
723                " h smallint, i smallint,"
724                " primary key (f, h))" % t)
725            query('create table "%s4" ('
726                "more_than_one_letter varchar primary key)" % t)
727            query('create table "%s5" ('
728                '"with space" date primary key)' % t)
729            query('create table "%s6" ('
730                'a_very_long_column_name varchar,'
731                ' "with space" date,'
732                ' "42" int,'
733                " primary key (a_very_long_column_name,"
734                ' "with space", "42"))' % t)
735            self.assertRaises(KeyError, pkey, '%s0' % t)
736            self.assertEqual(pkey('%s1' % t), 'b')
737            self.assertEqual(pkey('%s2' % t), 'd')
738            r = pkey('%s3' % t)
739            self.assertIsInstance(r, frozenset)
740            self.assertEqual(r, frozenset('fh'))
741            self.assertEqual(pkey('%s4' % t), 'more_than_one_letter')
742            self.assertEqual(pkey('%s5' % t), 'with space')
743            r = pkey('%s6' % t)
744            self.assertIsInstance(r, frozenset)
745            self.assertEqual(r, frozenset([
746                'a_very_long_column_name', 'with space', '42']))
747            # a newly added primary key will be detected
748            query('alter table "%s0" add primary key (a)' % t)
749            self.assertEqual(pkey('%s0' % t), 'a')
750            # a changed primary key will not be detected,
751            # indicating that the internal cache is operating
752            query('alter table "%s1" rename column b to x' % t)
753            self.assertEqual(pkey('%s1' % t), 'b')
754            # we get the changed primary key when the cache is flushed
755            self.assertEqual(pkey('%s1' % t, flush=True), 'x')
756
757    def testGetDatabases(self):
758        databases = self.db.get_databases()
759        self.assertIn('template0', databases)
760        self.assertIn('template1', databases)
761        self.assertNotIn('not existing database', databases)
762        self.assertIn('postgres', databases)
763        self.assertIn(dbname, databases)
764
765    def testGetTables(self):
766        get_tables = self.db.get_tables
767        result1 = get_tables()
768        self.assertIsInstance(result1, list)
769        for t in result1:
770            t = t.split('.', 1)
771            self.assertGreaterEqual(len(t), 2)
772            if len(t) > 2:
773                self.assertTrue(t[1].startswith('"'))
774            t = t[0]
775            self.assertNotEqual(t, 'information_schema')
776            self.assertFalse(t.startswith('pg_'))
777        tables = ('"A very Special Name"',
778            '"A_MiXeD_quoted_NaMe"', 'a1', 'a2',
779            'A_MiXeD_NaMe', '"another special name"',
780            'averyveryveryveryveryveryverylongtablename',
781            'b0', 'b3', 'x', 'xx', 'xXx', 'y', 'z')
782        for t in tables:
783            self.db.query('drop table if exists %s' % t)
784            self.db.query("create table %s"
785                " as select 0" % t)
786        result3 = get_tables()
787        result2 = []
788        for t in result3:
789            if t not in result1:
790                result2.append(t)
791        result3 = []
792        for t in tables:
793            if not t.startswith('"'):
794                t = t.lower()
795            result3.append('public.' + t)
796        self.assertEqual(result2, result3)
797        for t in result2:
798            self.db.query('drop table %s' % t)
799        result2 = get_tables()
800        self.assertEqual(result2, result1)
801
802    def testGetRelations(self):
803        get_relations = self.db.get_relations
804        result = get_relations()
805        self.assertIn('public.test', result)
806        self.assertIn('public.test_view', result)
807        result = get_relations('rv')
808        self.assertIn('public.test', result)
809        self.assertIn('public.test_view', result)
810        result = get_relations('r')
811        self.assertIn('public.test', result)
812        self.assertNotIn('public.test_view', result)
813        result = get_relations('v')
814        self.assertNotIn('public.test', result)
815        self.assertIn('public.test_view', result)
816        result = get_relations('cisSt')
817        self.assertNotIn('public.test', result)
818        self.assertNotIn('public.test_view', result)
819
820    def testGetAttnames(self):
821        get_attnames = self.db.get_attnames
822        self.assertRaises(pg.ProgrammingError,
823            self.db.get_attnames, 'does_not_exist')
824        self.assertRaises(pg.ProgrammingError,
825            self.db.get_attnames, 'has.too.many.dots')
826        r = get_attnames('test')
827        self.assertIsInstance(r, dict)
828        self.assertEqual(r, dict(
829            i2='int', i4='int', i8='int', d='num',
830            f4='float', f8='float', m='money',
831            v4='text', c4='text', t='text'))
832        query = self.db.query
833        query("drop table if exists test_table")
834        self.addCleanup(query, "drop table test_table")
835        query("create table test_table("
836            " n int, alpha smallint, beta bool,"
837            " gamma char(5), tau text, v varchar(3))")
838        r = get_attnames('test_table')
839        self.assertIsInstance(r, dict)
840        self.assertEqual(r, dict(
841            n='int', alpha='int', beta='bool',
842            gamma='text', tau='text', v='text'))
843
844    def testGetAttnamesWithQuotes(self):
845        get_attnames = self.db.get_attnames
846        query = self.db.query
847        table = 'test table for get_attnames()'
848        query('drop table if exists "%s"' % table)
849        self.addCleanup(query, 'drop table "%s"' % table)
850        query('create table "%s"('
851            '"Prime!" smallint,'
852            ' "much space" integer, "Questions?" text)' % table)
853        r = get_attnames(table)
854        self.assertIsInstance(r, dict)
855        self.assertEqual(r, {
856            'Prime!': 'int', 'much space': 'int', 'Questions?': 'text'})
857        table = 'yet another test table for get_attnames()'
858        query('drop table if exists "%s"' % table)
859        self.addCleanup(query, 'drop table "%s"' % table)
860        self.db.query('create table "%s" ('
861            'a smallint, b integer, c bigint,'
862            ' e numeric, f float, f2 double precision, m money,'
863            ' x smallint, y smallint, z smallint,'
864            ' Normal_NaMe smallint, "Special Name" smallint,'
865            ' t text, u char(2), v varchar(2),'
866            ' primary key (y, u)) with oids' % table)
867        r = get_attnames(table)
868        self.assertIsInstance(r, dict)
869        self.assertEqual(r, {'a': 'int', 'c': 'int', 'b': 'int',
870            'e': 'num', 'f': 'float', 'f2': 'float', 'm': 'money',
871            'normal_name': 'int', 'Special Name': 'int',
872            'u': 'text', 't': 'text', 'v': 'text',
873            'y': 'int', 'x': 'int', 'z': 'int', 'oid': 'int'})
874
875    def testGetAttnamesWithRegtypes(self):
876        get_attnames = self.db.get_attnames
877        query = self.db.query
878        query("drop table if exists test_table")
879        self.addCleanup(query, "drop table test_table")
880        query("create table test_table("
881            " n int, alpha smallint, beta bool,"
882            " gamma char(5), tau text, v varchar(3))")
883        use_regtypes = self.db.use_regtypes
884        regtypes = use_regtypes()
885        self.assertFalse(regtypes)
886        use_regtypes(True)
887        try:
888            r = get_attnames("test_table")
889            self.assertIsInstance(r, dict)
890        finally:
891            use_regtypes(regtypes)
892        self.assertEqual(r, dict(
893            n='integer', alpha='smallint', beta='boolean',
894            gamma='character', tau='text', v='character varying'))
895
896    def testGetAttnamesIsCached(self):
897        get_attnames = self.db.get_attnames
898        query = self.db.query
899        query("drop table if exists test_table")
900        self.addCleanup(query, "drop table test_table")
901        query("create table test_table(col int)")
902        r = get_attnames("test_table")
903        self.assertIsInstance(r, dict)
904        self.assertEqual(r, dict(col='int'))
905        query("alter table test_table alter column col type text")
906        query("alter table test_table add column col2 int")
907        r = get_attnames("test_table")
908        self.assertEqual(r, dict(col='int'))
909        r = get_attnames("test_table", flush=True)
910        self.assertEqual(r, dict(col='text', col2='int'))
911        query("alter table test_table drop column col2")
912        r = get_attnames("test_table")
913        self.assertEqual(r, dict(col='text', col2='int'))
914        r = get_attnames("test_table", flush=True)
915        self.assertEqual(r, dict(col='text'))
916        query("alter table test_table drop column col")
917        r = get_attnames("test_table")
918        self.assertEqual(r, dict(col='text'))
919        r = get_attnames("test_table", flush=True)
920        self.assertEqual(r, dict())
921
922    def testGetAttnamesIsOrdered(self):
923        get_attnames = self.db.get_attnames
924        query = self.db.query
925        query("drop table if exists test_table")
926        self.addCleanup(query, "drop table test_table")
927        query("create table test_table("
928            " n int, alpha smallint, v varchar(3),"
929            " gamma char(5), tau text, beta bool)")
930        r = get_attnames("test_table")
931        self.assertIsInstance(r, OrderedDict)
932        self.assertEqual(r, OrderedDict([
933            ('n', 'int'), ('alpha', 'int'), ('v', 'text'),
934            ('gamma', 'text'), ('tau', 'text'), ('beta', 'bool')]))
935        if OrderedDict is dict:
936            self.skipTest('OrderedDict is not supported')
937        r = ' '.join(list(r.keys()))
938        self.assertEqual(r, 'n alpha v gamma tau beta')
939
940    def testHasTablePrivilege(self):
941        can = self.db.has_table_privilege
942        self.assertEqual(can('test'), True)
943        self.assertEqual(can('test', 'select'), True)
944        self.assertEqual(can('test', 'SeLeCt'), True)
945        self.assertEqual(can('test', 'SELECT'), True)
946        self.assertEqual(can('test', 'insert'), True)
947        self.assertEqual(can('test', 'update'), True)
948        self.assertEqual(can('test', 'delete'), True)
949        self.assertEqual(can('pg_views', 'select'), True)
950        self.assertEqual(can('pg_views', 'delete'), False)
951        self.assertRaises(pg.ProgrammingError, can, 'test', 'foobar')
952        self.assertRaises(pg.ProgrammingError, can, 'table_does_not_exist')
953
954    def testGet(self):
955        get = self.db.get
956        query = self.db.query
957        table = 'get_test_table'
958        query('drop table if exists "%s"' % table)
959        self.addCleanup(query, 'drop table "%s"' % table)
960        query('create table "%s" ('
961            "n integer, t text) with oids" % table)
962        for n, t in enumerate('xyz'):
963            query('insert into "%s" values('"%d, '%s')"
964                % (table, n + 1, t))
965        self.assertRaises(pg.ProgrammingError, get, table, 2)
966        self.assertRaises(pg.ProgrammingError, get, table, {}, 'oid')
967        r = get(table, 2, 'n')
968        oid_table = 'oid(%s)' % table
969        self.assertIn(oid_table, r)
970        oid = r[oid_table]
971        self.assertIsInstance(oid, int)
972        result = {'t': 'y', 'n': 2, oid_table: oid}
973        self.assertEqual(r, result)
974        self.assertEqual(get(table + ' *', 2, 'n'), r)
975        self.assertEqual(get(table, oid, 'oid')['t'], 'y')
976        self.assertEqual(get(table, 1, 'n')['t'], 'x')
977        self.assertEqual(get(table, 3, 'n')['t'], 'z')
978        self.assertEqual(get(table, 2, 'n')['t'], 'y')
979        self.assertRaises(pg.DatabaseError, get, table, 4, 'n')
980        r['n'] = 3
981        self.assertEqual(get(table, r, 'n')['t'], 'z')
982        self.assertEqual(get(table, 1, 'n')['t'], 'x')
983        query('alter table "%s" alter n set not null' % table)
984        query('alter table "%s" add primary key (n)' % table)
985        self.assertEqual(get(table, 3)['t'], 'z')
986        self.assertEqual(get(table, 1)['t'], 'x')
987        self.assertEqual(get(table, 2)['t'], 'y')
988        r['n'] = 1
989        self.assertEqual(get(table, r)['t'], 'x')
990        r['n'] = 3
991        self.assertEqual(get(table, r)['t'], 'z')
992        r['n'] = 2
993        self.assertEqual(get(table, r)['t'], 'y')
994
995    def testGetWithCompositeKey(self):
996        get = self.db.get
997        query = self.db.query
998        table = 'get_test_table_1'
999        query('drop table if exists "%s"' % table)
1000        self.addCleanup(query, 'drop table "%s"' % table)
1001        query('create table "%s" ('
1002            "n integer, t text, primary key (n))" % table)
1003        for n, t in enumerate('abc'):
1004            query('insert into "%s" values('
1005                "%d, '%s')" % (table, n + 1, t))
1006        self.assertEqual(get(table, 2)['t'], 'b')
1007        table = 'get_test_table_2'
1008        query('drop table if exists "%s"' % table)
1009        self.addCleanup(query, 'drop table "%s"' % table)
1010        query('create table "%s" ('
1011            "n integer, m integer, t text, primary key (n, m))" % table)
1012        for n in range(3):
1013            for m in range(2):
1014                t = chr(ord('a') + 2 * n + m)
1015                query('insert into "%s" values('
1016                    "%d, %d, '%s')" % (table, n + 1, m + 1, t))
1017        self.assertRaises(pg.ProgrammingError, get, table, 2)
1018        self.assertEqual(get(table, dict(n=2, m=2))['t'], 'd')
1019        r = get(table, dict(n=1, m=2), ('n', 'm'))
1020        self.assertEqual(r['t'], 'b')
1021        r = get(table, dict(n=3, m=2), frozenset(['n', 'm']))
1022        self.assertEqual(r['t'], 'f')
1023
1024    def testGetWithQuotedNames(self):
1025        get = self.db.get
1026        query = self.db.query
1027        table = 'test table for get()'
1028        query('drop table if exists "%s"' % table)
1029        self.addCleanup(query, 'drop table "%s"' % table)
1030        query('create table "%s" ('
1031            '"Prime!" smallint primary key,'
1032            ' "much space" integer, "Questions?" text)' % table)
1033        query('insert into "%s"'
1034              " values(17, 1001, 'No!')" % table)
1035        r = get(table, 17)
1036        self.assertIsInstance(r, dict)
1037        self.assertEqual(r['Prime!'], 17)
1038        self.assertEqual(r['much space'], 1001)
1039        self.assertEqual(r['Questions?'], 'No!')
1040
1041    def testGetFromView(self):
1042        self.db.query('delete from test where i4=14')
1043        self.db.query('insert into test (i4, v4) values('
1044            "14, 'abc4')")
1045        r = self.db.get('test_view', 14, 'i4')
1046        self.assertIn('v4', r)
1047        self.assertEqual(r['v4'], 'abc4')
1048
1049    def testGetLittleBobbyTables(self):
1050        get = self.db.get
1051        query = self.db.query
1052        query("drop table if exists test_students")
1053        self.addCleanup(query, "drop table test_students")
1054        query("create table test_students (firstname varchar primary key,"
1055            " nickname varchar, grade char(2))")
1056        query("insert into test_students values ("
1057              "'D''Arcy', 'Darcey', 'A+')")
1058        query("insert into test_students values ("
1059              "'Sheldon', 'Moonpie', 'A+')")
1060        query("insert into test_students values ("
1061              "'Robert', 'Little Bobby Tables', 'D-')")
1062        r = get('test_students', 'Sheldon')
1063        self.assertEqual(r, dict(
1064            firstname="Sheldon", nickname='Moonpie', grade='A+'))
1065        r = get('test_students', 'Robert')
1066        self.assertEqual(r, dict(
1067            firstname="Robert", nickname='Little Bobby Tables', grade='D-'))
1068        r = get('test_students', "D'Arcy")
1069        self.assertEqual(r, dict(
1070            firstname="D'Arcy", nickname='Darcey', grade='A+'))
1071        try:
1072            get('test_students', "D' Arcy")
1073        except pg.DatabaseError as error:
1074            self.assertEqual(str(error),
1075                'No such record in test_students\nwhere "firstname" = $1\n'
1076                'with $1="D\' Arcy"')
1077        try:
1078            get('test_students', "Robert'); TRUNCATE TABLE test_students;--")
1079        except pg.DatabaseError as error:
1080            self.assertEqual(str(error),
1081                'No such record in test_students\nwhere "firstname" = $1\n'
1082                'with $1="Robert\'); TRUNCATE TABLE test_students;--"')
1083        q = "select * from test_students order by 1 limit 4"
1084        r = query(q).getresult()
1085        self.assertEqual(len(r), 3)
1086        self.assertEqual(r[1][2], 'D-')
1087
1088    def testInsert(self):
1089        insert = self.db.insert
1090        query = self.db.query
1091        bool_on = pg.get_bool()
1092        decimal = pg.get_decimal()
1093        table = 'insert_test_table'
1094        query('drop table if exists "%s"' % table)
1095        self.addCleanup(query, 'drop table "%s"' % table)
1096        query('create table "%s" ('
1097            "i2 smallint, i4 integer, i8 bigint,"
1098            " d numeric, f4 real, f8 double precision, m money,"
1099            " v4 varchar(4), c4 char(4), t text,"
1100            " b boolean, ts timestamp) with oids" % table)
1101        oid_table = 'oid(%s)' % table
1102        tests = [dict(i2=None, i4=None, i8=None),
1103            (dict(i2='', i4='', i8=''), dict(i2=None, i4=None, i8=None)),
1104            (dict(i2=0, i4=0, i8=0), dict(i2=0, i4=0, i8=0)),
1105            dict(i2=42, i4=123456, i8=9876543210),
1106            dict(i2=2 ** 15 - 1,
1107                i4=int(2 ** 31 - 1), i8=long(2 ** 63 - 1)),
1108            dict(d=None), (dict(d=''), dict(d=None)),
1109            dict(d=Decimal(0)), (dict(d=0), dict(d=Decimal(0))),
1110            dict(f4=None, f8=None), dict(f4=0, f8=0),
1111            (dict(f4='', f8=''), dict(f4=None, f8=None)),
1112            (dict(d=1234.5, f4=1234.5, f8=1234.5),
1113                  dict(d=Decimal('1234.5'))),
1114            dict(d=Decimal('123.456789'), f4=12.375, f8=123.4921875),
1115            dict(d=Decimal('123456789.9876543212345678987654321')),
1116            dict(m=None), (dict(m=''), dict(m=None)),
1117            dict(m=Decimal('-1234.56')),
1118            (dict(m=('-1234.56')), dict(m=Decimal('-1234.56'))),
1119            dict(m=Decimal('1234.56')), dict(m=Decimal('123456')),
1120            (dict(m='1234.56'), dict(m=Decimal('1234.56'))),
1121            (dict(m=1234.5), dict(m=Decimal('1234.5'))),
1122            (dict(m=-1234.5), dict(m=Decimal('-1234.5'))),
1123            (dict(m=123456), dict(m=Decimal('123456'))),
1124            (dict(m='1234567.89'), dict(m=Decimal('1234567.89'))),
1125            dict(b=None), (dict(b=''), dict(b=None)),
1126            dict(b='f'), dict(b='t'),
1127            (dict(b=0), dict(b='f')), (dict(b=1), dict(b='t')),
1128            (dict(b=False), dict(b='f')), (dict(b=True), dict(b='t')),
1129            (dict(b='0'), dict(b='f')), (dict(b='1'), dict(b='t')),
1130            (dict(b='n'), dict(b='f')), (dict(b='y'), dict(b='t')),
1131            (dict(b='no'), dict(b='f')), (dict(b='yes'), dict(b='t')),
1132            (dict(b='off'), dict(b='f')), (dict(b='on'), dict(b='t')),
1133            dict(v4=None, c4=None, t=None),
1134            (dict(v4='', c4='', t=''), dict(c4=' ' * 4)),
1135            dict(v4='1234', c4='1234', t='1234' * 10),
1136            dict(v4='abcd', c4='abcd', t='abcdefg'),
1137            (dict(v4='abc', c4='abc', t='abc'), dict(c4='abc ')),
1138            dict(ts=None), (dict(ts=''), dict(ts=None)),
1139            (dict(ts=0), dict(ts=None)), (dict(ts=False), dict(ts=None)),
1140            dict(ts='2012-12-21 00:00:00'),
1141            (dict(ts='2012-12-21'), dict(ts='2012-12-21 00:00:00')),
1142            dict(ts='2012-12-21 12:21:12'),
1143            dict(ts='2013-01-05 12:13:14'),
1144            dict(ts='current_timestamp')]
1145        for test in tests:
1146            if isinstance(test, dict):
1147                data = test
1148                change = {}
1149            else:
1150                data, change = test
1151            expect = data.copy()
1152            expect.update(change)
1153            if bool_on:
1154                b = expect.get('b')
1155                if b is not None:
1156                    expect['b'] = b == 't'
1157            if decimal is not Decimal:
1158                d = expect.get('d')
1159                if d is not None:
1160                    expect['d'] = decimal(d)
1161                m = expect.get('m')
1162                if m is not None:
1163                    expect['m'] = decimal(m)
1164            self.assertEqual(insert(table, data), data)
1165            self.assertIn(oid_table, data)
1166            oid = data[oid_table]
1167            self.assertIsInstance(oid, int)
1168            data = dict(item for item in data.items()
1169                if item[0] in expect)
1170            ts = expect.get('ts')
1171            if ts == 'current_timestamp':
1172                ts = expect['ts'] = data['ts']
1173                if len(ts) > 19:
1174                    self.assertEqual(ts[19], '.')
1175                    ts = ts[:19]
1176                else:
1177                    self.assertEqual(len(ts), 19)
1178                self.assertTrue(ts[:4].isdigit())
1179                self.assertEqual(ts[4], '-')
1180                self.assertEqual(ts[10], ' ')
1181                self.assertTrue(ts[11:13].isdigit())
1182                self.assertEqual(ts[13], ':')
1183            self.assertEqual(data, expect)
1184            data = query(
1185                'select oid,* from "%s"' % table).dictresult()[0]
1186            self.assertEqual(data['oid'], oid)
1187            data = dict(item for item in data.items()
1188                if item[0] in expect)
1189            self.assertEqual(data, expect)
1190            query('delete from "%s"' % table)
1191
1192    def testInsertWithOid(self):
1193        insert = self.db.insert
1194        query = self.db.query
1195        query("drop table if exists test_table")
1196        self.addCleanup(query, "drop table test_table")
1197        query("create table test_table (n int) with oids")
1198        r = insert('test_table', n=1)
1199        self.assertIsInstance(r, dict)
1200        self.assertEqual(r['n'], 1)
1201        qoid = 'oid(test_table)'
1202        self.assertIn(qoid, r)
1203        r = insert('test_table', n=2, oid='invalid')
1204        self.assertIsInstance(r, dict)
1205        self.assertEqual(r['n'], 2)
1206        r['n'] = 3
1207        r = insert('test_table', r)
1208        self.assertIsInstance(r, dict)
1209        self.assertEqual(r['n'], 3)
1210        r = insert('test_table', r, n=4)
1211        self.assertIsInstance(r, dict)
1212        self.assertEqual(r['n'], 4)
1213        q = 'select n from test_table order by 1 limit 5'
1214        r = query(q).getresult()
1215        self.assertEqual(r, [(1,), (2,), (3,), (4,)])
1216
1217    def testInsertWithQuotedNames(self):
1218        insert = self.db.insert
1219        query = self.db.query
1220        table = 'test table for insert()'
1221        query('drop table if exists "%s"' % table)
1222        self.addCleanup(query, 'drop table "%s"' % table)
1223        query('create table "%s" ('
1224            '"Prime!" smallint primary key,'
1225            ' "much space" integer, "Questions?" text)' % table)
1226        r = {'Prime!': 11, 'much space': 2002, 'Questions?': 'What?'}
1227        r = insert(table, r)
1228        self.assertIsInstance(r, dict)
1229        self.assertEqual(r['Prime!'], 11)
1230        self.assertEqual(r['much space'], 2002)
1231        self.assertEqual(r['Questions?'], 'What?')
1232        r = query('select * from "%s" limit 2' % table).dictresult()
1233        self.assertEqual(len(r), 1)
1234        r = r[0]
1235        self.assertEqual(r['Prime!'], 11)
1236        self.assertEqual(r['much space'], 2002)
1237        self.assertEqual(r['Questions?'], 'What?')
1238
1239    def testUpdate(self):
1240        update = self.db.update
1241        query = self.db.query
1242        self.assertRaises(pg.ProgrammingError, update,
1243            'test', i2=2, i4=4, i8=8)
1244        table = 'update_test_table'
1245        query('drop table if exists "%s"' % table)
1246        self.addCleanup(query, 'drop table "%s"' % table)
1247        query('create table "%s" ('
1248            "n integer, t text) with oids" % table)
1249        for n, t in enumerate('xyz'):
1250            query('insert into "%s" values('
1251                "%d, '%s')" % (table, n + 1, t))
1252        self.assertRaises(pg.ProgrammingError, self.db.get, table, 2)
1253        r = self.db.get(table, 2, 'n')
1254        r['t'] = 'u'
1255        s = update(table, r)
1256        self.assertEqual(s, r)
1257        q = 'select t from "%s" where n=2' % table
1258        r = query(q).getresult()[0][0]
1259        self.assertEqual(r, 'u')
1260
1261    def testUpdateWithOid(self):
1262        update = self.db.update
1263        get = self.db.get
1264        query = self.db.query
1265        query("drop table if exists test_table")
1266        self.addCleanup(query, "drop table test_table")
1267        query("create table test_table (n int) with oids")
1268        query("insert into test_table values (1)")
1269        r = get('test_table', 1, 'n')
1270        self.assertIsInstance(r, dict)
1271        self.assertEqual(r['n'], 1)
1272        r['n'] = 2
1273        r = update('test_table', r)
1274        self.assertIsInstance(r, dict)
1275        self.assertEqual(r['n'], 2)
1276        qoid = 'oid(test_table)'
1277        self.assertIn(qoid, r)
1278        r['n'] = 3
1279        r = update('test_table', r, oid=r.pop(qoid))
1280        self.assertIsInstance(r, dict)
1281        self.assertEqual(r['n'], 3)
1282        r.pop(qoid)
1283        self.assertRaises(pg.ProgrammingError, update, 'test_table', r)
1284        r = get('test_table', 3, 'n')
1285        self.assertIsInstance(r, dict)
1286        self.assertEqual(r['n'], 3)
1287        r.pop('n')
1288        r = update('test_table', r)
1289        r.pop(qoid)
1290        self.assertEqual(r, {})
1291        q = 'select n from test_table limit 2'
1292        r = query(q).getresult()
1293        self.assertEqual(r, [(3,)])
1294
1295    def testUpdateWithCompositeKey(self):
1296        update = self.db.update
1297        query = self.db.query
1298        table = 'update_test_table_1'
1299        query('drop table if exists "%s"' % table)
1300        self.addCleanup(query, 'drop table if exists "%s"' % table)
1301        query('create table "%s" ('
1302            "n integer, t text, primary key (n))" % table)
1303        for n, t in enumerate('abc'):
1304            query('insert into "%s" values('
1305                "%d, '%s')" % (table, n + 1, t))
1306        self.assertRaises(pg.ProgrammingError, update,
1307                          table, dict(t='b'))
1308        s = dict(n=2, t='d')
1309        r = update(table, s)
1310        self.assertIs(r, s)
1311        self.assertEqual(r['n'], 2)
1312        self.assertEqual(r['t'], 'd')
1313        q = 'select t from "%s" where n=2' % table
1314        r = query(q).getresult()[0][0]
1315        self.assertEqual(r, 'd')
1316        s.update(dict(n=4, t='e'))
1317        r = update(table, s)
1318        self.assertEqual(r['n'], 4)
1319        self.assertEqual(r['t'], 'e')
1320        q = 'select t from "%s" where n=2' % table
1321        r = query(q).getresult()[0][0]
1322        self.assertEqual(r, 'd')
1323        q = 'select t from "%s" where n=4' % table
1324        r = query(q).getresult()
1325        self.assertEqual(len(r), 0)
1326        query('drop table "%s"' % table)
1327        table = 'update_test_table_2'
1328        query('drop table if exists "%s"' % table)
1329        query('create table "%s" ('
1330            "n integer, m integer, t text, primary key (n, m))" % table)
1331        for n in range(3):
1332            for m in range(2):
1333                t = chr(ord('a') + 2 * n + m)
1334                query('insert into "%s" values('
1335                    "%d, %d, '%s')" % (table, n + 1, m + 1, t))
1336        self.assertRaises(pg.ProgrammingError, update,
1337                          table, dict(n=2, t='b'))
1338        self.assertEqual(update(table,
1339                                dict(n=2, m=2, t='x'))['t'], 'x')
1340        q = 'select t from "%s" where n=2 order by m' % table
1341        r = [r[0] for r in query(q).getresult()]
1342        self.assertEqual(r, ['c', 'x'])
1343
1344    def testUpdateWithQuotedNames(self):
1345        update = self.db.update
1346        query = self.db.query
1347        table = 'test table for update()'
1348        query('drop table if exists "%s"' % table)
1349        self.addCleanup(query, 'drop table "%s"' % table)
1350        query('create table "%s" ('
1351            '"Prime!" smallint primary key,'
1352            ' "much space" integer, "Questions?" text)' % table)
1353        query('insert into "%s"'
1354              " values(13, 3003, 'Why!')" % table)
1355        r = {'Prime!': 13, 'much space': 7007, 'Questions?': 'When?'}
1356        r = update(table, r)
1357        self.assertIsInstance(r, dict)
1358        self.assertEqual(r['Prime!'], 13)
1359        self.assertEqual(r['much space'], 7007)
1360        self.assertEqual(r['Questions?'], 'When?')
1361        r = query('select * from "%s" limit 2' % table).dictresult()
1362        self.assertEqual(len(r), 1)
1363        r = r[0]
1364        self.assertEqual(r['Prime!'], 13)
1365        self.assertEqual(r['much space'], 7007)
1366        self.assertEqual(r['Questions?'], 'When?')
1367
1368    def testUpsert(self):
1369        upsert = self.db.upsert
1370        query = self.db.query
1371        self.assertRaises(pg.ProgrammingError, upsert,
1372            'test', i2=2, i4=4, i8=8)
1373        table = 'upsert_test_table'
1374        query('drop table if exists "%s"' % table)
1375        self.addCleanup(query, 'drop table "%s"' % table)
1376        query('create table "%s" ('
1377            "n integer primary key, t text) with oids" % table)
1378        s = dict(n=1, t='x')
1379        try:
1380            r = upsert(table, s)
1381        except pg.ProgrammingError as error:
1382            if self.db.server_version < 90500:
1383                self.skipTest('database does not support upsert')
1384            self.fail(str(error))
1385        self.assertIs(r, s)
1386        self.assertEqual(r['n'], 1)
1387        self.assertEqual(r['t'], 'x')
1388        s.update(n=2, t='y')
1389        r = upsert(table, s, **dict.fromkeys(s))
1390        self.assertIs(r, s)
1391        self.assertEqual(r['n'], 2)
1392        self.assertEqual(r['t'], 'y')
1393        q = 'select n, t from "%s" order by n limit 3' % table
1394        r = query(q).getresult()
1395        self.assertEqual(r, [(1, 'x'), (2, 'y')])
1396        s.update(t='z')
1397        r = upsert(table, s)
1398        self.assertIs(r, s)
1399        self.assertEqual(r['n'], 2)
1400        self.assertEqual(r['t'], 'z')
1401        r = query(q).getresult()
1402        self.assertEqual(r, [(1, 'x'), (2, 'z')])
1403        s.update(t='n')
1404        r = upsert(table, s, t=False)
1405        self.assertIs(r, s)
1406        self.assertEqual(r['n'], 2)
1407        self.assertEqual(r['t'], 'z')
1408        r = query(q).getresult()
1409        self.assertEqual(r, [(1, 'x'), (2, 'z')])
1410        s.update(t='y')
1411        r = upsert(table, s, t=True)
1412        self.assertIs(r, s)
1413        self.assertEqual(r['n'], 2)
1414        self.assertEqual(r['t'], 'y')
1415        r = query(q).getresult()
1416        self.assertEqual(r, [(1, 'x'), (2, 'y')])
1417        s.update(t='n')
1418        r = upsert(table, s, t="included.t || '2'")
1419        self.assertIs(r, s)
1420        self.assertEqual(r['n'], 2)
1421        self.assertEqual(r['t'], 'y2')
1422        r = query(q).getresult()
1423        self.assertEqual(r, [(1, 'x'), (2, 'y2')])
1424        s.update(t='y')
1425        r = upsert(table, s, t="excluded.t || '3'")
1426        self.assertIs(r, s)
1427        self.assertEqual(r['n'], 2)
1428        self.assertEqual(r['t'], 'y3')
1429        r = query(q).getresult()
1430        self.assertEqual(r, [(1, 'x'), (2, 'y3')])
1431        s.update(n=1, t='2')
1432        r = upsert(table, s, t="included.t || excluded.t")
1433        self.assertIs(r, s)
1434        self.assertEqual(r['n'], 1)
1435        self.assertEqual(r['t'], 'x2')
1436        r = query(q).getresult()
1437        self.assertEqual(r, [(1, 'x2'), (2, 'y3')])
1438        # not existing columns and oid parameter should be ignored
1439        s = dict(m=3, u='z')
1440        r = upsert(table, s, oid='invalid')
1441        self.assertIs(r, s)
1442
1443    def testUpsertWithCompositeKey(self):
1444        upsert = self.db.upsert
1445        query = self.db.query
1446        table = 'upsert_test_table_2'
1447        query('drop table if exists "%s"' % table)
1448        self.addCleanup(query, 'drop table "%s"' % table)
1449        query('create table "%s" ('
1450            "n integer, m integer, t text, primary key (n, m))" % table)
1451        s = dict(n=1, m=2, t='x')
1452        try:
1453            r = upsert(table, s)
1454        except pg.ProgrammingError as error:
1455            if self.db.server_version < 90500:
1456                self.skipTest('database does not support upsert')
1457            self.fail(str(error))
1458        self.assertIs(r, s)
1459        self.assertEqual(r['n'], 1)
1460        self.assertEqual(r['m'], 2)
1461        self.assertEqual(r['t'], 'x')
1462        s.update(m=3, t='y')
1463        r = upsert(table, s, **dict.fromkeys(s))
1464        self.assertIs(r, s)
1465        self.assertEqual(r['n'], 1)
1466        self.assertEqual(r['m'], 3)
1467        self.assertEqual(r['t'], 'y')
1468        q = 'select n, m, t from "%s" order by n, m limit 3' % table
1469        r = query(q).getresult()
1470        self.assertEqual(r, [(1, 2, 'x'), (1, 3, 'y')])
1471        s.update(t='z')
1472        r = upsert(table, s)
1473        self.assertIs(r, s)
1474        self.assertEqual(r['n'], 1)
1475        self.assertEqual(r['m'], 3)
1476        self.assertEqual(r['t'], 'z')
1477        r = query(q).getresult()
1478        self.assertEqual(r, [(1, 2, 'x'), (1, 3, 'z')])
1479        s.update(t='n')
1480        r = upsert(table, s, t=False)
1481        self.assertIs(r, s)
1482        self.assertEqual(r['n'], 1)
1483        self.assertEqual(r['m'], 3)
1484        self.assertEqual(r['t'], 'z')
1485        r = query(q).getresult()
1486        self.assertEqual(r, [(1, 2, 'x'), (1, 3, 'z')])
1487        s.update(t='n')
1488        r = upsert(table, s, t=True)
1489        self.assertIs(r, s)
1490        self.assertEqual(r['n'], 1)
1491        self.assertEqual(r['m'], 3)
1492        self.assertEqual(r['t'], 'n')
1493        r = query(q).getresult()
1494        self.assertEqual(r, [(1, 2, 'x'), (1, 3, 'n')])
1495        s.update(n=2, t='y')
1496        r = upsert(table, s, t="'z'")
1497        self.assertIs(r, s)
1498        self.assertEqual(r['n'], 2)
1499        self.assertEqual(r['m'], 3)
1500        self.assertEqual(r['t'], 'y')
1501        r = query(q).getresult()
1502        self.assertEqual(r, [(1, 2, 'x'), (1, 3, 'n'), (2, 3, 'y')])
1503        s.update(n=1, t='m')
1504        r = upsert(table, s, t='included.t || excluded.t')
1505        self.assertIs(r, s)
1506        self.assertEqual(r['n'], 1)
1507        self.assertEqual(r['m'], 3)
1508        self.assertEqual(r['t'], 'nm')
1509        r = query(q).getresult()
1510        self.assertEqual(r, [(1, 2, 'x'), (1, 3, 'nm'), (2, 3, 'y')])
1511
1512    def testUpsertWithQuotedNames(self):
1513        upsert = self.db.upsert
1514        query = self.db.query
1515        table = 'test table for upsert()'
1516        query('drop table if exists "%s"' % table)
1517        self.addCleanup(query, 'drop table "%s"' % table)
1518        query('create table "%s" ('
1519            '"Prime!" smallint primary key,'
1520            ' "much space" integer, "Questions?" text)' % table)
1521        s = {'Prime!': 31, 'much space': 9009, 'Questions?': 'Yes.'}
1522        try:
1523            r = upsert(table, s)
1524        except pg.ProgrammingError as error:
1525            if self.db.server_version < 90500:
1526                self.skipTest('database does not support upsert')
1527            self.fail(str(error))
1528        self.assertIs(r, s)
1529        self.assertEqual(r['Prime!'], 31)
1530        self.assertEqual(r['much space'], 9009)
1531        self.assertEqual(r['Questions?'], 'Yes.')
1532        q = 'select * from "%s" limit 2' % table
1533        r = query(q).getresult()
1534        self.assertEqual(r, [(31, 9009, 'Yes.')])
1535        s.update({'Questions?': 'No.'})
1536        r = upsert(table, s)
1537        self.assertIs(r, s)
1538        self.assertEqual(r['Prime!'], 31)
1539        self.assertEqual(r['much space'], 9009)
1540        self.assertEqual(r['Questions?'], 'No.')
1541        r = query(q).getresult()
1542        self.assertEqual(r, [(31, 9009, 'No.')])
1543
1544    def testClear(self):
1545        clear = self.db.clear
1546        query = self.db.query
1547        f = False if pg.get_bool() else 'f'
1548        r = clear('test')
1549        result = dict(
1550            i2=0, i4=0, i8=0, d=0, f4=0, f8=0, m=0, v4='', c4='', t='')
1551        self.assertEqual(r, result)
1552        table = 'clear_test_table'
1553        query('drop table if exists "%s"' % table)
1554        self.addCleanup(query, 'drop table "%s"' % table)
1555        query('create table "%s" ('
1556            "n integer, b boolean, d date, t text) with oids" % table)
1557        r = clear(table)
1558        result = dict(n=0, b=f, d='', t='')
1559        self.assertEqual(r, result)
1560        r['a'] = r['n'] = 1
1561        r['d'] = r['t'] = 'x'
1562        r['b'] = 't'
1563        r['oid'] = long(1)
1564        r = clear(table, r)
1565        result = dict(a=1, n=0, b=f, d='', t='', oid=long(1))
1566        self.assertEqual(r, result)
1567
1568    def testClearWithQuotedNames(self):
1569        clear = self.db.clear
1570        query = self.db.query
1571        table = 'test table for clear()'
1572        query('drop table if exists "%s"' % table)
1573        self.addCleanup(query, 'drop table "%s"' % table)
1574        query('create table "%s" ('
1575            '"Prime!" smallint primary key,'
1576            ' "much space" integer, "Questions?" text)' % table)
1577        r = clear(table)
1578        self.assertIsInstance(r, dict)
1579        self.assertEqual(r['Prime!'], 0)
1580        self.assertEqual(r['much space'], 0)
1581        self.assertEqual(r['Questions?'], '')
1582
1583    def testDelete(self):
1584        delete = self.db.delete
1585        query = self.db.query
1586        self.assertRaises(pg.ProgrammingError, delete,
1587            'test', dict(i2=2, i4=4, i8=8))
1588        table = 'delete_test_table'
1589        query('drop table if exists "%s"' % table)
1590        self.addCleanup(query, 'drop table "%s"' % table)
1591        query('create table "%s" ('
1592            "n integer, t text) with oids" % table)
1593        for n, t in enumerate('xyz'):
1594            query('insert into "%s" values('
1595                "%d, '%s')" % (table, n + 1, t))
1596        self.assertRaises(pg.ProgrammingError, self.db.get, table, 2)
1597        r = self.db.get(table, 1, 'n')
1598        s = delete(table, r)
1599        self.assertEqual(s, 1)
1600        r = self.db.get(table, 3, 'n')
1601        s = delete(table, r)
1602        self.assertEqual(s, 1)
1603        s = delete(table, r)
1604        self.assertEqual(s, 0)
1605        r = query('select * from "%s"' % table).dictresult()
1606        self.assertEqual(len(r), 1)
1607        r = r[0]
1608        result = {'n': 2, 't': 'y'}
1609        self.assertEqual(r, result)
1610        r = self.db.get(table, 2, 'n')
1611        s = delete(table, r)
1612        self.assertEqual(s, 1)
1613        s = delete(table, r)
1614        self.assertEqual(s, 0)
1615        self.assertRaises(pg.DatabaseError, self.db.get, table, 2, 'n')
1616        # not existing columns and oid parameter should be ignored
1617        r.update(m=3, u='z', oid='invalid')
1618        s = delete(table, r)
1619        self.assertEqual(s, 0)
1620
1621    def testDeleteWithOid(self):
1622        delete = self.db.delete
1623        get = self.db.get
1624        query = self.db.query
1625        query("drop table if exists test_table")
1626        self.addCleanup(query, "drop table test_table")
1627        query("create table test_table (n int) with oids")
1628        query("insert into test_table values (1)")
1629        query("insert into test_table values (2)")
1630        query("insert into test_table values (3)")
1631        r = dict(n=3)
1632        self.assertRaises(pg.ProgrammingError, delete, 'test_table', r)
1633        r = get('test_table', 1, 'n')
1634        self.assertIsInstance(r, dict)
1635        self.assertEqual(r['n'], 1)
1636        qoid = 'oid(test_table)'
1637        self.assertIn(qoid, r)
1638        oid = r[qoid]
1639        self.assertIsInstance(oid, int)
1640        s = delete('test_table', r)
1641        self.assertEqual(s, 1)
1642        s = delete('test_table', r)
1643        self.assertEqual(s, 0)
1644        r = get('test_table', 2, 'n')
1645        self.assertIsInstance(r, dict)
1646        self.assertEqual(r['n'], 2)
1647        qoid = 'oid(test_table)'
1648        self.assertIn(qoid, r)
1649        oid = r[qoid]
1650        self.assertIsInstance(oid, int)
1651        r['oid'] = r.pop(qoid)
1652        self.assertRaises(pg.ProgrammingError, delete, 'test_table', r)
1653        s = delete('test_table', r, oid=oid)
1654        self.assertEqual(s, 1)
1655        s = delete('test_table', r)
1656        self.assertEqual(s, 0)
1657        s = delete('test_table', r, n=3)
1658        self.assertEqual(s, 0)
1659        q = 'select n from test_table order by 1 limit 3'
1660        r = query(q).getresult()
1661        self.assertEqual(r, [(3,)])
1662
1663    def testDeleteWithCompositeKey(self):
1664        query = self.db.query
1665        table = 'delete_test_table_1'
1666        query('drop table if exists "%s"' % table)
1667        self.addCleanup(query, 'drop table "%s"' % table)
1668        query('create table "%s" ('
1669            "n integer, t text, primary key (n))" % table)
1670        for n, t in enumerate('abc'):
1671            query("insert into %s values("
1672                "%d, '%s')" % (table, n + 1, t))
1673        self.assertRaises(pg.ProgrammingError, self.db.delete,
1674            table, dict(t='b'))
1675        self.assertEqual(self.db.delete(table, dict(n=2)), 1)
1676        r = query('select t from "%s" where n=2' % table
1677                  ).getresult()
1678        self.assertEqual(r, [])
1679        self.assertEqual(self.db.delete(table, dict(n=2)), 0)
1680        r = query('select t from "%s" where n=3' % table
1681                  ).getresult()[0][0]
1682        self.assertEqual(r, 'c')
1683        table = 'delete_test_table_2'
1684        query('drop table if exists "%s"' % table)
1685        self.addCleanup(query, 'drop table "%s"' % table)
1686        query('create table "%s" ('
1687            "n integer, m integer, t text, primary key (n, m))" % table)
1688        for n in range(3):
1689            for m in range(2):
1690                t = chr(ord('a') + 2 * n + m)
1691                query('insert into "%s" values('
1692                    "%d, %d, '%s')" % (table, n + 1, m + 1, t))
1693        self.assertRaises(pg.ProgrammingError, self.db.delete,
1694            table, dict(n=2, t='b'))
1695        self.assertEqual(self.db.delete(table, dict(n=2, m=2)), 1)
1696        r = [r[0] for r in query('select t from "%s" where n=2'
1697            ' order by m' % table).getresult()]
1698        self.assertEqual(r, ['c'])
1699        self.assertEqual(self.db.delete(table, dict(n=2, m=2)), 0)
1700        r = [r[0] for r in query('select t from "%s" where n=3'
1701            ' order by m' % table).getresult()]
1702        self.assertEqual(r, ['e', 'f'])
1703        self.assertEqual(self.db.delete(table, dict(n=3, m=1)), 1)
1704        r = [r[0] for r in query('select t from "%s" where n=3'
1705            ' order by m' % table).getresult()]
1706        self.assertEqual(r, ['f'])
1707
1708    def testDeleteWithQuotedNames(self):
1709        delete = self.db.delete
1710        query = self.db.query
1711        table = 'test table for delete()'
1712        query('drop table if exists "%s"' % table)
1713        self.addCleanup(query, 'drop table "%s"' % table)
1714        query('create table "%s" ('
1715            '"Prime!" smallint primary key,'
1716            ' "much space" integer, "Questions?" text)' % table)
1717        query('insert into "%s"'
1718              " values(19, 5005, 'Yes!')" % table)
1719        r = {'Prime!': 17}
1720        r = delete(table, r)
1721        self.assertEqual(r, 0)
1722        r = query('select count(*) from "%s"' % table).getresult()
1723        self.assertEqual(r[0][0], 1)
1724        r = {'Prime!': 19}
1725        r = delete(table, r)
1726        self.assertEqual(r, 1)
1727        r = query('select count(*) from "%s"' % table).getresult()
1728        self.assertEqual(r[0][0], 0)
1729
1730    def testTruncate(self):
1731        truncate = self.db.truncate
1732        self.assertRaises(TypeError, truncate, None)
1733        self.assertRaises(TypeError, truncate, 42)
1734        self.assertRaises(TypeError, truncate, dict(test_table=None))
1735        query = self.db.query
1736        query("drop table if exists test_table")
1737        self.addCleanup(query, "drop table test_table")
1738        query("create table test_table (n smallint)")
1739        for i in range(3):
1740            query("insert into test_table values (1)")
1741        q = "select count(*) from test_table"
1742        r = query(q).getresult()[0][0]
1743        self.assertEqual(r, 3)
1744        truncate('test_table')
1745        r = query(q).getresult()[0][0]
1746        self.assertEqual(r, 0)
1747        for i in range(3):
1748            query("insert into test_table values (1)")
1749        r = query(q).getresult()[0][0]
1750        self.assertEqual(r, 3)
1751        truncate('public.test_table')
1752        r = query(q).getresult()[0][0]
1753        self.assertEqual(r, 0)
1754        query("drop table if exists test_table_2")
1755        self.addCleanup(query, "drop table test_table_2")
1756        query('create table test_table_2 (n smallint)')
1757        for t in (list, tuple, set):
1758            for i in range(3):
1759                query("insert into test_table values (1)")
1760                query("insert into test_table_2 values (2)")
1761            q = ("select (select count(*) from test_table),"
1762                " (select count(*) from test_table_2)")
1763            r = query(q).getresult()[0]
1764            self.assertEqual(r, (3, 3))
1765            truncate(t(['test_table', 'test_table_2']))
1766            r = query(q).getresult()[0]
1767            self.assertEqual(r, (0, 0))
1768
1769    def testTruncateRestart(self):
1770        truncate = self.db.truncate
1771        self.assertRaises(TypeError, truncate, 'test_table', restart='invalid')
1772        query = self.db.query
1773        query("drop table if exists test_table")
1774        self.addCleanup(query, "drop table test_table")
1775        query("create table test_table (n serial, t text)")
1776        for n in range(3):
1777            query("insert into test_table (t) values ('test')")
1778        q = "select count(n), min(n), max(n) from test_table"
1779        r = query(q).getresult()[0]
1780        self.assertEqual(r, (3, 1, 3))
1781        truncate('test_table')
1782        r = query(q).getresult()[0]
1783        self.assertEqual(r, (0, None, None))
1784        for n in range(3):
1785            query("insert into test_table (t) values ('test')")
1786        r = query(q).getresult()[0]
1787        self.assertEqual(r, (3, 4, 6))
1788        truncate('test_table', restart=True)
1789        r = query(q).getresult()[0]
1790        self.assertEqual(r, (0, None, None))
1791        for n in range(3):
1792            query("insert into test_table (t) values ('test')")
1793        r = query(q).getresult()[0]
1794        self.assertEqual(r, (3, 1, 3))
1795
1796    def testTruncateCascade(self):
1797        truncate = self.db.truncate
1798        self.assertRaises(TypeError, truncate, 'test_table', cascade='invalid')
1799        query = self.db.query
1800        query("drop table if exists test_child")
1801        query("drop table if exists test_parent")
1802        self.addCleanup(query, "drop table test_parent")
1803        query("create table test_parent (n smallint primary key)")
1804        self.addCleanup(query, "drop table test_child")
1805        query("create table test_child ("
1806            " n smallint primary key references test_parent (n))")
1807        for n in range(3):
1808            query("insert into test_parent (n) values (%d)" % n)
1809            query("insert into test_child (n) values (%d)" % n)
1810        q = ("select (select count(*) from test_parent),"
1811            " (select count(*) from test_child)")
1812        r = query(q).getresult()[0]
1813        self.assertEqual(r, (3, 3))
1814        self.assertRaises(pg.ProgrammingError, truncate, 'test_parent')
1815        truncate(['test_parent', 'test_child'])
1816        r = query(q).getresult()[0]
1817        self.assertEqual(r, (0, 0))
1818        for n in range(3):
1819            query("insert into test_parent (n) values (%d)" % n)
1820            query("insert into test_child (n) values (%d)" % n)
1821        r = query(q).getresult()[0]
1822        self.assertEqual(r, (3, 3))
1823        truncate('test_parent', cascade=True)
1824        r = query(q).getresult()[0]
1825        self.assertEqual(r, (0, 0))
1826        for n in range(3):
1827            query("insert into test_parent (n) values (%d)" % n)
1828            query("insert into test_child (n) values (%d)" % n)
1829        r = query(q).getresult()[0]
1830        self.assertEqual(r, (3, 3))
1831        truncate('test_child')
1832        r = query(q).getresult()[0]
1833        self.assertEqual(r, (3, 0))
1834        self.assertRaises(pg.ProgrammingError, truncate, 'test_parent')
1835        truncate('test_parent', cascade=True)
1836        r = query(q).getresult()[0]
1837        self.assertEqual(r, (0, 0))
1838
1839    def testTruncateOnly(self):
1840        truncate = self.db.truncate
1841        self.assertRaises(TypeError, truncate, 'test_table', only='invalid')
1842        query = self.db.query
1843        query("drop table if exists test_child")
1844        query("drop table if exists test_parent")
1845        self.addCleanup(query, "drop table test_parent")
1846        query("create table test_parent (n smallint)")
1847        self.addCleanup(query, "drop table test_child")
1848        query("create table test_child ("
1849            " m smallint) inherits (test_parent)")
1850        for n in range(3):
1851            query("insert into test_parent (n) values (1)")
1852            query("insert into test_child (n, m) values (2, 3)")
1853        q = ("select (select count(*) from test_parent),"
1854            " (select count(*) from test_child)")
1855        r = query(q).getresult()[0]
1856        self.assertEqual(r, (6, 3))
1857        truncate('test_parent')
1858        r = query(q).getresult()[0]
1859        self.assertEqual(r, (0, 0))
1860        for n in range(3):
1861            query("insert into test_parent (n) values (1)")
1862            query("insert into test_child (n, m) values (2, 3)")
1863        r = query(q).getresult()[0]
1864        self.assertEqual(r, (6, 3))
1865        truncate('test_parent*')
1866        r = query(q).getresult()[0]
1867        self.assertEqual(r, (0, 0))
1868        for n in range(3):
1869            query("insert into test_parent (n) values (1)")
1870            query("insert into test_child (n, m) values (2, 3)")
1871        r = query(q).getresult()[0]
1872        self.assertEqual(r, (6, 3))
1873        truncate('test_parent', only=True)
1874        r = query(q).getresult()[0]
1875        self.assertEqual(r, (3, 3))
1876        truncate('test_parent', only=False)
1877        r = query(q).getresult()[0]
1878        self.assertEqual(r, (0, 0))
1879        self.assertRaises(ValueError, truncate, 'test_parent*', only=True)
1880        truncate('test_parent*', only=False)
1881        query("drop table if exists test_parent_2")
1882        self.addCleanup(query, "drop table test_parent_2")
1883        query("create table test_parent_2 (n smallint)")
1884        query("drop table if exists test_child_2")
1885        self.addCleanup(query, "drop table test_child_2")
1886        query("create table test_child_2 ("
1887            " m smallint) inherits (test_parent_2)")
1888        for n in range(3):
1889            query("insert into test_parent (n) values (1)")
1890            query("insert into test_child (n, m) values (2, 3)")
1891            query("insert into test_parent_2 (n) values (1)")
1892            query("insert into test_child_2 (n, m) values (2, 3)")
1893        q = ("select (select count(*) from test_parent),"
1894            " (select count(*) from test_child),"
1895            " (select count(*) from test_parent_2),"
1896            " (select count(*) from test_child_2)")
1897        r = query(q).getresult()[0]
1898        self.assertEqual(r, (6, 3, 6, 3))
1899        truncate(['test_parent', 'test_parent_2'], only=[False, True])
1900        r = query(q).getresult()[0]
1901        self.assertEqual(r, (0, 0, 3, 3))
1902        truncate(['test_parent', 'test_parent_2'], only=False)
1903        r = query(q).getresult()[0]
1904        self.assertEqual(r, (0, 0, 0, 0))
1905        self.assertRaises(ValueError, truncate,
1906            ['test_parent*', 'test_child'], only=[True, False])
1907        truncate(['test_parent*', 'test_child'], only=[False, True])
1908
1909    def testTruncateQuoted(self):
1910        truncate = self.db.truncate
1911        query = self.db.query
1912        table = "test table for truncate()"
1913        query('drop table if exists "%s"' % table)
1914        self.addCleanup(query, 'drop table "%s"' % table)
1915        query('create table "%s" (n smallint)' % table)
1916        for i in range(3):
1917            query('insert into "%s" values (1)' % table)
1918        q = 'select count(*) from "%s"' % table
1919        r = query(q).getresult()[0][0]
1920        self.assertEqual(r, 3)
1921        truncate(table)
1922        r = query(q).getresult()[0][0]
1923        self.assertEqual(r, 0)
1924        for i in range(3):
1925            query('insert into "%s" values (1)' % table)
1926        r = query(q).getresult()[0][0]
1927        self.assertEqual(r, 3)
1928        truncate('public."%s"' % table)
1929        r = query(q).getresult()[0][0]
1930        self.assertEqual(r, 0)
1931
1932    def testTransaction(self):
1933        query = self.db.query
1934        query("drop table if exists test_table")
1935        self.addCleanup(query, "drop table test_table")
1936        query("create table test_table (n integer)")
1937        self.db.begin()
1938        query("insert into test_table values (1)")
1939        query("insert into test_table values (2)")
1940        self.db.commit()
1941        self.db.begin()
1942        query("insert into test_table values (3)")
1943        query("insert into test_table values (4)")
1944        self.db.rollback()
1945        self.db.begin()
1946        query("insert into test_table values (5)")
1947        self.db.savepoint('before6')
1948        query("insert into test_table values (6)")
1949        self.db.rollback('before6')
1950        query("insert into test_table values (7)")
1951        self.db.commit()
1952        self.db.begin()
1953        self.db.savepoint('before8')
1954        query("insert into test_table values (8)")
1955        self.db.release('before8')
1956        self.assertRaises(pg.ProgrammingError, self.db.rollback, 'before8')
1957        self.db.commit()
1958        self.db.start()
1959        query("insert into test_table values (9)")
1960        self.db.end()
1961        r = [r[0] for r in query(
1962            "select * from test_table order by 1").getresult()]
1963        self.assertEqual(r, [1, 2, 5, 7, 9])
1964        self.db.begin(mode='read only')
1965        self.assertRaises(pg.ProgrammingError,
1966            query, "insert into test_table values (0)")
1967        self.db.rollback()
1968        self.db.start(mode='Read Only')
1969        self.assertRaises(pg.ProgrammingError,
1970            query, "insert into test_table values (0)")
1971        self.db.abort()
1972
1973    def testTransactionAliases(self):
1974        self.assertEqual(self.db.begin, self.db.start)
1975        self.assertEqual(self.db.commit, self.db.end)
1976        self.assertEqual(self.db.rollback, self.db.abort)
1977
1978    def testContextManager(self):
1979        query = self.db.query
1980        query("drop table if exists test_table")
1981        self.addCleanup(query, "drop table test_table")
1982        query("create table test_table (n integer check(n>0))")
1983        with self.db:
1984            query("insert into test_table values (1)")
1985            query("insert into test_table values (2)")
1986        try:
1987            with self.db:
1988                query("insert into test_table values (3)")
1989                query("insert into test_table values (4)")
1990                raise ValueError('test transaction should rollback')
1991        except ValueError as error:
1992            self.assertEqual(str(error), 'test transaction should rollback')
1993        with self.db:
1994            query("insert into test_table values (5)")
1995        try:
1996            with self.db:
1997                query("insert into test_table values (6)")
1998                query("insert into test_table values (-1)")
1999        except pg.ProgrammingError as error:
2000            self.assertTrue('check' in str(error))
2001        with self.db:
2002            query("insert into test_table values (7)")
2003        r = [r[0] for r in query(
2004            "select * from test_table order by 1").getresult()]
2005        self.assertEqual(r, [1, 2, 5, 7])
2006
2007    def testBytea(self):
2008        query = self.db.query
2009        query('drop table if exists bytea_test')
2010        self.addCleanup(query, 'drop table bytea_test')
2011        query('create table bytea_test (n smallint primary key, data bytea)')
2012        s = b"It's all \\ kinds \x00 of\r nasty \xff stuff!\n"
2013        r = self.db.escape_bytea(s)
2014        query('insert into bytea_test values(3,$1)', (r,))
2015        r = query('select * from bytea_test where n=3').getresult()
2016        self.assertEqual(len(r), 1)
2017        r = r[0]
2018        self.assertEqual(len(r), 2)
2019        self.assertEqual(r[0], 3)
2020        r = r[1]
2021        self.assertIsInstance(r, str)
2022        r = self.db.unescape_bytea(r)
2023        self.assertIsInstance(r, bytes)
2024        self.assertEqual(r, s)
2025
2026    def testInsertUpdateGetBytea(self):
2027        query = self.db.query
2028        query('drop table if exists bytea_test')
2029        self.addCleanup(query, 'drop table bytea_test')
2030        query('create table bytea_test (n smallint primary key, data bytea)')
2031        # insert null value
2032        r = self.db.insert('bytea_test', n=0, data=None)
2033        self.assertIsInstance(r, dict)
2034        self.assertIn('n', r)
2035        self.assertEqual(r['n'], 0)
2036        self.assertIn('data', r)
2037        self.assertIsNone(r['data'])
2038        s = b'None'
2039        r = self.db.update('bytea_test', n=0, data=s)
2040        self.assertIsInstance(r, dict)
2041        self.assertIn('n', r)
2042        self.assertEqual(r['n'], 0)
2043        self.assertIn('data', r)
2044        r = r['data']
2045        self.assertIsInstance(r, bytes)
2046        self.assertEqual(r, s)
2047        r = self.db.update('bytea_test', n=0, data=None)
2048        self.assertIsNone(r['data'])
2049        # insert as bytes
2050        s = b"It's all \\ kinds \x00 of\r nasty \xff stuff!\n"
2051        r = self.db.insert('bytea_test', n=5, data=s)
2052        self.assertIsInstance(r, dict)
2053        self.assertIn('n', r)
2054        self.assertEqual(r['n'], 5)
2055        self.assertIn('data', r)
2056        r = r['data']
2057        self.assertIsInstance(r, bytes)
2058        self.assertEqual(r, s)
2059        # update as bytes
2060        s += b"and now even more \x00 nasty \t stuff!\f"
2061        r = self.db.update('bytea_test', n=5, data=s)
2062        self.assertIsInstance(r, dict)
2063        self.assertIn('n', r)
2064        self.assertEqual(r['n'], 5)
2065        self.assertIn('data', r)
2066        r = r['data']
2067        self.assertIsInstance(r, bytes)
2068        self.assertEqual(r, s)
2069        r = query('select * from bytea_test where n=5').getresult()
2070        self.assertEqual(len(r), 1)
2071        r = r[0]
2072        self.assertEqual(len(r), 2)
2073        self.assertEqual(r[0], 5)
2074        r = r[1]
2075        self.assertIsInstance(r, str)
2076        r = self.db.unescape_bytea(r)
2077        self.assertIsInstance(r, bytes)
2078        self.assertEqual(r, s)
2079        r = self.db.get('bytea_test', dict(n=5))
2080        self.assertIsInstance(r, dict)
2081        self.assertIn('n', r)
2082        self.assertEqual(r['n'], 5)
2083        self.assertIn('data', r)
2084        r = r['data']
2085        self.assertIsInstance(r, bytes)
2086        self.assertEqual(r, s)
2087
2088    def testUpsertBytea(self):
2089        query = self.db.query
2090        query('drop table if exists bytea_test')
2091        self.addCleanup(query, 'drop table bytea_test')
2092        query('create table bytea_test (n smallint primary key, data bytea)')
2093        s = b"It's all \\ kinds \x00 of\r nasty \xff stuff!\n"
2094        r = dict(n=7, data=s)
2095        try:
2096            r = self.db.upsert('bytea_test', r)
2097        except pg.ProgrammingError as error:
2098            if self.db.server_version < 90500:
2099                self.skipTest('database does not support upsert')
2100            self.fail(str(error))
2101        self.assertIsInstance(r, dict)
2102        self.assertIn('n', r)
2103        self.assertEqual(r['n'], 7)
2104        self.assertIn('data', r)
2105        self.assertIsInstance(r['data'], bytes)
2106        self.assertEqual(r['data'], s)
2107        r['data'] = None
2108        r = self.db.upsert('bytea_test', r)
2109        self.assertIsInstance(r, dict)
2110        self.assertIn('n', r)
2111        self.assertEqual(r['n'], 7)
2112        self.assertIn('data', r)
2113        self.assertIsNone(r['data'], bytes)
2114
2115    def testNotificationHandler(self):
2116        # the notification handler itself is tested separately
2117        f = self.db.notification_handler
2118        callback = lambda arg_dict: None
2119        handler = f('test', callback)
2120        self.assertIsInstance(handler, pg.NotificationHandler)
2121        self.assertIs(handler.db, self.db)
2122        self.assertEqual(handler.event, 'test')
2123        self.assertEqual(handler.stop_event, 'stop_test')
2124        self.assertIs(handler.callback, callback)
2125        self.assertIsInstance(handler.arg_dict, dict)
2126        self.assertEqual(handler.arg_dict, {})
2127        self.assertIsNone(handler.timeout)
2128        self.assertFalse(handler.listening)
2129        handler.close()
2130        self.assertIsNone(handler.db)
2131        self.db.reopen()
2132        self.assertIsNone(handler.db)
2133        handler = f('test2', callback, timeout=2)
2134        self.assertIsInstance(handler, pg.NotificationHandler)
2135        self.assertIs(handler.db, self.db)
2136        self.assertEqual(handler.event, 'test2')
2137        self.assertEqual(handler.stop_event, 'stop_test2')
2138        self.assertIs(handler.callback, callback)
2139        self.assertIsInstance(handler.arg_dict, dict)
2140        self.assertEqual(handler.arg_dict, {})
2141        self.assertEqual(handler.timeout, 2)
2142        self.assertFalse(handler.listening)
2143        handler.close()
2144        self.assertIsNone(handler.db)
2145        self.db.reopen()
2146        self.assertIsNone(handler.db)
2147        arg_dict = {'testing': 3}
2148        handler = f('test3', callback, arg_dict=arg_dict)
2149        self.assertIsInstance(handler, pg.NotificationHandler)
2150        self.assertIs(handler.db, self.db)
2151        self.assertEqual(handler.event, 'test3')
2152        self.assertEqual(handler.stop_event, 'stop_test3')
2153        self.assertIs(handler.callback, callback)
2154        self.assertIs(handler.arg_dict, arg_dict)
2155        self.assertEqual(arg_dict['testing'], 3)
2156        self.assertIsNone(handler.timeout)
2157        self.assertFalse(handler.listening)
2158        handler.close()
2159        self.assertIsNone(handler.db)
2160        self.db.reopen()
2161        self.assertIsNone(handler.db)
2162        handler = f('test4', callback, stop_event='stop4')
2163        self.assertIsInstance(handler, pg.NotificationHandler)
2164        self.assertIs(handler.db, self.db)
2165        self.assertEqual(handler.event, 'test4')
2166        self.assertEqual(handler.stop_event, 'stop4')
2167        self.assertIs(handler.callback, callback)
2168        self.assertIsInstance(handler.arg_dict, dict)
2169        self.assertEqual(handler.arg_dict, {})
2170        self.assertIsNone(handler.timeout)
2171        self.assertFalse(handler.listening)
2172        handler.close()
2173        self.assertIsNone(handler.db)
2174        self.db.reopen()
2175        self.assertIsNone(handler.db)
2176        arg_dict = {'testing': 5}
2177        handler = f('test5', callback, arg_dict, 1.5, 'stop5')
2178        self.assertIsInstance(handler, pg.NotificationHandler)
2179        self.assertIs(handler.db, self.db)
2180        self.assertEqual(handler.event, 'test5')
2181        self.assertEqual(handler.stop_event, 'stop5')
2182        self.assertIs(handler.callback, callback)
2183        self.assertIs(handler.arg_dict, arg_dict)
2184        self.assertEqual(arg_dict['testing'], 5)
2185        self.assertEqual(handler.timeout, 1.5)
2186        self.assertFalse(handler.listening)
2187        handler.close()
2188        self.assertIsNone(handler.db)
2189        self.db.reopen()
2190        self.assertIsNone(handler.db)
2191
2192
2193class TestDBClassNonStdOpts(TestDBClass):
2194    """Test the methods of the DB class with non-standard global options."""
2195
2196    @classmethod
2197    def setUpClass(cls):
2198        cls.saved_options = {}
2199        cls.set_option('decimal', float)
2200        not_bool = not pg.get_bool()
2201        cls.set_option('bool', not_bool)
2202        unnamed_result = lambda q: q.getresult()
2203        cls.set_option('namedresult', unnamed_result)
2204        super(TestDBClassNonStdOpts, cls).setUpClass()
2205
2206    @classmethod
2207    def tearDownClass(cls):
2208        super(TestDBClassNonStdOpts, cls).tearDownClass()
2209        cls.reset_option('namedresult')
2210        cls.reset_option('bool')
2211        cls.reset_option('decimal')
2212
2213    @classmethod
2214    def set_option(cls, option, value):
2215        cls.saved_options[option] = getattr(pg, 'get_' + option)()
2216        return getattr(pg, 'set_' + option)(value)
2217
2218    @classmethod
2219    def reset_option(cls, option):
2220        return getattr(pg, 'set_' + option)(cls.saved_options[option])
2221
2222
2223class TestSchemas(unittest.TestCase):
2224    """Test correct handling of schemas (namespaces)."""
2225
2226    @classmethod
2227    def setUpClass(cls):
2228        db = DB()
2229        query = db.query
2230        for num_schema in range(5):
2231            if num_schema:
2232                schema = "s%d" % num_schema
2233                query("drop schema if exists %s cascade" % (schema,))
2234                try:
2235                    query("create schema %s" % (schema,))
2236                except pg.ProgrammingError:
2237                    raise RuntimeError("The test user cannot create schemas.\n"
2238                        "Grant create on database %s to the user"
2239                        " for running these tests." % dbname)
2240            else:
2241                schema = "public"
2242                query("drop table if exists %s.t" % (schema,))
2243                query("drop table if exists %s.t%d" % (schema, num_schema))
2244            query("create table %s.t with oids as select 1 as n, %d as d"
2245                  % (schema, num_schema))
2246            query("create table %s.t%d with oids as select 1 as n, %d as d"
2247                  % (schema, num_schema, num_schema))
2248        db.close()
2249
2250    @classmethod
2251    def tearDownClass(cls):
2252        db = DB()
2253        query = db.query
2254        for num_schema in range(5):
2255            if num_schema:
2256                schema = "s%d" % num_schema
2257                query("drop schema %s cascade" % (schema,))
2258            else:
2259                schema = "public"
2260                query("drop table %s.t" % (schema,))
2261                query("drop table %s.t%d" % (schema, num_schema))
2262        db.close()
2263
2264    def setUp(self):
2265        self.db = DB()
2266
2267    def tearDown(self):
2268        self.doCleanups()
2269        self.db.close()
2270
2271    def testGetTables(self):
2272        tables = self.db.get_tables()
2273        for num_schema in range(5):
2274            if num_schema:
2275                schema = "s" + str(num_schema)
2276            else:
2277                schema = "public"
2278            for t in (schema + ".t",
2279                    schema + ".t" + str(num_schema)):
2280                self.assertIn(t, tables)
2281
2282    def testGetAttnames(self):
2283        get_attnames = self.db.get_attnames
2284        query = self.db.query
2285        result = {'oid': 'int', 'd': 'int', 'n': 'int'}
2286        r = get_attnames("t")
2287        self.assertEqual(r, result)
2288        r = get_attnames("s4.t4")
2289        self.assertEqual(r, result)
2290        query("drop table if exists s3.t3m")
2291        self.addCleanup(query, "drop table s3.t3m")
2292        query("create table s3.t3m with oids as select 1 as m")
2293        result_m = {'oid': 'int', 'm': 'int'}
2294        r = get_attnames("s3.t3m")
2295        self.assertEqual(r, result_m)
2296        query("set search_path to s1,s3")
2297        r = get_attnames("t3")
2298        self.assertEqual(r, result)
2299        r = get_attnames("t3m")
2300        self.assertEqual(r, result_m)
2301
2302    def testGet(self):
2303        get = self.db.get
2304        query = self.db.query
2305        PrgError = pg.ProgrammingError
2306        self.assertEqual(get("t", 1, 'n')['d'], 0)
2307        self.assertEqual(get("t0", 1, 'n')['d'], 0)
2308        self.assertEqual(get("public.t", 1, 'n')['d'], 0)
2309        self.assertEqual(get("public.t0", 1, 'n')['d'], 0)
2310        self.assertRaises(PrgError, get, "public.t1", 1, 'n')
2311        self.assertEqual(get("s1.t1", 1, 'n')['d'], 1)
2312        self.assertEqual(get("s3.t", 1, 'n')['d'], 3)
2313        query("set search_path to s2,s4")
2314        self.assertRaises(PrgError, get, "t1", 1, 'n')
2315        self.assertEqual(get("t4", 1, 'n')['d'], 4)
2316        self.assertRaises(PrgError, get, "t3", 1, 'n')
2317        self.assertEqual(get("t", 1, 'n')['d'], 2)
2318        self.assertEqual(get("s3.t3", 1, 'n')['d'], 3)
2319        query("set search_path to s1,s3")
2320        self.assertRaises(PrgError, get, "t2", 1, 'n')
2321        self.assertEqual(get("t3", 1, 'n')['d'], 3)
2322        self.assertRaises(PrgError, get, "t4", 1, 'n')
2323        self.assertEqual(get("t", 1, 'n')['d'], 1)
2324        self.assertEqual(get("s4.t4", 1, 'n')['d'], 4)
2325
2326    def testMunging(self):
2327        get = self.db.get
2328        query = self.db.query
2329        r = get("t", 1, 'n')
2330        self.assertIn('oid(t)', r)
2331        query("set search_path to s2")
2332        r = get("t2", 1, 'n')
2333        self.assertIn('oid(t2)', r)
2334        query("set search_path to s3")
2335        r = get("t", 1, 'n')
2336        self.assertIn('oid(t)', r)
2337
2338
2339class TestDebug(unittest.TestCase):
2340    """Test the debug attribute of the DB class."""
2341
2342    def setUp(self):
2343        self.db = DB()
2344        self.query = self.db.query
2345        self.debug = self.db.debug
2346        self.output = StringIO()
2347        self.stdout, sys.stdout = sys.stdout, self.output
2348
2349    def tearDown(self):
2350        sys.stdout = self.stdout
2351        self.output.close()
2352        self.db.debug = debug
2353        self.db.close()
2354
2355    def get_output(self):
2356        return self.output.getvalue()
2357
2358    def send_queries(self):
2359        self.db.query("select 1")
2360        self.db.query("select 2")
2361
2362    def testDebugDefault(self):
2363        if debug:
2364            self.assertEqual(self.db.debug, debug)
2365        else:
2366            self.assertIsNone(self.db.debug)
2367
2368    def testDebugIsFalse(self):
2369        self.db.debug = False
2370        self.send_queries()
2371        self.assertEqual(self.get_output(), "")
2372
2373    def testDebugIsTrue(self):
2374        self.db.debug = True
2375        self.send_queries()
2376        self.assertEqual(self.get_output(), "select 1\nselect 2\n")
2377
2378    def testDebugIsString(self):
2379        self.db.debug = "Test with string: %s."
2380        self.send_queries()
2381        self.assertEqual(self.get_output(),
2382            "Test with string: select 1.\nTest with string: select 2.\n")
2383
2384    def testDebugIsFileLike(self):
2385        with tempfile.TemporaryFile('w+') as debug_file:
2386            self.db.debug = debug_file
2387            self.send_queries()
2388            debug_file.seek(0)
2389            output = debug_file.read()
2390            self.assertEqual(output, "select 1\nselect 2\n")
2391            self.assertEqual(self.get_output(), "")
2392
2393    def testDebugIsCallable(self):
2394        output = []
2395        self.db.debug = output.append
2396        self.db.query("select 1")
2397        self.db.query("select 2")
2398        self.assertEqual(output, ["select 1", "select 2"])
2399        self.assertEqual(self.get_output(), "")
2400
2401
2402if __name__ == '__main__':
2403    unittest.main()
Note: See TracBrowser for help on using the repository browser.