source: branches/4.x/tests/test_classic_dbwrapper.py @ 745

Last change on this file since 745 was 745, checked in by cito, 4 years ago

Add methods get/set_parameter to DB wrapper class

These methods can be used to get/set/reset run-time parameters,
even several at once.

Since this is pretty useful and will not break anything, I have
also back ported these additions to the 4.x branch.

Everything is well documented and tested, of course.

  • Property svn:executable set to *
  • Property svn:keywords set to Id
File size: 53.7 KB
Line 
1#! /usr/bin/python
2# -*- coding: utf-8 -*-
3
4"""Test the classic PyGreSQL interface.
5
6Sub-tests for the DB wrapper object.
7
8Contributed by Christoph Zwerschke.
9
10These tests need a database to test against.
11
12"""
13
14try:
15    import unittest2 as unittest  # for Python < 2.7
16except ImportError:
17    import unittest
18import os
19
20import sys
21
22import pg  # the module under test
23
24from decimal import Decimal
25
26# check whether the "with" statement is supported
27no_with = sys.version_info[:2] < (2, 5)
28
29# We need a database to test against.  If LOCAL_PyGreSQL.py exists we will
30# get our information from that.  Otherwise we use the defaults.
31# The current user must have create schema privilege on the database.
32dbname = 'unittest'
33dbhost = None
34dbport = 5432
35
36debug = False  # let DB wrapper print debugging output
37
38try:
39    from LOCAL_PyGreSQL import *
40except ImportError:
41    pass
42
43windows = os.name == 'nt'
44
45# There is a known a bug in libpq under Windows which can cause
46# the interface to crash when calling PQhost():
47do_not_ask_for_host = windows
48do_not_ask_for_host_reason = 'libpq issue on Windows'
49
50
51def DB():
52    """Create a DB wrapper object connecting to the test database."""
53    db = pg.DB(dbname, dbhost, dbport)
54    if debug:
55        db.debug = debug
56    db.query("set client_min_messages=warning")
57    return db
58
59
60class TestDBClassBasic(unittest.TestCase):
61    """Test existence of the DB class wrapped pg connection methods."""
62
63    def setUp(self):
64        self.db = DB()
65
66    def tearDown(self):
67        try:
68            self.db.close()
69        except pg.InternalError:
70            pass
71
72    def testAllDBAttributes(self):
73        attributes = [
74            'begin',
75            'cancel', 'clear', 'close', 'commit',
76            'db', 'dbname', 'debug', 'delete',
77            'end', 'endcopy', 'error',
78            'escape_bytea', 'escape_identifier',
79            'escape_literal', 'escape_string',
80            'fileno',
81            'get', 'get_attnames', 'get_databases',
82            'get_notice_receiver', 'get_parameter',
83            'get_relations', 'get_tables',
84            'getline', 'getlo', 'getnotify',
85            'has_table_privilege', 'host',
86            'insert', 'inserttable',
87            'locreate', 'loimport',
88            'notification_handler',
89            'options',
90            'parameter', 'pkey', 'port',
91            'protocol_version', 'putline',
92            'query',
93            'release', 'reopen', 'reset', 'rollback',
94            'savepoint', 'server_version',
95            'set_notice_receiver', 'set_parameter',
96            'source', 'start', 'status',
97            'transaction', 'tty',
98            'unescape_bytea', 'update',
99            'use_regtypes', 'user',
100        ]
101        if self.db.server_version < 90000:  # PostgreSQL < 9.0
102            attributes.remove('escape_identifier')
103            attributes.remove('escape_literal')
104        db_attributes = [a for a in dir(self.db)
105            if not a.startswith('_')]
106        self.assertEqual(attributes, db_attributes)
107
108    def testAttributeDb(self):
109        self.assertEqual(self.db.db.db, dbname)
110
111    def testAttributeDbname(self):
112        self.assertEqual(self.db.dbname, dbname)
113
114    def testAttributeError(self):
115        error = self.db.error
116        self.assertTrue(not error or 'krb5_' in error)
117        self.assertEqual(self.db.error, self.db.db.error)
118
119    @unittest.skipIf(do_not_ask_for_host, do_not_ask_for_host_reason)
120    def testAttributeHost(self):
121        def_host = 'localhost'
122        host = self.db.host
123        self.assertIsInstance(host, str)
124        self.assertEqual(host, dbhost or def_host)
125        self.assertEqual(host, self.db.db.host)
126
127    def testAttributeOptions(self):
128        no_options = ''
129        options = self.db.options
130        self.assertEqual(options, no_options)
131        self.assertEqual(options, self.db.db.options)
132
133    def testAttributePort(self):
134        def_port = 5432
135        port = self.db.port
136        self.assertIsInstance(port, int)
137        self.assertEqual(port, dbport or def_port)
138        self.assertEqual(port, self.db.db.port)
139
140    def testAttributeProtocolVersion(self):
141        protocol_version = self.db.protocol_version
142        self.assertIsInstance(protocol_version, int)
143        self.assertTrue(2 <= protocol_version < 4)
144        self.assertEqual(protocol_version, self.db.db.protocol_version)
145
146    def testAttributeServerVersion(self):
147        server_version = self.db.server_version
148        self.assertIsInstance(server_version, int)
149        self.assertTrue(70400 <= server_version < 100000)
150        self.assertEqual(server_version, self.db.db.server_version)
151
152    def testAttributeStatus(self):
153        status_ok = 1
154        status = self.db.status
155        self.assertIsInstance(status, int)
156        self.assertEqual(status, status_ok)
157        self.assertEqual(status, self.db.db.status)
158
159    def testAttributeTty(self):
160        def_tty = ''
161        tty = self.db.tty
162        self.assertIsInstance(tty, str)
163        self.assertEqual(tty, def_tty)
164        self.assertEqual(tty, self.db.db.tty)
165
166    def testAttributeUser(self):
167        no_user = 'Deprecated facility'
168        user = self.db.user
169        self.assertTrue(user)
170        self.assertIsInstance(user, str)
171        self.assertNotEqual(user, no_user)
172        self.assertEqual(user, self.db.db.user)
173
174    def testMethodEscapeLiteral(self):
175        if self.db.server_version < 90000:  # PostgreSQL < 9.0
176            self.skipTest('Escaping functions not supported')
177        self.assertEqual(self.db.escape_literal(''), "''")
178
179    def testMethodEscapeIdentifier(self):
180        if self.db.server_version < 90000:  # PostgreSQL < 9.0
181            self.skipTest('Escaping functions not supported')
182        self.assertEqual(self.db.escape_identifier(''), '""')
183
184    def testMethodEscapeString(self):
185        self.assertEqual(self.db.escape_string(''), '')
186
187    def testMethodEscapeBytea(self):
188        self.assertEqual(self.db.escape_bytea('').replace(
189            '\\x', '').replace('\\', ''), '')
190
191    def testMethodUnescapeBytea(self):
192        self.assertEqual(self.db.unescape_bytea(''), '')
193
194    def testMethodQuery(self):
195        query = self.db.query
196        query("select 1+1")
197        query("select 1+$1+$2", 2, 3)
198        query("select 1+$1+$2", (2, 3))
199        query("select 1+$1+$2", [2, 3])
200        query("select 1+$1", 1)
201
202    def testMethodQueryEmpty(self):
203        self.assertRaises(ValueError, self.db.query, '')
204
205    def testMethodQueryProgrammingError(self):
206        try:
207            self.db.query("select 1/0")
208        except pg.ProgrammingError, error:
209            self.assertEqual(error.sqlstate, '22012')
210
211    def testMethodEndcopy(self):
212        try:
213            self.db.endcopy()
214        except IOError:
215            pass
216
217    def testMethodClose(self):
218        self.db.close()
219        try:
220            self.db.reset()
221        except pg.Error:
222            pass
223        else:
224            self.fail('Reset should give an error for a closed connection')
225        self.assertRaises(pg.InternalError, self.db.close)
226        self.assertRaises(pg.InternalError, self.db.query, 'select 1')
227
228    def testExistingConnection(self):
229        db = pg.DB(self.db.db)
230        self.assertEqual(self.db.db, db.db)
231        self.assertTrue(db.db)
232        db.close()
233        self.assertTrue(db.db)
234        db.reopen()
235        self.assertTrue(db.db)
236        db.close()
237        self.assertTrue(db.db)
238        db = pg.DB(self.db)
239        self.assertEqual(self.db.db, db.db)
240        db = pg.DB(db=self.db.db)
241        self.assertEqual(self.db.db, db.db)
242
243        class DB2:
244            pass
245
246        db2 = DB2()
247        db2._cnx = self.db.db
248        db = pg.DB(db2)
249        self.assertEqual(self.db.db, db.db)
250
251
252class TestDBClass(unittest.TestCase):
253    """Test the methods of the DB class wrapped pg connection."""
254
255    @classmethod
256    def setUpClass(cls):
257        db = DB()
258        db.query("drop table if exists test cascade")
259        db.query("create table test ("
260            "i2 smallint, i4 integer, i8 bigint,"
261            " d numeric, f4 real, f8 double precision, m money,"
262            " v4 varchar(4), c4 char(4), t text)")
263        db.query("create or replace view test_view as"
264            " select i4, v4 from test")
265        db.close()
266
267    @classmethod
268    def tearDownClass(cls):
269        db = DB()
270        db.query("drop table test cascade")
271        db.close()
272
273    def setUp(self):
274        self.db = DB()
275        query = self.db.query
276        query('set client_encoding=utf8')
277        query('set standard_conforming_strings=on')
278        query("set lc_monetary='C'")
279        query("set datestyle='ISO,YMD'")
280        try:
281            query('set bytea_output=hex')
282        except pg.ProgrammingError:  # PostgreSQL < 9.0
283            pass
284
285    def tearDown(self):
286        self.db.close()
287
288    def testEscapeLiteral(self):
289        if self.db.server_version < 90000:  # PostgreSQL < 9.0
290            self.skipTest('Escaping functions not supported')
291        f = self.db.escape_literal
292        self.assertEqual(f("plain"), "'plain'")
293        self.assertEqual(f("that's k\xe4se"), "'that''s k\xe4se'")
294        self.assertEqual(f(r"It's fine to have a \ inside."),
295            r" E'It''s fine to have a \\ inside.'")
296        self.assertEqual(f('No "quotes" must be escaped.'),
297            "'No \"quotes\" must be escaped.'")
298
299    def testEscapeIdentifier(self):
300        if self.db.server_version < 90000:  # PostgreSQL < 9.0
301            self.skipTest('Escaping functions not supported')
302        f = self.db.escape_identifier
303        self.assertEqual(f("plain"), '"plain"')
304        self.assertEqual(f("that's k\xe4se"), '"that\'s k\xe4se"')
305        self.assertEqual(f(r"It's fine to have a \ inside."),
306            '"It\'s fine to have a \\ inside."')
307        self.assertEqual(f('All "quotes" must be escaped.'),
308            '"All ""quotes"" must be escaped."')
309
310    def testEscapeString(self):
311        f = self.db.escape_string
312        self.assertEqual(f("plain"), "plain")
313        self.assertEqual(f("that's k\xe4se"), "that''s k\xe4se")
314        self.assertEqual(f(r"It's fine to have a \ inside."),
315            r"It''s fine to have a \ inside.")
316
317    def testEscapeBytea(self):
318        f = self.db.escape_bytea
319        # note that escape_byte always returns hex output since PostgreSQL 9.0,
320        # regardless of the bytea_output setting
321        if self.db.server_version < 90000:
322            self.assertEqual(f("plain"), r"plain")
323            self.assertEqual(f("that's k\xe4se"), r"that''s k\344se")
324            self.assertEqual(f('O\x00ps\xff!'), r"O\000ps\377!")
325        else:
326            self.assertEqual(f("plain"), r"\x706c61696e")
327            self.assertEqual(f("that's k\xe4se"), r"\x746861742773206be47365")
328            self.assertEqual(f('O\x00ps\xff!'), r"\x4f007073ff21")
329
330    def testUnescapeBytea(self):
331        f = self.db.unescape_bytea
332        self.assertEqual(f("plain"), "plain")
333        self.assertEqual(f("that's k\\344se"), "that's k\xe4se")
334        self.assertEqual(f(r'O\000ps\377!'), 'O\x00ps\xff!')
335        self.assertEqual(f(r"\\x706c61696e"), r"\x706c61696e")
336        self.assertEqual(f(r"\\x746861742773206be47365"),
337            r"\x746861742773206be47365")
338        self.assertEqual(f(r"\\x4f007073ff21"), r"\x4f007073ff21")
339
340    def testQuote(self):
341        f = self.db._quote
342        self.assertEqual(f(None, None), 'NULL')
343        self.assertEqual(f(None, 'int'), 'NULL')
344        self.assertEqual(f(None, 'float'), 'NULL')
345        self.assertEqual(f(None, 'num'), 'NULL')
346        self.assertEqual(f(None, 'money'), 'NULL')
347        self.assertEqual(f(None, 'bool'), 'NULL')
348        self.assertEqual(f(None, 'date'), 'NULL')
349        self.assertEqual(f('', 'int'), 'NULL')
350        self.assertEqual(f('', 'float'), 'NULL')
351        self.assertEqual(f('', 'num'), 'NULL')
352        self.assertEqual(f('', 'money'), 'NULL')
353        self.assertEqual(f('', 'bool'), 'NULL')
354        self.assertEqual(f('', 'date'), 'NULL')
355        self.assertEqual(f('', 'text'), "''")
356        self.assertEqual(f(0, 'int'), '0')
357        self.assertEqual(f(0, 'num'), '0')
358        self.assertEqual(f(1, 'int'), '1')
359        self.assertEqual(f(1, 'num'), '1')
360        self.assertEqual(f(-1, 'int'), '-1')
361        self.assertEqual(f(-1, 'num'), '-1')
362        self.assertEqual(f(123456789, 'int'), '123456789')
363        self.assertEqual(f(123456987, 'num'), '123456987')
364        self.assertEqual(f(1.23654789, 'num'), '1.23654789')
365        self.assertEqual(f(12365478.9, 'num'), '12365478.9')
366        self.assertEqual(f('123456789', 'num'), '123456789')
367        self.assertEqual(f('1.23456789', 'num'), '1.23456789')
368        self.assertEqual(f('12345678.9', 'num'), '12345678.9')
369        self.assertEqual(f(123, 'money'), '123')
370        self.assertEqual(f('123', 'money'), '123')
371        self.assertEqual(f(123.45, 'money'), '123.45')
372        self.assertEqual(f('123.45', 'money'), '123.45')
373        self.assertEqual(f(123.454, 'money'), '123.454')
374        self.assertEqual(f('123.454', 'money'), '123.454')
375        self.assertEqual(f(123.456, 'money'), '123.456')
376        self.assertEqual(f('123.456', 'money'), '123.456')
377        self.assertEqual(f('f', 'bool'), "'f'")
378        self.assertEqual(f('F', 'bool'), "'f'")
379        self.assertEqual(f('false', 'bool'), "'f'")
380        self.assertEqual(f('False', 'bool'), "'f'")
381        self.assertEqual(f('FALSE', 'bool'), "'f'")
382        self.assertEqual(f(0, 'bool'), "'f'")
383        self.assertEqual(f('0', 'bool'), "'f'")
384        self.assertEqual(f('-', 'bool'), "'f'")
385        self.assertEqual(f('n', 'bool'), "'f'")
386        self.assertEqual(f('N', 'bool'), "'f'")
387        self.assertEqual(f('no', 'bool'), "'f'")
388        self.assertEqual(f('off', 'bool'), "'f'")
389        self.assertEqual(f('t', 'bool'), "'t'")
390        self.assertEqual(f('T', 'bool'), "'t'")
391        self.assertEqual(f('true', 'bool'), "'t'")
392        self.assertEqual(f('True', 'bool'), "'t'")
393        self.assertEqual(f('TRUE', 'bool'), "'t'")
394        self.assertEqual(f(1, 'bool'), "'t'")
395        self.assertEqual(f(2, 'bool'), "'t'")
396        self.assertEqual(f(-1, 'bool'), "'t'")
397        self.assertEqual(f(0.5, 'bool'), "'t'")
398        self.assertEqual(f('1', 'bool'), "'t'")
399        self.assertEqual(f('y', 'bool'), "'t'")
400        self.assertEqual(f('Y', 'bool'), "'t'")
401        self.assertEqual(f('yes', 'bool'), "'t'")
402        self.assertEqual(f('on', 'bool'), "'t'")
403        self.assertEqual(f('01.01.2000', 'date'), "'01.01.2000'")
404        self.assertEqual(f(123, 'text'), "'123'")
405        self.assertEqual(f(1.23, 'text'), "'1.23'")
406        self.assertEqual(f('abc', 'text'), "'abc'")
407        self.assertEqual(f("ab'c", 'text'), "'ab''c'")
408        self.assertEqual(f('ab\\c', 'text'), "'ab\\c'")
409        self.assertEqual(f("a\\b'c", 'text'), "'a\\b''c'")
410        self.db.query('set standard_conforming_strings=off')
411        self.assertEqual(f('ab\\c', 'text'), "'ab\\\\c'")
412        self.assertEqual(f("a\\b'c", 'text'), "'a\\\\b''c'")
413
414    def testGetParameter(self):
415        f = self.db.get_parameter
416        self.assertRaises(TypeError, f)
417        self.assertRaises(TypeError, f, None)
418        self.assertRaises(TypeError, f, 42)
419        self.assertRaises(pg.ProgrammingError, f, 'this_does_not_exist')
420        r = f('standard_conforming_strings')
421        self.assertEqual(r, 'on')
422        r = f('lc_monetary')
423        self.assertEqual(r, 'C')
424        r = f('datestyle')
425        self.assertEqual(r, 'ISO, YMD')
426        r = f('bytea_output')
427        self.assertEqual(r, 'hex')
428        r = f(('bytea_output', 'lc_monetary'))
429        self.assertIsInstance(r, list)
430        self.assertEqual(r, ['hex', 'C'])
431        r = f(['standard_conforming_strings', 'datestyle', 'bytea_output'])
432        self.assertEqual(r, ['on', 'ISO, YMD', 'hex'])
433        s = dict.fromkeys(('bytea_output', 'lc_monetary'))
434        r = f(s)
435        self.assertIs(r, s)
436        self.assertEqual(r, {'bytea_output': 'hex', 'lc_monetary': 'C'})
437        s = dict.fromkeys(('Bytea_Output', 'LC_Monetary'))
438        r = f(s)
439        self.assertIs(r, s)
440        self.assertEqual(r, {'Bytea_Output': 'hex', 'LC_Monetary': 'C'})
441
442    def testGetParameterServerVersion(self):
443        r = self.db.get_parameter('server_version_num')
444        self.assertIsInstance(r, str)
445        s = self.db.server_version
446        self.assertIsInstance(s, int)
447        self.assertEqual(r, str(s))
448
449    def testGetParameterAll(self):
450        f = self.db.get_parameter
451        r = f('all')
452        self.assertIsInstance(r, dict)
453        self.assertEqual(r['standard_conforming_strings'], 'on')
454        self.assertEqual(r['lc_monetary'], 'C')
455        self.assertEqual(r['DateStyle'], 'ISO, YMD')
456        self.assertEqual(r['bytea_output'], 'hex')
457
458    def testSetParameter(self):
459        f = self.db.set_parameter
460        g = self.db.get_parameter
461        self.assertRaises(TypeError, f)
462        self.assertRaises(TypeError, f, None)
463        self.assertRaises(TypeError, f, 42)
464        self.assertRaises(pg.ProgrammingError, f, 'this_does_not_exist')
465        f('standard_conforming_strings', 'off')
466        self.assertEqual(g('standard_conforming_strings'), 'off')
467        f('datestyle', 'ISO, DMY')
468        self.assertEqual(g('datestyle'), 'ISO, DMY')
469        f(('standard_conforming_strings', 'datestyle'), ('on', 'ISO, YMD'))
470        self.assertEqual(g('standard_conforming_strings'), 'on')
471        self.assertEqual(g('datestyle'), 'ISO, YMD')
472        f(['standard_conforming_strings', 'datestyle'], ['off', 'ISO, DMY'])
473        self.assertEqual(g('standard_conforming_strings'), 'off')
474        self.assertEqual(g('datestyle'), 'ISO, DMY')
475        f({'standard_conforming_strings': 'on', 'datestyle': 'ISO, YMD'})
476        self.assertEqual(g('standard_conforming_strings'), 'on')
477        self.assertEqual(g('datestyle'), 'ISO, YMD')
478        f(('default_with_oids', 'standard_conforming_strings'), 'off')
479        self.assertEqual(g('default_with_oids'), 'off')
480        self.assertEqual(g('standard_conforming_strings'), 'off')
481        f(['default_with_oids', 'standard_conforming_strings'], 'on')
482        self.assertEqual(g('default_with_oids'), 'on')
483        self.assertEqual(g('standard_conforming_strings'), 'on')
484
485    def testResetParameter(self):
486        db = DB()
487        f = db.set_parameter
488        g = db.get_parameter
489        r = g('default_with_oids')
490        self.assertIn(r, ('on', 'off'))
491        dwi, not_dwi = r, r == 'on' and 'off' or 'on'
492        r = g('standard_conforming_strings')
493        self.assertIn(r, ('on', 'off'))
494        scs, not_scs = r, r == 'on' and 'off' or 'on'
495        f('default_with_oids', not_dwi)
496        f('standard_conforming_strings', not_scs)
497        self.assertEqual(g('default_with_oids'), not_dwi)
498        self.assertEqual(g('standard_conforming_strings'), not_scs)
499        f('default_with_oids')
500        f('standard_conforming_strings', None)
501        self.assertEqual(g('default_with_oids'), dwi)
502        self.assertEqual(g('standard_conforming_strings'), scs)
503        f('default_with_oids', not_dwi)
504        f('standard_conforming_strings', not_scs)
505        self.assertEqual(g('default_with_oids'), not_dwi)
506        self.assertEqual(g('standard_conforming_strings'), not_scs)
507        f(('default_with_oids', 'standard_conforming_strings'))
508        self.assertEqual(g('default_with_oids'), dwi)
509        self.assertEqual(g('standard_conforming_strings'), scs)
510        f('default_with_oids', not_dwi)
511        f('standard_conforming_strings', not_scs)
512        self.assertEqual(g('default_with_oids'), not_dwi)
513        self.assertEqual(g('standard_conforming_strings'), not_scs)
514        f(['default_with_oids', 'standard_conforming_strings'], None)
515        self.assertEqual(g('default_with_oids'), dwi)
516        self.assertEqual(g('standard_conforming_strings'), scs)
517
518    def testResetParameterAll(self):
519        db = DB()
520        f = db.set_parameter
521        self.assertRaises(ValueError, f, 'all', 0)
522        self.assertRaises(ValueError, f, 'all', 'off')
523        g = db.get_parameter
524        r = g('default_with_oids')
525        self.assertIn(r, ('on', 'off'))
526        dwi, not_dwi = r, r == 'on' and 'off' or 'on'
527        r = g('standard_conforming_strings')
528        self.assertIn(r, ('on', 'off'))
529        scs, not_scs = r, r == 'on' and 'off' or 'on'
530        f('default_with_oids', not_dwi)
531        f('standard_conforming_strings', not_scs)
532        self.assertEqual(g('default_with_oids'), not_dwi)
533        self.assertEqual(g('standard_conforming_strings'), not_scs)
534        f('all')
535        self.assertEqual(g('default_with_oids'), dwi)
536        self.assertEqual(g('standard_conforming_strings'), scs)
537
538    def testSetParameterLocal(self):
539        f = self.db.set_parameter
540        g = self.db.get_parameter
541        self.assertEqual(g('standard_conforming_strings'), 'on')
542        self.db.begin()
543        f('standard_conforming_strings', 'off', local=True)
544        self.assertEqual(g('standard_conforming_strings'), 'off')
545        self.db.end()
546        self.assertEqual(g('standard_conforming_strings'), 'on')
547
548    def testSetParameterSession(self):
549        f = self.db.set_parameter
550        g = self.db.get_parameter
551        self.assertEqual(g('standard_conforming_strings'), 'on')
552        self.db.begin()
553        f('standard_conforming_strings', 'off', local=False)
554        self.assertEqual(g('standard_conforming_strings'), 'off')
555        self.db.end()
556        self.assertEqual(g('standard_conforming_strings'), 'off')
557
558    def testQuery(self):
559        query = self.db.query
560        query("drop table if exists test_table")
561        q = "create table test_table (n integer) with oids"
562        r = query(q)
563        self.assertIsNone(r)
564        q = "insert into test_table values (1)"
565        r = query(q)
566        self.assertIsInstance(r, int)
567        q = "insert into test_table select 2"
568        r = query(q)
569        self.assertIsInstance(r, int)
570        oid = r
571        q = "select oid from test_table where n=2"
572        r = query(q).getresult()
573        self.assertEqual(len(r), 1)
574        r = r[0]
575        self.assertEqual(len(r), 1)
576        r = r[0]
577        self.assertEqual(r, oid)
578        q = "insert into test_table select 3 union select 4 union select 5"
579        r = query(q)
580        self.assertIsInstance(r, str)
581        self.assertEqual(r, '3')
582        q = "update test_table set n=4 where n<5"
583        r = query(q)
584        self.assertIsInstance(r, str)
585        self.assertEqual(r, '4')
586        q = "delete from test_table"
587        r = query(q)
588        self.assertIsInstance(r, str)
589        self.assertEqual(r, '5')
590        query("drop table test_table")
591
592    def testMultipleQueries(self):
593        self.assertEqual(self.db.query(
594            "create temporary table test_multi (n integer);"
595            "insert into test_multi values (4711);"
596            "select n from test_multi").getresult()[0][0], 4711)
597
598    def testQueryWithParams(self):
599        query = self.db.query
600        query("drop table if exists test_table")
601        q = "create table test_table (n1 integer, n2 integer) with oids"
602        query(q)
603        q = "insert into test_table values ($1, $2)"
604        r = query(q, (1, 2))
605        self.assertIsInstance(r, int)
606        r = query(q, [3, 4])
607        self.assertIsInstance(r, int)
608        r = query(q, [5, 6])
609        self.assertIsInstance(r, int)
610        q = "select * from test_table order by 1, 2"
611        self.assertEqual(query(q).getresult(),
612            [(1, 2), (3, 4), (5, 6)])
613        q = "select * from test_table where n1=$1 and n2=$2"
614        self.assertEqual(query(q, 3, 4).getresult(), [(3, 4)])
615        q = "update test_table set n2=$2 where n1=$1"
616        r = query(q, 3, 7)
617        self.assertEqual(r, '1')
618        q = "select * from test_table order by 1, 2"
619        self.assertEqual(query(q).getresult(),
620            [(1, 2), (3, 7), (5, 6)])
621        q = "delete from test_table where n2!=$1"
622        r = query(q, 4)
623        self.assertEqual(r, '3')
624        query("drop table test_table")
625
626    def testEmptyQuery(self):
627        self.assertRaises(ValueError, self.db.query, '')
628
629    def testQueryProgrammingError(self):
630        try:
631            self.db.query("select 1/0")
632        except pg.ProgrammingError, error:
633            self.assertEqual(error.sqlstate, '22012')
634
635    def testPkey(self):
636        query = self.db.query
637        for n in range(4):
638            query("drop table if exists pkeytest%d" % n)
639        query("create table pkeytest0 ("
640            "a smallint)")
641        query("create table pkeytest1 ("
642            "b smallint primary key)")
643        query("create table pkeytest2 ("
644            "c smallint, d smallint primary key)")
645        query("create table pkeytest3 ("
646            "e smallint, f smallint, g smallint,"
647            " h smallint, i smallint,"
648            " primary key (f,h))")
649        pkey = self.db.pkey
650        self.assertRaises(KeyError, pkey, 'pkeytest0')
651        self.assertEqual(pkey('pkeytest1'), 'b')
652        self.assertEqual(pkey('pkeytest2'), 'd')
653        self.assertEqual(pkey('pkeytest3'), frozenset('fh'))
654        self.assertEqual(pkey('pkeytest0', 'none'), 'none')
655        self.assertEqual(pkey('pkeytest0'), 'none')
656        pkey(None, {'t': 'a', 'n.t': 'b'})
657        self.assertEqual(pkey('t'), 'a')
658        self.assertEqual(pkey('n.t'), 'b')
659        self.assertRaises(KeyError, pkey, 'pkeytest0')
660        for n in range(4):
661            query("drop table pkeytest%d" % n)
662
663    def testGetDatabases(self):
664        databases = self.db.get_databases()
665        self.assertIn('template0', databases)
666        self.assertIn('template1', databases)
667        self.assertNotIn('not existing database', databases)
668        self.assertIn('postgres', databases)
669        self.assertIn(dbname, databases)
670
671    def testGetTables(self):
672        get_tables = self.db.get_tables
673        result1 = get_tables()
674        self.assertIsInstance(result1, list)
675        for t in result1:
676            t = t.split('.', 1)
677            self.assertGreaterEqual(len(t), 2)
678            if len(t) > 2:
679                self.assertTrue(t[1].startswith('"'))
680            t = t[0]
681            self.assertNotEqual(t, 'information_schema')
682            self.assertFalse(t.startswith('pg_'))
683        tables = ('"A very Special Name"',
684            '"A_MiXeD_quoted_NaMe"', 'a1', 'a2',
685            'A_MiXeD_NaMe', '"another special name"',
686            'averyveryveryveryveryveryverylongtablename',
687            'b0', 'b3', 'x', 'xx', 'xXx', 'y', 'z')
688        for t in tables:
689            self.db.query('drop table if exists %s' % t)
690            self.db.query("create table %s"
691                " as select 0" % t)
692        result3 = get_tables()
693        result2 = []
694        for t in result3:
695            if t not in result1:
696                result2.append(t)
697        result3 = []
698        for t in tables:
699            if not t.startswith('"'):
700                t = t.lower()
701            result3.append('public.' + t)
702        self.assertEqual(result2, result3)
703        for t in result2:
704            self.db.query('drop table %s' % t)
705        result2 = get_tables()
706        self.assertEqual(result2, result1)
707
708    def testGetRelations(self):
709        get_relations = self.db.get_relations
710        result = get_relations()
711        self.assertIn('public.test', result)
712        self.assertIn('public.test_view', result)
713        result = get_relations('rv')
714        self.assertIn('public.test', result)
715        self.assertIn('public.test_view', result)
716        result = get_relations('r')
717        self.assertIn('public.test', result)
718        self.assertNotIn('public.test_view', result)
719        result = get_relations('v')
720        self.assertNotIn('public.test', result)
721        self.assertIn('public.test_view', result)
722        result = get_relations('cisSt')
723        self.assertNotIn('public.test', result)
724        self.assertNotIn('public.test_view', result)
725
726    def testGetAttnames(self):
727        self.assertRaises(pg.ProgrammingError,
728            self.db.get_attnames, 'does_not_exist')
729        self.assertRaises(pg.ProgrammingError,
730            self.db.get_attnames, 'has.too.many.dots')
731        for table in ('attnames_test_table', 'test table for attnames'):
732            self.db.query('drop table if exists "%s"' % table)
733            self.db.query('create table "%s" ('
734                ' a smallint, b integer, c bigint,'
735                ' e numeric, f float, f2 double precision, m money,'
736                ' x smallint, y smallint, z smallint,'
737                ' Normal_NaMe smallint, "Special Name" smallint,'
738                ' t text, u char(2), v varchar(2),'
739                ' primary key (y, u)) with oids' % table)
740            attributes = self.db.get_attnames(table)
741            result = {'a': 'int', 'c': 'int', 'b': 'int',
742                'e': 'num', 'f': 'float', 'f2': 'float', 'm': 'money',
743                'normal_name': 'int', 'Special Name': 'int',
744                'u': 'text', 't': 'text', 'v': 'text',
745                'y': 'int', 'x': 'int', 'z': 'int', 'oid': 'int'}
746            self.assertEqual(attributes, result)
747            self.db.query('drop table "%s"' % table)
748
749    def testHasTablePrivilege(self):
750        can = self.db.has_table_privilege
751        self.assertEqual(can('test'), True)
752        self.assertEqual(can('test', 'select'), True)
753        self.assertEqual(can('test', 'SeLeCt'), True)
754        self.assertEqual(can('test', 'SELECT'), True)
755        self.assertEqual(can('test', 'insert'), True)
756        self.assertEqual(can('test', 'update'), True)
757        self.assertEqual(can('test', 'delete'), True)
758        self.assertRaises(pg.ProgrammingError, can, 'test', 'foobar')
759        self.assertRaises(pg.ProgrammingError, can, 'table_does_not_exist')
760
761    def testGet(self):
762        get = self.db.get
763        query = self.db.query
764        for table in ('get_test_table', 'test table for get'):
765            query('drop table if exists "%s"' % table)
766            query('create table "%s" ('
767                "n integer, t text) with oids" % table)
768            for n, t in enumerate('xyz'):
769                query('insert into "%s" values('"%d, '%s')"
770                    % (table, n + 1, t))
771            self.assertRaises(pg.ProgrammingError, get, table, 2)
772            r = get(table, 2, 'n')
773            oid_table = table
774            if ' ' in table:
775                oid_table = '"%s"' % oid_table
776            oid_table = 'oid(public.%s)' % oid_table
777            self.assertIn(oid_table, r)
778            oid = r[oid_table]
779            self.assertIsInstance(oid, int)
780            result = {'t': 'y', 'n': 2, oid_table: oid}
781            self.assertEqual(r, result)
782            self.assertEqual(get(table + ' *', 2, 'n'), r)
783            self.assertEqual(get(table, oid, 'oid')['t'], 'y')
784            self.assertEqual(get(table, 1, 'n')['t'], 'x')
785            self.assertEqual(get(table, 3, 'n')['t'], 'z')
786            self.assertEqual(get(table, 2, 'n')['t'], 'y')
787            self.assertRaises(pg.DatabaseError, get, table, 4, 'n')
788            r['n'] = 3
789            self.assertEqual(get(table, r, 'n')['t'], 'z')
790            self.assertEqual(get(table, 1, 'n')['t'], 'x')
791            query('alter table "%s" alter n set not null' % table)
792            query('alter table "%s" add primary key (n)' % table)
793            self.assertEqual(get(table, 3)['t'], 'z')
794            self.assertEqual(get(table, 1)['t'], 'x')
795            self.assertEqual(get(table, 2)['t'], 'y')
796            r['n'] = 1
797            self.assertEqual(get(table, r)['t'], 'x')
798            r['n'] = 3
799            self.assertEqual(get(table, r)['t'], 'z')
800            r['n'] = 2
801            self.assertEqual(get(table, r)['t'], 'y')
802            query('drop table "%s"' % table)
803
804    def testGetWithCompositeKey(self):
805        get = self.db.get
806        query = self.db.query
807        table = 'get_test_table_1'
808        query("drop table if exists %s" % table)
809        query("create table %s ("
810            "n integer, t text, primary key (n))" % table)
811        for n, t in enumerate('abc'):
812            query("insert into %s values("
813                "%d, '%s')" % (table, n + 1, t))
814        self.assertEqual(get(table, 2)['t'], 'b')
815        query("drop table %s" % table)
816        table = 'get_test_table_2'
817        query("drop table if exists %s" % table)
818        query("create table %s ("
819            "n integer, m integer, t text, primary key (n, m))" % table)
820        for n in range(3):
821            for m in range(2):
822                t = chr(ord('a') + 2 * n + m)
823                query("insert into %s values("
824                    "%d, %d, '%s')" % (table, n + 1, m + 1, t))
825        self.assertRaises(pg.ProgrammingError, get, table, 2)
826        self.assertEqual(get(table, dict(n=2, m=2))['t'], 'd')
827        self.assertEqual(get(table, dict(n=1, m=2),
828                             ('n', 'm'))['t'], 'b')
829        self.assertEqual(get(table, dict(n=3, m=2),
830                             frozenset(['n', 'm']))['t'], 'f')
831        query("drop table %s" % table)
832
833    def testGetFromView(self):
834        self.db.query('delete from test where i4=14')
835        self.db.query('insert into test (i4, v4) values('
836            "14, 'abc4')")
837        r = self.db.get('test_view', 14, 'i4')
838        self.assertIn('v4', r)
839        self.assertEqual(r['v4'], 'abc4')
840
841    def testGetLittleBobbyTables(self):
842        get = self.db.get
843        query = self.db.query
844        query("drop table if exists test_students")
845        query("create table test_students (firstname varchar primary key,"
846            " nickname varchar, grade char(2))")
847        query("insert into test_students values ("
848              "'D''Arcy', 'Darcey', 'A+')")
849        query("insert into test_students values ("
850              "'Sheldon', 'Moonpie', 'A+')")
851        query("insert into test_students values ("
852              "'Robert', 'Little Bobby Tables', 'D-')")
853        r = get('test_students', 'Sheldon')
854        self.assertEqual(r, dict(
855            firstname="Sheldon", nickname='Moonpie', grade='A+'))
856        r = get('test_students', 'Robert')
857        self.assertEqual(r, dict(
858            firstname="Robert", nickname='Little Bobby Tables', grade='D-'))
859        r = get('test_students', "D'Arcy")
860        self.assertEqual(r, dict(
861            firstname="D'Arcy", nickname='Darcey', grade='A+'))
862        try:
863            get('test_students', "D' Arcy")
864        except pg.DatabaseError, error:
865            self.assertEqual(str(error),
866                'No such record in public.test_students where firstname = '
867                "'D'' Arcy'")
868        try:
869            get('test_students', "Robert'); TRUNCATE TABLE test_students;--")
870        except pg.DatabaseError, error:
871            self.assertEqual(str(error),
872                'No such record in public.test_students where firstname = '
873                "'Robert''); TRUNCATE TABLE test_students;--'")
874        q = "select * from test_students order by 1 limit 4"
875        r = query(q).getresult()
876        self.assertEqual(len(r), 3)
877        self.assertEqual(r[1][2], 'D-')
878        query('drop table test_students')
879
880    def testInsert(self):
881        insert = self.db.insert
882        query = self.db.query
883        server_version = self.db.server_version
884        for table in ('insert_test_table', 'test table for insert'):
885            query('drop table if exists "%s"' % table)
886            query('create table "%s" ('
887                "i2 smallint, i4 integer, i8 bigint,"
888                " d numeric, f4 real, f8 double precision, m money,"
889                " v4 varchar(4), c4 char(4), t text,"
890                " b boolean, ts timestamp) with oids" % table)
891            oid_table = table
892            if ' ' in table:
893                oid_table = '"%s"' % oid_table
894            oid_table = 'oid(public.%s)' % oid_table
895            tests = [dict(i2=None, i4=None, i8=None),
896                (dict(i2='', i4='', i8=''), dict(i2=None, i4=None, i8=None)),
897                (dict(i2=0, i4=0, i8=0), dict(i2=0, i4=0, i8=0)),
898                dict(i2=42, i4=123456, i8=9876543210),
899                dict(i2=2 ** 15 - 1,
900                    i4=int(2 ** 31 - 1), i8=long(2 ** 63 - 1)),
901                dict(d=None), (dict(d=''), dict(d=None)),
902                dict(d=Decimal(0)), (dict(d=0), dict(d=Decimal(0))),
903                dict(f4=None, f8=None), dict(f4=0, f8=0),
904                (dict(f4='', f8=''), dict(f4=None, f8=None)),
905                (dict(d=1234.5, f4=1234.5, f8=1234.5),
906                      dict(d=Decimal('1234.5'))),
907                dict(d=Decimal('123.456789'), f4=12.375, f8=123.4921875),
908                dict(d=Decimal('123456789.9876543212345678987654321')),
909                dict(m=None), (dict(m=''), dict(m=None)),
910                dict(m=Decimal('-1234.56')),
911                (dict(m='-1234.56'), dict(m=Decimal('-1234.56'))),
912                dict(m=Decimal('1234.56')), dict(m=Decimal('123456')),
913                (dict(m='1234.56'), dict(m=Decimal('1234.56'))),
914                (dict(m=1234.5), dict(m=Decimal('1234.5'))),
915                (dict(m=-1234.5), dict(m=Decimal('-1234.5'))),
916                (dict(m=123456), dict(m=Decimal('123456'))),
917                (dict(m='1234567.89'), dict(m=Decimal('1234567.89'))),
918                dict(b=None), (dict(b=''), dict(b=None)),
919                dict(b='f'), dict(b='t'),
920                (dict(b=0), dict(b='f')), (dict(b=1), dict(b='t')),
921                (dict(b=False), dict(b='f')), (dict(b=True), dict(b='t')),
922                (dict(b='0'), dict(b='f')), (dict(b='1'), dict(b='t')),
923                (dict(b='n'), dict(b='f')), (dict(b='y'), dict(b='t')),
924                (dict(b='no'), dict(b='f')), (dict(b='yes'), dict(b='t')),
925                (dict(b='off'), dict(b='f')), (dict(b='on'), dict(b='t')),
926                dict(v4=None, c4=None, t=None),
927                (dict(v4='', c4='', t=''), dict(c4=' ' * 4)),
928                dict(v4='1234', c4='1234', t='1234' * 10),
929                dict(v4='abcd', c4='abcd', t='abcdefg'),
930                (dict(v4='abc', c4='abc', t='abc'), dict(c4='abc ')),
931                dict(ts=None), (dict(ts=''), dict(ts=None)),
932                (dict(ts=0), dict(ts=None)), (dict(ts=False), dict(ts=None)),
933                dict(ts='2012-12-21 00:00:00'),
934                (dict(ts='2012-12-21'), dict(ts='2012-12-21 00:00:00')),
935                dict(ts='2012-12-21 12:21:12'),
936                dict(ts='2013-01-05 12:13:14'),
937                dict(ts='current_timestamp')]
938            for test in tests:
939                if isinstance(test, dict):
940                    data = test
941                    change = {}
942                else:
943                    data, change = test
944                expect = data.copy()
945                expect.update(change)
946                if data.get('m') and server_version < 90100:
947                    # PostgreSQL < 9.1 cannot directly convert numbers to money
948                    data['m'] = "'%s'::money" % data['m']
949                self.assertEqual(insert(table, data), data)
950                self.assertIn(oid_table, data)
951                oid = data[oid_table]
952                self.assertIsInstance(oid, int)
953                data = dict(item for item in data.iteritems()
954                    if item[0] in expect)
955                ts = expect.get('ts')
956                if ts == 'current_timestamp':
957                    ts = expect['ts'] = data['ts']
958                    if len(ts) > 19:
959                        self.assertEqual(ts[19], '.')
960                        ts = ts[:19]
961                    else:
962                        self.assertEqual(len(ts), 19)
963                    self.assertTrue(ts[:4].isdigit())
964                    self.assertEqual(ts[4], '-')
965                    self.assertEqual(ts[10], ' ')
966                    self.assertTrue(ts[11:13].isdigit())
967                    self.assertEqual(ts[13], ':')
968                self.assertEqual(data, expect)
969                data = query(
970                    'select oid,* from "%s"' % table).dictresult()[0]
971                self.assertEqual(data['oid'], oid)
972                data = dict(item for item in data.iteritems()
973                    if item[0] in expect)
974                self.assertEqual(data, expect)
975                query('delete from "%s"' % table)
976            query('drop table "%s"' % table)
977
978    def testUpdate(self):
979        update = self.db.update
980        query = self.db.query
981        for table in ('update_test_table', 'test table for update'):
982            query('drop table if exists "%s"' % table)
983            query('create table "%s" ('
984                "n integer, t text) with oids" % table)
985            for n, t in enumerate('xyz'):
986                query('insert into "%s" values('
987                    "%d, '%s')" % (table, n + 1, t))
988            self.assertRaises(pg.ProgrammingError, self.db.get, table, 2)
989            r = self.db.get(table, 2, 'n')
990            r['t'] = 'u'
991            s = update(table, r)
992            self.assertEqual(s, r)
993            r = query('select t from "%s" where n=2' % table
994                      ).getresult()[0][0]
995            self.assertEqual(r, 'u')
996            query('drop table "%s"' % table)
997
998    def testUpdateWithCompositeKey(self):
999        update = self.db.update
1000        query = self.db.query
1001        table = 'update_test_table_1'
1002        query("drop table if exists %s" % table)
1003        query("create table %s ("
1004            "n integer, t text, primary key (n))" % table)
1005        for n, t in enumerate('abc'):
1006            query("insert into %s values("
1007                "%d, '%s')" % (table, n + 1, t))
1008        self.assertRaises(pg.ProgrammingError, update,
1009                          table, dict(t='b'))
1010        self.assertEqual(update(table, dict(n=2, t='d'))['t'], 'd')
1011        r = query('select t from "%s" where n=2' % table
1012                  ).getresult()[0][0]
1013        self.assertEqual(r, 'd')
1014        query("drop table %s" % table)
1015        table = 'update_test_table_2'
1016        query("drop table if exists %s" % table)
1017        query("create table %s ("
1018            "n integer, m integer, t text, primary key (n, m))" % table)
1019        for n in range(3):
1020            for m in range(2):
1021                t = chr(ord('a') + 2 * n + m)
1022                query("insert into %s values("
1023                    "%d, %d, '%s')" % (table, n + 1, m + 1, t))
1024        self.assertRaises(pg.ProgrammingError, update,
1025                          table, dict(n=2, t='b'))
1026        self.assertEqual(update(table,
1027                                dict(n=2, m=2, t='x'))['t'], 'x')
1028        r = [r[0] for r in query('select t from "%s" where n=2'
1029            ' order by m' % table).getresult()]
1030        self.assertEqual(r, ['c', 'x'])
1031        query("drop table %s" % table)
1032
1033    def testClear(self):
1034        clear = self.db.clear
1035        query = self.db.query
1036        for table in ('clear_test_table', 'test table for clear'):
1037            query('drop table if exists "%s"' % table)
1038            query('create table "%s" ('
1039                "n integer, b boolean, d date, t text)" % table)
1040            r = clear(table)
1041            result = {'n': 0, 'b': 'f', 'd': '', 't': ''}
1042            self.assertEqual(r, result)
1043            r['a'] = r['n'] = 1
1044            r['d'] = r['t'] = 'x'
1045            r['b'] = 't'
1046            r['oid'] = 1L
1047            r = clear(table, r)
1048            result = {'a': 1, 'n': 0, 'b': 'f', 'd': '', 't': '', 'oid': 1L}
1049            self.assertEqual(r, result)
1050            query('drop table "%s"' % table)
1051
1052    def testDelete(self):
1053        delete = self.db.delete
1054        query = self.db.query
1055        for table in ('delete_test_table', 'test table for delete'):
1056            query('drop table if exists "%s"' % table)
1057            query('create table "%s" ('
1058                "n integer, t text) with oids" % table)
1059            for n, t in enumerate('xyz'):
1060                query('insert into "%s" values('
1061                    "%d, '%s')" % (table, n + 1, t))
1062            self.assertRaises(pg.ProgrammingError, self.db.get, table, 2)
1063            r = self.db.get(table, 1, 'n')
1064            s = delete(table, r)
1065            self.assertEqual(s, 1)
1066            r = self.db.get(table, 3, 'n')
1067            s = delete(table, r)
1068            self.assertEqual(s, 1)
1069            s = delete(table, r)
1070            self.assertEqual(s, 0)
1071            r = query('select * from "%s"' % table).dictresult()
1072            self.assertEqual(len(r), 1)
1073            r = r[0]
1074            result = {'n': 2, 't': 'y'}
1075            self.assertEqual(r, result)
1076            r = self.db.get(table, 2, 'n')
1077            s = delete(table, r)
1078            self.assertEqual(s, 1)
1079            s = delete(table, r)
1080            self.assertEqual(s, 0)
1081            self.assertRaises(pg.DatabaseError, self.db.get, table, 2, 'n')
1082            query('drop table "%s"' % table)
1083
1084    def testDeleteWithCompositeKey(self):
1085        query = self.db.query
1086        table = 'delete_test_table_1'
1087        query("drop table if exists %s" % table)
1088        query("create table %s ("
1089            "n integer, t text, primary key (n))" % table)
1090        for n, t in enumerate('abc'):
1091            query("insert into %s values("
1092                "%d, '%s')" % (table, n + 1, t))
1093        self.assertRaises(pg.ProgrammingError, self.db.delete,
1094            table, dict(t='b'))
1095        self.assertEqual(self.db.delete(table, dict(n=2)), 1)
1096        r = query('select t from "%s" where n=2' % table
1097                  ).getresult()
1098        self.assertEqual(r, [])
1099        self.assertEqual(self.db.delete(table, dict(n=2)), 0)
1100        r = query('select t from "%s" where n=3' % table
1101                  ).getresult()[0][0]
1102        self.assertEqual(r, 'c')
1103        query("drop table %s" % table)
1104        table = 'delete_test_table_2'
1105        query("drop table if exists %s" % table)
1106        query("create table %s ("
1107            "n integer, m integer, t text, primary key (n, m))" % table)
1108        for n in range(3):
1109            for m in range(2):
1110                t = chr(ord('a') + 2 * n + m)
1111                query("insert into %s values("
1112                    "%d, %d, '%s')" % (table, n + 1, m + 1, t))
1113        self.assertRaises(pg.ProgrammingError, self.db.delete,
1114            table, dict(n=2, t='b'))
1115        self.assertEqual(self.db.delete(table, dict(n=2, m=2)), 1)
1116        r = [r[0] for r in query('select t from "%s" where n=2'
1117            ' order by m' % table).getresult()]
1118        self.assertEqual(r, ['c'])
1119        self.assertEqual(self.db.delete(table, dict(n=2, m=2)), 0)
1120        r = [r[0] for r in query('select t from "%s" where n=3'
1121            ' order by m' % table).getresult()]
1122        self.assertEqual(r, ['e', 'f'])
1123        self.assertEqual(self.db.delete(table, dict(n=3, m=1)), 1)
1124        r = [r[0] for r in query('select t from "%s" where n=3'
1125            ' order by m' % table).getresult()]
1126        self.assertEqual(r, ['f'])
1127        query("drop table %s" % table)
1128
1129    def testTransaction(self):
1130        query = self.db.query
1131        query("drop table if exists test_table")
1132        query("create table test_table (n integer)")
1133        self.db.begin()
1134        query("insert into test_table values (1)")
1135        query("insert into test_table values (2)")
1136        self.db.commit()
1137        self.db.begin()
1138        query("insert into test_table values (3)")
1139        query("insert into test_table values (4)")
1140        self.db.rollback()
1141        self.db.begin()
1142        query("insert into test_table values (5)")
1143        self.db.savepoint('before6')
1144        query("insert into test_table values (6)")
1145        self.db.rollback('before6')
1146        query("insert into test_table values (7)")
1147        self.db.commit()
1148        self.db.begin()
1149        self.db.savepoint('before8')
1150        query("insert into test_table values (8)")
1151        self.db.release('before8')
1152        self.assertRaises(pg.ProgrammingError, self.db.rollback, 'before8')
1153        self.db.commit()
1154        self.db.start()
1155        query("insert into test_table values (9)")
1156        self.db.end()
1157        r = [r[0] for r in query(
1158            "select * from test_table order by 1").getresult()]
1159        self.assertEqual(r, [1, 2, 5, 7, 9])
1160        query("drop table test_table")
1161
1162    @unittest.skipIf(no_with, 'context managers not supported')
1163    def testContextManager(self):
1164        query = self.db.query
1165        query("drop table if exists test_table")
1166        query("create table test_table (n integer check(n>0))")
1167        # wrap "with" statements to avoid SyntaxError in Python < 2.5
1168        exec """from __future__ import with_statement\nif True:
1169        with self.db:
1170            query("insert into test_table values (1)")
1171            query("insert into test_table values (2)")
1172        try:
1173            with self.db:
1174                query("insert into test_table values (3)")
1175                query("insert into test_table values (4)")
1176                raise ValueError('test transaction should rollback')
1177        except ValueError, error:
1178            self.assertEqual(str(error), 'test transaction should rollback')
1179        with self.db:
1180            query("insert into test_table values (5)")
1181        try:
1182            with self.db:
1183                query("insert into test_table values (6)")
1184                query("insert into test_table values (-1)")
1185        except pg.ProgrammingError, error:
1186            self.assertTrue('check' in str(error))
1187        with self.db:
1188            query("insert into test_table values (7)")\n"""
1189        r = [r[0] for r in query(
1190            "select * from test_table order by 1").getresult()]
1191        self.assertEqual(r, [1, 2, 5, 7])
1192        query("drop table test_table")
1193
1194    def testBytea(self):
1195        query = self.db.query
1196        query('drop table if exists bytea_test')
1197        query('create table bytea_test ('
1198            'data bytea)')
1199        s = "It's all \\ kinds \x00 of\r nasty \xff stuff!\n"
1200        r = self.db.escape_bytea(s)
1201        query('insert into bytea_test values('
1202            "'%s')" % r)
1203        r = query('select * from bytea_test').getresult()
1204        self.assertTrue(len(r) == 1)
1205        r = r[0]
1206        self.assertTrue(len(r) == 1)
1207        r = r[0]
1208        r = self.db.unescape_bytea(r)
1209        self.assertEqual(r, s)
1210        query('drop table bytea_test')
1211
1212    def testDebugWithCallable(self):
1213        if debug:
1214            self.assertEqual(self.db.debug, debug)
1215        else:
1216            self.assertIsNone(self.db.debug)
1217        s = []
1218        self.db.debug = s.append
1219        try:
1220            self.db.query("select 1")
1221            self.db.query("select 2")
1222            self.assertEqual(s, ["select 1", "select 2"])
1223        finally:
1224            self.db.debug = debug
1225
1226
1227class TestSchemas(unittest.TestCase):
1228    """Test correct handling of schemas (namespaces)."""
1229
1230    @classmethod
1231    def setUpClass(cls):
1232        db = DB()
1233        query = db.query
1234        query("set client_min_messages=warning")
1235        for num_schema in range(5):
1236            if num_schema:
1237                schema = "s%d" % num_schema
1238                query("drop schema if exists %s cascade" % (schema,))
1239                try:
1240                    query("create schema %s" % (schema,))
1241                except pg.ProgrammingError:
1242                    raise RuntimeError("The test user cannot create schemas.\n"
1243                        "Grant create on database %s to the user"
1244                        " for running these tests." % dbname)
1245            else:
1246                schema = "public"
1247                query("drop table if exists %s.t" % (schema,))
1248                query("drop table if exists %s.t%d" % (schema, num_schema))
1249            query("create table %s.t with oids as select 1 as n, %d as d"
1250                  % (schema, num_schema))
1251            query("create table %s.t%d with oids as select 1 as n, %d as d"
1252                  % (schema, num_schema, num_schema))
1253        db.close()
1254
1255    @classmethod
1256    def tearDownClass(cls):
1257        db = DB()
1258        query = db.query
1259        query("set client_min_messages=warning")
1260        for num_schema in range(5):
1261            if num_schema:
1262                schema = "s%d" % num_schema
1263                query("drop schema %s cascade" % (schema,))
1264            else:
1265                schema = "public"
1266                query("drop table %s.t" % (schema,))
1267                query("drop table %s.t%d" % (schema, num_schema))
1268        db.close()
1269
1270    def setUp(self):
1271        self.db = DB()
1272        self.db.query("set client_min_messages=warning")
1273
1274    def tearDown(self):
1275        self.db.close()
1276
1277    def testGetTables(self):
1278        tables = self.db.get_tables()
1279        for num_schema in range(5):
1280            if num_schema:
1281                schema = "s" + str(num_schema)
1282            else:
1283                schema = "public"
1284            for t in (schema + ".t",
1285                    schema + ".t" + str(num_schema)):
1286                self.assertIn(t, tables)
1287
1288    def testGetAttnames(self):
1289        get_attnames = self.db.get_attnames
1290        query = self.db.query
1291        result = {'oid': 'int', 'd': 'int', 'n': 'int'}
1292        r = get_attnames("t")
1293        self.assertEqual(r, result)
1294        r = get_attnames("s4.t4")
1295        self.assertEqual(r, result)
1296        query("drop table if exists s3.t3m")
1297        query("create table s3.t3m with oids as select 1 as m")
1298        result_m = {'oid': 'int', 'm': 'int'}
1299        r = get_attnames("s3.t3m")
1300        self.assertEqual(r, result_m)
1301        query("set search_path to s1,s3")
1302        r = get_attnames("t3")
1303        self.assertEqual(r, result)
1304        r = get_attnames("t3m")
1305        self.assertEqual(r, result_m)
1306        query("drop table s3.t3m")
1307
1308    def testGet(self):
1309        get = self.db.get
1310        query = self.db.query
1311        PrgError = pg.ProgrammingError
1312        self.assertEqual(get("t", 1, 'n')['d'], 0)
1313        self.assertEqual(get("t0", 1, 'n')['d'], 0)
1314        self.assertEqual(get("public.t", 1, 'n')['d'], 0)
1315        self.assertEqual(get("public.t0", 1, 'n')['d'], 0)
1316        self.assertRaises(PrgError, get, "public.t1", 1, 'n')
1317        self.assertEqual(get("s1.t1", 1, 'n')['d'], 1)
1318        self.assertEqual(get("s3.t", 1, 'n')['d'], 3)
1319        query("set search_path to s2,s4")
1320        self.assertRaises(PrgError, get, "t1", 1, 'n')
1321        self.assertEqual(get("t4", 1, 'n')['d'], 4)
1322        self.assertRaises(PrgError, get, "t3", 1, 'n')
1323        self.assertEqual(get("t", 1, 'n')['d'], 2)
1324        self.assertEqual(get("s3.t3", 1, 'n')['d'], 3)
1325        query("set search_path to s1,s3")
1326        self.assertRaises(PrgError, get, "t2", 1, 'n')
1327        self.assertEqual(get("t3", 1, 'n')['d'], 3)
1328        self.assertRaises(PrgError, get, "t4", 1, 'n')
1329        self.assertEqual(get("t", 1, 'n')['d'], 1)
1330        self.assertEqual(get("s4.t4", 1, 'n')['d'], 4)
1331
1332    def testMangling(self):
1333        get = self.db.get
1334        query = self.db.query
1335        r = get("t", 1, 'n')
1336        self.assertIn('oid(public.t)', r)
1337        query("set search_path to s2")
1338        r = get("t2", 1, 'n')
1339        self.assertIn('oid(s2.t2)', r)
1340        query("set search_path to s3")
1341        r = get("t", 1, 'n')
1342        self.assertIn('oid(s3.t)', r)
1343
1344
1345if __name__ == '__main__':
1346    unittest.main()
Note: See TracBrowser for help on using the repository browser.