source: branches/4.x/tests/test_classic_dbwrapper.py @ 743

Last change on this file since 743 was 743, checked in by cito, 4 years ago

Test error messages and security of the get() method

The get() method should be immune against SQL hacking with apostrophes in
values, and give a proper and helpful error message if a row is not found.

  • Property svn:executable set to *
  • Property svn:keywords set to Id
File size: 47.7 KB
Line 
1#! /usr/bin/python
2# -*- coding: utf-8 -*-
3
4"""Test the classic PyGreSQL interface.
5
6Sub-tests for the DB wrapper object.
7
8Contributed by Christoph Zwerschke.
9
10These tests need a database to test against.
11
12"""
13
14try:
15    import unittest2 as unittest  # for Python < 2.7
16except ImportError:
17    import unittest
18import os
19
20import sys
21
22import pg  # the module under test
23
24from decimal import Decimal
25
26# check whether the "with" statement is supported
27no_with = sys.version_info[:2] < (2, 5)
28
29# We need a database to test against.  If LOCAL_PyGreSQL.py exists we will
30# get our information from that.  Otherwise we use the defaults.
31# The current user must have create schema privilege on the database.
32dbname = 'unittest'
33dbhost = None
34dbport = 5432
35
36debug = False  # let DB wrapper print debugging output
37
38try:
39    from LOCAL_PyGreSQL import *
40except ImportError:
41    pass
42
43windows = os.name == 'nt'
44
45# There is a known a bug in libpq under Windows which can cause
46# the interface to crash when calling PQhost():
47do_not_ask_for_host = windows
48do_not_ask_for_host_reason = 'libpq issue on Windows'
49
50
51def DB():
52    """Create a DB wrapper object connecting to the test database."""
53    db = pg.DB(dbname, dbhost, dbport)
54    if debug:
55        db.debug = debug
56    db.query("set client_min_messages=warning")
57    return db
58
59
60class TestDBClassBasic(unittest.TestCase):
61    """Test existence of the DB class wrapped pg connection methods."""
62
63    def setUp(self):
64        self.db = DB()
65
66    def tearDown(self):
67        try:
68            self.db.close()
69        except pg.InternalError:
70            pass
71
72    def testAllDBAttributes(self):
73        attributes = [
74            'begin',
75            'cancel',
76            'clear',
77            'close',
78            'commit',
79            'db',
80            'dbname',
81            'debug',
82            'delete',
83            'end',
84            'endcopy',
85            'error',
86            'escape_bytea',
87            'escape_identifier',
88            'escape_literal',
89            'escape_string',
90            'fileno',
91            'get',
92            'get_attnames',
93            'get_databases',
94            'get_notice_receiver',
95            'get_relations',
96            'get_tables',
97            'getline',
98            'getlo',
99            'getnotify',
100            'has_table_privilege',
101            'host',
102            'insert',
103            'inserttable',
104            'locreate',
105            'loimport',
106            'notification_handler',
107            'options',
108            'parameter',
109            'pkey',
110            'port',
111            'protocol_version',
112            'putline',
113            'query',
114            'release',
115            'reopen',
116            'reset',
117            'rollback',
118            'savepoint',
119            'server_version',
120            'set_notice_receiver',
121            'source',
122            'start',
123            'status',
124            'transaction',
125            'tty',
126            'unescape_bytea',
127            'update',
128            'use_regtypes',
129            'user',
130        ]
131        if self.db.server_version < 90000:  # PostgreSQL < 9.0
132            attributes.remove('escape_identifier')
133            attributes.remove('escape_literal')
134        db_attributes = [a for a in dir(self.db)
135            if not a.startswith('_')]
136        self.assertEqual(attributes, db_attributes)
137
138    def testAttributeDb(self):
139        self.assertEqual(self.db.db.db, dbname)
140
141    def testAttributeDbname(self):
142        self.assertEqual(self.db.dbname, dbname)
143
144    def testAttributeError(self):
145        error = self.db.error
146        self.assertTrue(not error or 'krb5_' in error)
147        self.assertEqual(self.db.error, self.db.db.error)
148
149    @unittest.skipIf(do_not_ask_for_host, do_not_ask_for_host_reason)
150    def testAttributeHost(self):
151        def_host = 'localhost'
152        host = self.db.host
153        self.assertIsInstance(host, str)
154        self.assertEqual(host, dbhost or def_host)
155        self.assertEqual(host, self.db.db.host)
156
157    def testAttributeOptions(self):
158        no_options = ''
159        options = self.db.options
160        self.assertEqual(options, no_options)
161        self.assertEqual(options, self.db.db.options)
162
163    def testAttributePort(self):
164        def_port = 5432
165        port = self.db.port
166        self.assertIsInstance(port, int)
167        self.assertEqual(port, dbport or def_port)
168        self.assertEqual(port, self.db.db.port)
169
170    def testAttributeProtocolVersion(self):
171        protocol_version = self.db.protocol_version
172        self.assertIsInstance(protocol_version, int)
173        self.assertTrue(2 <= protocol_version < 4)
174        self.assertEqual(protocol_version, self.db.db.protocol_version)
175
176    def testAttributeServerVersion(self):
177        server_version = self.db.server_version
178        self.assertIsInstance(server_version, int)
179        self.assertTrue(70400 <= server_version < 100000)
180        self.assertEqual(server_version, self.db.db.server_version)
181
182    def testAttributeStatus(self):
183        status_ok = 1
184        status = self.db.status
185        self.assertIsInstance(status, int)
186        self.assertEqual(status, status_ok)
187        self.assertEqual(status, self.db.db.status)
188
189    def testAttributeTty(self):
190        def_tty = ''
191        tty = self.db.tty
192        self.assertIsInstance(tty, str)
193        self.assertEqual(tty, def_tty)
194        self.assertEqual(tty, self.db.db.tty)
195
196    def testAttributeUser(self):
197        no_user = 'Deprecated facility'
198        user = self.db.user
199        self.assertTrue(user)
200        self.assertIsInstance(user, str)
201        self.assertNotEqual(user, no_user)
202        self.assertEqual(user, self.db.db.user)
203
204    def testMethodEscapeLiteral(self):
205        if self.db.server_version < 90000:  # PostgreSQL < 9.0
206            self.skipTest('Escaping functions not supported')
207        self.assertEqual(self.db.escape_literal(''), "''")
208
209    def testMethodEscapeIdentifier(self):
210        if self.db.server_version < 90000:  # PostgreSQL < 9.0
211            self.skipTest('Escaping functions not supported')
212        self.assertEqual(self.db.escape_identifier(''), '""')
213
214    def testMethodEscapeString(self):
215        self.assertEqual(self.db.escape_string(''), '')
216
217    def testMethodEscapeBytea(self):
218        self.assertEqual(self.db.escape_bytea('').replace(
219            '\\x', '').replace('\\', ''), '')
220
221    def testMethodUnescapeBytea(self):
222        self.assertEqual(self.db.unescape_bytea(''), '')
223
224    def testMethodQuery(self):
225        query = self.db.query
226        query("select 1+1")
227        query("select 1+$1+$2", 2, 3)
228        query("select 1+$1+$2", (2, 3))
229        query("select 1+$1+$2", [2, 3])
230        query("select 1+$1", 1)
231
232    def testMethodQueryEmpty(self):
233        self.assertRaises(ValueError, self.db.query, '')
234
235    def testMethodQueryProgrammingError(self):
236        try:
237            self.db.query("select 1/0")
238        except pg.ProgrammingError, error:
239            self.assertEqual(error.sqlstate, '22012')
240
241    def testMethodEndcopy(self):
242        try:
243            self.db.endcopy()
244        except IOError:
245            pass
246
247    def testMethodClose(self):
248        self.db.close()
249        try:
250            self.db.reset()
251        except pg.Error:
252            pass
253        else:
254            self.fail('Reset should give an error for a closed connection')
255        self.assertRaises(pg.InternalError, self.db.close)
256        self.assertRaises(pg.InternalError, self.db.query, 'select 1')
257
258    def testExistingConnection(self):
259        db = pg.DB(self.db.db)
260        self.assertEqual(self.db.db, db.db)
261        self.assertTrue(db.db)
262        db.close()
263        self.assertTrue(db.db)
264        db.reopen()
265        self.assertTrue(db.db)
266        db.close()
267        self.assertTrue(db.db)
268        db = pg.DB(self.db)
269        self.assertEqual(self.db.db, db.db)
270        db = pg.DB(db=self.db.db)
271        self.assertEqual(self.db.db, db.db)
272
273        class DB2:
274            pass
275
276        db2 = DB2()
277        db2._cnx = self.db.db
278        db = pg.DB(db2)
279        self.assertEqual(self.db.db, db.db)
280
281
282class TestDBClass(unittest.TestCase):
283    """Test the methods of the DB class wrapped pg connection."""
284
285    @classmethod
286    def setUpClass(cls):
287        db = DB()
288        db.query("drop table if exists test cascade")
289        db.query("create table test ("
290            "i2 smallint, i4 integer, i8 bigint,"
291            "d numeric, f4 real, f8 double precision, m money, "
292            "v4 varchar(4), c4 char(4), t text)")
293        db.query("create or replace view test_view as"
294            " select i4, v4 from test")
295        db.close()
296
297    @classmethod
298    def tearDownClass(cls):
299        db = DB()
300        db.query("drop table test cascade")
301        db.close()
302
303    def setUp(self):
304        self.db = DB()
305        query = self.db.query
306        query('set client_encoding=utf8')
307        query('set standard_conforming_strings=on')
308        query("set lc_monetary='C'")
309        query("set datestyle='ISO,YMD'")
310        try:
311            query('set bytea_output=hex')
312        except pg.ProgrammingError:  # PostgreSQL < 9.0
313            pass
314
315    def tearDown(self):
316        self.db.close()
317
318    def testEscapeLiteral(self):
319        if self.db.server_version < 90000:  # PostgreSQL < 9.0
320            self.skipTest('Escaping functions not supported')
321        f = self.db.escape_literal
322        self.assertEqual(f("plain"), "'plain'")
323        self.assertEqual(f("that's k\xe4se"), "'that''s k\xe4se'")
324        self.assertEqual(f(r"It's fine to have a \ inside."),
325            r" E'It''s fine to have a \\ inside.'")
326        self.assertEqual(f('No "quotes" must be escaped.'),
327            "'No \"quotes\" must be escaped.'")
328
329    def testEscapeIdentifier(self):
330        if self.db.server_version < 90000:  # PostgreSQL < 9.0
331            self.skipTest('Escaping functions not supported')
332        f = self.db.escape_identifier
333        self.assertEqual(f("plain"), '"plain"')
334        self.assertEqual(f("that's k\xe4se"), '"that\'s k\xe4se"')
335        self.assertEqual(f(r"It's fine to have a \ inside."),
336            '"It\'s fine to have a \\ inside."')
337        self.assertEqual(f('All "quotes" must be escaped.'),
338            '"All ""quotes"" must be escaped."')
339
340    def testEscapeString(self):
341        f = self.db.escape_string
342        self.assertEqual(f("plain"), "plain")
343        self.assertEqual(f("that's k\xe4se"), "that''s k\xe4se")
344        self.assertEqual(f(r"It's fine to have a \ inside."),
345            r"It''s fine to have a \ inside.")
346
347    def testEscapeBytea(self):
348        f = self.db.escape_bytea
349        # note that escape_byte always returns hex output since PostgreSQL 9.0,
350        # regardless of the bytea_output setting
351        if self.db.server_version < 90000:
352            self.assertEqual(f("plain"), r"plain")
353            self.assertEqual(f("that's k\xe4se"), r"that''s k\344se")
354            self.assertEqual(f('O\x00ps\xff!'), r"O\000ps\377!")
355        else:
356            self.assertEqual(f("plain"), r"\x706c61696e")
357            self.assertEqual(f("that's k\xe4se"), r"\x746861742773206be47365")
358            self.assertEqual(f('O\x00ps\xff!'), r"\x4f007073ff21")
359
360    def testUnescapeBytea(self):
361        f = self.db.unescape_bytea
362        self.assertEqual(f("plain"), "plain")
363        self.assertEqual(f("that's k\\344se"), "that's k\xe4se")
364        self.assertEqual(f(r'O\000ps\377!'), 'O\x00ps\xff!')
365        self.assertEqual(f(r"\\x706c61696e"), r"\x706c61696e")
366        self.assertEqual(f(r"\\x746861742773206be47365"),
367            r"\x746861742773206be47365")
368        self.assertEqual(f(r"\\x4f007073ff21"), r"\x4f007073ff21")
369
370    def testQuote(self):
371        f = self.db._quote
372        self.assertEqual(f(None, None), 'NULL')
373        self.assertEqual(f(None, 'int'), 'NULL')
374        self.assertEqual(f(None, 'float'), 'NULL')
375        self.assertEqual(f(None, 'num'), 'NULL')
376        self.assertEqual(f(None, 'money'), 'NULL')
377        self.assertEqual(f(None, 'bool'), 'NULL')
378        self.assertEqual(f(None, 'date'), 'NULL')
379        self.assertEqual(f('', 'int'), 'NULL')
380        self.assertEqual(f('', 'float'), 'NULL')
381        self.assertEqual(f('', 'num'), 'NULL')
382        self.assertEqual(f('', 'money'), 'NULL')
383        self.assertEqual(f('', 'bool'), 'NULL')
384        self.assertEqual(f('', 'date'), 'NULL')
385        self.assertEqual(f('', 'text'), "''")
386        self.assertEqual(f(0, 'int'), '0')
387        self.assertEqual(f(0, 'num'), '0')
388        self.assertEqual(f(1, 'int'), '1')
389        self.assertEqual(f(1, 'num'), '1')
390        self.assertEqual(f(-1, 'int'), '-1')
391        self.assertEqual(f(-1, 'num'), '-1')
392        self.assertEqual(f(123456789, 'int'), '123456789')
393        self.assertEqual(f(123456987, 'num'), '123456987')
394        self.assertEqual(f(1.23654789, 'num'), '1.23654789')
395        self.assertEqual(f(12365478.9, 'num'), '12365478.9')
396        self.assertEqual(f('123456789', 'num'), '123456789')
397        self.assertEqual(f('1.23456789', 'num'), '1.23456789')
398        self.assertEqual(f('12345678.9', 'num'), '12345678.9')
399        self.assertEqual(f(123, 'money'), '123')
400        self.assertEqual(f('123', 'money'), '123')
401        self.assertEqual(f(123.45, 'money'), '123.45')
402        self.assertEqual(f('123.45', 'money'), '123.45')
403        self.assertEqual(f(123.454, 'money'), '123.454')
404        self.assertEqual(f('123.454', 'money'), '123.454')
405        self.assertEqual(f(123.456, 'money'), '123.456')
406        self.assertEqual(f('123.456', 'money'), '123.456')
407        self.assertEqual(f('f', 'bool'), "'f'")
408        self.assertEqual(f('F', 'bool'), "'f'")
409        self.assertEqual(f('false', 'bool'), "'f'")
410        self.assertEqual(f('False', 'bool'), "'f'")
411        self.assertEqual(f('FALSE', 'bool'), "'f'")
412        self.assertEqual(f(0, 'bool'), "'f'")
413        self.assertEqual(f('0', 'bool'), "'f'")
414        self.assertEqual(f('-', 'bool'), "'f'")
415        self.assertEqual(f('n', 'bool'), "'f'")
416        self.assertEqual(f('N', 'bool'), "'f'")
417        self.assertEqual(f('no', 'bool'), "'f'")
418        self.assertEqual(f('off', 'bool'), "'f'")
419        self.assertEqual(f('t', 'bool'), "'t'")
420        self.assertEqual(f('T', 'bool'), "'t'")
421        self.assertEqual(f('true', 'bool'), "'t'")
422        self.assertEqual(f('True', 'bool'), "'t'")
423        self.assertEqual(f('TRUE', 'bool'), "'t'")
424        self.assertEqual(f(1, 'bool'), "'t'")
425        self.assertEqual(f(2, 'bool'), "'t'")
426        self.assertEqual(f(-1, 'bool'), "'t'")
427        self.assertEqual(f(0.5, 'bool'), "'t'")
428        self.assertEqual(f('1', 'bool'), "'t'")
429        self.assertEqual(f('y', 'bool'), "'t'")
430        self.assertEqual(f('Y', 'bool'), "'t'")
431        self.assertEqual(f('yes', 'bool'), "'t'")
432        self.assertEqual(f('on', 'bool'), "'t'")
433        self.assertEqual(f('01.01.2000', 'date'), "'01.01.2000'")
434        self.assertEqual(f(123, 'text'), "'123'")
435        self.assertEqual(f(1.23, 'text'), "'1.23'")
436        self.assertEqual(f('abc', 'text'), "'abc'")
437        self.assertEqual(f("ab'c", 'text'), "'ab''c'")
438        self.assertEqual(f('ab\\c', 'text'), "'ab\\c'")
439        self.assertEqual(f("a\\b'c", 'text'), "'a\\b''c'")
440        self.db.query('set standard_conforming_strings=off')
441        self.assertEqual(f('ab\\c', 'text'), "'ab\\\\c'")
442        self.assertEqual(f("a\\b'c", 'text'), "'a\\\\b''c'")
443
444    def testQuery(self):
445        query = self.db.query
446        query("drop table if exists test_table")
447        q = "create table test_table (n integer) with oids"
448        r = query(q)
449        self.assertIsNone(r)
450        q = "insert into test_table values (1)"
451        r = query(q)
452        self.assertIsInstance(r, int)
453        q = "insert into test_table select 2"
454        r = query(q)
455        self.assertIsInstance(r, int)
456        oid = r
457        q = "select oid from test_table where n=2"
458        r = query(q).getresult()
459        self.assertEqual(len(r), 1)
460        r = r[0]
461        self.assertEqual(len(r), 1)
462        r = r[0]
463        self.assertEqual(r, oid)
464        q = "insert into test_table select 3 union select 4 union select 5"
465        r = query(q)
466        self.assertIsInstance(r, str)
467        self.assertEqual(r, '3')
468        q = "update test_table set n=4 where n<5"
469        r = query(q)
470        self.assertIsInstance(r, str)
471        self.assertEqual(r, '4')
472        q = "delete from test_table"
473        r = query(q)
474        self.assertIsInstance(r, str)
475        self.assertEqual(r, '5')
476        query("drop table test_table")
477
478    def testMultipleQueries(self):
479        self.assertEqual(self.db.query(
480            "create temporary table test_multi (n integer);"
481            "insert into test_multi values (4711);"
482            "select n from test_multi").getresult()[0][0], 4711)
483
484    def testQueryWithParams(self):
485        query = self.db.query
486        query("drop table if exists test_table")
487        q = "create table test_table (n1 integer, n2 integer) with oids"
488        query(q)
489        q = "insert into test_table values ($1, $2)"
490        r = query(q, (1, 2))
491        self.assertIsInstance(r, int)
492        r = query(q, [3, 4])
493        self.assertIsInstance(r, int)
494        r = query(q, [5, 6])
495        self.assertIsInstance(r, int)
496        q = "select * from test_table order by 1, 2"
497        self.assertEqual(query(q).getresult(),
498            [(1, 2), (3, 4), (5, 6)])
499        q = "select * from test_table where n1=$1 and n2=$2"
500        self.assertEqual(query(q, 3, 4).getresult(), [(3, 4)])
501        q = "update test_table set n2=$2 where n1=$1"
502        r = query(q, 3, 7)
503        self.assertEqual(r, '1')
504        q = "select * from test_table order by 1, 2"
505        self.assertEqual(query(q).getresult(),
506            [(1, 2), (3, 7), (5, 6)])
507        q = "delete from test_table where n2!=$1"
508        r = query(q, 4)
509        self.assertEqual(r, '3')
510        query("drop table test_table")
511
512    def testEmptyQuery(self):
513        self.assertRaises(ValueError, self.db.query, '')
514
515    def testQueryProgrammingError(self):
516        try:
517            self.db.query("select 1/0")
518        except pg.ProgrammingError, error:
519            self.assertEqual(error.sqlstate, '22012')
520
521    def testPkey(self):
522        query = self.db.query
523        for n in range(4):
524            query("drop table if exists pkeytest%d" % n)
525        query("create table pkeytest0 ("
526            "a smallint)")
527        query("create table pkeytest1 ("
528            "b smallint primary key)")
529        query("create table pkeytest2 ("
530            "c smallint, d smallint primary key)")
531        query("create table pkeytest3 ("
532            "e smallint, f smallint, g smallint, "
533            "h smallint, i smallint, "
534            "primary key (f,h))")
535        pkey = self.db.pkey
536        self.assertRaises(KeyError, pkey, 'pkeytest0')
537        self.assertEqual(pkey('pkeytest1'), 'b')
538        self.assertEqual(pkey('pkeytest2'), 'd')
539        self.assertEqual(pkey('pkeytest3'), frozenset('fh'))
540        self.assertEqual(pkey('pkeytest0', 'none'), 'none')
541        self.assertEqual(pkey('pkeytest0'), 'none')
542        pkey(None, {'t': 'a', 'n.t': 'b'})
543        self.assertEqual(pkey('t'), 'a')
544        self.assertEqual(pkey('n.t'), 'b')
545        self.assertRaises(KeyError, pkey, 'pkeytest0')
546        for n in range(4):
547            query("drop table pkeytest%d" % n)
548
549    def testGetDatabases(self):
550        databases = self.db.get_databases()
551        self.assertIn('template0', databases)
552        self.assertIn('template1', databases)
553        self.assertNotIn('not existing database', databases)
554        self.assertIn('postgres', databases)
555        self.assertIn(dbname, databases)
556
557    def testGetTables(self):
558        get_tables = self.db.get_tables
559        result1 = get_tables()
560        self.assertIsInstance(result1, list)
561        for t in result1:
562            t = t.split('.', 1)
563            self.assertGreaterEqual(len(t), 2)
564            if len(t) > 2:
565                self.assertTrue(t[1].startswith('"'))
566            t = t[0]
567            self.assertNotEqual(t, 'information_schema')
568            self.assertFalse(t.startswith('pg_'))
569        tables = ('"A very Special Name"',
570            '"A_MiXeD_quoted_NaMe"', 'a1', 'a2',
571            'A_MiXeD_NaMe', '"another special name"',
572            'averyveryveryveryveryveryverylongtablename',
573            'b0', 'b3', 'x', 'xx', 'xXx', 'y', 'z')
574        for t in tables:
575            self.db.query('drop table if exists %s' % t)
576            self.db.query("create table %s"
577                " as select 0" % t)
578        result3 = get_tables()
579        result2 = []
580        for t in result3:
581            if t not in result1:
582                result2.append(t)
583        result3 = []
584        for t in tables:
585            if not t.startswith('"'):
586                t = t.lower()
587            result3.append('public.' + t)
588        self.assertEqual(result2, result3)
589        for t in result2:
590            self.db.query('drop table %s' % t)
591        result2 = get_tables()
592        self.assertEqual(result2, result1)
593
594    def testGetRelations(self):
595        get_relations = self.db.get_relations
596        result = get_relations()
597        self.assertIn('public.test', result)
598        self.assertIn('public.test_view', result)
599        result = get_relations('rv')
600        self.assertIn('public.test', result)
601        self.assertIn('public.test_view', result)
602        result = get_relations('r')
603        self.assertIn('public.test', result)
604        self.assertNotIn('public.test_view', result)
605        result = get_relations('v')
606        self.assertNotIn('public.test', result)
607        self.assertIn('public.test_view', result)
608        result = get_relations('cisSt')
609        self.assertNotIn('public.test', result)
610        self.assertNotIn('public.test_view', result)
611
612    def testAttnames(self):
613        self.assertRaises(pg.ProgrammingError,
614            self.db.get_attnames, 'does_not_exist')
615        self.assertRaises(pg.ProgrammingError,
616            self.db.get_attnames, 'has.too.many.dots')
617        for table in ('attnames_test_table', 'test table for attnames'):
618            self.db.query('drop table if exists "%s"' % table)
619            self.db.query('create table "%s" ('
620                'a smallint, b integer, c bigint, '
621                'e numeric, f float, f2 double precision, m money, '
622                'x smallint, y smallint, z smallint, '
623                'Normal_NaMe smallint, "Special Name" smallint, '
624                't text, u char(2), v varchar(2), '
625                'primary key (y, u)) with oids' % table)
626            attributes = self.db.get_attnames(table)
627            result = {'a': 'int', 'c': 'int', 'b': 'int',
628                'e': 'num', 'f': 'float', 'f2': 'float', 'm': 'money',
629                'normal_name': 'int', 'Special Name': 'int',
630                'u': 'text', 't': 'text', 'v': 'text',
631                'y': 'int', 'x': 'int', 'z': 'int', 'oid': 'int'}
632            self.assertEqual(attributes, result)
633            self.db.query('drop table "%s"' % table)
634
635    def testHasTablePrivilege(self):
636        can = self.db.has_table_privilege
637        self.assertEqual(can('test'), True)
638        self.assertEqual(can('test', 'select'), True)
639        self.assertEqual(can('test', 'SeLeCt'), True)
640        self.assertEqual(can('test', 'SELECT'), True)
641        self.assertEqual(can('test', 'insert'), True)
642        self.assertEqual(can('test', 'update'), True)
643        self.assertEqual(can('test', 'delete'), True)
644        self.assertRaises(pg.ProgrammingError, can, 'test', 'foobar')
645        self.assertRaises(pg.ProgrammingError, can, 'table_does_not_exist')
646
647    def testGet(self):
648        get = self.db.get
649        query = self.db.query
650        for table in ('get_test_table', 'test table for get'):
651            query('drop table if exists "%s"' % table)
652            query('create table "%s" ('
653                "n integer, t text) with oids" % table)
654            for n, t in enumerate('xyz'):
655                query('insert into "%s" values('"%d, '%s')"
656                    % (table, n + 1, t))
657            self.assertRaises(pg.ProgrammingError, get, table, 2)
658            r = get(table, 2, 'n')
659            oid_table = table
660            if ' ' in table:
661                oid_table = '"%s"' % oid_table
662            oid_table = 'oid(public.%s)' % oid_table
663            self.assertIn(oid_table, r)
664            oid = r[oid_table]
665            self.assertIsInstance(oid, int)
666            result = {'t': 'y', 'n': 2, oid_table: oid}
667            self.assertEqual(r, result)
668            self.assertEqual(get(table + ' *', 2, 'n'), r)
669            self.assertEqual(get(table, oid, 'oid')['t'], 'y')
670            self.assertEqual(get(table, 1, 'n')['t'], 'x')
671            self.assertEqual(get(table, 3, 'n')['t'], 'z')
672            self.assertEqual(get(table, 2, 'n')['t'], 'y')
673            self.assertRaises(pg.DatabaseError, get, table, 4, 'n')
674            r['n'] = 3
675            self.assertEqual(get(table, r, 'n')['t'], 'z')
676            self.assertEqual(get(table, 1, 'n')['t'], 'x')
677            query('alter table "%s" alter n set not null' % table)
678            query('alter table "%s" add primary key (n)' % table)
679            self.assertEqual(get(table, 3)['t'], 'z')
680            self.assertEqual(get(table, 1)['t'], 'x')
681            self.assertEqual(get(table, 2)['t'], 'y')
682            r['n'] = 1
683            self.assertEqual(get(table, r)['t'], 'x')
684            r['n'] = 3
685            self.assertEqual(get(table, r)['t'], 'z')
686            r['n'] = 2
687            self.assertEqual(get(table, r)['t'], 'y')
688            query('drop table "%s"' % table)
689
690    def testGetWithCompositeKey(self):
691        get = self.db.get
692        query = self.db.query
693        table = 'get_test_table_1'
694        query("drop table if exists %s" % table)
695        query("create table %s ("
696            "n integer, t text, primary key (n))" % table)
697        for n, t in enumerate('abc'):
698            query("insert into %s values("
699                "%d, '%s')" % (table, n + 1, t))
700        self.assertEqual(get(table, 2)['t'], 'b')
701        query("drop table %s" % table)
702        table = 'get_test_table_2'
703        query("drop table if exists %s" % table)
704        query("create table %s ("
705            "n integer, m integer, t text, primary key (n, m))" % table)
706        for n in range(3):
707            for m in range(2):
708                t = chr(ord('a') + 2 * n + m)
709                query("insert into %s values("
710                    "%d, %d, '%s')" % (table, n + 1, m + 1, t))
711        self.assertRaises(pg.ProgrammingError, get, table, 2)
712        self.assertEqual(get(table, dict(n=2, m=2))['t'], 'd')
713        self.assertEqual(get(table, dict(n=1, m=2),
714                             ('n', 'm'))['t'], 'b')
715        self.assertEqual(get(table, dict(n=3, m=2),
716                             frozenset(['n', 'm']))['t'], 'f')
717        query("drop table %s" % table)
718
719    def testGetFromView(self):
720        self.db.query('delete from test where i4=14')
721        self.db.query('insert into test (i4, v4) values('
722            "14, 'abc4')")
723        r = self.db.get('test_view', 14, 'i4')
724        self.assertIn('v4', r)
725        self.assertEqual(r['v4'], 'abc4')
726
727    def testGetLittleBobbyTables(self):
728        get = self.db.get
729        query = self.db.query
730        query("drop table if exists test_students")
731        query("create table test_students (firstname varchar primary key,"
732            " nickname varchar, grade char(2))")
733        query("insert into test_students values ("
734              "'D''Arcy', 'Darcey', 'A+')")
735        query("insert into test_students values ("
736              "'Sheldon', 'Moonpie', 'A+')")
737        query("insert into test_students values ("
738              "'Robert', 'Little Bobby Tables', 'D-')")
739        r = get('test_students', 'Sheldon')
740        self.assertEqual(r, dict(
741            firstname="Sheldon", nickname='Moonpie', grade='A+'))
742        r = get('test_students', 'Robert')
743        self.assertEqual(r, dict(
744            firstname="Robert", nickname='Little Bobby Tables', grade='D-'))
745        r = get('test_students', "D'Arcy")
746        self.assertEqual(r, dict(
747            firstname="D'Arcy", nickname='Darcey', grade='A+'))
748        try:
749            get('test_students', "D' Arcy")
750        except pg.DatabaseError as error:
751            self.assertEqual(str(error),
752                'No such record in public.test_students where firstname = '
753                "'D'' Arcy'")
754        try:
755            get('test_students', "Robert'); TRUNCATE TABLE test_students;--")
756        except pg.DatabaseError as error:
757            self.assertEqual(str(error),
758                'No such record in public.test_students where firstname = '
759                "'Robert''); TRUNCATE TABLE test_students;--'")
760        q = "select * from test_students order by 1 limit 4"
761        r = query(q).getresult()
762        self.assertEqual(len(r), 3)
763        self.assertEqual(r[1][2], 'D-')
764        query('drop table test_students')
765
766    def testInsert(self):
767        insert = self.db.insert
768        query = self.db.query
769        server_version = self.db.server_version
770        for table in ('insert_test_table', 'test table for insert'):
771            query('drop table if exists "%s"' % table)
772            query('create table "%s" ('
773                "i2 smallint, i4 integer, i8 bigint,"
774                " d numeric, f4 real, f8 double precision, m money,"
775                " v4 varchar(4), c4 char(4), t text,"
776                " b boolean, ts timestamp) with oids" % table)
777            oid_table = table
778            if ' ' in table:
779                oid_table = '"%s"' % oid_table
780            oid_table = 'oid(public.%s)' % oid_table
781            tests = [dict(i2=None, i4=None, i8=None),
782                (dict(i2='', i4='', i8=''), dict(i2=None, i4=None, i8=None)),
783                (dict(i2=0, i4=0, i8=0), dict(i2=0, i4=0, i8=0)),
784                dict(i2=42, i4=123456, i8=9876543210),
785                dict(i2=2 ** 15 - 1,
786                    i4=int(2 ** 31 - 1), i8=long(2 ** 63 - 1)),
787                dict(d=None), (dict(d=''), dict(d=None)),
788                dict(d=Decimal(0)), (dict(d=0), dict(d=Decimal(0))),
789                dict(f4=None, f8=None), dict(f4=0, f8=0),
790                (dict(f4='', f8=''), dict(f4=None, f8=None)),
791                (dict(d=1234.5, f4=1234.5, f8=1234.5),
792                      dict(d=Decimal('1234.5'))),
793                dict(d=Decimal('123.456789'), f4=12.375, f8=123.4921875),
794                dict(d=Decimal('123456789.9876543212345678987654321')),
795                dict(m=None), (dict(m=''), dict(m=None)),
796                dict(m=Decimal('-1234.56')),
797                (dict(m='-1234.56'), dict(m=Decimal('-1234.56'))),
798                dict(m=Decimal('1234.56')), dict(m=Decimal('123456')),
799                (dict(m='1234.56'), dict(m=Decimal('1234.56'))),
800                (dict(m=1234.5), dict(m=Decimal('1234.5'))),
801                (dict(m=-1234.5), dict(m=Decimal('-1234.5'))),
802                (dict(m=123456), dict(m=Decimal('123456'))),
803                (dict(m='1234567.89'), dict(m=Decimal('1234567.89'))),
804                dict(b=None), (dict(b=''), dict(b=None)),
805                dict(b='f'), dict(b='t'),
806                (dict(b=0), dict(b='f')), (dict(b=1), dict(b='t')),
807                (dict(b=False), dict(b='f')), (dict(b=True), dict(b='t')),
808                (dict(b='0'), dict(b='f')), (dict(b='1'), dict(b='t')),
809                (dict(b='n'), dict(b='f')), (dict(b='y'), dict(b='t')),
810                (dict(b='no'), dict(b='f')), (dict(b='yes'), dict(b='t')),
811                (dict(b='off'), dict(b='f')), (dict(b='on'), dict(b='t')),
812                dict(v4=None, c4=None, t=None),
813                (dict(v4='', c4='', t=''), dict(c4=' ' * 4)),
814                dict(v4='1234', c4='1234', t='1234' * 10),
815                dict(v4='abcd', c4='abcd', t='abcdefg'),
816                (dict(v4='abc', c4='abc', t='abc'), dict(c4='abc ')),
817                dict(ts=None), (dict(ts=''), dict(ts=None)),
818                (dict(ts=0), dict(ts=None)), (dict(ts=False), dict(ts=None)),
819                dict(ts='2012-12-21 00:00:00'),
820                (dict(ts='2012-12-21'), dict(ts='2012-12-21 00:00:00')),
821                dict(ts='2012-12-21 12:21:12'),
822                dict(ts='2013-01-05 12:13:14'),
823                dict(ts='current_timestamp')]
824            for test in tests:
825                if isinstance(test, dict):
826                    data = test
827                    change = {}
828                else:
829                    data, change = test
830                expect = data.copy()
831                expect.update(change)
832                if data.get('m') and server_version < 90100:
833                    # PostgreSQL < 9.1 cannot directly convert numbers to money
834                    data['m'] = "'%s'::money" % data['m']
835                self.assertEqual(insert(table, data), data)
836                self.assertIn(oid_table, data)
837                oid = data[oid_table]
838                self.assertIsInstance(oid, int)
839                data = dict(item for item in data.iteritems()
840                    if item[0] in expect)
841                ts = expect.get('ts')
842                if ts == 'current_timestamp':
843                    ts = expect['ts'] = data['ts']
844                    if len(ts) > 19:
845                        self.assertEqual(ts[19], '.')
846                        ts = ts[:19]
847                    else:
848                        self.assertEqual(len(ts), 19)
849                    self.assertTrue(ts[:4].isdigit())
850                    self.assertEqual(ts[4], '-')
851                    self.assertEqual(ts[10], ' ')
852                    self.assertTrue(ts[11:13].isdigit())
853                    self.assertEqual(ts[13], ':')
854                self.assertEqual(data, expect)
855                data = query(
856                    'select oid,* from "%s"' % table).dictresult()[0]
857                self.assertEqual(data['oid'], oid)
858                data = dict(item for item in data.iteritems()
859                    if item[0] in expect)
860                self.assertEqual(data, expect)
861                query('delete from "%s"' % table)
862            query('drop table "%s"' % table)
863
864    def testUpdate(self):
865        update = self.db.update
866        query = self.db.query
867        for table in ('update_test_table', 'test table for update'):
868            query('drop table if exists "%s"' % table)
869            query('create table "%s" ('
870                "n integer, t text) with oids" % table)
871            for n, t in enumerate('xyz'):
872                query('insert into "%s" values('
873                    "%d, '%s')" % (table, n + 1, t))
874            self.assertRaises(pg.ProgrammingError, self.db.get, table, 2)
875            r = self.db.get(table, 2, 'n')
876            r['t'] = 'u'
877            s = update(table, r)
878            self.assertEqual(s, r)
879            r = query('select t from "%s" where n=2' % table
880                      ).getresult()[0][0]
881            self.assertEqual(r, 'u')
882            query('drop table "%s"' % table)
883
884    def testUpdateWithCompositeKey(self):
885        update = self.db.update
886        query = self.db.query
887        table = 'update_test_table_1'
888        query("drop table if exists %s" % table)
889        query("create table %s ("
890            "n integer, t text, primary key (n))" % table)
891        for n, t in enumerate('abc'):
892            query("insert into %s values("
893                "%d, '%s')" % (table, n + 1, t))
894        self.assertRaises(pg.ProgrammingError, update,
895                          table, dict(t='b'))
896        self.assertEqual(update(table, dict(n=2, t='d'))['t'], 'd')
897        r = query('select t from "%s" where n=2' % table
898                  ).getresult()[0][0]
899        self.assertEqual(r, 'd')
900        query("drop table %s" % table)
901        table = 'update_test_table_2'
902        query("drop table if exists %s" % table)
903        query("create table %s ("
904            "n integer, m integer, t text, primary key (n, m))" % table)
905        for n in range(3):
906            for m in range(2):
907                t = chr(ord('a') + 2 * n + m)
908                query("insert into %s values("
909                    "%d, %d, '%s')" % (table, n + 1, m + 1, t))
910        self.assertRaises(pg.ProgrammingError, update,
911                          table, dict(n=2, t='b'))
912        self.assertEqual(update(table,
913                                dict(n=2, m=2, t='x'))['t'], 'x')
914        r = [r[0] for r in query('select t from "%s" where n=2'
915            ' order by m' % table).getresult()]
916        self.assertEqual(r, ['c', 'x'])
917        query("drop table %s" % table)
918
919    def testClear(self):
920        clear = self.db.clear
921        query = self.db.query
922        for table in ('clear_test_table', 'test table for clear'):
923            query('drop table if exists "%s"' % table)
924            query('create table "%s" ('
925                "n integer, b boolean, d date, t text)" % table)
926            r = clear(table)
927            result = {'n': 0, 'b': 'f', 'd': '', 't': ''}
928            self.assertEqual(r, result)
929            r['a'] = r['n'] = 1
930            r['d'] = r['t'] = 'x'
931            r['b'] = 't'
932            r['oid'] = 1L
933            r = clear(table, r)
934            result = {'a': 1, 'n': 0, 'b': 'f', 'd': '', 't': '', 'oid': 1L}
935            self.assertEqual(r, result)
936            query('drop table "%s"' % table)
937
938    def testDelete(self):
939        delete = self.db.delete
940        query = self.db.query
941        for table in ('delete_test_table', 'test table for delete'):
942            query('drop table if exists "%s"' % table)
943            query('create table "%s" ('
944                "n integer, t text) with oids" % table)
945            for n, t in enumerate('xyz'):
946                query('insert into "%s" values('
947                    "%d, '%s')" % (table, n + 1, t))
948            self.assertRaises(pg.ProgrammingError, self.db.get, table, 2)
949            r = self.db.get(table, 1, 'n')
950            s = delete(table, r)
951            self.assertEqual(s, 1)
952            r = self.db.get(table, 3, 'n')
953            s = delete(table, r)
954            self.assertEqual(s, 1)
955            s = delete(table, r)
956            self.assertEqual(s, 0)
957            r = query('select * from "%s"' % table).dictresult()
958            self.assertEqual(len(r), 1)
959            r = r[0]
960            result = {'n': 2, 't': 'y'}
961            self.assertEqual(r, result)
962            r = self.db.get(table, 2, 'n')
963            s = delete(table, r)
964            self.assertEqual(s, 1)
965            s = delete(table, r)
966            self.assertEqual(s, 0)
967            self.assertRaises(pg.DatabaseError, self.db.get, table, 2, 'n')
968            query('drop table "%s"' % table)
969
970    def testDeleteWithCompositeKey(self):
971        query = self.db.query
972        table = 'delete_test_table_1'
973        query("drop table if exists %s" % table)
974        query("create table %s ("
975            "n integer, t text, primary key (n))" % table)
976        for n, t in enumerate('abc'):
977            query("insert into %s values("
978                "%d, '%s')" % (table, n + 1, t))
979        self.assertRaises(pg.ProgrammingError, self.db.delete,
980            table, dict(t='b'))
981        self.assertEqual(self.db.delete(table, dict(n=2)), 1)
982        r = query('select t from "%s" where n=2' % table
983                  ).getresult()
984        self.assertEqual(r, [])
985        self.assertEqual(self.db.delete(table, dict(n=2)), 0)
986        r = query('select t from "%s" where n=3' % table
987                  ).getresult()[0][0]
988        self.assertEqual(r, 'c')
989        query("drop table %s" % table)
990        table = 'delete_test_table_2'
991        query("drop table if exists %s" % table)
992        query("create table %s ("
993            "n integer, m integer, t text, primary key (n, m))" % table)
994        for n in range(3):
995            for m in range(2):
996                t = chr(ord('a') + 2 * n + m)
997                query("insert into %s values("
998                    "%d, %d, '%s')" % (table, n + 1, m + 1, t))
999        self.assertRaises(pg.ProgrammingError, self.db.delete,
1000            table, dict(n=2, t='b'))
1001        self.assertEqual(self.db.delete(table, dict(n=2, m=2)), 1)
1002        r = [r[0] for r in query('select t from "%s" where n=2'
1003            ' order by m' % table).getresult()]
1004        self.assertEqual(r, ['c'])
1005        self.assertEqual(self.db.delete(table, dict(n=2, m=2)), 0)
1006        r = [r[0] for r in query('select t from "%s" where n=3'
1007            ' order by m' % table).getresult()]
1008        self.assertEqual(r, ['e', 'f'])
1009        self.assertEqual(self.db.delete(table, dict(n=3, m=1)), 1)
1010        r = [r[0] for r in query('select t from "%s" where n=3'
1011            ' order by m' % table).getresult()]
1012        self.assertEqual(r, ['f'])
1013        query("drop table %s" % table)
1014
1015    def testTransaction(self):
1016        query = self.db.query
1017        query("drop table if exists test_table")
1018        query("create table test_table (n integer)")
1019        self.db.begin()
1020        query("insert into test_table values (1)")
1021        query("insert into test_table values (2)")
1022        self.db.commit()
1023        self.db.begin()
1024        query("insert into test_table values (3)")
1025        query("insert into test_table values (4)")
1026        self.db.rollback()
1027        self.db.begin()
1028        query("insert into test_table values (5)")
1029        self.db.savepoint('before6')
1030        query("insert into test_table values (6)")
1031        self.db.rollback('before6')
1032        query("insert into test_table values (7)")
1033        self.db.commit()
1034        self.db.begin()
1035        self.db.savepoint('before8')
1036        query("insert into test_table values (8)")
1037        self.db.release('before8')
1038        self.assertRaises(pg.ProgrammingError, self.db.rollback, 'before8')
1039        self.db.commit()
1040        self.db.start()
1041        query("insert into test_table values (9)")
1042        self.db.end()
1043        r = [r[0] for r in query(
1044            "select * from test_table order by 1").getresult()]
1045        self.assertEqual(r, [1, 2, 5, 7, 9])
1046        query("drop table test_table")
1047
1048    @unittest.skipIf(no_with, 'context managers not supported')
1049    def testContextManager(self):
1050        query = self.db.query
1051        query("drop table if exists test_table")
1052        query("create table test_table (n integer check(n>0))")
1053        # wrap "with" statements to avoid SyntaxError in Python < 2.5
1054        exec """from __future__ import with_statement\nif True:
1055        with self.db:
1056            query("insert into test_table values (1)")
1057            query("insert into test_table values (2)")
1058        try:
1059            with self.db:
1060                query("insert into test_table values (3)")
1061                query("insert into test_table values (4)")
1062                raise ValueError('test transaction should rollback')
1063        except ValueError, error:
1064            self.assertEqual(str(error), 'test transaction should rollback')
1065        with self.db:
1066            query("insert into test_table values (5)")
1067        try:
1068            with self.db:
1069                query("insert into test_table values (6)")
1070                query("insert into test_table values (-1)")
1071        except pg.ProgrammingError, error:
1072            self.assertTrue('check' in str(error))
1073        with self.db:
1074            query("insert into test_table values (7)")\n"""
1075        r = [r[0] for r in query(
1076            "select * from test_table order by 1").getresult()]
1077        self.assertEqual(r, [1, 2, 5, 7])
1078        query("drop table test_table")
1079
1080    def testBytea(self):
1081        query = self.db.query
1082        query('drop table if exists bytea_test')
1083        query('create table bytea_test ('
1084            'data bytea)')
1085        s = "It's all \\ kinds \x00 of\r nasty \xff stuff!\n"
1086        r = self.db.escape_bytea(s)
1087        query('insert into bytea_test values('
1088            "'%s')" % r)
1089        r = query('select * from bytea_test').getresult()
1090        self.assertTrue(len(r) == 1)
1091        r = r[0]
1092        self.assertTrue(len(r) == 1)
1093        r = r[0]
1094        r = self.db.unescape_bytea(r)
1095        self.assertEqual(r, s)
1096        query('drop table bytea_test')
1097
1098    def testDebugWithCallable(self):
1099        if debug:
1100            self.assertEqual(self.db.debug, debug)
1101        else:
1102            self.assertIsNone(self.db.debug)
1103        s = []
1104        self.db.debug = s.append
1105        try:
1106            self.db.query("select 1")
1107            self.db.query("select 2")
1108            self.assertEqual(s, ["select 1", "select 2"])
1109        finally:
1110            self.db.debug = debug
1111
1112
1113class TestSchemas(unittest.TestCase):
1114    """Test correct handling of schemas (namespaces)."""
1115
1116    @classmethod
1117    def setUpClass(cls):
1118        db = DB()
1119        query = db.query
1120        query("set client_min_messages=warning")
1121        for num_schema in range(5):
1122            if num_schema:
1123                schema = "s%d" % num_schema
1124                query("drop schema if exists %s cascade" % (schema,))
1125                try:
1126                    query("create schema %s" % (schema,))
1127                except pg.ProgrammingError:
1128                    raise RuntimeError("The test user cannot create schemas.\n"
1129                        "Grant create on database %s to the user"
1130                        " for running these tests." % dbname)
1131            else:
1132                schema = "public"
1133                query("drop table if exists %s.t" % (schema,))
1134                query("drop table if exists %s.t%d" % (schema, num_schema))
1135            query("create table %s.t with oids as select 1 as n, %d as d"
1136                  % (schema, num_schema))
1137            query("create table %s.t%d with oids as select 1 as n, %d as d"
1138                  % (schema, num_schema, num_schema))
1139        db.close()
1140
1141    @classmethod
1142    def tearDownClass(cls):
1143        db = DB()
1144        query = db.query
1145        query("set client_min_messages=warning")
1146        for num_schema in range(5):
1147            if num_schema:
1148                schema = "s%d" % num_schema
1149                query("drop schema %s cascade" % (schema,))
1150            else:
1151                schema = "public"
1152                query("drop table %s.t" % (schema,))
1153                query("drop table %s.t%d" % (schema, num_schema))
1154        db.close()
1155
1156    def setUp(self):
1157        self.db = DB()
1158        self.db.query("set client_min_messages=warning")
1159
1160    def tearDown(self):
1161        self.db.close()
1162
1163    def testGetTables(self):
1164        tables = self.db.get_tables()
1165        for num_schema in range(5):
1166            if num_schema:
1167                schema = "s" + str(num_schema)
1168            else:
1169                schema = "public"
1170            for t in (schema + ".t",
1171                    schema + ".t" + str(num_schema)):
1172                self.assertIn(t, tables)
1173
1174    def testGetAttnames(self):
1175        get_attnames = self.db.get_attnames
1176        query = self.db.query
1177        result = {'oid': 'int', 'd': 'int', 'n': 'int'}
1178        r = get_attnames("t")
1179        self.assertEqual(r, result)
1180        r = get_attnames("s4.t4")
1181        self.assertEqual(r, result)
1182        query("drop table if exists s3.t3m")
1183        query("create table s3.t3m with oids as select 1 as m")
1184        result_m = {'oid': 'int', 'm': 'int'}
1185        r = get_attnames("s3.t3m")
1186        self.assertEqual(r, result_m)
1187        query("set search_path to s1,s3")
1188        r = get_attnames("t3")
1189        self.assertEqual(r, result)
1190        r = get_attnames("t3m")
1191        self.assertEqual(r, result_m)
1192        query("drop table s3.t3m")
1193
1194    def testGet(self):
1195        get = self.db.get
1196        query = self.db.query
1197        PrgError = pg.ProgrammingError
1198        self.assertEqual(get("t", 1, 'n')['d'], 0)
1199        self.assertEqual(get("t0", 1, 'n')['d'], 0)
1200        self.assertEqual(get("public.t", 1, 'n')['d'], 0)
1201        self.assertEqual(get("public.t0", 1, 'n')['d'], 0)
1202        self.assertRaises(PrgError, get, "public.t1", 1, 'n')
1203        self.assertEqual(get("s1.t1", 1, 'n')['d'], 1)
1204        self.assertEqual(get("s3.t", 1, 'n')['d'], 3)
1205        query("set search_path to s2,s4")
1206        self.assertRaises(PrgError, get, "t1", 1, 'n')
1207        self.assertEqual(get("t4", 1, 'n')['d'], 4)
1208        self.assertRaises(PrgError, get, "t3", 1, 'n')
1209        self.assertEqual(get("t", 1, 'n')['d'], 2)
1210        self.assertEqual(get("s3.t3", 1, 'n')['d'], 3)
1211        query("set search_path to s1,s3")
1212        self.assertRaises(PrgError, get, "t2", 1, 'n')
1213        self.assertEqual(get("t3", 1, 'n')['d'], 3)
1214        self.assertRaises(PrgError, get, "t4", 1, 'n')
1215        self.assertEqual(get("t", 1, 'n')['d'], 1)
1216        self.assertEqual(get("s4.t4", 1, 'n')['d'], 4)
1217
1218    def testMangling(self):
1219        get = self.db.get
1220        query = self.db.query
1221        r = get("t", 1, 'n')
1222        self.assertIn('oid(public.t)', r)
1223        query("set search_path to s2")
1224        r = get("t2", 1, 'n')
1225        self.assertIn('oid(s2.t2)', r)
1226        query("set search_path to s3")
1227        r = get("t", 1, 'n')
1228        self.assertIn('oid(s3.t)', r)
1229
1230
1231if __name__ == '__main__':
1232    unittest.main()
Note: See TracBrowser for help on using the repository browser.