source: trunk/module/tests/test_classic_dbwrapper.py @ 690

Last change on this file since 690 was 690, checked in by cito, 4 years ago

Amend tests so that they can run with PostgreSQL < 9.0

Note that we do not need to make these amendments in the trunk,
because we assume PostgreSQL >= 9.0 for PyGreSQL version 5.0.

  • Property svn:executable set to *
  • Property svn:keywords set to Id
File size: 48.1 KB
Line 
1#! /usr/bin/python
2# -*- coding: utf-8 -*-
3
4"""Test the classic PyGreSQL interface.
5
6Sub-tests for the DB wrapper object.
7
8Contributed by Christoph Zwerschke.
9
10These tests need a database to test against.
11
12"""
13
14try:
15    import unittest2 as unittest  # for Python < 2.7
16except ImportError:
17    import unittest
18import os
19
20import pg  # the module under test
21
22from decimal import Decimal
23
24# We need a database to test against.  If LOCAL_PyGreSQL.py exists we will
25# get our information from that.  Otherwise we use the defaults.
26# The current user must have create schema privilege on the database.
27dbname = 'unittest'
28dbhost = None
29dbport = 5432
30
31debug = False  # let DB wrapper print debugging output
32
33try:
34    from .LOCAL_PyGreSQL import *
35except (ImportError, ValueError):
36    try:
37        from LOCAL_PyGreSQL import *
38    except ImportError:
39        pass
40
41try:
42    long
43except NameError:  # Python >= 3.0
44    long = int
45
46try:
47    unicode
48except NameError:  # Python >= 3.0
49    unicode = str
50
51windows = os.name == 'nt'
52
53# There is a known a bug in libpq under Windows which can cause
54# the interface to crash when calling PQhost():
55do_not_ask_for_host = windows
56do_not_ask_for_host_reason = 'libpq issue on Windows'
57
58
59def DB():
60    """Create a DB wrapper object connecting to the test database."""
61    db = pg.DB(dbname, dbhost, dbport)
62    if debug:
63        db.debug = debug
64    db.query("set client_min_messages=warning")
65    return db
66
67
68class TestDBClassBasic(unittest.TestCase):
69    """Test existence of the DB class wrapped pg connection methods."""
70
71    def setUp(self):
72        self.db = DB()
73
74    def tearDown(self):
75        try:
76            self.db.close()
77        except pg.InternalError:
78            pass
79
80    def testAllDBAttributes(self):
81        attributes = [
82            'begin',
83            'cancel',
84            'clear',
85            'close',
86            'commit',
87            'db',
88            'dbname',
89            'debug',
90            'delete',
91            'end',
92            'endcopy',
93            'error',
94            'escape_bytea',
95            'escape_identifier',
96            'escape_literal',
97            'escape_string',
98            'fileno',
99            'get',
100            'get_attnames',
101            'get_databases',
102            'get_notice_receiver',
103            'get_relations',
104            'get_tables',
105            'getline',
106            'getlo',
107            'getnotify',
108            'has_table_privilege',
109            'host',
110            'insert',
111            'inserttable',
112            'locreate',
113            'loimport',
114            'notification_handler',
115            'options',
116            'parameter',
117            'pkey',
118            'port',
119            'protocol_version',
120            'putline',
121            'query',
122            'release',
123            'reopen',
124            'reset',
125            'rollback',
126            'savepoint',
127            'server_version',
128            'set_notice_receiver',
129            'source',
130            'start',
131            'status',
132            'transaction',
133            'unescape_bytea',
134            'update',
135            'use_regtypes',
136            'user',
137        ]
138        db_attributes = [a for a in dir(self.db)
139            if not a.startswith('_')]
140        self.assertEqual(attributes, db_attributes)
141
142    def testAttributeDb(self):
143        self.assertEqual(self.db.db.db, dbname)
144
145    def testAttributeDbname(self):
146        self.assertEqual(self.db.dbname, dbname)
147
148    def testAttributeError(self):
149        error = self.db.error
150        self.assertTrue(not error or 'krb5_' in error)
151        self.assertEqual(self.db.error, self.db.db.error)
152
153    @unittest.skipIf(do_not_ask_for_host, do_not_ask_for_host_reason)
154    def testAttributeHost(self):
155        def_host = 'localhost'
156        host = self.db.host
157        self.assertIsInstance(host, str)
158        self.assertEqual(host, dbhost or def_host)
159        self.assertEqual(host, self.db.db.host)
160
161    def testAttributeOptions(self):
162        no_options = ''
163        options = self.db.options
164        self.assertEqual(options, no_options)
165        self.assertEqual(options, self.db.db.options)
166
167    def testAttributePort(self):
168        def_port = 5432
169        port = self.db.port
170        self.assertIsInstance(port, int)
171        self.assertEqual(port, dbport or def_port)
172        self.assertEqual(port, self.db.db.port)
173
174    def testAttributeProtocolVersion(self):
175        protocol_version = self.db.protocol_version
176        self.assertIsInstance(protocol_version, int)
177        self.assertTrue(2 <= protocol_version < 4)
178        self.assertEqual(protocol_version, self.db.db.protocol_version)
179
180    def testAttributeServerVersion(self):
181        server_version = self.db.server_version
182        self.assertIsInstance(server_version, int)
183        self.assertTrue(70400 <= server_version < 100000)
184        self.assertEqual(server_version, self.db.db.server_version)
185
186    def testAttributeStatus(self):
187        status_ok = 1
188        status = self.db.status
189        self.assertIsInstance(status, int)
190        self.assertEqual(status, status_ok)
191        self.assertEqual(status, self.db.db.status)
192
193    def testAttributeUser(self):
194        no_user = 'Deprecated facility'
195        user = self.db.user
196        self.assertTrue(user)
197        self.assertIsInstance(user, str)
198        self.assertNotEqual(user, no_user)
199        self.assertEqual(user, self.db.db.user)
200
201    def testMethodEscapeLiteral(self):
202        self.assertEqual(self.db.escape_literal(''), "''")
203
204    def testMethodEscapeIdentifier(self):
205        self.assertEqual(self.db.escape_identifier(''), '""')
206
207    def testMethodEscapeString(self):
208        self.assertEqual(self.db.escape_string(''), '')
209
210    def testMethodEscapeBytea(self):
211        self.assertEqual(self.db.escape_bytea('').replace(
212            '\\x', '').replace('\\', ''), '')
213
214    def testMethodUnescapeBytea(self):
215        self.assertEqual(self.db.unescape_bytea(''), b'')
216
217    def testMethodQuery(self):
218        query = self.db.query
219        query("select 1+1")
220        query("select 1+$1+$2", 2, 3)
221        query("select 1+$1+$2", (2, 3))
222        query("select 1+$1+$2", [2, 3])
223        query("select 1+$1", 1)
224
225    def testMethodQueryEmpty(self):
226        self.assertRaises(ValueError, self.db.query, '')
227
228    def testMethodQueryProgrammingError(self):
229        try:
230            self.db.query("select 1/0")
231        except pg.ProgrammingError as error:
232            self.assertEqual(error.sqlstate, '22012')
233
234    def testMethodEndcopy(self):
235        try:
236            self.db.endcopy()
237        except IOError:
238            pass
239
240    def testMethodClose(self):
241        self.db.close()
242        try:
243            self.db.reset()
244        except pg.Error:
245            pass
246        else:
247            self.fail('Reset should give an error for a closed connection')
248        self.assertRaises(pg.InternalError, self.db.close)
249        self.assertRaises(pg.InternalError, self.db.query, 'select 1')
250
251    def testExistingConnection(self):
252        db = pg.DB(self.db.db)
253        self.assertEqual(self.db.db, db.db)
254        self.assertTrue(db.db)
255        db.close()
256        self.assertTrue(db.db)
257        db.reopen()
258        self.assertTrue(db.db)
259        db.close()
260        self.assertTrue(db.db)
261        db = pg.DB(self.db)
262        self.assertEqual(self.db.db, db.db)
263        db = pg.DB(db=self.db.db)
264        self.assertEqual(self.db.db, db.db)
265
266        class DB2:
267            pass
268
269        db2 = DB2()
270        db2._cnx = self.db.db
271        db = pg.DB(db2)
272        self.assertEqual(self.db.db, db.db)
273
274
275class TestDBClass(unittest.TestCase):
276    """Test the methods of the DB class wrapped pg connection."""
277
278    @classmethod
279    def setUpClass(cls):
280        db = DB()
281        db.query("drop table if exists test cascade")
282        db.query("create table test ("
283            "i2 smallint, i4 integer, i8 bigint,"
284            "d numeric, f4 real, f8 double precision, m money, "
285            "v4 varchar(4), c4 char(4), t text)")
286        db.query("create or replace view test_view as"
287            " select i4, v4 from test")
288        db.close()
289
290    @classmethod
291    def tearDownClass(cls):
292        db = DB()
293        db.query("drop table test cascade")
294        db.close()
295
296    def setUp(self):
297        self.db = DB()
298        query = self.db.query
299        query('set client_encoding=utf8')
300        query('set standard_conforming_strings=on')
301        query("set lc_monetary='C'")
302        query("set datestyle='ISO,YMD'")
303        query('set bytea_output=hex')
304
305    def tearDown(self):
306        self.db.close()
307
308    def testClassName(self):
309        self.assertEqual(self.db.__class__.__name__, 'DB')
310
311    def testModuleName(self):
312        self.assertEqual(self.db.__module__, 'pg')
313        self.assertEqual(self.db.__class__.__module__, 'pg')
314
315    def testEscapeLiteral(self):
316        f = self.db.escape_literal
317        r = f(b"plain")
318        self.assertIsInstance(r, bytes)
319        self.assertEqual(r, b"'plain'")
320        r = f(u"plain")
321        self.assertIsInstance(r, unicode)
322        self.assertEqual(r, u"'plain'")
323        r = f(u"that's kÀse".encode('utf-8'))
324        self.assertIsInstance(r, bytes)
325        self.assertEqual(r, u"'that''s kÀse'".encode('utf-8'))
326        r = f(u"that's kÀse")
327        self.assertIsInstance(r, unicode)
328        self.assertEqual(r, u"'that''s kÀse'")
329        self.assertEqual(f(r"It's fine to have a \ inside."),
330            r" E'It''s fine to have a \\ inside.'")
331        self.assertEqual(f('No "quotes" must be escaped.'),
332            "'No \"quotes\" must be escaped.'")
333
334    def testEscapeIdentifier(self):
335        f = self.db.escape_identifier
336        r = f(b"plain")
337        self.assertIsInstance(r, bytes)
338        self.assertEqual(r, b'"plain"')
339        r = f(u"plain")
340        self.assertIsInstance(r, unicode)
341        self.assertEqual(r, u'"plain"')
342        r = f(u"that's kÀse".encode('utf-8'))
343        self.assertIsInstance(r, bytes)
344        self.assertEqual(r, u'"that\'s kÀse"'.encode('utf-8'))
345        r = f(u"that's kÀse")
346        self.assertIsInstance(r, unicode)
347        self.assertEqual(r, u'"that\'s kÀse"')
348        self.assertEqual(f(r"It's fine to have a \ inside."),
349            '"It\'s fine to have a \\ inside."')
350        self.assertEqual(f('All "quotes" must be escaped.'),
351            '"All ""quotes"" must be escaped."')
352
353    def testEscapeString(self):
354        f = self.db.escape_string
355        r = f(b"plain")
356        self.assertIsInstance(r, bytes)
357        self.assertEqual(r, b"plain")
358        r = f(u"plain")
359        self.assertIsInstance(r, unicode)
360        self.assertEqual(r, u"plain")
361        r = f(u"that's kÀse".encode('utf-8'))
362        self.assertIsInstance(r, bytes)
363        self.assertEqual(r, u"that''s kÀse".encode('utf-8'))
364        r = f(u"that's kÀse")
365        self.assertIsInstance(r, unicode)
366        self.assertEqual(r, u"that''s kÀse")
367        self.assertEqual(f(r"It's fine to have a \ inside."),
368            r"It''s fine to have a \ inside.")
369
370    def testEscapeBytea(self):
371        f = self.db.escape_bytea
372        # note that escape_byte always returns hex output since Pg 9.0,
373        # regardless of the bytea_output setting
374        r = f(b'plain')
375        self.assertIsInstance(r, bytes)
376        self.assertEqual(r, b'\\x706c61696e')
377        r = f(u'plain')
378        self.assertIsInstance(r, unicode)
379        self.assertEqual(r, u'\\x706c61696e')
380        r = f(u"das is' kÀse".encode('utf-8'))
381        self.assertIsInstance(r, bytes)
382        self.assertEqual(r, b'\\x64617320697327206bc3a47365')
383        r = f(u"das is' kÀse")
384        self.assertIsInstance(r, unicode)
385        self.assertEqual(r, u'\\x64617320697327206bc3a47365')
386        self.assertEqual(f(b'O\x00ps\xff!'), b'\\x4f007073ff21')
387
388    def testUnescapeBytea(self):
389        f = self.db.unescape_bytea
390        r = f(b'plain')
391        self.assertIsInstance(r, bytes)
392        self.assertEqual(r, b'plain')
393        r = f(u'plain')
394        self.assertIsInstance(r, bytes)
395        self.assertEqual(r, b'plain')
396        r = f(b"das is' k\\303\\244se")
397        self.assertIsInstance(r, bytes)
398        self.assertEqual(r, u"das is' kÀse".encode('utf8'))
399        r = f(u"das is' k\\303\\244se")
400        self.assertIsInstance(r, bytes)
401        self.assertEqual(r, u"das is' kÀse".encode('utf8'))
402        self.assertEqual(f(r'O\\000ps\\377!'), b'O\\000ps\\377!')
403        self.assertEqual(f(r'\\x706c61696e'), b'\\x706c61696e')
404        self.assertEqual(f(r'\\x746861742773206be47365'),
405            b'\\x746861742773206be47365')
406        self.assertEqual(f(r'\\x4f007073ff21'), b'\\x4f007073ff21')
407
408    def testQuote(self):
409        f = self.db._quote
410        self.assertEqual(f(None, None), 'NULL')
411        self.assertEqual(f(None, 'int'), 'NULL')
412        self.assertEqual(f(None, 'float'), 'NULL')
413        self.assertEqual(f(None, 'num'), 'NULL')
414        self.assertEqual(f(None, 'money'), 'NULL')
415        self.assertEqual(f(None, 'bool'), 'NULL')
416        self.assertEqual(f(None, 'date'), 'NULL')
417        self.assertEqual(f('', 'int'), 'NULL')
418        self.assertEqual(f('', 'float'), 'NULL')
419        self.assertEqual(f('', 'num'), 'NULL')
420        self.assertEqual(f('', 'money'), 'NULL')
421        self.assertEqual(f('', 'bool'), 'NULL')
422        self.assertEqual(f('', 'date'), 'NULL')
423        self.assertEqual(f('', 'text'), "''")
424        self.assertEqual(f(0, 'int'), '0')
425        self.assertEqual(f(0, 'num'), '0')
426        self.assertEqual(f(1, 'int'), '1')
427        self.assertEqual(f(1, 'num'), '1')
428        self.assertEqual(f(-1, 'int'), '-1')
429        self.assertEqual(f(-1, 'num'), '-1')
430        self.assertEqual(f(123456789, 'int'), '123456789')
431        self.assertEqual(f(123456987, 'num'), '123456987')
432        self.assertEqual(f(1.23654789, 'num'), '1.23654789')
433        self.assertEqual(f(12365478.9, 'num'), '12365478.9')
434        self.assertEqual(f('123456789', 'num'), '123456789')
435        self.assertEqual(f('1.23456789', 'num'), '1.23456789')
436        self.assertEqual(f('12345678.9', 'num'), '12345678.9')
437        self.assertEqual(f(123, 'money'), '123')
438        self.assertEqual(f('123', 'money'), '123')
439        self.assertEqual(f(123.45, 'money'), '123.45')
440        self.assertEqual(f('123.45', 'money'), '123.45')
441        self.assertEqual(f(123.454, 'money'), '123.454')
442        self.assertEqual(f('123.454', 'money'), '123.454')
443        self.assertEqual(f(123.456, 'money'), '123.456')
444        self.assertEqual(f('123.456', 'money'), '123.456')
445        self.assertEqual(f('f', 'bool'), "'f'")
446        self.assertEqual(f('F', 'bool'), "'f'")
447        self.assertEqual(f('false', 'bool'), "'f'")
448        self.assertEqual(f('False', 'bool'), "'f'")
449        self.assertEqual(f('FALSE', 'bool'), "'f'")
450        self.assertEqual(f(0, 'bool'), "'f'")
451        self.assertEqual(f('0', 'bool'), "'f'")
452        self.assertEqual(f('-', 'bool'), "'f'")
453        self.assertEqual(f('n', 'bool'), "'f'")
454        self.assertEqual(f('N', 'bool'), "'f'")
455        self.assertEqual(f('no', 'bool'), "'f'")
456        self.assertEqual(f('off', 'bool'), "'f'")
457        self.assertEqual(f('t', 'bool'), "'t'")
458        self.assertEqual(f('T', 'bool'), "'t'")
459        self.assertEqual(f('true', 'bool'), "'t'")
460        self.assertEqual(f('True', 'bool'), "'t'")
461        self.assertEqual(f('TRUE', 'bool'), "'t'")
462        self.assertEqual(f(1, 'bool'), "'t'")
463        self.assertEqual(f(2, 'bool'), "'t'")
464        self.assertEqual(f(-1, 'bool'), "'t'")
465        self.assertEqual(f(0.5, 'bool'), "'t'")
466        self.assertEqual(f('1', 'bool'), "'t'")
467        self.assertEqual(f('y', 'bool'), "'t'")
468        self.assertEqual(f('Y', 'bool'), "'t'")
469        self.assertEqual(f('yes', 'bool'), "'t'")
470        self.assertEqual(f('on', 'bool'), "'t'")
471        self.assertEqual(f('01.01.2000', 'date'), "'01.01.2000'")
472        self.assertEqual(f(123, 'text'), "'123'")
473        self.assertEqual(f(1.23, 'text'), "'1.23'")
474        self.assertEqual(f('abc', 'text'), "'abc'")
475        self.assertEqual(f("ab'c", 'text'), "'ab''c'")
476        self.assertEqual(f('ab\\c', 'text'), "'ab\\c'")
477        self.assertEqual(f("a\\b'c", 'text'), "'a\\b''c'")
478        self.db.query('set standard_conforming_strings=off')
479        self.assertEqual(f('ab\\c', 'text'), "'ab\\\\c'")
480        self.assertEqual(f("a\\b'c", 'text'), "'a\\\\b''c'")
481
482    def testQuery(self):
483        query = self.db.query
484        query("drop table if exists test_table")
485        q = "create table test_table (n integer) with oids"
486        r = query(q)
487        self.assertIsNone(r)
488        q = "insert into test_table values (1)"
489        r = query(q)
490        self.assertIsInstance(r, int)
491        q = "insert into test_table select 2"
492        r = query(q)
493        self.assertIsInstance(r, int)
494        oid = r
495        q = "select oid from test_table where n=2"
496        r = query(q).getresult()
497        self.assertEqual(len(r), 1)
498        r = r[0]
499        self.assertEqual(len(r), 1)
500        r = r[0]
501        self.assertEqual(r, oid)
502        q = "insert into test_table select 3 union select 4 union select 5"
503        r = query(q)
504        self.assertIsInstance(r, str)
505        self.assertEqual(r, '3')
506        q = "update test_table set n=4 where n<5"
507        r = query(q)
508        self.assertIsInstance(r, str)
509        self.assertEqual(r, '4')
510        q = "delete from test_table"
511        r = query(q)
512        self.assertIsInstance(r, str)
513        self.assertEqual(r, '5')
514        query("drop table test_table")
515
516    def testMultipleQueries(self):
517        self.assertEqual(self.db.query(
518            "create temporary table test_multi (n integer);"
519            "insert into test_multi values (4711);"
520            "select n from test_multi").getresult()[0][0], 4711)
521
522    def testQueryWithParams(self):
523        query = self.db.query
524        query("drop table if exists test_table")
525        q = "create table test_table (n1 integer, n2 integer) with oids"
526        query(q)
527        q = "insert into test_table values ($1, $2)"
528        r = query(q, (1, 2))
529        self.assertIsInstance(r, int)
530        r = query(q, [3, 4])
531        self.assertIsInstance(r, int)
532        r = query(q, [5, 6])
533        self.assertIsInstance(r, int)
534        q = "select * from test_table order by 1, 2"
535        self.assertEqual(query(q).getresult(),
536            [(1, 2), (3, 4), (5, 6)])
537        q = "select * from test_table where n1=$1 and n2=$2"
538        self.assertEqual(query(q, 3, 4).getresult(), [(3, 4)])
539        q = "update test_table set n2=$2 where n1=$1"
540        r = query(q, 3, 7)
541        self.assertEqual(r, '1')
542        q = "select * from test_table order by 1, 2"
543        self.assertEqual(query(q).getresult(),
544            [(1, 2), (3, 7), (5, 6)])
545        q = "delete from test_table where n2!=$1"
546        r = query(q, 4)
547        self.assertEqual(r, '3')
548        query("drop table test_table")
549
550    def testEmptyQuery(self):
551        self.assertRaises(ValueError, self.db.query, '')
552
553    def testQueryProgrammingError(self):
554        try:
555            self.db.query("select 1/0")
556        except pg.ProgrammingError as error:
557            self.assertEqual(error.sqlstate, '22012')
558
559    def testPkey(self):
560        query = self.db.query
561        for n in range(4):
562            query("drop table if exists pkeytest%d" % n)
563        query("create table pkeytest0 ("
564            "a smallint)")
565        query("create table pkeytest1 ("
566            "b smallint primary key)")
567        query("create table pkeytest2 ("
568            "c smallint, d smallint primary key)")
569        query("create table pkeytest3 ("
570            "e smallint, f smallint, g smallint, "
571            "h smallint, i smallint, "
572            "primary key (f,h))")
573        pkey = self.db.pkey
574        self.assertRaises(KeyError, pkey, 'pkeytest0')
575        self.assertEqual(pkey('pkeytest1'), 'b')
576        self.assertEqual(pkey('pkeytest2'), 'd')
577        self.assertEqual(pkey('pkeytest3'), frozenset('fh'))
578        self.assertEqual(pkey('pkeytest0', 'none'), 'none')
579        self.assertEqual(pkey('pkeytest0'), 'none')
580        pkey(None, {'t': 'a', 'n.t': 'b'})
581        self.assertEqual(pkey('t'), 'a')
582        self.assertEqual(pkey('n.t'), 'b')
583        self.assertRaises(KeyError, pkey, 'pkeytest0')
584        for n in range(4):
585            query("drop table pkeytest%d" % n)
586
587    def testGetDatabases(self):
588        databases = self.db.get_databases()
589        self.assertIn('template0', databases)
590        self.assertIn('template1', databases)
591        self.assertNotIn('not existing database', databases)
592        self.assertIn('postgres', databases)
593        self.assertIn(dbname, databases)
594
595    def testGetTables(self):
596        get_tables = self.db.get_tables
597        result1 = get_tables()
598        tables = ('"A very Special Name"',
599            '"A_MiXeD_quoted_NaMe"', 'a1', 'a2',
600            'A_MiXeD_NaMe', '"another special name"',
601            'averyveryveryveryveryveryverylongtablename',
602            'b0', 'b3', 'x', 'xx', 'xXx', 'y', 'z')
603        for t in tables:
604            self.db.query('drop table if exists %s' % t)
605            self.db.query("create table %s"
606                " as select 0" % t)
607        result3 = get_tables()
608        result2 = []
609        for t in result3:
610            if t not in result1:
611                result2.append(t)
612        result3 = []
613        for t in tables:
614            if not t.startswith('"'):
615                t = t.lower()
616            result3.append('public.' + t)
617        self.assertEqual(result2, result3)
618        for t in result2:
619            self.db.query('drop table %s' % t)
620        result2 = get_tables()
621        self.assertEqual(result2, result1)
622
623    def testGetRelations(self):
624        get_relations = self.db.get_relations
625        result = get_relations()
626        self.assertIn('public.test', result)
627        self.assertIn('public.test_view', result)
628        result = get_relations('rv')
629        self.assertIn('public.test', result)
630        self.assertIn('public.test_view', result)
631        result = get_relations('r')
632        self.assertIn('public.test', result)
633        self.assertNotIn('public.test_view', result)
634        result = get_relations('v')
635        self.assertNotIn('public.test', result)
636        self.assertIn('public.test_view', result)
637        result = get_relations('cisSt')
638        self.assertNotIn('public.test', result)
639        self.assertNotIn('public.test_view', result)
640
641    def testAttnames(self):
642        self.assertRaises(pg.ProgrammingError,
643            self.db.get_attnames, 'does_not_exist')
644        self.assertRaises(pg.ProgrammingError,
645            self.db.get_attnames, 'has.too.many.dots')
646        for table in ('attnames_test_table', 'test table for attnames'):
647            self.db.query('drop table if exists "%s"' % table)
648            self.db.query('create table "%s" ('
649                'a smallint, b integer, c bigint, '
650                'e numeric, f float, f2 double precision, m money, '
651                'x smallint, y smallint, z smallint, '
652                'Normal_NaMe smallint, "Special Name" smallint, '
653                't text, u char(2), v varchar(2), '
654                'primary key (y, u)) with oids' % table)
655            attributes = self.db.get_attnames(table)
656            result = {'a': 'int', 'c': 'int', 'b': 'int',
657                'e': 'num', 'f': 'float', 'f2': 'float', 'm': 'money',
658                'normal_name': 'int', 'Special Name': 'int',
659                'u': 'text', 't': 'text', 'v': 'text',
660                'y': 'int', 'x': 'int', 'z': 'int', 'oid': 'int'}
661            self.assertEqual(attributes, result)
662            self.db.query('drop table "%s"' % table)
663
664    def testHasTablePrivilege(self):
665        can = self.db.has_table_privilege
666        self.assertEqual(can('test'), True)
667        self.assertEqual(can('test', 'select'), True)
668        self.assertEqual(can('test', 'SeLeCt'), True)
669        self.assertEqual(can('test', 'SELECT'), True)
670        self.assertEqual(can('test', 'insert'), True)
671        self.assertEqual(can('test', 'update'), True)
672        self.assertEqual(can('test', 'delete'), True)
673        self.assertRaises(pg.ProgrammingError, can, 'test', 'foobar')
674        self.assertRaises(pg.ProgrammingError, can, 'table_does_not_exist')
675
676    def testGet(self):
677        get = self.db.get
678        query = self.db.query
679        for table in ('get_test_table', 'test table for get'):
680            query('drop table if exists "%s"' % table)
681            query('create table "%s" ('
682                "n integer, t text) with oids" % table)
683            for n, t in enumerate('xyz'):
684                query('insert into "%s" values('"%d, '%s')"
685                    % (table, n + 1, t))
686            self.assertRaises(pg.ProgrammingError, get, table, 2)
687            r = get(table, 2, 'n')
688            oid_table = table
689            if ' ' in table:
690                oid_table = '"%s"' % oid_table
691            oid_table = 'oid(public.%s)' % oid_table
692            self.assertIn(oid_table, r)
693            oid = r[oid_table]
694            self.assertIsInstance(oid, int)
695            result = {'t': 'y', 'n': 2, oid_table: oid}
696            self.assertEqual(r, result)
697            self.assertEqual(get(table + ' *', 2, 'n'), r)
698            self.assertEqual(get(table, oid, 'oid')['t'], 'y')
699            self.assertEqual(get(table, 1, 'n')['t'], 'x')
700            self.assertEqual(get(table, 3, 'n')['t'], 'z')
701            self.assertEqual(get(table, 2, 'n')['t'], 'y')
702            self.assertRaises(pg.DatabaseError, get, table, 4, 'n')
703            r['n'] = 3
704            self.assertEqual(get(table, r, 'n')['t'], 'z')
705            self.assertEqual(get(table, 1, 'n')['t'], 'x')
706            query('alter table "%s" alter n set not null' % table)
707            query('alter table "%s" add primary key (n)' % table)
708            self.assertEqual(get(table, 3)['t'], 'z')
709            self.assertEqual(get(table, 1)['t'], 'x')
710            self.assertEqual(get(table, 2)['t'], 'y')
711            r['n'] = 1
712            self.assertEqual(get(table, r)['t'], 'x')
713            r['n'] = 3
714            self.assertEqual(get(table, r)['t'], 'z')
715            r['n'] = 2
716            self.assertEqual(get(table, r)['t'], 'y')
717            query('drop table "%s"' % table)
718
719    def testGetWithCompositeKey(self):
720        get = self.db.get
721        query = self.db.query
722        table = 'get_test_table_1'
723        query("drop table if exists %s" % table)
724        query("create table %s ("
725            "n integer, t text, primary key (n))" % table)
726        for n, t in enumerate('abc'):
727            query("insert into %s values("
728                "%d, '%s')" % (table, n + 1, t))
729        self.assertEqual(get(table, 2)['t'], 'b')
730        query("drop table %s" % table)
731        table = 'get_test_table_2'
732        query("drop table if exists %s" % table)
733        query("create table %s ("
734            "n integer, m integer, t text, primary key (n, m))" % table)
735        for n in range(3):
736            for m in range(2):
737                t = chr(ord('a') + 2 * n + m)
738                query("insert into %s values("
739                    "%d, %d, '%s')" % (table, n + 1, m + 1, t))
740        self.assertRaises(pg.ProgrammingError, get, table, 2)
741        self.assertEqual(get(table, dict(n=2, m=2))['t'], 'd')
742        self.assertEqual(get(table, dict(n=1, m=2),
743                             ('n', 'm'))['t'], 'b')
744        self.assertEqual(get(table, dict(n=3, m=2),
745                             frozenset(['n', 'm']))['t'], 'f')
746        query("drop table %s" % table)
747
748    def testGetFromView(self):
749        self.db.query('delete from test where i4=14')
750        self.db.query('insert into test (i4, v4) values('
751            "14, 'abc4')")
752        r = self.db.get('test_view', 14, 'i4')
753        self.assertIn('v4', r)
754        self.assertEqual(r['v4'], 'abc4')
755
756    def testInsert(self):
757        insert = self.db.insert
758        query = self.db.query
759        server_version = self.db.server_version
760        for table in ('insert_test_table', 'test table for insert'):
761            query('drop table if exists "%s"' % table)
762            query('create table "%s" ('
763                "i2 smallint, i4 integer, i8 bigint,"
764                " d numeric, f4 real, f8 double precision, m money,"
765                " v4 varchar(4), c4 char(4), t text,"
766                " b boolean, ts timestamp) with oids" % table)
767            oid_table = table
768            if ' ' in table:
769                oid_table = '"%s"' % oid_table
770            oid_table = 'oid(public.%s)' % oid_table
771            tests = [dict(i2=None, i4=None, i8=None),
772                (dict(i2='', i4='', i8=''), dict(i2=None, i4=None, i8=None)),
773                (dict(i2=0, i4=0, i8=0), dict(i2=0, i4=0, i8=0)),
774                dict(i2=42, i4=123456, i8=9876543210),
775                dict(i2=2 ** 15 - 1,
776                    i4=int(2 ** 31 - 1), i8=long(2 ** 63 - 1)),
777                dict(d=None), (dict(d=''), dict(d=None)),
778                dict(d=Decimal(0)), (dict(d=0), dict(d=Decimal(0))),
779                dict(f4=None, f8=None), dict(f4=0, f8=0),
780                (dict(f4='', f8=''), dict(f4=None, f8=None)),
781                (dict(d=1234.5, f4=1234.5, f8=1234.5),
782                      dict(d=Decimal('1234.5'))),
783                dict(d=Decimal('123.456789'), f4=12.375, f8=123.4921875),
784                dict(d=Decimal('123456789.9876543212345678987654321')),
785                dict(m=None), (dict(m=''), dict(m=None)),
786                dict(m=Decimal('-1234.56')),
787                (dict(m=('-1234.56')), dict(m=Decimal('-1234.56'))),
788                dict(m=Decimal('1234.56')), dict(m=Decimal('123456')),
789                (dict(m='1234.56'), dict(m=Decimal('1234.56'))),
790                (dict(m=1234.5), dict(m=Decimal('1234.5'))),
791                (dict(m=-1234.5), dict(m=Decimal('-1234.5'))),
792                (dict(m=123456), dict(m=Decimal('123456'))),
793                (dict(m='1234567.89'), dict(m=Decimal('1234567.89'))),
794                dict(b=None), (dict(b=''), dict(b=None)),
795                dict(b='f'), dict(b='t'),
796                (dict(b=0), dict(b='f')), (dict(b=1), dict(b='t')),
797                (dict(b=False), dict(b='f')), (dict(b=True), dict(b='t')),
798                (dict(b='0'), dict(b='f')), (dict(b='1'), dict(b='t')),
799                (dict(b='n'), dict(b='f')), (dict(b='y'), dict(b='t')),
800                (dict(b='no'), dict(b='f')), (dict(b='yes'), dict(b='t')),
801                (dict(b='off'), dict(b='f')), (dict(b='on'), dict(b='t')),
802                dict(v4=None, c4=None, t=None),
803                (dict(v4='', c4='', t=''), dict(c4=' ' * 4)),
804                dict(v4='1234', c4='1234', t='1234' * 10),
805                dict(v4='abcd', c4='abcd', t='abcdefg'),
806                (dict(v4='abc', c4='abc', t='abc'), dict(c4='abc ')),
807                dict(ts=None), (dict(ts=''), dict(ts=None)),
808                (dict(ts=0), dict(ts=None)), (dict(ts=False), dict(ts=None)),
809                dict(ts='2012-12-21 00:00:00'),
810                (dict(ts='2012-12-21'), dict(ts='2012-12-21 00:00:00')),
811                dict(ts='2012-12-21 12:21:12'),
812                dict(ts='2013-01-05 12:13:14'),
813                dict(ts='current_timestamp')]
814            for test in tests:
815                if isinstance(test, dict):
816                    data = test
817                    change = {}
818                else:
819                    data, change = test
820                expect = data.copy()
821                expect.update(change)
822                if data.get('m') and server_version < 910000:
823                    # PostgreSQL < 9.1 cannot directly convert numbers to money
824                    data['m'] = "'%s'::money" % data['m']
825                self.assertEqual(insert(table, data), data)
826                self.assertIn(oid_table, data)
827                oid = data[oid_table]
828                self.assertIsInstance(oid, int)
829                data = dict(item for item in data.items()
830                    if item[0] in expect)
831                ts = expect.get('ts')
832                if ts == 'current_timestamp':
833                    ts = expect['ts'] = data['ts']
834                    if len(ts) > 19:
835                        self.assertEqual(ts[19], '.')
836                        ts = ts[:19]
837                    else:
838                        self.assertEqual(len(ts), 19)
839                    self.assertTrue(ts[:4].isdigit())
840                    self.assertEqual(ts[4], '-')
841                    self.assertEqual(ts[10], ' ')
842                    self.assertTrue(ts[11:13].isdigit())
843                    self.assertEqual(ts[13], ':')
844                self.assertEqual(data, expect)
845                data = query(
846                    'select oid,* from "%s"' % table).dictresult()[0]
847                self.assertEqual(data['oid'], oid)
848                data = dict(item for item in data.items()
849                    if item[0] in expect)
850                self.assertEqual(data, expect)
851                query('delete from "%s"' % table)
852            query('drop table "%s"' % table)
853
854    def testUpdate(self):
855        update = self.db.update
856        query = self.db.query
857        for table in ('update_test_table', 'test table for update'):
858            query('drop table if exists "%s"' % table)
859            query('create table "%s" ('
860                "n integer, t text) with oids" % table)
861            for n, t in enumerate('xyz'):
862                query('insert into "%s" values('
863                    "%d, '%s')" % (table, n + 1, t))
864            self.assertRaises(pg.ProgrammingError, self.db.get, table, 2)
865            r = self.db.get(table, 2, 'n')
866            r['t'] = 'u'
867            s = update(table, r)
868            self.assertEqual(s, r)
869            r = query('select t from "%s" where n=2' % table
870                      ).getresult()[0][0]
871            self.assertEqual(r, 'u')
872            query('drop table "%s"' % table)
873
874    def testUpdateWithCompositeKey(self):
875        update = self.db.update
876        query = self.db.query
877        table = 'update_test_table_1'
878        query("drop table if exists %s" % table)
879        query("create table %s ("
880            "n integer, t text, primary key (n))" % table)
881        for n, t in enumerate('abc'):
882            query("insert into %s values("
883                "%d, '%s')" % (table, n + 1, t))
884        self.assertRaises(pg.ProgrammingError, update,
885                          table, dict(t='b'))
886        self.assertEqual(update(table, dict(n=2, t='d'))['t'], 'd')
887        r = query('select t from "%s" where n=2' % table
888                  ).getresult()[0][0]
889        self.assertEqual(r, 'd')
890        query("drop table %s" % table)
891        table = 'update_test_table_2'
892        query("drop table if exists %s" % table)
893        query("create table %s ("
894            "n integer, m integer, t text, primary key (n, m))" % table)
895        for n in range(3):
896            for m in range(2):
897                t = chr(ord('a') + 2 * n + m)
898                query("insert into %s values("
899                    "%d, %d, '%s')" % (table, n + 1, m + 1, t))
900        self.assertRaises(pg.ProgrammingError, update,
901                          table, dict(n=2, t='b'))
902        self.assertEqual(update(table,
903                                dict(n=2, m=2, t='x'))['t'], 'x')
904        r = [r[0] for r in query('select t from "%s" where n=2'
905            ' order by m' % table).getresult()]
906        self.assertEqual(r, ['c', 'x'])
907        query("drop table %s" % table)
908
909    def testClear(self):
910        clear = self.db.clear
911        query = self.db.query
912        for table in ('clear_test_table', 'test table for clear'):
913            query('drop table if exists "%s"' % table)
914            query('create table "%s" ('
915                "n integer, b boolean, d date, t text)" % table)
916            r = clear(table)
917            result = {'n': 0, 'b': 'f', 'd': '', 't': ''}
918            self.assertEqual(r, result)
919            r['a'] = r['n'] = 1
920            r['d'] = r['t'] = 'x'
921            r['b'] = 't'
922            r['oid'] = long(1)
923            r = clear(table, r)
924            result = {'a': 1, 'n': 0, 'b': 'f', 'd': '', 't': '',
925                'oid': long(1)}
926            self.assertEqual(r, result)
927            query('drop table "%s"' % table)
928
929    def testDelete(self):
930        delete = self.db.delete
931        query = self.db.query
932        for table in ('delete_test_table', 'test table for delete'):
933            query('drop table if exists "%s"' % table)
934            query('create table "%s" ('
935                "n integer, t text) with oids" % table)
936            for n, t in enumerate('xyz'):
937                query('insert into "%s" values('
938                    "%d, '%s')" % (table, n + 1, t))
939            self.assertRaises(pg.ProgrammingError, self.db.get, table, 2)
940            r = self.db.get(table, 1, 'n')
941            s = delete(table, r)
942            self.assertEqual(s, 1)
943            r = self.db.get(table, 3, 'n')
944            s = delete(table, r)
945            self.assertEqual(s, 1)
946            s = delete(table, r)
947            self.assertEqual(s, 0)
948            r = query('select * from "%s"' % table).dictresult()
949            self.assertEqual(len(r), 1)
950            r = r[0]
951            result = {'n': 2, 't': 'y'}
952            self.assertEqual(r, result)
953            r = self.db.get(table, 2, 'n')
954            s = delete(table, r)
955            self.assertEqual(s, 1)
956            s = delete(table, r)
957            self.assertEqual(s, 0)
958            self.assertRaises(pg.DatabaseError, self.db.get, table, 2, 'n')
959            query('drop table "%s"' % table)
960
961    def testDeleteWithCompositeKey(self):
962        query = self.db.query
963        table = 'delete_test_table_1'
964        query("drop table if exists %s" % table)
965        query("create table %s ("
966            "n integer, t text, primary key (n))" % table)
967        for n, t in enumerate('abc'):
968            query("insert into %s values("
969                "%d, '%s')" % (table, n + 1, t))
970        self.assertRaises(pg.ProgrammingError, self.db.delete,
971            table, dict(t='b'))
972        self.assertEqual(self.db.delete(table, dict(n=2)), 1)
973        r = query('select t from "%s" where n=2' % table
974                  ).getresult()
975        self.assertEqual(r, [])
976        self.assertEqual(self.db.delete(table, dict(n=2)), 0)
977        r = query('select t from "%s" where n=3' % table
978                  ).getresult()[0][0]
979        self.assertEqual(r, 'c')
980        query("drop table %s" % table)
981        table = 'delete_test_table_2'
982        query("drop table if exists %s" % table)
983        query("create table %s ("
984            "n integer, m integer, t text, primary key (n, m))" % table)
985        for n in range(3):
986            for m in range(2):
987                t = chr(ord('a') + 2 * n + m)
988                query("insert into %s values("
989                    "%d, %d, '%s')" % (table, n + 1, m + 1, t))
990        self.assertRaises(pg.ProgrammingError, self.db.delete,
991            table, dict(n=2, t='b'))
992        self.assertEqual(self.db.delete(table, dict(n=2, m=2)), 1)
993        r = [r[0] for r in query('select t from "%s" where n=2'
994            ' order by m' % table).getresult()]
995        self.assertEqual(r, ['c'])
996        self.assertEqual(self.db.delete(table, dict(n=2, m=2)), 0)
997        r = [r[0] for r in query('select t from "%s" where n=3'
998            ' order by m' % table).getresult()]
999        self.assertEqual(r, ['e', 'f'])
1000        self.assertEqual(self.db.delete(table, dict(n=3, m=1)), 1)
1001        r = [r[0] for r in query('select t from "%s" where n=3'
1002            ' order by m' % table).getresult()]
1003        self.assertEqual(r, ['f'])
1004        query("drop table %s" % table)
1005
1006    def testTransaction(self):
1007        query = self.db.query
1008        query("drop table if exists test_table")
1009        query("create table test_table (n integer)")
1010        self.db.begin()
1011        query("insert into test_table values (1)")
1012        query("insert into test_table values (2)")
1013        self.db.commit()
1014        self.db.begin()
1015        query("insert into test_table values (3)")
1016        query("insert into test_table values (4)")
1017        self.db.rollback()
1018        self.db.begin()
1019        query("insert into test_table values (5)")
1020        self.db.savepoint('before6')
1021        query("insert into test_table values (6)")
1022        self.db.rollback('before6')
1023        query("insert into test_table values (7)")
1024        self.db.commit()
1025        self.db.begin()
1026        self.db.savepoint('before8')
1027        query("insert into test_table values (8)")
1028        self.db.release('before8')
1029        self.assertRaises(pg.ProgrammingError, self.db.rollback, 'before8')
1030        self.db.commit()
1031        self.db.start()
1032        query("insert into test_table values (9)")
1033        self.db.end()
1034        r = [r[0] for r in query(
1035            "select * from test_table order by 1").getresult()]
1036        self.assertEqual(r, [1, 2, 5, 7, 9])
1037        query("drop table test_table")
1038
1039    def testContextManager(self):
1040        query = self.db.query
1041        query("drop table if exists test_table")
1042        query("create table test_table (n integer check(n>0))")
1043        with self.db:
1044            query("insert into test_table values (1)")
1045            query("insert into test_table values (2)")
1046        try:
1047            with self.db:
1048                query("insert into test_table values (3)")
1049                query("insert into test_table values (4)")
1050                raise ValueError('test transaction should rollback')
1051        except ValueError as error:
1052            self.assertEqual(str(error), 'test transaction should rollback')
1053        with self.db:
1054            query("insert into test_table values (5)")
1055        try:
1056            with self.db:
1057                query("insert into test_table values (6)")
1058                query("insert into test_table values (-1)")
1059        except pg.ProgrammingError as error:
1060            self.assertTrue('check' in str(error))
1061        with self.db:
1062            query("insert into test_table values (7)")
1063        r = [r[0] for r in query(
1064            "select * from test_table order by 1").getresult()]
1065        self.assertEqual(r, [1, 2, 5, 7])
1066        query("drop table test_table")
1067
1068    def testBytea(self):
1069        query = self.db.query
1070        query('drop table if exists bytea_test')
1071        query('create table bytea_test (n smallint primary key, data bytea)')
1072        s = b"It's all \\ kinds \x00 of\r nasty \xff stuff!\n"
1073        r = self.db.escape_bytea(s)
1074        query('insert into bytea_test values(3,$1)', (r,))
1075        r = query('select * from bytea_test where n=3').getresult()
1076        self.assertEqual(len(r), 1)
1077        r = r[0]
1078        self.assertEqual(len(r), 2)
1079        self.assertEqual(r[0], 3)
1080        r = r[1]
1081        self.assertIsInstance(r, str)
1082        r = self.db.unescape_bytea(r)
1083        self.assertIsInstance(r, bytes)
1084        self.assertEqual(r, s)
1085        query('drop table bytea_test')
1086
1087    def testInsertUpdateGetBytea(self):
1088        query = self.db.query
1089        query('drop table if exists bytea_test')
1090        query('create table bytea_test (n smallint primary key, data bytea)')
1091        # insert as bytes
1092        s = b"It's all \\ kinds \x00 of\r nasty \xff stuff!\n"
1093        r = self.db.insert('bytea_test', n=5, data=s)
1094        self.assertIsInstance(r, dict)
1095        self.assertIn('n', r)
1096        self.assertEqual(r['n'], 5)
1097        self.assertIn('data', r)
1098        r = r['data']
1099        self.assertIsInstance(r, bytes)
1100        self.assertEqual(r, s)
1101        # update as bytes
1102        s += b"and now even more \x00 nasty \t stuff!\f"
1103        r = self.db.update('bytea_test', n=5, data=s)
1104        self.assertIsInstance(r, dict)
1105        self.assertIn('n', r)
1106        self.assertEqual(r['n'], 5)
1107        self.assertIn('data', r)
1108        r = r['data']
1109        self.assertIsInstance(r, bytes)
1110        self.assertEqual(r, s)
1111        r = query('select * from bytea_test where n=5').getresult()
1112        self.assertEqual(len(r), 1)
1113        r = r[0]
1114        self.assertEqual(len(r), 2)
1115        self.assertEqual(r[0], 5)
1116        r = r[1]
1117        self.assertIsInstance(r, str)
1118        r = self.db.unescape_bytea(r)
1119        self.assertIsInstance(r, bytes)
1120        self.assertEqual(r, s)
1121        r = self.db.get('bytea_test', dict(n=5))
1122        self.assertIsInstance(r, dict)
1123        self.assertIn('n', r)
1124        self.assertEqual(r['n'], 5)
1125        self.assertIn('data', r)
1126        r = r['data']
1127        self.assertIsInstance(r, bytes)
1128        self.assertEqual(r, s)
1129        query('drop table bytea_test')
1130
1131    def testDebugWithCallable(self):
1132        if debug:
1133            self.assertEqual(self.db.debug, debug)
1134        else:
1135            self.assertIsNone(self.db.debug)
1136        s = []
1137        self.db.debug = s.append
1138        try:
1139            self.db.query("select 1")
1140            self.db.query("select 2")
1141            self.assertEqual(s, ["select 1", "select 2"])
1142        finally:
1143            self.db.debug = debug
1144
1145
1146class TestSchemas(unittest.TestCase):
1147    """Test correct handling of schemas (namespaces)."""
1148
1149    @classmethod
1150    def setUpClass(cls):
1151        db = DB()
1152        query = db.query
1153        query("set client_min_messages=warning")
1154        for num_schema in range(5):
1155            if num_schema:
1156                schema = "s%d" % num_schema
1157                query("drop schema if exists %s cascade" % (schema,))
1158                try:
1159                    query("create schema %s" % (schema,))
1160                except pg.ProgrammingError:
1161                    raise RuntimeError("The test user cannot create schemas.\n"
1162                        "Grant create on database %s to the user"
1163                        " for running these tests." % dbname)
1164            else:
1165                schema = "public"
1166                query("drop table if exists %s.t" % (schema,))
1167                query("drop table if exists %s.t%d" % (schema, num_schema))
1168            query("create table %s.t with oids as select 1 as n, %d as d"
1169                  % (schema, num_schema))
1170            query("create table %s.t%d with oids as select 1 as n, %d as d"
1171                  % (schema, num_schema, num_schema))
1172        db.close()
1173
1174    @classmethod
1175    def tearDownClass(cls):
1176        db = DB()
1177        query = db.query
1178        query("set client_min_messages=warning")
1179        for num_schema in range(5):
1180            if num_schema:
1181                schema = "s%d" % num_schema
1182                query("drop schema %s cascade" % (schema,))
1183            else:
1184                schema = "public"
1185                query("drop table %s.t" % (schema,))
1186                query("drop table %s.t%d" % (schema, num_schema))
1187        db.close()
1188
1189    def setUp(self):
1190        self.db = DB()
1191        self.db.query("set client_min_messages=warning")
1192
1193    def tearDown(self):
1194        self.db.close()
1195
1196    def testGetTables(self):
1197        tables = self.db.get_tables()
1198        for num_schema in range(5):
1199            if num_schema:
1200                schema = "s" + str(num_schema)
1201            else:
1202                schema = "public"
1203            for t in (schema + ".t",
1204                    schema + ".t" + str(num_schema)):
1205                self.assertIn(t, tables)
1206
1207    def testGetAttnames(self):
1208        get_attnames = self.db.get_attnames
1209        query = self.db.query
1210        result = {'oid': 'int', 'd': 'int', 'n': 'int'}
1211        r = get_attnames("t")
1212        self.assertEqual(r, result)
1213        r = get_attnames("s4.t4")
1214        self.assertEqual(r, result)
1215        query("drop table if exists s3.t3m")
1216        query("create table s3.t3m with oids as select 1 as m")
1217        result_m = {'oid': 'int', 'm': 'int'}
1218        r = get_attnames("s3.t3m")
1219        self.assertEqual(r, result_m)
1220        query("set search_path to s1,s3")
1221        r = get_attnames("t3")
1222        self.assertEqual(r, result)
1223        r = get_attnames("t3m")
1224        self.assertEqual(r, result_m)
1225        query("drop table s3.t3m")
1226
1227    def testGet(self):
1228        get = self.db.get
1229        query = self.db.query
1230        PrgError = pg.ProgrammingError
1231        self.assertEqual(get("t", 1, 'n')['d'], 0)
1232        self.assertEqual(get("t0", 1, 'n')['d'], 0)
1233        self.assertEqual(get("public.t", 1, 'n')['d'], 0)
1234        self.assertEqual(get("public.t0", 1, 'n')['d'], 0)
1235        self.assertRaises(PrgError, get, "public.t1", 1, 'n')
1236        self.assertEqual(get("s1.t1", 1, 'n')['d'], 1)
1237        self.assertEqual(get("s3.t", 1, 'n')['d'], 3)
1238        query("set search_path to s2,s4")
1239        self.assertRaises(PrgError, get, "t1", 1, 'n')
1240        self.assertEqual(get("t4", 1, 'n')['d'], 4)
1241        self.assertRaises(PrgError, get, "t3", 1, 'n')
1242        self.assertEqual(get("t", 1, 'n')['d'], 2)
1243        self.assertEqual(get("s3.t3", 1, 'n')['d'], 3)
1244        query("set search_path to s1,s3")
1245        self.assertRaises(PrgError, get, "t2", 1, 'n')
1246        self.assertEqual(get("t3", 1, 'n')['d'], 3)
1247        self.assertRaises(PrgError, get, "t4", 1, 'n')
1248        self.assertEqual(get("t", 1, 'n')['d'], 1)
1249        self.assertEqual(get("s4.t4", 1, 'n')['d'], 4)
1250
1251    def testMangling(self):
1252        get = self.db.get
1253        query = self.db.query
1254        r = get("t", 1, 'n')
1255        self.assertIn('oid(public.t)', r)
1256        query("set search_path to s2")
1257        r = get("t2", 1, 'n')
1258        self.assertIn('oid(s2.t2)', r)
1259        query("set search_path to s3")
1260        r = get("t", 1, 'n')
1261        self.assertIn('oid(s3.t)', r)
1262
1263
1264if __name__ == '__main__':
1265    unittest.main()
Note: See TracBrowser for help on using the repository browser.