source: trunk/tests/test_classic_dbwrapper.py @ 729

Last change on this file since 729 was 729, checked in by cito, 4 years ago

Simplify caching and handling of class names

The caches now use the class names as keys as they are passed in.
We do not automatically calculate the qualified name any more,
since this causes too much overhead. Also, we fill the pkey cache
not pro-actively with tables from all possible schemes any more.

Most of the internal auxiliary functions for handling class names
could be discarded by making good use of quote_ident and reglass.

  • Property svn:executable set to *
  • Property svn:keywords set to Id
File size: 51.0 KB
RevLine 
[521]1#! /usr/bin/python
[446]2# -*- coding: utf-8 -*-
[204]3
[539]4"""Test the classic PyGreSQL interface.
[204]5
[539]6Sub-tests for the DB wrapper object.
[204]7
[539]8Contributed by Christoph Zwerschke.
[204]9
[539]10These tests need a database to test against.
11
[204]12"""
13
[539]14try:
[553]15    import unittest2 as unittest  # for Python < 2.7
[539]16except ImportError:
17    import unittest
[607]18import os
[537]19
[539]20import pg  # the module under test
[204]21
[466]22from decimal import Decimal
23
[539]24# We need a database to test against.  If LOCAL_PyGreSQL.py exists we will
25# get our information from that.  Otherwise we use the defaults.
26# The current user must have create schema privilege on the database.
27dbname = 'unittest'
28dbhost = None
29dbport = 5432
[365]30
[539]31debug = False  # let DB wrapper print debugging output
[204]32
[539]33try:
[640]34    from .LOCAL_PyGreSQL import *
[648]35except (ImportError, ValueError):
36    try:
37        from LOCAL_PyGreSQL import *
38    except ImportError:
39        pass
[328]40
[607]41try:
42    long
43except NameError:  # Python >= 3.0
44    long = int
[204]45
[629]46try:
47    unicode
48except NameError:  # Python >= 3.0
49    unicode = str
50
[607]51windows = os.name == 'nt'
52
53# There is a known a bug in libpq under Windows which can cause
54# the interface to crash when calling PQhost():
55do_not_ask_for_host = windows
56do_not_ask_for_host_reason = 'libpq issue on Windows'
57
58
[539]59def DB():
60    """Create a DB wrapper object connecting to the test database."""
61    db = pg.DB(dbname, dbhost, dbport)
62    if debug:
63        db.debug = debug
64    db.query("set client_min_messages=warning")
65    return db
[328]66
[204]67
[539]68class TestDBClassBasic(unittest.TestCase):
[641]69    """Test existence of the DB class wrapped pg connection methods."""
[204]70
[346]71    def setUp(self):
[539]72        self.db = DB()
[204]73
[346]74    def tearDown(self):
75        try:
[539]76            self.db.close()
77        except pg.InternalError:
[346]78            pass
[204]79
[346]80    def testAllDBAttributes(self):
[513]81        attributes = [
82            'begin',
83            'cancel',
84            'clear',
85            'close',
86            'commit',
87            'db',
88            'dbname',
89            'debug',
90            'delete',
91            'end',
92            'endcopy',
93            'error',
94            'escape_bytea',
95            'escape_identifier',
96            'escape_literal',
97            'escape_string',
98            'fileno',
99            'get',
100            'get_attnames',
101            'get_databases',
102            'get_notice_receiver',
103            'get_relations',
104            'get_tables',
105            'getline',
106            'getlo',
107            'getnotify',
108            'has_table_privilege',
109            'host',
110            'insert',
111            'inserttable',
112            'locreate',
113            'loimport',
114            'notification_handler',
115            'options',
116            'parameter',
117            'pkey',
118            'port',
119            'protocol_version',
120            'putline',
121            'query',
122            'release',
123            'reopen',
124            'reset',
125            'rollback',
126            'savepoint',
127            'server_version',
128            'set_notice_receiver',
129            'source',
130            'start',
131            'status',
132            'transaction',
133            'unescape_bytea',
134            'update',
135            'use_regtypes',
136            'user',
137        ]
[346]138        db_attributes = [a for a in dir(self.db)
139            if not a.startswith('_')]
140        self.assertEqual(attributes, db_attributes)
[204]141
[346]142    def testAttributeDb(self):
[539]143        self.assertEqual(self.db.db.db, dbname)
[204]144
[346]145    def testAttributeDbname(self):
[539]146        self.assertEqual(self.db.dbname, dbname)
[204]147
[346]148    def testAttributeError(self):
149        error = self.db.error
[457]150        self.assertTrue(not error or 'krb5_' in error)
[346]151        self.assertEqual(self.db.error, self.db.db.error)
[204]152
[607]153    @unittest.skipIf(do_not_ask_for_host, do_not_ask_for_host_reason)
[346]154    def testAttributeHost(self):
155        def_host = 'localhost'
[384]156        host = self.db.host
[543]157        self.assertIsInstance(host, str)
158        self.assertEqual(host, dbhost or def_host)
[384]159        self.assertEqual(host, self.db.db.host)
[204]160
[346]161    def testAttributeOptions(self):
162        no_options = ''
[384]163        options = self.db.options
164        self.assertEqual(options, no_options)
165        self.assertEqual(options, self.db.db.options)
[204]166
[346]167    def testAttributePort(self):
168        def_port = 5432
[384]169        port = self.db.port
[542]170        self.assertIsInstance(port, int)
[543]171        self.assertEqual(port, dbport or def_port)
[384]172        self.assertEqual(port, self.db.db.port)
[204]173
[384]174    def testAttributeProtocolVersion(self):
175        protocol_version = self.db.protocol_version
[542]176        self.assertIsInstance(protocol_version, int)
[457]177        self.assertTrue(2 <= protocol_version < 4)
[384]178        self.assertEqual(protocol_version, self.db.db.protocol_version)
179
180    def testAttributeServerVersion(self):
181        server_version = self.db.server_version
[542]182        self.assertIsInstance(server_version, int)
[457]183        self.assertTrue(70400 <= server_version < 100000)
[384]184        self.assertEqual(server_version, self.db.db.server_version)
185
[346]186    def testAttributeStatus(self):
187        status_ok = 1
[384]188        status = self.db.status
[542]189        self.assertIsInstance(status, int)
[384]190        self.assertEqual(status, status_ok)
191        self.assertEqual(status, self.db.db.status)
[204]192
[346]193    def testAttributeUser(self):
[384]194        no_user = 'Deprecated facility'
195        user = self.db.user
[457]196        self.assertTrue(user)
[542]197        self.assertIsInstance(user, str)
[384]198        self.assertNotEqual(user, no_user)
199        self.assertEqual(user, self.db.db.user)
[204]200
[423]201    def testMethodEscapeLiteral(self):
[539]202        self.assertEqual(self.db.escape_literal(''), "''")
[423]203
204    def testMethodEscapeIdentifier(self):
[539]205        self.assertEqual(self.db.escape_identifier(''), '""')
[423]206
[346]207    def testMethodEscapeString(self):
[539]208        self.assertEqual(self.db.escape_string(''), '')
[287]209
[346]210    def testMethodEscapeBytea(self):
[539]211        self.assertEqual(self.db.escape_bytea('').replace(
212            '\\x', '').replace('\\', ''), '')
[287]213
[346]214    def testMethodUnescapeBytea(self):
[629]215        self.assertEqual(self.db.unescape_bytea(''), b'')
[287]216
[346]217    def testMethodQuery(self):
[539]218        query = self.db.query
219        query("select 1+1")
220        query("select 1+$1+$2", 2, 3)
221        query("select 1+$1+$2", (2, 3))
222        query("select 1+$1+$2", [2, 3])
223        query("select 1+$1", 1)
[204]224
[446]225    def testMethodQueryEmpty(self):
226        self.assertRaises(ValueError, self.db.query, '')
227
[434]228    def testMethodQueryProgrammingError(self):
229        try:
230            self.db.query("select 1/0")
[556]231        except pg.ProgrammingError as error:
[434]232            self.assertEqual(error.sqlstate, '22012')
233
[346]234    def testMethodEndcopy(self):
235        try:
236            self.db.endcopy()
237        except IOError:
238            pass
[204]239
[346]240    def testMethodClose(self):
241        self.db.close()
242        try:
243            self.db.reset()
[436]244        except pg.Error:
245            pass
246        else:
[377]247            self.fail('Reset should give an error for a closed connection')
[346]248        self.assertRaises(pg.InternalError, self.db.close)
249        self.assertRaises(pg.InternalError, self.db.query, 'select 1')
[204]250
[346]251    def testExistingConnection(self):
252        db = pg.DB(self.db.db)
253        self.assertEqual(self.db.db, db.db)
[457]254        self.assertTrue(db.db)
[346]255        db.close()
[457]256        self.assertTrue(db.db)
[346]257        db.reopen()
[457]258        self.assertTrue(db.db)
[346]259        db.close()
[457]260        self.assertTrue(db.db)
[346]261        db = pg.DB(self.db)
262        self.assertEqual(self.db.db, db.db)
263        db = pg.DB(db=self.db.db)
264        self.assertEqual(self.db.db, db.db)
[464]265
[346]266        class DB2:
267            pass
[464]268
[346]269        db2 = DB2()
270        db2._cnx = self.db.db
271        db = pg.DB(db2)
272        self.assertEqual(self.db.db, db.db)
[328]273
[333]274
[204]275class TestDBClass(unittest.TestCase):
[641]276    """Test the methods of the DB class wrapped pg connection."""
[204]277
[539]278    @classmethod
279    def setUpClass(cls):
280        db = DB()
281        db.query("drop table if exists test cascade")
282        db.query("create table test ("
283            "i2 smallint, i4 integer, i8 bigint,"
284            "d numeric, f4 real, f8 double precision, m money, "
285            "v4 varchar(4), c4 char(4), t text)")
286        db.query("create or replace view test_view as"
287            " select i4, v4 from test")
288        db.close()
[204]289
[539]290    @classmethod
291    def tearDownClass(cls):
292        db = DB()
293        db.query("drop table test cascade")
294        db.close()
295
[346]296    def setUp(self):
[539]297        self.db = DB()
[602]298        query = self.db.query
299        query('set client_encoding=utf8')
300        query('set standard_conforming_strings=on')
[690]301        query("set lc_monetary='C'")
302        query("set datestyle='ISO,YMD'")
[602]303        query('set bytea_output=hex')
[204]304
[346]305    def tearDown(self):
306        self.db.close()
[204]307
[580]308    def testClassName(self):
309        self.assertEqual(self.db.__class__.__name__, 'DB')
310
311    def testModuleName(self):
312        self.assertEqual(self.db.__module__, 'pg')
[598]313        self.assertEqual(self.db.__class__.__module__, 'pg')
[580]314
[539]315    def testEscapeLiteral(self):
316        f = self.db.escape_literal
[629]317        r = f(b"plain")
318        self.assertIsInstance(r, bytes)
319        self.assertEqual(r, b"'plain'")
320        r = f(u"plain")
321        self.assertIsInstance(r, unicode)
322        self.assertEqual(r, u"'plain'")
323        r = f(u"that's kÀse".encode('utf-8'))
324        self.assertIsInstance(r, bytes)
325        self.assertEqual(r, u"'that''s kÀse'".encode('utf-8'))
326        r = f(u"that's kÀse")
327        self.assertIsInstance(r, unicode)
328        self.assertEqual(r, u"'that''s kÀse'")
[539]329        self.assertEqual(f(r"It's fine to have a \ inside."),
330            r" E'It''s fine to have a \\ inside.'")
331        self.assertEqual(f('No "quotes" must be escaped.'),
332            "'No \"quotes\" must be escaped.'")
333
334    def testEscapeIdentifier(self):
335        f = self.db.escape_identifier
[629]336        r = f(b"plain")
337        self.assertIsInstance(r, bytes)
338        self.assertEqual(r, b'"plain"')
339        r = f(u"plain")
340        self.assertIsInstance(r, unicode)
341        self.assertEqual(r, u'"plain"')
342        r = f(u"that's kÀse".encode('utf-8'))
343        self.assertIsInstance(r, bytes)
344        self.assertEqual(r, u'"that\'s kÀse"'.encode('utf-8'))
345        r = f(u"that's kÀse")
346        self.assertIsInstance(r, unicode)
347        self.assertEqual(r, u'"that\'s kÀse"')
[539]348        self.assertEqual(f(r"It's fine to have a \ inside."),
349            '"It\'s fine to have a \\ inside."')
350        self.assertEqual(f('All "quotes" must be escaped.'),
351            '"All ""quotes"" must be escaped."')
352
[346]353    def testEscapeString(self):
[539]354        f = self.db.escape_string
[629]355        r = f(b"plain")
356        self.assertIsInstance(r, bytes)
357        self.assertEqual(r, b"plain")
358        r = f(u"plain")
359        self.assertIsInstance(r, unicode)
360        self.assertEqual(r, u"plain")
361        r = f(u"that's kÀse".encode('utf-8'))
362        self.assertIsInstance(r, bytes)
363        self.assertEqual(r, u"that''s kÀse".encode('utf-8'))
364        r = f(u"that's kÀse")
365        self.assertIsInstance(r, unicode)
366        self.assertEqual(r, u"that''s kÀse")
[539]367        self.assertEqual(f(r"It's fine to have a \ inside."),
368            r"It''s fine to have a \ inside.")
[287]369
[346]370    def testEscapeBytea(self):
[539]371        f = self.db.escape_bytea
372        # note that escape_byte always returns hex output since Pg 9.0,
373        # regardless of the bytea_output setting
[592]374        r = f(b'plain')
[629]375        self.assertIsInstance(r, bytes)
376        self.assertEqual(r, b'\\x706c61696e')
[592]377        r = f(u'plain')
[629]378        self.assertIsInstance(r, unicode)
379        self.assertEqual(r, u'\\x706c61696e')
[592]380        r = f(u"das is' kÀse".encode('utf-8'))
[629]381        self.assertIsInstance(r, bytes)
382        self.assertEqual(r, b'\\x64617320697327206bc3a47365')
383        r = f(u"das is' kÀse")
384        self.assertIsInstance(r, unicode)
385        self.assertEqual(r, u'\\x64617320697327206bc3a47365')
386        self.assertEqual(f(b'O\x00ps\xff!'), b'\\x4f007073ff21')
[287]387
[346]388    def testUnescapeBytea(self):
[539]389        f = self.db.unescape_bytea
[592]390        r = f(b'plain')
[629]391        self.assertIsInstance(r, bytes)
392        self.assertEqual(r, b'plain')
[592]393        r = f(u'plain')
[629]394        self.assertIsInstance(r, bytes)
395        self.assertEqual(r, b'plain')
[592]396        r = f(b"das is' k\\303\\244se")
[629]397        self.assertIsInstance(r, bytes)
398        self.assertEqual(r, u"das is' kÀse".encode('utf8'))
[592]399        r = f(u"das is' k\\303\\244se")
[629]400        self.assertIsInstance(r, bytes)
401        self.assertEqual(r, u"das is' kÀse".encode('utf8'))
402        self.assertEqual(f(r'O\\000ps\\377!'), b'O\\000ps\\377!')
403        self.assertEqual(f(r'\\x706c61696e'), b'\\x706c61696e')
[592]404        self.assertEqual(f(r'\\x746861742773206be47365'),
[629]405            b'\\x746861742773206be47365')
406        self.assertEqual(f(r'\\x4f007073ff21'), b'\\x4f007073ff21')
[287]407
[365]408    def testQuote(self):
409        f = self.db._quote
410        self.assertEqual(f(None, None), 'NULL')
411        self.assertEqual(f(None, 'int'), 'NULL')
412        self.assertEqual(f(None, 'float'), 'NULL')
413        self.assertEqual(f(None, 'num'), 'NULL')
414        self.assertEqual(f(None, 'money'), 'NULL')
415        self.assertEqual(f(None, 'bool'), 'NULL')
416        self.assertEqual(f(None, 'date'), 'NULL')
417        self.assertEqual(f('', 'int'), 'NULL')
418        self.assertEqual(f('', 'float'), 'NULL')
419        self.assertEqual(f('', 'num'), 'NULL')
420        self.assertEqual(f('', 'money'), 'NULL')
421        self.assertEqual(f('', 'bool'), 'NULL')
422        self.assertEqual(f('', 'date'), 'NULL')
423        self.assertEqual(f('', 'text'), "''")
[397]424        self.assertEqual(f(0, 'int'), '0')
425        self.assertEqual(f(0, 'num'), '0')
426        self.assertEqual(f(1, 'int'), '1')
427        self.assertEqual(f(1, 'num'), '1')
428        self.assertEqual(f(-1, 'int'), '-1')
429        self.assertEqual(f(-1, 'num'), '-1')
[365]430        self.assertEqual(f(123456789, 'int'), '123456789')
431        self.assertEqual(f(123456987, 'num'), '123456987')
432        self.assertEqual(f(1.23654789, 'num'), '1.23654789')
433        self.assertEqual(f(12365478.9, 'num'), '12365478.9')
434        self.assertEqual(f('123456789', 'num'), '123456789')
435        self.assertEqual(f('1.23456789', 'num'), '1.23456789')
436        self.assertEqual(f('12345678.9', 'num'), '12345678.9')
[493]437        self.assertEqual(f(123, 'money'), '123')
438        self.assertEqual(f('123', 'money'), '123')
439        self.assertEqual(f(123.45, 'money'), '123.45')
440        self.assertEqual(f('123.45', 'money'), '123.45')
441        self.assertEqual(f(123.454, 'money'), '123.454')
442        self.assertEqual(f('123.454', 'money'), '123.454')
443        self.assertEqual(f(123.456, 'money'), '123.456')
444        self.assertEqual(f('123.456', 'money'), '123.456')
[365]445        self.assertEqual(f('f', 'bool'), "'f'")
446        self.assertEqual(f('F', 'bool'), "'f'")
447        self.assertEqual(f('false', 'bool'), "'f'")
448        self.assertEqual(f('False', 'bool'), "'f'")
449        self.assertEqual(f('FALSE', 'bool'), "'f'")
450        self.assertEqual(f(0, 'bool'), "'f'")
451        self.assertEqual(f('0', 'bool'), "'f'")
452        self.assertEqual(f('-', 'bool'), "'f'")
453        self.assertEqual(f('n', 'bool'), "'f'")
454        self.assertEqual(f('N', 'bool'), "'f'")
455        self.assertEqual(f('no', 'bool'), "'f'")
456        self.assertEqual(f('off', 'bool'), "'f'")
457        self.assertEqual(f('t', 'bool'), "'t'")
458        self.assertEqual(f('T', 'bool'), "'t'")
459        self.assertEqual(f('true', 'bool'), "'t'")
460        self.assertEqual(f('True', 'bool'), "'t'")
461        self.assertEqual(f('TRUE', 'bool'), "'t'")
462        self.assertEqual(f(1, 'bool'), "'t'")
463        self.assertEqual(f(2, 'bool'), "'t'")
464        self.assertEqual(f(-1, 'bool'), "'t'")
465        self.assertEqual(f(0.5, 'bool'), "'t'")
466        self.assertEqual(f('1', 'bool'), "'t'")
467        self.assertEqual(f('y', 'bool'), "'t'")
468        self.assertEqual(f('Y', 'bool'), "'t'")
469        self.assertEqual(f('yes', 'bool'), "'t'")
470        self.assertEqual(f('on', 'bool'), "'t'")
471        self.assertEqual(f('01.01.2000', 'date'), "'01.01.2000'")
472        self.assertEqual(f(123, 'text'), "'123'")
473        self.assertEqual(f(1.23, 'text'), "'1.23'")
474        self.assertEqual(f('abc', 'text'), "'abc'")
475        self.assertEqual(f("ab'c", 'text'), "'ab''c'")
[539]476        self.assertEqual(f('ab\\c', 'text'), "'ab\\c'")
477        self.assertEqual(f("a\\b'c", 'text'), "'a\\b''c'")
478        self.db.query('set standard_conforming_strings=off')
[365]479        self.assertEqual(f('ab\\c', 'text'), "'ab\\\\c'")
480        self.assertEqual(f("a\\b'c", 'text'), "'a\\\\b''c'")
481
[368]482    def testQuery(self):
[539]483        query = self.db.query
484        query("drop table if exists test_table")
[368]485        q = "create table test_table (n integer) with oids"
[539]486        r = query(q)
487        self.assertIsNone(r)
[368]488        q = "insert into test_table values (1)"
[539]489        r = query(q)
490        self.assertIsInstance(r, int)
[368]491        q = "insert into test_table select 2"
[539]492        r = query(q)
493        self.assertIsInstance(r, int)
[368]494        oid = r
495        q = "select oid from test_table where n=2"
[539]496        r = query(q).getresult()
[368]497        self.assertEqual(len(r), 1)
498        r = r[0]
499        self.assertEqual(len(r), 1)
500        r = r[0]
501        self.assertEqual(r, oid)
502        q = "insert into test_table select 3 union select 4 union select 5"
[539]503        r = query(q)
504        self.assertIsInstance(r, str)
[368]505        self.assertEqual(r, '3')
506        q = "update test_table set n=4 where n<5"
[539]507        r = query(q)
508        self.assertIsInstance(r, str)
[368]509        self.assertEqual(r, '4')
510        q = "delete from test_table"
[539]511        r = query(q)
512        self.assertIsInstance(r, str)
[368]513        self.assertEqual(r, '5')
[539]514        query("drop table test_table")
[368]515
[446]516    def testMultipleQueries(self):
517        self.assertEqual(self.db.query(
518            "create temporary table test_multi (n integer);"
519            "insert into test_multi values (4711);"
520            "select n from test_multi").getresult()[0][0], 4711)
521
522    def testQueryWithParams(self):
[539]523        query = self.db.query
524        query("drop table if exists test_table")
[446]525        q = "create table test_table (n1 integer, n2 integer) with oids"
[539]526        query(q)
[446]527        q = "insert into test_table values ($1, $2)"
[539]528        r = query(q, (1, 2))
529        self.assertIsInstance(r, int)
530        r = query(q, [3, 4])
531        self.assertIsInstance(r, int)
532        r = query(q, [5, 6])
533        self.assertIsInstance(r, int)
[446]534        q = "select * from test_table order by 1, 2"
[539]535        self.assertEqual(query(q).getresult(),
[446]536            [(1, 2), (3, 4), (5, 6)])
537        q = "select * from test_table where n1=$1 and n2=$2"
[539]538        self.assertEqual(query(q, 3, 4).getresult(), [(3, 4)])
[446]539        q = "update test_table set n2=$2 where n1=$1"
[539]540        r = query(q, 3, 7)
[446]541        self.assertEqual(r, '1')
542        q = "select * from test_table order by 1, 2"
[539]543        self.assertEqual(query(q).getresult(),
[446]544            [(1, 2), (3, 7), (5, 6)])
545        q = "delete from test_table where n2!=$1"
[539]546        r = query(q, 4)
[446]547        self.assertEqual(r, '3')
[539]548        query("drop table test_table")
[446]549
550    def testEmptyQuery(self):
551        self.assertRaises(ValueError, self.db.query, '')
552
[434]553    def testQueryProgrammingError(self):
554        try:
555            self.db.query("select 1/0")
[556]556        except pg.ProgrammingError as error:
[434]557            self.assertEqual(error.sqlstate, '22012')
558
[346]559    def testPkey(self):
[539]560        query = self.db.query
561        pkey = self.db.pkey
[725]562        for t in ('pkeytest', 'primary key test'):
563            for n in range(7):
564                query('drop table if exists "%s%d"' % (t, n))
565            query('create table "%s0" ('
566                "a smallint)" % t)
567            query('create table "%s1" ('
568                "b smallint primary key)" % t)
569            query('create table "%s2" ('
570                "c smallint, d smallint primary key)" % t)
571            query('create table "%s3" ('
572                "e smallint, f smallint, g smallint, "
573                "h smallint, i smallint, "
574                "primary key (f, h))" % t)
575            query('create table "%s4" ('
576                "more_than_one_letter varchar primary key)" % t)
577            query('create table "%s5" ('
578                '"with space" date primary key)' % t)
579            query('create table "%s6" ('
580                'a_very_long_column_name varchar, '
581                '"with space" date, '
582                '"42" int, '
583                "primary key (a_very_long_column_name, "
584                '"with space", "42"))' % t)
585            self.assertRaises(KeyError, pkey, '%s0' % t)
586            self.assertEqual(pkey('%s1' % t), 'b')
587            self.assertEqual(pkey('%s2' % t), 'd')
588            r = pkey('%s3' % t)
589            self.assertIsInstance(r, frozenset)
590            self.assertEqual(r, frozenset('fh'))
591            self.assertEqual(pkey('%s4' % t), 'more_than_one_letter')
592            self.assertEqual(pkey('%s5' % t), 'with space')
593            r = pkey('%s6' % t)
594            self.assertIsInstance(r, frozenset)
595            self.assertEqual(r, frozenset([
596                'a_very_long_column_name', 'with space', '42']))
[729]597            # a newly added primary key will be detected
598            query('alter table "%s0" add primary key (a)' % t)
[725]599            self.assertEqual(pkey('%s0' % t), 'a')
[729]600            # a changed primary key will not be detected,
601            # indicating that the internal cache is operating
602            query('alter table "%s1" rename column b to x' % t)
[725]603            self.assertEqual(pkey('%s1' % t), 'b')
[729]604            # we get the changed primary key when the cache is flushed
605            self.assertEqual(pkey('%s1' % t, flush=True), 'x')
[725]606            for n in range(7):
607                query('drop table "%s%d"' % (t, n))
[204]608
[346]609    def testGetDatabases(self):
610        databases = self.db.get_databases()
[539]611        self.assertIn('template0', databases)
612        self.assertIn('template1', databases)
613        self.assertNotIn('not existing database', databases)
614        self.assertIn('postgres', databases)
615        self.assertIn(dbname, databases)
[204]616
[346]617    def testGetTables(self):
[539]618        get_tables = self.db.get_tables
619        result1 = get_tables()
[709]620        self.assertIsInstance(result1, list)
621        for t in result1:
622            t = t.split('.', 1)
623            self.assertGreaterEqual(len(t), 2)
624            if len(t) > 2:
625                self.assertTrue(t[1].startswith('"'))
626            t = t[0]
627            self.assertNotEqual(t, 'information_schema')
628            self.assertFalse(t.startswith('pg_'))
[346]629        tables = ('"A very Special Name"',
630            '"A_MiXeD_quoted_NaMe"', 'a1', 'a2',
631            'A_MiXeD_NaMe', '"another special name"',
632            'averyveryveryveryveryveryverylongtablename',
633            'b0', 'b3', 'x', 'xx', 'xXx', 'y', 'z')
634        for t in tables:
[539]635            self.db.query('drop table if exists %s' % t)
636            self.db.query("create table %s"
[346]637                " as select 0" % t)
[539]638        result3 = get_tables()
[346]639        result2 = []
640        for t in result3:
641            if t not in result1:
642                result2.append(t)
643        result3 = []
644        for t in tables:
645            if not t.startswith('"'):
646                t = t.lower()
647            result3.append('public.' + t)
648        self.assertEqual(result2, result3)
649        for t in result2:
[539]650            self.db.query('drop table %s' % t)
651        result2 = get_tables()
[346]652        self.assertEqual(result2, result1)
[204]653
[346]654    def testGetRelations(self):
[539]655        get_relations = self.db.get_relations
656        result = get_relations()
657        self.assertIn('public.test', result)
658        self.assertIn('public.test_view', result)
659        result = get_relations('rv')
660        self.assertIn('public.test', result)
661        self.assertIn('public.test_view', result)
662        result = get_relations('r')
663        self.assertIn('public.test', result)
664        self.assertNotIn('public.test_view', result)
665        result = get_relations('v')
666        self.assertNotIn('public.test', result)
667        self.assertIn('public.test_view', result)
668        result = get_relations('cisSt')
669        self.assertNotIn('public.test', result)
670        self.assertNotIn('public.test_view', result)
[250]671
[346]672    def testAttnames(self):
673        self.assertRaises(pg.ProgrammingError,
674            self.db.get_attnames, 'does_not_exist')
675        self.assertRaises(pg.ProgrammingError,
676            self.db.get_attnames, 'has.too.many.dots')
677        for table in ('attnames_test_table', 'test table for attnames'):
[539]678            self.db.query('drop table if exists "%s"' % table)
679            self.db.query('create table "%s" ('
[346]680                'a smallint, b integer, c bigint, '
681                'e numeric, f float, f2 double precision, m money, '
682                'x smallint, y smallint, z smallint, '
683                'Normal_NaMe smallint, "Special Name" smallint, '
684                't text, u char(2), v varchar(2), '
[539]685                'primary key (y, u)) with oids' % table)
[346]686            attributes = self.db.get_attnames(table)
687            result = {'a': 'int', 'c': 'int', 'b': 'int',
688                'e': 'num', 'f': 'float', 'f2': 'float', 'm': 'money',
689                'normal_name': 'int', 'Special Name': 'int',
690                'u': 'text', 't': 'text', 'v': 'text',
[436]691                'y': 'int', 'x': 'int', 'z': 'int', 'oid': 'int'}
[346]692            self.assertEqual(attributes, result)
[539]693            self.db.query('drop table "%s"' % table)
[204]694
[391]695    def testHasTablePrivilege(self):
696        can = self.db.has_table_privilege
697        self.assertEqual(can('test'), True)
698        self.assertEqual(can('test', 'select'), True)
699        self.assertEqual(can('test', 'SeLeCt'), True)
700        self.assertEqual(can('test', 'SELECT'), True)
701        self.assertEqual(can('test', 'insert'), True)
702        self.assertEqual(can('test', 'update'), True)
703        self.assertEqual(can('test', 'delete'), True)
[706]704        self.assertEqual(can('pg_views', 'select'), True)
705        self.assertEqual(can('pg_views', 'delete'), False)
[391]706        self.assertRaises(pg.ProgrammingError, can, 'test', 'foobar')
707        self.assertRaises(pg.ProgrammingError, can, 'table_does_not_exist')
708
[346]709    def testGet(self):
[539]710        get = self.db.get
711        query = self.db.query
[346]712        for table in ('get_test_table', 'test table for get'):
[539]713            query('drop table if exists "%s"' % table)
714            query('create table "%s" ('
715                "n integer, t text) with oids" % table)
[346]716            for n, t in enumerate('xyz'):
[539]717                query('insert into "%s" values('"%d, '%s')"
718                    % (table, n + 1, t))
719            self.assertRaises(pg.ProgrammingError, get, table, 2)
720            r = get(table, 2, 'n')
[729]721            oid_table = 'oid(%s)' % table
[539]722            self.assertIn(oid_table, r)
[389]723            oid = r[oid_table]
[539]724            self.assertIsInstance(oid, int)
[389]725            result = {'t': 'y', 'n': 2, oid_table: oid}
726            self.assertEqual(r, result)
[539]727            self.assertEqual(get(table + ' *', 2, 'n'), r)
728            self.assertEqual(get(table, oid, 'oid')['t'], 'y')
729            self.assertEqual(get(table, 1, 'n')['t'], 'x')
730            self.assertEqual(get(table, 3, 'n')['t'], 'z')
731            self.assertEqual(get(table, 2, 'n')['t'], 'y')
732            self.assertRaises(pg.DatabaseError, get, table, 4, 'n')
[346]733            r['n'] = 3
[539]734            self.assertEqual(get(table, r, 'n')['t'], 'z')
735            self.assertEqual(get(table, 1, 'n')['t'], 'x')
736            query('alter table "%s" alter n set not null' % table)
737            query('alter table "%s" add primary key (n)' % table)
738            self.assertEqual(get(table, 3)['t'], 'z')
739            self.assertEqual(get(table, 1)['t'], 'x')
740            self.assertEqual(get(table, 2)['t'], 'y')
[346]741            r['n'] = 1
[539]742            self.assertEqual(get(table, r)['t'], 'x')
[346]743            r['n'] = 3
[539]744            self.assertEqual(get(table, r)['t'], 'z')
[346]745            r['n'] = 2
[539]746            self.assertEqual(get(table, r)['t'], 'y')
747            query('drop table "%s"' % table)
[204]748
[389]749    def testGetWithCompositeKey(self):
[539]750        get = self.db.get
751        query = self.db.query
[389]752        table = 'get_test_table_1'
[539]753        query("drop table if exists %s" % table)
754        query("create table %s ("
[389]755            "n integer, t text, primary key (n))" % table)
756        for n, t in enumerate('abc'):
[539]757            query("insert into %s values("
[464]758                "%d, '%s')" % (table, n + 1, t))
[539]759        self.assertEqual(get(table, 2)['t'], 'b')
760        query("drop table %s" % table)
[389]761        table = 'get_test_table_2'
[539]762        query("drop table if exists %s" % table)
763        query("create table %s ("
[389]764            "n integer, m integer, t text, primary key (n, m))" % table)
765        for n in range(3):
766            for m in range(2):
[464]767                t = chr(ord('a') + 2 * n + m)
[539]768                query("insert into %s values("
[464]769                    "%d, %d, '%s')" % (table, n + 1, m + 1, t))
[539]770        self.assertRaises(pg.ProgrammingError, get, table, 2)
771        self.assertEqual(get(table, dict(n=2, m=2))['t'], 'd')
772        self.assertEqual(get(table, dict(n=1, m=2),
773                             ('n', 'm'))['t'], 'b')
774        self.assertEqual(get(table, dict(n=3, m=2),
775                             frozenset(['n', 'm']))['t'], 'f')
776        query("drop table %s" % table)
[389]777
[346]778    def testGetFromView(self):
779        self.db.query('delete from test where i4=14')
780        self.db.query('insert into test (i4, v4) values('
781            "14, 'abc4')")
782        r = self.db.get('test_view', 14, 'i4')
[539]783        self.assertIn('v4', r)
[346]784        self.assertEqual(r['v4'], 'abc4')
[254]785
[346]786    def testInsert(self):
[539]787        insert = self.db.insert
788        query = self.db.query
[690]789        server_version = self.db.server_version
[706]790        bool_on = pg.get_bool()
791        decimal = pg.get_decimal()
[346]792        for table in ('insert_test_table', 'test table for insert'):
[539]793            query('drop table if exists "%s"' % table)
794            query('create table "%s" ('
[346]795                "i2 smallint, i4 integer, i8 bigint,"
[493]796                " d numeric, f4 real, f8 double precision, m money,"
797                " v4 varchar(4), c4 char(4), t text,"
[539]798                " b boolean, ts timestamp) with oids" % table)
[729]799            oid_table = 'oid(%s)' % table
[493]800            tests = [dict(i2=None, i4=None, i8=None),
801                (dict(i2='', i4='', i8=''), dict(i2=None, i4=None, i8=None)),
802                (dict(i2=0, i4=0, i8=0), dict(i2=0, i4=0, i8=0)),
803                dict(i2=42, i4=123456, i8=9876543210),
804                dict(i2=2 ** 15 - 1,
805                    i4=int(2 ** 31 - 1), i8=long(2 ** 63 - 1)),
806                dict(d=None), (dict(d=''), dict(d=None)),
807                dict(d=Decimal(0)), (dict(d=0), dict(d=Decimal(0))),
808                dict(f4=None, f8=None), dict(f4=0, f8=0),
809                (dict(f4='', f8=''), dict(f4=None, f8=None)),
[644]810                (dict(d=1234.5, f4=1234.5, f8=1234.5),
811                      dict(d=Decimal('1234.5'))),
[493]812                dict(d=Decimal('123.456789'), f4=12.375, f8=123.4921875),
813                dict(d=Decimal('123456789.9876543212345678987654321')),
814                dict(m=None), (dict(m=''), dict(m=None)),
815                dict(m=Decimal('-1234.56')),
816                (dict(m=('-1234.56')), dict(m=Decimal('-1234.56'))),
817                dict(m=Decimal('1234.56')), dict(m=Decimal('123456')),
818                (dict(m='1234.56'), dict(m=Decimal('1234.56'))),
819                (dict(m=1234.5), dict(m=Decimal('1234.5'))),
820                (dict(m=-1234.5), dict(m=Decimal('-1234.5'))),
821                (dict(m=123456), dict(m=Decimal('123456'))),
822                (dict(m='1234567.89'), dict(m=Decimal('1234567.89'))),
823                dict(b=None), (dict(b=''), dict(b=None)),
824                dict(b='f'), dict(b='t'),
825                (dict(b=0), dict(b='f')), (dict(b=1), dict(b='t')),
826                (dict(b=False), dict(b='f')), (dict(b=True), dict(b='t')),
827                (dict(b='0'), dict(b='f')), (dict(b='1'), dict(b='t')),
828                (dict(b='n'), dict(b='f')), (dict(b='y'), dict(b='t')),
829                (dict(b='no'), dict(b='f')), (dict(b='yes'), dict(b='t')),
830                (dict(b='off'), dict(b='f')), (dict(b='on'), dict(b='t')),
831                dict(v4=None, c4=None, t=None),
832                (dict(v4='', c4='', t=''), dict(c4=' ' * 4)),
833                dict(v4='1234', c4='1234', t='1234' * 10),
834                dict(v4='abcd', c4='abcd', t='abcdefg'),
835                (dict(v4='abc', c4='abc', t='abc'), dict(c4='abc ')),
836                dict(ts=None), (dict(ts=''), dict(ts=None)),
837                (dict(ts=0), dict(ts=None)), (dict(ts=False), dict(ts=None)),
838                dict(ts='2012-12-21 00:00:00'),
839                (dict(ts='2012-12-21'), dict(ts='2012-12-21 00:00:00')),
840                dict(ts='2012-12-21 12:21:12'),
841                dict(ts='2013-01-05 12:13:14'),
842                dict(ts='current_timestamp')]
843            for test in tests:
844                if isinstance(test, dict):
845                    data = test
846                    change = {}
847                else:
848                    data, change = test
849                expect = data.copy()
850                expect.update(change)
[706]851                if bool_on:
852                    b = expect.get('b')
853                    if b is not None:
854                        expect['b'] = b == 't'
855                if decimal is not Decimal:
856                    d = expect.get('d')
857                    if d is not None:
858                        expect['d'] = decimal(d)
859                    m = expect.get('m')
860                    if m is not None:
861                        expect['m'] = decimal(m)
[690]862                if data.get('m') and server_version < 910000:
863                    # PostgreSQL < 9.1 cannot directly convert numbers to money
864                    data['m'] = "'%s'::money" % data['m']
[539]865                self.assertEqual(insert(table, data), data)
866                self.assertIn(oid_table, data)
[493]867                oid = data[oid_table]
[539]868                self.assertIsInstance(oid, int)
[591]869                data = dict(item for item in data.items()
[493]870                    if item[0] in expect)
871                ts = expect.get('ts')
872                if ts == 'current_timestamp':
[497]873                    ts = expect['ts'] = data['ts']
874                    if len(ts) > 19:
875                        self.assertEqual(ts[19], '.')
876                        ts = ts[:19]
877                    else:
878                        self.assertEqual(len(ts), 19)
879                    self.assertTrue(ts[:4].isdigit())
880                    self.assertEqual(ts[4], '-')
881                    self.assertEqual(ts[10], ' ')
882                    self.assertTrue(ts[11:13].isdigit())
883                    self.assertEqual(ts[13], ':')
[493]884                self.assertEqual(data, expect)
[539]885                data = query(
[493]886                    'select oid,* from "%s"' % table).dictresult()[0]
887                self.assertEqual(data['oid'], oid)
[591]888                data = dict(item for item in data.items()
[493]889                    if item[0] in expect)
890                self.assertEqual(data, expect)
[539]891                query('delete from "%s"' % table)
892            query('drop table "%s"' % table)
[204]893
[346]894    def testUpdate(self):
[539]895        update = self.db.update
896        query = self.db.query
[346]897        for table in ('update_test_table', 'test table for update'):
[539]898            query('drop table if exists "%s"' % table)
899            query('create table "%s" ('
900                "n integer, t text) with oids" % table)
[346]901            for n, t in enumerate('xyz'):
[539]902                query('insert into "%s" values('
[464]903                    "%d, '%s')" % (table, n + 1, t))
[389]904            self.assertRaises(pg.ProgrammingError, self.db.get, table, 2)
[346]905            r = self.db.get(table, 2, 'n')
906            r['t'] = 'u'
[539]907            s = update(table, r)
[346]908            self.assertEqual(s, r)
[539]909            r = query('select t from "%s" where n=2' % table
910                      ).getresult()[0][0]
[346]911            self.assertEqual(r, 'u')
[539]912            query('drop table "%s"' % table)
[204]913
[389]914    def testUpdateWithCompositeKey(self):
[539]915        update = self.db.update
916        query = self.db.query
[391]917        table = 'update_test_table_1'
[539]918        query("drop table if exists %s" % table)
919        query("create table %s ("
[389]920            "n integer, t text, primary key (n))" % table)
921        for n, t in enumerate('abc'):
[539]922            query("insert into %s values("
[464]923                "%d, '%s')" % (table, n + 1, t))
[539]924        self.assertRaises(pg.ProgrammingError, update,
925                          table, dict(t='b'))
926        self.assertEqual(update(table, dict(n=2, t='d'))['t'], 'd')
927        r = query('select t from "%s" where n=2' % table
928                  ).getresult()[0][0]
[389]929        self.assertEqual(r, 'd')
[539]930        query("drop table %s" % table)
[391]931        table = 'update_test_table_2'
[539]932        query("drop table if exists %s" % table)
933        query("create table %s ("
[389]934            "n integer, m integer, t text, primary key (n, m))" % table)
935        for n in range(3):
936            for m in range(2):
[464]937                t = chr(ord('a') + 2 * n + m)
[539]938                query("insert into %s values("
[464]939                    "%d, %d, '%s')" % (table, n + 1, m + 1, t))
[539]940        self.assertRaises(pg.ProgrammingError, update,
941                          table, dict(n=2, t='b'))
942        self.assertEqual(update(table,
943                                dict(n=2, m=2, t='x'))['t'], 'x')
944        r = [r[0] for r in query('select t from "%s" where n=2'
[389]945            ' order by m' % table).getresult()]
946        self.assertEqual(r, ['c', 'x'])
[539]947        query("drop table %s" % table)
[389]948
[346]949    def testClear(self):
[539]950        clear = self.db.clear
951        query = self.db.query
[706]952        f = False if pg.get_bool() else 'f'
[346]953        for table in ('clear_test_table', 'test table for clear'):
[539]954            query('drop table if exists "%s"' % table)
955            query('create table "%s" ('
[346]956                "n integer, b boolean, d date, t text)" % table)
[539]957            r = clear(table)
[706]958            result = {'n': 0, 'b': f, 'd': '', 't': ''}
[346]959            self.assertEqual(r, result)
960            r['a'] = r['n'] = 1
961            r['d'] = r['t'] = 'x'
[436]962            r['b'] = 't'
[571]963            r['oid'] = long(1)
[539]964            r = clear(table, r)
[706]965            result = {'a': 1, 'n': 0, 'b': f, 'd': '', 't': '',
[571]966                'oid': long(1)}
[346]967            self.assertEqual(r, result)
[539]968            query('drop table "%s"' % table)
[204]969
[346]970    def testDelete(self):
[539]971        delete = self.db.delete
972        query = self.db.query
[346]973        for table in ('delete_test_table', 'test table for delete'):
[539]974            query('drop table if exists "%s"' % table)
975            query('create table "%s" ('
976                "n integer, t text) with oids" % table)
[346]977            for n, t in enumerate('xyz'):
[539]978                query('insert into "%s" values('
[464]979                    "%d, '%s')" % (table, n + 1, t))
[389]980            self.assertRaises(pg.ProgrammingError, self.db.get, table, 2)
[346]981            r = self.db.get(table, 1, 'n')
[539]982            s = delete(table, r)
[391]983            self.assertEqual(s, 1)
[346]984            r = self.db.get(table, 3, 'n')
[539]985            s = delete(table, r)
[391]986            self.assertEqual(s, 1)
[539]987            s = delete(table, r)
[391]988            self.assertEqual(s, 0)
[539]989            r = query('select * from "%s"' % table).dictresult()
[346]990            self.assertEqual(len(r), 1)
991            r = r[0]
992            result = {'n': 2, 't': 'y'}
993            self.assertEqual(r, result)
994            r = self.db.get(table, 2, 'n')
[539]995            s = delete(table, r)
[391]996            self.assertEqual(s, 1)
[539]997            s = delete(table, r)
[391]998            self.assertEqual(s, 0)
[346]999            self.assertRaises(pg.DatabaseError, self.db.get, table, 2, 'n')
[539]1000            query('drop table "%s"' % table)
[204]1001
[391]1002    def testDeleteWithCompositeKey(self):
[539]1003        query = self.db.query
[391]1004        table = 'delete_test_table_1'
[539]1005        query("drop table if exists %s" % table)
1006        query("create table %s ("
[391]1007            "n integer, t text, primary key (n))" % table)
1008        for n, t in enumerate('abc'):
[539]1009            query("insert into %s values("
[464]1010                "%d, '%s')" % (table, n + 1, t))
[391]1011        self.assertRaises(pg.ProgrammingError, self.db.delete,
1012            table, dict(t='b'))
1013        self.assertEqual(self.db.delete(table, dict(n=2)), 1)
[539]1014        r = query('select t from "%s" where n=2' % table
1015                  ).getresult()
[391]1016        self.assertEqual(r, [])
1017        self.assertEqual(self.db.delete(table, dict(n=2)), 0)
[539]1018        r = query('select t from "%s" where n=3' % table
1019                  ).getresult()[0][0]
[391]1020        self.assertEqual(r, 'c')
[539]1021        query("drop table %s" % table)
[391]1022        table = 'delete_test_table_2'
[539]1023        query("drop table if exists %s" % table)
1024        query("create table %s ("
[391]1025            "n integer, m integer, t text, primary key (n, m))" % table)
1026        for n in range(3):
1027            for m in range(2):
[464]1028                t = chr(ord('a') + 2 * n + m)
[539]1029                query("insert into %s values("
[464]1030                    "%d, %d, '%s')" % (table, n + 1, m + 1, t))
[391]1031        self.assertRaises(pg.ProgrammingError, self.db.delete,
1032            table, dict(n=2, t='b'))
1033        self.assertEqual(self.db.delete(table, dict(n=2, m=2)), 1)
[539]1034        r = [r[0] for r in query('select t from "%s" where n=2'
[391]1035            ' order by m' % table).getresult()]
1036        self.assertEqual(r, ['c'])
1037        self.assertEqual(self.db.delete(table, dict(n=2, m=2)), 0)
[539]1038        r = [r[0] for r in query('select t from "%s" where n=3'
[391]1039            ' order by m' % table).getresult()]
1040        self.assertEqual(r, ['e', 'f'])
1041        self.assertEqual(self.db.delete(table, dict(n=3, m=1)), 1)
[539]1042        r = [r[0] for r in query('select t from "%s" where n=3'
[391]1043            ' order by m' % table).getresult()]
1044        self.assertEqual(r, ['f'])
[539]1045        query("drop table %s" % table)
[391]1046
[464]1047    def testTransaction(self):
[539]1048        query = self.db.query
1049        query("drop table if exists test_table")
1050        query("create table test_table (n integer)")
[464]1051        self.db.begin()
[539]1052        query("insert into test_table values (1)")
1053        query("insert into test_table values (2)")
[464]1054        self.db.commit()
1055        self.db.begin()
[539]1056        query("insert into test_table values (3)")
1057        query("insert into test_table values (4)")
[464]1058        self.db.rollback()
1059        self.db.begin()
[539]1060        query("insert into test_table values (5)")
[464]1061        self.db.savepoint('before6')
[539]1062        query("insert into test_table values (6)")
[464]1063        self.db.rollback('before6')
[539]1064        query("insert into test_table values (7)")
[464]1065        self.db.commit()
1066        self.db.begin()
1067        self.db.savepoint('before8')
[539]1068        query("insert into test_table values (8)")
[464]1069        self.db.release('before8')
1070        self.assertRaises(pg.ProgrammingError, self.db.rollback, 'before8')
1071        self.db.commit()
1072        self.db.start()
[539]1073        query("insert into test_table values (9)")
[464]1074        self.db.end()
[539]1075        r = [r[0] for r in query(
[464]1076            "select * from test_table order by 1").getresult()]
1077        self.assertEqual(r, [1, 2, 5, 7, 9])
[539]1078        query("drop table test_table")
[464]1079
1080    def testContextManager(self):
[539]1081        query = self.db.query
1082        query("drop table if exists test_table")
1083        query("create table test_table (n integer check(n>0))")
[464]1084        with self.db:
[539]1085            query("insert into test_table values (1)")
1086            query("insert into test_table values (2)")
[464]1087        try:
1088            with self.db:
[539]1089                query("insert into test_table values (3)")
1090                query("insert into test_table values (4)")
[464]1091                raise ValueError('test transaction should rollback')
[556]1092        except ValueError as error:
[464]1093            self.assertEqual(str(error), 'test transaction should rollback')
1094        with self.db:
[539]1095            query("insert into test_table values (5)")
[464]1096        try:
1097            with self.db:
[539]1098                query("insert into test_table values (6)")
1099                query("insert into test_table values (-1)")
[556]1100        except pg.ProgrammingError as error:
[464]1101            self.assertTrue('check' in str(error))
1102        with self.db:
[539]1103            query("insert into test_table values (7)")
1104        r = [r[0] for r in query(
[464]1105            "select * from test_table order by 1").getresult()]
1106        self.assertEqual(r, [1, 2, 5, 7])
[539]1107        query("drop table test_table")
[464]1108
[346]1109    def testBytea(self):
[539]1110        query = self.db.query
1111        query('drop table if exists bytea_test')
[662]1112        query('create table bytea_test (n smallint primary key, data bytea)')
[629]1113        s = b"It's all \\ kinds \x00 of\r nasty \xff stuff!\n"
[346]1114        r = self.db.escape_bytea(s)
[662]1115        query('insert into bytea_test values(3,$1)', (r,))
1116        r = query('select * from bytea_test where n=3').getresult()
1117        self.assertEqual(len(r), 1)
[346]1118        r = r[0]
[662]1119        self.assertEqual(len(r), 2)
1120        self.assertEqual(r[0], 3)
1121        r = r[1]
1122        self.assertIsInstance(r, str)
1123        r = self.db.unescape_bytea(r)
1124        self.assertIsInstance(r, bytes)
1125        self.assertEqual(r, s)
1126        query('drop table bytea_test')
1127
[664]1128    def testInsertUpdateGetBytea(self):
[662]1129        query = self.db.query
1130        query('drop table if exists bytea_test')
1131        query('create table bytea_test (n smallint primary key, data bytea)')
1132        # insert as bytes
1133        s = b"It's all \\ kinds \x00 of\r nasty \xff stuff!\n"
1134        r = self.db.insert('bytea_test', n=5, data=s)
1135        self.assertIsInstance(r, dict)
1136        self.assertIn('n', r)
1137        self.assertEqual(r['n'], 5)
1138        self.assertIn('data', r)
1139        r = r['data']
1140        self.assertIsInstance(r, bytes)
1141        self.assertEqual(r, s)
1142        # update as bytes
1143        s += b"and now even more \x00 nasty \t stuff!\f"
1144        r = self.db.update('bytea_test', n=5, data=s)
1145        self.assertIsInstance(r, dict)
1146        self.assertIn('n', r)
1147        self.assertEqual(r['n'], 5)
1148        self.assertIn('data', r)
1149        r = r['data']
1150        self.assertIsInstance(r, bytes)
1151        self.assertEqual(r, s)
1152        r = query('select * from bytea_test where n=5').getresult()
1153        self.assertEqual(len(r), 1)
[346]1154        r = r[0]
[662]1155        self.assertEqual(len(r), 2)
1156        self.assertEqual(r[0], 5)
1157        r = r[1]
1158        self.assertIsInstance(r, str)
[346]1159        r = self.db.unescape_bytea(r)
[629]1160        self.assertIsInstance(r, bytes)
[346]1161        self.assertEqual(r, s)
[664]1162        r = self.db.get('bytea_test', dict(n=5))
1163        self.assertIsInstance(r, dict)
1164        self.assertIn('n', r)
1165        self.assertEqual(r['n'], 5)
1166        self.assertIn('data', r)
1167        r = r['data']
1168        self.assertIsInstance(r, bytes)
1169        self.assertEqual(r, s)
[539]1170        query('drop table bytea_test')
[204]1171
[466]1172    def testDebugWithCallable(self):
[539]1173        if debug:
1174            self.assertEqual(self.db.debug, debug)
1175        else:
1176            self.assertIsNone(self.db.debug)
[466]1177        s = []
1178        self.db.debug = s.append
1179        try:
1180            self.db.query("select 1")
1181            self.db.query("select 2")
1182            self.assertEqual(s, ["select 1", "select 2"])
1183        finally:
[539]1184            self.db.debug = debug
[328]1185
[466]1186
[706]1187class TestDBClassNonStdOpts(TestDBClass):
1188    """Test the methods of the DB class with non-standard global options."""
1189
1190    @classmethod
1191    def setUpClass(cls):
1192        cls.saved_options = {}
1193        cls.set_option('decimal', float)
1194        not_bool = not pg.get_bool()
1195        cls.set_option('bool', not_bool)
1196        unnamed_result = lambda q: q.getresult()
1197        cls.set_option('namedresult', unnamed_result)
1198        super(TestDBClassNonStdOpts, cls).setUpClass()
1199
1200    @classmethod
1201    def tearDownClass(cls):
1202        super(TestDBClassNonStdOpts, cls).tearDownClass()
1203        cls.reset_option('namedresult')
1204        cls.reset_option('bool')
1205        cls.reset_option('decimal')
1206
1207    @classmethod
1208    def set_option(cls, option, value):
1209        cls.saved_options[option] = getattr(pg, 'get_' + option)()
1210        return getattr(pg, 'set_' + option)(value)
1211
1212    @classmethod
1213    def reset_option(cls, option):
1214        return getattr(pg, 'set_' + option)(cls.saved_options[option])
1215
1216
[204]1217class TestSchemas(unittest.TestCase):
[641]1218    """Test correct handling of schemas (namespaces)."""
[204]1219
[539]1220    @classmethod
1221    def setUpClass(cls):
1222        db = DB()
1223        query = db.query
1224        query("set client_min_messages=warning")
1225        for num_schema in range(5):
1226            if num_schema:
1227                schema = "s%d" % num_schema
1228                query("drop schema if exists %s cascade" % (schema,))
1229                try:
1230                    query("create schema %s" % (schema,))
1231                except pg.ProgrammingError:
1232                    raise RuntimeError("The test user cannot create schemas.\n"
1233                        "Grant create on database %s to the user"
1234                        " for running these tests." % dbname)
1235            else:
1236                schema = "public"
1237                query("drop table if exists %s.t" % (schema,))
1238                query("drop table if exists %s.t%d" % (schema, num_schema))
1239            query("create table %s.t with oids as select 1 as n, %d as d"
1240                  % (schema, num_schema))
1241            query("create table %s.t%d with oids as select 1 as n, %d as d"
1242                  % (schema, num_schema, num_schema))
1243        db.close()
[204]1244
[539]1245    @classmethod
1246    def tearDownClass(cls):
1247        db = DB()
1248        query = db.query
1249        query("set client_min_messages=warning")
1250        for num_schema in range(5):
1251            if num_schema:
1252                schema = "s%d" % num_schema
1253                query("drop schema %s cascade" % (schema,))
1254            else:
1255                schema = "public"
1256                query("drop table %s.t" % (schema,))
1257                query("drop table %s.t%d" % (schema, num_schema))
1258        db.close()
1259
[346]1260    def setUp(self):
[539]1261        self.db = DB()
1262        self.db.query("set client_min_messages=warning")
[204]1263
[346]1264    def tearDown(self):
1265        self.db.close()
[204]1266
[346]1267    def testGetTables(self):
1268        tables = self.db.get_tables()
1269        for num_schema in range(5):
1270            if num_schema:
1271                schema = "s" + str(num_schema)
1272            else:
1273                schema = "public"
1274            for t in (schema + ".t",
[539]1275                    schema + ".t" + str(num_schema)):
1276                self.assertIn(t, tables)
[204]1277
[346]1278    def testGetAttnames(self):
[539]1279        get_attnames = self.db.get_attnames
1280        query = self.db.query
[346]1281        result = {'oid': 'int', 'd': 'int', 'n': 'int'}
[539]1282        r = get_attnames("t")
[346]1283        self.assertEqual(r, result)
[539]1284        r = get_attnames("s4.t4")
[346]1285        self.assertEqual(r, result)
[539]1286        query("drop table if exists s3.t3m")
1287        query("create table s3.t3m with oids as select 1 as m")
[346]1288        result_m = {'oid': 'int', 'm': 'int'}
[539]1289        r = get_attnames("s3.t3m")
[346]1290        self.assertEqual(r, result_m)
[539]1291        query("set search_path to s1,s3")
1292        r = get_attnames("t3")
[346]1293        self.assertEqual(r, result)
[539]1294        r = get_attnames("t3m")
[346]1295        self.assertEqual(r, result_m)
[539]1296        query("drop table s3.t3m")
[204]1297
[346]1298    def testGet(self):
[539]1299        get = self.db.get
1300        query = self.db.query
[436]1301        PrgError = pg.ProgrammingError
[539]1302        self.assertEqual(get("t", 1, 'n')['d'], 0)
1303        self.assertEqual(get("t0", 1, 'n')['d'], 0)
1304        self.assertEqual(get("public.t", 1, 'n')['d'], 0)
1305        self.assertEqual(get("public.t0", 1, 'n')['d'], 0)
1306        self.assertRaises(PrgError, get, "public.t1", 1, 'n')
1307        self.assertEqual(get("s1.t1", 1, 'n')['d'], 1)
1308        self.assertEqual(get("s3.t", 1, 'n')['d'], 3)
1309        query("set search_path to s2,s4")
1310        self.assertRaises(PrgError, get, "t1", 1, 'n')
1311        self.assertEqual(get("t4", 1, 'n')['d'], 4)
1312        self.assertRaises(PrgError, get, "t3", 1, 'n')
1313        self.assertEqual(get("t", 1, 'n')['d'], 2)
1314        self.assertEqual(get("s3.t3", 1, 'n')['d'], 3)
1315        query("set search_path to s1,s3")
1316        self.assertRaises(PrgError, get, "t2", 1, 'n')
1317        self.assertEqual(get("t3", 1, 'n')['d'], 3)
1318        self.assertRaises(PrgError, get, "t4", 1, 'n')
1319        self.assertEqual(get("t", 1, 'n')['d'], 1)
1320        self.assertEqual(get("s4.t4", 1, 'n')['d'], 4)
[204]1321
[729]1322    def testMunging(self):
[539]1323        get = self.db.get
1324        query = self.db.query
1325        r = get("t", 1, 'n')
[729]1326        self.assertIn('oid(t)', r)
[539]1327        query("set search_path to s2")
1328        r = get("t2", 1, 'n')
[729]1329        self.assertIn('oid(t2)', r)
[539]1330        query("set search_path to s3")
1331        r = get("t", 1, 'n')
[729]1332        self.assertIn('oid(t)', r)
[204]1333
1334
1335if __name__ == '__main__':
[539]1336    unittest.main()
Note: See TracBrowser for help on using the repository browser.