source: trunk/tests/test_classic_dbwrapper.py @ 730

Last change on this file since 730 was 730, checked in by cito, 4 years ago

Use query parameters instead of inline values

The single row methods of the DB wrapper class created queries with inline values
instead of passing them separately as parameters, even though our query method
does have this capability. Using query parameters also spares us a lot of quoting
and escaping that is necessary when passing values inline.

  • Property svn:executable set to *
  • Property svn:keywords set to Id
File size: 47.8 KB
Line 
1#! /usr/bin/python
2# -*- coding: utf-8 -*-
3
4"""Test the classic PyGreSQL interface.
5
6Sub-tests for the DB wrapper object.
7
8Contributed by Christoph Zwerschke.
9
10These tests need a database to test against.
11
12"""
13
14try:
15    import unittest2 as unittest  # for Python < 2.7
16except ImportError:
17    import unittest
18import os
19
20import pg  # the module under test
21
22from decimal import Decimal
23
24# We need a database to test against.  If LOCAL_PyGreSQL.py exists we will
25# get our information from that.  Otherwise we use the defaults.
26# The current user must have create schema privilege on the database.
27dbname = 'unittest'
28dbhost = None
29dbport = 5432
30
31debug = False  # let DB wrapper print debugging output
32
33try:
34    from .LOCAL_PyGreSQL import *
35except (ImportError, ValueError):
36    try:
37        from LOCAL_PyGreSQL import *
38    except ImportError:
39        pass
40
41try:
42    long
43except NameError:  # Python >= 3.0
44    long = int
45
46try:
47    unicode
48except NameError:  # Python >= 3.0
49    unicode = str
50
51windows = os.name == 'nt'
52
53# There is a known a bug in libpq under Windows which can cause
54# the interface to crash when calling PQhost():
55do_not_ask_for_host = windows
56do_not_ask_for_host_reason = 'libpq issue on Windows'
57
58
59def DB():
60    """Create a DB wrapper object connecting to the test database."""
61    db = pg.DB(dbname, dbhost, dbport)
62    if debug:
63        db.debug = debug
64    db.query("set client_min_messages=warning")
65    return db
66
67
68class TestDBClassBasic(unittest.TestCase):
69    """Test existence of the DB class wrapped pg connection methods."""
70
71    def setUp(self):
72        self.db = DB()
73
74    def tearDown(self):
75        try:
76            self.db.close()
77        except pg.InternalError:
78            pass
79
80    def testAllDBAttributes(self):
81        attributes = [
82            'begin',
83            'cancel',
84            'clear',
85            'close',
86            'commit',
87            'db',
88            'dbname',
89            'debug',
90            'delete',
91            'end',
92            'endcopy',
93            'error',
94            'escape_bytea',
95            'escape_identifier',
96            'escape_literal',
97            'escape_string',
98            'fileno',
99            'get',
100            'get_attnames',
101            'get_databases',
102            'get_notice_receiver',
103            'get_relations',
104            'get_tables',
105            'getline',
106            'getlo',
107            'getnotify',
108            'has_table_privilege',
109            'host',
110            'insert',
111            'inserttable',
112            'locreate',
113            'loimport',
114            'notification_handler',
115            'options',
116            'parameter',
117            'pkey',
118            'port',
119            'protocol_version',
120            'putline',
121            'query',
122            'release',
123            'reopen',
124            'reset',
125            'rollback',
126            'savepoint',
127            'server_version',
128            'set_notice_receiver',
129            'source',
130            'start',
131            'status',
132            'transaction',
133            'unescape_bytea',
134            'update',
135            'use_regtypes',
136            'user',
137        ]
138        db_attributes = [a for a in dir(self.db)
139            if not a.startswith('_')]
140        self.assertEqual(attributes, db_attributes)
141
142    def testAttributeDb(self):
143        self.assertEqual(self.db.db.db, dbname)
144
145    def testAttributeDbname(self):
146        self.assertEqual(self.db.dbname, dbname)
147
148    def testAttributeError(self):
149        error = self.db.error
150        self.assertTrue(not error or 'krb5_' in error)
151        self.assertEqual(self.db.error, self.db.db.error)
152
153    @unittest.skipIf(do_not_ask_for_host, do_not_ask_for_host_reason)
154    def testAttributeHost(self):
155        def_host = 'localhost'
156        host = self.db.host
157        self.assertIsInstance(host, str)
158        self.assertEqual(host, dbhost or def_host)
159        self.assertEqual(host, self.db.db.host)
160
161    def testAttributeOptions(self):
162        no_options = ''
163        options = self.db.options
164        self.assertEqual(options, no_options)
165        self.assertEqual(options, self.db.db.options)
166
167    def testAttributePort(self):
168        def_port = 5432
169        port = self.db.port
170        self.assertIsInstance(port, int)
171        self.assertEqual(port, dbport or def_port)
172        self.assertEqual(port, self.db.db.port)
173
174    def testAttributeProtocolVersion(self):
175        protocol_version = self.db.protocol_version
176        self.assertIsInstance(protocol_version, int)
177        self.assertTrue(2 <= protocol_version < 4)
178        self.assertEqual(protocol_version, self.db.db.protocol_version)
179
180    def testAttributeServerVersion(self):
181        server_version = self.db.server_version
182        self.assertIsInstance(server_version, int)
183        self.assertTrue(70400 <= server_version < 100000)
184        self.assertEqual(server_version, self.db.db.server_version)
185
186    def testAttributeStatus(self):
187        status_ok = 1
188        status = self.db.status
189        self.assertIsInstance(status, int)
190        self.assertEqual(status, status_ok)
191        self.assertEqual(status, self.db.db.status)
192
193    def testAttributeUser(self):
194        no_user = 'Deprecated facility'
195        user = self.db.user
196        self.assertTrue(user)
197        self.assertIsInstance(user, str)
198        self.assertNotEqual(user, no_user)
199        self.assertEqual(user, self.db.db.user)
200
201    def testMethodEscapeLiteral(self):
202        self.assertEqual(self.db.escape_literal(''), "''")
203
204    def testMethodEscapeIdentifier(self):
205        self.assertEqual(self.db.escape_identifier(''), '""')
206
207    def testMethodEscapeString(self):
208        self.assertEqual(self.db.escape_string(''), '')
209
210    def testMethodEscapeBytea(self):
211        self.assertEqual(self.db.escape_bytea('').replace(
212            '\\x', '').replace('\\', ''), '')
213
214    def testMethodUnescapeBytea(self):
215        self.assertEqual(self.db.unescape_bytea(''), b'')
216
217    def testMethodQuery(self):
218        query = self.db.query
219        query("select 1+1")
220        query("select 1+$1+$2", 2, 3)
221        query("select 1+$1+$2", (2, 3))
222        query("select 1+$1+$2", [2, 3])
223        query("select 1+$1", 1)
224
225    def testMethodQueryEmpty(self):
226        self.assertRaises(ValueError, self.db.query, '')
227
228    def testMethodQueryProgrammingError(self):
229        try:
230            self.db.query("select 1/0")
231        except pg.ProgrammingError as error:
232            self.assertEqual(error.sqlstate, '22012')
233
234    def testMethodEndcopy(self):
235        try:
236            self.db.endcopy()
237        except IOError:
238            pass
239
240    def testMethodClose(self):
241        self.db.close()
242        try:
243            self.db.reset()
244        except pg.Error:
245            pass
246        else:
247            self.fail('Reset should give an error for a closed connection')
248        self.assertRaises(pg.InternalError, self.db.close)
249        self.assertRaises(pg.InternalError, self.db.query, 'select 1')
250
251    def testExistingConnection(self):
252        db = pg.DB(self.db.db)
253        self.assertEqual(self.db.db, db.db)
254        self.assertTrue(db.db)
255        db.close()
256        self.assertTrue(db.db)
257        db.reopen()
258        self.assertTrue(db.db)
259        db.close()
260        self.assertTrue(db.db)
261        db = pg.DB(self.db)
262        self.assertEqual(self.db.db, db.db)
263        db = pg.DB(db=self.db.db)
264        self.assertEqual(self.db.db, db.db)
265
266        class DB2:
267            pass
268
269        db2 = DB2()
270        db2._cnx = self.db.db
271        db = pg.DB(db2)
272        self.assertEqual(self.db.db, db.db)
273
274
275class TestDBClass(unittest.TestCase):
276    """Test the methods of the DB class wrapped pg connection."""
277
278    @classmethod
279    def setUpClass(cls):
280        db = DB()
281        db.query("drop table if exists test cascade")
282        db.query("create table test ("
283            "i2 smallint, i4 integer, i8 bigint,"
284            "d numeric, f4 real, f8 double precision, m money, "
285            "v4 varchar(4), c4 char(4), t text)")
286        db.query("create or replace view test_view as"
287            " select i4, v4 from test")
288        db.close()
289
290    @classmethod
291    def tearDownClass(cls):
292        db = DB()
293        db.query("drop table test cascade")
294        db.close()
295
296    def setUp(self):
297        self.db = DB()
298        query = self.db.query
299        query('set client_encoding=utf8')
300        query('set standard_conforming_strings=on')
301        query("set lc_monetary='C'")
302        query("set datestyle='ISO,YMD'")
303        query('set bytea_output=hex')
304
305    def tearDown(self):
306        self.db.close()
307
308    def testClassName(self):
309        self.assertEqual(self.db.__class__.__name__, 'DB')
310
311    def testModuleName(self):
312        self.assertEqual(self.db.__module__, 'pg')
313        self.assertEqual(self.db.__class__.__module__, 'pg')
314
315    def testEscapeLiteral(self):
316        f = self.db.escape_literal
317        r = f(b"plain")
318        self.assertIsInstance(r, bytes)
319        self.assertEqual(r, b"'plain'")
320        r = f(u"plain")
321        self.assertIsInstance(r, unicode)
322        self.assertEqual(r, u"'plain'")
323        r = f(u"that's kÀse".encode('utf-8'))
324        self.assertIsInstance(r, bytes)
325        self.assertEqual(r, u"'that''s kÀse'".encode('utf-8'))
326        r = f(u"that's kÀse")
327        self.assertIsInstance(r, unicode)
328        self.assertEqual(r, u"'that''s kÀse'")
329        self.assertEqual(f(r"It's fine to have a \ inside."),
330            r" E'It''s fine to have a \\ inside.'")
331        self.assertEqual(f('No "quotes" must be escaped.'),
332            "'No \"quotes\" must be escaped.'")
333
334    def testEscapeIdentifier(self):
335        f = self.db.escape_identifier
336        r = f(b"plain")
337        self.assertIsInstance(r, bytes)
338        self.assertEqual(r, b'"plain"')
339        r = f(u"plain")
340        self.assertIsInstance(r, unicode)
341        self.assertEqual(r, u'"plain"')
342        r = f(u"that's kÀse".encode('utf-8'))
343        self.assertIsInstance(r, bytes)
344        self.assertEqual(r, u'"that\'s kÀse"'.encode('utf-8'))
345        r = f(u"that's kÀse")
346        self.assertIsInstance(r, unicode)
347        self.assertEqual(r, u'"that\'s kÀse"')
348        self.assertEqual(f(r"It's fine to have a \ inside."),
349            '"It\'s fine to have a \\ inside."')
350        self.assertEqual(f('All "quotes" must be escaped.'),
351            '"All ""quotes"" must be escaped."')
352
353    def testEscapeString(self):
354        f = self.db.escape_string
355        r = f(b"plain")
356        self.assertIsInstance(r, bytes)
357        self.assertEqual(r, b"plain")
358        r = f(u"plain")
359        self.assertIsInstance(r, unicode)
360        self.assertEqual(r, u"plain")
361        r = f(u"that's kÀse".encode('utf-8'))
362        self.assertIsInstance(r, bytes)
363        self.assertEqual(r, u"that''s kÀse".encode('utf-8'))
364        r = f(u"that's kÀse")
365        self.assertIsInstance(r, unicode)
366        self.assertEqual(r, u"that''s kÀse")
367        self.assertEqual(f(r"It's fine to have a \ inside."),
368            r"It''s fine to have a \ inside.")
369
370    def testEscapeBytea(self):
371        f = self.db.escape_bytea
372        # note that escape_byte always returns hex output since Pg 9.0,
373        # regardless of the bytea_output setting
374        r = f(b'plain')
375        self.assertIsInstance(r, bytes)
376        self.assertEqual(r, b'\\x706c61696e')
377        r = f(u'plain')
378        self.assertIsInstance(r, unicode)
379        self.assertEqual(r, u'\\x706c61696e')
380        r = f(u"das is' kÀse".encode('utf-8'))
381        self.assertIsInstance(r, bytes)
382        self.assertEqual(r, b'\\x64617320697327206bc3a47365')
383        r = f(u"das is' kÀse")
384        self.assertIsInstance(r, unicode)
385        self.assertEqual(r, u'\\x64617320697327206bc3a47365')
386        self.assertEqual(f(b'O\x00ps\xff!'), b'\\x4f007073ff21')
387
388    def testUnescapeBytea(self):
389        f = self.db.unescape_bytea
390        r = f(b'plain')
391        self.assertIsInstance(r, bytes)
392        self.assertEqual(r, b'plain')
393        r = f(u'plain')
394        self.assertIsInstance(r, bytes)
395        self.assertEqual(r, b'plain')
396        r = f(b"das is' k\\303\\244se")
397        self.assertIsInstance(r, bytes)
398        self.assertEqual(r, u"das is' kÀse".encode('utf8'))
399        r = f(u"das is' k\\303\\244se")
400        self.assertIsInstance(r, bytes)
401        self.assertEqual(r, u"das is' kÀse".encode('utf8'))
402        self.assertEqual(f(r'O\\000ps\\377!'), b'O\\000ps\\377!')
403        self.assertEqual(f(r'\\x706c61696e'), b'\\x706c61696e')
404        self.assertEqual(f(r'\\x746861742773206be47365'),
405            b'\\x746861742773206be47365')
406        self.assertEqual(f(r'\\x4f007073ff21'), b'\\x4f007073ff21')
407
408    def testQuery(self):
409        query = self.db.query
410        query("drop table if exists test_table")
411        q = "create table test_table (n integer) with oids"
412        r = query(q)
413        self.assertIsNone(r)
414        q = "insert into test_table values (1)"
415        r = query(q)
416        self.assertIsInstance(r, int)
417        q = "insert into test_table select 2"
418        r = query(q)
419        self.assertIsInstance(r, int)
420        oid = r
421        q = "select oid from test_table where n=2"
422        r = query(q).getresult()
423        self.assertEqual(len(r), 1)
424        r = r[0]
425        self.assertEqual(len(r), 1)
426        r = r[0]
427        self.assertEqual(r, oid)
428        q = "insert into test_table select 3 union select 4 union select 5"
429        r = query(q)
430        self.assertIsInstance(r, str)
431        self.assertEqual(r, '3')
432        q = "update test_table set n=4 where n<5"
433        r = query(q)
434        self.assertIsInstance(r, str)
435        self.assertEqual(r, '4')
436        q = "delete from test_table"
437        r = query(q)
438        self.assertIsInstance(r, str)
439        self.assertEqual(r, '5')
440        query("drop table test_table")
441
442    def testMultipleQueries(self):
443        self.assertEqual(self.db.query(
444            "create temporary table test_multi (n integer);"
445            "insert into test_multi values (4711);"
446            "select n from test_multi").getresult()[0][0], 4711)
447
448    def testQueryWithParams(self):
449        query = self.db.query
450        query("drop table if exists test_table")
451        q = "create table test_table (n1 integer, n2 integer) with oids"
452        query(q)
453        q = "insert into test_table values ($1, $2)"
454        r = query(q, (1, 2))
455        self.assertIsInstance(r, int)
456        r = query(q, [3, 4])
457        self.assertIsInstance(r, int)
458        r = query(q, [5, 6])
459        self.assertIsInstance(r, int)
460        q = "select * from test_table order by 1, 2"
461        self.assertEqual(query(q).getresult(),
462            [(1, 2), (3, 4), (5, 6)])
463        q = "select * from test_table where n1=$1 and n2=$2"
464        self.assertEqual(query(q, 3, 4).getresult(), [(3, 4)])
465        q = "update test_table set n2=$2 where n1=$1"
466        r = query(q, 3, 7)
467        self.assertEqual(r, '1')
468        q = "select * from test_table order by 1, 2"
469        self.assertEqual(query(q).getresult(),
470            [(1, 2), (3, 7), (5, 6)])
471        q = "delete from test_table where n2!=$1"
472        r = query(q, 4)
473        self.assertEqual(r, '3')
474        query("drop table test_table")
475
476    def testEmptyQuery(self):
477        self.assertRaises(ValueError, self.db.query, '')
478
479    def testQueryProgrammingError(self):
480        try:
481            self.db.query("select 1/0")
482        except pg.ProgrammingError as error:
483            self.assertEqual(error.sqlstate, '22012')
484
485    def testPkey(self):
486        query = self.db.query
487        pkey = self.db.pkey
488        for t in ('pkeytest', 'primary key test'):
489            for n in range(7):
490                query('drop table if exists "%s%d"' % (t, n))
491            query('create table "%s0" ('
492                "a smallint)" % t)
493            query('create table "%s1" ('
494                "b smallint primary key)" % t)
495            query('create table "%s2" ('
496                "c smallint, d smallint primary key)" % t)
497            query('create table "%s3" ('
498                "e smallint, f smallint, g smallint, "
499                "h smallint, i smallint, "
500                "primary key (f, h))" % t)
501            query('create table "%s4" ('
502                "more_than_one_letter varchar primary key)" % t)
503            query('create table "%s5" ('
504                '"with space" date primary key)' % t)
505            query('create table "%s6" ('
506                'a_very_long_column_name varchar, '
507                '"with space" date, '
508                '"42" int, '
509                "primary key (a_very_long_column_name, "
510                '"with space", "42"))' % t)
511            self.assertRaises(KeyError, pkey, '%s0' % t)
512            self.assertEqual(pkey('%s1' % t), 'b')
513            self.assertEqual(pkey('%s2' % t), 'd')
514            r = pkey('%s3' % t)
515            self.assertIsInstance(r, frozenset)
516            self.assertEqual(r, frozenset('fh'))
517            self.assertEqual(pkey('%s4' % t), 'more_than_one_letter')
518            self.assertEqual(pkey('%s5' % t), 'with space')
519            r = pkey('%s6' % t)
520            self.assertIsInstance(r, frozenset)
521            self.assertEqual(r, frozenset([
522                'a_very_long_column_name', 'with space', '42']))
523            # a newly added primary key will be detected
524            query('alter table "%s0" add primary key (a)' % t)
525            self.assertEqual(pkey('%s0' % t), 'a')
526            # a changed primary key will not be detected,
527            # indicating that the internal cache is operating
528            query('alter table "%s1" rename column b to x' % t)
529            self.assertEqual(pkey('%s1' % t), 'b')
530            # we get the changed primary key when the cache is flushed
531            self.assertEqual(pkey('%s1' % t, flush=True), 'x')
532            for n in range(7):
533                query('drop table "%s%d"' % (t, n))
534
535    def testGetDatabases(self):
536        databases = self.db.get_databases()
537        self.assertIn('template0', databases)
538        self.assertIn('template1', databases)
539        self.assertNotIn('not existing database', databases)
540        self.assertIn('postgres', databases)
541        self.assertIn(dbname, databases)
542
543    def testGetTables(self):
544        get_tables = self.db.get_tables
545        result1 = get_tables()
546        self.assertIsInstance(result1, list)
547        for t in result1:
548            t = t.split('.', 1)
549            self.assertGreaterEqual(len(t), 2)
550            if len(t) > 2:
551                self.assertTrue(t[1].startswith('"'))
552            t = t[0]
553            self.assertNotEqual(t, 'information_schema')
554            self.assertFalse(t.startswith('pg_'))
555        tables = ('"A very Special Name"',
556            '"A_MiXeD_quoted_NaMe"', 'a1', 'a2',
557            'A_MiXeD_NaMe', '"another special name"',
558            'averyveryveryveryveryveryverylongtablename',
559            'b0', 'b3', 'x', 'xx', 'xXx', 'y', 'z')
560        for t in tables:
561            self.db.query('drop table if exists %s' % t)
562            self.db.query("create table %s"
563                " as select 0" % t)
564        result3 = get_tables()
565        result2 = []
566        for t in result3:
567            if t not in result1:
568                result2.append(t)
569        result3 = []
570        for t in tables:
571            if not t.startswith('"'):
572                t = t.lower()
573            result3.append('public.' + t)
574        self.assertEqual(result2, result3)
575        for t in result2:
576            self.db.query('drop table %s' % t)
577        result2 = get_tables()
578        self.assertEqual(result2, result1)
579
580    def testGetRelations(self):
581        get_relations = self.db.get_relations
582        result = get_relations()
583        self.assertIn('public.test', result)
584        self.assertIn('public.test_view', result)
585        result = get_relations('rv')
586        self.assertIn('public.test', result)
587        self.assertIn('public.test_view', result)
588        result = get_relations('r')
589        self.assertIn('public.test', result)
590        self.assertNotIn('public.test_view', result)
591        result = get_relations('v')
592        self.assertNotIn('public.test', result)
593        self.assertIn('public.test_view', result)
594        result = get_relations('cisSt')
595        self.assertNotIn('public.test', result)
596        self.assertNotIn('public.test_view', result)
597
598    def testAttnames(self):
599        self.assertRaises(pg.ProgrammingError,
600            self.db.get_attnames, 'does_not_exist')
601        self.assertRaises(pg.ProgrammingError,
602            self.db.get_attnames, 'has.too.many.dots')
603        for table in ('attnames_test_table', 'test table for attnames'):
604            self.db.query('drop table if exists "%s"' % table)
605            self.db.query('create table "%s" ('
606                'a smallint, b integer, c bigint, '
607                'e numeric, f float, f2 double precision, m money, '
608                'x smallint, y smallint, z smallint, '
609                'Normal_NaMe smallint, "Special Name" smallint, '
610                't text, u char(2), v varchar(2), '
611                'primary key (y, u)) with oids' % table)
612            attributes = self.db.get_attnames(table)
613            result = {'a': 'int', 'c': 'int', 'b': 'int',
614                'e': 'num', 'f': 'float', 'f2': 'float', 'm': 'money',
615                'normal_name': 'int', 'Special Name': 'int',
616                'u': 'text', 't': 'text', 'v': 'text',
617                'y': 'int', 'x': 'int', 'z': 'int', 'oid': 'int'}
618            self.assertEqual(attributes, result)
619            self.db.query('drop table "%s"' % table)
620
621    def testHasTablePrivilege(self):
622        can = self.db.has_table_privilege
623        self.assertEqual(can('test'), True)
624        self.assertEqual(can('test', 'select'), True)
625        self.assertEqual(can('test', 'SeLeCt'), True)
626        self.assertEqual(can('test', 'SELECT'), True)
627        self.assertEqual(can('test', 'insert'), True)
628        self.assertEqual(can('test', 'update'), True)
629        self.assertEqual(can('test', 'delete'), True)
630        self.assertEqual(can('pg_views', 'select'), True)
631        self.assertEqual(can('pg_views', 'delete'), False)
632        self.assertRaises(pg.ProgrammingError, can, 'test', 'foobar')
633        self.assertRaises(pg.ProgrammingError, can, 'table_does_not_exist')
634
635    def testGet(self):
636        get = self.db.get
637        query = self.db.query
638        for table in ('get_test_table', 'test table for get'):
639            query('drop table if exists "%s"' % table)
640            query('create table "%s" ('
641                "n integer, t text) with oids" % table)
642            for n, t in enumerate('xyz'):
643                query('insert into "%s" values('"%d, '%s')"
644                    % (table, n + 1, t))
645            self.assertRaises(pg.ProgrammingError, get, table, 2)
646            r = get(table, 2, 'n')
647            oid_table = 'oid(%s)' % table
648            self.assertIn(oid_table, r)
649            oid = r[oid_table]
650            self.assertIsInstance(oid, int)
651            result = {'t': 'y', 'n': 2, oid_table: oid}
652            self.assertEqual(r, result)
653            self.assertEqual(get(table + ' *', 2, 'n'), r)
654            self.assertEqual(get(table, oid, 'oid')['t'], 'y')
655            self.assertEqual(get(table, 1, 'n')['t'], 'x')
656            self.assertEqual(get(table, 3, 'n')['t'], 'z')
657            self.assertEqual(get(table, 2, 'n')['t'], 'y')
658            self.assertRaises(pg.DatabaseError, get, table, 4, 'n')
659            r['n'] = 3
660            self.assertEqual(get(table, r, 'n')['t'], 'z')
661            self.assertEqual(get(table, 1, 'n')['t'], 'x')
662            query('alter table "%s" alter n set not null' % table)
663            query('alter table "%s" add primary key (n)' % table)
664            self.assertEqual(get(table, 3)['t'], 'z')
665            self.assertEqual(get(table, 1)['t'], 'x')
666            self.assertEqual(get(table, 2)['t'], 'y')
667            r['n'] = 1
668            self.assertEqual(get(table, r)['t'], 'x')
669            r['n'] = 3
670            self.assertEqual(get(table, r)['t'], 'z')
671            r['n'] = 2
672            self.assertEqual(get(table, r)['t'], 'y')
673            query('drop table "%s"' % table)
674
675    def testGetWithCompositeKey(self):
676        get = self.db.get
677        query = self.db.query
678        table = 'get_test_table_1'
679        query("drop table if exists %s" % table)
680        query("create table %s ("
681            "n integer, t text, primary key (n))" % table)
682        for n, t in enumerate('abc'):
683            query("insert into %s values("
684                "%d, '%s')" % (table, n + 1, t))
685        self.assertEqual(get(table, 2)['t'], 'b')
686        query("drop table %s" % table)
687        table = 'get_test_table_2'
688        query("drop table if exists %s" % table)
689        query("create table %s ("
690            "n integer, m integer, t text, primary key (n, m))" % table)
691        for n in range(3):
692            for m in range(2):
693                t = chr(ord('a') + 2 * n + m)
694                query("insert into %s values("
695                    "%d, %d, '%s')" % (table, n + 1, m + 1, t))
696        self.assertRaises(pg.ProgrammingError, get, table, 2)
697        self.assertEqual(get(table, dict(n=2, m=2))['t'], 'd')
698        self.assertEqual(get(table, dict(n=1, m=2),
699                             ('n', 'm'))['t'], 'b')
700        self.assertEqual(get(table, dict(n=3, m=2),
701                             frozenset(['n', 'm']))['t'], 'f')
702        query("drop table %s" % table)
703
704    def testGetFromView(self):
705        self.db.query('delete from test where i4=14')
706        self.db.query('insert into test (i4, v4) values('
707            "14, 'abc4')")
708        r = self.db.get('test_view', 14, 'i4')
709        self.assertIn('v4', r)
710        self.assertEqual(r['v4'], 'abc4')
711
712    def testInsert(self):
713        insert = self.db.insert
714        query = self.db.query
715        bool_on = pg.get_bool()
716        decimal = pg.get_decimal()
717        for table in ('insert_test_table', 'test table for insert'):
718            query('drop table if exists "%s"' % table)
719            query('create table "%s" ('
720                "i2 smallint, i4 integer, i8 bigint,"
721                " d numeric, f4 real, f8 double precision, m money,"
722                " v4 varchar(4), c4 char(4), t text,"
723                " b boolean, ts timestamp) with oids" % table)
724            oid_table = 'oid(%s)' % table
725            tests = [dict(i2=None, i4=None, i8=None),
726                (dict(i2='', i4='', i8=''), dict(i2=None, i4=None, i8=None)),
727                (dict(i2=0, i4=0, i8=0), dict(i2=0, i4=0, i8=0)),
728                dict(i2=42, i4=123456, i8=9876543210),
729                dict(i2=2 ** 15 - 1,
730                    i4=int(2 ** 31 - 1), i8=long(2 ** 63 - 1)),
731                dict(d=None), (dict(d=''), dict(d=None)),
732                dict(d=Decimal(0)), (dict(d=0), dict(d=Decimal(0))),
733                dict(f4=None, f8=None), dict(f4=0, f8=0),
734                (dict(f4='', f8=''), dict(f4=None, f8=None)),
735                (dict(d=1234.5, f4=1234.5, f8=1234.5),
736                      dict(d=Decimal('1234.5'))),
737                dict(d=Decimal('123.456789'), f4=12.375, f8=123.4921875),
738                dict(d=Decimal('123456789.9876543212345678987654321')),
739                dict(m=None), (dict(m=''), dict(m=None)),
740                dict(m=Decimal('-1234.56')),
741                (dict(m=('-1234.56')), dict(m=Decimal('-1234.56'))),
742                dict(m=Decimal('1234.56')), dict(m=Decimal('123456')),
743                (dict(m='1234.56'), dict(m=Decimal('1234.56'))),
744                (dict(m=1234.5), dict(m=Decimal('1234.5'))),
745                (dict(m=-1234.5), dict(m=Decimal('-1234.5'))),
746                (dict(m=123456), dict(m=Decimal('123456'))),
747                (dict(m='1234567.89'), dict(m=Decimal('1234567.89'))),
748                dict(b=None), (dict(b=''), dict(b=None)),
749                dict(b='f'), dict(b='t'),
750                (dict(b=0), dict(b='f')), (dict(b=1), dict(b='t')),
751                (dict(b=False), dict(b='f')), (dict(b=True), dict(b='t')),
752                (dict(b='0'), dict(b='f')), (dict(b='1'), dict(b='t')),
753                (dict(b='n'), dict(b='f')), (dict(b='y'), dict(b='t')),
754                (dict(b='no'), dict(b='f')), (dict(b='yes'), dict(b='t')),
755                (dict(b='off'), dict(b='f')), (dict(b='on'), dict(b='t')),
756                dict(v4=None, c4=None, t=None),
757                (dict(v4='', c4='', t=''), dict(c4=' ' * 4)),
758                dict(v4='1234', c4='1234', t='1234' * 10),
759                dict(v4='abcd', c4='abcd', t='abcdefg'),
760                (dict(v4='abc', c4='abc', t='abc'), dict(c4='abc ')),
761                dict(ts=None), (dict(ts=''), dict(ts=None)),
762                (dict(ts=0), dict(ts=None)), (dict(ts=False), dict(ts=None)),
763                dict(ts='2012-12-21 00:00:00'),
764                (dict(ts='2012-12-21'), dict(ts='2012-12-21 00:00:00')),
765                dict(ts='2012-12-21 12:21:12'),
766                dict(ts='2013-01-05 12:13:14'),
767                dict(ts='current_timestamp')]
768            for test in tests:
769                if isinstance(test, dict):
770                    data = test
771                    change = {}
772                else:
773                    data, change = test
774                expect = data.copy()
775                expect.update(change)
776                if bool_on:
777                    b = expect.get('b')
778                    if b is not None:
779                        expect['b'] = b == 't'
780                if decimal is not Decimal:
781                    d = expect.get('d')
782                    if d is not None:
783                        expect['d'] = decimal(d)
784                    m = expect.get('m')
785                    if m is not None:
786                        expect['m'] = decimal(m)
787                self.assertEqual(insert(table, data), data)
788                self.assertIn(oid_table, data)
789                oid = data[oid_table]
790                self.assertIsInstance(oid, int)
791                data = dict(item for item in data.items()
792                    if item[0] in expect)
793                ts = expect.get('ts')
794                if ts == 'current_timestamp':
795                    ts = expect['ts'] = data['ts']
796                    if len(ts) > 19:
797                        self.assertEqual(ts[19], '.')
798                        ts = ts[:19]
799                    else:
800                        self.assertEqual(len(ts), 19)
801                    self.assertTrue(ts[:4].isdigit())
802                    self.assertEqual(ts[4], '-')
803                    self.assertEqual(ts[10], ' ')
804                    self.assertTrue(ts[11:13].isdigit())
805                    self.assertEqual(ts[13], ':')
806                self.assertEqual(data, expect)
807                data = query(
808                    'select oid,* from "%s"' % table).dictresult()[0]
809                self.assertEqual(data['oid'], oid)
810                data = dict(item for item in data.items()
811                    if item[0] in expect)
812                self.assertEqual(data, expect)
813                query('delete from "%s"' % table)
814            query('drop table "%s"' % table)
815
816    def testUpdate(self):
817        update = self.db.update
818        query = self.db.query
819        for table in ('update_test_table', 'test table for update'):
820            query('drop table if exists "%s"' % table)
821            query('create table "%s" ('
822                "n integer, t text) with oids" % table)
823            for n, t in enumerate('xyz'):
824                query('insert into "%s" values('
825                    "%d, '%s')" % (table, n + 1, t))
826            self.assertRaises(pg.ProgrammingError, self.db.get, table, 2)
827            r = self.db.get(table, 2, 'n')
828            r['t'] = 'u'
829            s = update(table, r)
830            self.assertEqual(s, r)
831            r = query('select t from "%s" where n=2' % table
832                      ).getresult()[0][0]
833            self.assertEqual(r, 'u')
834            query('drop table "%s"' % table)
835
836    def testUpdateWithCompositeKey(self):
837        update = self.db.update
838        query = self.db.query
839        table = 'update_test_table_1'
840        query("drop table if exists %s" % table)
841        query("create table %s ("
842            "n integer, t text, primary key (n))" % table)
843        for n, t in enumerate('abc'):
844            query("insert into %s values("
845                "%d, '%s')" % (table, n + 1, t))
846        self.assertRaises(pg.ProgrammingError, update,
847                          table, dict(t='b'))
848        self.assertEqual(update(table, dict(n=2, t='d'))['t'], 'd')
849        r = query('select t from "%s" where n=2' % table
850                  ).getresult()[0][0]
851        self.assertEqual(r, 'd')
852        query("drop table %s" % table)
853        table = 'update_test_table_2'
854        query("drop table if exists %s" % table)
855        query("create table %s ("
856            "n integer, m integer, t text, primary key (n, m))" % table)
857        for n in range(3):
858            for m in range(2):
859                t = chr(ord('a') + 2 * n + m)
860                query("insert into %s values("
861                    "%d, %d, '%s')" % (table, n + 1, m + 1, t))
862        self.assertRaises(pg.ProgrammingError, update,
863                          table, dict(n=2, t='b'))
864        self.assertEqual(update(table,
865                                dict(n=2, m=2, t='x'))['t'], 'x')
866        r = [r[0] for r in query('select t from "%s" where n=2'
867            ' order by m' % table).getresult()]
868        self.assertEqual(r, ['c', 'x'])
869        query("drop table %s" % table)
870
871    def testClear(self):
872        clear = self.db.clear
873        query = self.db.query
874        f = False if pg.get_bool() else 'f'
875        for table in ('clear_test_table', 'test table for clear'):
876            query('drop table if exists "%s"' % table)
877            query('create table "%s" ('
878                "n integer, b boolean, d date, t text)" % table)
879            r = clear(table)
880            result = {'n': 0, 'b': f, 'd': '', 't': ''}
881            self.assertEqual(r, result)
882            r['a'] = r['n'] = 1
883            r['d'] = r['t'] = 'x'
884            r['b'] = 't'
885            r['oid'] = long(1)
886            r = clear(table, r)
887            result = {'a': 1, 'n': 0, 'b': f, 'd': '', 't': '',
888                'oid': long(1)}
889            self.assertEqual(r, result)
890            query('drop table "%s"' % table)
891
892    def testDelete(self):
893        delete = self.db.delete
894        query = self.db.query
895        for table in ('delete_test_table', 'test table for delete'):
896            query('drop table if exists "%s"' % table)
897            query('create table "%s" ('
898                "n integer, t text) with oids" % table)
899            for n, t in enumerate('xyz'):
900                query('insert into "%s" values('
901                    "%d, '%s')" % (table, n + 1, t))
902            self.assertRaises(pg.ProgrammingError, self.db.get, table, 2)
903            r = self.db.get(table, 1, 'n')
904            s = delete(table, r)
905            self.assertEqual(s, 1)
906            r = self.db.get(table, 3, 'n')
907            s = delete(table, r)
908            self.assertEqual(s, 1)
909            s = delete(table, r)
910            self.assertEqual(s, 0)
911            r = query('select * from "%s"' % table).dictresult()
912            self.assertEqual(len(r), 1)
913            r = r[0]
914            result = {'n': 2, 't': 'y'}
915            self.assertEqual(r, result)
916            r = self.db.get(table, 2, 'n')
917            s = delete(table, r)
918            self.assertEqual(s, 1)
919            s = delete(table, r)
920            self.assertEqual(s, 0)
921            self.assertRaises(pg.DatabaseError, self.db.get, table, 2, 'n')
922            query('drop table "%s"' % table)
923
924    def testDeleteWithCompositeKey(self):
925        query = self.db.query
926        table = 'delete_test_table_1'
927        query("drop table if exists %s" % table)
928        query("create table %s ("
929            "n integer, t text, primary key (n))" % table)
930        for n, t in enumerate('abc'):
931            query("insert into %s values("
932                "%d, '%s')" % (table, n + 1, t))
933        self.assertRaises(pg.ProgrammingError, self.db.delete,
934            table, dict(t='b'))
935        self.assertEqual(self.db.delete(table, dict(n=2)), 1)
936        r = query('select t from "%s" where n=2' % table
937                  ).getresult()
938        self.assertEqual(r, [])
939        self.assertEqual(self.db.delete(table, dict(n=2)), 0)
940        r = query('select t from "%s" where n=3' % table
941                  ).getresult()[0][0]
942        self.assertEqual(r, 'c')
943        query("drop table %s" % table)
944        table = 'delete_test_table_2'
945        query("drop table if exists %s" % table)
946        query("create table %s ("
947            "n integer, m integer, t text, primary key (n, m))" % table)
948        for n in range(3):
949            for m in range(2):
950                t = chr(ord('a') + 2 * n + m)
951                query("insert into %s values("
952                    "%d, %d, '%s')" % (table, n + 1, m + 1, t))
953        self.assertRaises(pg.ProgrammingError, self.db.delete,
954            table, dict(n=2, t='b'))
955        self.assertEqual(self.db.delete(table, dict(n=2, m=2)), 1)
956        r = [r[0] for r in query('select t from "%s" where n=2'
957            ' order by m' % table).getresult()]
958        self.assertEqual(r, ['c'])
959        self.assertEqual(self.db.delete(table, dict(n=2, m=2)), 0)
960        r = [r[0] for r in query('select t from "%s" where n=3'
961            ' order by m' % table).getresult()]
962        self.assertEqual(r, ['e', 'f'])
963        self.assertEqual(self.db.delete(table, dict(n=3, m=1)), 1)
964        r = [r[0] for r in query('select t from "%s" where n=3'
965            ' order by m' % table).getresult()]
966        self.assertEqual(r, ['f'])
967        query("drop table %s" % table)
968
969    def testTransaction(self):
970        query = self.db.query
971        query("drop table if exists test_table")
972        query("create table test_table (n integer)")
973        self.db.begin()
974        query("insert into test_table values (1)")
975        query("insert into test_table values (2)")
976        self.db.commit()
977        self.db.begin()
978        query("insert into test_table values (3)")
979        query("insert into test_table values (4)")
980        self.db.rollback()
981        self.db.begin()
982        query("insert into test_table values (5)")
983        self.db.savepoint('before6')
984        query("insert into test_table values (6)")
985        self.db.rollback('before6')
986        query("insert into test_table values (7)")
987        self.db.commit()
988        self.db.begin()
989        self.db.savepoint('before8')
990        query("insert into test_table values (8)")
991        self.db.release('before8')
992        self.assertRaises(pg.ProgrammingError, self.db.rollback, 'before8')
993        self.db.commit()
994        self.db.start()
995        query("insert into test_table values (9)")
996        self.db.end()
997        r = [r[0] for r in query(
998            "select * from test_table order by 1").getresult()]
999        self.assertEqual(r, [1, 2, 5, 7, 9])
1000        query("drop table test_table")
1001
1002    def testContextManager(self):
1003        query = self.db.query
1004        query("drop table if exists test_table")
1005        query("create table test_table (n integer check(n>0))")
1006        with self.db:
1007            query("insert into test_table values (1)")
1008            query("insert into test_table values (2)")
1009        try:
1010            with self.db:
1011                query("insert into test_table values (3)")
1012                query("insert into test_table values (4)")
1013                raise ValueError('test transaction should rollback')
1014        except ValueError as error:
1015            self.assertEqual(str(error), 'test transaction should rollback')
1016        with self.db:
1017            query("insert into test_table values (5)")
1018        try:
1019            with self.db:
1020                query("insert into test_table values (6)")
1021                query("insert into test_table values (-1)")
1022        except pg.ProgrammingError as error:
1023            self.assertTrue('check' in str(error))
1024        with self.db:
1025            query("insert into test_table values (7)")
1026        r = [r[0] for r in query(
1027            "select * from test_table order by 1").getresult()]
1028        self.assertEqual(r, [1, 2, 5, 7])
1029        query("drop table test_table")
1030
1031    def testBytea(self):
1032        query = self.db.query
1033        query('drop table if exists bytea_test')
1034        query('create table bytea_test (n smallint primary key, data bytea)')
1035        s = b"It's all \\ kinds \x00 of\r nasty \xff stuff!\n"
1036        r = self.db.escape_bytea(s)
1037        query('insert into bytea_test values(3,$1)', (r,))
1038        r = query('select * from bytea_test where n=3').getresult()
1039        self.assertEqual(len(r), 1)
1040        r = r[0]
1041        self.assertEqual(len(r), 2)
1042        self.assertEqual(r[0], 3)
1043        r = r[1]
1044        self.assertIsInstance(r, str)
1045        r = self.db.unescape_bytea(r)
1046        self.assertIsInstance(r, bytes)
1047        self.assertEqual(r, s)
1048        query('drop table bytea_test')
1049
1050    def testInsertUpdateGetBytea(self):
1051        query = self.db.query
1052        query('drop table if exists bytea_test')
1053        query('create table bytea_test (n smallint primary key, data bytea)')
1054        # insert null value
1055        r = self.db.insert('bytea_test', n=0, data=None)
1056        self.assertIsInstance(r, dict)
1057        self.assertIn('n', r)
1058        self.assertEqual(r['n'], 0)
1059        self.assertIn('data', r)
1060        self.assertIsNone(r['data'])
1061        s = b'None'
1062        r = self.db.update('bytea_test', n=0, data=s)
1063        self.assertIsInstance(r, dict)
1064        self.assertIn('n', r)
1065        self.assertEqual(r['n'], 0)
1066        self.assertIn('data', r)
1067        r = r['data']
1068        self.assertIsInstance(r, bytes)
1069        self.assertEqual(r, s)
1070        r = self.db.update('bytea_test', n=0, data=None)
1071        self.assertIsNone(r['data'])
1072        # insert as bytes
1073        s = b"It's all \\ kinds \x00 of\r nasty \xff stuff!\n"
1074        r = self.db.insert('bytea_test', n=5, data=s)
1075        self.assertIsInstance(r, dict)
1076        self.assertIn('n', r)
1077        self.assertEqual(r['n'], 5)
1078        self.assertIn('data', r)
1079        r = r['data']
1080        self.assertIsInstance(r, bytes)
1081        self.assertEqual(r, s)
1082        # update as bytes
1083        s += b"and now even more \x00 nasty \t stuff!\f"
1084        r = self.db.update('bytea_test', n=5, data=s)
1085        self.assertIsInstance(r, dict)
1086        self.assertIn('n', r)
1087        self.assertEqual(r['n'], 5)
1088        self.assertIn('data', r)
1089        r = r['data']
1090        self.assertIsInstance(r, bytes)
1091        self.assertEqual(r, s)
1092        r = query('select * from bytea_test where n=5').getresult()
1093        self.assertEqual(len(r), 1)
1094        r = r[0]
1095        self.assertEqual(len(r), 2)
1096        self.assertEqual(r[0], 5)
1097        r = r[1]
1098        self.assertIsInstance(r, str)
1099        r = self.db.unescape_bytea(r)
1100        self.assertIsInstance(r, bytes)
1101        self.assertEqual(r, s)
1102        r = self.db.get('bytea_test', dict(n=5))
1103        self.assertIsInstance(r, dict)
1104        self.assertIn('n', r)
1105        self.assertEqual(r['n'], 5)
1106        self.assertIn('data', r)
1107        r = r['data']
1108        self.assertIsInstance(r, bytes)
1109        self.assertEqual(r, s)
1110        query('drop table bytea_test')
1111
1112    def testDebugWithCallable(self):
1113        if debug:
1114            self.assertEqual(self.db.debug, debug)
1115        else:
1116            self.assertIsNone(self.db.debug)
1117        s = []
1118        self.db.debug = s.append
1119        try:
1120            self.db.query("select 1")
1121            self.db.query("select 2")
1122            self.assertEqual(s, ["select 1", "select 2"])
1123        finally:
1124            self.db.debug = debug
1125
1126
1127class TestDBClassNonStdOpts(TestDBClass):
1128    """Test the methods of the DB class with non-standard global options."""
1129
1130    @classmethod
1131    def setUpClass(cls):
1132        cls.saved_options = {}
1133        cls.set_option('decimal', float)
1134        not_bool = not pg.get_bool()
1135        cls.set_option('bool', not_bool)
1136        unnamed_result = lambda q: q.getresult()
1137        cls.set_option('namedresult', unnamed_result)
1138        super(TestDBClassNonStdOpts, cls).setUpClass()
1139
1140    @classmethod
1141    def tearDownClass(cls):
1142        super(TestDBClassNonStdOpts, cls).tearDownClass()
1143        cls.reset_option('namedresult')
1144        cls.reset_option('bool')
1145        cls.reset_option('decimal')
1146
1147    @classmethod
1148    def set_option(cls, option, value):
1149        cls.saved_options[option] = getattr(pg, 'get_' + option)()
1150        return getattr(pg, 'set_' + option)(value)
1151
1152    @classmethod
1153    def reset_option(cls, option):
1154        return getattr(pg, 'set_' + option)(cls.saved_options[option])
1155
1156
1157class TestSchemas(unittest.TestCase):
1158    """Test correct handling of schemas (namespaces)."""
1159
1160    @classmethod
1161    def setUpClass(cls):
1162        db = DB()
1163        query = db.query
1164        query("set client_min_messages=warning")
1165        for num_schema in range(5):
1166            if num_schema:
1167                schema = "s%d" % num_schema
1168                query("drop schema if exists %s cascade" % (schema,))
1169                try:
1170                    query("create schema %s" % (schema,))
1171                except pg.ProgrammingError:
1172                    raise RuntimeError("The test user cannot create schemas.\n"
1173                        "Grant create on database %s to the user"
1174                        " for running these tests." % dbname)
1175            else:
1176                schema = "public"
1177                query("drop table if exists %s.t" % (schema,))
1178                query("drop table if exists %s.t%d" % (schema, num_schema))
1179            query("create table %s.t with oids as select 1 as n, %d as d"
1180                  % (schema, num_schema))
1181            query("create table %s.t%d with oids as select 1 as n, %d as d"
1182                  % (schema, num_schema, num_schema))
1183        db.close()
1184
1185    @classmethod
1186    def tearDownClass(cls):
1187        db = DB()
1188        query = db.query
1189        query("set client_min_messages=warning")
1190        for num_schema in range(5):
1191            if num_schema:
1192                schema = "s%d" % num_schema
1193                query("drop schema %s cascade" % (schema,))
1194            else:
1195                schema = "public"
1196                query("drop table %s.t" % (schema,))
1197                query("drop table %s.t%d" % (schema, num_schema))
1198        db.close()
1199
1200    def setUp(self):
1201        self.db = DB()
1202        self.db.query("set client_min_messages=warning")
1203
1204    def tearDown(self):
1205        self.db.close()
1206
1207    def testGetTables(self):
1208        tables = self.db.get_tables()
1209        for num_schema in range(5):
1210            if num_schema:
1211                schema = "s" + str(num_schema)
1212            else:
1213                schema = "public"
1214            for t in (schema + ".t",
1215                    schema + ".t" + str(num_schema)):
1216                self.assertIn(t, tables)
1217
1218    def testGetAttnames(self):
1219        get_attnames = self.db.get_attnames
1220        query = self.db.query
1221        result = {'oid': 'int', 'd': 'int', 'n': 'int'}
1222        r = get_attnames("t")
1223        self.assertEqual(r, result)
1224        r = get_attnames("s4.t4")
1225        self.assertEqual(r, result)
1226        query("drop table if exists s3.t3m")
1227        query("create table s3.t3m with oids as select 1 as m")
1228        result_m = {'oid': 'int', 'm': 'int'}
1229        r = get_attnames("s3.t3m")
1230        self.assertEqual(r, result_m)
1231        query("set search_path to s1,s3")
1232        r = get_attnames("t3")
1233        self.assertEqual(r, result)
1234        r = get_attnames("t3m")
1235        self.assertEqual(r, result_m)
1236        query("drop table s3.t3m")
1237
1238    def testGet(self):
1239        get = self.db.get
1240        query = self.db.query
1241        PrgError = pg.ProgrammingError
1242        self.assertEqual(get("t", 1, 'n')['d'], 0)
1243        self.assertEqual(get("t0", 1, 'n')['d'], 0)
1244        self.assertEqual(get("public.t", 1, 'n')['d'], 0)
1245        self.assertEqual(get("public.t0", 1, 'n')['d'], 0)
1246        self.assertRaises(PrgError, get, "public.t1", 1, 'n')
1247        self.assertEqual(get("s1.t1", 1, 'n')['d'], 1)
1248        self.assertEqual(get("s3.t", 1, 'n')['d'], 3)
1249        query("set search_path to s2,s4")
1250        self.assertRaises(PrgError, get, "t1", 1, 'n')
1251        self.assertEqual(get("t4", 1, 'n')['d'], 4)
1252        self.assertRaises(PrgError, get, "t3", 1, 'n')
1253        self.assertEqual(get("t", 1, 'n')['d'], 2)
1254        self.assertEqual(get("s3.t3", 1, 'n')['d'], 3)
1255        query("set search_path to s1,s3")
1256        self.assertRaises(PrgError, get, "t2", 1, 'n')
1257        self.assertEqual(get("t3", 1, 'n')['d'], 3)
1258        self.assertRaises(PrgError, get, "t4", 1, 'n')
1259        self.assertEqual(get("t", 1, 'n')['d'], 1)
1260        self.assertEqual(get("s4.t4", 1, 'n')['d'], 4)
1261
1262    def testMunging(self):
1263        get = self.db.get
1264        query = self.db.query
1265        r = get("t", 1, 'n')
1266        self.assertIn('oid(t)', r)
1267        query("set search_path to s2")
1268        r = get("t2", 1, 'n')
1269        self.assertIn('oid(t2)', r)
1270        query("set search_path to s3")
1271        r = get("t", 1, 'n')
1272        self.assertIn('oid(t)', r)
1273
1274
1275if __name__ == '__main__':
1276    unittest.main()
Note: See TracBrowser for help on using the repository browser.