source: branches/4.x/module/tests/test_classic_dbwrapper.py @ 650

Last change on this file since 650 was 650, checked in by cito, 4 years ago

Make tests compatible with Python 2.4

We want to be able to test the 4.x branch with older Python versions
so that we can specify eligible minimum requirements.

  • Property svn:executable set to *
  • Property svn:keywords set to Id
File size: 44.0 KB
Line 
1#! /usr/bin/python
2# -*- coding: utf-8 -*-
3
4"""Test the classic PyGreSQL interface.
5
6Sub-tests for the DB wrapper object.
7
8Contributed by Christoph Zwerschke.
9
10These tests need a database to test against.
11
12"""
13
14try:
15    import unittest2 as unittest  # for Python < 2.7
16except ImportError:
17    import unittest
18
19import sys
20
21import pg  # the module under test
22
23from decimal import Decimal
24
25# check whether the "with" statement is supported
26no_with = sys.version_info[:2] < (2, 5)
27
28# We need a database to test against.  If LOCAL_PyGreSQL.py exists we will
29# get our information from that.  Otherwise we use the defaults.
30# The current user must have create schema privilege on the database.
31dbname = 'unittest'
32dbhost = None
33dbport = 5432
34
35debug = False  # let DB wrapper print debugging output
36
37try:
38    from LOCAL_PyGreSQL import *
39except ImportError:
40    pass
41
42
43def DB():
44    """Create a DB wrapper object connecting to the test database."""
45    db = pg.DB(dbname, dbhost, dbport)
46    if debug:
47        db.debug = debug
48    db.query("set client_min_messages=warning")
49    return db
50
51
52class TestDBClassBasic(unittest.TestCase):
53    """Test existence of the DB class wrapped pg connection methods."""
54
55    def setUp(self):
56        self.db = DB()
57
58    def tearDown(self):
59        try:
60            self.db.close()
61        except pg.InternalError:
62            pass
63
64    def testAllDBAttributes(self):
65        attributes = [
66            'begin',
67            'cancel',
68            'clear',
69            'close',
70            'commit',
71            'db',
72            'dbname',
73            'debug',
74            'delete',
75            'end',
76            'endcopy',
77            'error',
78            'escape_bytea',
79            'escape_identifier',
80            'escape_literal',
81            'escape_string',
82            'fileno',
83            'get',
84            'get_attnames',
85            'get_databases',
86            'get_notice_receiver',
87            'get_relations',
88            'get_tables',
89            'getline',
90            'getlo',
91            'getnotify',
92            'has_table_privilege',
93            'host',
94            'insert',
95            'inserttable',
96            'locreate',
97            'loimport',
98            'notification_handler',
99            'options',
100            'parameter',
101            'pkey',
102            'port',
103            'protocol_version',
104            'putline',
105            'query',
106            'release',
107            'reopen',
108            'reset',
109            'rollback',
110            'savepoint',
111            'server_version',
112            'set_notice_receiver',
113            'source',
114            'start',
115            'status',
116            'transaction',
117            'tty',
118            'unescape_bytea',
119            'update',
120            'use_regtypes',
121            'user',
122        ]
123        db_attributes = [a for a in dir(self.db)
124            if not a.startswith('_')]
125        self.assertEqual(attributes, db_attributes)
126
127    def testAttributeDb(self):
128        self.assertEqual(self.db.db.db, dbname)
129
130    def testAttributeDbname(self):
131        self.assertEqual(self.db.dbname, dbname)
132
133    def testAttributeError(self):
134        error = self.db.error
135        self.assertTrue(not error or 'krb5_' in error)
136        self.assertEqual(self.db.error, self.db.db.error)
137
138    def testAttributeHost(self):
139        def_host = 'localhost'
140        host = self.db.host
141        self.assertIsInstance(host, str)
142        self.assertEqual(host, dbhost or def_host)
143        self.assertEqual(host, self.db.db.host)
144
145    def testAttributeOptions(self):
146        no_options = ''
147        options = self.db.options
148        self.assertEqual(options, no_options)
149        self.assertEqual(options, self.db.db.options)
150
151    def testAttributePort(self):
152        def_port = 5432
153        port = self.db.port
154        self.assertIsInstance(port, int)
155        self.assertEqual(port, dbport or def_port)
156        self.assertEqual(port, self.db.db.port)
157
158    def testAttributeProtocolVersion(self):
159        protocol_version = self.db.protocol_version
160        self.assertIsInstance(protocol_version, int)
161        self.assertTrue(2 <= protocol_version < 4)
162        self.assertEqual(protocol_version, self.db.db.protocol_version)
163
164    def testAttributeServerVersion(self):
165        server_version = self.db.server_version
166        self.assertIsInstance(server_version, int)
167        self.assertTrue(70400 <= server_version < 100000)
168        self.assertEqual(server_version, self.db.db.server_version)
169
170    def testAttributeStatus(self):
171        status_ok = 1
172        status = self.db.status
173        self.assertIsInstance(status, int)
174        self.assertEqual(status, status_ok)
175        self.assertEqual(status, self.db.db.status)
176
177    def testAttributeTty(self):
178        def_tty = ''
179        tty = self.db.tty
180        self.assertIsInstance(tty, str)
181        self.assertEqual(tty, def_tty)
182        self.assertEqual(tty, self.db.db.tty)
183
184    def testAttributeUser(self):
185        no_user = 'Deprecated facility'
186        user = self.db.user
187        self.assertTrue(user)
188        self.assertIsInstance(user, str)
189        self.assertNotEqual(user, no_user)
190        self.assertEqual(user, self.db.db.user)
191
192    def testMethodEscapeLiteral(self):
193        self.assertEqual(self.db.escape_literal(''), "''")
194
195    def testMethodEscapeIdentifier(self):
196        self.assertEqual(self.db.escape_identifier(''), '""')
197
198    def testMethodEscapeString(self):
199        self.assertEqual(self.db.escape_string(''), '')
200
201    def testMethodEscapeBytea(self):
202        self.assertEqual(self.db.escape_bytea('').replace(
203            '\\x', '').replace('\\', ''), '')
204
205    def testMethodUnescapeBytea(self):
206        self.assertEqual(self.db.unescape_bytea(''), '')
207
208    def testMethodQuery(self):
209        query = self.db.query
210        query("select 1+1")
211        query("select 1+$1+$2", 2, 3)
212        query("select 1+$1+$2", (2, 3))
213        query("select 1+$1+$2", [2, 3])
214        query("select 1+$1", 1)
215
216    def testMethodQueryEmpty(self):
217        self.assertRaises(ValueError, self.db.query, '')
218
219    def testMethodQueryProgrammingError(self):
220        try:
221            self.db.query("select 1/0")
222        except pg.ProgrammingError, error:
223            self.assertEqual(error.sqlstate, '22012')
224
225    def testMethodEndcopy(self):
226        try:
227            self.db.endcopy()
228        except IOError:
229            pass
230
231    def testMethodClose(self):
232        self.db.close()
233        try:
234            self.db.reset()
235        except pg.Error:
236            pass
237        else:
238            self.fail('Reset should give an error for a closed connection')
239        self.assertRaises(pg.InternalError, self.db.close)
240        self.assertRaises(pg.InternalError, self.db.query, 'select 1')
241
242    def testExistingConnection(self):
243        db = pg.DB(self.db.db)
244        self.assertEqual(self.db.db, db.db)
245        self.assertTrue(db.db)
246        db.close()
247        self.assertTrue(db.db)
248        db.reopen()
249        self.assertTrue(db.db)
250        db.close()
251        self.assertTrue(db.db)
252        db = pg.DB(self.db)
253        self.assertEqual(self.db.db, db.db)
254        db = pg.DB(db=self.db.db)
255        self.assertEqual(self.db.db, db.db)
256
257        class DB2:
258            pass
259
260        db2 = DB2()
261        db2._cnx = self.db.db
262        db = pg.DB(db2)
263        self.assertEqual(self.db.db, db.db)
264
265
266class TestDBClass(unittest.TestCase):
267    """Test the methods of the DB class wrapped pg connection."""
268
269    @classmethod
270    def setUpClass(cls):
271        db = DB()
272        db.query("drop table if exists test cascade")
273        db.query("create table test ("
274            "i2 smallint, i4 integer, i8 bigint,"
275            "d numeric, f4 real, f8 double precision, m money, "
276            "v4 varchar(4), c4 char(4), t text)")
277        db.query("create or replace view test_view as"
278            " select i4, v4 from test")
279        db.close()
280
281    @classmethod
282    def tearDownClass(cls):
283        db = DB()
284        db.query("drop table test cascade")
285        db.close()
286
287    def setUp(self):
288        self.db = DB()
289        query = self.db.query
290        query('set client_encoding=utf8')
291        query('set standard_conforming_strings=on')
292        query('set bytea_output=hex')
293        query("set lc_monetary='C'")
294
295    def tearDown(self):
296        self.db.close()
297
298    def testEscapeLiteral(self):
299        f = self.db.escape_literal
300        self.assertEqual(f("plain"), "'plain'")
301        self.assertEqual(f("that's k\xe4se"), "'that''s k\xe4se'")
302        self.assertEqual(f(r"It's fine to have a \ inside."),
303            r" E'It''s fine to have a \\ inside.'")
304        self.assertEqual(f('No "quotes" must be escaped.'),
305            "'No \"quotes\" must be escaped.'")
306
307    def testEscapeIdentifier(self):
308        f = self.db.escape_identifier
309        self.assertEqual(f("plain"), '"plain"')
310        self.assertEqual(f("that's k\xe4se"), '"that\'s k\xe4se"')
311        self.assertEqual(f(r"It's fine to have a \ inside."),
312            '"It\'s fine to have a \\ inside."')
313        self.assertEqual(f('All "quotes" must be escaped.'),
314            '"All ""quotes"" must be escaped."')
315
316    def testEscapeString(self):
317        f = self.db.escape_string
318        self.assertEqual(f("plain"), "plain")
319        self.assertEqual(f("that's k\xe4se"), "that''s k\xe4se")
320        self.assertEqual(f(r"It's fine to have a \ inside."),
321            r"It''s fine to have a \ inside.")
322
323    def testEscapeBytea(self):
324        f = self.db.escape_bytea
325        # note that escape_byte always returns hex output since Pg 9.0,
326        # regardless of the bytea_output setting
327        self.assertEqual(f("plain"), r"\x706c61696e")
328        self.assertEqual(f("that's k\xe4se"), r"\x746861742773206be47365")
329        self.assertEqual(f('O\x00ps\xff!'), r"\x4f007073ff21")
330
331    def testUnescapeBytea(self):
332        f = self.db.unescape_bytea
333        self.assertEqual(f("plain"), "plain")
334        self.assertEqual(f("that's k\\344se"), "that's k\xe4se")
335        self.assertEqual(f(r'O\000ps\377!'), 'O\x00ps\xff!')
336        self.assertEqual(f(r"\\x706c61696e"), r"\x706c61696e")
337        self.assertEqual(f(r"\\x746861742773206be47365"),
338            r"\x746861742773206be47365")
339        self.assertEqual(f(r"\\x4f007073ff21"), r"\x4f007073ff21")
340
341    def testQuote(self):
342        f = self.db._quote
343        self.assertEqual(f(None, None), 'NULL')
344        self.assertEqual(f(None, 'int'), 'NULL')
345        self.assertEqual(f(None, 'float'), 'NULL')
346        self.assertEqual(f(None, 'num'), 'NULL')
347        self.assertEqual(f(None, 'money'), 'NULL')
348        self.assertEqual(f(None, 'bool'), 'NULL')
349        self.assertEqual(f(None, 'date'), 'NULL')
350        self.assertEqual(f('', 'int'), 'NULL')
351        self.assertEqual(f('', 'float'), 'NULL')
352        self.assertEqual(f('', 'num'), 'NULL')
353        self.assertEqual(f('', 'money'), 'NULL')
354        self.assertEqual(f('', 'bool'), 'NULL')
355        self.assertEqual(f('', 'date'), 'NULL')
356        self.assertEqual(f('', 'text'), "''")
357        self.assertEqual(f(0, 'int'), '0')
358        self.assertEqual(f(0, 'num'), '0')
359        self.assertEqual(f(1, 'int'), '1')
360        self.assertEqual(f(1, 'num'), '1')
361        self.assertEqual(f(-1, 'int'), '-1')
362        self.assertEqual(f(-1, 'num'), '-1')
363        self.assertEqual(f(123456789, 'int'), '123456789')
364        self.assertEqual(f(123456987, 'num'), '123456987')
365        self.assertEqual(f(1.23654789, 'num'), '1.23654789')
366        self.assertEqual(f(12365478.9, 'num'), '12365478.9')
367        self.assertEqual(f('123456789', 'num'), '123456789')
368        self.assertEqual(f('1.23456789', 'num'), '1.23456789')
369        self.assertEqual(f('12345678.9', 'num'), '12345678.9')
370        self.assertEqual(f(123, 'money'), '123')
371        self.assertEqual(f('123', 'money'), '123')
372        self.assertEqual(f(123.45, 'money'), '123.45')
373        self.assertEqual(f('123.45', 'money'), '123.45')
374        self.assertEqual(f(123.454, 'money'), '123.454')
375        self.assertEqual(f('123.454', 'money'), '123.454')
376        self.assertEqual(f(123.456, 'money'), '123.456')
377        self.assertEqual(f('123.456', 'money'), '123.456')
378        self.assertEqual(f('f', 'bool'), "'f'")
379        self.assertEqual(f('F', 'bool'), "'f'")
380        self.assertEqual(f('false', 'bool'), "'f'")
381        self.assertEqual(f('False', 'bool'), "'f'")
382        self.assertEqual(f('FALSE', 'bool'), "'f'")
383        self.assertEqual(f(0, 'bool'), "'f'")
384        self.assertEqual(f('0', 'bool'), "'f'")
385        self.assertEqual(f('-', 'bool'), "'f'")
386        self.assertEqual(f('n', 'bool'), "'f'")
387        self.assertEqual(f('N', 'bool'), "'f'")
388        self.assertEqual(f('no', 'bool'), "'f'")
389        self.assertEqual(f('off', 'bool'), "'f'")
390        self.assertEqual(f('t', 'bool'), "'t'")
391        self.assertEqual(f('T', 'bool'), "'t'")
392        self.assertEqual(f('true', 'bool'), "'t'")
393        self.assertEqual(f('True', 'bool'), "'t'")
394        self.assertEqual(f('TRUE', 'bool'), "'t'")
395        self.assertEqual(f(1, 'bool'), "'t'")
396        self.assertEqual(f(2, 'bool'), "'t'")
397        self.assertEqual(f(-1, 'bool'), "'t'")
398        self.assertEqual(f(0.5, 'bool'), "'t'")
399        self.assertEqual(f('1', 'bool'), "'t'")
400        self.assertEqual(f('y', 'bool'), "'t'")
401        self.assertEqual(f('Y', 'bool'), "'t'")
402        self.assertEqual(f('yes', 'bool'), "'t'")
403        self.assertEqual(f('on', 'bool'), "'t'")
404        self.assertEqual(f('01.01.2000', 'date'), "'01.01.2000'")
405        self.assertEqual(f(123, 'text'), "'123'")
406        self.assertEqual(f(1.23, 'text'), "'1.23'")
407        self.assertEqual(f('abc', 'text'), "'abc'")
408        self.assertEqual(f("ab'c", 'text'), "'ab''c'")
409        self.assertEqual(f('ab\\c', 'text'), "'ab\\c'")
410        self.assertEqual(f("a\\b'c", 'text'), "'a\\b''c'")
411        self.db.query('set standard_conforming_strings=off')
412        self.assertEqual(f('ab\\c', 'text'), "'ab\\\\c'")
413        self.assertEqual(f("a\\b'c", 'text'), "'a\\\\b''c'")
414
415    def testQuery(self):
416        query = self.db.query
417        query("drop table if exists test_table")
418        q = "create table test_table (n integer) with oids"
419        r = query(q)
420        self.assertIsNone(r)
421        q = "insert into test_table values (1)"
422        r = query(q)
423        self.assertIsInstance(r, int)
424        q = "insert into test_table select 2"
425        r = query(q)
426        self.assertIsInstance(r, int)
427        oid = r
428        q = "select oid from test_table where n=2"
429        r = query(q).getresult()
430        self.assertEqual(len(r), 1)
431        r = r[0]
432        self.assertEqual(len(r), 1)
433        r = r[0]
434        self.assertEqual(r, oid)
435        q = "insert into test_table select 3 union select 4 union select 5"
436        r = query(q)
437        self.assertIsInstance(r, str)
438        self.assertEqual(r, '3')
439        q = "update test_table set n=4 where n<5"
440        r = query(q)
441        self.assertIsInstance(r, str)
442        self.assertEqual(r, '4')
443        q = "delete from test_table"
444        r = query(q)
445        self.assertIsInstance(r, str)
446        self.assertEqual(r, '5')
447        query("drop table test_table")
448
449    def testMultipleQueries(self):
450        self.assertEqual(self.db.query(
451            "create temporary table test_multi (n integer);"
452            "insert into test_multi values (4711);"
453            "select n from test_multi").getresult()[0][0], 4711)
454
455    def testQueryWithParams(self):
456        query = self.db.query
457        query("drop table if exists test_table")
458        q = "create table test_table (n1 integer, n2 integer) with oids"
459        query(q)
460        q = "insert into test_table values ($1, $2)"
461        r = query(q, (1, 2))
462        self.assertIsInstance(r, int)
463        r = query(q, [3, 4])
464        self.assertIsInstance(r, int)
465        r = query(q, [5, 6])
466        self.assertIsInstance(r, int)
467        q = "select * from test_table order by 1, 2"
468        self.assertEqual(query(q).getresult(),
469            [(1, 2), (3, 4), (5, 6)])
470        q = "select * from test_table where n1=$1 and n2=$2"
471        self.assertEqual(query(q, 3, 4).getresult(), [(3, 4)])
472        q = "update test_table set n2=$2 where n1=$1"
473        r = query(q, 3, 7)
474        self.assertEqual(r, '1')
475        q = "select * from test_table order by 1, 2"
476        self.assertEqual(query(q).getresult(),
477            [(1, 2), (3, 7), (5, 6)])
478        q = "delete from test_table where n2!=$1"
479        r = query(q, 4)
480        self.assertEqual(r, '3')
481        query("drop table test_table")
482
483    def testEmptyQuery(self):
484        self.assertRaises(ValueError, self.db.query, '')
485
486    def testQueryProgrammingError(self):
487        try:
488            self.db.query("select 1/0")
489        except pg.ProgrammingError, error:
490            self.assertEqual(error.sqlstate, '22012')
491
492    def testPkey(self):
493        query = self.db.query
494        for n in range(4):
495            query("drop table if exists pkeytest%d" % n)
496        query("create table pkeytest0 ("
497            "a smallint)")
498        query("create table pkeytest1 ("
499            "b smallint primary key)")
500        query("create table pkeytest2 ("
501            "c smallint, d smallint primary key)")
502        query("create table pkeytest3 ("
503            "e smallint, f smallint, g smallint, "
504            "h smallint, i smallint, "
505            "primary key (f,h))")
506        pkey = self.db.pkey
507        self.assertRaises(KeyError, pkey, 'pkeytest0')
508        self.assertEqual(pkey('pkeytest1'), 'b')
509        self.assertEqual(pkey('pkeytest2'), 'd')
510        self.assertEqual(pkey('pkeytest3'), frozenset('fh'))
511        self.assertEqual(pkey('pkeytest0', 'none'), 'none')
512        self.assertEqual(pkey('pkeytest0'), 'none')
513        pkey(None, {'t': 'a', 'n.t': 'b'})
514        self.assertEqual(pkey('t'), 'a')
515        self.assertEqual(pkey('n.t'), 'b')
516        self.assertRaises(KeyError, pkey, 'pkeytest0')
517        for n in range(4):
518            query("drop table pkeytest%d" % n)
519
520    def testGetDatabases(self):
521        databases = self.db.get_databases()
522        self.assertIn('template0', databases)
523        self.assertIn('template1', databases)
524        self.assertNotIn('not existing database', databases)
525        self.assertIn('postgres', databases)
526        self.assertIn(dbname, databases)
527
528    def testGetTables(self):
529        get_tables = self.db.get_tables
530        result1 = get_tables()
531        tables = ('"A very Special Name"',
532            '"A_MiXeD_quoted_NaMe"', 'a1', 'a2',
533            'A_MiXeD_NaMe', '"another special name"',
534            'averyveryveryveryveryveryverylongtablename',
535            'b0', 'b3', 'x', 'xx', 'xXx', 'y', 'z')
536        for t in tables:
537            self.db.query('drop table if exists %s' % t)
538            self.db.query("create table %s"
539                " as select 0" % t)
540        result3 = get_tables()
541        result2 = []
542        for t in result3:
543            if t not in result1:
544                result2.append(t)
545        result3 = []
546        for t in tables:
547            if not t.startswith('"'):
548                t = t.lower()
549            result3.append('public.' + t)
550        self.assertEqual(result2, result3)
551        for t in result2:
552            self.db.query('drop table %s' % t)
553        result2 = get_tables()
554        self.assertEqual(result2, result1)
555
556    def testGetRelations(self):
557        get_relations = self.db.get_relations
558        result = get_relations()
559        self.assertIn('public.test', result)
560        self.assertIn('public.test_view', result)
561        result = get_relations('rv')
562        self.assertIn('public.test', result)
563        self.assertIn('public.test_view', result)
564        result = get_relations('r')
565        self.assertIn('public.test', result)
566        self.assertNotIn('public.test_view', result)
567        result = get_relations('v')
568        self.assertNotIn('public.test', result)
569        self.assertIn('public.test_view', result)
570        result = get_relations('cisSt')
571        self.assertNotIn('public.test', result)
572        self.assertNotIn('public.test_view', result)
573
574    def testAttnames(self):
575        self.assertRaises(pg.ProgrammingError,
576            self.db.get_attnames, 'does_not_exist')
577        self.assertRaises(pg.ProgrammingError,
578            self.db.get_attnames, 'has.too.many.dots')
579        for table in ('attnames_test_table', 'test table for attnames'):
580            self.db.query('drop table if exists "%s"' % table)
581            self.db.query('create table "%s" ('
582                'a smallint, b integer, c bigint, '
583                'e numeric, f float, f2 double precision, m money, '
584                'x smallint, y smallint, z smallint, '
585                'Normal_NaMe smallint, "Special Name" smallint, '
586                't text, u char(2), v varchar(2), '
587                'primary key (y, u)) with oids' % table)
588            attributes = self.db.get_attnames(table)
589            result = {'a': 'int', 'c': 'int', 'b': 'int',
590                'e': 'num', 'f': 'float', 'f2': 'float', 'm': 'money',
591                'normal_name': 'int', 'Special Name': 'int',
592                'u': 'text', 't': 'text', 'v': 'text',
593                'y': 'int', 'x': 'int', 'z': 'int', 'oid': 'int'}
594            self.assertEqual(attributes, result)
595            self.db.query('drop table "%s"' % table)
596
597    def testHasTablePrivilege(self):
598        can = self.db.has_table_privilege
599        self.assertEqual(can('test'), True)
600        self.assertEqual(can('test', 'select'), True)
601        self.assertEqual(can('test', 'SeLeCt'), True)
602        self.assertEqual(can('test', 'SELECT'), True)
603        self.assertEqual(can('test', 'insert'), True)
604        self.assertEqual(can('test', 'update'), True)
605        self.assertEqual(can('test', 'delete'), True)
606        self.assertRaises(pg.ProgrammingError, can, 'test', 'foobar')
607        self.assertRaises(pg.ProgrammingError, can, 'table_does_not_exist')
608
609    def testGet(self):
610        get = self.db.get
611        query = self.db.query
612        for table in ('get_test_table', 'test table for get'):
613            query('drop table if exists "%s"' % table)
614            query('create table "%s" ('
615                "n integer, t text) with oids" % table)
616            for n, t in enumerate('xyz'):
617                query('insert into "%s" values('"%d, '%s')"
618                    % (table, n + 1, t))
619            self.assertRaises(pg.ProgrammingError, get, table, 2)
620            r = get(table, 2, 'n')
621            oid_table = table
622            if ' ' in table:
623                oid_table = '"%s"' % oid_table
624            oid_table = 'oid(public.%s)' % oid_table
625            self.assertIn(oid_table, r)
626            oid = r[oid_table]
627            self.assertIsInstance(oid, int)
628            result = {'t': 'y', 'n': 2, oid_table: oid}
629            self.assertEqual(r, result)
630            self.assertEqual(get(table + ' *', 2, 'n'), r)
631            self.assertEqual(get(table, oid, 'oid')['t'], 'y')
632            self.assertEqual(get(table, 1, 'n')['t'], 'x')
633            self.assertEqual(get(table, 3, 'n')['t'], 'z')
634            self.assertEqual(get(table, 2, 'n')['t'], 'y')
635            self.assertRaises(pg.DatabaseError, get, table, 4, 'n')
636            r['n'] = 3
637            self.assertEqual(get(table, r, 'n')['t'], 'z')
638            self.assertEqual(get(table, 1, 'n')['t'], 'x')
639            query('alter table "%s" alter n set not null' % table)
640            query('alter table "%s" add primary key (n)' % table)
641            self.assertEqual(get(table, 3)['t'], 'z')
642            self.assertEqual(get(table, 1)['t'], 'x')
643            self.assertEqual(get(table, 2)['t'], 'y')
644            r['n'] = 1
645            self.assertEqual(get(table, r)['t'], 'x')
646            r['n'] = 3
647            self.assertEqual(get(table, r)['t'], 'z')
648            r['n'] = 2
649            self.assertEqual(get(table, r)['t'], 'y')
650            query('drop table "%s"' % table)
651
652    def testGetWithCompositeKey(self):
653        get = self.db.get
654        query = self.db.query
655        table = 'get_test_table_1'
656        query("drop table if exists %s" % table)
657        query("create table %s ("
658            "n integer, t text, primary key (n))" % table)
659        for n, t in enumerate('abc'):
660            query("insert into %s values("
661                "%d, '%s')" % (table, n + 1, t))
662        self.assertEqual(get(table, 2)['t'], 'b')
663        query("drop table %s" % table)
664        table = 'get_test_table_2'
665        query("drop table if exists %s" % table)
666        query("create table %s ("
667            "n integer, m integer, t text, primary key (n, m))" % table)
668        for n in range(3):
669            for m in range(2):
670                t = chr(ord('a') + 2 * n + m)
671                query("insert into %s values("
672                    "%d, %d, '%s')" % (table, n + 1, m + 1, t))
673        self.assertRaises(pg.ProgrammingError, get, table, 2)
674        self.assertEqual(get(table, dict(n=2, m=2))['t'], 'd')
675        self.assertEqual(get(table, dict(n=1, m=2),
676                             ('n', 'm'))['t'], 'b')
677        self.assertEqual(get(table, dict(n=3, m=2),
678                             frozenset(['n', 'm']))['t'], 'f')
679        query("drop table %s" % table)
680
681    def testGetFromView(self):
682        self.db.query('delete from test where i4=14')
683        self.db.query('insert into test (i4, v4) values('
684            "14, 'abc4')")
685        r = self.db.get('test_view', 14, 'i4')
686        self.assertIn('v4', r)
687        self.assertEqual(r['v4'], 'abc4')
688
689    def testInsert(self):
690        insert = self.db.insert
691        query = self.db.query
692        for table in ('insert_test_table', 'test table for insert'):
693            query('drop table if exists "%s"' % table)
694            query('create table "%s" ('
695                "i2 smallint, i4 integer, i8 bigint,"
696                " d numeric, f4 real, f8 double precision, m money,"
697                " v4 varchar(4), c4 char(4), t text,"
698                " b boolean, ts timestamp) with oids" % table)
699            oid_table = table
700            if ' ' in table:
701                oid_table = '"%s"' % oid_table
702            oid_table = 'oid(public.%s)' % oid_table
703            tests = [dict(i2=None, i4=None, i8=None),
704                (dict(i2='', i4='', i8=''), dict(i2=None, i4=None, i8=None)),
705                (dict(i2=0, i4=0, i8=0), dict(i2=0, i4=0, i8=0)),
706                dict(i2=42, i4=123456, i8=9876543210),
707                dict(i2=2 ** 15 - 1,
708                    i4=int(2 ** 31 - 1), i8=long(2 ** 63 - 1)),
709                dict(d=None), (dict(d=''), dict(d=None)),
710                dict(d=Decimal(0)), (dict(d=0), dict(d=Decimal(0))),
711                dict(f4=None, f8=None), dict(f4=0, f8=0),
712                (dict(f4='', f8=''), dict(f4=None, f8=None)),
713                (dict(d=1234.5, f4=1234.5, f8=1234.5),
714                      dict(d=Decimal('1234.5'))),
715                dict(d=Decimal('123.456789'), f4=12.375, f8=123.4921875),
716                dict(d=Decimal('123456789.9876543212345678987654321')),
717                dict(m=None), (dict(m=''), dict(m=None)),
718                dict(m=Decimal('-1234.56')),
719                (dict(m=('-1234.56')), dict(m=Decimal('-1234.56'))),
720                dict(m=Decimal('1234.56')), dict(m=Decimal('123456')),
721                (dict(m='1234.56'), dict(m=Decimal('1234.56'))),
722                (dict(m=1234.5), dict(m=Decimal('1234.5'))),
723                (dict(m=-1234.5), dict(m=Decimal('-1234.5'))),
724                (dict(m=123456), dict(m=Decimal('123456'))),
725                (dict(m='1234567.89'), dict(m=Decimal('1234567.89'))),
726                dict(b=None), (dict(b=''), dict(b=None)),
727                dict(b='f'), dict(b='t'),
728                (dict(b=0), dict(b='f')), (dict(b=1), dict(b='t')),
729                (dict(b=False), dict(b='f')), (dict(b=True), dict(b='t')),
730                (dict(b='0'), dict(b='f')), (dict(b='1'), dict(b='t')),
731                (dict(b='n'), dict(b='f')), (dict(b='y'), dict(b='t')),
732                (dict(b='no'), dict(b='f')), (dict(b='yes'), dict(b='t')),
733                (dict(b='off'), dict(b='f')), (dict(b='on'), dict(b='t')),
734                dict(v4=None, c4=None, t=None),
735                (dict(v4='', c4='', t=''), dict(c4=' ' * 4)),
736                dict(v4='1234', c4='1234', t='1234' * 10),
737                dict(v4='abcd', c4='abcd', t='abcdefg'),
738                (dict(v4='abc', c4='abc', t='abc'), dict(c4='abc ')),
739                dict(ts=None), (dict(ts=''), dict(ts=None)),
740                (dict(ts=0), dict(ts=None)), (dict(ts=False), dict(ts=None)),
741                dict(ts='2012-12-21 00:00:00'),
742                (dict(ts='2012-12-21'), dict(ts='2012-12-21 00:00:00')),
743                dict(ts='2012-12-21 12:21:12'),
744                dict(ts='2013-01-05 12:13:14'),
745                dict(ts='current_timestamp')]
746            for test in tests:
747                if isinstance(test, dict):
748                    data = test
749                    change = {}
750                else:
751                    data, change = test
752                expect = data.copy()
753                expect.update(change)
754                self.assertEqual(insert(table, data), data)
755                self.assertIn(oid_table, data)
756                oid = data[oid_table]
757                self.assertIsInstance(oid, int)
758                data = dict(item for item in data.iteritems()
759                    if item[0] in expect)
760                ts = expect.get('ts')
761                if ts == 'current_timestamp':
762                    ts = expect['ts'] = data['ts']
763                    if len(ts) > 19:
764                        self.assertEqual(ts[19], '.')
765                        ts = ts[:19]
766                    else:
767                        self.assertEqual(len(ts), 19)
768                    self.assertTrue(ts[:4].isdigit())
769                    self.assertEqual(ts[4], '-')
770                    self.assertEqual(ts[10], ' ')
771                    self.assertTrue(ts[11:13].isdigit())
772                    self.assertEqual(ts[13], ':')
773                self.assertEqual(data, expect)
774                data = query(
775                    'select oid,* from "%s"' % table).dictresult()[0]
776                self.assertEqual(data['oid'], oid)
777                data = dict(item for item in data.iteritems()
778                    if item[0] in expect)
779                self.assertEqual(data, expect)
780                query('delete from "%s"' % table)
781            query('drop table "%s"' % table)
782
783    def testUpdate(self):
784        update = self.db.update
785        query = self.db.query
786        for table in ('update_test_table', 'test table for update'):
787            query('drop table if exists "%s"' % table)
788            query('create table "%s" ('
789                "n integer, t text) with oids" % table)
790            for n, t in enumerate('xyz'):
791                query('insert into "%s" values('
792                    "%d, '%s')" % (table, n + 1, t))
793            self.assertRaises(pg.ProgrammingError, self.db.get, table, 2)
794            r = self.db.get(table, 2, 'n')
795            r['t'] = 'u'
796            s = update(table, r)
797            self.assertEqual(s, r)
798            r = query('select t from "%s" where n=2' % table
799                      ).getresult()[0][0]
800            self.assertEqual(r, 'u')
801            query('drop table "%s"' % table)
802
803    def testUpdateWithCompositeKey(self):
804        update = self.db.update
805        query = self.db.query
806        table = 'update_test_table_1'
807        query("drop table if exists %s" % table)
808        query("create table %s ("
809            "n integer, t text, primary key (n))" % table)
810        for n, t in enumerate('abc'):
811            query("insert into %s values("
812                "%d, '%s')" % (table, n + 1, t))
813        self.assertRaises(pg.ProgrammingError, update,
814                          table, dict(t='b'))
815        self.assertEqual(update(table, dict(n=2, t='d'))['t'], 'd')
816        r = query('select t from "%s" where n=2' % table
817                  ).getresult()[0][0]
818        self.assertEqual(r, 'd')
819        query("drop table %s" % table)
820        table = 'update_test_table_2'
821        query("drop table if exists %s" % table)
822        query("create table %s ("
823            "n integer, m integer, t text, primary key (n, m))" % table)
824        for n in range(3):
825            for m in range(2):
826                t = chr(ord('a') + 2 * n + m)
827                query("insert into %s values("
828                    "%d, %d, '%s')" % (table, n + 1, m + 1, t))
829        self.assertRaises(pg.ProgrammingError, update,
830                          table, dict(n=2, t='b'))
831        self.assertEqual(update(table,
832                                dict(n=2, m=2, t='x'))['t'], 'x')
833        r = [r[0] for r in query('select t from "%s" where n=2'
834            ' order by m' % table).getresult()]
835        self.assertEqual(r, ['c', 'x'])
836        query("drop table %s" % table)
837
838    def testClear(self):
839        clear = self.db.clear
840        query = self.db.query
841        for table in ('clear_test_table', 'test table for clear'):
842            query('drop table if exists "%s"' % table)
843            query('create table "%s" ('
844                "n integer, b boolean, d date, t text)" % table)
845            r = clear(table)
846            result = {'n': 0, 'b': 'f', 'd': '', 't': ''}
847            self.assertEqual(r, result)
848            r['a'] = r['n'] = 1
849            r['d'] = r['t'] = 'x'
850            r['b'] = 't'
851            r['oid'] = 1L
852            r = clear(table, r)
853            result = {'a': 1, 'n': 0, 'b': 'f', 'd': '', 't': '', 'oid': 1L}
854            self.assertEqual(r, result)
855            query('drop table "%s"' % table)
856
857    def testDelete(self):
858        delete = self.db.delete
859        query = self.db.query
860        for table in ('delete_test_table', 'test table for delete'):
861            query('drop table if exists "%s"' % table)
862            query('create table "%s" ('
863                "n integer, t text) with oids" % table)
864            for n, t in enumerate('xyz'):
865                query('insert into "%s" values('
866                    "%d, '%s')" % (table, n + 1, t))
867            self.assertRaises(pg.ProgrammingError, self.db.get, table, 2)
868            r = self.db.get(table, 1, 'n')
869            s = delete(table, r)
870            self.assertEqual(s, 1)
871            r = self.db.get(table, 3, 'n')
872            s = delete(table, r)
873            self.assertEqual(s, 1)
874            s = delete(table, r)
875            self.assertEqual(s, 0)
876            r = query('select * from "%s"' % table).dictresult()
877            self.assertEqual(len(r), 1)
878            r = r[0]
879            result = {'n': 2, 't': 'y'}
880            self.assertEqual(r, result)
881            r = self.db.get(table, 2, 'n')
882            s = delete(table, r)
883            self.assertEqual(s, 1)
884            s = delete(table, r)
885            self.assertEqual(s, 0)
886            self.assertRaises(pg.DatabaseError, self.db.get, table, 2, 'n')
887            query('drop table "%s"' % table)
888
889    def testDeleteWithCompositeKey(self):
890        query = self.db.query
891        table = 'delete_test_table_1'
892        query("drop table if exists %s" % table)
893        query("create table %s ("
894            "n integer, t text, primary key (n))" % table)
895        for n, t in enumerate('abc'):
896            query("insert into %s values("
897                "%d, '%s')" % (table, n + 1, t))
898        self.assertRaises(pg.ProgrammingError, self.db.delete,
899            table, dict(t='b'))
900        self.assertEqual(self.db.delete(table, dict(n=2)), 1)
901        r = query('select t from "%s" where n=2' % table
902                  ).getresult()
903        self.assertEqual(r, [])
904        self.assertEqual(self.db.delete(table, dict(n=2)), 0)
905        r = query('select t from "%s" where n=3' % table
906                  ).getresult()[0][0]
907        self.assertEqual(r, 'c')
908        query("drop table %s" % table)
909        table = 'delete_test_table_2'
910        query("drop table if exists %s" % table)
911        query("create table %s ("
912            "n integer, m integer, t text, primary key (n, m))" % table)
913        for n in range(3):
914            for m in range(2):
915                t = chr(ord('a') + 2 * n + m)
916                query("insert into %s values("
917                    "%d, %d, '%s')" % (table, n + 1, m + 1, t))
918        self.assertRaises(pg.ProgrammingError, self.db.delete,
919            table, dict(n=2, t='b'))
920        self.assertEqual(self.db.delete(table, dict(n=2, m=2)), 1)
921        r = [r[0] for r in query('select t from "%s" where n=2'
922            ' order by m' % table).getresult()]
923        self.assertEqual(r, ['c'])
924        self.assertEqual(self.db.delete(table, dict(n=2, m=2)), 0)
925        r = [r[0] for r in query('select t from "%s" where n=3'
926            ' order by m' % table).getresult()]
927        self.assertEqual(r, ['e', 'f'])
928        self.assertEqual(self.db.delete(table, dict(n=3, m=1)), 1)
929        r = [r[0] for r in query('select t from "%s" where n=3'
930            ' order by m' % table).getresult()]
931        self.assertEqual(r, ['f'])
932        query("drop table %s" % table)
933
934    def testTransaction(self):
935        query = self.db.query
936        query("drop table if exists test_table")
937        query("create table test_table (n integer)")
938        self.db.begin()
939        query("insert into test_table values (1)")
940        query("insert into test_table values (2)")
941        self.db.commit()
942        self.db.begin()
943        query("insert into test_table values (3)")
944        query("insert into test_table values (4)")
945        self.db.rollback()
946        self.db.begin()
947        query("insert into test_table values (5)")
948        self.db.savepoint('before6')
949        query("insert into test_table values (6)")
950        self.db.rollback('before6')
951        query("insert into test_table values (7)")
952        self.db.commit()
953        self.db.begin()
954        self.db.savepoint('before8')
955        query("insert into test_table values (8)")
956        self.db.release('before8')
957        self.assertRaises(pg.ProgrammingError, self.db.rollback, 'before8')
958        self.db.commit()
959        self.db.start()
960        query("insert into test_table values (9)")
961        self.db.end()
962        r = [r[0] for r in query(
963            "select * from test_table order by 1").getresult()]
964        self.assertEqual(r, [1, 2, 5, 7, 9])
965        query("drop table test_table")
966
967    @unittest.skipIf(no_with, 'context managers not supported')
968    def testContextManager(self):
969        query = self.db.query
970        query("drop table if exists test_table")
971        query("create table test_table (n integer check(n>0))")
972        # wrap "with" statements to avoid SyntaxError in Python < 2.5
973        exec """from __future__ import with_statement\nif True:
974        with self.db:
975            query("insert into test_table values (1)")
976            query("insert into test_table values (2)")
977        try:
978            with self.db:
979                query("insert into test_table values (3)")
980                query("insert into test_table values (4)")
981                raise ValueError('test transaction should rollback')
982        except ValueError, error:
983            self.assertEqual(str(error), 'test transaction should rollback')
984        with self.db:
985            query("insert into test_table values (5)")
986        try:
987            with self.db:
988                query("insert into test_table values (6)")
989                query("insert into test_table values (-1)")
990        except pg.ProgrammingError, error:
991            self.assertTrue('check' in str(error))
992        with self.db:
993            query("insert into test_table values (7)")\n"""
994        r = [r[0] for r in query(
995            "select * from test_table order by 1").getresult()]
996        self.assertEqual(r, [1, 2, 5, 7])
997        query("drop table test_table")
998
999    def testBytea(self):
1000        query = self.db.query
1001        query('drop table if exists bytea_test')
1002        query('create table bytea_test ('
1003            'data bytea)')
1004        s = "It's all \\ kinds \x00 of\r nasty \xff stuff!\n"
1005        r = self.db.escape_bytea(s)
1006        query('insert into bytea_test values('
1007            "'%s')" % r)
1008        r = query('select * from bytea_test').getresult()
1009        self.assertTrue(len(r) == 1)
1010        r = r[0]
1011        self.assertTrue(len(r) == 1)
1012        r = r[0]
1013        r = self.db.unescape_bytea(r)
1014        self.assertEqual(r, s)
1015        query('drop table bytea_test')
1016
1017    def testDebugWithCallable(self):
1018        if debug:
1019            self.assertEqual(self.db.debug, debug)
1020        else:
1021            self.assertIsNone(self.db.debug)
1022        s = []
1023        self.db.debug = s.append
1024        try:
1025            self.db.query("select 1")
1026            self.db.query("select 2")
1027            self.assertEqual(s, ["select 1", "select 2"])
1028        finally:
1029            self.db.debug = debug
1030
1031
1032class TestSchemas(unittest.TestCase):
1033    """Test correct handling of schemas (namespaces)."""
1034
1035    @classmethod
1036    def setUpClass(cls):
1037        db = DB()
1038        query = db.query
1039        query("set client_min_messages=warning")
1040        for num_schema in range(5):
1041            if num_schema:
1042                schema = "s%d" % num_schema
1043                query("drop schema if exists %s cascade" % (schema,))
1044                try:
1045                    query("create schema %s" % (schema,))
1046                except pg.ProgrammingError:
1047                    raise RuntimeError("The test user cannot create schemas.\n"
1048                        "Grant create on database %s to the user"
1049                        " for running these tests." % dbname)
1050            else:
1051                schema = "public"
1052                query("drop table if exists %s.t" % (schema,))
1053                query("drop table if exists %s.t%d" % (schema, num_schema))
1054            query("create table %s.t with oids as select 1 as n, %d as d"
1055                  % (schema, num_schema))
1056            query("create table %s.t%d with oids as select 1 as n, %d as d"
1057                  % (schema, num_schema, num_schema))
1058        db.close()
1059
1060    @classmethod
1061    def tearDownClass(cls):
1062        db = DB()
1063        query = db.query
1064        query("set client_min_messages=warning")
1065        for num_schema in range(5):
1066            if num_schema:
1067                schema = "s%d" % num_schema
1068                query("drop schema %s cascade" % (schema,))
1069            else:
1070                schema = "public"
1071                query("drop table %s.t" % (schema,))
1072                query("drop table %s.t%d" % (schema, num_schema))
1073        db.close()
1074
1075    def setUp(self):
1076        self.db = DB()
1077        self.db.query("set client_min_messages=warning")
1078
1079    def tearDown(self):
1080        self.db.close()
1081
1082    def testGetTables(self):
1083        tables = self.db.get_tables()
1084        for num_schema in range(5):
1085            if num_schema:
1086                schema = "s" + str(num_schema)
1087            else:
1088                schema = "public"
1089            for t in (schema + ".t",
1090                    schema + ".t" + str(num_schema)):
1091                self.assertIn(t, tables)
1092
1093    def testGetAttnames(self):
1094        get_attnames = self.db.get_attnames
1095        query = self.db.query
1096        result = {'oid': 'int', 'd': 'int', 'n': 'int'}
1097        r = get_attnames("t")
1098        self.assertEqual(r, result)
1099        r = get_attnames("s4.t4")
1100        self.assertEqual(r, result)
1101        query("drop table if exists s3.t3m")
1102        query("create table s3.t3m with oids as select 1 as m")
1103        result_m = {'oid': 'int', 'm': 'int'}
1104        r = get_attnames("s3.t3m")
1105        self.assertEqual(r, result_m)
1106        query("set search_path to s1,s3")
1107        r = get_attnames("t3")
1108        self.assertEqual(r, result)
1109        r = get_attnames("t3m")
1110        self.assertEqual(r, result_m)
1111        query("drop table s3.t3m")
1112
1113    def testGet(self):
1114        get = self.db.get
1115        query = self.db.query
1116        PrgError = pg.ProgrammingError
1117        self.assertEqual(get("t", 1, 'n')['d'], 0)
1118        self.assertEqual(get("t0", 1, 'n')['d'], 0)
1119        self.assertEqual(get("public.t", 1, 'n')['d'], 0)
1120        self.assertEqual(get("public.t0", 1, 'n')['d'], 0)
1121        self.assertRaises(PrgError, get, "public.t1", 1, 'n')
1122        self.assertEqual(get("s1.t1", 1, 'n')['d'], 1)
1123        self.assertEqual(get("s3.t", 1, 'n')['d'], 3)
1124        query("set search_path to s2,s4")
1125        self.assertRaises(PrgError, get, "t1", 1, 'n')
1126        self.assertEqual(get("t4", 1, 'n')['d'], 4)
1127        self.assertRaises(PrgError, get, "t3", 1, 'n')
1128        self.assertEqual(get("t", 1, 'n')['d'], 2)
1129        self.assertEqual(get("s3.t3", 1, 'n')['d'], 3)
1130        query("set search_path to s1,s3")
1131        self.assertRaises(PrgError, get, "t2", 1, 'n')
1132        self.assertEqual(get("t3", 1, 'n')['d'], 3)
1133        self.assertRaises(PrgError, get, "t4", 1, 'n')
1134        self.assertEqual(get("t", 1, 'n')['d'], 1)
1135        self.assertEqual(get("s4.t4", 1, 'n')['d'], 4)
1136
1137    def testMangling(self):
1138        get = self.db.get
1139        query = self.db.query
1140        r = get("t", 1, 'n')
1141        self.assertIn('oid(public.t)', r)
1142        query("set search_path to s2")
1143        r = get("t2", 1, 'n')
1144        self.assertIn('oid(s2.t2)', r)
1145        query("set search_path to s3")
1146        r = get("t", 1, 'n')
1147        self.assertIn('oid(s3.t)', r)
1148
1149
1150if __name__ == '__main__':
1151    unittest.main()
Note: See TracBrowser for help on using the repository browser.