source: trunk/tests/test_classic_dbwrapper.py @ 732

Last change on this file since 732 was 732, checked in by cito, 4 years ago

Better handling of quoted identifiers

Methods like get(), update() did not handle quoted identifiers properly
(i.e. identifiers with spaces, mixed case characters or special characters).
This has been improved and tests have been added to make sure this works.

  • Property svn:executable set to *
  • Property svn:keywords set to Id
File size: 50.8 KB
Line 
1#! /usr/bin/python
2# -*- coding: utf-8 -*-
3
4"""Test the classic PyGreSQL interface.
5
6Sub-tests for the DB wrapper object.
7
8Contributed by Christoph Zwerschke.
9
10These tests need a database to test against.
11
12"""
13
14try:
15    import unittest2 as unittest  # for Python < 2.7
16except ImportError:
17    import unittest
18import os
19
20import pg  # the module under test
21
22from decimal import Decimal
23
24# We need a database to test against.  If LOCAL_PyGreSQL.py exists we will
25# get our information from that.  Otherwise we use the defaults.
26# The current user must have create schema privilege on the database.
27dbname = 'unittest'
28dbhost = None
29dbport = 5432
30
31debug = False  # let DB wrapper print debugging output
32
33try:
34    from .LOCAL_PyGreSQL import *
35except (ImportError, ValueError):
36    try:
37        from LOCAL_PyGreSQL import *
38    except ImportError:
39        pass
40
41try:
42    long
43except NameError:  # Python >= 3.0
44    long = int
45
46try:
47    unicode
48except NameError:  # Python >= 3.0
49    unicode = str
50
51windows = os.name == 'nt'
52
53# There is a known a bug in libpq under Windows which can cause
54# the interface to crash when calling PQhost():
55do_not_ask_for_host = windows
56do_not_ask_for_host_reason = 'libpq issue on Windows'
57
58
59def DB():
60    """Create a DB wrapper object connecting to the test database."""
61    db = pg.DB(dbname, dbhost, dbport)
62    if debug:
63        db.debug = debug
64    db.query("set client_min_messages=warning")
65    return db
66
67
68class TestDBClassBasic(unittest.TestCase):
69    """Test existence of the DB class wrapped pg connection methods."""
70
71    def setUp(self):
72        self.db = DB()
73
74    def tearDown(self):
75        try:
76            self.db.close()
77        except pg.InternalError:
78            pass
79
80    def testAllDBAttributes(self):
81        attributes = [
82            'begin',
83            'cancel',
84            'clear',
85            'close',
86            'commit',
87            'db',
88            'dbname',
89            'debug',
90            'delete',
91            'end',
92            'endcopy',
93            'error',
94            'escape_bytea',
95            'escape_identifier',
96            'escape_literal',
97            'escape_string',
98            'fileno',
99            'get',
100            'get_attnames',
101            'get_databases',
102            'get_notice_receiver',
103            'get_relations',
104            'get_tables',
105            'getline',
106            'getlo',
107            'getnotify',
108            'has_table_privilege',
109            'host',
110            'insert',
111            'inserttable',
112            'locreate',
113            'loimport',
114            'notification_handler',
115            'options',
116            'parameter',
117            'pkey',
118            'port',
119            'protocol_version',
120            'putline',
121            'query',
122            'release',
123            'reopen',
124            'reset',
125            'rollback',
126            'savepoint',
127            'server_version',
128            'set_notice_receiver',
129            'source',
130            'start',
131            'status',
132            'transaction',
133            'unescape_bytea',
134            'update',
135            'use_regtypes',
136            'user',
137        ]
138        db_attributes = [a for a in dir(self.db)
139            if not a.startswith('_')]
140        self.assertEqual(attributes, db_attributes)
141
142    def testAttributeDb(self):
143        self.assertEqual(self.db.db.db, dbname)
144
145    def testAttributeDbname(self):
146        self.assertEqual(self.db.dbname, dbname)
147
148    def testAttributeError(self):
149        error = self.db.error
150        self.assertTrue(not error or 'krb5_' in error)
151        self.assertEqual(self.db.error, self.db.db.error)
152
153    @unittest.skipIf(do_not_ask_for_host, do_not_ask_for_host_reason)
154    def testAttributeHost(self):
155        def_host = 'localhost'
156        host = self.db.host
157        self.assertIsInstance(host, str)
158        self.assertEqual(host, dbhost or def_host)
159        self.assertEqual(host, self.db.db.host)
160
161    def testAttributeOptions(self):
162        no_options = ''
163        options = self.db.options
164        self.assertEqual(options, no_options)
165        self.assertEqual(options, self.db.db.options)
166
167    def testAttributePort(self):
168        def_port = 5432
169        port = self.db.port
170        self.assertIsInstance(port, int)
171        self.assertEqual(port, dbport or def_port)
172        self.assertEqual(port, self.db.db.port)
173
174    def testAttributeProtocolVersion(self):
175        protocol_version = self.db.protocol_version
176        self.assertIsInstance(protocol_version, int)
177        self.assertTrue(2 <= protocol_version < 4)
178        self.assertEqual(protocol_version, self.db.db.protocol_version)
179
180    def testAttributeServerVersion(self):
181        server_version = self.db.server_version
182        self.assertIsInstance(server_version, int)
183        self.assertTrue(70400 <= server_version < 100000)
184        self.assertEqual(server_version, self.db.db.server_version)
185
186    def testAttributeStatus(self):
187        status_ok = 1
188        status = self.db.status
189        self.assertIsInstance(status, int)
190        self.assertEqual(status, status_ok)
191        self.assertEqual(status, self.db.db.status)
192
193    def testAttributeUser(self):
194        no_user = 'Deprecated facility'
195        user = self.db.user
196        self.assertTrue(user)
197        self.assertIsInstance(user, str)
198        self.assertNotEqual(user, no_user)
199        self.assertEqual(user, self.db.db.user)
200
201    def testMethodEscapeLiteral(self):
202        self.assertEqual(self.db.escape_literal(''), "''")
203
204    def testMethodEscapeIdentifier(self):
205        self.assertEqual(self.db.escape_identifier(''), '""')
206
207    def testMethodEscapeString(self):
208        self.assertEqual(self.db.escape_string(''), '')
209
210    def testMethodEscapeBytea(self):
211        self.assertEqual(self.db.escape_bytea('').replace(
212            '\\x', '').replace('\\', ''), '')
213
214    def testMethodUnescapeBytea(self):
215        self.assertEqual(self.db.unescape_bytea(''), b'')
216
217    def testMethodQuery(self):
218        query = self.db.query
219        query("select 1+1")
220        query("select 1+$1+$2", 2, 3)
221        query("select 1+$1+$2", (2, 3))
222        query("select 1+$1+$2", [2, 3])
223        query("select 1+$1", 1)
224
225    def testMethodQueryEmpty(self):
226        self.assertRaises(ValueError, self.db.query, '')
227
228    def testMethodQueryProgrammingError(self):
229        try:
230            self.db.query("select 1/0")
231        except pg.ProgrammingError as error:
232            self.assertEqual(error.sqlstate, '22012')
233
234    def testMethodEndcopy(self):
235        try:
236            self.db.endcopy()
237        except IOError:
238            pass
239
240    def testMethodClose(self):
241        self.db.close()
242        try:
243            self.db.reset()
244        except pg.Error:
245            pass
246        else:
247            self.fail('Reset should give an error for a closed connection')
248        self.assertRaises(pg.InternalError, self.db.close)
249        self.assertRaises(pg.InternalError, self.db.query, 'select 1')
250
251    def testExistingConnection(self):
252        db = pg.DB(self.db.db)
253        self.assertEqual(self.db.db, db.db)
254        self.assertTrue(db.db)
255        db.close()
256        self.assertTrue(db.db)
257        db.reopen()
258        self.assertTrue(db.db)
259        db.close()
260        self.assertTrue(db.db)
261        db = pg.DB(self.db)
262        self.assertEqual(self.db.db, db.db)
263        db = pg.DB(db=self.db.db)
264        self.assertEqual(self.db.db, db.db)
265
266        class DB2:
267            pass
268
269        db2 = DB2()
270        db2._cnx = self.db.db
271        db = pg.DB(db2)
272        self.assertEqual(self.db.db, db.db)
273
274
275class TestDBClass(unittest.TestCase):
276    """Test the methods of the DB class wrapped pg connection."""
277
278    @classmethod
279    def setUpClass(cls):
280        db = DB()
281        db.query("drop table if exists test cascade")
282        db.query("create table test ("
283            "i2 smallint, i4 integer, i8 bigint,"
284            "d numeric, f4 real, f8 double precision, m money, "
285            "v4 varchar(4), c4 char(4), t text)")
286        db.query("create or replace view test_view as"
287            " select i4, v4 from test")
288        db.close()
289
290    @classmethod
291    def tearDownClass(cls):
292        db = DB()
293        db.query("drop table test cascade")
294        db.close()
295
296    def setUp(self):
297        self.db = DB()
298        query = self.db.query
299        query('set client_encoding=utf8')
300        query('set standard_conforming_strings=on')
301        query("set lc_monetary='C'")
302        query("set datestyle='ISO,YMD'")
303        query('set bytea_output=hex')
304
305    def tearDown(self):
306        self.db.close()
307
308    def testClassName(self):
309        self.assertEqual(self.db.__class__.__name__, 'DB')
310
311    def testModuleName(self):
312        self.assertEqual(self.db.__module__, 'pg')
313        self.assertEqual(self.db.__class__.__module__, 'pg')
314
315    def testEscapeLiteral(self):
316        f = self.db.escape_literal
317        r = f(b"plain")
318        self.assertIsInstance(r, bytes)
319        self.assertEqual(r, b"'plain'")
320        r = f(u"plain")
321        self.assertIsInstance(r, unicode)
322        self.assertEqual(r, u"'plain'")
323        r = f(u"that's kÀse".encode('utf-8'))
324        self.assertIsInstance(r, bytes)
325        self.assertEqual(r, u"'that''s kÀse'".encode('utf-8'))
326        r = f(u"that's kÀse")
327        self.assertIsInstance(r, unicode)
328        self.assertEqual(r, u"'that''s kÀse'")
329        self.assertEqual(f(r"It's fine to have a \ inside."),
330            r" E'It''s fine to have a \\ inside.'")
331        self.assertEqual(f('No "quotes" must be escaped.'),
332            "'No \"quotes\" must be escaped.'")
333
334    def testEscapeIdentifier(self):
335        f = self.db.escape_identifier
336        r = f(b"plain")
337        self.assertIsInstance(r, bytes)
338        self.assertEqual(r, b'"plain"')
339        r = f(u"plain")
340        self.assertIsInstance(r, unicode)
341        self.assertEqual(r, u'"plain"')
342        r = f(u"that's kÀse".encode('utf-8'))
343        self.assertIsInstance(r, bytes)
344        self.assertEqual(r, u'"that\'s kÀse"'.encode('utf-8'))
345        r = f(u"that's kÀse")
346        self.assertIsInstance(r, unicode)
347        self.assertEqual(r, u'"that\'s kÀse"')
348        self.assertEqual(f(r"It's fine to have a \ inside."),
349            '"It\'s fine to have a \\ inside."')
350        self.assertEqual(f('All "quotes" must be escaped.'),
351            '"All ""quotes"" must be escaped."')
352
353    def testEscapeString(self):
354        f = self.db.escape_string
355        r = f(b"plain")
356        self.assertIsInstance(r, bytes)
357        self.assertEqual(r, b"plain")
358        r = f(u"plain")
359        self.assertIsInstance(r, unicode)
360        self.assertEqual(r, u"plain")
361        r = f(u"that's kÀse".encode('utf-8'))
362        self.assertIsInstance(r, bytes)
363        self.assertEqual(r, u"that''s kÀse".encode('utf-8'))
364        r = f(u"that's kÀse")
365        self.assertIsInstance(r, unicode)
366        self.assertEqual(r, u"that''s kÀse")
367        self.assertEqual(f(r"It's fine to have a \ inside."),
368            r"It''s fine to have a \ inside.")
369
370    def testEscapeBytea(self):
371        f = self.db.escape_bytea
372        # note that escape_byte always returns hex output since Pg 9.0,
373        # regardless of the bytea_output setting
374        r = f(b'plain')
375        self.assertIsInstance(r, bytes)
376        self.assertEqual(r, b'\\x706c61696e')
377        r = f(u'plain')
378        self.assertIsInstance(r, unicode)
379        self.assertEqual(r, u'\\x706c61696e')
380        r = f(u"das is' kÀse".encode('utf-8'))
381        self.assertIsInstance(r, bytes)
382        self.assertEqual(r, b'\\x64617320697327206bc3a47365')
383        r = f(u"das is' kÀse")
384        self.assertIsInstance(r, unicode)
385        self.assertEqual(r, u'\\x64617320697327206bc3a47365')
386        self.assertEqual(f(b'O\x00ps\xff!'), b'\\x4f007073ff21')
387
388    def testUnescapeBytea(self):
389        f = self.db.unescape_bytea
390        r = f(b'plain')
391        self.assertIsInstance(r, bytes)
392        self.assertEqual(r, b'plain')
393        r = f(u'plain')
394        self.assertIsInstance(r, bytes)
395        self.assertEqual(r, b'plain')
396        r = f(b"das is' k\\303\\244se")
397        self.assertIsInstance(r, bytes)
398        self.assertEqual(r, u"das is' kÀse".encode('utf8'))
399        r = f(u"das is' k\\303\\244se")
400        self.assertIsInstance(r, bytes)
401        self.assertEqual(r, u"das is' kÀse".encode('utf8'))
402        self.assertEqual(f(r'O\\000ps\\377!'), b'O\\000ps\\377!')
403        self.assertEqual(f(r'\\x706c61696e'), b'\\x706c61696e')
404        self.assertEqual(f(r'\\x746861742773206be47365'),
405            b'\\x746861742773206be47365')
406        self.assertEqual(f(r'\\x4f007073ff21'), b'\\x4f007073ff21')
407
408    def testQuery(self):
409        query = self.db.query
410        query("drop table if exists test_table")
411        q = "create table test_table (n integer) with oids"
412        r = query(q)
413        self.assertIsNone(r)
414        q = "insert into test_table values (1)"
415        r = query(q)
416        self.assertIsInstance(r, int)
417        q = "insert into test_table select 2"
418        r = query(q)
419        self.assertIsInstance(r, int)
420        oid = r
421        q = "select oid from test_table where n=2"
422        r = query(q).getresult()
423        self.assertEqual(len(r), 1)
424        r = r[0]
425        self.assertEqual(len(r), 1)
426        r = r[0]
427        self.assertEqual(r, oid)
428        q = "insert into test_table select 3 union select 4 union select 5"
429        r = query(q)
430        self.assertIsInstance(r, str)
431        self.assertEqual(r, '3')
432        q = "update test_table set n=4 where n<5"
433        r = query(q)
434        self.assertIsInstance(r, str)
435        self.assertEqual(r, '4')
436        q = "delete from test_table"
437        r = query(q)
438        self.assertIsInstance(r, str)
439        self.assertEqual(r, '5')
440        query("drop table test_table")
441
442    def testMultipleQueries(self):
443        self.assertEqual(self.db.query(
444            "create temporary table test_multi (n integer);"
445            "insert into test_multi values (4711);"
446            "select n from test_multi").getresult()[0][0], 4711)
447
448    def testQueryWithParams(self):
449        query = self.db.query
450        query("drop table if exists test_table")
451        q = "create table test_table (n1 integer, n2 integer) with oids"
452        query(q)
453        q = "insert into test_table values ($1, $2)"
454        r = query(q, (1, 2))
455        self.assertIsInstance(r, int)
456        r = query(q, [3, 4])
457        self.assertIsInstance(r, int)
458        r = query(q, [5, 6])
459        self.assertIsInstance(r, int)
460        q = "select * from test_table order by 1, 2"
461        self.assertEqual(query(q).getresult(),
462            [(1, 2), (3, 4), (5, 6)])
463        q = "select * from test_table where n1=$1 and n2=$2"
464        self.assertEqual(query(q, 3, 4).getresult(), [(3, 4)])
465        q = "update test_table set n2=$2 where n1=$1"
466        r = query(q, 3, 7)
467        self.assertEqual(r, '1')
468        q = "select * from test_table order by 1, 2"
469        self.assertEqual(query(q).getresult(),
470            [(1, 2), (3, 7), (5, 6)])
471        q = "delete from test_table where n2!=$1"
472        r = query(q, 4)
473        self.assertEqual(r, '3')
474        query("drop table test_table")
475
476    def testEmptyQuery(self):
477        self.assertRaises(ValueError, self.db.query, '')
478
479    def testQueryProgrammingError(self):
480        try:
481            self.db.query("select 1/0")
482        except pg.ProgrammingError as error:
483            self.assertEqual(error.sqlstate, '22012')
484
485    def testPkey(self):
486        query = self.db.query
487        pkey = self.db.pkey
488        for t in ('pkeytest', 'primary key test'):
489            for n in range(7):
490                query('drop table if exists "%s%d"' % (t, n))
491            query('create table "%s0" ('
492                "a smallint)" % t)
493            query('create table "%s1" ('
494                "b smallint primary key)" % t)
495            query('create table "%s2" ('
496                "c smallint, d smallint primary key)" % t)
497            query('create table "%s3" ('
498                "e smallint, f smallint, g smallint, "
499                "h smallint, i smallint, "
500                "primary key (f, h))" % t)
501            query('create table "%s4" ('
502                "more_than_one_letter varchar primary key)" % t)
503            query('create table "%s5" ('
504                '"with space" date primary key)' % t)
505            query('create table "%s6" ('
506                'a_very_long_column_name varchar, '
507                '"with space" date, '
508                '"42" int, '
509                "primary key (a_very_long_column_name, "
510                '"with space", "42"))' % t)
511            self.assertRaises(KeyError, pkey, '%s0' % t)
512            self.assertEqual(pkey('%s1' % t), 'b')
513            self.assertEqual(pkey('%s2' % t), 'd')
514            r = pkey('%s3' % t)
515            self.assertIsInstance(r, frozenset)
516            self.assertEqual(r, frozenset('fh'))
517            self.assertEqual(pkey('%s4' % t), 'more_than_one_letter')
518            self.assertEqual(pkey('%s5' % t), 'with space')
519            r = pkey('%s6' % t)
520            self.assertIsInstance(r, frozenset)
521            self.assertEqual(r, frozenset([
522                'a_very_long_column_name', 'with space', '42']))
523            # a newly added primary key will be detected
524            query('alter table "%s0" add primary key (a)' % t)
525            self.assertEqual(pkey('%s0' % t), 'a')
526            # a changed primary key will not be detected,
527            # indicating that the internal cache is operating
528            query('alter table "%s1" rename column b to x' % t)
529            self.assertEqual(pkey('%s1' % t), 'b')
530            # we get the changed primary key when the cache is flushed
531            self.assertEqual(pkey('%s1' % t, flush=True), 'x')
532            for n in range(7):
533                query('drop table "%s%d"' % (t, n))
534
535    def testGetDatabases(self):
536        databases = self.db.get_databases()
537        self.assertIn('template0', databases)
538        self.assertIn('template1', databases)
539        self.assertNotIn('not existing database', databases)
540        self.assertIn('postgres', databases)
541        self.assertIn(dbname, databases)
542
543    def testGetTables(self):
544        get_tables = self.db.get_tables
545        result1 = get_tables()
546        self.assertIsInstance(result1, list)
547        for t in result1:
548            t = t.split('.', 1)
549            self.assertGreaterEqual(len(t), 2)
550            if len(t) > 2:
551                self.assertTrue(t[1].startswith('"'))
552            t = t[0]
553            self.assertNotEqual(t, 'information_schema')
554            self.assertFalse(t.startswith('pg_'))
555        tables = ('"A very Special Name"',
556            '"A_MiXeD_quoted_NaMe"', 'a1', 'a2',
557            'A_MiXeD_NaMe', '"another special name"',
558            'averyveryveryveryveryveryverylongtablename',
559            'b0', 'b3', 'x', 'xx', 'xXx', 'y', 'z')
560        for t in tables:
561            self.db.query('drop table if exists %s' % t)
562            self.db.query("create table %s"
563                " as select 0" % t)
564        result3 = get_tables()
565        result2 = []
566        for t in result3:
567            if t not in result1:
568                result2.append(t)
569        result3 = []
570        for t in tables:
571            if not t.startswith('"'):
572                t = t.lower()
573            result3.append('public.' + t)
574        self.assertEqual(result2, result3)
575        for t in result2:
576            self.db.query('drop table %s' % t)
577        result2 = get_tables()
578        self.assertEqual(result2, result1)
579
580    def testGetRelations(self):
581        get_relations = self.db.get_relations
582        result = get_relations()
583        self.assertIn('public.test', result)
584        self.assertIn('public.test_view', result)
585        result = get_relations('rv')
586        self.assertIn('public.test', result)
587        self.assertIn('public.test_view', result)
588        result = get_relations('r')
589        self.assertIn('public.test', result)
590        self.assertNotIn('public.test_view', result)
591        result = get_relations('v')
592        self.assertNotIn('public.test', result)
593        self.assertIn('public.test_view', result)
594        result = get_relations('cisSt')
595        self.assertNotIn('public.test', result)
596        self.assertNotIn('public.test_view', result)
597
598    def testAttnames(self):
599        self.assertRaises(pg.ProgrammingError,
600            self.db.get_attnames, 'does_not_exist')
601        self.assertRaises(pg.ProgrammingError,
602            self.db.get_attnames, 'has.too.many.dots')
603        for table in ('attnames_test_table', 'test table for attnames'):
604            self.db.query('drop table if exists "%s"' % table)
605            self.db.query('create table "%s" ('
606                'a smallint, b integer, c bigint, '
607                'e numeric, f float, f2 double precision, m money, '
608                'x smallint, y smallint, z smallint, '
609                'Normal_NaMe smallint, "Special Name" smallint, '
610                't text, u char(2), v varchar(2), '
611                'primary key (y, u)) with oids' % table)
612            attributes = self.db.get_attnames(table)
613            result = {'a': 'int', 'c': 'int', 'b': 'int',
614                'e': 'num', 'f': 'float', 'f2': 'float', 'm': 'money',
615                'normal_name': 'int', 'Special Name': 'int',
616                'u': 'text', 't': 'text', 'v': 'text',
617                'y': 'int', 'x': 'int', 'z': 'int', 'oid': 'int'}
618            self.assertEqual(attributes, result)
619            self.db.query('drop table "%s"' % table)
620
621    def testHasTablePrivilege(self):
622        can = self.db.has_table_privilege
623        self.assertEqual(can('test'), True)
624        self.assertEqual(can('test', 'select'), True)
625        self.assertEqual(can('test', 'SeLeCt'), True)
626        self.assertEqual(can('test', 'SELECT'), True)
627        self.assertEqual(can('test', 'insert'), True)
628        self.assertEqual(can('test', 'update'), True)
629        self.assertEqual(can('test', 'delete'), True)
630        self.assertEqual(can('pg_views', 'select'), True)
631        self.assertEqual(can('pg_views', 'delete'), False)
632        self.assertRaises(pg.ProgrammingError, can, 'test', 'foobar')
633        self.assertRaises(pg.ProgrammingError, can, 'table_does_not_exist')
634
635    def testGet(self):
636        get = self.db.get
637        query = self.db.query
638        table = 'get_test_table'
639        query('drop table if exists "%s"' % table)
640        query('create table "%s" ('
641            "n integer, t text) with oids" % table)
642        for n, t in enumerate('xyz'):
643            query('insert into "%s" values('"%d, '%s')"
644                % (table, n + 1, t))
645        self.assertRaises(pg.ProgrammingError, get, table, 2)
646        r = get(table, 2, 'n')
647        oid_table = 'oid(%s)' % table
648        self.assertIn(oid_table, r)
649        oid = r[oid_table]
650        self.assertIsInstance(oid, int)
651        result = {'t': 'y', 'n': 2, oid_table: oid}
652        self.assertEqual(r, result)
653        self.assertEqual(get(table + ' *', 2, 'n'), r)
654        self.assertEqual(get(table, oid, 'oid')['t'], 'y')
655        self.assertEqual(get(table, 1, 'n')['t'], 'x')
656        self.assertEqual(get(table, 3, 'n')['t'], 'z')
657        self.assertEqual(get(table, 2, 'n')['t'], 'y')
658        self.assertRaises(pg.DatabaseError, get, table, 4, 'n')
659        r['n'] = 3
660        self.assertEqual(get(table, r, 'n')['t'], 'z')
661        self.assertEqual(get(table, 1, 'n')['t'], 'x')
662        query('alter table "%s" alter n set not null' % table)
663        query('alter table "%s" add primary key (n)' % table)
664        self.assertEqual(get(table, 3)['t'], 'z')
665        self.assertEqual(get(table, 1)['t'], 'x')
666        self.assertEqual(get(table, 2)['t'], 'y')
667        r['n'] = 1
668        self.assertEqual(get(table, r)['t'], 'x')
669        r['n'] = 3
670        self.assertEqual(get(table, r)['t'], 'z')
671        r['n'] = 2
672        self.assertEqual(get(table, r)['t'], 'y')
673        query('drop table "%s"' % table)
674
675    def testGetWithCompositeKey(self):
676        get = self.db.get
677        query = self.db.query
678        table = 'get_test_table_1'
679        query('drop table if exists "%s"' % table)
680        query('create table "%s" ('
681            "n integer, t text, primary key (n))" % table)
682        for n, t in enumerate('abc'):
683            query('insert into "%s" values('
684                "%d, '%s')" % (table, n + 1, t))
685        self.assertEqual(get(table, 2)['t'], 'b')
686        query('drop table "%s"' % table)
687        table = 'get_test_table_2'
688        query('drop table if exists "%s"' % table)
689        query('create table "%s" ('
690            "n integer, m integer, t text, primary key (n, m))" % table)
691        for n in range(3):
692            for m in range(2):
693                t = chr(ord('a') + 2 * n + m)
694                query('insert into "%s" values('
695                    "%d, %d, '%s')" % (table, n + 1, m + 1, t))
696        self.assertRaises(pg.ProgrammingError, get, table, 2)
697        self.assertEqual(get(table, dict(n=2, m=2))['t'], 'd')
698        self.assertEqual(get(table, dict(n=1, m=2),
699                             ('n', 'm'))['t'], 'b')
700        self.assertEqual(get(table, dict(n=3, m=2),
701                             frozenset(['n', 'm']))['t'], 'f')
702        query('drop table "%s"' % table)
703
704    def testGetWithQuotedNames(self):
705        get = self.db.get
706        query = self.db.query
707        table = 'test table for get()'
708        query('drop table if exists "%s"' % table)
709        query('create table "%s" ('
710            '"Prime!" smallint primary key,'
711            '"much space" integer, "Questions?" text)' % table)
712        query('insert into "%s"'
713              " values(17, 1001, 'No!')" % table)
714        r = get(table, 17)
715        self.assertIsInstance(r, dict)
716        self.assertEqual(r['Prime!'], 17)
717        self.assertEqual(r['much space'], 1001)
718        self.assertEqual(r['Questions?'], 'No!')
719        query('drop table "%s"' % table)
720
721    def testGetFromView(self):
722        self.db.query('delete from test where i4=14')
723        self.db.query('insert into test (i4, v4) values('
724            "14, 'abc4')")
725        r = self.db.get('test_view', 14, 'i4')
726        self.assertIn('v4', r)
727        self.assertEqual(r['v4'], 'abc4')
728
729    def testInsert(self):
730        insert = self.db.insert
731        query = self.db.query
732        bool_on = pg.get_bool()
733        decimal = pg.get_decimal()
734        table = 'insert_test_table'
735        query('drop table if exists "%s"' % table)
736        query('create table "%s" ('
737            "i2 smallint, i4 integer, i8 bigint,"
738            " d numeric, f4 real, f8 double precision, m money,"
739            " v4 varchar(4), c4 char(4), t text,"
740            " b boolean, ts timestamp) with oids" % table)
741        oid_table = 'oid(%s)' % table
742        tests = [dict(i2=None, i4=None, i8=None),
743            (dict(i2='', i4='', i8=''), dict(i2=None, i4=None, i8=None)),
744            (dict(i2=0, i4=0, i8=0), dict(i2=0, i4=0, i8=0)),
745            dict(i2=42, i4=123456, i8=9876543210),
746            dict(i2=2 ** 15 - 1,
747                i4=int(2 ** 31 - 1), i8=long(2 ** 63 - 1)),
748            dict(d=None), (dict(d=''), dict(d=None)),
749            dict(d=Decimal(0)), (dict(d=0), dict(d=Decimal(0))),
750            dict(f4=None, f8=None), dict(f4=0, f8=0),
751            (dict(f4='', f8=''), dict(f4=None, f8=None)),
752            (dict(d=1234.5, f4=1234.5, f8=1234.5),
753                  dict(d=Decimal('1234.5'))),
754            dict(d=Decimal('123.456789'), f4=12.375, f8=123.4921875),
755            dict(d=Decimal('123456789.9876543212345678987654321')),
756            dict(m=None), (dict(m=''), dict(m=None)),
757            dict(m=Decimal('-1234.56')),
758            (dict(m=('-1234.56')), dict(m=Decimal('-1234.56'))),
759            dict(m=Decimal('1234.56')), dict(m=Decimal('123456')),
760            (dict(m='1234.56'), dict(m=Decimal('1234.56'))),
761            (dict(m=1234.5), dict(m=Decimal('1234.5'))),
762            (dict(m=-1234.5), dict(m=Decimal('-1234.5'))),
763            (dict(m=123456), dict(m=Decimal('123456'))),
764            (dict(m='1234567.89'), dict(m=Decimal('1234567.89'))),
765            dict(b=None), (dict(b=''), dict(b=None)),
766            dict(b='f'), dict(b='t'),
767            (dict(b=0), dict(b='f')), (dict(b=1), dict(b='t')),
768            (dict(b=False), dict(b='f')), (dict(b=True), dict(b='t')),
769            (dict(b='0'), dict(b='f')), (dict(b='1'), dict(b='t')),
770            (dict(b='n'), dict(b='f')), (dict(b='y'), dict(b='t')),
771            (dict(b='no'), dict(b='f')), (dict(b='yes'), dict(b='t')),
772            (dict(b='off'), dict(b='f')), (dict(b='on'), dict(b='t')),
773            dict(v4=None, c4=None, t=None),
774            (dict(v4='', c4='', t=''), dict(c4=' ' * 4)),
775            dict(v4='1234', c4='1234', t='1234' * 10),
776            dict(v4='abcd', c4='abcd', t='abcdefg'),
777            (dict(v4='abc', c4='abc', t='abc'), dict(c4='abc ')),
778            dict(ts=None), (dict(ts=''), dict(ts=None)),
779            (dict(ts=0), dict(ts=None)), (dict(ts=False), dict(ts=None)),
780            dict(ts='2012-12-21 00:00:00'),
781            (dict(ts='2012-12-21'), dict(ts='2012-12-21 00:00:00')),
782            dict(ts='2012-12-21 12:21:12'),
783            dict(ts='2013-01-05 12:13:14'),
784            dict(ts='current_timestamp')]
785        for test in tests:
786            if isinstance(test, dict):
787                data = test
788                change = {}
789            else:
790                data, change = test
791            expect = data.copy()
792            expect.update(change)
793            if bool_on:
794                b = expect.get('b')
795                if b is not None:
796                    expect['b'] = b == 't'
797            if decimal is not Decimal:
798                d = expect.get('d')
799                if d is not None:
800                    expect['d'] = decimal(d)
801                m = expect.get('m')
802                if m is not None:
803                    expect['m'] = decimal(m)
804            self.assertEqual(insert(table, data), data)
805            self.assertIn(oid_table, data)
806            oid = data[oid_table]
807            self.assertIsInstance(oid, int)
808            data = dict(item for item in data.items()
809                if item[0] in expect)
810            ts = expect.get('ts')
811            if ts == 'current_timestamp':
812                ts = expect['ts'] = data['ts']
813                if len(ts) > 19:
814                    self.assertEqual(ts[19], '.')
815                    ts = ts[:19]
816                else:
817                    self.assertEqual(len(ts), 19)
818                self.assertTrue(ts[:4].isdigit())
819                self.assertEqual(ts[4], '-')
820                self.assertEqual(ts[10], ' ')
821                self.assertTrue(ts[11:13].isdigit())
822                self.assertEqual(ts[13], ':')
823            self.assertEqual(data, expect)
824            data = query(
825                'select oid,* from "%s"' % table).dictresult()[0]
826            self.assertEqual(data['oid'], oid)
827            data = dict(item for item in data.items()
828                if item[0] in expect)
829            self.assertEqual(data, expect)
830            query('delete from "%s"' % table)
831        query('drop table "%s"' % table)
832
833    def testInsertWithQuotedNames(self):
834        insert = self.db.insert
835        query = self.db.query
836        table = 'test table for insert()'
837        query('drop table if exists "%s"' % table)
838        query('create table "%s" ('
839            '"Prime!" smallint primary key,'
840            '"much space" integer, "Questions?" text)' % table)
841        r = {'Prime!': 11, 'much space': 2002, 'Questions?': 'What?'}
842        r = insert(table, r)
843        self.assertIsInstance(r, dict)
844        self.assertEqual(r['Prime!'], 11)
845        self.assertEqual(r['much space'], 2002)
846        self.assertEqual(r['Questions?'], 'What?')
847        r = query('select * from "%s" limit 2' % table).dictresult()
848        self.assertEqual(len(r), 1)
849        r = r[0]
850        self.assertEqual(r['Prime!'], 11)
851        self.assertEqual(r['much space'], 2002)
852        self.assertEqual(r['Questions?'], 'What?')
853        query('drop table "%s"' % table)
854
855    def testUpdate(self):
856        update = self.db.update
857        query = self.db.query
858        table = 'update_test_table'
859        query('drop table if exists "%s"' % table)
860        query('create table "%s" ('
861            "n integer, t text) with oids" % table)
862        for n, t in enumerate('xyz'):
863            query('insert into "%s" values('
864                "%d, '%s')" % (table, n + 1, t))
865        self.assertRaises(pg.ProgrammingError, self.db.get, table, 2)
866        r = self.db.get(table, 2, 'n')
867        r['t'] = 'u'
868        s = update(table, r)
869        self.assertEqual(s, r)
870        r = query('select t from "%s" where n=2' % table
871                  ).getresult()[0][0]
872        self.assertEqual(r, 'u')
873        query('drop table "%s"' % table)
874
875    def testUpdateWithCompositeKey(self):
876        update = self.db.update
877        query = self.db.query
878        table = 'update_test_table_1'
879        query('drop table if exists "%s"' % table)
880        query('create table "%s" ('
881            "n integer, t text, primary key (n))" % table)
882        for n, t in enumerate('abc'):
883            query('insert into "%s" values('
884                "%d, '%s')" % (table, n + 1, t))
885        self.assertRaises(pg.ProgrammingError, update,
886                          table, dict(t='b'))
887        self.assertEqual(update(table, dict(n=2, t='d'))['t'], 'd')
888        r = query('select t from "%s" where n=2' % table
889                  ).getresult()[0][0]
890        self.assertEqual(r, 'd')
891        query('drop table "%s"' % table)
892        table = 'update_test_table_2'
893        query('drop table if exists "%s"' % table)
894        query('create table "%s" ('
895            "n integer, m integer, t text, primary key (n, m))" % table)
896        for n in range(3):
897            for m in range(2):
898                t = chr(ord('a') + 2 * n + m)
899                query('insert into "%s" values('
900                    "%d, %d, '%s')" % (table, n + 1, m + 1, t))
901        self.assertRaises(pg.ProgrammingError, update,
902                          table, dict(n=2, t='b'))
903        self.assertEqual(update(table,
904                                dict(n=2, m=2, t='x'))['t'], 'x')
905        r = [r[0] for r in query('select t from "%s" where n=2'
906            ' order by m' % table).getresult()]
907        self.assertEqual(r, ['c', 'x'])
908        query('drop table "%s"' % table)
909
910    def testUpdateWithQuotedNames(self):
911        update = self.db.update
912        query = self.db.query
913        table = 'test table for update()'
914        query('drop table if exists "%s"' % table)
915        query('create table "%s" ('
916            '"Prime!" smallint primary key,'
917            '"much space" integer, "Questions?" text)' % table)
918        query('insert into "%s"'
919              " values(13, 3003, 'Why!')" % table)
920        r = {'Prime!': 13, 'much space': 7007, 'Questions?': 'When?'}
921        r = update(table, r)
922        self.assertIsInstance(r, dict)
923        self.assertEqual(r['Prime!'], 13)
924        self.assertEqual(r['much space'], 7007)
925        self.assertEqual(r['Questions?'], 'When?')
926        r = query('select * from "%s" limit 2' % table).dictresult()
927        self.assertEqual(len(r), 1)
928        r = r[0]
929        self.assertEqual(r['Prime!'], 13)
930        self.assertEqual(r['much space'], 7007)
931        self.assertEqual(r['Questions?'], 'When?')
932        query('drop table "%s"' % table)
933
934    def testClear(self):
935        clear = self.db.clear
936        query = self.db.query
937        f = False if pg.get_bool() else 'f'
938        table = 'clear_test_table'
939        query('drop table if exists "%s"' % table)
940        query('create table "%s" ('
941            "n integer, b boolean, d date, t text)" % table)
942        r = clear(table)
943        result = {'n': 0, 'b': f, 'd': '', 't': ''}
944        self.assertEqual(r, result)
945        r['a'] = r['n'] = 1
946        r['d'] = r['t'] = 'x'
947        r['b'] = 't'
948        r['oid'] = long(1)
949        r = clear(table, r)
950        result = {'a': 1, 'n': 0, 'b': f, 'd': '', 't': '',
951            'oid': long(1)}
952        self.assertEqual(r, result)
953        query('drop table "%s"' % table)
954
955    def testClearWithQuotedNames(self):
956        clear = self.db.clear
957        query = self.db.query
958        table = 'test table for clear()'
959        query('drop table if exists "%s"' % table)
960        query('create table "%s" ('
961            '"Prime!" smallint primary key,'
962            '"much space" integer, "Questions?" text)' % table)
963        r = clear(table)
964        self.assertIsInstance(r, dict)
965        self.assertEqual(r['Prime!'], 0)
966        self.assertEqual(r['much space'], 0)
967        self.assertEqual(r['Questions?'], '')
968        query('drop table "%s"' % table)
969
970    def testDelete(self):
971        delete = self.db.delete
972        query = self.db.query
973        table = 'delete_test_table'
974        query('drop table if exists "%s"' % table)
975        query('create table "%s" ('
976            "n integer, t text) with oids" % table)
977        for n, t in enumerate('xyz'):
978            query('insert into "%s" values('
979                "%d, '%s')" % (table, n + 1, t))
980        self.assertRaises(pg.ProgrammingError, self.db.get, table, 2)
981        r = self.db.get(table, 1, 'n')
982        s = delete(table, r)
983        self.assertEqual(s, 1)
984        r = self.db.get(table, 3, 'n')
985        s = delete(table, r)
986        self.assertEqual(s, 1)
987        s = delete(table, r)
988        self.assertEqual(s, 0)
989        r = query('select * from "%s"' % table).dictresult()
990        self.assertEqual(len(r), 1)
991        r = r[0]
992        result = {'n': 2, 't': 'y'}
993        self.assertEqual(r, result)
994        r = self.db.get(table, 2, 'n')
995        s = delete(table, r)
996        self.assertEqual(s, 1)
997        s = delete(table, r)
998        self.assertEqual(s, 0)
999        self.assertRaises(pg.DatabaseError, self.db.get, table, 2, 'n')
1000        query('drop table "%s"' % table)
1001
1002    def testDeleteWithCompositeKey(self):
1003        query = self.db.query
1004        table = 'delete_test_table_1'
1005        query('drop table if exists "%s"' % table)
1006        query('create table "%s" ('
1007            "n integer, t text, primary key (n))" % table)
1008        for n, t in enumerate('abc'):
1009            query("insert into %s values("
1010                "%d, '%s')" % (table, n + 1, t))
1011        self.assertRaises(pg.ProgrammingError, self.db.delete,
1012            table, dict(t='b'))
1013        self.assertEqual(self.db.delete(table, dict(n=2)), 1)
1014        r = query('select t from "%s" where n=2' % table
1015                  ).getresult()
1016        self.assertEqual(r, [])
1017        self.assertEqual(self.db.delete(table, dict(n=2)), 0)
1018        r = query('select t from "%s" where n=3' % table
1019                  ).getresult()[0][0]
1020        self.assertEqual(r, 'c')
1021        query('drop table "%s"' % table)
1022        table = 'delete_test_table_2'
1023        query('drop table if exists "%s"' % table)
1024        query('create table "%s" ('
1025            "n integer, m integer, t text, primary key (n, m))" % table)
1026        for n in range(3):
1027            for m in range(2):
1028                t = chr(ord('a') + 2 * n + m)
1029                query('insert into "%s" values('
1030                    "%d, %d, '%s')" % (table, n + 1, m + 1, t))
1031        self.assertRaises(pg.ProgrammingError, self.db.delete,
1032            table, dict(n=2, t='b'))
1033        self.assertEqual(self.db.delete(table, dict(n=2, m=2)), 1)
1034        r = [r[0] for r in query('select t from "%s" where n=2'
1035            ' order by m' % table).getresult()]
1036        self.assertEqual(r, ['c'])
1037        self.assertEqual(self.db.delete(table, dict(n=2, m=2)), 0)
1038        r = [r[0] for r in query('select t from "%s" where n=3'
1039            ' order by m' % table).getresult()]
1040        self.assertEqual(r, ['e', 'f'])
1041        self.assertEqual(self.db.delete(table, dict(n=3, m=1)), 1)
1042        r = [r[0] for r in query('select t from "%s" where n=3'
1043            ' order by m' % table).getresult()]
1044        self.assertEqual(r, ['f'])
1045        query('drop table "%s"' % table)
1046
1047    def testDeleteWithQuotedNames(self):
1048        delete = self.db.delete
1049        query = self.db.query
1050        table = 'test table for delete()'
1051        query('drop table if exists "%s"' % table)
1052        query('create table "%s" ('
1053            '"Prime!" smallint primary key,'
1054            '"much space" integer, "Questions?" text)' % table)
1055        query('insert into "%s"'
1056              " values(19, 5005, 'Yes!')" % table)
1057        r = {'Prime!': 17}
1058        r = delete(table, r)
1059        self.assertEqual(r, 0)
1060        r = query('select count(*) from "%s"' % table).getresult()
1061        self.assertEqual(r[0][0], 1)
1062        r = {'Prime!': 19}
1063        r = delete(table, r)
1064        self.assertEqual(r, 1)
1065        r = query('select count(*) from "%s"' % table).getresult()
1066        self.assertEqual(r[0][0], 0)
1067        query('drop table "%s"' % table)
1068
1069    def testTransaction(self):
1070        query = self.db.query
1071        query("drop table if exists test_table")
1072        query("create table test_table (n integer)")
1073        self.db.begin()
1074        query("insert into test_table values (1)")
1075        query("insert into test_table values (2)")
1076        self.db.commit()
1077        self.db.begin()
1078        query("insert into test_table values (3)")
1079        query("insert into test_table values (4)")
1080        self.db.rollback()
1081        self.db.begin()
1082        query("insert into test_table values (5)")
1083        self.db.savepoint('before6')
1084        query("insert into test_table values (6)")
1085        self.db.rollback('before6')
1086        query("insert into test_table values (7)")
1087        self.db.commit()
1088        self.db.begin()
1089        self.db.savepoint('before8')
1090        query("insert into test_table values (8)")
1091        self.db.release('before8')
1092        self.assertRaises(pg.ProgrammingError, self.db.rollback, 'before8')
1093        self.db.commit()
1094        self.db.start()
1095        query("insert into test_table values (9)")
1096        self.db.end()
1097        r = [r[0] for r in query(
1098            "select * from test_table order by 1").getresult()]
1099        self.assertEqual(r, [1, 2, 5, 7, 9])
1100        query("drop table test_table")
1101
1102    def testContextManager(self):
1103        query = self.db.query
1104        query("drop table if exists test_table")
1105        query("create table test_table (n integer check(n>0))")
1106        with self.db:
1107            query("insert into test_table values (1)")
1108            query("insert into test_table values (2)")
1109        try:
1110            with self.db:
1111                query("insert into test_table values (3)")
1112                query("insert into test_table values (4)")
1113                raise ValueError('test transaction should rollback')
1114        except ValueError as error:
1115            self.assertEqual(str(error), 'test transaction should rollback')
1116        with self.db:
1117            query("insert into test_table values (5)")
1118        try:
1119            with self.db:
1120                query("insert into test_table values (6)")
1121                query("insert into test_table values (-1)")
1122        except pg.ProgrammingError as error:
1123            self.assertTrue('check' in str(error))
1124        with self.db:
1125            query("insert into test_table values (7)")
1126        r = [r[0] for r in query(
1127            "select * from test_table order by 1").getresult()]
1128        self.assertEqual(r, [1, 2, 5, 7])
1129        query("drop table test_table")
1130
1131    def testBytea(self):
1132        query = self.db.query
1133        query('drop table if exists bytea_test')
1134        query('create table bytea_test (n smallint primary key, data bytea)')
1135        s = b"It's all \\ kinds \x00 of\r nasty \xff stuff!\n"
1136        r = self.db.escape_bytea(s)
1137        query('insert into bytea_test values(3,$1)', (r,))
1138        r = query('select * from bytea_test where n=3').getresult()
1139        self.assertEqual(len(r), 1)
1140        r = r[0]
1141        self.assertEqual(len(r), 2)
1142        self.assertEqual(r[0], 3)
1143        r = r[1]
1144        self.assertIsInstance(r, str)
1145        r = self.db.unescape_bytea(r)
1146        self.assertIsInstance(r, bytes)
1147        self.assertEqual(r, s)
1148        query('drop table bytea_test')
1149
1150    def testInsertUpdateGetBytea(self):
1151        query = self.db.query
1152        query('drop table if exists bytea_test')
1153        query('create table bytea_test (n smallint primary key, data bytea)')
1154        # insert null value
1155        r = self.db.insert('bytea_test', n=0, data=None)
1156        self.assertIsInstance(r, dict)
1157        self.assertIn('n', r)
1158        self.assertEqual(r['n'], 0)
1159        self.assertIn('data', r)
1160        self.assertIsNone(r['data'])
1161        s = b'None'
1162        r = self.db.update('bytea_test', n=0, data=s)
1163        self.assertIsInstance(r, dict)
1164        self.assertIn('n', r)
1165        self.assertEqual(r['n'], 0)
1166        self.assertIn('data', r)
1167        r = r['data']
1168        self.assertIsInstance(r, bytes)
1169        self.assertEqual(r, s)
1170        r = self.db.update('bytea_test', n=0, data=None)
1171        self.assertIsNone(r['data'])
1172        # insert as bytes
1173        s = b"It's all \\ kinds \x00 of\r nasty \xff stuff!\n"
1174        r = self.db.insert('bytea_test', n=5, data=s)
1175        self.assertIsInstance(r, dict)
1176        self.assertIn('n', r)
1177        self.assertEqual(r['n'], 5)
1178        self.assertIn('data', r)
1179        r = r['data']
1180        self.assertIsInstance(r, bytes)
1181        self.assertEqual(r, s)
1182        # update as bytes
1183        s += b"and now even more \x00 nasty \t stuff!\f"
1184        r = self.db.update('bytea_test', n=5, data=s)
1185        self.assertIsInstance(r, dict)
1186        self.assertIn('n', r)
1187        self.assertEqual(r['n'], 5)
1188        self.assertIn('data', r)
1189        r = r['data']
1190        self.assertIsInstance(r, bytes)
1191        self.assertEqual(r, s)
1192        r = query('select * from bytea_test where n=5').getresult()
1193        self.assertEqual(len(r), 1)
1194        r = r[0]
1195        self.assertEqual(len(r), 2)
1196        self.assertEqual(r[0], 5)
1197        r = r[1]
1198        self.assertIsInstance(r, str)
1199        r = self.db.unescape_bytea(r)
1200        self.assertIsInstance(r, bytes)
1201        self.assertEqual(r, s)
1202        r = self.db.get('bytea_test', dict(n=5))
1203        self.assertIsInstance(r, dict)
1204        self.assertIn('n', r)
1205        self.assertEqual(r['n'], 5)
1206        self.assertIn('data', r)
1207        r = r['data']
1208        self.assertIsInstance(r, bytes)
1209        self.assertEqual(r, s)
1210        query('drop table bytea_test')
1211
1212    def testDebugWithCallable(self):
1213        if debug:
1214            self.assertEqual(self.db.debug, debug)
1215        else:
1216            self.assertIsNone(self.db.debug)
1217        s = []
1218        self.db.debug = s.append
1219        try:
1220            self.db.query("select 1")
1221            self.db.query("select 2")
1222            self.assertEqual(s, ["select 1", "select 2"])
1223        finally:
1224            self.db.debug = debug
1225
1226
1227class TestDBClassNonStdOpts(TestDBClass):
1228    """Test the methods of the DB class with non-standard global options."""
1229
1230    @classmethod
1231    def setUpClass(cls):
1232        cls.saved_options = {}
1233        cls.set_option('decimal', float)
1234        not_bool = not pg.get_bool()
1235        cls.set_option('bool', not_bool)
1236        unnamed_result = lambda q: q.getresult()
1237        cls.set_option('namedresult', unnamed_result)
1238        super(TestDBClassNonStdOpts, cls).setUpClass()
1239
1240    @classmethod
1241    def tearDownClass(cls):
1242        super(TestDBClassNonStdOpts, cls).tearDownClass()
1243        cls.reset_option('namedresult')
1244        cls.reset_option('bool')
1245        cls.reset_option('decimal')
1246
1247    @classmethod
1248    def set_option(cls, option, value):
1249        cls.saved_options[option] = getattr(pg, 'get_' + option)()
1250        return getattr(pg, 'set_' + option)(value)
1251
1252    @classmethod
1253    def reset_option(cls, option):
1254        return getattr(pg, 'set_' + option)(cls.saved_options[option])
1255
1256
1257class TestSchemas(unittest.TestCase):
1258    """Test correct handling of schemas (namespaces)."""
1259
1260    @classmethod
1261    def setUpClass(cls):
1262        db = DB()
1263        query = db.query
1264        query("set client_min_messages=warning")
1265        for num_schema in range(5):
1266            if num_schema:
1267                schema = "s%d" % num_schema
1268                query("drop schema if exists %s cascade" % (schema,))
1269                try:
1270                    query("create schema %s" % (schema,))
1271                except pg.ProgrammingError:
1272                    raise RuntimeError("The test user cannot create schemas.\n"
1273                        "Grant create on database %s to the user"
1274                        " for running these tests." % dbname)
1275            else:
1276                schema = "public"
1277                query("drop table if exists %s.t" % (schema,))
1278                query("drop table if exists %s.t%d" % (schema, num_schema))
1279            query("create table %s.t with oids as select 1 as n, %d as d"
1280                  % (schema, num_schema))
1281            query("create table %s.t%d with oids as select 1 as n, %d as d"
1282                  % (schema, num_schema, num_schema))
1283        db.close()
1284
1285    @classmethod
1286    def tearDownClass(cls):
1287        db = DB()
1288        query = db.query
1289        query("set client_min_messages=warning")
1290        for num_schema in range(5):
1291            if num_schema:
1292                schema = "s%d" % num_schema
1293                query("drop schema %s cascade" % (schema,))
1294            else:
1295                schema = "public"
1296                query("drop table %s.t" % (schema,))
1297                query("drop table %s.t%d" % (schema, num_schema))
1298        db.close()
1299
1300    def setUp(self):
1301        self.db = DB()
1302        self.db.query("set client_min_messages=warning")
1303
1304    def tearDown(self):
1305        self.db.close()
1306
1307    def testGetTables(self):
1308        tables = self.db.get_tables()
1309        for num_schema in range(5):
1310            if num_schema:
1311                schema = "s" + str(num_schema)
1312            else:
1313                schema = "public"
1314            for t in (schema + ".t",
1315                    schema + ".t" + str(num_schema)):
1316                self.assertIn(t, tables)
1317
1318    def testGetAttnames(self):
1319        get_attnames = self.db.get_attnames
1320        query = self.db.query
1321        result = {'oid': 'int', 'd': 'int', 'n': 'int'}
1322        r = get_attnames("t")
1323        self.assertEqual(r, result)
1324        r = get_attnames("s4.t4")
1325        self.assertEqual(r, result)
1326        query("drop table if exists s3.t3m")
1327        query("create table s3.t3m with oids as select 1 as m")
1328        result_m = {'oid': 'int', 'm': 'int'}
1329        r = get_attnames("s3.t3m")
1330        self.assertEqual(r, result_m)
1331        query("set search_path to s1,s3")
1332        r = get_attnames("t3")
1333        self.assertEqual(r, result)
1334        r = get_attnames("t3m")
1335        self.assertEqual(r, result_m)
1336        query("drop table s3.t3m")
1337
1338    def testGet(self):
1339        get = self.db.get
1340        query = self.db.query
1341        PrgError = pg.ProgrammingError
1342        self.assertEqual(get("t", 1, 'n')['d'], 0)
1343        self.assertEqual(get("t0", 1, 'n')['d'], 0)
1344        self.assertEqual(get("public.t", 1, 'n')['d'], 0)
1345        self.assertEqual(get("public.t0", 1, 'n')['d'], 0)
1346        self.assertRaises(PrgError, get, "public.t1", 1, 'n')
1347        self.assertEqual(get("s1.t1", 1, 'n')['d'], 1)
1348        self.assertEqual(get("s3.t", 1, 'n')['d'], 3)
1349        query("set search_path to s2,s4")
1350        self.assertRaises(PrgError, get, "t1", 1, 'n')
1351        self.assertEqual(get("t4", 1, 'n')['d'], 4)
1352        self.assertRaises(PrgError, get, "t3", 1, 'n')
1353        self.assertEqual(get("t", 1, 'n')['d'], 2)
1354        self.assertEqual(get("s3.t3", 1, 'n')['d'], 3)
1355        query("set search_path to s1,s3")
1356        self.assertRaises(PrgError, get, "t2", 1, 'n')
1357        self.assertEqual(get("t3", 1, 'n')['d'], 3)
1358        self.assertRaises(PrgError, get, "t4", 1, 'n')
1359        self.assertEqual(get("t", 1, 'n')['d'], 1)
1360        self.assertEqual(get("s4.t4", 1, 'n')['d'], 4)
1361
1362    def testMunging(self):
1363        get = self.db.get
1364        query = self.db.query
1365        r = get("t", 1, 'n')
1366        self.assertIn('oid(t)', r)
1367        query("set search_path to s2")
1368        r = get("t2", 1, 'n')
1369        self.assertIn('oid(t2)', r)
1370        query("set search_path to s3")
1371        r = get("t", 1, 'n')
1372        self.assertIn('oid(t)', r)
1373
1374
1375if __name__ == '__main__':
1376    unittest.main()
Note: See TracBrowser for help on using the repository browser.