source: branches/4.x/module/tests/test_classic_dbwrapper.py @ 690

Last change on this file since 690 was 690, checked in by cito, 4 years ago

Amend tests so that they can run with PostgreSQL < 9.0

Note that we do not need to make these amendments in the trunk,
because we assume PostgreSQL >= 9.0 for PyGreSQL version 5.0.

  • Property svn:executable set to *
  • Property svn:keywords set to Id
File size: 45.3 KB
Line 
1#! /usr/bin/python
2# -*- coding: utf-8 -*-
3
4"""Test the classic PyGreSQL interface.
5
6Sub-tests for the DB wrapper object.
7
8Contributed by Christoph Zwerschke.
9
10These tests need a database to test against.
11
12"""
13
14try:
15    import unittest2 as unittest  # for Python < 2.7
16except ImportError:
17    import unittest
18
19import sys
20
21import pg  # the module under test
22
23from decimal import Decimal
24
25# check whether the "with" statement is supported
26no_with = sys.version_info[:2] < (2, 5)
27
28# We need a database to test against.  If LOCAL_PyGreSQL.py exists we will
29# get our information from that.  Otherwise we use the defaults.
30# The current user must have create schema privilege on the database.
31dbname = 'unittest'
32dbhost = None
33dbport = 5432
34
35debug = False  # let DB wrapper print debugging output
36
37try:
38    from LOCAL_PyGreSQL import *
39except ImportError:
40    pass
41
42
43def DB():
44    """Create a DB wrapper object connecting to the test database."""
45    db = pg.DB(dbname, dbhost, dbport)
46    if debug:
47        db.debug = debug
48    db.query("set client_min_messages=warning")
49    return db
50
51
52class TestDBClassBasic(unittest.TestCase):
53    """Test existence of the DB class wrapped pg connection methods."""
54
55    def setUp(self):
56        self.db = DB()
57
58    def tearDown(self):
59        try:
60            self.db.close()
61        except pg.InternalError:
62            pass
63
64    def testAllDBAttributes(self):
65        attributes = [
66            'begin',
67            'cancel',
68            'clear',
69            'close',
70            'commit',
71            'db',
72            'dbname',
73            'debug',
74            'delete',
75            'end',
76            'endcopy',
77            'error',
78            'escape_bytea',
79            'escape_identifier',
80            'escape_literal',
81            'escape_string',
82            'fileno',
83            'get',
84            'get_attnames',
85            'get_databases',
86            'get_notice_receiver',
87            'get_relations',
88            'get_tables',
89            'getline',
90            'getlo',
91            'getnotify',
92            'has_table_privilege',
93            'host',
94            'insert',
95            'inserttable',
96            'locreate',
97            'loimport',
98            'notification_handler',
99            'options',
100            'parameter',
101            'pkey',
102            'port',
103            'protocol_version',
104            'putline',
105            'query',
106            'release',
107            'reopen',
108            'reset',
109            'rollback',
110            'savepoint',
111            'server_version',
112            'set_notice_receiver',
113            'source',
114            'start',
115            'status',
116            'transaction',
117            'tty',
118            'unescape_bytea',
119            'update',
120            'use_regtypes',
121            'user',
122        ]
123        if self.db.server_version < 90000:  # PostgreSQL < 9.0
124            attributes.remove('escape_identifier')
125            attributes.remove('escape_literal')
126        db_attributes = [a for a in dir(self.db)
127            if not a.startswith('_')]
128        self.assertEqual(attributes, db_attributes)
129
130    def testAttributeDb(self):
131        self.assertEqual(self.db.db.db, dbname)
132
133    def testAttributeDbname(self):
134        self.assertEqual(self.db.dbname, dbname)
135
136    def testAttributeError(self):
137        error = self.db.error
138        self.assertTrue(not error or 'krb5_' in error)
139        self.assertEqual(self.db.error, self.db.db.error)
140
141    def testAttributeHost(self):
142        def_host = 'localhost'
143        host = self.db.host
144        self.assertIsInstance(host, str)
145        self.assertEqual(host, dbhost or def_host)
146        self.assertEqual(host, self.db.db.host)
147
148    def testAttributeOptions(self):
149        no_options = ''
150        options = self.db.options
151        self.assertEqual(options, no_options)
152        self.assertEqual(options, self.db.db.options)
153
154    def testAttributePort(self):
155        def_port = 5432
156        port = self.db.port
157        self.assertIsInstance(port, int)
158        self.assertEqual(port, dbport or def_port)
159        self.assertEqual(port, self.db.db.port)
160
161    def testAttributeProtocolVersion(self):
162        protocol_version = self.db.protocol_version
163        self.assertIsInstance(protocol_version, int)
164        self.assertTrue(2 <= protocol_version < 4)
165        self.assertEqual(protocol_version, self.db.db.protocol_version)
166
167    def testAttributeServerVersion(self):
168        server_version = self.db.server_version
169        self.assertIsInstance(server_version, int)
170        self.assertTrue(70400 <= server_version < 100000)
171        self.assertEqual(server_version, self.db.db.server_version)
172
173    def testAttributeStatus(self):
174        status_ok = 1
175        status = self.db.status
176        self.assertIsInstance(status, int)
177        self.assertEqual(status, status_ok)
178        self.assertEqual(status, self.db.db.status)
179
180    def testAttributeTty(self):
181        def_tty = ''
182        tty = self.db.tty
183        self.assertIsInstance(tty, str)
184        self.assertEqual(tty, def_tty)
185        self.assertEqual(tty, self.db.db.tty)
186
187    def testAttributeUser(self):
188        no_user = 'Deprecated facility'
189        user = self.db.user
190        self.assertTrue(user)
191        self.assertIsInstance(user, str)
192        self.assertNotEqual(user, no_user)
193        self.assertEqual(user, self.db.db.user)
194
195    def testMethodEscapeLiteral(self):
196        if self.db.server_version < 90000:  # PostgreSQL < 9.0
197            self.skipTest('Escaping functions not supported')
198        self.assertEqual(self.db.escape_literal(''), "''")
199
200    def testMethodEscapeIdentifier(self):
201        if self.db.server_version < 90000:  # PostgreSQL < 9.0
202            self.skipTest('Escaping functions not supported')
203        self.assertEqual(self.db.escape_identifier(''), '""')
204
205    def testMethodEscapeString(self):
206        self.assertEqual(self.db.escape_string(''), '')
207
208    def testMethodEscapeBytea(self):
209        self.assertEqual(self.db.escape_bytea('').replace(
210            '\\x', '').replace('\\', ''), '')
211
212    def testMethodUnescapeBytea(self):
213        self.assertEqual(self.db.unescape_bytea(''), '')
214
215    def testMethodQuery(self):
216        query = self.db.query
217        query("select 1+1")
218        query("select 1+$1+$2", 2, 3)
219        query("select 1+$1+$2", (2, 3))
220        query("select 1+$1+$2", [2, 3])
221        query("select 1+$1", 1)
222
223    def testMethodQueryEmpty(self):
224        self.assertRaises(ValueError, self.db.query, '')
225
226    def testMethodQueryProgrammingError(self):
227        try:
228            self.db.query("select 1/0")
229        except pg.ProgrammingError, error:
230            self.assertEqual(error.sqlstate, '22012')
231
232    def testMethodEndcopy(self):
233        try:
234            self.db.endcopy()
235        except IOError:
236            pass
237
238    def testMethodClose(self):
239        self.db.close()
240        try:
241            self.db.reset()
242        except pg.Error:
243            pass
244        else:
245            self.fail('Reset should give an error for a closed connection')
246        self.assertRaises(pg.InternalError, self.db.close)
247        self.assertRaises(pg.InternalError, self.db.query, 'select 1')
248
249    def testExistingConnection(self):
250        db = pg.DB(self.db.db)
251        self.assertEqual(self.db.db, db.db)
252        self.assertTrue(db.db)
253        db.close()
254        self.assertTrue(db.db)
255        db.reopen()
256        self.assertTrue(db.db)
257        db.close()
258        self.assertTrue(db.db)
259        db = pg.DB(self.db)
260        self.assertEqual(self.db.db, db.db)
261        db = pg.DB(db=self.db.db)
262        self.assertEqual(self.db.db, db.db)
263
264        class DB2:
265            pass
266
267        db2 = DB2()
268        db2._cnx = self.db.db
269        db = pg.DB(db2)
270        self.assertEqual(self.db.db, db.db)
271
272
273class TestDBClass(unittest.TestCase):
274    """Test the methods of the DB class wrapped pg connection."""
275
276    @classmethod
277    def setUpClass(cls):
278        db = DB()
279        db.query("drop table if exists test cascade")
280        db.query("create table test ("
281            "i2 smallint, i4 integer, i8 bigint,"
282            "d numeric, f4 real, f8 double precision, m money, "
283            "v4 varchar(4), c4 char(4), t text)")
284        db.query("create or replace view test_view as"
285            " select i4, v4 from test")
286        db.close()
287
288    @classmethod
289    def tearDownClass(cls):
290        db = DB()
291        db.query("drop table test cascade")
292        db.close()
293
294    def setUp(self):
295        self.db = DB()
296        query = self.db.query
297        query('set client_encoding=utf8')
298        query('set standard_conforming_strings=on')
299        query("set lc_monetary='C'")
300        query("set datestyle='ISO,YMD'")
301        try:
302            query('set bytea_output=hex')
303        except pg.ProgrammingError:  # PostgreSQL < 9.0
304            pass
305
306    def tearDown(self):
307        self.db.close()
308
309    def testEscapeLiteral(self):
310        if self.db.server_version < 90000:  # PostgreSQL < 9.0
311            self.skipTest('Escaping functions not supported')
312        f = self.db.escape_literal
313        self.assertEqual(f("plain"), "'plain'")
314        self.assertEqual(f("that's k\xe4se"), "'that''s k\xe4se'")
315        self.assertEqual(f(r"It's fine to have a \ inside."),
316            r" E'It''s fine to have a \\ inside.'")
317        self.assertEqual(f('No "quotes" must be escaped.'),
318            "'No \"quotes\" must be escaped.'")
319
320    def testEscapeIdentifier(self):
321        if self.db.server_version < 90000:  # PostgreSQL < 9.0
322            self.skipTest('Escaping functions not supported')
323        f = self.db.escape_identifier
324        self.assertEqual(f("plain"), '"plain"')
325        self.assertEqual(f("that's k\xe4se"), '"that\'s k\xe4se"')
326        self.assertEqual(f(r"It's fine to have a \ inside."),
327            '"It\'s fine to have a \\ inside."')
328        self.assertEqual(f('All "quotes" must be escaped.'),
329            '"All ""quotes"" must be escaped."')
330
331    def testEscapeString(self):
332        f = self.db.escape_string
333        self.assertEqual(f("plain"), "plain")
334        self.assertEqual(f("that's k\xe4se"), "that''s k\xe4se")
335        self.assertEqual(f(r"It's fine to have a \ inside."),
336            r"It''s fine to have a \ inside.")
337
338    def testEscapeBytea(self):
339        f = self.db.escape_bytea
340        # note that escape_byte always returns hex output since PostgreSQL 9.0,
341        # regardless of the bytea_output setting
342        if self.db.server_version < 90000:
343            self.assertEqual(f("plain"), r"plain")
344            self.assertEqual(f("that's k\xe4se"), r"that''s k\344se")
345            self.assertEqual(f('O\x00ps\xff!'), r"O\000ps\377!")
346        else:
347            self.assertEqual(f("plain"), r"\x706c61696e")
348            self.assertEqual(f("that's k\xe4se"), r"\x746861742773206be47365")
349            self.assertEqual(f('O\x00ps\xff!'), r"\x4f007073ff21")
350
351    def testUnescapeBytea(self):
352        f = self.db.unescape_bytea
353        self.assertEqual(f("plain"), "plain")
354        self.assertEqual(f("that's k\\344se"), "that's k\xe4se")
355        self.assertEqual(f(r'O\000ps\377!'), 'O\x00ps\xff!')
356        self.assertEqual(f(r"\\x706c61696e"), r"\x706c61696e")
357        self.assertEqual(f(r"\\x746861742773206be47365"),
358            r"\x746861742773206be47365")
359        self.assertEqual(f(r"\\x4f007073ff21"), r"\x4f007073ff21")
360
361    def testQuote(self):
362        f = self.db._quote
363        self.assertEqual(f(None, None), 'NULL')
364        self.assertEqual(f(None, 'int'), 'NULL')
365        self.assertEqual(f(None, 'float'), 'NULL')
366        self.assertEqual(f(None, 'num'), 'NULL')
367        self.assertEqual(f(None, 'money'), 'NULL')
368        self.assertEqual(f(None, 'bool'), 'NULL')
369        self.assertEqual(f(None, 'date'), 'NULL')
370        self.assertEqual(f('', 'int'), 'NULL')
371        self.assertEqual(f('', 'float'), 'NULL')
372        self.assertEqual(f('', 'num'), 'NULL')
373        self.assertEqual(f('', 'money'), 'NULL')
374        self.assertEqual(f('', 'bool'), 'NULL')
375        self.assertEqual(f('', 'date'), 'NULL')
376        self.assertEqual(f('', 'text'), "''")
377        self.assertEqual(f(0, 'int'), '0')
378        self.assertEqual(f(0, 'num'), '0')
379        self.assertEqual(f(1, 'int'), '1')
380        self.assertEqual(f(1, 'num'), '1')
381        self.assertEqual(f(-1, 'int'), '-1')
382        self.assertEqual(f(-1, 'num'), '-1')
383        self.assertEqual(f(123456789, 'int'), '123456789')
384        self.assertEqual(f(123456987, 'num'), '123456987')
385        self.assertEqual(f(1.23654789, 'num'), '1.23654789')
386        self.assertEqual(f(12365478.9, 'num'), '12365478.9')
387        self.assertEqual(f('123456789', 'num'), '123456789')
388        self.assertEqual(f('1.23456789', 'num'), '1.23456789')
389        self.assertEqual(f('12345678.9', 'num'), '12345678.9')
390        self.assertEqual(f(123, 'money'), '123')
391        self.assertEqual(f('123', 'money'), '123')
392        self.assertEqual(f(123.45, 'money'), '123.45')
393        self.assertEqual(f('123.45', 'money'), '123.45')
394        self.assertEqual(f(123.454, 'money'), '123.454')
395        self.assertEqual(f('123.454', 'money'), '123.454')
396        self.assertEqual(f(123.456, 'money'), '123.456')
397        self.assertEqual(f('123.456', 'money'), '123.456')
398        self.assertEqual(f('f', 'bool'), "'f'")
399        self.assertEqual(f('F', 'bool'), "'f'")
400        self.assertEqual(f('false', 'bool'), "'f'")
401        self.assertEqual(f('False', 'bool'), "'f'")
402        self.assertEqual(f('FALSE', 'bool'), "'f'")
403        self.assertEqual(f(0, 'bool'), "'f'")
404        self.assertEqual(f('0', 'bool'), "'f'")
405        self.assertEqual(f('-', 'bool'), "'f'")
406        self.assertEqual(f('n', 'bool'), "'f'")
407        self.assertEqual(f('N', 'bool'), "'f'")
408        self.assertEqual(f('no', 'bool'), "'f'")
409        self.assertEqual(f('off', 'bool'), "'f'")
410        self.assertEqual(f('t', 'bool'), "'t'")
411        self.assertEqual(f('T', 'bool'), "'t'")
412        self.assertEqual(f('true', 'bool'), "'t'")
413        self.assertEqual(f('True', 'bool'), "'t'")
414        self.assertEqual(f('TRUE', 'bool'), "'t'")
415        self.assertEqual(f(1, 'bool'), "'t'")
416        self.assertEqual(f(2, 'bool'), "'t'")
417        self.assertEqual(f(-1, 'bool'), "'t'")
418        self.assertEqual(f(0.5, 'bool'), "'t'")
419        self.assertEqual(f('1', 'bool'), "'t'")
420        self.assertEqual(f('y', 'bool'), "'t'")
421        self.assertEqual(f('Y', 'bool'), "'t'")
422        self.assertEqual(f('yes', 'bool'), "'t'")
423        self.assertEqual(f('on', 'bool'), "'t'")
424        self.assertEqual(f('01.01.2000', 'date'), "'01.01.2000'")
425        self.assertEqual(f(123, 'text'), "'123'")
426        self.assertEqual(f(1.23, 'text'), "'1.23'")
427        self.assertEqual(f('abc', 'text'), "'abc'")
428        self.assertEqual(f("ab'c", 'text'), "'ab''c'")
429        self.assertEqual(f('ab\\c', 'text'), "'ab\\c'")
430        self.assertEqual(f("a\\b'c", 'text'), "'a\\b''c'")
431        self.db.query('set standard_conforming_strings=off')
432        self.assertEqual(f('ab\\c', 'text'), "'ab\\\\c'")
433        self.assertEqual(f("a\\b'c", 'text'), "'a\\\\b''c'")
434
435    def testQuery(self):
436        query = self.db.query
437        query("drop table if exists test_table")
438        q = "create table test_table (n integer) with oids"
439        r = query(q)
440        self.assertIsNone(r)
441        q = "insert into test_table values (1)"
442        r = query(q)
443        self.assertIsInstance(r, int)
444        q = "insert into test_table select 2"
445        r = query(q)
446        self.assertIsInstance(r, int)
447        oid = r
448        q = "select oid from test_table where n=2"
449        r = query(q).getresult()
450        self.assertEqual(len(r), 1)
451        r = r[0]
452        self.assertEqual(len(r), 1)
453        r = r[0]
454        self.assertEqual(r, oid)
455        q = "insert into test_table select 3 union select 4 union select 5"
456        r = query(q)
457        self.assertIsInstance(r, str)
458        self.assertEqual(r, '3')
459        q = "update test_table set n=4 where n<5"
460        r = query(q)
461        self.assertIsInstance(r, str)
462        self.assertEqual(r, '4')
463        q = "delete from test_table"
464        r = query(q)
465        self.assertIsInstance(r, str)
466        self.assertEqual(r, '5')
467        query("drop table test_table")
468
469    def testMultipleQueries(self):
470        self.assertEqual(self.db.query(
471            "create temporary table test_multi (n integer);"
472            "insert into test_multi values (4711);"
473            "select n from test_multi").getresult()[0][0], 4711)
474
475    def testQueryWithParams(self):
476        query = self.db.query
477        query("drop table if exists test_table")
478        q = "create table test_table (n1 integer, n2 integer) with oids"
479        query(q)
480        q = "insert into test_table values ($1, $2)"
481        r = query(q, (1, 2))
482        self.assertIsInstance(r, int)
483        r = query(q, [3, 4])
484        self.assertIsInstance(r, int)
485        r = query(q, [5, 6])
486        self.assertIsInstance(r, int)
487        q = "select * from test_table order by 1, 2"
488        self.assertEqual(query(q).getresult(),
489            [(1, 2), (3, 4), (5, 6)])
490        q = "select * from test_table where n1=$1 and n2=$2"
491        self.assertEqual(query(q, 3, 4).getresult(), [(3, 4)])
492        q = "update test_table set n2=$2 where n1=$1"
493        r = query(q, 3, 7)
494        self.assertEqual(r, '1')
495        q = "select * from test_table order by 1, 2"
496        self.assertEqual(query(q).getresult(),
497            [(1, 2), (3, 7), (5, 6)])
498        q = "delete from test_table where n2!=$1"
499        r = query(q, 4)
500        self.assertEqual(r, '3')
501        query("drop table test_table")
502
503    def testEmptyQuery(self):
504        self.assertRaises(ValueError, self.db.query, '')
505
506    def testQueryProgrammingError(self):
507        try:
508            self.db.query("select 1/0")
509        except pg.ProgrammingError, error:
510            self.assertEqual(error.sqlstate, '22012')
511
512    def testPkey(self):
513        query = self.db.query
514        for n in range(4):
515            query("drop table if exists pkeytest%d" % n)
516        query("create table pkeytest0 ("
517            "a smallint)")
518        query("create table pkeytest1 ("
519            "b smallint primary key)")
520        query("create table pkeytest2 ("
521            "c smallint, d smallint primary key)")
522        query("create table pkeytest3 ("
523            "e smallint, f smallint, g smallint, "
524            "h smallint, i smallint, "
525            "primary key (f,h))")
526        pkey = self.db.pkey
527        self.assertRaises(KeyError, pkey, 'pkeytest0')
528        self.assertEqual(pkey('pkeytest1'), 'b')
529        self.assertEqual(pkey('pkeytest2'), 'd')
530        self.assertEqual(pkey('pkeytest3'), frozenset('fh'))
531        self.assertEqual(pkey('pkeytest0', 'none'), 'none')
532        self.assertEqual(pkey('pkeytest0'), 'none')
533        pkey(None, {'t': 'a', 'n.t': 'b'})
534        self.assertEqual(pkey('t'), 'a')
535        self.assertEqual(pkey('n.t'), 'b')
536        self.assertRaises(KeyError, pkey, 'pkeytest0')
537        for n in range(4):
538            query("drop table pkeytest%d" % n)
539
540    def testGetDatabases(self):
541        databases = self.db.get_databases()
542        self.assertIn('template0', databases)
543        self.assertIn('template1', databases)
544        self.assertNotIn('not existing database', databases)
545        self.assertIn('postgres', databases)
546        self.assertIn(dbname, databases)
547
548    def testGetTables(self):
549        get_tables = self.db.get_tables
550        result1 = get_tables()
551        tables = ('"A very Special Name"',
552            '"A_MiXeD_quoted_NaMe"', 'a1', 'a2',
553            'A_MiXeD_NaMe', '"another special name"',
554            'averyveryveryveryveryveryverylongtablename',
555            'b0', 'b3', 'x', 'xx', 'xXx', 'y', 'z')
556        for t in tables:
557            self.db.query('drop table if exists %s' % t)
558            self.db.query("create table %s"
559                " as select 0" % t)
560        result3 = get_tables()
561        result2 = []
562        for t in result3:
563            if t not in result1:
564                result2.append(t)
565        result3 = []
566        for t in tables:
567            if not t.startswith('"'):
568                t = t.lower()
569            result3.append('public.' + t)
570        self.assertEqual(result2, result3)
571        for t in result2:
572            self.db.query('drop table %s' % t)
573        result2 = get_tables()
574        self.assertEqual(result2, result1)
575
576    def testGetRelations(self):
577        get_relations = self.db.get_relations
578        result = get_relations()
579        self.assertIn('public.test', result)
580        self.assertIn('public.test_view', result)
581        result = get_relations('rv')
582        self.assertIn('public.test', result)
583        self.assertIn('public.test_view', result)
584        result = get_relations('r')
585        self.assertIn('public.test', result)
586        self.assertNotIn('public.test_view', result)
587        result = get_relations('v')
588        self.assertNotIn('public.test', result)
589        self.assertIn('public.test_view', result)
590        result = get_relations('cisSt')
591        self.assertNotIn('public.test', result)
592        self.assertNotIn('public.test_view', result)
593
594    def testAttnames(self):
595        self.assertRaises(pg.ProgrammingError,
596            self.db.get_attnames, 'does_not_exist')
597        self.assertRaises(pg.ProgrammingError,
598            self.db.get_attnames, 'has.too.many.dots')
599        for table in ('attnames_test_table', 'test table for attnames'):
600            self.db.query('drop table if exists "%s"' % table)
601            self.db.query('create table "%s" ('
602                'a smallint, b integer, c bigint, '
603                'e numeric, f float, f2 double precision, m money, '
604                'x smallint, y smallint, z smallint, '
605                'Normal_NaMe smallint, "Special Name" smallint, '
606                't text, u char(2), v varchar(2), '
607                'primary key (y, u)) with oids' % table)
608            attributes = self.db.get_attnames(table)
609            result = {'a': 'int', 'c': 'int', 'b': 'int',
610                'e': 'num', 'f': 'float', 'f2': 'float', 'm': 'money',
611                'normal_name': 'int', 'Special Name': 'int',
612                'u': 'text', 't': 'text', 'v': 'text',
613                'y': 'int', 'x': 'int', 'z': 'int', 'oid': 'int'}
614            self.assertEqual(attributes, result)
615            self.db.query('drop table "%s"' % table)
616
617    def testHasTablePrivilege(self):
618        can = self.db.has_table_privilege
619        self.assertEqual(can('test'), True)
620        self.assertEqual(can('test', 'select'), True)
621        self.assertEqual(can('test', 'SeLeCt'), True)
622        self.assertEqual(can('test', 'SELECT'), True)
623        self.assertEqual(can('test', 'insert'), True)
624        self.assertEqual(can('test', 'update'), True)
625        self.assertEqual(can('test', 'delete'), True)
626        self.assertRaises(pg.ProgrammingError, can, 'test', 'foobar')
627        self.assertRaises(pg.ProgrammingError, can, 'table_does_not_exist')
628
629    def testGet(self):
630        get = self.db.get
631        query = self.db.query
632        for table in ('get_test_table', 'test table for get'):
633            query('drop table if exists "%s"' % table)
634            query('create table "%s" ('
635                "n integer, t text) with oids" % table)
636            for n, t in enumerate('xyz'):
637                query('insert into "%s" values('"%d, '%s')"
638                    % (table, n + 1, t))
639            self.assertRaises(pg.ProgrammingError, get, table, 2)
640            r = get(table, 2, 'n')
641            oid_table = table
642            if ' ' in table:
643                oid_table = '"%s"' % oid_table
644            oid_table = 'oid(public.%s)' % oid_table
645            self.assertIn(oid_table, r)
646            oid = r[oid_table]
647            self.assertIsInstance(oid, int)
648            result = {'t': 'y', 'n': 2, oid_table: oid}
649            self.assertEqual(r, result)
650            self.assertEqual(get(table + ' *', 2, 'n'), r)
651            self.assertEqual(get(table, oid, 'oid')['t'], 'y')
652            self.assertEqual(get(table, 1, 'n')['t'], 'x')
653            self.assertEqual(get(table, 3, 'n')['t'], 'z')
654            self.assertEqual(get(table, 2, 'n')['t'], 'y')
655            self.assertRaises(pg.DatabaseError, get, table, 4, 'n')
656            r['n'] = 3
657            self.assertEqual(get(table, r, 'n')['t'], 'z')
658            self.assertEqual(get(table, 1, 'n')['t'], 'x')
659            query('alter table "%s" alter n set not null' % table)
660            query('alter table "%s" add primary key (n)' % table)
661            self.assertEqual(get(table, 3)['t'], 'z')
662            self.assertEqual(get(table, 1)['t'], 'x')
663            self.assertEqual(get(table, 2)['t'], 'y')
664            r['n'] = 1
665            self.assertEqual(get(table, r)['t'], 'x')
666            r['n'] = 3
667            self.assertEqual(get(table, r)['t'], 'z')
668            r['n'] = 2
669            self.assertEqual(get(table, r)['t'], 'y')
670            query('drop table "%s"' % table)
671
672    def testGetWithCompositeKey(self):
673        get = self.db.get
674        query = self.db.query
675        table = 'get_test_table_1'
676        query("drop table if exists %s" % table)
677        query("create table %s ("
678            "n integer, t text, primary key (n))" % table)
679        for n, t in enumerate('abc'):
680            query("insert into %s values("
681                "%d, '%s')" % (table, n + 1, t))
682        self.assertEqual(get(table, 2)['t'], 'b')
683        query("drop table %s" % table)
684        table = 'get_test_table_2'
685        query("drop table if exists %s" % table)
686        query("create table %s ("
687            "n integer, m integer, t text, primary key (n, m))" % table)
688        for n in range(3):
689            for m in range(2):
690                t = chr(ord('a') + 2 * n + m)
691                query("insert into %s values("
692                    "%d, %d, '%s')" % (table, n + 1, m + 1, t))
693        self.assertRaises(pg.ProgrammingError, get, table, 2)
694        self.assertEqual(get(table, dict(n=2, m=2))['t'], 'd')
695        self.assertEqual(get(table, dict(n=1, m=2),
696                             ('n', 'm'))['t'], 'b')
697        self.assertEqual(get(table, dict(n=3, m=2),
698                             frozenset(['n', 'm']))['t'], 'f')
699        query("drop table %s" % table)
700
701    def testGetFromView(self):
702        self.db.query('delete from test where i4=14')
703        self.db.query('insert into test (i4, v4) values('
704            "14, 'abc4')")
705        r = self.db.get('test_view', 14, 'i4')
706        self.assertIn('v4', r)
707        self.assertEqual(r['v4'], 'abc4')
708
709    def testInsert(self):
710        insert = self.db.insert
711        query = self.db.query
712        server_version = self.db.server_version
713        for table in ('insert_test_table', 'test table for insert'):
714            query('drop table if exists "%s"' % table)
715            query('create table "%s" ('
716                "i2 smallint, i4 integer, i8 bigint,"
717                " d numeric, f4 real, f8 double precision, m money,"
718                " v4 varchar(4), c4 char(4), t text,"
719                " b boolean, ts timestamp) with oids" % table)
720            oid_table = table
721            if ' ' in table:
722                oid_table = '"%s"' % oid_table
723            oid_table = 'oid(public.%s)' % oid_table
724            tests = [dict(i2=None, i4=None, i8=None),
725                (dict(i2='', i4='', i8=''), dict(i2=None, i4=None, i8=None)),
726                (dict(i2=0, i4=0, i8=0), dict(i2=0, i4=0, i8=0)),
727                dict(i2=42, i4=123456, i8=9876543210),
728                dict(i2=2 ** 15 - 1,
729                    i4=int(2 ** 31 - 1), i8=long(2 ** 63 - 1)),
730                dict(d=None), (dict(d=''), dict(d=None)),
731                dict(d=Decimal(0)), (dict(d=0), dict(d=Decimal(0))),
732                dict(f4=None, f8=None), dict(f4=0, f8=0),
733                (dict(f4='', f8=''), dict(f4=None, f8=None)),
734                (dict(d=1234.5, f4=1234.5, f8=1234.5),
735                      dict(d=Decimal('1234.5'))),
736                dict(d=Decimal('123.456789'), f4=12.375, f8=123.4921875),
737                dict(d=Decimal('123456789.9876543212345678987654321')),
738                dict(m=None), (dict(m=''), dict(m=None)),
739                dict(m=Decimal('-1234.56')),
740                (dict(m='-1234.56'), dict(m=Decimal('-1234.56'))),
741                dict(m=Decimal('1234.56')), dict(m=Decimal('123456')),
742                (dict(m='1234.56'), dict(m=Decimal('1234.56'))),
743                (dict(m=1234.5), dict(m=Decimal('1234.5'))),
744                (dict(m=-1234.5), dict(m=Decimal('-1234.5'))),
745                (dict(m=123456), dict(m=Decimal('123456'))),
746                (dict(m='1234567.89'), dict(m=Decimal('1234567.89'))),
747                dict(b=None), (dict(b=''), dict(b=None)),
748                dict(b='f'), dict(b='t'),
749                (dict(b=0), dict(b='f')), (dict(b=1), dict(b='t')),
750                (dict(b=False), dict(b='f')), (dict(b=True), dict(b='t')),
751                (dict(b='0'), dict(b='f')), (dict(b='1'), dict(b='t')),
752                (dict(b='n'), dict(b='f')), (dict(b='y'), dict(b='t')),
753                (dict(b='no'), dict(b='f')), (dict(b='yes'), dict(b='t')),
754                (dict(b='off'), dict(b='f')), (dict(b='on'), dict(b='t')),
755                dict(v4=None, c4=None, t=None),
756                (dict(v4='', c4='', t=''), dict(c4=' ' * 4)),
757                dict(v4='1234', c4='1234', t='1234' * 10),
758                dict(v4='abcd', c4='abcd', t='abcdefg'),
759                (dict(v4='abc', c4='abc', t='abc'), dict(c4='abc ')),
760                dict(ts=None), (dict(ts=''), dict(ts=None)),
761                (dict(ts=0), dict(ts=None)), (dict(ts=False), dict(ts=None)),
762                dict(ts='2012-12-21 00:00:00'),
763                (dict(ts='2012-12-21'), dict(ts='2012-12-21 00:00:00')),
764                dict(ts='2012-12-21 12:21:12'),
765                dict(ts='2013-01-05 12:13:14'),
766                dict(ts='current_timestamp')]
767            for test in tests:
768                if isinstance(test, dict):
769                    data = test
770                    change = {}
771                else:
772                    data, change = test
773                expect = data.copy()
774                expect.update(change)
775                if data.get('m') and server_version < 910000:
776                    # PostgreSQL < 9.1 cannot directly convert numbers to money
777                    data['m'] = "'%s'::money" % data['m']
778                self.assertEqual(insert(table, data), data)
779                self.assertIn(oid_table, data)
780                oid = data[oid_table]
781                self.assertIsInstance(oid, int)
782                data = dict(item for item in data.iteritems()
783                    if item[0] in expect)
784                ts = expect.get('ts')
785                if ts == 'current_timestamp':
786                    ts = expect['ts'] = data['ts']
787                    if len(ts) > 19:
788                        self.assertEqual(ts[19], '.')
789                        ts = ts[:19]
790                    else:
791                        self.assertEqual(len(ts), 19)
792                    self.assertTrue(ts[:4].isdigit())
793                    self.assertEqual(ts[4], '-')
794                    self.assertEqual(ts[10], ' ')
795                    self.assertTrue(ts[11:13].isdigit())
796                    self.assertEqual(ts[13], ':')
797                self.assertEqual(data, expect)
798                data = query(
799                    'select oid,* from "%s"' % table).dictresult()[0]
800                self.assertEqual(data['oid'], oid)
801                data = dict(item for item in data.iteritems()
802                    if item[0] in expect)
803                self.assertEqual(data, expect)
804                query('delete from "%s"' % table)
805            query('drop table "%s"' % table)
806
807    def testUpdate(self):
808        update = self.db.update
809        query = self.db.query
810        for table in ('update_test_table', 'test table for update'):
811            query('drop table if exists "%s"' % table)
812            query('create table "%s" ('
813                "n integer, t text) with oids" % table)
814            for n, t in enumerate('xyz'):
815                query('insert into "%s" values('
816                    "%d, '%s')" % (table, n + 1, t))
817            self.assertRaises(pg.ProgrammingError, self.db.get, table, 2)
818            r = self.db.get(table, 2, 'n')
819            r['t'] = 'u'
820            s = update(table, r)
821            self.assertEqual(s, r)
822            r = query('select t from "%s" where n=2' % table
823                      ).getresult()[0][0]
824            self.assertEqual(r, 'u')
825            query('drop table "%s"' % table)
826
827    def testUpdateWithCompositeKey(self):
828        update = self.db.update
829        query = self.db.query
830        table = 'update_test_table_1'
831        query("drop table if exists %s" % table)
832        query("create table %s ("
833            "n integer, t text, primary key (n))" % table)
834        for n, t in enumerate('abc'):
835            query("insert into %s values("
836                "%d, '%s')" % (table, n + 1, t))
837        self.assertRaises(pg.ProgrammingError, update,
838                          table, dict(t='b'))
839        self.assertEqual(update(table, dict(n=2, t='d'))['t'], 'd')
840        r = query('select t from "%s" where n=2' % table
841                  ).getresult()[0][0]
842        self.assertEqual(r, 'd')
843        query("drop table %s" % table)
844        table = 'update_test_table_2'
845        query("drop table if exists %s" % table)
846        query("create table %s ("
847            "n integer, m integer, t text, primary key (n, m))" % table)
848        for n in range(3):
849            for m in range(2):
850                t = chr(ord('a') + 2 * n + m)
851                query("insert into %s values("
852                    "%d, %d, '%s')" % (table, n + 1, m + 1, t))
853        self.assertRaises(pg.ProgrammingError, update,
854                          table, dict(n=2, t='b'))
855        self.assertEqual(update(table,
856                                dict(n=2, m=2, t='x'))['t'], 'x')
857        r = [r[0] for r in query('select t from "%s" where n=2'
858            ' order by m' % table).getresult()]
859        self.assertEqual(r, ['c', 'x'])
860        query("drop table %s" % table)
861
862    def testClear(self):
863        clear = self.db.clear
864        query = self.db.query
865        for table in ('clear_test_table', 'test table for clear'):
866            query('drop table if exists "%s"' % table)
867            query('create table "%s" ('
868                "n integer, b boolean, d date, t text)" % table)
869            r = clear(table)
870            result = {'n': 0, 'b': 'f', 'd': '', 't': ''}
871            self.assertEqual(r, result)
872            r['a'] = r['n'] = 1
873            r['d'] = r['t'] = 'x'
874            r['b'] = 't'
875            r['oid'] = 1L
876            r = clear(table, r)
877            result = {'a': 1, 'n': 0, 'b': 'f', 'd': '', 't': '', 'oid': 1L}
878            self.assertEqual(r, result)
879            query('drop table "%s"' % table)
880
881    def testDelete(self):
882        delete = self.db.delete
883        query = self.db.query
884        for table in ('delete_test_table', 'test table for delete'):
885            query('drop table if exists "%s"' % table)
886            query('create table "%s" ('
887                "n integer, t text) with oids" % table)
888            for n, t in enumerate('xyz'):
889                query('insert into "%s" values('
890                    "%d, '%s')" % (table, n + 1, t))
891            self.assertRaises(pg.ProgrammingError, self.db.get, table, 2)
892            r = self.db.get(table, 1, 'n')
893            s = delete(table, r)
894            self.assertEqual(s, 1)
895            r = self.db.get(table, 3, 'n')
896            s = delete(table, r)
897            self.assertEqual(s, 1)
898            s = delete(table, r)
899            self.assertEqual(s, 0)
900            r = query('select * from "%s"' % table).dictresult()
901            self.assertEqual(len(r), 1)
902            r = r[0]
903            result = {'n': 2, 't': 'y'}
904            self.assertEqual(r, result)
905            r = self.db.get(table, 2, 'n')
906            s = delete(table, r)
907            self.assertEqual(s, 1)
908            s = delete(table, r)
909            self.assertEqual(s, 0)
910            self.assertRaises(pg.DatabaseError, self.db.get, table, 2, 'n')
911            query('drop table "%s"' % table)
912
913    def testDeleteWithCompositeKey(self):
914        query = self.db.query
915        table = 'delete_test_table_1'
916        query("drop table if exists %s" % table)
917        query("create table %s ("
918            "n integer, t text, primary key (n))" % table)
919        for n, t in enumerate('abc'):
920            query("insert into %s values("
921                "%d, '%s')" % (table, n + 1, t))
922        self.assertRaises(pg.ProgrammingError, self.db.delete,
923            table, dict(t='b'))
924        self.assertEqual(self.db.delete(table, dict(n=2)), 1)
925        r = query('select t from "%s" where n=2' % table
926                  ).getresult()
927        self.assertEqual(r, [])
928        self.assertEqual(self.db.delete(table, dict(n=2)), 0)
929        r = query('select t from "%s" where n=3' % table
930                  ).getresult()[0][0]
931        self.assertEqual(r, 'c')
932        query("drop table %s" % table)
933        table = 'delete_test_table_2'
934        query("drop table if exists %s" % table)
935        query("create table %s ("
936            "n integer, m integer, t text, primary key (n, m))" % table)
937        for n in range(3):
938            for m in range(2):
939                t = chr(ord('a') + 2 * n + m)
940                query("insert into %s values("
941                    "%d, %d, '%s')" % (table, n + 1, m + 1, t))
942        self.assertRaises(pg.ProgrammingError, self.db.delete,
943            table, dict(n=2, t='b'))
944        self.assertEqual(self.db.delete(table, dict(n=2, m=2)), 1)
945        r = [r[0] for r in query('select t from "%s" where n=2'
946            ' order by m' % table).getresult()]
947        self.assertEqual(r, ['c'])
948        self.assertEqual(self.db.delete(table, dict(n=2, m=2)), 0)
949        r = [r[0] for r in query('select t from "%s" where n=3'
950            ' order by m' % table).getresult()]
951        self.assertEqual(r, ['e', 'f'])
952        self.assertEqual(self.db.delete(table, dict(n=3, m=1)), 1)
953        r = [r[0] for r in query('select t from "%s" where n=3'
954            ' order by m' % table).getresult()]
955        self.assertEqual(r, ['f'])
956        query("drop table %s" % table)
957
958    def testTransaction(self):
959        query = self.db.query
960        query("drop table if exists test_table")
961        query("create table test_table (n integer)")
962        self.db.begin()
963        query("insert into test_table values (1)")
964        query("insert into test_table values (2)")
965        self.db.commit()
966        self.db.begin()
967        query("insert into test_table values (3)")
968        query("insert into test_table values (4)")
969        self.db.rollback()
970        self.db.begin()
971        query("insert into test_table values (5)")
972        self.db.savepoint('before6')
973        query("insert into test_table values (6)")
974        self.db.rollback('before6')
975        query("insert into test_table values (7)")
976        self.db.commit()
977        self.db.begin()
978        self.db.savepoint('before8')
979        query("insert into test_table values (8)")
980        self.db.release('before8')
981        self.assertRaises(pg.ProgrammingError, self.db.rollback, 'before8')
982        self.db.commit()
983        self.db.start()
984        query("insert into test_table values (9)")
985        self.db.end()
986        r = [r[0] for r in query(
987            "select * from test_table order by 1").getresult()]
988        self.assertEqual(r, [1, 2, 5, 7, 9])
989        query("drop table test_table")
990
991    @unittest.skipIf(no_with, 'context managers not supported')
992    def testContextManager(self):
993        query = self.db.query
994        query("drop table if exists test_table")
995        query("create table test_table (n integer check(n>0))")
996        # wrap "with" statements to avoid SyntaxError in Python < 2.5
997        exec """from __future__ import with_statement\nif True:
998        with self.db:
999            query("insert into test_table values (1)")
1000            query("insert into test_table values (2)")
1001        try:
1002            with self.db:
1003                query("insert into test_table values (3)")
1004                query("insert into test_table values (4)")
1005                raise ValueError('test transaction should rollback')
1006        except ValueError, error:
1007            self.assertEqual(str(error), 'test transaction should rollback')
1008        with self.db:
1009            query("insert into test_table values (5)")
1010        try:
1011            with self.db:
1012                query("insert into test_table values (6)")
1013                query("insert into test_table values (-1)")
1014        except pg.ProgrammingError, error:
1015            self.assertTrue('check' in str(error))
1016        with self.db:
1017            query("insert into test_table values (7)")\n"""
1018        r = [r[0] for r in query(
1019            "select * from test_table order by 1").getresult()]
1020        self.assertEqual(r, [1, 2, 5, 7])
1021        query("drop table test_table")
1022
1023    def testBytea(self):
1024        query = self.db.query
1025        query('drop table if exists bytea_test')
1026        query('create table bytea_test ('
1027            'data bytea)')
1028        s = "It's all \\ kinds \x00 of\r nasty \xff stuff!\n"
1029        r = self.db.escape_bytea(s)
1030        query('insert into bytea_test values('
1031            "'%s')" % r)
1032        r = query('select * from bytea_test').getresult()
1033        self.assertTrue(len(r) == 1)
1034        r = r[0]
1035        self.assertTrue(len(r) == 1)
1036        r = r[0]
1037        r = self.db.unescape_bytea(r)
1038        self.assertEqual(r, s)
1039        query('drop table bytea_test')
1040
1041    def testDebugWithCallable(self):
1042        if debug:
1043            self.assertEqual(self.db.debug, debug)
1044        else:
1045            self.assertIsNone(self.db.debug)
1046        s = []
1047        self.db.debug = s.append
1048        try:
1049            self.db.query("select 1")
1050            self.db.query("select 2")
1051            self.assertEqual(s, ["select 1", "select 2"])
1052        finally:
1053            self.db.debug = debug
1054
1055
1056class TestSchemas(unittest.TestCase):
1057    """Test correct handling of schemas (namespaces)."""
1058
1059    @classmethod
1060    def setUpClass(cls):
1061        db = DB()
1062        query = db.query
1063        query("set client_min_messages=warning")
1064        for num_schema in range(5):
1065            if num_schema:
1066                schema = "s%d" % num_schema
1067                query("drop schema if exists %s cascade" % (schema,))
1068                try:
1069                    query("create schema %s" % (schema,))
1070                except pg.ProgrammingError:
1071                    raise RuntimeError("The test user cannot create schemas.\n"
1072                        "Grant create on database %s to the user"
1073                        " for running these tests." % dbname)
1074            else:
1075                schema = "public"
1076                query("drop table if exists %s.t" % (schema,))
1077                query("drop table if exists %s.t%d" % (schema, num_schema))
1078            query("create table %s.t with oids as select 1 as n, %d as d"
1079                  % (schema, num_schema))
1080            query("create table %s.t%d with oids as select 1 as n, %d as d"
1081                  % (schema, num_schema, num_schema))
1082        db.close()
1083
1084    @classmethod
1085    def tearDownClass(cls):
1086        db = DB()
1087        query = db.query
1088        query("set client_min_messages=warning")
1089        for num_schema in range(5):
1090            if num_schema:
1091                schema = "s%d" % num_schema
1092                query("drop schema %s cascade" % (schema,))
1093            else:
1094                schema = "public"
1095                query("drop table %s.t" % (schema,))
1096                query("drop table %s.t%d" % (schema, num_schema))
1097        db.close()
1098
1099    def setUp(self):
1100        self.db = DB()
1101        self.db.query("set client_min_messages=warning")
1102
1103    def tearDown(self):
1104        self.db.close()
1105
1106    def testGetTables(self):
1107        tables = self.db.get_tables()
1108        for num_schema in range(5):
1109            if num_schema:
1110                schema = "s" + str(num_schema)
1111            else:
1112                schema = "public"
1113            for t in (schema + ".t",
1114                    schema + ".t" + str(num_schema)):
1115                self.assertIn(t, tables)
1116
1117    def testGetAttnames(self):
1118        get_attnames = self.db.get_attnames
1119        query = self.db.query
1120        result = {'oid': 'int', 'd': 'int', 'n': 'int'}
1121        r = get_attnames("t")
1122        self.assertEqual(r, result)
1123        r = get_attnames("s4.t4")
1124        self.assertEqual(r, result)
1125        query("drop table if exists s3.t3m")
1126        query("create table s3.t3m with oids as select 1 as m")
1127        result_m = {'oid': 'int', 'm': 'int'}
1128        r = get_attnames("s3.t3m")
1129        self.assertEqual(r, result_m)
1130        query("set search_path to s1,s3")
1131        r = get_attnames("t3")
1132        self.assertEqual(r, result)
1133        r = get_attnames("t3m")
1134        self.assertEqual(r, result_m)
1135        query("drop table s3.t3m")
1136
1137    def testGet(self):
1138        get = self.db.get
1139        query = self.db.query
1140        PrgError = pg.ProgrammingError
1141        self.assertEqual(get("t", 1, 'n')['d'], 0)
1142        self.assertEqual(get("t0", 1, 'n')['d'], 0)
1143        self.assertEqual(get("public.t", 1, 'n')['d'], 0)
1144        self.assertEqual(get("public.t0", 1, 'n')['d'], 0)
1145        self.assertRaises(PrgError, get, "public.t1", 1, 'n')
1146        self.assertEqual(get("s1.t1", 1, 'n')['d'], 1)
1147        self.assertEqual(get("s3.t", 1, 'n')['d'], 3)
1148        query("set search_path to s2,s4")
1149        self.assertRaises(PrgError, get, "t1", 1, 'n')
1150        self.assertEqual(get("t4", 1, 'n')['d'], 4)
1151        self.assertRaises(PrgError, get, "t3", 1, 'n')
1152        self.assertEqual(get("t", 1, 'n')['d'], 2)
1153        self.assertEqual(get("s3.t3", 1, 'n')['d'], 3)
1154        query("set search_path to s1,s3")
1155        self.assertRaises(PrgError, get, "t2", 1, 'n')
1156        self.assertEqual(get("t3", 1, 'n')['d'], 3)
1157        self.assertRaises(PrgError, get, "t4", 1, 'n')
1158        self.assertEqual(get("t", 1, 'n')['d'], 1)
1159        self.assertEqual(get("s4.t4", 1, 'n')['d'], 4)
1160
1161    def testMangling(self):
1162        get = self.db.get
1163        query = self.db.query
1164        r = get("t", 1, 'n')
1165        self.assertIn('oid(public.t)', r)
1166        query("set search_path to s2")
1167        r = get("t2", 1, 'n')
1168        self.assertIn('oid(s2.t2)', r)
1169        query("set search_path to s3")
1170        r = get("t", 1, 'n')
1171        self.assertIn('oid(s3.t)', r)
1172
1173
1174if __name__ == '__main__':
1175    unittest.main()
Note: See TracBrowser for help on using the repository browser.