source: trunk/module/tests/test_classic_dbwrapper.py @ 706

Last change on this file since 706 was 706, checked in by cito, 4 years ago

Make sure DB methods respect the new bool option

Two DB methods assumed that booleans are always returned as strings,
which is no longer true when the set_bool() option is activated.
Added a test run with different global options to make sure that
no DB methods make such tacit assumptions about these options.

  • Property svn:executable set to *
  • Property svn:keywords set to Id
File size: 49.7 KB
Line 
1#! /usr/bin/python
2# -*- coding: utf-8 -*-
3
4"""Test the classic PyGreSQL interface.
5
6Sub-tests for the DB wrapper object.
7
8Contributed by Christoph Zwerschke.
9
10These tests need a database to test against.
11
12"""
13
14try:
15    import unittest2 as unittest  # for Python < 2.7
16except ImportError:
17    import unittest
18import os
19
20import pg  # the module under test
21
22from decimal import Decimal
23
24# We need a database to test against.  If LOCAL_PyGreSQL.py exists we will
25# get our information from that.  Otherwise we use the defaults.
26# The current user must have create schema privilege on the database.
27dbname = 'unittest'
28dbhost = None
29dbport = 5432
30
31debug = False  # let DB wrapper print debugging output
32
33try:
34    from .LOCAL_PyGreSQL import *
35except (ImportError, ValueError):
36    try:
37        from LOCAL_PyGreSQL import *
38    except ImportError:
39        pass
40
41try:
42    long
43except NameError:  # Python >= 3.0
44    long = int
45
46try:
47    unicode
48except NameError:  # Python >= 3.0
49    unicode = str
50
51windows = os.name == 'nt'
52
53# There is a known a bug in libpq under Windows which can cause
54# the interface to crash when calling PQhost():
55do_not_ask_for_host = windows
56do_not_ask_for_host_reason = 'libpq issue on Windows'
57
58
59def DB():
60    """Create a DB wrapper object connecting to the test database."""
61    db = pg.DB(dbname, dbhost, dbport)
62    if debug:
63        db.debug = debug
64    db.query("set client_min_messages=warning")
65    return db
66
67
68class TestDBClassBasic(unittest.TestCase):
69    """Test existence of the DB class wrapped pg connection methods."""
70
71    def setUp(self):
72        self.db = DB()
73
74    def tearDown(self):
75        try:
76            self.db.close()
77        except pg.InternalError:
78            pass
79
80    def testAllDBAttributes(self):
81        attributes = [
82            'begin',
83            'cancel',
84            'clear',
85            'close',
86            'commit',
87            'db',
88            'dbname',
89            'debug',
90            'delete',
91            'end',
92            'endcopy',
93            'error',
94            'escape_bytea',
95            'escape_identifier',
96            'escape_literal',
97            'escape_string',
98            'fileno',
99            'get',
100            'get_attnames',
101            'get_databases',
102            'get_notice_receiver',
103            'get_relations',
104            'get_tables',
105            'getline',
106            'getlo',
107            'getnotify',
108            'has_table_privilege',
109            'host',
110            'insert',
111            'inserttable',
112            'locreate',
113            'loimport',
114            'notification_handler',
115            'options',
116            'parameter',
117            'pkey',
118            'port',
119            'protocol_version',
120            'putline',
121            'query',
122            'release',
123            'reopen',
124            'reset',
125            'rollback',
126            'savepoint',
127            'server_version',
128            'set_notice_receiver',
129            'source',
130            'start',
131            'status',
132            'transaction',
133            'unescape_bytea',
134            'update',
135            'use_regtypes',
136            'user',
137        ]
138        db_attributes = [a for a in dir(self.db)
139            if not a.startswith('_')]
140        self.assertEqual(attributes, db_attributes)
141
142    def testAttributeDb(self):
143        self.assertEqual(self.db.db.db, dbname)
144
145    def testAttributeDbname(self):
146        self.assertEqual(self.db.dbname, dbname)
147
148    def testAttributeError(self):
149        error = self.db.error
150        self.assertTrue(not error or 'krb5_' in error)
151        self.assertEqual(self.db.error, self.db.db.error)
152
153    @unittest.skipIf(do_not_ask_for_host, do_not_ask_for_host_reason)
154    def testAttributeHost(self):
155        def_host = 'localhost'
156        host = self.db.host
157        self.assertIsInstance(host, str)
158        self.assertEqual(host, dbhost or def_host)
159        self.assertEqual(host, self.db.db.host)
160
161    def testAttributeOptions(self):
162        no_options = ''
163        options = self.db.options
164        self.assertEqual(options, no_options)
165        self.assertEqual(options, self.db.db.options)
166
167    def testAttributePort(self):
168        def_port = 5432
169        port = self.db.port
170        self.assertIsInstance(port, int)
171        self.assertEqual(port, dbport or def_port)
172        self.assertEqual(port, self.db.db.port)
173
174    def testAttributeProtocolVersion(self):
175        protocol_version = self.db.protocol_version
176        self.assertIsInstance(protocol_version, int)
177        self.assertTrue(2 <= protocol_version < 4)
178        self.assertEqual(protocol_version, self.db.db.protocol_version)
179
180    def testAttributeServerVersion(self):
181        server_version = self.db.server_version
182        self.assertIsInstance(server_version, int)
183        self.assertTrue(70400 <= server_version < 100000)
184        self.assertEqual(server_version, self.db.db.server_version)
185
186    def testAttributeStatus(self):
187        status_ok = 1
188        status = self.db.status
189        self.assertIsInstance(status, int)
190        self.assertEqual(status, status_ok)
191        self.assertEqual(status, self.db.db.status)
192
193    def testAttributeUser(self):
194        no_user = 'Deprecated facility'
195        user = self.db.user
196        self.assertTrue(user)
197        self.assertIsInstance(user, str)
198        self.assertNotEqual(user, no_user)
199        self.assertEqual(user, self.db.db.user)
200
201    def testMethodEscapeLiteral(self):
202        self.assertEqual(self.db.escape_literal(''), "''")
203
204    def testMethodEscapeIdentifier(self):
205        self.assertEqual(self.db.escape_identifier(''), '""')
206
207    def testMethodEscapeString(self):
208        self.assertEqual(self.db.escape_string(''), '')
209
210    def testMethodEscapeBytea(self):
211        self.assertEqual(self.db.escape_bytea('').replace(
212            '\\x', '').replace('\\', ''), '')
213
214    def testMethodUnescapeBytea(self):
215        self.assertEqual(self.db.unescape_bytea(''), b'')
216
217    def testMethodQuery(self):
218        query = self.db.query
219        query("select 1+1")
220        query("select 1+$1+$2", 2, 3)
221        query("select 1+$1+$2", (2, 3))
222        query("select 1+$1+$2", [2, 3])
223        query("select 1+$1", 1)
224
225    def testMethodQueryEmpty(self):
226        self.assertRaises(ValueError, self.db.query, '')
227
228    def testMethodQueryProgrammingError(self):
229        try:
230            self.db.query("select 1/0")
231        except pg.ProgrammingError as error:
232            self.assertEqual(error.sqlstate, '22012')
233
234    def testMethodEndcopy(self):
235        try:
236            self.db.endcopy()
237        except IOError:
238            pass
239
240    def testMethodClose(self):
241        self.db.close()
242        try:
243            self.db.reset()
244        except pg.Error:
245            pass
246        else:
247            self.fail('Reset should give an error for a closed connection')
248        self.assertRaises(pg.InternalError, self.db.close)
249        self.assertRaises(pg.InternalError, self.db.query, 'select 1')
250
251    def testExistingConnection(self):
252        db = pg.DB(self.db.db)
253        self.assertEqual(self.db.db, db.db)
254        self.assertTrue(db.db)
255        db.close()
256        self.assertTrue(db.db)
257        db.reopen()
258        self.assertTrue(db.db)
259        db.close()
260        self.assertTrue(db.db)
261        db = pg.DB(self.db)
262        self.assertEqual(self.db.db, db.db)
263        db = pg.DB(db=self.db.db)
264        self.assertEqual(self.db.db, db.db)
265
266        class DB2:
267            pass
268
269        db2 = DB2()
270        db2._cnx = self.db.db
271        db = pg.DB(db2)
272        self.assertEqual(self.db.db, db.db)
273
274
275class TestDBClass(unittest.TestCase):
276    """Test the methods of the DB class wrapped pg connection."""
277
278    @classmethod
279    def setUpClass(cls):
280        db = DB()
281        db.query("drop table if exists test cascade")
282        db.query("create table test ("
283            "i2 smallint, i4 integer, i8 bigint,"
284            "d numeric, f4 real, f8 double precision, m money, "
285            "v4 varchar(4), c4 char(4), t text)")
286        db.query("create or replace view test_view as"
287            " select i4, v4 from test")
288        db.close()
289
290    @classmethod
291    def tearDownClass(cls):
292        db = DB()
293        db.query("drop table test cascade")
294        db.close()
295
296    def setUp(self):
297        self.db = DB()
298        query = self.db.query
299        query('set client_encoding=utf8')
300        query('set standard_conforming_strings=on')
301        query("set lc_monetary='C'")
302        query("set datestyle='ISO,YMD'")
303        query('set bytea_output=hex')
304
305    def tearDown(self):
306        self.db.close()
307
308    def testClassName(self):
309        self.assertEqual(self.db.__class__.__name__, 'DB')
310
311    def testModuleName(self):
312        self.assertEqual(self.db.__module__, 'pg')
313        self.assertEqual(self.db.__class__.__module__, 'pg')
314
315    def testEscapeLiteral(self):
316        f = self.db.escape_literal
317        r = f(b"plain")
318        self.assertIsInstance(r, bytes)
319        self.assertEqual(r, b"'plain'")
320        r = f(u"plain")
321        self.assertIsInstance(r, unicode)
322        self.assertEqual(r, u"'plain'")
323        r = f(u"that's kÀse".encode('utf-8'))
324        self.assertIsInstance(r, bytes)
325        self.assertEqual(r, u"'that''s kÀse'".encode('utf-8'))
326        r = f(u"that's kÀse")
327        self.assertIsInstance(r, unicode)
328        self.assertEqual(r, u"'that''s kÀse'")
329        self.assertEqual(f(r"It's fine to have a \ inside."),
330            r" E'It''s fine to have a \\ inside.'")
331        self.assertEqual(f('No "quotes" must be escaped.'),
332            "'No \"quotes\" must be escaped.'")
333
334    def testEscapeIdentifier(self):
335        f = self.db.escape_identifier
336        r = f(b"plain")
337        self.assertIsInstance(r, bytes)
338        self.assertEqual(r, b'"plain"')
339        r = f(u"plain")
340        self.assertIsInstance(r, unicode)
341        self.assertEqual(r, u'"plain"')
342        r = f(u"that's kÀse".encode('utf-8'))
343        self.assertIsInstance(r, bytes)
344        self.assertEqual(r, u'"that\'s kÀse"'.encode('utf-8'))
345        r = f(u"that's kÀse")
346        self.assertIsInstance(r, unicode)
347        self.assertEqual(r, u'"that\'s kÀse"')
348        self.assertEqual(f(r"It's fine to have a \ inside."),
349            '"It\'s fine to have a \\ inside."')
350        self.assertEqual(f('All "quotes" must be escaped.'),
351            '"All ""quotes"" must be escaped."')
352
353    def testEscapeString(self):
354        f = self.db.escape_string
355        r = f(b"plain")
356        self.assertIsInstance(r, bytes)
357        self.assertEqual(r, b"plain")
358        r = f(u"plain")
359        self.assertIsInstance(r, unicode)
360        self.assertEqual(r, u"plain")
361        r = f(u"that's kÀse".encode('utf-8'))
362        self.assertIsInstance(r, bytes)
363        self.assertEqual(r, u"that''s kÀse".encode('utf-8'))
364        r = f(u"that's kÀse")
365        self.assertIsInstance(r, unicode)
366        self.assertEqual(r, u"that''s kÀse")
367        self.assertEqual(f(r"It's fine to have a \ inside."),
368            r"It''s fine to have a \ inside.")
369
370    def testEscapeBytea(self):
371        f = self.db.escape_bytea
372        # note that escape_byte always returns hex output since Pg 9.0,
373        # regardless of the bytea_output setting
374        r = f(b'plain')
375        self.assertIsInstance(r, bytes)
376        self.assertEqual(r, b'\\x706c61696e')
377        r = f(u'plain')
378        self.assertIsInstance(r, unicode)
379        self.assertEqual(r, u'\\x706c61696e')
380        r = f(u"das is' kÀse".encode('utf-8'))
381        self.assertIsInstance(r, bytes)
382        self.assertEqual(r, b'\\x64617320697327206bc3a47365')
383        r = f(u"das is' kÀse")
384        self.assertIsInstance(r, unicode)
385        self.assertEqual(r, u'\\x64617320697327206bc3a47365')
386        self.assertEqual(f(b'O\x00ps\xff!'), b'\\x4f007073ff21')
387
388    def testUnescapeBytea(self):
389        f = self.db.unescape_bytea
390        r = f(b'plain')
391        self.assertIsInstance(r, bytes)
392        self.assertEqual(r, b'plain')
393        r = f(u'plain')
394        self.assertIsInstance(r, bytes)
395        self.assertEqual(r, b'plain')
396        r = f(b"das is' k\\303\\244se")
397        self.assertIsInstance(r, bytes)
398        self.assertEqual(r, u"das is' kÀse".encode('utf8'))
399        r = f(u"das is' k\\303\\244se")
400        self.assertIsInstance(r, bytes)
401        self.assertEqual(r, u"das is' kÀse".encode('utf8'))
402        self.assertEqual(f(r'O\\000ps\\377!'), b'O\\000ps\\377!')
403        self.assertEqual(f(r'\\x706c61696e'), b'\\x706c61696e')
404        self.assertEqual(f(r'\\x746861742773206be47365'),
405            b'\\x746861742773206be47365')
406        self.assertEqual(f(r'\\x4f007073ff21'), b'\\x4f007073ff21')
407
408    def testQuote(self):
409        f = self.db._quote
410        self.assertEqual(f(None, None), 'NULL')
411        self.assertEqual(f(None, 'int'), 'NULL')
412        self.assertEqual(f(None, 'float'), 'NULL')
413        self.assertEqual(f(None, 'num'), 'NULL')
414        self.assertEqual(f(None, 'money'), 'NULL')
415        self.assertEqual(f(None, 'bool'), 'NULL')
416        self.assertEqual(f(None, 'date'), 'NULL')
417        self.assertEqual(f('', 'int'), 'NULL')
418        self.assertEqual(f('', 'float'), 'NULL')
419        self.assertEqual(f('', 'num'), 'NULL')
420        self.assertEqual(f('', 'money'), 'NULL')
421        self.assertEqual(f('', 'bool'), 'NULL')
422        self.assertEqual(f('', 'date'), 'NULL')
423        self.assertEqual(f('', 'text'), "''")
424        self.assertEqual(f(0, 'int'), '0')
425        self.assertEqual(f(0, 'num'), '0')
426        self.assertEqual(f(1, 'int'), '1')
427        self.assertEqual(f(1, 'num'), '1')
428        self.assertEqual(f(-1, 'int'), '-1')
429        self.assertEqual(f(-1, 'num'), '-1')
430        self.assertEqual(f(123456789, 'int'), '123456789')
431        self.assertEqual(f(123456987, 'num'), '123456987')
432        self.assertEqual(f(1.23654789, 'num'), '1.23654789')
433        self.assertEqual(f(12365478.9, 'num'), '12365478.9')
434        self.assertEqual(f('123456789', 'num'), '123456789')
435        self.assertEqual(f('1.23456789', 'num'), '1.23456789')
436        self.assertEqual(f('12345678.9', 'num'), '12345678.9')
437        self.assertEqual(f(123, 'money'), '123')
438        self.assertEqual(f('123', 'money'), '123')
439        self.assertEqual(f(123.45, 'money'), '123.45')
440        self.assertEqual(f('123.45', 'money'), '123.45')
441        self.assertEqual(f(123.454, 'money'), '123.454')
442        self.assertEqual(f('123.454', 'money'), '123.454')
443        self.assertEqual(f(123.456, 'money'), '123.456')
444        self.assertEqual(f('123.456', 'money'), '123.456')
445        self.assertEqual(f('f', 'bool'), "'f'")
446        self.assertEqual(f('F', 'bool'), "'f'")
447        self.assertEqual(f('false', 'bool'), "'f'")
448        self.assertEqual(f('False', 'bool'), "'f'")
449        self.assertEqual(f('FALSE', 'bool'), "'f'")
450        self.assertEqual(f(0, 'bool'), "'f'")
451        self.assertEqual(f('0', 'bool'), "'f'")
452        self.assertEqual(f('-', 'bool'), "'f'")
453        self.assertEqual(f('n', 'bool'), "'f'")
454        self.assertEqual(f('N', 'bool'), "'f'")
455        self.assertEqual(f('no', 'bool'), "'f'")
456        self.assertEqual(f('off', 'bool'), "'f'")
457        self.assertEqual(f('t', 'bool'), "'t'")
458        self.assertEqual(f('T', 'bool'), "'t'")
459        self.assertEqual(f('true', 'bool'), "'t'")
460        self.assertEqual(f('True', 'bool'), "'t'")
461        self.assertEqual(f('TRUE', 'bool'), "'t'")
462        self.assertEqual(f(1, 'bool'), "'t'")
463        self.assertEqual(f(2, 'bool'), "'t'")
464        self.assertEqual(f(-1, 'bool'), "'t'")
465        self.assertEqual(f(0.5, 'bool'), "'t'")
466        self.assertEqual(f('1', 'bool'), "'t'")
467        self.assertEqual(f('y', 'bool'), "'t'")
468        self.assertEqual(f('Y', 'bool'), "'t'")
469        self.assertEqual(f('yes', 'bool'), "'t'")
470        self.assertEqual(f('on', 'bool'), "'t'")
471        self.assertEqual(f('01.01.2000', 'date'), "'01.01.2000'")
472        self.assertEqual(f(123, 'text'), "'123'")
473        self.assertEqual(f(1.23, 'text'), "'1.23'")
474        self.assertEqual(f('abc', 'text'), "'abc'")
475        self.assertEqual(f("ab'c", 'text'), "'ab''c'")
476        self.assertEqual(f('ab\\c', 'text'), "'ab\\c'")
477        self.assertEqual(f("a\\b'c", 'text'), "'a\\b''c'")
478        self.db.query('set standard_conforming_strings=off')
479        self.assertEqual(f('ab\\c', 'text'), "'ab\\\\c'")
480        self.assertEqual(f("a\\b'c", 'text'), "'a\\\\b''c'")
481
482    def testQuery(self):
483        query = self.db.query
484        query("drop table if exists test_table")
485        q = "create table test_table (n integer) with oids"
486        r = query(q)
487        self.assertIsNone(r)
488        q = "insert into test_table values (1)"
489        r = query(q)
490        self.assertIsInstance(r, int)
491        q = "insert into test_table select 2"
492        r = query(q)
493        self.assertIsInstance(r, int)
494        oid = r
495        q = "select oid from test_table where n=2"
496        r = query(q).getresult()
497        self.assertEqual(len(r), 1)
498        r = r[0]
499        self.assertEqual(len(r), 1)
500        r = r[0]
501        self.assertEqual(r, oid)
502        q = "insert into test_table select 3 union select 4 union select 5"
503        r = query(q)
504        self.assertIsInstance(r, str)
505        self.assertEqual(r, '3')
506        q = "update test_table set n=4 where n<5"
507        r = query(q)
508        self.assertIsInstance(r, str)
509        self.assertEqual(r, '4')
510        q = "delete from test_table"
511        r = query(q)
512        self.assertIsInstance(r, str)
513        self.assertEqual(r, '5')
514        query("drop table test_table")
515
516    def testMultipleQueries(self):
517        self.assertEqual(self.db.query(
518            "create temporary table test_multi (n integer);"
519            "insert into test_multi values (4711);"
520            "select n from test_multi").getresult()[0][0], 4711)
521
522    def testQueryWithParams(self):
523        query = self.db.query
524        query("drop table if exists test_table")
525        q = "create table test_table (n1 integer, n2 integer) with oids"
526        query(q)
527        q = "insert into test_table values ($1, $2)"
528        r = query(q, (1, 2))
529        self.assertIsInstance(r, int)
530        r = query(q, [3, 4])
531        self.assertIsInstance(r, int)
532        r = query(q, [5, 6])
533        self.assertIsInstance(r, int)
534        q = "select * from test_table order by 1, 2"
535        self.assertEqual(query(q).getresult(),
536            [(1, 2), (3, 4), (5, 6)])
537        q = "select * from test_table where n1=$1 and n2=$2"
538        self.assertEqual(query(q, 3, 4).getresult(), [(3, 4)])
539        q = "update test_table set n2=$2 where n1=$1"
540        r = query(q, 3, 7)
541        self.assertEqual(r, '1')
542        q = "select * from test_table order by 1, 2"
543        self.assertEqual(query(q).getresult(),
544            [(1, 2), (3, 7), (5, 6)])
545        q = "delete from test_table where n2!=$1"
546        r = query(q, 4)
547        self.assertEqual(r, '3')
548        query("drop table test_table")
549
550    def testEmptyQuery(self):
551        self.assertRaises(ValueError, self.db.query, '')
552
553    def testQueryProgrammingError(self):
554        try:
555            self.db.query("select 1/0")
556        except pg.ProgrammingError as error:
557            self.assertEqual(error.sqlstate, '22012')
558
559    def testPkey(self):
560        query = self.db.query
561        for n in range(4):
562            query("drop table if exists pkeytest%d" % n)
563        query("create table pkeytest0 ("
564            "a smallint)")
565        query("create table pkeytest1 ("
566            "b smallint primary key)")
567        query("create table pkeytest2 ("
568            "c smallint, d smallint primary key)")
569        query("create table pkeytest3 ("
570            "e smallint, f smallint, g smallint, "
571            "h smallint, i smallint, "
572            "primary key (f,h))")
573        pkey = self.db.pkey
574        self.assertRaises(KeyError, pkey, 'pkeytest0')
575        self.assertEqual(pkey('pkeytest1'), 'b')
576        self.assertEqual(pkey('pkeytest2'), 'd')
577        self.assertEqual(pkey('pkeytest3'), frozenset('fh'))
578        self.assertEqual(pkey('pkeytest0', 'none'), 'none')
579        self.assertEqual(pkey('pkeytest0'), 'none')
580        pkey(None, {'t': 'a', 'n.t': 'b'})
581        self.assertEqual(pkey('t'), 'a')
582        self.assertEqual(pkey('n.t'), 'b')
583        self.assertRaises(KeyError, pkey, 'pkeytest0')
584        for n in range(4):
585            query("drop table pkeytest%d" % n)
586
587    def testGetDatabases(self):
588        databases = self.db.get_databases()
589        self.assertIn('template0', databases)
590        self.assertIn('template1', databases)
591        self.assertNotIn('not existing database', databases)
592        self.assertIn('postgres', databases)
593        self.assertIn(dbname, databases)
594
595    def testGetTables(self):
596        get_tables = self.db.get_tables
597        result1 = get_tables()
598        tables = ('"A very Special Name"',
599            '"A_MiXeD_quoted_NaMe"', 'a1', 'a2',
600            'A_MiXeD_NaMe', '"another special name"',
601            'averyveryveryveryveryveryverylongtablename',
602            'b0', 'b3', 'x', 'xx', 'xXx', 'y', 'z')
603        for t in tables:
604            self.db.query('drop table if exists %s' % t)
605            self.db.query("create table %s"
606                " as select 0" % t)
607        result3 = get_tables()
608        result2 = []
609        for t in result3:
610            if t not in result1:
611                result2.append(t)
612        result3 = []
613        for t in tables:
614            if not t.startswith('"'):
615                t = t.lower()
616            result3.append('public.' + t)
617        self.assertEqual(result2, result3)
618        for t in result2:
619            self.db.query('drop table %s' % t)
620        result2 = get_tables()
621        self.assertEqual(result2, result1)
622
623    def testGetRelations(self):
624        get_relations = self.db.get_relations
625        result = get_relations()
626        self.assertIn('public.test', result)
627        self.assertIn('public.test_view', result)
628        result = get_relations('rv')
629        self.assertIn('public.test', result)
630        self.assertIn('public.test_view', result)
631        result = get_relations('r')
632        self.assertIn('public.test', result)
633        self.assertNotIn('public.test_view', result)
634        result = get_relations('v')
635        self.assertNotIn('public.test', result)
636        self.assertIn('public.test_view', result)
637        result = get_relations('cisSt')
638        self.assertNotIn('public.test', result)
639        self.assertNotIn('public.test_view', result)
640
641    def testAttnames(self):
642        self.assertRaises(pg.ProgrammingError,
643            self.db.get_attnames, 'does_not_exist')
644        self.assertRaises(pg.ProgrammingError,
645            self.db.get_attnames, 'has.too.many.dots')
646        for table in ('attnames_test_table', 'test table for attnames'):
647            self.db.query('drop table if exists "%s"' % table)
648            self.db.query('create table "%s" ('
649                'a smallint, b integer, c bigint, '
650                'e numeric, f float, f2 double precision, m money, '
651                'x smallint, y smallint, z smallint, '
652                'Normal_NaMe smallint, "Special Name" smallint, '
653                't text, u char(2), v varchar(2), '
654                'primary key (y, u)) with oids' % table)
655            attributes = self.db.get_attnames(table)
656            result = {'a': 'int', 'c': 'int', 'b': 'int',
657                'e': 'num', 'f': 'float', 'f2': 'float', 'm': 'money',
658                'normal_name': 'int', 'Special Name': 'int',
659                'u': 'text', 't': 'text', 'v': 'text',
660                'y': 'int', 'x': 'int', 'z': 'int', 'oid': 'int'}
661            self.assertEqual(attributes, result)
662            self.db.query('drop table "%s"' % table)
663
664    def testHasTablePrivilege(self):
665        can = self.db.has_table_privilege
666        self.assertEqual(can('test'), True)
667        self.assertEqual(can('test', 'select'), True)
668        self.assertEqual(can('test', 'SeLeCt'), True)
669        self.assertEqual(can('test', 'SELECT'), True)
670        self.assertEqual(can('test', 'insert'), True)
671        self.assertEqual(can('test', 'update'), True)
672        self.assertEqual(can('test', 'delete'), True)
673        self.assertEqual(can('pg_views', 'select'), True)
674        self.assertEqual(can('pg_views', 'delete'), False)
675        self.assertRaises(pg.ProgrammingError, can, 'test', 'foobar')
676        self.assertRaises(pg.ProgrammingError, can, 'table_does_not_exist')
677
678    def testGet(self):
679        get = self.db.get
680        query = self.db.query
681        for table in ('get_test_table', 'test table for get'):
682            query('drop table if exists "%s"' % table)
683            query('create table "%s" ('
684                "n integer, t text) with oids" % table)
685            for n, t in enumerate('xyz'):
686                query('insert into "%s" values('"%d, '%s')"
687                    % (table, n + 1, t))
688            self.assertRaises(pg.ProgrammingError, get, table, 2)
689            r = get(table, 2, 'n')
690            oid_table = table
691            if ' ' in table:
692                oid_table = '"%s"' % oid_table
693            oid_table = 'oid(public.%s)' % oid_table
694            self.assertIn(oid_table, r)
695            oid = r[oid_table]
696            self.assertIsInstance(oid, int)
697            result = {'t': 'y', 'n': 2, oid_table: oid}
698            self.assertEqual(r, result)
699            self.assertEqual(get(table + ' *', 2, 'n'), r)
700            self.assertEqual(get(table, oid, 'oid')['t'], 'y')
701            self.assertEqual(get(table, 1, 'n')['t'], 'x')
702            self.assertEqual(get(table, 3, 'n')['t'], 'z')
703            self.assertEqual(get(table, 2, 'n')['t'], 'y')
704            self.assertRaises(pg.DatabaseError, get, table, 4, 'n')
705            r['n'] = 3
706            self.assertEqual(get(table, r, 'n')['t'], 'z')
707            self.assertEqual(get(table, 1, 'n')['t'], 'x')
708            query('alter table "%s" alter n set not null' % table)
709            query('alter table "%s" add primary key (n)' % table)
710            self.assertEqual(get(table, 3)['t'], 'z')
711            self.assertEqual(get(table, 1)['t'], 'x')
712            self.assertEqual(get(table, 2)['t'], 'y')
713            r['n'] = 1
714            self.assertEqual(get(table, r)['t'], 'x')
715            r['n'] = 3
716            self.assertEqual(get(table, r)['t'], 'z')
717            r['n'] = 2
718            self.assertEqual(get(table, r)['t'], 'y')
719            query('drop table "%s"' % table)
720
721    def testGetWithCompositeKey(self):
722        get = self.db.get
723        query = self.db.query
724        table = 'get_test_table_1'
725        query("drop table if exists %s" % table)
726        query("create table %s ("
727            "n integer, t text, primary key (n))" % table)
728        for n, t in enumerate('abc'):
729            query("insert into %s values("
730                "%d, '%s')" % (table, n + 1, t))
731        self.assertEqual(get(table, 2)['t'], 'b')
732        query("drop table %s" % table)
733        table = 'get_test_table_2'
734        query("drop table if exists %s" % table)
735        query("create table %s ("
736            "n integer, m integer, t text, primary key (n, m))" % table)
737        for n in range(3):
738            for m in range(2):
739                t = chr(ord('a') + 2 * n + m)
740                query("insert into %s values("
741                    "%d, %d, '%s')" % (table, n + 1, m + 1, t))
742        self.assertRaises(pg.ProgrammingError, get, table, 2)
743        self.assertEqual(get(table, dict(n=2, m=2))['t'], 'd')
744        self.assertEqual(get(table, dict(n=1, m=2),
745                             ('n', 'm'))['t'], 'b')
746        self.assertEqual(get(table, dict(n=3, m=2),
747                             frozenset(['n', 'm']))['t'], 'f')
748        query("drop table %s" % table)
749
750    def testGetFromView(self):
751        self.db.query('delete from test where i4=14')
752        self.db.query('insert into test (i4, v4) values('
753            "14, 'abc4')")
754        r = self.db.get('test_view', 14, 'i4')
755        self.assertIn('v4', r)
756        self.assertEqual(r['v4'], 'abc4')
757
758    def testInsert(self):
759        insert = self.db.insert
760        query = self.db.query
761        server_version = self.db.server_version
762        bool_on = pg.get_bool()
763        decimal = pg.get_decimal()
764        for table in ('insert_test_table', 'test table for insert'):
765            query('drop table if exists "%s"' % table)
766            query('create table "%s" ('
767                "i2 smallint, i4 integer, i8 bigint,"
768                " d numeric, f4 real, f8 double precision, m money,"
769                " v4 varchar(4), c4 char(4), t text,"
770                " b boolean, ts timestamp) with oids" % table)
771            oid_table = table
772            if ' ' in table:
773                oid_table = '"%s"' % oid_table
774            oid_table = 'oid(public.%s)' % oid_table
775            tests = [dict(i2=None, i4=None, i8=None),
776                (dict(i2='', i4='', i8=''), dict(i2=None, i4=None, i8=None)),
777                (dict(i2=0, i4=0, i8=0), dict(i2=0, i4=0, i8=0)),
778                dict(i2=42, i4=123456, i8=9876543210),
779                dict(i2=2 ** 15 - 1,
780                    i4=int(2 ** 31 - 1), i8=long(2 ** 63 - 1)),
781                dict(d=None), (dict(d=''), dict(d=None)),
782                dict(d=Decimal(0)), (dict(d=0), dict(d=Decimal(0))),
783                dict(f4=None, f8=None), dict(f4=0, f8=0),
784                (dict(f4='', f8=''), dict(f4=None, f8=None)),
785                (dict(d=1234.5, f4=1234.5, f8=1234.5),
786                      dict(d=Decimal('1234.5'))),
787                dict(d=Decimal('123.456789'), f4=12.375, f8=123.4921875),
788                dict(d=Decimal('123456789.9876543212345678987654321')),
789                dict(m=None), (dict(m=''), dict(m=None)),
790                dict(m=Decimal('-1234.56')),
791                (dict(m=('-1234.56')), dict(m=Decimal('-1234.56'))),
792                dict(m=Decimal('1234.56')), dict(m=Decimal('123456')),
793                (dict(m='1234.56'), dict(m=Decimal('1234.56'))),
794                (dict(m=1234.5), dict(m=Decimal('1234.5'))),
795                (dict(m=-1234.5), dict(m=Decimal('-1234.5'))),
796                (dict(m=123456), dict(m=Decimal('123456'))),
797                (dict(m='1234567.89'), dict(m=Decimal('1234567.89'))),
798                dict(b=None), (dict(b=''), dict(b=None)),
799                dict(b='f'), dict(b='t'),
800                (dict(b=0), dict(b='f')), (dict(b=1), dict(b='t')),
801                (dict(b=False), dict(b='f')), (dict(b=True), dict(b='t')),
802                (dict(b='0'), dict(b='f')), (dict(b='1'), dict(b='t')),
803                (dict(b='n'), dict(b='f')), (dict(b='y'), dict(b='t')),
804                (dict(b='no'), dict(b='f')), (dict(b='yes'), dict(b='t')),
805                (dict(b='off'), dict(b='f')), (dict(b='on'), dict(b='t')),
806                dict(v4=None, c4=None, t=None),
807                (dict(v4='', c4='', t=''), dict(c4=' ' * 4)),
808                dict(v4='1234', c4='1234', t='1234' * 10),
809                dict(v4='abcd', c4='abcd', t='abcdefg'),
810                (dict(v4='abc', c4='abc', t='abc'), dict(c4='abc ')),
811                dict(ts=None), (dict(ts=''), dict(ts=None)),
812                (dict(ts=0), dict(ts=None)), (dict(ts=False), dict(ts=None)),
813                dict(ts='2012-12-21 00:00:00'),
814                (dict(ts='2012-12-21'), dict(ts='2012-12-21 00:00:00')),
815                dict(ts='2012-12-21 12:21:12'),
816                dict(ts='2013-01-05 12:13:14'),
817                dict(ts='current_timestamp')]
818            for test in tests:
819                if isinstance(test, dict):
820                    data = test
821                    change = {}
822                else:
823                    data, change = test
824                expect = data.copy()
825                expect.update(change)
826                if bool_on:
827                    b = expect.get('b')
828                    if b is not None:
829                        expect['b'] = b == 't'
830                if decimal is not Decimal:
831                    d = expect.get('d')
832                    if d is not None:
833                        expect['d'] = decimal(d)
834                    m = expect.get('m')
835                    if m is not None:
836                        expect['m'] = decimal(m)
837                if data.get('m') and server_version < 910000:
838                    # PostgreSQL < 9.1 cannot directly convert numbers to money
839                    data['m'] = "'%s'::money" % data['m']
840                self.assertEqual(insert(table, data), data)
841                self.assertIn(oid_table, data)
842                oid = data[oid_table]
843                self.assertIsInstance(oid, int)
844                data = dict(item for item in data.items()
845                    if item[0] in expect)
846                ts = expect.get('ts')
847                if ts == 'current_timestamp':
848                    ts = expect['ts'] = data['ts']
849                    if len(ts) > 19:
850                        self.assertEqual(ts[19], '.')
851                        ts = ts[:19]
852                    else:
853                        self.assertEqual(len(ts), 19)
854                    self.assertTrue(ts[:4].isdigit())
855                    self.assertEqual(ts[4], '-')
856                    self.assertEqual(ts[10], ' ')
857                    self.assertTrue(ts[11:13].isdigit())
858                    self.assertEqual(ts[13], ':')
859                self.assertEqual(data, expect)
860                data = query(
861                    'select oid,* from "%s"' % table).dictresult()[0]
862                self.assertEqual(data['oid'], oid)
863                data = dict(item for item in data.items()
864                    if item[0] in expect)
865                self.assertEqual(data, expect)
866                query('delete from "%s"' % table)
867            query('drop table "%s"' % table)
868
869    def testUpdate(self):
870        update = self.db.update
871        query = self.db.query
872        for table in ('update_test_table', 'test table for update'):
873            query('drop table if exists "%s"' % table)
874            query('create table "%s" ('
875                "n integer, t text) with oids" % table)
876            for n, t in enumerate('xyz'):
877                query('insert into "%s" values('
878                    "%d, '%s')" % (table, n + 1, t))
879            self.assertRaises(pg.ProgrammingError, self.db.get, table, 2)
880            r = self.db.get(table, 2, 'n')
881            r['t'] = 'u'
882            s = update(table, r)
883            self.assertEqual(s, r)
884            r = query('select t from "%s" where n=2' % table
885                      ).getresult()[0][0]
886            self.assertEqual(r, 'u')
887            query('drop table "%s"' % table)
888
889    def testUpdateWithCompositeKey(self):
890        update = self.db.update
891        query = self.db.query
892        table = 'update_test_table_1'
893        query("drop table if exists %s" % table)
894        query("create table %s ("
895            "n integer, t text, primary key (n))" % table)
896        for n, t in enumerate('abc'):
897            query("insert into %s values("
898                "%d, '%s')" % (table, n + 1, t))
899        self.assertRaises(pg.ProgrammingError, update,
900                          table, dict(t='b'))
901        self.assertEqual(update(table, dict(n=2, t='d'))['t'], 'd')
902        r = query('select t from "%s" where n=2' % table
903                  ).getresult()[0][0]
904        self.assertEqual(r, 'd')
905        query("drop table %s" % table)
906        table = 'update_test_table_2'
907        query("drop table if exists %s" % table)
908        query("create table %s ("
909            "n integer, m integer, t text, primary key (n, m))" % table)
910        for n in range(3):
911            for m in range(2):
912                t = chr(ord('a') + 2 * n + m)
913                query("insert into %s values("
914                    "%d, %d, '%s')" % (table, n + 1, m + 1, t))
915        self.assertRaises(pg.ProgrammingError, update,
916                          table, dict(n=2, t='b'))
917        self.assertEqual(update(table,
918                                dict(n=2, m=2, t='x'))['t'], 'x')
919        r = [r[0] for r in query('select t from "%s" where n=2'
920            ' order by m' % table).getresult()]
921        self.assertEqual(r, ['c', 'x'])
922        query("drop table %s" % table)
923
924    def testClear(self):
925        clear = self.db.clear
926        query = self.db.query
927        f = False if pg.get_bool() else 'f'
928        for table in ('clear_test_table', 'test table for clear'):
929            query('drop table if exists "%s"' % table)
930            query('create table "%s" ('
931                "n integer, b boolean, d date, t text)" % table)
932            r = clear(table)
933            result = {'n': 0, 'b': f, 'd': '', 't': ''}
934            self.assertEqual(r, result)
935            r['a'] = r['n'] = 1
936            r['d'] = r['t'] = 'x'
937            r['b'] = 't'
938            r['oid'] = long(1)
939            r = clear(table, r)
940            result = {'a': 1, 'n': 0, 'b': f, 'd': '', 't': '',
941                'oid': long(1)}
942            self.assertEqual(r, result)
943            query('drop table "%s"' % table)
944
945    def testDelete(self):
946        delete = self.db.delete
947        query = self.db.query
948        for table in ('delete_test_table', 'test table for delete'):
949            query('drop table if exists "%s"' % table)
950            query('create table "%s" ('
951                "n integer, t text) with oids" % table)
952            for n, t in enumerate('xyz'):
953                query('insert into "%s" values('
954                    "%d, '%s')" % (table, n + 1, t))
955            self.assertRaises(pg.ProgrammingError, self.db.get, table, 2)
956            r = self.db.get(table, 1, 'n')
957            s = delete(table, r)
958            self.assertEqual(s, 1)
959            r = self.db.get(table, 3, 'n')
960            s = delete(table, r)
961            self.assertEqual(s, 1)
962            s = delete(table, r)
963            self.assertEqual(s, 0)
964            r = query('select * from "%s"' % table).dictresult()
965            self.assertEqual(len(r), 1)
966            r = r[0]
967            result = {'n': 2, 't': 'y'}
968            self.assertEqual(r, result)
969            r = self.db.get(table, 2, 'n')
970            s = delete(table, r)
971            self.assertEqual(s, 1)
972            s = delete(table, r)
973            self.assertEqual(s, 0)
974            self.assertRaises(pg.DatabaseError, self.db.get, table, 2, 'n')
975            query('drop table "%s"' % table)
976
977    def testDeleteWithCompositeKey(self):
978        query = self.db.query
979        table = 'delete_test_table_1'
980        query("drop table if exists %s" % table)
981        query("create table %s ("
982            "n integer, t text, primary key (n))" % table)
983        for n, t in enumerate('abc'):
984            query("insert into %s values("
985                "%d, '%s')" % (table, n + 1, t))
986        self.assertRaises(pg.ProgrammingError, self.db.delete,
987            table, dict(t='b'))
988        self.assertEqual(self.db.delete(table, dict(n=2)), 1)
989        r = query('select t from "%s" where n=2' % table
990                  ).getresult()
991        self.assertEqual(r, [])
992        self.assertEqual(self.db.delete(table, dict(n=2)), 0)
993        r = query('select t from "%s" where n=3' % table
994                  ).getresult()[0][0]
995        self.assertEqual(r, 'c')
996        query("drop table %s" % table)
997        table = 'delete_test_table_2'
998        query("drop table if exists %s" % table)
999        query("create table %s ("
1000            "n integer, m integer, t text, primary key (n, m))" % table)
1001        for n in range(3):
1002            for m in range(2):
1003                t = chr(ord('a') + 2 * n + m)
1004                query("insert into %s values("
1005                    "%d, %d, '%s')" % (table, n + 1, m + 1, t))
1006        self.assertRaises(pg.ProgrammingError, self.db.delete,
1007            table, dict(n=2, t='b'))
1008        self.assertEqual(self.db.delete(table, dict(n=2, m=2)), 1)
1009        r = [r[0] for r in query('select t from "%s" where n=2'
1010            ' order by m' % table).getresult()]
1011        self.assertEqual(r, ['c'])
1012        self.assertEqual(self.db.delete(table, dict(n=2, m=2)), 0)
1013        r = [r[0] for r in query('select t from "%s" where n=3'
1014            ' order by m' % table).getresult()]
1015        self.assertEqual(r, ['e', 'f'])
1016        self.assertEqual(self.db.delete(table, dict(n=3, m=1)), 1)
1017        r = [r[0] for r in query('select t from "%s" where n=3'
1018            ' order by m' % table).getresult()]
1019        self.assertEqual(r, ['f'])
1020        query("drop table %s" % table)
1021
1022    def testTransaction(self):
1023        query = self.db.query
1024        query("drop table if exists test_table")
1025        query("create table test_table (n integer)")
1026        self.db.begin()
1027        query("insert into test_table values (1)")
1028        query("insert into test_table values (2)")
1029        self.db.commit()
1030        self.db.begin()
1031        query("insert into test_table values (3)")
1032        query("insert into test_table values (4)")
1033        self.db.rollback()
1034        self.db.begin()
1035        query("insert into test_table values (5)")
1036        self.db.savepoint('before6')
1037        query("insert into test_table values (6)")
1038        self.db.rollback('before6')
1039        query("insert into test_table values (7)")
1040        self.db.commit()
1041        self.db.begin()
1042        self.db.savepoint('before8')
1043        query("insert into test_table values (8)")
1044        self.db.release('before8')
1045        self.assertRaises(pg.ProgrammingError, self.db.rollback, 'before8')
1046        self.db.commit()
1047        self.db.start()
1048        query("insert into test_table values (9)")
1049        self.db.end()
1050        r = [r[0] for r in query(
1051            "select * from test_table order by 1").getresult()]
1052        self.assertEqual(r, [1, 2, 5, 7, 9])
1053        query("drop table test_table")
1054
1055    def testContextManager(self):
1056        query = self.db.query
1057        query("drop table if exists test_table")
1058        query("create table test_table (n integer check(n>0))")
1059        with self.db:
1060            query("insert into test_table values (1)")
1061            query("insert into test_table values (2)")
1062        try:
1063            with self.db:
1064                query("insert into test_table values (3)")
1065                query("insert into test_table values (4)")
1066                raise ValueError('test transaction should rollback')
1067        except ValueError as error:
1068            self.assertEqual(str(error), 'test transaction should rollback')
1069        with self.db:
1070            query("insert into test_table values (5)")
1071        try:
1072            with self.db:
1073                query("insert into test_table values (6)")
1074                query("insert into test_table values (-1)")
1075        except pg.ProgrammingError as error:
1076            self.assertTrue('check' in str(error))
1077        with self.db:
1078            query("insert into test_table values (7)")
1079        r = [r[0] for r in query(
1080            "select * from test_table order by 1").getresult()]
1081        self.assertEqual(r, [1, 2, 5, 7])
1082        query("drop table test_table")
1083
1084    def testBytea(self):
1085        query = self.db.query
1086        query('drop table if exists bytea_test')
1087        query('create table bytea_test (n smallint primary key, data bytea)')
1088        s = b"It's all \\ kinds \x00 of\r nasty \xff stuff!\n"
1089        r = self.db.escape_bytea(s)
1090        query('insert into bytea_test values(3,$1)', (r,))
1091        r = query('select * from bytea_test where n=3').getresult()
1092        self.assertEqual(len(r), 1)
1093        r = r[0]
1094        self.assertEqual(len(r), 2)
1095        self.assertEqual(r[0], 3)
1096        r = r[1]
1097        self.assertIsInstance(r, str)
1098        r = self.db.unescape_bytea(r)
1099        self.assertIsInstance(r, bytes)
1100        self.assertEqual(r, s)
1101        query('drop table bytea_test')
1102
1103    def testInsertUpdateGetBytea(self):
1104        query = self.db.query
1105        query('drop table if exists bytea_test')
1106        query('create table bytea_test (n smallint primary key, data bytea)')
1107        # insert as bytes
1108        s = b"It's all \\ kinds \x00 of\r nasty \xff stuff!\n"
1109        r = self.db.insert('bytea_test', n=5, data=s)
1110        self.assertIsInstance(r, dict)
1111        self.assertIn('n', r)
1112        self.assertEqual(r['n'], 5)
1113        self.assertIn('data', r)
1114        r = r['data']
1115        self.assertIsInstance(r, bytes)
1116        self.assertEqual(r, s)
1117        # update as bytes
1118        s += b"and now even more \x00 nasty \t stuff!\f"
1119        r = self.db.update('bytea_test', n=5, data=s)
1120        self.assertIsInstance(r, dict)
1121        self.assertIn('n', r)
1122        self.assertEqual(r['n'], 5)
1123        self.assertIn('data', r)
1124        r = r['data']
1125        self.assertIsInstance(r, bytes)
1126        self.assertEqual(r, s)
1127        r = query('select * from bytea_test where n=5').getresult()
1128        self.assertEqual(len(r), 1)
1129        r = r[0]
1130        self.assertEqual(len(r), 2)
1131        self.assertEqual(r[0], 5)
1132        r = r[1]
1133        self.assertIsInstance(r, str)
1134        r = self.db.unescape_bytea(r)
1135        self.assertIsInstance(r, bytes)
1136        self.assertEqual(r, s)
1137        r = self.db.get('bytea_test', dict(n=5))
1138        self.assertIsInstance(r, dict)
1139        self.assertIn('n', r)
1140        self.assertEqual(r['n'], 5)
1141        self.assertIn('data', r)
1142        r = r['data']
1143        self.assertIsInstance(r, bytes)
1144        self.assertEqual(r, s)
1145        query('drop table bytea_test')
1146
1147    def testDebugWithCallable(self):
1148        if debug:
1149            self.assertEqual(self.db.debug, debug)
1150        else:
1151            self.assertIsNone(self.db.debug)
1152        s = []
1153        self.db.debug = s.append
1154        try:
1155            self.db.query("select 1")
1156            self.db.query("select 2")
1157            self.assertEqual(s, ["select 1", "select 2"])
1158        finally:
1159            self.db.debug = debug
1160
1161
1162class TestDBClassNonStdOpts(TestDBClass):
1163    """Test the methods of the DB class with non-standard global options."""
1164
1165    @classmethod
1166    def setUpClass(cls):
1167        cls.saved_options = {}
1168        cls.set_option('decimal', float)
1169        not_bool = not pg.get_bool()
1170        cls.set_option('bool', not_bool)
1171        unnamed_result = lambda q: q.getresult()
1172        cls.set_option('namedresult', unnamed_result)
1173        super(TestDBClassNonStdOpts, cls).setUpClass()
1174
1175    @classmethod
1176    def tearDownClass(cls):
1177        super(TestDBClassNonStdOpts, cls).tearDownClass()
1178        cls.reset_option('namedresult')
1179        cls.reset_option('bool')
1180        cls.reset_option('decimal')
1181
1182    @classmethod
1183    def set_option(cls, option, value):
1184        cls.saved_options[option] = getattr(pg, 'get_' + option)()
1185        return getattr(pg, 'set_' + option)(value)
1186
1187    @classmethod
1188    def reset_option(cls, option):
1189        return getattr(pg, 'set_' + option)(cls.saved_options[option])
1190
1191
1192class TestSchemas(unittest.TestCase):
1193    """Test correct handling of schemas (namespaces)."""
1194
1195    @classmethod
1196    def setUpClass(cls):
1197        db = DB()
1198        query = db.query
1199        query("set client_min_messages=warning")
1200        for num_schema in range(5):
1201            if num_schema:
1202                schema = "s%d" % num_schema
1203                query("drop schema if exists %s cascade" % (schema,))
1204                try:
1205                    query("create schema %s" % (schema,))
1206                except pg.ProgrammingError:
1207                    raise RuntimeError("The test user cannot create schemas.\n"
1208                        "Grant create on database %s to the user"
1209                        " for running these tests." % dbname)
1210            else:
1211                schema = "public"
1212                query("drop table if exists %s.t" % (schema,))
1213                query("drop table if exists %s.t%d" % (schema, num_schema))
1214            query("create table %s.t with oids as select 1 as n, %d as d"
1215                  % (schema, num_schema))
1216            query("create table %s.t%d with oids as select 1 as n, %d as d"
1217                  % (schema, num_schema, num_schema))
1218        db.close()
1219
1220    @classmethod
1221    def tearDownClass(cls):
1222        db = DB()
1223        query = db.query
1224        query("set client_min_messages=warning")
1225        for num_schema in range(5):
1226            if num_schema:
1227                schema = "s%d" % num_schema
1228                query("drop schema %s cascade" % (schema,))
1229            else:
1230                schema = "public"
1231                query("drop table %s.t" % (schema,))
1232                query("drop table %s.t%d" % (schema, num_schema))
1233        db.close()
1234
1235    def setUp(self):
1236        self.db = DB()
1237        self.db.query("set client_min_messages=warning")
1238
1239    def tearDown(self):
1240        self.db.close()
1241
1242    def testGetTables(self):
1243        tables = self.db.get_tables()
1244        for num_schema in range(5):
1245            if num_schema:
1246                schema = "s" + str(num_schema)
1247            else:
1248                schema = "public"
1249            for t in (schema + ".t",
1250                    schema + ".t" + str(num_schema)):
1251                self.assertIn(t, tables)
1252
1253    def testGetAttnames(self):
1254        get_attnames = self.db.get_attnames
1255        query = self.db.query
1256        result = {'oid': 'int', 'd': 'int', 'n': 'int'}
1257        r = get_attnames("t")
1258        self.assertEqual(r, result)
1259        r = get_attnames("s4.t4")
1260        self.assertEqual(r, result)
1261        query("drop table if exists s3.t3m")
1262        query("create table s3.t3m with oids as select 1 as m")
1263        result_m = {'oid': 'int', 'm': 'int'}
1264        r = get_attnames("s3.t3m")
1265        self.assertEqual(r, result_m)
1266        query("set search_path to s1,s3")
1267        r = get_attnames("t3")
1268        self.assertEqual(r, result)
1269        r = get_attnames("t3m")
1270        self.assertEqual(r, result_m)
1271        query("drop table s3.t3m")
1272
1273    def testGet(self):
1274        get = self.db.get
1275        query = self.db.query
1276        PrgError = pg.ProgrammingError
1277        self.assertEqual(get("t", 1, 'n')['d'], 0)
1278        self.assertEqual(get("t0", 1, 'n')['d'], 0)
1279        self.assertEqual(get("public.t", 1, 'n')['d'], 0)
1280        self.assertEqual(get("public.t0", 1, 'n')['d'], 0)
1281        self.assertRaises(PrgError, get, "public.t1", 1, 'n')
1282        self.assertEqual(get("s1.t1", 1, 'n')['d'], 1)
1283        self.assertEqual(get("s3.t", 1, 'n')['d'], 3)
1284        query("set search_path to s2,s4")
1285        self.assertRaises(PrgError, get, "t1", 1, 'n')
1286        self.assertEqual(get("t4", 1, 'n')['d'], 4)
1287        self.assertRaises(PrgError, get, "t3", 1, 'n')
1288        self.assertEqual(get("t", 1, 'n')['d'], 2)
1289        self.assertEqual(get("s3.t3", 1, 'n')['d'], 3)
1290        query("set search_path to s1,s3")
1291        self.assertRaises(PrgError, get, "t2", 1, 'n')
1292        self.assertEqual(get("t3", 1, 'n')['d'], 3)
1293        self.assertRaises(PrgError, get, "t4", 1, 'n')
1294        self.assertEqual(get("t", 1, 'n')['d'], 1)
1295        self.assertEqual(get("s4.t4", 1, 'n')['d'], 4)
1296
1297    def testMangling(self):
1298        get = self.db.get
1299        query = self.db.query
1300        r = get("t", 1, 'n')
1301        self.assertIn('oid(public.t)', r)
1302        query("set search_path to s2")
1303        r = get("t2", 1, 'n')
1304        self.assertIn('oid(s2.t2)', r)
1305        query("set search_path to s3")
1306        r = get("t", 1, 'n')
1307        self.assertIn('oid(s3.t)', r)
1308
1309
1310if __name__ == '__main__':
1311    unittest.main()
Note: See TracBrowser for help on using the repository browser.