source: trunk/tests/test_classic_dbwrapper.py @ 745

Last change on this file since 745 was 745, checked in by cito, 4 years ago

Add methods get/set_parameter to DB wrapper class

These methods can be used to get/set/reset run-time parameters,
even several at once.

Since this is pretty useful and will not break anything, I have
also back ported these additions to the 4.x branch.

Everything is well documented and tested, of course.

  • Property svn:executable set to *
  • Property svn:keywords set to Id
File size: 69.3 KB
Line 
1#! /usr/bin/python
2# -*- coding: utf-8 -*-
3
4"""Test the classic PyGreSQL interface.
5
6Sub-tests for the DB wrapper object.
7
8Contributed by Christoph Zwerschke.
9
10These tests need a database to test against.
11
12"""
13
14try:
15    import unittest2 as unittest  # for Python < 2.7
16except ImportError:
17    import unittest
18import os
19
20import pg  # the module under test
21
22from decimal import Decimal
23
24# We need a database to test against.  If LOCAL_PyGreSQL.py exists we will
25# get our information from that.  Otherwise we use the defaults.
26# The current user must have create schema privilege on the database.
27dbname = 'unittest'
28dbhost = None
29dbport = 5432
30
31debug = False  # let DB wrapper print debugging output
32
33try:
34    from .LOCAL_PyGreSQL import *
35except (ImportError, ValueError):
36    try:
37        from LOCAL_PyGreSQL import *
38    except ImportError:
39        pass
40
41try:
42    long
43except NameError:  # Python >= 3.0
44    long = int
45
46try:
47    unicode
48except NameError:  # Python >= 3.0
49    unicode = str
50
51try:
52    from collections import OrderedDict
53except ImportError:  # Python 2.6 or 3.0
54    OrderedDict = dict
55
56windows = os.name == 'nt'
57
58# There is a known a bug in libpq under Windows which can cause
59# the interface to crash when calling PQhost():
60do_not_ask_for_host = windows
61do_not_ask_for_host_reason = 'libpq issue on Windows'
62
63
64def DB():
65    """Create a DB wrapper object connecting to the test database."""
66    db = pg.DB(dbname, dbhost, dbport)
67    if debug:
68        db.debug = debug
69    db.query("set client_min_messages=warning")
70    return db
71
72
73class TestDBClassBasic(unittest.TestCase):
74    """Test existence of the DB class wrapped pg connection methods."""
75
76    def setUp(self):
77        self.db = DB()
78
79    def tearDown(self):
80        try:
81            self.db.close()
82        except pg.InternalError:
83            pass
84
85    def testAllDBAttributes(self):
86        attributes = [
87            'begin',
88            'cancel', 'clear', 'close', 'commit',
89            'db', 'dbname', 'debug', 'delete',
90            'end', 'endcopy', 'error',
91            'escape_bytea', 'escape_identifier',
92            'escape_literal', 'escape_string',
93            'fileno',
94            'get', 'get_attnames', 'get_databases',
95            'get_notice_receiver', 'get_parameter',
96            'get_relations', 'get_tables',
97            'getline', 'getlo', 'getnotify',
98            'has_table_privilege', 'host',
99            'insert', 'inserttable',
100            'locreate', 'loimport',
101            'notification_handler',
102            'options',
103            'parameter', 'pkey', 'port',
104            'protocol_version', 'putline',
105            'query',
106            'release', 'reopen', 'reset', 'rollback',
107            'savepoint', 'server_version',
108            'set_notice_receiver', 'set_parameter',
109            'source', 'start', 'status',
110            'transaction',
111            'unescape_bytea', 'update', 'upsert',
112            'use_regtypes', 'user',
113        ]
114        db_attributes = [a for a in dir(self.db)
115            if not a.startswith('_')]
116        self.assertEqual(attributes, db_attributes)
117
118    def testAttributeDb(self):
119        self.assertEqual(self.db.db.db, dbname)
120
121    def testAttributeDbname(self):
122        self.assertEqual(self.db.dbname, dbname)
123
124    def testAttributeError(self):
125        error = self.db.error
126        self.assertTrue(not error or 'krb5_' in error)
127        self.assertEqual(self.db.error, self.db.db.error)
128
129    @unittest.skipIf(do_not_ask_for_host, do_not_ask_for_host_reason)
130    def testAttributeHost(self):
131        def_host = 'localhost'
132        host = self.db.host
133        self.assertIsInstance(host, str)
134        self.assertEqual(host, dbhost or def_host)
135        self.assertEqual(host, self.db.db.host)
136
137    def testAttributeOptions(self):
138        no_options = ''
139        options = self.db.options
140        self.assertEqual(options, no_options)
141        self.assertEqual(options, self.db.db.options)
142
143    def testAttributePort(self):
144        def_port = 5432
145        port = self.db.port
146        self.assertIsInstance(port, int)
147        self.assertEqual(port, dbport or def_port)
148        self.assertEqual(port, self.db.db.port)
149
150    def testAttributeProtocolVersion(self):
151        protocol_version = self.db.protocol_version
152        self.assertIsInstance(protocol_version, int)
153        self.assertTrue(2 <= protocol_version < 4)
154        self.assertEqual(protocol_version, self.db.db.protocol_version)
155
156    def testAttributeServerVersion(self):
157        server_version = self.db.server_version
158        self.assertIsInstance(server_version, int)
159        self.assertTrue(70400 <= server_version < 100000)
160        self.assertEqual(server_version, self.db.db.server_version)
161
162    def testAttributeStatus(self):
163        status_ok = 1
164        status = self.db.status
165        self.assertIsInstance(status, int)
166        self.assertEqual(status, status_ok)
167        self.assertEqual(status, self.db.db.status)
168
169    def testAttributeUser(self):
170        no_user = 'Deprecated facility'
171        user = self.db.user
172        self.assertTrue(user)
173        self.assertIsInstance(user, str)
174        self.assertNotEqual(user, no_user)
175        self.assertEqual(user, self.db.db.user)
176
177    def testMethodEscapeLiteral(self):
178        self.assertEqual(self.db.escape_literal(''), "''")
179
180    def testMethodEscapeIdentifier(self):
181        self.assertEqual(self.db.escape_identifier(''), '""')
182
183    def testMethodEscapeString(self):
184        self.assertEqual(self.db.escape_string(''), '')
185
186    def testMethodEscapeBytea(self):
187        self.assertEqual(self.db.escape_bytea('').replace(
188            '\\x', '').replace('\\', ''), '')
189
190    def testMethodUnescapeBytea(self):
191        self.assertEqual(self.db.unescape_bytea(''), b'')
192
193    def testMethodQuery(self):
194        query = self.db.query
195        query("select 1+1")
196        query("select 1+$1+$2", 2, 3)
197        query("select 1+$1+$2", (2, 3))
198        query("select 1+$1+$2", [2, 3])
199        query("select 1+$1", 1)
200
201    def testMethodQueryEmpty(self):
202        self.assertRaises(ValueError, self.db.query, '')
203
204    def testMethodQueryProgrammingError(self):
205        try:
206            self.db.query("select 1/0")
207        except pg.ProgrammingError as error:
208            self.assertEqual(error.sqlstate, '22012')
209
210    def testMethodEndcopy(self):
211        try:
212            self.db.endcopy()
213        except IOError:
214            pass
215
216    def testMethodClose(self):
217        self.db.close()
218        try:
219            self.db.reset()
220        except pg.Error:
221            pass
222        else:
223            self.fail('Reset should give an error for a closed connection')
224        self.assertRaises(pg.InternalError, self.db.close)
225        self.assertRaises(pg.InternalError, self.db.query, 'select 1')
226
227    def testExistingConnection(self):
228        db = pg.DB(self.db.db)
229        self.assertEqual(self.db.db, db.db)
230        self.assertTrue(db.db)
231        db.close()
232        self.assertTrue(db.db)
233        db.reopen()
234        self.assertTrue(db.db)
235        db.close()
236        self.assertTrue(db.db)
237        db = pg.DB(self.db)
238        self.assertEqual(self.db.db, db.db)
239        db = pg.DB(db=self.db.db)
240        self.assertEqual(self.db.db, db.db)
241
242        class DB2:
243            pass
244
245        db2 = DB2()
246        db2._cnx = self.db.db
247        db = pg.DB(db2)
248        self.assertEqual(self.db.db, db.db)
249
250
251class TestDBClass(unittest.TestCase):
252    """Test the methods of the DB class wrapped pg connection."""
253
254    @classmethod
255    def setUpClass(cls):
256        db = DB()
257        db.query("drop table if exists test cascade")
258        db.query("create table test ("
259            "i2 smallint, i4 integer, i8 bigint,"
260            " d numeric, f4 real, f8 double precision, m money,"
261            " v4 varchar(4), c4 char(4), t text)")
262        db.query("create or replace view test_view as"
263            " select i4, v4 from test")
264        db.close()
265
266    @classmethod
267    def tearDownClass(cls):
268        db = DB()
269        db.query("drop table test cascade")
270        db.close()
271
272    def setUp(self):
273        self.db = DB()
274        query = self.db.query
275        query('set client_encoding=utf8')
276        query('set standard_conforming_strings=on')
277        query("set lc_monetary='C'")
278        query("set datestyle='ISO,YMD'")
279        query('set bytea_output=hex')
280
281    def tearDown(self):
282        self.doCleanups()
283        self.db.close()
284
285    def testClassName(self):
286        self.assertEqual(self.db.__class__.__name__, 'DB')
287
288    def testModuleName(self):
289        self.assertEqual(self.db.__module__, 'pg')
290        self.assertEqual(self.db.__class__.__module__, 'pg')
291
292    def testEscapeLiteral(self):
293        f = self.db.escape_literal
294        r = f(b"plain")
295        self.assertIsInstance(r, bytes)
296        self.assertEqual(r, b"'plain'")
297        r = f(u"plain")
298        self.assertIsInstance(r, unicode)
299        self.assertEqual(r, u"'plain'")
300        r = f(u"that's kÀse".encode('utf-8'))
301        self.assertIsInstance(r, bytes)
302        self.assertEqual(r, u"'that''s kÀse'".encode('utf-8'))
303        r = f(u"that's kÀse")
304        self.assertIsInstance(r, unicode)
305        self.assertEqual(r, u"'that''s kÀse'")
306        self.assertEqual(f(r"It's fine to have a \ inside."),
307            r" E'It''s fine to have a \\ inside.'")
308        self.assertEqual(f('No "quotes" must be escaped.'),
309            "'No \"quotes\" must be escaped.'")
310
311    def testEscapeIdentifier(self):
312        f = self.db.escape_identifier
313        r = f(b"plain")
314        self.assertIsInstance(r, bytes)
315        self.assertEqual(r, b'"plain"')
316        r = f(u"plain")
317        self.assertIsInstance(r, unicode)
318        self.assertEqual(r, u'"plain"')
319        r = f(u"that's kÀse".encode('utf-8'))
320        self.assertIsInstance(r, bytes)
321        self.assertEqual(r, u'"that\'s kÀse"'.encode('utf-8'))
322        r = f(u"that's kÀse")
323        self.assertIsInstance(r, unicode)
324        self.assertEqual(r, u'"that\'s kÀse"')
325        self.assertEqual(f(r"It's fine to have a \ inside."),
326            '"It\'s fine to have a \\ inside."')
327        self.assertEqual(f('All "quotes" must be escaped.'),
328            '"All ""quotes"" must be escaped."')
329
330    def testEscapeString(self):
331        f = self.db.escape_string
332        r = f(b"plain")
333        self.assertIsInstance(r, bytes)
334        self.assertEqual(r, b"plain")
335        r = f(u"plain")
336        self.assertIsInstance(r, unicode)
337        self.assertEqual(r, u"plain")
338        r = f(u"that's kÀse".encode('utf-8'))
339        self.assertIsInstance(r, bytes)
340        self.assertEqual(r, u"that''s kÀse".encode('utf-8'))
341        r = f(u"that's kÀse")
342        self.assertIsInstance(r, unicode)
343        self.assertEqual(r, u"that''s kÀse")
344        self.assertEqual(f(r"It's fine to have a \ inside."),
345            r"It''s fine to have a \ inside.")
346
347    def testEscapeBytea(self):
348        f = self.db.escape_bytea
349        # note that escape_byte always returns hex output since Pg 9.0,
350        # regardless of the bytea_output setting
351        r = f(b'plain')
352        self.assertIsInstance(r, bytes)
353        self.assertEqual(r, b'\\x706c61696e')
354        r = f(u'plain')
355        self.assertIsInstance(r, unicode)
356        self.assertEqual(r, u'\\x706c61696e')
357        r = f(u"das is' kÀse".encode('utf-8'))
358        self.assertIsInstance(r, bytes)
359        self.assertEqual(r, b'\\x64617320697327206bc3a47365')
360        r = f(u"das is' kÀse")
361        self.assertIsInstance(r, unicode)
362        self.assertEqual(r, u'\\x64617320697327206bc3a47365')
363        self.assertEqual(f(b'O\x00ps\xff!'), b'\\x4f007073ff21')
364
365    def testUnescapeBytea(self):
366        f = self.db.unescape_bytea
367        r = f(b'plain')
368        self.assertIsInstance(r, bytes)
369        self.assertEqual(r, b'plain')
370        r = f(u'plain')
371        self.assertIsInstance(r, bytes)
372        self.assertEqual(r, b'plain')
373        r = f(b"das is' k\\303\\244se")
374        self.assertIsInstance(r, bytes)
375        self.assertEqual(r, u"das is' kÀse".encode('utf8'))
376        r = f(u"das is' k\\303\\244se")
377        self.assertIsInstance(r, bytes)
378        self.assertEqual(r, u"das is' kÀse".encode('utf8'))
379        self.assertEqual(f(r'O\\000ps\\377!'), b'O\\000ps\\377!')
380        self.assertEqual(f(r'\\x706c61696e'), b'\\x706c61696e')
381        self.assertEqual(f(r'\\x746861742773206be47365'),
382            b'\\x746861742773206be47365')
383        self.assertEqual(f(r'\\x4f007073ff21'), b'\\x4f007073ff21')
384
385    def testGetParameter(self):
386        f = self.db.get_parameter
387        self.assertRaises(TypeError, f)
388        self.assertRaises(TypeError, f, None)
389        self.assertRaises(TypeError, f, 42)
390        self.assertRaises(pg.ProgrammingError, f, 'this_does_not_exist')
391        r = f('standard_conforming_strings')
392        self.assertEqual(r, 'on')
393        r = f('lc_monetary')
394        self.assertEqual(r, 'C')
395        r = f('datestyle')
396        self.assertEqual(r, 'ISO, YMD')
397        r = f('bytea_output')
398        self.assertEqual(r, 'hex')
399        r = f(('bytea_output', 'lc_monetary'))
400        self.assertIsInstance(r, list)
401        self.assertEqual(r, ['hex', 'C'])
402        r = f(['standard_conforming_strings', 'datestyle', 'bytea_output'])
403        self.assertEqual(r, ['on', 'ISO, YMD', 'hex'])
404        s = dict.fromkeys(('bytea_output', 'lc_monetary'))
405        r = f(s)
406        self.assertIs(r, s)
407        self.assertEqual(r, {'bytea_output': 'hex', 'lc_monetary': 'C'})
408        s = dict.fromkeys(('Bytea_Output', 'LC_Monetary'))
409        r = f(s)
410        self.assertIs(r, s)
411        self.assertEqual(r, {'Bytea_Output': 'hex', 'LC_Monetary': 'C'})
412
413    def testGetParameterServerVersion(self):
414        r = self.db.get_parameter('server_version_num')
415        self.assertIsInstance(r, str)
416        s = self.db.server_version
417        self.assertIsInstance(s, int)
418        self.assertEqual(r, str(s))
419
420    def testGetParameterAll(self):
421        f = self.db.get_parameter
422        r = f('all')
423        self.assertIsInstance(r, dict)
424        self.assertEqual(r['standard_conforming_strings'], 'on')
425        self.assertEqual(r['lc_monetary'], 'C')
426        self.assertEqual(r['DateStyle'], 'ISO, YMD')
427        self.assertEqual(r['bytea_output'], 'hex')
428
429    def testSetParameter(self):
430        f = self.db.set_parameter
431        g = self.db.get_parameter
432        self.assertRaises(TypeError, f)
433        self.assertRaises(TypeError, f, None)
434        self.assertRaises(TypeError, f, 42)
435        self.assertRaises(pg.ProgrammingError, f, 'this_does_not_exist')
436        f('standard_conforming_strings', 'off')
437        self.assertEqual(g('standard_conforming_strings'), 'off')
438        f('datestyle', 'ISO, DMY')
439        self.assertEqual(g('datestyle'), 'ISO, DMY')
440        f(('standard_conforming_strings', 'datestyle'), ('on', 'ISO, YMD'))
441        self.assertEqual(g('standard_conforming_strings'), 'on')
442        self.assertEqual(g('datestyle'), 'ISO, YMD')
443        f(['standard_conforming_strings', 'datestyle'], ['off', 'ISO, DMY'])
444        self.assertEqual(g('standard_conforming_strings'), 'off')
445        self.assertEqual(g('datestyle'), 'ISO, DMY')
446        f({'standard_conforming_strings': 'on', 'datestyle': 'ISO, YMD'})
447        self.assertEqual(g('standard_conforming_strings'), 'on')
448        self.assertEqual(g('datestyle'), 'ISO, YMD')
449        f(('default_with_oids', 'standard_conforming_strings'), 'off')
450        self.assertEqual(g('default_with_oids'), 'off')
451        self.assertEqual(g('standard_conforming_strings'), 'off')
452        f(['default_with_oids', 'standard_conforming_strings'], 'on')
453        self.assertEqual(g('default_with_oids'), 'on')
454        self.assertEqual(g('standard_conforming_strings'), 'on')
455
456    def testResetParameter(self):
457        db = DB()
458        f = db.set_parameter
459        g = db.get_parameter
460        r = g('default_with_oids')
461        self.assertIn(r, ('on', 'off'))
462        dwi, not_dwi = r, 'off' if r == 'on' else 'on'
463        r = g('standard_conforming_strings')
464        self.assertIn(r, ('on', 'off'))
465        scs, not_scs = r, 'off' if r == 'on' else 'on'
466        f('default_with_oids', not_dwi)
467        f('standard_conforming_strings', not_scs)
468        self.assertEqual(g('default_with_oids'), not_dwi)
469        self.assertEqual(g('standard_conforming_strings'), not_scs)
470        f('default_with_oids')
471        f('standard_conforming_strings', None)
472        self.assertEqual(g('default_with_oids'), dwi)
473        self.assertEqual(g('standard_conforming_strings'), scs)
474        f('default_with_oids', not_dwi)
475        f('standard_conforming_strings', not_scs)
476        self.assertEqual(g('default_with_oids'), not_dwi)
477        self.assertEqual(g('standard_conforming_strings'), not_scs)
478        f(('default_with_oids', 'standard_conforming_strings'))
479        self.assertEqual(g('default_with_oids'), dwi)
480        self.assertEqual(g('standard_conforming_strings'), scs)
481        f('default_with_oids', not_dwi)
482        f('standard_conforming_strings', not_scs)
483        self.assertEqual(g('default_with_oids'), not_dwi)
484        self.assertEqual(g('standard_conforming_strings'), not_scs)
485        f(['default_with_oids', 'standard_conforming_strings'], None)
486        self.assertEqual(g('default_with_oids'), dwi)
487        self.assertEqual(g('standard_conforming_strings'), scs)
488
489    def testResetParameterAll(self):
490        db = DB()
491        f = db.set_parameter
492        self.assertRaises(ValueError, f, 'all', 0)
493        self.assertRaises(ValueError, f, 'all', 'off')
494        g = db.get_parameter
495        r = g('default_with_oids')
496        self.assertIn(r, ('on', 'off'))
497        dwi, not_dwi = r, 'off' if r == 'on' else 'on'
498        r = g('standard_conforming_strings')
499        self.assertIn(r, ('on', 'off'))
500        scs, not_scs = r, 'off' if r == 'on' else 'on'
501        f('default_with_oids', not_dwi)
502        f('standard_conforming_strings', not_scs)
503        self.assertEqual(g('default_with_oids'), not_dwi)
504        self.assertEqual(g('standard_conforming_strings'), not_scs)
505        f('all')
506        self.assertEqual(g('default_with_oids'), dwi)
507        self.assertEqual(g('standard_conforming_strings'), scs)
508
509    def testSetParameterLocal(self):
510        f = self.db.set_parameter
511        g = self.db.get_parameter
512        self.assertEqual(g('standard_conforming_strings'), 'on')
513        self.db.begin()
514        f('standard_conforming_strings', 'off', local=True)
515        self.assertEqual(g('standard_conforming_strings'), 'off')
516        self.db.end()
517        self.assertEqual(g('standard_conforming_strings'), 'on')
518
519    def testSetParameterSession(self):
520        f = self.db.set_parameter
521        g = self.db.get_parameter
522        self.assertEqual(g('standard_conforming_strings'), 'on')
523        self.db.begin()
524        f('standard_conforming_strings', 'off', local=False)
525        self.assertEqual(g('standard_conforming_strings'), 'off')
526        self.db.end()
527        self.assertEqual(g('standard_conforming_strings'), 'off')
528
529    def testQuery(self):
530        query = self.db.query
531        query("drop table if exists test_table")
532        self.addCleanup(query, "drop table test_table")
533        q = "create table test_table (n integer) with oids"
534        r = query(q)
535        self.assertIsNone(r)
536        q = "insert into test_table values (1)"
537        r = query(q)
538        self.assertIsInstance(r, int)
539        q = "insert into test_table select 2"
540        r = query(q)
541        self.assertIsInstance(r, int)
542        oid = r
543        q = "select oid from test_table where n=2"
544        r = query(q).getresult()
545        self.assertEqual(len(r), 1)
546        r = r[0]
547        self.assertEqual(len(r), 1)
548        r = r[0]
549        self.assertEqual(r, oid)
550        q = "insert into test_table select 3 union select 4 union select 5"
551        r = query(q)
552        self.assertIsInstance(r, str)
553        self.assertEqual(r, '3')
554        q = "update test_table set n=4 where n<5"
555        r = query(q)
556        self.assertIsInstance(r, str)
557        self.assertEqual(r, '4')
558        q = "delete from test_table"
559        r = query(q)
560        self.assertIsInstance(r, str)
561        self.assertEqual(r, '5')
562
563    def testMultipleQueries(self):
564        self.assertEqual(self.db.query(
565            "create temporary table test_multi (n integer);"
566            "insert into test_multi values (4711);"
567            "select n from test_multi").getresult()[0][0], 4711)
568
569    def testQueryWithParams(self):
570        query = self.db.query
571        query("drop table if exists test_table")
572        self.addCleanup(query, "drop table test_table")
573        q = "create table test_table (n1 integer, n2 integer) with oids"
574        query(q)
575        q = "insert into test_table values ($1, $2)"
576        r = query(q, (1, 2))
577        self.assertIsInstance(r, int)
578        r = query(q, [3, 4])
579        self.assertIsInstance(r, int)
580        r = query(q, [5, 6])
581        self.assertIsInstance(r, int)
582        q = "select * from test_table order by 1, 2"
583        self.assertEqual(query(q).getresult(),
584            [(1, 2), (3, 4), (5, 6)])
585        q = "select * from test_table where n1=$1 and n2=$2"
586        self.assertEqual(query(q, 3, 4).getresult(), [(3, 4)])
587        q = "update test_table set n2=$2 where n1=$1"
588        r = query(q, 3, 7)
589        self.assertEqual(r, '1')
590        q = "select * from test_table order by 1, 2"
591        self.assertEqual(query(q).getresult(),
592            [(1, 2), (3, 7), (5, 6)])
593        q = "delete from test_table where n2!=$1"
594        r = query(q, 4)
595        self.assertEqual(r, '3')
596
597    def testEmptyQuery(self):
598        self.assertRaises(ValueError, self.db.query, '')
599
600    def testQueryProgrammingError(self):
601        try:
602            self.db.query("select 1/0")
603        except pg.ProgrammingError as error:
604            self.assertEqual(error.sqlstate, '22012')
605
606    def testPkey(self):
607        query = self.db.query
608        pkey = self.db.pkey
609        for t in ('pkeytest', 'primary key test'):
610            for n in range(7):
611                query('drop table if exists "%s%d"' % (t, n))
612                self.addCleanup(query, 'drop table "%s%d"' % (t, n))
613            query('create table "%s0" ('
614                "a smallint)" % t)
615            query('create table "%s1" ('
616                "b smallint primary key)" % t)
617            query('create table "%s2" ('
618                "c smallint, d smallint primary key)" % t)
619            query('create table "%s3" ('
620                "e smallint, f smallint, g smallint,"
621                " h smallint, i smallint,"
622                " primary key (f, h))" % t)
623            query('create table "%s4" ('
624                "more_than_one_letter varchar primary key)" % t)
625            query('create table "%s5" ('
626                '"with space" date primary key)' % t)
627            query('create table "%s6" ('
628                'a_very_long_column_name varchar,'
629                ' "with space" date,'
630                ' "42" int,'
631                " primary key (a_very_long_column_name,"
632                ' "with space", "42"))' % t)
633            self.assertRaises(KeyError, pkey, '%s0' % t)
634            self.assertEqual(pkey('%s1' % t), 'b')
635            self.assertEqual(pkey('%s2' % t), 'd')
636            r = pkey('%s3' % t)
637            self.assertIsInstance(r, frozenset)
638            self.assertEqual(r, frozenset('fh'))
639            self.assertEqual(pkey('%s4' % t), 'more_than_one_letter')
640            self.assertEqual(pkey('%s5' % t), 'with space')
641            r = pkey('%s6' % t)
642            self.assertIsInstance(r, frozenset)
643            self.assertEqual(r, frozenset([
644                'a_very_long_column_name', 'with space', '42']))
645            # a newly added primary key will be detected
646            query('alter table "%s0" add primary key (a)' % t)
647            self.assertEqual(pkey('%s0' % t), 'a')
648            # a changed primary key will not be detected,
649            # indicating that the internal cache is operating
650            query('alter table "%s1" rename column b to x' % t)
651            self.assertEqual(pkey('%s1' % t), 'b')
652            # we get the changed primary key when the cache is flushed
653            self.assertEqual(pkey('%s1' % t, flush=True), 'x')
654
655    def testGetDatabases(self):
656        databases = self.db.get_databases()
657        self.assertIn('template0', databases)
658        self.assertIn('template1', databases)
659        self.assertNotIn('not existing database', databases)
660        self.assertIn('postgres', databases)
661        self.assertIn(dbname, databases)
662
663    def testGetTables(self):
664        get_tables = self.db.get_tables
665        result1 = get_tables()
666        self.assertIsInstance(result1, list)
667        for t in result1:
668            t = t.split('.', 1)
669            self.assertGreaterEqual(len(t), 2)
670            if len(t) > 2:
671                self.assertTrue(t[1].startswith('"'))
672            t = t[0]
673            self.assertNotEqual(t, 'information_schema')
674            self.assertFalse(t.startswith('pg_'))
675        tables = ('"A very Special Name"',
676            '"A_MiXeD_quoted_NaMe"', 'a1', 'a2',
677            'A_MiXeD_NaMe', '"another special name"',
678            'averyveryveryveryveryveryverylongtablename',
679            'b0', 'b3', 'x', 'xx', 'xXx', 'y', 'z')
680        for t in tables:
681            self.db.query('drop table if exists %s' % t)
682            self.db.query("create table %s"
683                " as select 0" % t)
684        result3 = get_tables()
685        result2 = []
686        for t in result3:
687            if t not in result1:
688                result2.append(t)
689        result3 = []
690        for t in tables:
691            if not t.startswith('"'):
692                t = t.lower()
693            result3.append('public.' + t)
694        self.assertEqual(result2, result3)
695        for t in result2:
696            self.db.query('drop table %s' % t)
697        result2 = get_tables()
698        self.assertEqual(result2, result1)
699
700    def testGetRelations(self):
701        get_relations = self.db.get_relations
702        result = get_relations()
703        self.assertIn('public.test', result)
704        self.assertIn('public.test_view', result)
705        result = get_relations('rv')
706        self.assertIn('public.test', result)
707        self.assertIn('public.test_view', result)
708        result = get_relations('r')
709        self.assertIn('public.test', result)
710        self.assertNotIn('public.test_view', result)
711        result = get_relations('v')
712        self.assertNotIn('public.test', result)
713        self.assertIn('public.test_view', result)
714        result = get_relations('cisSt')
715        self.assertNotIn('public.test', result)
716        self.assertNotIn('public.test_view', result)
717
718    def testGetAttnames(self):
719        get_attnames = self.db.get_attnames
720        self.assertRaises(pg.ProgrammingError,
721            self.db.get_attnames, 'does_not_exist')
722        self.assertRaises(pg.ProgrammingError,
723            self.db.get_attnames, 'has.too.many.dots')
724        query = self.db.query
725        query("drop table if exists test_table")
726        self.addCleanup(query, "drop table test_table")
727        query("create table test_table("
728            " n int, alpha smallint, beta bool,"
729            " gamma char(5), tau text, v varchar(3))")
730        r = get_attnames('test_table')
731        self.assertIsInstance(r, dict)
732        self.assertEqual(r, dict(
733            n='int', alpha='int', beta='bool',
734            gamma='text', tau='text', v='text'))
735
736    def testGetAttnamesWithQuotes(self):
737        get_attnames = self.db.get_attnames
738        query = self.db.query
739        table = 'test table for get_attnames()'
740        query('drop table if exists "%s"' % table)
741        self.addCleanup(query, 'drop table "%s"' % table)
742        query('create table "%s"('
743            '"Prime!" smallint,'
744            ' "much space" integer, "Questions?" text)' % table)
745        r = get_attnames(table)
746        self.assertIsInstance(r, dict)
747        self.assertEqual(r, {
748            'Prime!': 'int', 'much space': 'int', 'Questions?': 'text'})
749        table = 'yet another test table for get_attnames()'
750        query('drop table if exists "%s"' % table)
751        self.addCleanup(query, 'drop table "%s"' % table)
752        self.db.query('create table "%s" ('
753            'a smallint, b integer, c bigint,'
754            ' e numeric, f float, f2 double precision, m money,'
755            ' x smallint, y smallint, z smallint,'
756            ' Normal_NaMe smallint, "Special Name" smallint,'
757            ' t text, u char(2), v varchar(2),'
758            ' primary key (y, u)) with oids' % table)
759        r = get_attnames(table)
760        self.assertIsInstance(r, dict)
761        self.assertEqual(r, {'a': 'int', 'c': 'int', 'b': 'int',
762            'e': 'num', 'f': 'float', 'f2': 'float', 'm': 'money',
763            'normal_name': 'int', 'Special Name': 'int',
764            'u': 'text', 't': 'text', 'v': 'text',
765            'y': 'int', 'x': 'int', 'z': 'int', 'oid': 'int'})
766
767    def testGetAttnamesWithRegtypes(self):
768        get_attnames = self.db.get_attnames
769        query = self.db.query
770        query("drop table if exists test_table")
771        self.addCleanup(query, "drop table test_table")
772        query("create table test_table("
773            " n int, alpha smallint, beta bool,"
774            " gamma char(5), tau text, v varchar(3))")
775        self.db.use_regtypes(True)
776        try:
777            r = get_attnames("test_table")
778            self.assertIsInstance(r, dict)
779        finally:
780            self.db.use_regtypes(False)
781        self.assertEqual(r, dict(
782            n='integer', alpha='smallint', beta='boolean',
783            gamma='character', tau='text', v='character varying'))
784
785    def testGetAttnamesIsCached(self):
786        get_attnames = self.db.get_attnames
787        query = self.db.query
788        query("drop table if exists test_table")
789        self.addCleanup(query, "drop table if exists test_table")
790        query("create table test_table(col int)")
791        r = get_attnames("test_table")
792        self.assertIsInstance(r, dict)
793        self.assertEqual(r, dict(col='int'))
794        query("drop table test_table")
795        query("create table test_table(col text)")
796        r = get_attnames("test_table")
797        self.assertEqual(r, dict(col='int'))
798        r = get_attnames("test_table", flush=True)
799        self.assertEqual(r, dict(col='text'))
800        query("drop table test_table")
801        r = get_attnames("test_table")
802        self.assertEqual(r, dict(col='text'))
803        self.assertRaises(pg.ProgrammingError,
804            get_attnames, "test_table", flush=True)
805
806    def testGetAttnamesIsOrdered(self):
807        get_attnames = self.db.get_attnames
808        query = self.db.query
809        query("drop table if exists test_table")
810        self.addCleanup(query, "drop table test_table")
811        query("create table test_table("
812            " n int, alpha smallint, v varchar(3),"
813            " gamma char(5), tau text, beta bool)")
814        r = get_attnames("test_table")
815        self.assertIsInstance(r, OrderedDict)
816        self.assertEqual(r, OrderedDict([
817            ('n', 'int'), ('alpha', 'int'), ('v', 'text'),
818            ('gamma', 'text'), ('tau', 'text'), ('beta', 'bool')]))
819        if OrderedDict is dict:
820            self.skipTest('OrderedDict is not supported')
821        r = ' '.join(list(r.keys()))
822        self.assertEqual(r, 'n alpha v gamma tau beta')
823
824    def testHasTablePrivilege(self):
825        can = self.db.has_table_privilege
826        self.assertEqual(can('test'), True)
827        self.assertEqual(can('test', 'select'), True)
828        self.assertEqual(can('test', 'SeLeCt'), True)
829        self.assertEqual(can('test', 'SELECT'), True)
830        self.assertEqual(can('test', 'insert'), True)
831        self.assertEqual(can('test', 'update'), True)
832        self.assertEqual(can('test', 'delete'), True)
833        self.assertEqual(can('pg_views', 'select'), True)
834        self.assertEqual(can('pg_views', 'delete'), False)
835        self.assertRaises(pg.ProgrammingError, can, 'test', 'foobar')
836        self.assertRaises(pg.ProgrammingError, can, 'table_does_not_exist')
837
838    def testGet(self):
839        get = self.db.get
840        query = self.db.query
841        table = 'get_test_table'
842        query('drop table if exists "%s"' % table)
843        self.addCleanup(query, 'drop table "%s"' % table)
844        query('create table "%s" ('
845            "n integer, t text) with oids" % table)
846        for n, t in enumerate('xyz'):
847            query('insert into "%s" values('"%d, '%s')"
848                % (table, n + 1, t))
849        self.assertRaises(pg.ProgrammingError, get, table, 2)
850        r = get(table, 2, 'n')
851        oid_table = 'oid(%s)' % table
852        self.assertIn(oid_table, r)
853        oid = r[oid_table]
854        self.assertIsInstance(oid, int)
855        result = {'t': 'y', 'n': 2, oid_table: oid}
856        self.assertEqual(r, result)
857        self.assertEqual(get(table + ' *', 2, 'n'), r)
858        self.assertEqual(get(table, oid, 'oid')['t'], 'y')
859        self.assertEqual(get(table, 1, 'n')['t'], 'x')
860        self.assertEqual(get(table, 3, 'n')['t'], 'z')
861        self.assertEqual(get(table, 2, 'n')['t'], 'y')
862        self.assertRaises(pg.DatabaseError, get, table, 4, 'n')
863        r['n'] = 3
864        self.assertEqual(get(table, r, 'n')['t'], 'z')
865        self.assertEqual(get(table, 1, 'n')['t'], 'x')
866        query('alter table "%s" alter n set not null' % table)
867        query('alter table "%s" add primary key (n)' % table)
868        self.assertEqual(get(table, 3)['t'], 'z')
869        self.assertEqual(get(table, 1)['t'], 'x')
870        self.assertEqual(get(table, 2)['t'], 'y')
871        r['n'] = 1
872        self.assertEqual(get(table, r)['t'], 'x')
873        r['n'] = 3
874        self.assertEqual(get(table, r)['t'], 'z')
875        r['n'] = 2
876        self.assertEqual(get(table, r)['t'], 'y')
877
878    def testGetWithCompositeKey(self):
879        get = self.db.get
880        query = self.db.query
881        table = 'get_test_table_1'
882        query('drop table if exists "%s"' % table)
883        self.addCleanup(query, 'drop table "%s"' % table)
884        query('create table "%s" ('
885            "n integer, t text, primary key (n))" % table)
886        for n, t in enumerate('abc'):
887            query('insert into "%s" values('
888                "%d, '%s')" % (table, n + 1, t))
889        self.assertEqual(get(table, 2)['t'], 'b')
890        table = 'get_test_table_2'
891        query('drop table if exists "%s"' % table)
892        self.addCleanup(query, 'drop table "%s"' % table)
893        query('create table "%s" ('
894            "n integer, m integer, t text, primary key (n, m))" % table)
895        for n in range(3):
896            for m in range(2):
897                t = chr(ord('a') + 2 * n + m)
898                query('insert into "%s" values('
899                    "%d, %d, '%s')" % (table, n + 1, m + 1, t))
900        self.assertRaises(pg.ProgrammingError, get, table, 2)
901        self.assertEqual(get(table, dict(n=2, m=2))['t'], 'd')
902        r = get(table, dict(n=1, m=2), ('n', 'm'))
903        self.assertEqual(r['t'], 'b')
904        r = get(table, dict(n=3, m=2), frozenset(['n', 'm']))
905        self.assertEqual(r['t'], 'f')
906
907    def testGetWithQuotedNames(self):
908        get = self.db.get
909        query = self.db.query
910        table = 'test table for get()'
911        query('drop table if exists "%s"' % table)
912        self.addCleanup(query, 'drop table "%s"' % table)
913        query('create table "%s" ('
914            '"Prime!" smallint primary key,'
915            ' "much space" integer, "Questions?" text)' % table)
916        query('insert into "%s"'
917              " values(17, 1001, 'No!')" % table)
918        r = get(table, 17)
919        self.assertIsInstance(r, dict)
920        self.assertEqual(r['Prime!'], 17)
921        self.assertEqual(r['much space'], 1001)
922        self.assertEqual(r['Questions?'], 'No!')
923
924    def testGetFromView(self):
925        self.db.query('delete from test where i4=14')
926        self.db.query('insert into test (i4, v4) values('
927            "14, 'abc4')")
928        r = self.db.get('test_view', 14, 'i4')
929        self.assertIn('v4', r)
930        self.assertEqual(r['v4'], 'abc4')
931
932    def testGetLittleBobbyTables(self):
933        get = self.db.get
934        query = self.db.query
935        query("drop table if exists test_students")
936        self.addCleanup(query, "drop table test_students")
937        query("create table test_students (firstname varchar primary key,"
938            " nickname varchar, grade char(2))")
939        query("insert into test_students values ("
940              "'D''Arcy', 'Darcey', 'A+')")
941        query("insert into test_students values ("
942              "'Sheldon', 'Moonpie', 'A+')")
943        query("insert into test_students values ("
944              "'Robert', 'Little Bobby Tables', 'D-')")
945        r = get('test_students', 'Sheldon')
946        self.assertEqual(r, dict(
947            firstname="Sheldon", nickname='Moonpie', grade='A+'))
948        r = get('test_students', 'Robert')
949        self.assertEqual(r, dict(
950            firstname="Robert", nickname='Little Bobby Tables', grade='D-'))
951        r = get('test_students', "D'Arcy")
952        self.assertEqual(r, dict(
953            firstname="D'Arcy", nickname='Darcey', grade='A+'))
954        try:
955            get('test_students', "D' Arcy")
956        except pg.DatabaseError as error:
957            self.assertEqual(str(error),
958                'No such record in test_students\nwhere "firstname" = $1\n'
959                'with $1="D\' Arcy"')
960        try:
961            get('test_students', "Robert'); TRUNCATE TABLE test_students;--")
962        except pg.DatabaseError as error:
963            self.assertEqual(str(error),
964                'No such record in test_students\nwhere "firstname" = $1\n'
965                'with $1="Robert\'); TRUNCATE TABLE test_students;--"')
966        q = "select * from test_students order by 1 limit 4"
967        r = query(q).getresult()
968        self.assertEqual(len(r), 3)
969        self.assertEqual(r[1][2], 'D-')
970
971    def testInsert(self):
972        insert = self.db.insert
973        query = self.db.query
974        bool_on = pg.get_bool()
975        decimal = pg.get_decimal()
976        table = 'insert_test_table'
977        query('drop table if exists "%s"' % table)
978        self.addCleanup(query, 'drop table "%s"' % table)
979        query('create table "%s" ('
980            "i2 smallint, i4 integer, i8 bigint,"
981            " d numeric, f4 real, f8 double precision, m money,"
982            " v4 varchar(4), c4 char(4), t text,"
983            " b boolean, ts timestamp) with oids" % table)
984        oid_table = 'oid(%s)' % table
985        tests = [dict(i2=None, i4=None, i8=None),
986            (dict(i2='', i4='', i8=''), dict(i2=None, i4=None, i8=None)),
987            (dict(i2=0, i4=0, i8=0), dict(i2=0, i4=0, i8=0)),
988            dict(i2=42, i4=123456, i8=9876543210),
989            dict(i2=2 ** 15 - 1,
990                i4=int(2 ** 31 - 1), i8=long(2 ** 63 - 1)),
991            dict(d=None), (dict(d=''), dict(d=None)),
992            dict(d=Decimal(0)), (dict(d=0), dict(d=Decimal(0))),
993            dict(f4=None, f8=None), dict(f4=0, f8=0),
994            (dict(f4='', f8=''), dict(f4=None, f8=None)),
995            (dict(d=1234.5, f4=1234.5, f8=1234.5),
996                  dict(d=Decimal('1234.5'))),
997            dict(d=Decimal('123.456789'), f4=12.375, f8=123.4921875),
998            dict(d=Decimal('123456789.9876543212345678987654321')),
999            dict(m=None), (dict(m=''), dict(m=None)),
1000            dict(m=Decimal('-1234.56')),
1001            (dict(m=('-1234.56')), dict(m=Decimal('-1234.56'))),
1002            dict(m=Decimal('1234.56')), dict(m=Decimal('123456')),
1003            (dict(m='1234.56'), dict(m=Decimal('1234.56'))),
1004            (dict(m=1234.5), dict(m=Decimal('1234.5'))),
1005            (dict(m=-1234.5), dict(m=Decimal('-1234.5'))),
1006            (dict(m=123456), dict(m=Decimal('123456'))),
1007            (dict(m='1234567.89'), dict(m=Decimal('1234567.89'))),
1008            dict(b=None), (dict(b=''), dict(b=None)),
1009            dict(b='f'), dict(b='t'),
1010            (dict(b=0), dict(b='f')), (dict(b=1), dict(b='t')),
1011            (dict(b=False), dict(b='f')), (dict(b=True), dict(b='t')),
1012            (dict(b='0'), dict(b='f')), (dict(b='1'), dict(b='t')),
1013            (dict(b='n'), dict(b='f')), (dict(b='y'), dict(b='t')),
1014            (dict(b='no'), dict(b='f')), (dict(b='yes'), dict(b='t')),
1015            (dict(b='off'), dict(b='f')), (dict(b='on'), dict(b='t')),
1016            dict(v4=None, c4=None, t=None),
1017            (dict(v4='', c4='', t=''), dict(c4=' ' * 4)),
1018            dict(v4='1234', c4='1234', t='1234' * 10),
1019            dict(v4='abcd', c4='abcd', t='abcdefg'),
1020            (dict(v4='abc', c4='abc', t='abc'), dict(c4='abc ')),
1021            dict(ts=None), (dict(ts=''), dict(ts=None)),
1022            (dict(ts=0), dict(ts=None)), (dict(ts=False), dict(ts=None)),
1023            dict(ts='2012-12-21 00:00:00'),
1024            (dict(ts='2012-12-21'), dict(ts='2012-12-21 00:00:00')),
1025            dict(ts='2012-12-21 12:21:12'),
1026            dict(ts='2013-01-05 12:13:14'),
1027            dict(ts='current_timestamp')]
1028        for test in tests:
1029            if isinstance(test, dict):
1030                data = test
1031                change = {}
1032            else:
1033                data, change = test
1034            expect = data.copy()
1035            expect.update(change)
1036            if bool_on:
1037                b = expect.get('b')
1038                if b is not None:
1039                    expect['b'] = b == 't'
1040            if decimal is not Decimal:
1041                d = expect.get('d')
1042                if d is not None:
1043                    expect['d'] = decimal(d)
1044                m = expect.get('m')
1045                if m is not None:
1046                    expect['m'] = decimal(m)
1047            self.assertEqual(insert(table, data), data)
1048            self.assertIn(oid_table, data)
1049            oid = data[oid_table]
1050            self.assertIsInstance(oid, int)
1051            data = dict(item for item in data.items()
1052                if item[0] in expect)
1053            ts = expect.get('ts')
1054            if ts == 'current_timestamp':
1055                ts = expect['ts'] = data['ts']
1056                if len(ts) > 19:
1057                    self.assertEqual(ts[19], '.')
1058                    ts = ts[:19]
1059                else:
1060                    self.assertEqual(len(ts), 19)
1061                self.assertTrue(ts[:4].isdigit())
1062                self.assertEqual(ts[4], '-')
1063                self.assertEqual(ts[10], ' ')
1064                self.assertTrue(ts[11:13].isdigit())
1065                self.assertEqual(ts[13], ':')
1066            self.assertEqual(data, expect)
1067            data = query(
1068                'select oid,* from "%s"' % table).dictresult()[0]
1069            self.assertEqual(data['oid'], oid)
1070            data = dict(item for item in data.items()
1071                if item[0] in expect)
1072            self.assertEqual(data, expect)
1073            query('delete from "%s"' % table)
1074
1075    def testInsertWithQuotedNames(self):
1076        insert = self.db.insert
1077        query = self.db.query
1078        table = 'test table for insert()'
1079        query('drop table if exists "%s"' % table)
1080        self.addCleanup(query, 'drop table "%s"' % table)
1081        query('create table "%s" ('
1082            '"Prime!" smallint primary key,'
1083            ' "much space" integer, "Questions?" text)' % table)
1084        r = {'Prime!': 11, 'much space': 2002, 'Questions?': 'What?'}
1085        r = insert(table, r)
1086        self.assertIsInstance(r, dict)
1087        self.assertEqual(r['Prime!'], 11)
1088        self.assertEqual(r['much space'], 2002)
1089        self.assertEqual(r['Questions?'], 'What?')
1090        r = query('select * from "%s" limit 2' % table).dictresult()
1091        self.assertEqual(len(r), 1)
1092        r = r[0]
1093        self.assertEqual(r['Prime!'], 11)
1094        self.assertEqual(r['much space'], 2002)
1095        self.assertEqual(r['Questions?'], 'What?')
1096
1097    def testUpdate(self):
1098        update = self.db.update
1099        query = self.db.query
1100        table = 'update_test_table'
1101        query('drop table if exists "%s"' % table)
1102        self.addCleanup(query, 'drop table "%s"' % table)
1103        query('create table "%s" ('
1104            "n integer, t text) with oids" % table)
1105        for n, t in enumerate('xyz'):
1106            query('insert into "%s" values('
1107                "%d, '%s')" % (table, n + 1, t))
1108        self.assertRaises(pg.ProgrammingError, self.db.get, table, 2)
1109        r = self.db.get(table, 2, 'n')
1110        r['t'] = 'u'
1111        s = update(table, r)
1112        self.assertEqual(s, r)
1113        q = 'select t from "%s" where n=2' % table
1114        r = query(q).getresult()[0][0]
1115        self.assertEqual(r, 'u')
1116
1117    def testUpdateWithCompositeKey(self):
1118        update = self.db.update
1119        query = self.db.query
1120        table = 'update_test_table_1'
1121        query('drop table if exists "%s"' % table)
1122        self.addCleanup(query, 'drop table if exists "%s"' % table)
1123        query('create table "%s" ('
1124            "n integer, t text, primary key (n))" % table)
1125        for n, t in enumerate('abc'):
1126            query('insert into "%s" values('
1127                "%d, '%s')" % (table, n + 1, t))
1128        self.assertRaises(pg.ProgrammingError, update,
1129                          table, dict(t='b'))
1130        s = dict(n=2, t='d')
1131        r = update(table, s)
1132        self.assertIs(r, s)
1133        self.assertEqual(r['n'], 2)
1134        self.assertEqual(r['t'], 'd')
1135        q = 'select t from "%s" where n=2' % table
1136        r = query(q).getresult()[0][0]
1137        self.assertEqual(r, 'd')
1138        s.update(dict(n=4, t='e'))
1139        r = update(table, s)
1140        self.assertEqual(r['n'], 4)
1141        self.assertEqual(r['t'], 'e')
1142        q = 'select t from "%s" where n=2' % table
1143        r = query(q).getresult()[0][0]
1144        self.assertEqual(r, 'd')
1145        q = 'select t from "%s" where n=4' % table
1146        r = query(q).getresult()
1147        self.assertEqual(len(r), 0)
1148        query('drop table "%s"' % table)
1149        table = 'update_test_table_2'
1150        query('drop table if exists "%s"' % table)
1151        query('create table "%s" ('
1152            "n integer, m integer, t text, primary key (n, m))" % table)
1153        for n in range(3):
1154            for m in range(2):
1155                t = chr(ord('a') + 2 * n + m)
1156                query('insert into "%s" values('
1157                    "%d, %d, '%s')" % (table, n + 1, m + 1, t))
1158        self.assertRaises(pg.ProgrammingError, update,
1159                          table, dict(n=2, t='b'))
1160        self.assertEqual(update(table,
1161                                dict(n=2, m=2, t='x'))['t'], 'x')
1162        q = 'select t from "%s" where n=2 order by m' % table
1163        r = [r[0] for r in query(q).getresult()]
1164        self.assertEqual(r, ['c', 'x'])
1165
1166    def testUpdateWithQuotedNames(self):
1167        update = self.db.update
1168        query = self.db.query
1169        table = 'test table for update()'
1170        query('drop table if exists "%s"' % table)
1171        self.addCleanup(query, 'drop table "%s"' % table)
1172        query('create table "%s" ('
1173            '"Prime!" smallint primary key,'
1174            ' "much space" integer, "Questions?" text)' % table)
1175        query('insert into "%s"'
1176              " values(13, 3003, 'Why!')" % table)
1177        r = {'Prime!': 13, 'much space': 7007, 'Questions?': 'When?'}
1178        r = update(table, r)
1179        self.assertIsInstance(r, dict)
1180        self.assertEqual(r['Prime!'], 13)
1181        self.assertEqual(r['much space'], 7007)
1182        self.assertEqual(r['Questions?'], 'When?')
1183        r = query('select * from "%s" limit 2' % table).dictresult()
1184        self.assertEqual(len(r), 1)
1185        r = r[0]
1186        self.assertEqual(r['Prime!'], 13)
1187        self.assertEqual(r['much space'], 7007)
1188        self.assertEqual(r['Questions?'], 'When?')
1189
1190    def testUpsert(self):
1191        upsert = self.db.upsert
1192        query = self.db.query
1193        table = 'upsert_test_table'
1194        query('drop table if exists "%s"' % table)
1195        self.addCleanup(query, 'drop table "%s"' % table)
1196        query('create table "%s" ('
1197            "n integer primary key, t text) with oids" % table)
1198        s = dict(n=1, t='x')
1199        try:
1200            r = upsert(table, s)
1201        except pg.ProgrammingError as error:
1202            if self.db.server_version < 90500:
1203                self.skipTest('database does not support upsert')
1204            self.fail(str(error))
1205        self.assertIs(r, s)
1206        self.assertEqual(r['n'], 1)
1207        self.assertEqual(r['t'], 'x')
1208        s.update(n=2, t='y')
1209        r = upsert(table, s, **dict.fromkeys(s))
1210        self.assertIs(r, s)
1211        self.assertEqual(r['n'], 2)
1212        self.assertEqual(r['t'], 'y')
1213        q = 'select n, t from "%s" order by n limit 3' % table
1214        r = query(q).getresult()
1215        self.assertEqual(r, [(1, 'x'), (2, 'y')])
1216        s.update(t='z')
1217        r = upsert(table, s)
1218        self.assertIs(r, s)
1219        self.assertEqual(r['n'], 2)
1220        self.assertEqual(r['t'], 'z')
1221        r = query(q).getresult()
1222        self.assertEqual(r, [(1, 'x'), (2, 'z')])
1223        s.update(t='n')
1224        r = upsert(table, s, t=False)
1225        self.assertIs(r, s)
1226        self.assertEqual(r['n'], 2)
1227        self.assertEqual(r['t'], 'z')
1228        r = query(q).getresult()
1229        self.assertEqual(r, [(1, 'x'), (2, 'z')])
1230        s.update(t='y')
1231        r = upsert(table, s, t=True)
1232        self.assertIs(r, s)
1233        self.assertEqual(r['n'], 2)
1234        self.assertEqual(r['t'], 'y')
1235        r = query(q).getresult()
1236        self.assertEqual(r, [(1, 'x'), (2, 'y')])
1237        s.update(t='n')
1238        r = upsert(table, s, t="included.t || '2'")
1239        self.assertIs(r, s)
1240        self.assertEqual(r['n'], 2)
1241        self.assertEqual(r['t'], 'y2')
1242        r = query(q).getresult()
1243        self.assertEqual(r, [(1, 'x'), (2, 'y2')])
1244        s.update(t='y')
1245        r = upsert(table, s, t="excluded.t || '3'")
1246        self.assertIs(r, s)
1247        self.assertEqual(r['n'], 2)
1248        self.assertEqual(r['t'], 'y3')
1249        r = query(q).getresult()
1250        self.assertEqual(r, [(1, 'x'), (2, 'y3')])
1251        s.update(n=1, t='2')
1252        r = upsert(table, s, t="included.t || excluded.t")
1253        self.assertIs(r, s)
1254        self.assertEqual(r['n'], 1)
1255        self.assertEqual(r['t'], 'x2')
1256        r = query(q).getresult()
1257        self.assertEqual(r, [(1, 'x2'), (2, 'y3')])
1258
1259    def testUpsertWithCompositeKey(self):
1260        upsert = self.db.upsert
1261        query = self.db.query
1262        table = 'upsert_test_table_2'
1263        query('drop table if exists "%s"' % table)
1264        self.addCleanup(query, 'drop table "%s"' % table)
1265        query('create table "%s" ('
1266            "n integer, m integer, t text, primary key (n, m))" % table)
1267        s = dict(n=1, m=2, t='x')
1268        try:
1269            r = upsert(table, s)
1270        except pg.ProgrammingError as error:
1271            if self.db.server_version < 90500:
1272                self.skipTest('database does not support upsert')
1273            self.fail(str(error))
1274        self.assertIs(r, s)
1275        self.assertEqual(r['n'], 1)
1276        self.assertEqual(r['m'], 2)
1277        self.assertEqual(r['t'], 'x')
1278        s.update(m=3, t='y')
1279        r = upsert(table, s, **dict.fromkeys(s))
1280        self.assertIs(r, s)
1281        self.assertEqual(r['n'], 1)
1282        self.assertEqual(r['m'], 3)
1283        self.assertEqual(r['t'], 'y')
1284        q = 'select n, m, t from "%s" order by n, m limit 3' % table
1285        r = query(q).getresult()
1286        self.assertEqual(r, [(1, 2, 'x'), (1, 3, 'y')])
1287        s.update(t='z')
1288        r = upsert(table, s)
1289        self.assertIs(r, s)
1290        self.assertEqual(r['n'], 1)
1291        self.assertEqual(r['m'], 3)
1292        self.assertEqual(r['t'], 'z')
1293        r = query(q).getresult()
1294        self.assertEqual(r, [(1, 2, 'x'), (1, 3, 'z')])
1295        s.update(t='n')
1296        r = upsert(table, s, t=False)
1297        self.assertIs(r, s)
1298        self.assertEqual(r['n'], 1)
1299        self.assertEqual(r['m'], 3)
1300        self.assertEqual(r['t'], 'z')
1301        r = query(q).getresult()
1302        self.assertEqual(r, [(1, 2, 'x'), (1, 3, 'z')])
1303        s.update(t='n')
1304        r = upsert(table, s, t=True)
1305        self.assertIs(r, s)
1306        self.assertEqual(r['n'], 1)
1307        self.assertEqual(r['m'], 3)
1308        self.assertEqual(r['t'], 'n')
1309        r = query(q).getresult()
1310        self.assertEqual(r, [(1, 2, 'x'), (1, 3, 'n')])
1311        s.update(n=2, t='y')
1312        r = upsert(table, s, t="'z'")
1313        self.assertIs(r, s)
1314        self.assertEqual(r['n'], 2)
1315        self.assertEqual(r['m'], 3)
1316        self.assertEqual(r['t'], 'y')
1317        r = query(q).getresult()
1318        self.assertEqual(r, [(1, 2, 'x'), (1, 3, 'n'), (2, 3, 'y')])
1319        s.update(n=1, t='m')
1320        r = upsert(table, s, t='included.t || excluded.t')
1321        self.assertIs(r, s)
1322        self.assertEqual(r['n'], 1)
1323        self.assertEqual(r['m'], 3)
1324        self.assertEqual(r['t'], 'nm')
1325        r = query(q).getresult()
1326        self.assertEqual(r, [(1, 2, 'x'), (1, 3, 'nm'), (2, 3, 'y')])
1327
1328    def testUpsertWithQuotedNames(self):
1329        upsert = self.db.upsert
1330        query = self.db.query
1331        table = 'test table for upsert()'
1332        query('drop table if exists "%s"' % table)
1333        self.addCleanup(query, 'drop table "%s"' % table)
1334        query('create table "%s" ('
1335            '"Prime!" smallint primary key,'
1336            ' "much space" integer, "Questions?" text)' % table)
1337        s = {'Prime!': 31, 'much space': 9009, 'Questions?': 'Yes.'}
1338        try:
1339            r = upsert(table, s)
1340        except pg.ProgrammingError as error:
1341            if self.db.server_version < 90500:
1342                self.skipTest('database does not support upsert')
1343            self.fail(str(error))
1344        self.assertIs(r, s)
1345        self.assertEqual(r['Prime!'], 31)
1346        self.assertEqual(r['much space'], 9009)
1347        self.assertEqual(r['Questions?'], 'Yes.')
1348        q = 'select * from "%s" limit 2' % table
1349        r = query(q).getresult()
1350        self.assertEqual(r, [(31, 9009, 'Yes.')])
1351        s.update({'Questions?': 'No.'})
1352        r = upsert(table, s)
1353        self.assertIs(r, s)
1354        self.assertEqual(r['Prime!'], 31)
1355        self.assertEqual(r['much space'], 9009)
1356        self.assertEqual(r['Questions?'], 'No.')
1357        r = query(q).getresult()
1358        self.assertEqual(r, [(31, 9009, 'No.')])
1359
1360    def testClear(self):
1361        clear = self.db.clear
1362        query = self.db.query
1363        f = False if pg.get_bool() else 'f'
1364        table = 'clear_test_table'
1365        query('drop table if exists "%s"' % table)
1366        self.addCleanup(query, 'drop table "%s"' % table)
1367        query('create table "%s" ('
1368            "n integer, b boolean, d date, t text)" % table)
1369        r = clear(table)
1370        result = {'n': 0, 'b': f, 'd': '', 't': ''}
1371        self.assertEqual(r, result)
1372        r['a'] = r['n'] = 1
1373        r['d'] = r['t'] = 'x'
1374        r['b'] = 't'
1375        r['oid'] = long(1)
1376        r = clear(table, r)
1377        result = {'a': 1, 'n': 0, 'b': f, 'd': '', 't': '',
1378            'oid': long(1)}
1379        self.assertEqual(r, result)
1380
1381    def testClearWithQuotedNames(self):
1382        clear = self.db.clear
1383        query = self.db.query
1384        table = 'test table for clear()'
1385        query('drop table if exists "%s"' % table)
1386        self.addCleanup(query, 'drop table "%s"' % table)
1387        query('create table "%s" ('
1388            '"Prime!" smallint primary key,'
1389            ' "much space" integer, "Questions?" text)' % table)
1390        r = clear(table)
1391        self.assertIsInstance(r, dict)
1392        self.assertEqual(r['Prime!'], 0)
1393        self.assertEqual(r['much space'], 0)
1394        self.assertEqual(r['Questions?'], '')
1395
1396    def testDelete(self):
1397        delete = self.db.delete
1398        query = self.db.query
1399        table = 'delete_test_table'
1400        query('drop table if exists "%s"' % table)
1401        self.addCleanup(query, 'drop table "%s"' % table)
1402        query('create table "%s" ('
1403            "n integer, t text) with oids" % table)
1404        for n, t in enumerate('xyz'):
1405            query('insert into "%s" values('
1406                "%d, '%s')" % (table, n + 1, t))
1407        self.assertRaises(pg.ProgrammingError, self.db.get, table, 2)
1408        r = self.db.get(table, 1, 'n')
1409        s = delete(table, r)
1410        self.assertEqual(s, 1)
1411        r = self.db.get(table, 3, 'n')
1412        s = delete(table, r)
1413        self.assertEqual(s, 1)
1414        s = delete(table, r)
1415        self.assertEqual(s, 0)
1416        r = query('select * from "%s"' % table).dictresult()
1417        self.assertEqual(len(r), 1)
1418        r = r[0]
1419        result = {'n': 2, 't': 'y'}
1420        self.assertEqual(r, result)
1421        r = self.db.get(table, 2, 'n')
1422        s = delete(table, r)
1423        self.assertEqual(s, 1)
1424        s = delete(table, r)
1425        self.assertEqual(s, 0)
1426        self.assertRaises(pg.DatabaseError, self.db.get, table, 2, 'n')
1427
1428    def testDeleteWithCompositeKey(self):
1429        query = self.db.query
1430        table = 'delete_test_table_1'
1431        query('drop table if exists "%s"' % table)
1432        self.addCleanup(query, 'drop table "%s"' % table)
1433        query('create table "%s" ('
1434            "n integer, t text, primary key (n))" % table)
1435        for n, t in enumerate('abc'):
1436            query("insert into %s values("
1437                "%d, '%s')" % (table, n + 1, t))
1438        self.assertRaises(pg.ProgrammingError, self.db.delete,
1439            table, dict(t='b'))
1440        self.assertEqual(self.db.delete(table, dict(n=2)), 1)
1441        r = query('select t from "%s" where n=2' % table
1442                  ).getresult()
1443        self.assertEqual(r, [])
1444        self.assertEqual(self.db.delete(table, dict(n=2)), 0)
1445        r = query('select t from "%s" where n=3' % table
1446                  ).getresult()[0][0]
1447        self.assertEqual(r, 'c')
1448        table = 'delete_test_table_2'
1449        query('drop table if exists "%s"' % table)
1450        self.addCleanup(query, 'drop table "%s"' % table)
1451        query('create table "%s" ('
1452            "n integer, m integer, t text, primary key (n, m))" % table)
1453        for n in range(3):
1454            for m in range(2):
1455                t = chr(ord('a') + 2 * n + m)
1456                query('insert into "%s" values('
1457                    "%d, %d, '%s')" % (table, n + 1, m + 1, t))
1458        self.assertRaises(pg.ProgrammingError, self.db.delete,
1459            table, dict(n=2, t='b'))
1460        self.assertEqual(self.db.delete(table, dict(n=2, m=2)), 1)
1461        r = [r[0] for r in query('select t from "%s" where n=2'
1462            ' order by m' % table).getresult()]
1463        self.assertEqual(r, ['c'])
1464        self.assertEqual(self.db.delete(table, dict(n=2, m=2)), 0)
1465        r = [r[0] for r in query('select t from "%s" where n=3'
1466            ' order by m' % table).getresult()]
1467        self.assertEqual(r, ['e', 'f'])
1468        self.assertEqual(self.db.delete(table, dict(n=3, m=1)), 1)
1469        r = [r[0] for r in query('select t from "%s" where n=3'
1470            ' order by m' % table).getresult()]
1471        self.assertEqual(r, ['f'])
1472
1473    def testDeleteWithQuotedNames(self):
1474        delete = self.db.delete
1475        query = self.db.query
1476        table = 'test table for delete()'
1477        query('drop table if exists "%s"' % table)
1478        self.addCleanup(query, 'drop table "%s"' % table)
1479        query('create table "%s" ('
1480            '"Prime!" smallint primary key,'
1481            ' "much space" integer, "Questions?" text)' % table)
1482        query('insert into "%s"'
1483              " values(19, 5005, 'Yes!')" % table)
1484        r = {'Prime!': 17}
1485        r = delete(table, r)
1486        self.assertEqual(r, 0)
1487        r = query('select count(*) from "%s"' % table).getresult()
1488        self.assertEqual(r[0][0], 1)
1489        r = {'Prime!': 19}
1490        r = delete(table, r)
1491        self.assertEqual(r, 1)
1492        r = query('select count(*) from "%s"' % table).getresult()
1493        self.assertEqual(r[0][0], 0)
1494
1495    def testTransaction(self):
1496        query = self.db.query
1497        query("drop table if exists test_table")
1498        self.addCleanup(query, "drop table test_table")
1499        query("create table test_table (n integer)")
1500        self.db.begin()
1501        query("insert into test_table values (1)")
1502        query("insert into test_table values (2)")
1503        self.db.commit()
1504        self.db.begin()
1505        query("insert into test_table values (3)")
1506        query("insert into test_table values (4)")
1507        self.db.rollback()
1508        self.db.begin()
1509        query("insert into test_table values (5)")
1510        self.db.savepoint('before6')
1511        query("insert into test_table values (6)")
1512        self.db.rollback('before6')
1513        query("insert into test_table values (7)")
1514        self.db.commit()
1515        self.db.begin()
1516        self.db.savepoint('before8')
1517        query("insert into test_table values (8)")
1518        self.db.release('before8')
1519        self.assertRaises(pg.ProgrammingError, self.db.rollback, 'before8')
1520        self.db.commit()
1521        self.db.start()
1522        query("insert into test_table values (9)")
1523        self.db.end()
1524        r = [r[0] for r in query(
1525            "select * from test_table order by 1").getresult()]
1526        self.assertEqual(r, [1, 2, 5, 7, 9])
1527
1528    def testContextManager(self):
1529        query = self.db.query
1530        query("drop table if exists test_table")
1531        self.addCleanup(query, "drop table test_table")
1532        query("create table test_table (n integer check(n>0))")
1533        with self.db:
1534            query("insert into test_table values (1)")
1535            query("insert into test_table values (2)")
1536        try:
1537            with self.db:
1538                query("insert into test_table values (3)")
1539                query("insert into test_table values (4)")
1540                raise ValueError('test transaction should rollback')
1541        except ValueError as error:
1542            self.assertEqual(str(error), 'test transaction should rollback')
1543        with self.db:
1544            query("insert into test_table values (5)")
1545        try:
1546            with self.db:
1547                query("insert into test_table values (6)")
1548                query("insert into test_table values (-1)")
1549        except pg.ProgrammingError as error:
1550            self.assertTrue('check' in str(error))
1551        with self.db:
1552            query("insert into test_table values (7)")
1553        r = [r[0] for r in query(
1554            "select * from test_table order by 1").getresult()]
1555        self.assertEqual(r, [1, 2, 5, 7])
1556
1557    def testBytea(self):
1558        query = self.db.query
1559        query('drop table if exists bytea_test')
1560        self.addCleanup(query, 'drop table bytea_test')
1561        query('create table bytea_test (n smallint primary key, data bytea)')
1562        s = b"It's all \\ kinds \x00 of\r nasty \xff stuff!\n"
1563        r = self.db.escape_bytea(s)
1564        query('insert into bytea_test values(3,$1)', (r,))
1565        r = query('select * from bytea_test where n=3').getresult()
1566        self.assertEqual(len(r), 1)
1567        r = r[0]
1568        self.assertEqual(len(r), 2)
1569        self.assertEqual(r[0], 3)
1570        r = r[1]
1571        self.assertIsInstance(r, str)
1572        r = self.db.unescape_bytea(r)
1573        self.assertIsInstance(r, bytes)
1574        self.assertEqual(r, s)
1575
1576    def testInsertUpdateGetBytea(self):
1577        query = self.db.query
1578        query('drop table if exists bytea_test')
1579        self.addCleanup(query, 'drop table bytea_test')
1580        query('create table bytea_test (n smallint primary key, data bytea)')
1581        # insert null value
1582        r = self.db.insert('bytea_test', n=0, data=None)
1583        self.assertIsInstance(r, dict)
1584        self.assertIn('n', r)
1585        self.assertEqual(r['n'], 0)
1586        self.assertIn('data', r)
1587        self.assertIsNone(r['data'])
1588        s = b'None'
1589        r = self.db.update('bytea_test', n=0, data=s)
1590        self.assertIsInstance(r, dict)
1591        self.assertIn('n', r)
1592        self.assertEqual(r['n'], 0)
1593        self.assertIn('data', r)
1594        r = r['data']
1595        self.assertIsInstance(r, bytes)
1596        self.assertEqual(r, s)
1597        r = self.db.update('bytea_test', n=0, data=None)
1598        self.assertIsNone(r['data'])
1599        # insert as bytes
1600        s = b"It's all \\ kinds \x00 of\r nasty \xff stuff!\n"
1601        r = self.db.insert('bytea_test', n=5, data=s)
1602        self.assertIsInstance(r, dict)
1603        self.assertIn('n', r)
1604        self.assertEqual(r['n'], 5)
1605        self.assertIn('data', r)
1606        r = r['data']
1607        self.assertIsInstance(r, bytes)
1608        self.assertEqual(r, s)
1609        # update as bytes
1610        s += b"and now even more \x00 nasty \t stuff!\f"
1611        r = self.db.update('bytea_test', n=5, data=s)
1612        self.assertIsInstance(r, dict)
1613        self.assertIn('n', r)
1614        self.assertEqual(r['n'], 5)
1615        self.assertIn('data', r)
1616        r = r['data']
1617        self.assertIsInstance(r, bytes)
1618        self.assertEqual(r, s)
1619        r = query('select * from bytea_test where n=5').getresult()
1620        self.assertEqual(len(r), 1)
1621        r = r[0]
1622        self.assertEqual(len(r), 2)
1623        self.assertEqual(r[0], 5)
1624        r = r[1]
1625        self.assertIsInstance(r, str)
1626        r = self.db.unescape_bytea(r)
1627        self.assertIsInstance(r, bytes)
1628        self.assertEqual(r, s)
1629        r = self.db.get('bytea_test', dict(n=5))
1630        self.assertIsInstance(r, dict)
1631        self.assertIn('n', r)
1632        self.assertEqual(r['n'], 5)
1633        self.assertIn('data', r)
1634        r = r['data']
1635        self.assertIsInstance(r, bytes)
1636        self.assertEqual(r, s)
1637
1638    def testDebugWithCallable(self):
1639        if debug:
1640            self.assertEqual(self.db.debug, debug)
1641        else:
1642            self.assertIsNone(self.db.debug)
1643        s = []
1644        self.db.debug = s.append
1645        try:
1646            self.db.query("select 1")
1647            self.db.query("select 2")
1648            self.assertEqual(s, ["select 1", "select 2"])
1649        finally:
1650            self.db.debug = debug
1651
1652
1653class TestDBClassNonStdOpts(TestDBClass):
1654    """Test the methods of the DB class with non-standard global options."""
1655
1656    @classmethod
1657    def setUpClass(cls):
1658        cls.saved_options = {}
1659        cls.set_option('decimal', float)
1660        not_bool = not pg.get_bool()
1661        cls.set_option('bool', not_bool)
1662        unnamed_result = lambda q: q.getresult()
1663        cls.set_option('namedresult', unnamed_result)
1664        super(TestDBClassNonStdOpts, cls).setUpClass()
1665
1666    @classmethod
1667    def tearDownClass(cls):
1668        super(TestDBClassNonStdOpts, cls).tearDownClass()
1669        cls.reset_option('namedresult')
1670        cls.reset_option('bool')
1671        cls.reset_option('decimal')
1672
1673    @classmethod
1674    def set_option(cls, option, value):
1675        cls.saved_options[option] = getattr(pg, 'get_' + option)()
1676        return getattr(pg, 'set_' + option)(value)
1677
1678    @classmethod
1679    def reset_option(cls, option):
1680        return getattr(pg, 'set_' + option)(cls.saved_options[option])
1681
1682
1683class TestSchemas(unittest.TestCase):
1684    """Test correct handling of schemas (namespaces)."""
1685
1686    @classmethod
1687    def setUpClass(cls):
1688        db = DB()
1689        query = db.query
1690        query("set client_min_messages=warning")
1691        for num_schema in range(5):
1692            if num_schema:
1693                schema = "s%d" % num_schema
1694                query("drop schema if exists %s cascade" % (schema,))
1695                try:
1696                    query("create schema %s" % (schema,))
1697                except pg.ProgrammingError:
1698                    raise RuntimeError("The test user cannot create schemas.\n"
1699                        "Grant create on database %s to the user"
1700                        " for running these tests." % dbname)
1701            else:
1702                schema = "public"
1703                query("drop table if exists %s.t" % (schema,))
1704                query("drop table if exists %s.t%d" % (schema, num_schema))
1705            query("create table %s.t with oids as select 1 as n, %d as d"
1706                  % (schema, num_schema))
1707            query("create table %s.t%d with oids as select 1 as n, %d as d"
1708                  % (schema, num_schema, num_schema))
1709        db.close()
1710
1711    @classmethod
1712    def tearDownClass(cls):
1713        db = DB()
1714        query = db.query
1715        query("set client_min_messages=warning")
1716        for num_schema in range(5):
1717            if num_schema:
1718                schema = "s%d" % num_schema
1719                query("drop schema %s cascade" % (schema,))
1720            else:
1721                schema = "public"
1722                query("drop table %s.t" % (schema,))
1723                query("drop table %s.t%d" % (schema, num_schema))
1724        db.close()
1725
1726    def setUp(self):
1727        self.db = DB()
1728        self.db.query("set client_min_messages=warning")
1729
1730    def tearDown(self):
1731        self.doCleanups()
1732        self.db.close()
1733
1734    def testGetTables(self):
1735        tables = self.db.get_tables()
1736        for num_schema in range(5):
1737            if num_schema:
1738                schema = "s" + str(num_schema)
1739            else:
1740                schema = "public"
1741            for t in (schema + ".t",
1742                    schema + ".t" + str(num_schema)):
1743                self.assertIn(t, tables)
1744
1745    def testGetAttnames(self):
1746        get_attnames = self.db.get_attnames
1747        query = self.db.query
1748        result = {'oid': 'int', 'd': 'int', 'n': 'int'}
1749        r = get_attnames("t")
1750        self.assertEqual(r, result)
1751        r = get_attnames("s4.t4")
1752        self.assertEqual(r, result)
1753        query("drop table if exists s3.t3m")
1754        self.addCleanup(query, "drop table s3.t3m")
1755        query("create table s3.t3m with oids as select 1 as m")
1756        result_m = {'oid': 'int', 'm': 'int'}
1757        r = get_attnames("s3.t3m")
1758        self.assertEqual(r, result_m)
1759        query("set search_path to s1,s3")
1760        r = get_attnames("t3")
1761        self.assertEqual(r, result)
1762        r = get_attnames("t3m")
1763        self.assertEqual(r, result_m)
1764
1765    def testGet(self):
1766        get = self.db.get
1767        query = self.db.query
1768        PrgError = pg.ProgrammingError
1769        self.assertEqual(get("t", 1, 'n')['d'], 0)
1770        self.assertEqual(get("t0", 1, 'n')['d'], 0)
1771        self.assertEqual(get("public.t", 1, 'n')['d'], 0)
1772        self.assertEqual(get("public.t0", 1, 'n')['d'], 0)
1773        self.assertRaises(PrgError, get, "public.t1", 1, 'n')
1774        self.assertEqual(get("s1.t1", 1, 'n')['d'], 1)
1775        self.assertEqual(get("s3.t", 1, 'n')['d'], 3)
1776        query("set search_path to s2,s4")
1777        self.assertRaises(PrgError, get, "t1", 1, 'n')
1778        self.assertEqual(get("t4", 1, 'n')['d'], 4)
1779        self.assertRaises(PrgError, get, "t3", 1, 'n')
1780        self.assertEqual(get("t", 1, 'n')['d'], 2)
1781        self.assertEqual(get("s3.t3", 1, 'n')['d'], 3)
1782        query("set search_path to s1,s3")
1783        self.assertRaises(PrgError, get, "t2", 1, 'n')
1784        self.assertEqual(get("t3", 1, 'n')['d'], 3)
1785        self.assertRaises(PrgError, get, "t4", 1, 'n')
1786        self.assertEqual(get("t", 1, 'n')['d'], 1)
1787        self.assertEqual(get("s4.t4", 1, 'n')['d'], 4)
1788
1789    def testMunging(self):
1790        get = self.db.get
1791        query = self.db.query
1792        r = get("t", 1, 'n')
1793        self.assertIn('oid(t)', r)
1794        query("set search_path to s2")
1795        r = get("t2", 1, 'n')
1796        self.assertIn('oid(t2)', r)
1797        query("set search_path to s3")
1798        r = get("t", 1, 'n')
1799        self.assertIn('oid(t)', r)
1800
1801
1802if __name__ == '__main__':
1803    unittest.main()
Note: See TracBrowser for help on using the repository browser.