source: trunk/module/tests/test_classic_dbwrapper.py @ 680

Last change on this file since 680 was 680, checked in by cito, 4 years ago

Remove the deprecated tty parameter and attribute

This parameter has been ignored by PostgreSQL since version 7.4.

  • Property svn:executable set to *
  • Property svn:keywords set to Id
File size: 47.8 KB
Line 
1#! /usr/bin/python
2# -*- coding: utf-8 -*-
3
4"""Test the classic PyGreSQL interface.
5
6Sub-tests for the DB wrapper object.
7
8Contributed by Christoph Zwerschke.
9
10These tests need a database to test against.
11
12"""
13
14try:
15    import unittest2 as unittest  # for Python < 2.7
16except ImportError:
17    import unittest
18import os
19
20import pg  # the module under test
21
22from decimal import Decimal
23
24# We need a database to test against.  If LOCAL_PyGreSQL.py exists we will
25# get our information from that.  Otherwise we use the defaults.
26# The current user must have create schema privilege on the database.
27dbname = 'unittest'
28dbhost = None
29dbport = 5432
30
31debug = False  # let DB wrapper print debugging output
32
33try:
34    from .LOCAL_PyGreSQL import *
35except (ImportError, ValueError):
36    try:
37        from LOCAL_PyGreSQL import *
38    except ImportError:
39        pass
40
41try:
42    long
43except NameError:  # Python >= 3.0
44    long = int
45
46try:
47    unicode
48except NameError:  # Python >= 3.0
49    unicode = str
50
51windows = os.name == 'nt'
52
53# There is a known a bug in libpq under Windows which can cause
54# the interface to crash when calling PQhost():
55do_not_ask_for_host = windows
56do_not_ask_for_host_reason = 'libpq issue on Windows'
57
58
59def DB():
60    """Create a DB wrapper object connecting to the test database."""
61    db = pg.DB(dbname, dbhost, dbport)
62    if debug:
63        db.debug = debug
64    db.query("set client_min_messages=warning")
65    return db
66
67
68class TestDBClassBasic(unittest.TestCase):
69    """Test existence of the DB class wrapped pg connection methods."""
70
71    def setUp(self):
72        self.db = DB()
73
74    def tearDown(self):
75        try:
76            self.db.close()
77        except pg.InternalError:
78            pass
79
80    def testAllDBAttributes(self):
81        attributes = [
82            'begin',
83            'cancel',
84            'clear',
85            'close',
86            'commit',
87            'db',
88            'dbname',
89            'debug',
90            'delete',
91            'end',
92            'endcopy',
93            'error',
94            'escape_bytea',
95            'escape_identifier',
96            'escape_literal',
97            'escape_string',
98            'fileno',
99            'get',
100            'get_attnames',
101            'get_databases',
102            'get_notice_receiver',
103            'get_relations',
104            'get_tables',
105            'getline',
106            'getlo',
107            'getnotify',
108            'has_table_privilege',
109            'host',
110            'insert',
111            'inserttable',
112            'locreate',
113            'loimport',
114            'notification_handler',
115            'options',
116            'parameter',
117            'pkey',
118            'port',
119            'protocol_version',
120            'putline',
121            'query',
122            'release',
123            'reopen',
124            'reset',
125            'rollback',
126            'savepoint',
127            'server_version',
128            'set_notice_receiver',
129            'source',
130            'start',
131            'status',
132            'transaction',
133            'unescape_bytea',
134            'update',
135            'use_regtypes',
136            'user',
137        ]
138        db_attributes = [a for a in dir(self.db)
139            if not a.startswith('_')]
140        self.assertEqual(attributes, db_attributes)
141
142    def testAttributeDb(self):
143        self.assertEqual(self.db.db.db, dbname)
144
145    def testAttributeDbname(self):
146        self.assertEqual(self.db.dbname, dbname)
147
148    def testAttributeError(self):
149        error = self.db.error
150        self.assertTrue(not error or 'krb5_' in error)
151        self.assertEqual(self.db.error, self.db.db.error)
152
153    @unittest.skipIf(do_not_ask_for_host, do_not_ask_for_host_reason)
154    def testAttributeHost(self):
155        def_host = 'localhost'
156        host = self.db.host
157        self.assertIsInstance(host, str)
158        self.assertEqual(host, dbhost or def_host)
159        self.assertEqual(host, self.db.db.host)
160
161    def testAttributeOptions(self):
162        no_options = ''
163        options = self.db.options
164        self.assertEqual(options, no_options)
165        self.assertEqual(options, self.db.db.options)
166
167    def testAttributePort(self):
168        def_port = 5432
169        port = self.db.port
170        self.assertIsInstance(port, int)
171        self.assertEqual(port, dbport or def_port)
172        self.assertEqual(port, self.db.db.port)
173
174    def testAttributeProtocolVersion(self):
175        protocol_version = self.db.protocol_version
176        self.assertIsInstance(protocol_version, int)
177        self.assertTrue(2 <= protocol_version < 4)
178        self.assertEqual(protocol_version, self.db.db.protocol_version)
179
180    def testAttributeServerVersion(self):
181        server_version = self.db.server_version
182        self.assertIsInstance(server_version, int)
183        self.assertTrue(70400 <= server_version < 100000)
184        self.assertEqual(server_version, self.db.db.server_version)
185
186    def testAttributeStatus(self):
187        status_ok = 1
188        status = self.db.status
189        self.assertIsInstance(status, int)
190        self.assertEqual(status, status_ok)
191        self.assertEqual(status, self.db.db.status)
192
193    def testAttributeUser(self):
194        no_user = 'Deprecated facility'
195        user = self.db.user
196        self.assertTrue(user)
197        self.assertIsInstance(user, str)
198        self.assertNotEqual(user, no_user)
199        self.assertEqual(user, self.db.db.user)
200
201    def testMethodEscapeLiteral(self):
202        self.assertEqual(self.db.escape_literal(''), "''")
203
204    def testMethodEscapeIdentifier(self):
205        self.assertEqual(self.db.escape_identifier(''), '""')
206
207    def testMethodEscapeString(self):
208        self.assertEqual(self.db.escape_string(''), '')
209
210    def testMethodEscapeBytea(self):
211        self.assertEqual(self.db.escape_bytea('').replace(
212            '\\x', '').replace('\\', ''), '')
213
214    def testMethodUnescapeBytea(self):
215        self.assertEqual(self.db.unescape_bytea(''), b'')
216
217    def testMethodQuery(self):
218        query = self.db.query
219        query("select 1+1")
220        query("select 1+$1+$2", 2, 3)
221        query("select 1+$1+$2", (2, 3))
222        query("select 1+$1+$2", [2, 3])
223        query("select 1+$1", 1)
224
225    def testMethodQueryEmpty(self):
226        self.assertRaises(ValueError, self.db.query, '')
227
228    def testMethodQueryProgrammingError(self):
229        try:
230            self.db.query("select 1/0")
231        except pg.ProgrammingError as error:
232            self.assertEqual(error.sqlstate, '22012')
233
234    def testMethodEndcopy(self):
235        try:
236            self.db.endcopy()
237        except IOError:
238            pass
239
240    def testMethodClose(self):
241        self.db.close()
242        try:
243            self.db.reset()
244        except pg.Error:
245            pass
246        else:
247            self.fail('Reset should give an error for a closed connection')
248        self.assertRaises(pg.InternalError, self.db.close)
249        self.assertRaises(pg.InternalError, self.db.query, 'select 1')
250
251    def testExistingConnection(self):
252        db = pg.DB(self.db.db)
253        self.assertEqual(self.db.db, db.db)
254        self.assertTrue(db.db)
255        db.close()
256        self.assertTrue(db.db)
257        db.reopen()
258        self.assertTrue(db.db)
259        db.close()
260        self.assertTrue(db.db)
261        db = pg.DB(self.db)
262        self.assertEqual(self.db.db, db.db)
263        db = pg.DB(db=self.db.db)
264        self.assertEqual(self.db.db, db.db)
265
266        class DB2:
267            pass
268
269        db2 = DB2()
270        db2._cnx = self.db.db
271        db = pg.DB(db2)
272        self.assertEqual(self.db.db, db.db)
273
274
275class TestDBClass(unittest.TestCase):
276    """Test the methods of the DB class wrapped pg connection."""
277
278    @classmethod
279    def setUpClass(cls):
280        db = DB()
281        db.query("drop table if exists test cascade")
282        db.query("create table test ("
283            "i2 smallint, i4 integer, i8 bigint,"
284            "d numeric, f4 real, f8 double precision, m money, "
285            "v4 varchar(4), c4 char(4), t text)")
286        db.query("create or replace view test_view as"
287            " select i4, v4 from test")
288        db.close()
289
290    @classmethod
291    def tearDownClass(cls):
292        db = DB()
293        db.query("drop table test cascade")
294        db.close()
295
296    def setUp(self):
297        self.db = DB()
298        query = self.db.query
299        query('set client_encoding=utf8')
300        query('set standard_conforming_strings=on')
301        query('set bytea_output=hex')
302        query("set lc_monetary='C'")
303
304    def tearDown(self):
305        self.db.close()
306
307    def testClassName(self):
308        self.assertEqual(self.db.__class__.__name__, 'DB')
309
310    def testModuleName(self):
311        self.assertEqual(self.db.__module__, 'pg')
312        self.assertEqual(self.db.__class__.__module__, 'pg')
313
314    def testEscapeLiteral(self):
315        f = self.db.escape_literal
316        r = f(b"plain")
317        self.assertIsInstance(r, bytes)
318        self.assertEqual(r, b"'plain'")
319        r = f(u"plain")
320        self.assertIsInstance(r, unicode)
321        self.assertEqual(r, u"'plain'")
322        r = f(u"that's kÀse".encode('utf-8'))
323        self.assertIsInstance(r, bytes)
324        self.assertEqual(r, u"'that''s kÀse'".encode('utf-8'))
325        r = f(u"that's kÀse")
326        self.assertIsInstance(r, unicode)
327        self.assertEqual(r, u"'that''s kÀse'")
328        self.assertEqual(f(r"It's fine to have a \ inside."),
329            r" E'It''s fine to have a \\ inside.'")
330        self.assertEqual(f('No "quotes" must be escaped.'),
331            "'No \"quotes\" must be escaped.'")
332
333    def testEscapeIdentifier(self):
334        f = self.db.escape_identifier
335        r = f(b"plain")
336        self.assertIsInstance(r, bytes)
337        self.assertEqual(r, b'"plain"')
338        r = f(u"plain")
339        self.assertIsInstance(r, unicode)
340        self.assertEqual(r, u'"plain"')
341        r = f(u"that's kÀse".encode('utf-8'))
342        self.assertIsInstance(r, bytes)
343        self.assertEqual(r, u'"that\'s kÀse"'.encode('utf-8'))
344        r = f(u"that's kÀse")
345        self.assertIsInstance(r, unicode)
346        self.assertEqual(r, u'"that\'s kÀse"')
347        self.assertEqual(f(r"It's fine to have a \ inside."),
348            '"It\'s fine to have a \\ inside."')
349        self.assertEqual(f('All "quotes" must be escaped.'),
350            '"All ""quotes"" must be escaped."')
351
352    def testEscapeString(self):
353        f = self.db.escape_string
354        r = f(b"plain")
355        self.assertIsInstance(r, bytes)
356        self.assertEqual(r, b"plain")
357        r = f(u"plain")
358        self.assertIsInstance(r, unicode)
359        self.assertEqual(r, u"plain")
360        r = f(u"that's kÀse".encode('utf-8'))
361        self.assertIsInstance(r, bytes)
362        self.assertEqual(r, u"that''s kÀse".encode('utf-8'))
363        r = f(u"that's kÀse")
364        self.assertIsInstance(r, unicode)
365        self.assertEqual(r, u"that''s kÀse")
366        self.assertEqual(f(r"It's fine to have a \ inside."),
367            r"It''s fine to have a \ inside.")
368
369    def testEscapeBytea(self):
370        f = self.db.escape_bytea
371        # note that escape_byte always returns hex output since Pg 9.0,
372        # regardless of the bytea_output setting
373        r = f(b'plain')
374        self.assertIsInstance(r, bytes)
375        self.assertEqual(r, b'\\x706c61696e')
376        r = f(u'plain')
377        self.assertIsInstance(r, unicode)
378        self.assertEqual(r, u'\\x706c61696e')
379        r = f(u"das is' kÀse".encode('utf-8'))
380        self.assertIsInstance(r, bytes)
381        self.assertEqual(r, b'\\x64617320697327206bc3a47365')
382        r = f(u"das is' kÀse")
383        self.assertIsInstance(r, unicode)
384        self.assertEqual(r, u'\\x64617320697327206bc3a47365')
385        self.assertEqual(f(b'O\x00ps\xff!'), b'\\x4f007073ff21')
386
387    def testUnescapeBytea(self):
388        f = self.db.unescape_bytea
389        r = f(b'plain')
390        self.assertIsInstance(r, bytes)
391        self.assertEqual(r, b'plain')
392        r = f(u'plain')
393        self.assertIsInstance(r, bytes)
394        self.assertEqual(r, b'plain')
395        r = f(b"das is' k\\303\\244se")
396        self.assertIsInstance(r, bytes)
397        self.assertEqual(r, u"das is' kÀse".encode('utf8'))
398        r = f(u"das is' k\\303\\244se")
399        self.assertIsInstance(r, bytes)
400        self.assertEqual(r, u"das is' kÀse".encode('utf8'))
401        self.assertEqual(f(r'O\\000ps\\377!'), b'O\\000ps\\377!')
402        self.assertEqual(f(r'\\x706c61696e'), b'\\x706c61696e')
403        self.assertEqual(f(r'\\x746861742773206be47365'),
404            b'\\x746861742773206be47365')
405        self.assertEqual(f(r'\\x4f007073ff21'), b'\\x4f007073ff21')
406
407    def testQuote(self):
408        f = self.db._quote
409        self.assertEqual(f(None, None), 'NULL')
410        self.assertEqual(f(None, 'int'), 'NULL')
411        self.assertEqual(f(None, 'float'), 'NULL')
412        self.assertEqual(f(None, 'num'), 'NULL')
413        self.assertEqual(f(None, 'money'), 'NULL')
414        self.assertEqual(f(None, 'bool'), 'NULL')
415        self.assertEqual(f(None, 'date'), 'NULL')
416        self.assertEqual(f('', 'int'), 'NULL')
417        self.assertEqual(f('', 'float'), 'NULL')
418        self.assertEqual(f('', 'num'), 'NULL')
419        self.assertEqual(f('', 'money'), 'NULL')
420        self.assertEqual(f('', 'bool'), 'NULL')
421        self.assertEqual(f('', 'date'), 'NULL')
422        self.assertEqual(f('', 'text'), "''")
423        self.assertEqual(f(0, 'int'), '0')
424        self.assertEqual(f(0, 'num'), '0')
425        self.assertEqual(f(1, 'int'), '1')
426        self.assertEqual(f(1, 'num'), '1')
427        self.assertEqual(f(-1, 'int'), '-1')
428        self.assertEqual(f(-1, 'num'), '-1')
429        self.assertEqual(f(123456789, 'int'), '123456789')
430        self.assertEqual(f(123456987, 'num'), '123456987')
431        self.assertEqual(f(1.23654789, 'num'), '1.23654789')
432        self.assertEqual(f(12365478.9, 'num'), '12365478.9')
433        self.assertEqual(f('123456789', 'num'), '123456789')
434        self.assertEqual(f('1.23456789', 'num'), '1.23456789')
435        self.assertEqual(f('12345678.9', 'num'), '12345678.9')
436        self.assertEqual(f(123, 'money'), '123')
437        self.assertEqual(f('123', 'money'), '123')
438        self.assertEqual(f(123.45, 'money'), '123.45')
439        self.assertEqual(f('123.45', 'money'), '123.45')
440        self.assertEqual(f(123.454, 'money'), '123.454')
441        self.assertEqual(f('123.454', 'money'), '123.454')
442        self.assertEqual(f(123.456, 'money'), '123.456')
443        self.assertEqual(f('123.456', 'money'), '123.456')
444        self.assertEqual(f('f', 'bool'), "'f'")
445        self.assertEqual(f('F', 'bool'), "'f'")
446        self.assertEqual(f('false', 'bool'), "'f'")
447        self.assertEqual(f('False', 'bool'), "'f'")
448        self.assertEqual(f('FALSE', 'bool'), "'f'")
449        self.assertEqual(f(0, 'bool'), "'f'")
450        self.assertEqual(f('0', 'bool'), "'f'")
451        self.assertEqual(f('-', 'bool'), "'f'")
452        self.assertEqual(f('n', 'bool'), "'f'")
453        self.assertEqual(f('N', 'bool'), "'f'")
454        self.assertEqual(f('no', 'bool'), "'f'")
455        self.assertEqual(f('off', 'bool'), "'f'")
456        self.assertEqual(f('t', 'bool'), "'t'")
457        self.assertEqual(f('T', 'bool'), "'t'")
458        self.assertEqual(f('true', 'bool'), "'t'")
459        self.assertEqual(f('True', 'bool'), "'t'")
460        self.assertEqual(f('TRUE', 'bool'), "'t'")
461        self.assertEqual(f(1, 'bool'), "'t'")
462        self.assertEqual(f(2, 'bool'), "'t'")
463        self.assertEqual(f(-1, 'bool'), "'t'")
464        self.assertEqual(f(0.5, 'bool'), "'t'")
465        self.assertEqual(f('1', 'bool'), "'t'")
466        self.assertEqual(f('y', 'bool'), "'t'")
467        self.assertEqual(f('Y', 'bool'), "'t'")
468        self.assertEqual(f('yes', 'bool'), "'t'")
469        self.assertEqual(f('on', 'bool'), "'t'")
470        self.assertEqual(f('01.01.2000', 'date'), "'01.01.2000'")
471        self.assertEqual(f(123, 'text'), "'123'")
472        self.assertEqual(f(1.23, 'text'), "'1.23'")
473        self.assertEqual(f('abc', 'text'), "'abc'")
474        self.assertEqual(f("ab'c", 'text'), "'ab''c'")
475        self.assertEqual(f('ab\\c', 'text'), "'ab\\c'")
476        self.assertEqual(f("a\\b'c", 'text'), "'a\\b''c'")
477        self.db.query('set standard_conforming_strings=off')
478        self.assertEqual(f('ab\\c', 'text'), "'ab\\\\c'")
479        self.assertEqual(f("a\\b'c", 'text'), "'a\\\\b''c'")
480
481    def testQuery(self):
482        query = self.db.query
483        query("drop table if exists test_table")
484        q = "create table test_table (n integer) with oids"
485        r = query(q)
486        self.assertIsNone(r)
487        q = "insert into test_table values (1)"
488        r = query(q)
489        self.assertIsInstance(r, int)
490        q = "insert into test_table select 2"
491        r = query(q)
492        self.assertIsInstance(r, int)
493        oid = r
494        q = "select oid from test_table where n=2"
495        r = query(q).getresult()
496        self.assertEqual(len(r), 1)
497        r = r[0]
498        self.assertEqual(len(r), 1)
499        r = r[0]
500        self.assertEqual(r, oid)
501        q = "insert into test_table select 3 union select 4 union select 5"
502        r = query(q)
503        self.assertIsInstance(r, str)
504        self.assertEqual(r, '3')
505        q = "update test_table set n=4 where n<5"
506        r = query(q)
507        self.assertIsInstance(r, str)
508        self.assertEqual(r, '4')
509        q = "delete from test_table"
510        r = query(q)
511        self.assertIsInstance(r, str)
512        self.assertEqual(r, '5')
513        query("drop table test_table")
514
515    def testMultipleQueries(self):
516        self.assertEqual(self.db.query(
517            "create temporary table test_multi (n integer);"
518            "insert into test_multi values (4711);"
519            "select n from test_multi").getresult()[0][0], 4711)
520
521    def testQueryWithParams(self):
522        query = self.db.query
523        query("drop table if exists test_table")
524        q = "create table test_table (n1 integer, n2 integer) with oids"
525        query(q)
526        q = "insert into test_table values ($1, $2)"
527        r = query(q, (1, 2))
528        self.assertIsInstance(r, int)
529        r = query(q, [3, 4])
530        self.assertIsInstance(r, int)
531        r = query(q, [5, 6])
532        self.assertIsInstance(r, int)
533        q = "select * from test_table order by 1, 2"
534        self.assertEqual(query(q).getresult(),
535            [(1, 2), (3, 4), (5, 6)])
536        q = "select * from test_table where n1=$1 and n2=$2"
537        self.assertEqual(query(q, 3, 4).getresult(), [(3, 4)])
538        q = "update test_table set n2=$2 where n1=$1"
539        r = query(q, 3, 7)
540        self.assertEqual(r, '1')
541        q = "select * from test_table order by 1, 2"
542        self.assertEqual(query(q).getresult(),
543            [(1, 2), (3, 7), (5, 6)])
544        q = "delete from test_table where n2!=$1"
545        r = query(q, 4)
546        self.assertEqual(r, '3')
547        query("drop table test_table")
548
549    def testEmptyQuery(self):
550        self.assertRaises(ValueError, self.db.query, '')
551
552    def testQueryProgrammingError(self):
553        try:
554            self.db.query("select 1/0")
555        except pg.ProgrammingError as error:
556            self.assertEqual(error.sqlstate, '22012')
557
558    def testPkey(self):
559        query = self.db.query
560        for n in range(4):
561            query("drop table if exists pkeytest%d" % n)
562        query("create table pkeytest0 ("
563            "a smallint)")
564        query("create table pkeytest1 ("
565            "b smallint primary key)")
566        query("create table pkeytest2 ("
567            "c smallint, d smallint primary key)")
568        query("create table pkeytest3 ("
569            "e smallint, f smallint, g smallint, "
570            "h smallint, i smallint, "
571            "primary key (f,h))")
572        pkey = self.db.pkey
573        self.assertRaises(KeyError, pkey, 'pkeytest0')
574        self.assertEqual(pkey('pkeytest1'), 'b')
575        self.assertEqual(pkey('pkeytest2'), 'd')
576        self.assertEqual(pkey('pkeytest3'), frozenset('fh'))
577        self.assertEqual(pkey('pkeytest0', 'none'), 'none')
578        self.assertEqual(pkey('pkeytest0'), 'none')
579        pkey(None, {'t': 'a', 'n.t': 'b'})
580        self.assertEqual(pkey('t'), 'a')
581        self.assertEqual(pkey('n.t'), 'b')
582        self.assertRaises(KeyError, pkey, 'pkeytest0')
583        for n in range(4):
584            query("drop table pkeytest%d" % n)
585
586    def testGetDatabases(self):
587        databases = self.db.get_databases()
588        self.assertIn('template0', databases)
589        self.assertIn('template1', databases)
590        self.assertNotIn('not existing database', databases)
591        self.assertIn('postgres', databases)
592        self.assertIn(dbname, databases)
593
594    def testGetTables(self):
595        get_tables = self.db.get_tables
596        result1 = get_tables()
597        tables = ('"A very Special Name"',
598            '"A_MiXeD_quoted_NaMe"', 'a1', 'a2',
599            'A_MiXeD_NaMe', '"another special name"',
600            'averyveryveryveryveryveryverylongtablename',
601            'b0', 'b3', 'x', 'xx', 'xXx', 'y', 'z')
602        for t in tables:
603            self.db.query('drop table if exists %s' % t)
604            self.db.query("create table %s"
605                " as select 0" % t)
606        result3 = get_tables()
607        result2 = []
608        for t in result3:
609            if t not in result1:
610                result2.append(t)
611        result3 = []
612        for t in tables:
613            if not t.startswith('"'):
614                t = t.lower()
615            result3.append('public.' + t)
616        self.assertEqual(result2, result3)
617        for t in result2:
618            self.db.query('drop table %s' % t)
619        result2 = get_tables()
620        self.assertEqual(result2, result1)
621
622    def testGetRelations(self):
623        get_relations = self.db.get_relations
624        result = get_relations()
625        self.assertIn('public.test', result)
626        self.assertIn('public.test_view', result)
627        result = get_relations('rv')
628        self.assertIn('public.test', result)
629        self.assertIn('public.test_view', result)
630        result = get_relations('r')
631        self.assertIn('public.test', result)
632        self.assertNotIn('public.test_view', result)
633        result = get_relations('v')
634        self.assertNotIn('public.test', result)
635        self.assertIn('public.test_view', result)
636        result = get_relations('cisSt')
637        self.assertNotIn('public.test', result)
638        self.assertNotIn('public.test_view', result)
639
640    def testAttnames(self):
641        self.assertRaises(pg.ProgrammingError,
642            self.db.get_attnames, 'does_not_exist')
643        self.assertRaises(pg.ProgrammingError,
644            self.db.get_attnames, 'has.too.many.dots')
645        for table in ('attnames_test_table', 'test table for attnames'):
646            self.db.query('drop table if exists "%s"' % table)
647            self.db.query('create table "%s" ('
648                'a smallint, b integer, c bigint, '
649                'e numeric, f float, f2 double precision, m money, '
650                'x smallint, y smallint, z smallint, '
651                'Normal_NaMe smallint, "Special Name" smallint, '
652                't text, u char(2), v varchar(2), '
653                'primary key (y, u)) with oids' % table)
654            attributes = self.db.get_attnames(table)
655            result = {'a': 'int', 'c': 'int', 'b': 'int',
656                'e': 'num', 'f': 'float', 'f2': 'float', 'm': 'money',
657                'normal_name': 'int', 'Special Name': 'int',
658                'u': 'text', 't': 'text', 'v': 'text',
659                'y': 'int', 'x': 'int', 'z': 'int', 'oid': 'int'}
660            self.assertEqual(attributes, result)
661            self.db.query('drop table "%s"' % table)
662
663    def testHasTablePrivilege(self):
664        can = self.db.has_table_privilege
665        self.assertEqual(can('test'), True)
666        self.assertEqual(can('test', 'select'), True)
667        self.assertEqual(can('test', 'SeLeCt'), True)
668        self.assertEqual(can('test', 'SELECT'), True)
669        self.assertEqual(can('test', 'insert'), True)
670        self.assertEqual(can('test', 'update'), True)
671        self.assertEqual(can('test', 'delete'), True)
672        self.assertRaises(pg.ProgrammingError, can, 'test', 'foobar')
673        self.assertRaises(pg.ProgrammingError, can, 'table_does_not_exist')
674
675    def testGet(self):
676        get = self.db.get
677        query = self.db.query
678        for table in ('get_test_table', 'test table for get'):
679            query('drop table if exists "%s"' % table)
680            query('create table "%s" ('
681                "n integer, t text) with oids" % table)
682            for n, t in enumerate('xyz'):
683                query('insert into "%s" values('"%d, '%s')"
684                    % (table, n + 1, t))
685            self.assertRaises(pg.ProgrammingError, get, table, 2)
686            r = get(table, 2, 'n')
687            oid_table = table
688            if ' ' in table:
689                oid_table = '"%s"' % oid_table
690            oid_table = 'oid(public.%s)' % oid_table
691            self.assertIn(oid_table, r)
692            oid = r[oid_table]
693            self.assertIsInstance(oid, int)
694            result = {'t': 'y', 'n': 2, oid_table: oid}
695            self.assertEqual(r, result)
696            self.assertEqual(get(table + ' *', 2, 'n'), r)
697            self.assertEqual(get(table, oid, 'oid')['t'], 'y')
698            self.assertEqual(get(table, 1, 'n')['t'], 'x')
699            self.assertEqual(get(table, 3, 'n')['t'], 'z')
700            self.assertEqual(get(table, 2, 'n')['t'], 'y')
701            self.assertRaises(pg.DatabaseError, get, table, 4, 'n')
702            r['n'] = 3
703            self.assertEqual(get(table, r, 'n')['t'], 'z')
704            self.assertEqual(get(table, 1, 'n')['t'], 'x')
705            query('alter table "%s" alter n set not null' % table)
706            query('alter table "%s" add primary key (n)' % table)
707            self.assertEqual(get(table, 3)['t'], 'z')
708            self.assertEqual(get(table, 1)['t'], 'x')
709            self.assertEqual(get(table, 2)['t'], 'y')
710            r['n'] = 1
711            self.assertEqual(get(table, r)['t'], 'x')
712            r['n'] = 3
713            self.assertEqual(get(table, r)['t'], 'z')
714            r['n'] = 2
715            self.assertEqual(get(table, r)['t'], 'y')
716            query('drop table "%s"' % table)
717
718    def testGetWithCompositeKey(self):
719        get = self.db.get
720        query = self.db.query
721        table = 'get_test_table_1'
722        query("drop table if exists %s" % table)
723        query("create table %s ("
724            "n integer, t text, primary key (n))" % table)
725        for n, t in enumerate('abc'):
726            query("insert into %s values("
727                "%d, '%s')" % (table, n + 1, t))
728        self.assertEqual(get(table, 2)['t'], 'b')
729        query("drop table %s" % table)
730        table = 'get_test_table_2'
731        query("drop table if exists %s" % table)
732        query("create table %s ("
733            "n integer, m integer, t text, primary key (n, m))" % table)
734        for n in range(3):
735            for m in range(2):
736                t = chr(ord('a') + 2 * n + m)
737                query("insert into %s values("
738                    "%d, %d, '%s')" % (table, n + 1, m + 1, t))
739        self.assertRaises(pg.ProgrammingError, get, table, 2)
740        self.assertEqual(get(table, dict(n=2, m=2))['t'], 'd')
741        self.assertEqual(get(table, dict(n=1, m=2),
742                             ('n', 'm'))['t'], 'b')
743        self.assertEqual(get(table, dict(n=3, m=2),
744                             frozenset(['n', 'm']))['t'], 'f')
745        query("drop table %s" % table)
746
747    def testGetFromView(self):
748        self.db.query('delete from test where i4=14')
749        self.db.query('insert into test (i4, v4) values('
750            "14, 'abc4')")
751        r = self.db.get('test_view', 14, 'i4')
752        self.assertIn('v4', r)
753        self.assertEqual(r['v4'], 'abc4')
754
755    def testInsert(self):
756        insert = self.db.insert
757        query = self.db.query
758        for table in ('insert_test_table', 'test table for insert'):
759            query('drop table if exists "%s"' % table)
760            query('create table "%s" ('
761                "i2 smallint, i4 integer, i8 bigint,"
762                " d numeric, f4 real, f8 double precision, m money,"
763                " v4 varchar(4), c4 char(4), t text,"
764                " b boolean, ts timestamp) with oids" % table)
765            oid_table = table
766            if ' ' in table:
767                oid_table = '"%s"' % oid_table
768            oid_table = 'oid(public.%s)' % oid_table
769            tests = [dict(i2=None, i4=None, i8=None),
770                (dict(i2='', i4='', i8=''), dict(i2=None, i4=None, i8=None)),
771                (dict(i2=0, i4=0, i8=0), dict(i2=0, i4=0, i8=0)),
772                dict(i2=42, i4=123456, i8=9876543210),
773                dict(i2=2 ** 15 - 1,
774                    i4=int(2 ** 31 - 1), i8=long(2 ** 63 - 1)),
775                dict(d=None), (dict(d=''), dict(d=None)),
776                dict(d=Decimal(0)), (dict(d=0), dict(d=Decimal(0))),
777                dict(f4=None, f8=None), dict(f4=0, f8=0),
778                (dict(f4='', f8=''), dict(f4=None, f8=None)),
779                (dict(d=1234.5, f4=1234.5, f8=1234.5),
780                      dict(d=Decimal('1234.5'))),
781                dict(d=Decimal('123.456789'), f4=12.375, f8=123.4921875),
782                dict(d=Decimal('123456789.9876543212345678987654321')),
783                dict(m=None), (dict(m=''), dict(m=None)),
784                dict(m=Decimal('-1234.56')),
785                (dict(m=('-1234.56')), dict(m=Decimal('-1234.56'))),
786                dict(m=Decimal('1234.56')), dict(m=Decimal('123456')),
787                (dict(m='1234.56'), dict(m=Decimal('1234.56'))),
788                (dict(m=1234.5), dict(m=Decimal('1234.5'))),
789                (dict(m=-1234.5), dict(m=Decimal('-1234.5'))),
790                (dict(m=123456), dict(m=Decimal('123456'))),
791                (dict(m='1234567.89'), dict(m=Decimal('1234567.89'))),
792                dict(b=None), (dict(b=''), dict(b=None)),
793                dict(b='f'), dict(b='t'),
794                (dict(b=0), dict(b='f')), (dict(b=1), dict(b='t')),
795                (dict(b=False), dict(b='f')), (dict(b=True), dict(b='t')),
796                (dict(b='0'), dict(b='f')), (dict(b='1'), dict(b='t')),
797                (dict(b='n'), dict(b='f')), (dict(b='y'), dict(b='t')),
798                (dict(b='no'), dict(b='f')), (dict(b='yes'), dict(b='t')),
799                (dict(b='off'), dict(b='f')), (dict(b='on'), dict(b='t')),
800                dict(v4=None, c4=None, t=None),
801                (dict(v4='', c4='', t=''), dict(c4=' ' * 4)),
802                dict(v4='1234', c4='1234', t='1234' * 10),
803                dict(v4='abcd', c4='abcd', t='abcdefg'),
804                (dict(v4='abc', c4='abc', t='abc'), dict(c4='abc ')),
805                dict(ts=None), (dict(ts=''), dict(ts=None)),
806                (dict(ts=0), dict(ts=None)), (dict(ts=False), dict(ts=None)),
807                dict(ts='2012-12-21 00:00:00'),
808                (dict(ts='2012-12-21'), dict(ts='2012-12-21 00:00:00')),
809                dict(ts='2012-12-21 12:21:12'),
810                dict(ts='2013-01-05 12:13:14'),
811                dict(ts='current_timestamp')]
812            for test in tests:
813                if isinstance(test, dict):
814                    data = test
815                    change = {}
816                else:
817                    data, change = test
818                expect = data.copy()
819                expect.update(change)
820                self.assertEqual(insert(table, data), data)
821                self.assertIn(oid_table, data)
822                oid = data[oid_table]
823                self.assertIsInstance(oid, int)
824                data = dict(item for item in data.items()
825                    if item[0] in expect)
826                ts = expect.get('ts')
827                if ts == 'current_timestamp':
828                    ts = expect['ts'] = data['ts']
829                    if len(ts) > 19:
830                        self.assertEqual(ts[19], '.')
831                        ts = ts[:19]
832                    else:
833                        self.assertEqual(len(ts), 19)
834                    self.assertTrue(ts[:4].isdigit())
835                    self.assertEqual(ts[4], '-')
836                    self.assertEqual(ts[10], ' ')
837                    self.assertTrue(ts[11:13].isdigit())
838                    self.assertEqual(ts[13], ':')
839                self.assertEqual(data, expect)
840                data = query(
841                    'select oid,* from "%s"' % table).dictresult()[0]
842                self.assertEqual(data['oid'], oid)
843                data = dict(item for item in data.items()
844                    if item[0] in expect)
845                self.assertEqual(data, expect)
846                query('delete from "%s"' % table)
847            query('drop table "%s"' % table)
848
849    def testUpdate(self):
850        update = self.db.update
851        query = self.db.query
852        for table in ('update_test_table', 'test table for update'):
853            query('drop table if exists "%s"' % table)
854            query('create table "%s" ('
855                "n integer, t text) with oids" % table)
856            for n, t in enumerate('xyz'):
857                query('insert into "%s" values('
858                    "%d, '%s')" % (table, n + 1, t))
859            self.assertRaises(pg.ProgrammingError, self.db.get, table, 2)
860            r = self.db.get(table, 2, 'n')
861            r['t'] = 'u'
862            s = update(table, r)
863            self.assertEqual(s, r)
864            r = query('select t from "%s" where n=2' % table
865                      ).getresult()[0][0]
866            self.assertEqual(r, 'u')
867            query('drop table "%s"' % table)
868
869    def testUpdateWithCompositeKey(self):
870        update = self.db.update
871        query = self.db.query
872        table = 'update_test_table_1'
873        query("drop table if exists %s" % table)
874        query("create table %s ("
875            "n integer, t text, primary key (n))" % table)
876        for n, t in enumerate('abc'):
877            query("insert into %s values("
878                "%d, '%s')" % (table, n + 1, t))
879        self.assertRaises(pg.ProgrammingError, update,
880                          table, dict(t='b'))
881        self.assertEqual(update(table, dict(n=2, t='d'))['t'], 'd')
882        r = query('select t from "%s" where n=2' % table
883                  ).getresult()[0][0]
884        self.assertEqual(r, 'd')
885        query("drop table %s" % table)
886        table = 'update_test_table_2'
887        query("drop table if exists %s" % table)
888        query("create table %s ("
889            "n integer, m integer, t text, primary key (n, m))" % table)
890        for n in range(3):
891            for m in range(2):
892                t = chr(ord('a') + 2 * n + m)
893                query("insert into %s values("
894                    "%d, %d, '%s')" % (table, n + 1, m + 1, t))
895        self.assertRaises(pg.ProgrammingError, update,
896                          table, dict(n=2, t='b'))
897        self.assertEqual(update(table,
898                                dict(n=2, m=2, t='x'))['t'], 'x')
899        r = [r[0] for r in query('select t from "%s" where n=2'
900            ' order by m' % table).getresult()]
901        self.assertEqual(r, ['c', 'x'])
902        query("drop table %s" % table)
903
904    def testClear(self):
905        clear = self.db.clear
906        query = self.db.query
907        for table in ('clear_test_table', 'test table for clear'):
908            query('drop table if exists "%s"' % table)
909            query('create table "%s" ('
910                "n integer, b boolean, d date, t text)" % table)
911            r = clear(table)
912            result = {'n': 0, 'b': 'f', 'd': '', 't': ''}
913            self.assertEqual(r, result)
914            r['a'] = r['n'] = 1
915            r['d'] = r['t'] = 'x'
916            r['b'] = 't'
917            r['oid'] = long(1)
918            r = clear(table, r)
919            result = {'a': 1, 'n': 0, 'b': 'f', 'd': '', 't': '',
920                'oid': long(1)}
921            self.assertEqual(r, result)
922            query('drop table "%s"' % table)
923
924    def testDelete(self):
925        delete = self.db.delete
926        query = self.db.query
927        for table in ('delete_test_table', 'test table for delete'):
928            query('drop table if exists "%s"' % table)
929            query('create table "%s" ('
930                "n integer, t text) with oids" % table)
931            for n, t in enumerate('xyz'):
932                query('insert into "%s" values('
933                    "%d, '%s')" % (table, n + 1, t))
934            self.assertRaises(pg.ProgrammingError, self.db.get, table, 2)
935            r = self.db.get(table, 1, 'n')
936            s = delete(table, r)
937            self.assertEqual(s, 1)
938            r = self.db.get(table, 3, 'n')
939            s = delete(table, r)
940            self.assertEqual(s, 1)
941            s = delete(table, r)
942            self.assertEqual(s, 0)
943            r = query('select * from "%s"' % table).dictresult()
944            self.assertEqual(len(r), 1)
945            r = r[0]
946            result = {'n': 2, 't': 'y'}
947            self.assertEqual(r, result)
948            r = self.db.get(table, 2, 'n')
949            s = delete(table, r)
950            self.assertEqual(s, 1)
951            s = delete(table, r)
952            self.assertEqual(s, 0)
953            self.assertRaises(pg.DatabaseError, self.db.get, table, 2, 'n')
954            query('drop table "%s"' % table)
955
956    def testDeleteWithCompositeKey(self):
957        query = self.db.query
958        table = 'delete_test_table_1'
959        query("drop table if exists %s" % table)
960        query("create table %s ("
961            "n integer, t text, primary key (n))" % table)
962        for n, t in enumerate('abc'):
963            query("insert into %s values("
964                "%d, '%s')" % (table, n + 1, t))
965        self.assertRaises(pg.ProgrammingError, self.db.delete,
966            table, dict(t='b'))
967        self.assertEqual(self.db.delete(table, dict(n=2)), 1)
968        r = query('select t from "%s" where n=2' % table
969                  ).getresult()
970        self.assertEqual(r, [])
971        self.assertEqual(self.db.delete(table, dict(n=2)), 0)
972        r = query('select t from "%s" where n=3' % table
973                  ).getresult()[0][0]
974        self.assertEqual(r, 'c')
975        query("drop table %s" % table)
976        table = 'delete_test_table_2'
977        query("drop table if exists %s" % table)
978        query("create table %s ("
979            "n integer, m integer, t text, primary key (n, m))" % table)
980        for n in range(3):
981            for m in range(2):
982                t = chr(ord('a') + 2 * n + m)
983                query("insert into %s values("
984                    "%d, %d, '%s')" % (table, n + 1, m + 1, t))
985        self.assertRaises(pg.ProgrammingError, self.db.delete,
986            table, dict(n=2, t='b'))
987        self.assertEqual(self.db.delete(table, dict(n=2, m=2)), 1)
988        r = [r[0] for r in query('select t from "%s" where n=2'
989            ' order by m' % table).getresult()]
990        self.assertEqual(r, ['c'])
991        self.assertEqual(self.db.delete(table, dict(n=2, m=2)), 0)
992        r = [r[0] for r in query('select t from "%s" where n=3'
993            ' order by m' % table).getresult()]
994        self.assertEqual(r, ['e', 'f'])
995        self.assertEqual(self.db.delete(table, dict(n=3, m=1)), 1)
996        r = [r[0] for r in query('select t from "%s" where n=3'
997            ' order by m' % table).getresult()]
998        self.assertEqual(r, ['f'])
999        query("drop table %s" % table)
1000
1001    def testTransaction(self):
1002        query = self.db.query
1003        query("drop table if exists test_table")
1004        query("create table test_table (n integer)")
1005        self.db.begin()
1006        query("insert into test_table values (1)")
1007        query("insert into test_table values (2)")
1008        self.db.commit()
1009        self.db.begin()
1010        query("insert into test_table values (3)")
1011        query("insert into test_table values (4)")
1012        self.db.rollback()
1013        self.db.begin()
1014        query("insert into test_table values (5)")
1015        self.db.savepoint('before6')
1016        query("insert into test_table values (6)")
1017        self.db.rollback('before6')
1018        query("insert into test_table values (7)")
1019        self.db.commit()
1020        self.db.begin()
1021        self.db.savepoint('before8')
1022        query("insert into test_table values (8)")
1023        self.db.release('before8')
1024        self.assertRaises(pg.ProgrammingError, self.db.rollback, 'before8')
1025        self.db.commit()
1026        self.db.start()
1027        query("insert into test_table values (9)")
1028        self.db.end()
1029        r = [r[0] for r in query(
1030            "select * from test_table order by 1").getresult()]
1031        self.assertEqual(r, [1, 2, 5, 7, 9])
1032        query("drop table test_table")
1033
1034    def testContextManager(self):
1035        query = self.db.query
1036        query("drop table if exists test_table")
1037        query("create table test_table (n integer check(n>0))")
1038        with self.db:
1039            query("insert into test_table values (1)")
1040            query("insert into test_table values (2)")
1041        try:
1042            with self.db:
1043                query("insert into test_table values (3)")
1044                query("insert into test_table values (4)")
1045                raise ValueError('test transaction should rollback')
1046        except ValueError as error:
1047            self.assertEqual(str(error), 'test transaction should rollback')
1048        with self.db:
1049            query("insert into test_table values (5)")
1050        try:
1051            with self.db:
1052                query("insert into test_table values (6)")
1053                query("insert into test_table values (-1)")
1054        except pg.ProgrammingError as error:
1055            self.assertTrue('check' in str(error))
1056        with self.db:
1057            query("insert into test_table values (7)")
1058        r = [r[0] for r in query(
1059            "select * from test_table order by 1").getresult()]
1060        self.assertEqual(r, [1, 2, 5, 7])
1061        query("drop table test_table")
1062
1063    def testBytea(self):
1064        query = self.db.query
1065        query('drop table if exists bytea_test')
1066        query('create table bytea_test (n smallint primary key, data bytea)')
1067        s = b"It's all \\ kinds \x00 of\r nasty \xff stuff!\n"
1068        r = self.db.escape_bytea(s)
1069        query('insert into bytea_test values(3,$1)', (r,))
1070        r = query('select * from bytea_test where n=3').getresult()
1071        self.assertEqual(len(r), 1)
1072        r = r[0]
1073        self.assertEqual(len(r), 2)
1074        self.assertEqual(r[0], 3)
1075        r = r[1]
1076        self.assertIsInstance(r, str)
1077        r = self.db.unescape_bytea(r)
1078        self.assertIsInstance(r, bytes)
1079        self.assertEqual(r, s)
1080        query('drop table bytea_test')
1081
1082    def testInsertUpdateGetBytea(self):
1083        query = self.db.query
1084        query('drop table if exists bytea_test')
1085        query('create table bytea_test (n smallint primary key, data bytea)')
1086        # insert as bytes
1087        s = b"It's all \\ kinds \x00 of\r nasty \xff stuff!\n"
1088        r = self.db.insert('bytea_test', n=5, data=s)
1089        self.assertIsInstance(r, dict)
1090        self.assertIn('n', r)
1091        self.assertEqual(r['n'], 5)
1092        self.assertIn('data', r)
1093        r = r['data']
1094        self.assertIsInstance(r, bytes)
1095        self.assertEqual(r, s)
1096        # update as bytes
1097        s += b"and now even more \x00 nasty \t stuff!\f"
1098        r = self.db.update('bytea_test', n=5, data=s)
1099        self.assertIsInstance(r, dict)
1100        self.assertIn('n', r)
1101        self.assertEqual(r['n'], 5)
1102        self.assertIn('data', r)
1103        r = r['data']
1104        self.assertIsInstance(r, bytes)
1105        self.assertEqual(r, s)
1106        r = query('select * from bytea_test where n=5').getresult()
1107        self.assertEqual(len(r), 1)
1108        r = r[0]
1109        self.assertEqual(len(r), 2)
1110        self.assertEqual(r[0], 5)
1111        r = r[1]
1112        self.assertIsInstance(r, str)
1113        r = self.db.unescape_bytea(r)
1114        self.assertIsInstance(r, bytes)
1115        self.assertEqual(r, s)
1116        r = self.db.get('bytea_test', dict(n=5))
1117        self.assertIsInstance(r, dict)
1118        self.assertIn('n', r)
1119        self.assertEqual(r['n'], 5)
1120        self.assertIn('data', r)
1121        r = r['data']
1122        self.assertIsInstance(r, bytes)
1123        self.assertEqual(r, s)
1124        query('drop table bytea_test')
1125
1126    def testDebugWithCallable(self):
1127        if debug:
1128            self.assertEqual(self.db.debug, debug)
1129        else:
1130            self.assertIsNone(self.db.debug)
1131        s = []
1132        self.db.debug = s.append
1133        try:
1134            self.db.query("select 1")
1135            self.db.query("select 2")
1136            self.assertEqual(s, ["select 1", "select 2"])
1137        finally:
1138            self.db.debug = debug
1139
1140
1141class TestSchemas(unittest.TestCase):
1142    """Test correct handling of schemas (namespaces)."""
1143
1144    @classmethod
1145    def setUpClass(cls):
1146        db = DB()
1147        query = db.query
1148        query("set client_min_messages=warning")
1149        for num_schema in range(5):
1150            if num_schema:
1151                schema = "s%d" % num_schema
1152                query("drop schema if exists %s cascade" % (schema,))
1153                try:
1154                    query("create schema %s" % (schema,))
1155                except pg.ProgrammingError:
1156                    raise RuntimeError("The test user cannot create schemas.\n"
1157                        "Grant create on database %s to the user"
1158                        " for running these tests." % dbname)
1159            else:
1160                schema = "public"
1161                query("drop table if exists %s.t" % (schema,))
1162                query("drop table if exists %s.t%d" % (schema, num_schema))
1163            query("create table %s.t with oids as select 1 as n, %d as d"
1164                  % (schema, num_schema))
1165            query("create table %s.t%d with oids as select 1 as n, %d as d"
1166                  % (schema, num_schema, num_schema))
1167        db.close()
1168
1169    @classmethod
1170    def tearDownClass(cls):
1171        db = DB()
1172        query = db.query
1173        query("set client_min_messages=warning")
1174        for num_schema in range(5):
1175            if num_schema:
1176                schema = "s%d" % num_schema
1177                query("drop schema %s cascade" % (schema,))
1178            else:
1179                schema = "public"
1180                query("drop table %s.t" % (schema,))
1181                query("drop table %s.t%d" % (schema, num_schema))
1182        db.close()
1183
1184    def setUp(self):
1185        self.db = DB()
1186        self.db.query("set client_min_messages=warning")
1187
1188    def tearDown(self):
1189        self.db.close()
1190
1191    def testGetTables(self):
1192        tables = self.db.get_tables()
1193        for num_schema in range(5):
1194            if num_schema:
1195                schema = "s" + str(num_schema)
1196            else:
1197                schema = "public"
1198            for t in (schema + ".t",
1199                    schema + ".t" + str(num_schema)):
1200                self.assertIn(t, tables)
1201
1202    def testGetAttnames(self):
1203        get_attnames = self.db.get_attnames
1204        query = self.db.query
1205        result = {'oid': 'int', 'd': 'int', 'n': 'int'}
1206        r = get_attnames("t")
1207        self.assertEqual(r, result)
1208        r = get_attnames("s4.t4")
1209        self.assertEqual(r, result)
1210        query("drop table if exists s3.t3m")
1211        query("create table s3.t3m with oids as select 1 as m")
1212        result_m = {'oid': 'int', 'm': 'int'}
1213        r = get_attnames("s3.t3m")
1214        self.assertEqual(r, result_m)
1215        query("set search_path to s1,s3")
1216        r = get_attnames("t3")
1217        self.assertEqual(r, result)
1218        r = get_attnames("t3m")
1219        self.assertEqual(r, result_m)
1220        query("drop table s3.t3m")
1221
1222    def testGet(self):
1223        get = self.db.get
1224        query = self.db.query
1225        PrgError = pg.ProgrammingError
1226        self.assertEqual(get("t", 1, 'n')['d'], 0)
1227        self.assertEqual(get("t0", 1, 'n')['d'], 0)
1228        self.assertEqual(get("public.t", 1, 'n')['d'], 0)
1229        self.assertEqual(get("public.t0", 1, 'n')['d'], 0)
1230        self.assertRaises(PrgError, get, "public.t1", 1, 'n')
1231        self.assertEqual(get("s1.t1", 1, 'n')['d'], 1)
1232        self.assertEqual(get("s3.t", 1, 'n')['d'], 3)
1233        query("set search_path to s2,s4")
1234        self.assertRaises(PrgError, get, "t1", 1, 'n')
1235        self.assertEqual(get("t4", 1, 'n')['d'], 4)
1236        self.assertRaises(PrgError, get, "t3", 1, 'n')
1237        self.assertEqual(get("t", 1, 'n')['d'], 2)
1238        self.assertEqual(get("s3.t3", 1, 'n')['d'], 3)
1239        query("set search_path to s1,s3")
1240        self.assertRaises(PrgError, get, "t2", 1, 'n')
1241        self.assertEqual(get("t3", 1, 'n')['d'], 3)
1242        self.assertRaises(PrgError, get, "t4", 1, 'n')
1243        self.assertEqual(get("t", 1, 'n')['d'], 1)
1244        self.assertEqual(get("s4.t4", 1, 'n')['d'], 4)
1245
1246    def testMangling(self):
1247        get = self.db.get
1248        query = self.db.query
1249        r = get("t", 1, 'n')
1250        self.assertIn('oid(public.t)', r)
1251        query("set search_path to s2")
1252        r = get("t2", 1, 'n')
1253        self.assertIn('oid(s2.t2)', r)
1254        query("set search_path to s3")
1255        r = get("t", 1, 'n')
1256        self.assertIn('oid(s3.t)', r)
1257
1258
1259if __name__ == '__main__':
1260    unittest.main()
Note: See TracBrowser for help on using the repository browser.