source: branches/4.x/tests/test_classic_dbwrapper.py @ 763

Last change on this file since 763 was 763, checked in by cito, 4 years ago

Achieve 100% test coverage for pg module on the trunk

Note that some lines are only covered in certain Pg or Py versions,
so you need to run tests with different versions to be sure.

Also added another synonym for transaction methods,
you can now pick your favorite for all three of them.

  • Property svn:executable set to *
  • Property svn:keywords set to Id
File size: 68.2 KB
Line 
1#! /usr/bin/python
2# -*- coding: utf-8 -*-
3
4"""Test the classic PyGreSQL interface.
5
6Sub-tests for the DB wrapper object.
7
8Contributed by Christoph Zwerschke.
9
10These tests need a database to test against.
11
12"""
13
14try:
15    import unittest2 as unittest  # for Python < 2.7
16except ImportError:
17    import unittest
18import os
19
20import sys
21
22import pg  # the module under test
23
24from decimal import Decimal
25
26# check whether the "with" statement is supported
27no_with = sys.version_info[:2] < (2, 5)
28
29# We need a database to test against.  If LOCAL_PyGreSQL.py exists we will
30# get our information from that.  Otherwise we use the defaults.
31# The current user must have create schema privilege on the database.
32dbname = 'unittest'
33dbhost = None
34dbport = 5432
35
36debug = False  # let DB wrapper print debugging output
37
38try:
39    from LOCAL_PyGreSQL import *
40except ImportError:
41    pass
42
43windows = os.name == 'nt'
44
45# There is a known a bug in libpq under Windows which can cause
46# the interface to crash when calling PQhost():
47do_not_ask_for_host = windows
48do_not_ask_for_host_reason = 'libpq issue on Windows'
49
50
51def DB():
52    """Create a DB wrapper object connecting to the test database."""
53    db = pg.DB(dbname, dbhost, dbport)
54    if debug:
55        db.debug = debug
56    db.query("set client_min_messages=warning")
57    return db
58
59
60class TestDBClassBasic(unittest.TestCase):
61    """Test existence of the DB class wrapped pg connection methods."""
62
63    def setUp(self):
64        self.db = DB()
65
66    def tearDown(self):
67        try:
68            self.db.close()
69        except pg.InternalError:
70            pass
71
72    def testAllDBAttributes(self):
73        attributes = [
74            'abort',
75            'begin',
76            'cancel', 'clear', 'close', 'commit',
77            'db', 'dbname', 'debug', 'delete',
78            'end', 'endcopy', 'error',
79            'escape_bytea', 'escape_identifier',
80            'escape_literal', 'escape_string',
81            'fileno',
82            'get', 'get_attnames', 'get_databases',
83            'get_notice_receiver', 'get_parameter',
84            'get_relations', 'get_tables',
85            'getline', 'getlo', 'getnotify',
86            'has_table_privilege', 'host',
87            'insert', 'inserttable',
88            'locreate', 'loimport',
89            'notification_handler',
90            'options',
91            'parameter', 'pkey', 'port',
92            'protocol_version', 'putline',
93            'query',
94            'release', 'reopen', 'reset', 'rollback',
95            'savepoint', 'server_version',
96            'set_notice_receiver', 'set_parameter',
97            'source', 'start', 'status',
98            'transaction', 'truncate', 'tty',
99            'unescape_bytea', 'update',
100            'use_regtypes', 'user',
101        ]
102        if self.db.server_version < 90000:  # PostgreSQL < 9.0
103            attributes.remove('escape_identifier')
104            attributes.remove('escape_literal')
105        db_attributes = [a for a in dir(self.db)
106            if not a.startswith('_')]
107        self.assertEqual(attributes, db_attributes)
108
109    def testAttributeDb(self):
110        self.assertEqual(self.db.db.db, dbname)
111
112    def testAttributeDbname(self):
113        self.assertEqual(self.db.dbname, dbname)
114
115    def testAttributeError(self):
116        error = self.db.error
117        self.assertTrue(not error or 'krb5_' in error)
118        self.assertEqual(self.db.error, self.db.db.error)
119
120    @unittest.skipIf(do_not_ask_for_host, do_not_ask_for_host_reason)
121    def testAttributeHost(self):
122        def_host = 'localhost'
123        host = self.db.host
124        self.assertIsInstance(host, str)
125        self.assertEqual(host, dbhost or def_host)
126        self.assertEqual(host, self.db.db.host)
127
128    def testAttributeOptions(self):
129        no_options = ''
130        options = self.db.options
131        self.assertEqual(options, no_options)
132        self.assertEqual(options, self.db.db.options)
133
134    def testAttributePort(self):
135        def_port = 5432
136        port = self.db.port
137        self.assertIsInstance(port, int)
138        self.assertEqual(port, dbport or def_port)
139        self.assertEqual(port, self.db.db.port)
140
141    def testAttributeProtocolVersion(self):
142        protocol_version = self.db.protocol_version
143        self.assertIsInstance(protocol_version, int)
144        self.assertTrue(2 <= protocol_version < 4)
145        self.assertEqual(protocol_version, self.db.db.protocol_version)
146
147    def testAttributeServerVersion(self):
148        server_version = self.db.server_version
149        self.assertIsInstance(server_version, int)
150        self.assertTrue(70400 <= server_version < 100000)
151        self.assertEqual(server_version, self.db.db.server_version)
152
153    def testAttributeStatus(self):
154        status_ok = 1
155        status = self.db.status
156        self.assertIsInstance(status, int)
157        self.assertEqual(status, status_ok)
158        self.assertEqual(status, self.db.db.status)
159
160    def testAttributeTty(self):
161        def_tty = ''
162        tty = self.db.tty
163        self.assertIsInstance(tty, str)
164        self.assertEqual(tty, def_tty)
165        self.assertEqual(tty, self.db.db.tty)
166
167    def testAttributeUser(self):
168        no_user = 'Deprecated facility'
169        user = self.db.user
170        self.assertTrue(user)
171        self.assertIsInstance(user, str)
172        self.assertNotEqual(user, no_user)
173        self.assertEqual(user, self.db.db.user)
174
175    def testMethodEscapeLiteral(self):
176        if self.db.server_version < 90000:  # PostgreSQL < 9.0
177            self.skipTest('Escaping functions not supported')
178        self.assertEqual(self.db.escape_literal(''), "''")
179
180    def testMethodEscapeIdentifier(self):
181        if self.db.server_version < 90000:  # PostgreSQL < 9.0
182            self.skipTest('Escaping functions not supported')
183        self.assertEqual(self.db.escape_identifier(''), '""')
184
185    def testMethodEscapeString(self):
186        self.assertEqual(self.db.escape_string(''), '')
187
188    def testMethodEscapeBytea(self):
189        self.assertEqual(self.db.escape_bytea('').replace(
190            '\\x', '').replace('\\', ''), '')
191
192    def testMethodUnescapeBytea(self):
193        self.assertEqual(self.db.unescape_bytea(''), '')
194
195    def testMethodQuery(self):
196        query = self.db.query
197        query("select 1+1")
198        query("select 1+$1+$2", 2, 3)
199        query("select 1+$1+$2", (2, 3))
200        query("select 1+$1+$2", [2, 3])
201        query("select 1+$1", 1)
202
203    def testMethodQueryEmpty(self):
204        self.assertRaises(ValueError, self.db.query, '')
205
206    def testMethodQueryProgrammingError(self):
207        try:
208            self.db.query("select 1/0")
209        except pg.ProgrammingError, error:
210            self.assertEqual(error.sqlstate, '22012')
211
212    def testMethodEndcopy(self):
213        try:
214            self.db.endcopy()
215        except IOError:
216            pass
217
218    def testMethodClose(self):
219        self.db.close()
220        try:
221            self.db.reset()
222        except pg.Error:
223            pass
224        else:
225            self.fail('Reset should give an error for a closed connection')
226        self.assertIsNone(self.db.db)
227        self.assertRaises(pg.InternalError, self.db.close)
228        self.assertRaises(pg.InternalError, self.db.query, 'select 1')
229        self.assertRaises(pg.InternalError, getattr, self.db, 'status')
230        self.assertRaises(pg.InternalError, getattr, self.db, 'error')
231        self.assertRaises(pg.InternalError, getattr, self.db, 'absent')
232
233    def testMethodReset(self):
234        con = self.db.db
235        self.db.reset()
236        self.assertIs(self.db.db, con)
237        self.db.query("select 1+1")
238        self.db.close()
239        self.assertRaises(pg.InternalError, self.db.reset)
240
241    def testMethodReopen(self):
242        con = self.db.db
243        self.db.reopen()
244        self.assertIsNot(self.db.db, con)
245        con = self.db.db
246        self.db.query("select 1+1")
247        self.db.close()
248        self.db.reopen()
249        self.assertIsNot(self.db.db, con)
250        self.db.query("select 1+1")
251        self.db.close()
252
253    def testExistingConnection(self):
254        db = pg.DB(self.db.db)
255        self.assertEqual(self.db.db, db.db)
256        self.assertTrue(db.db)
257        db.close()
258        self.assertTrue(db.db)
259        db.reopen()
260        self.assertTrue(db.db)
261        db.close()
262        self.assertTrue(db.db)
263        db = pg.DB(self.db)
264        self.assertEqual(self.db.db, db.db)
265        db = pg.DB(db=self.db.db)
266        self.assertEqual(self.db.db, db.db)
267
268        class DB2:
269            pass
270
271        db2 = DB2()
272        db2._cnx = self.db.db
273        db = pg.DB(db2)
274        self.assertEqual(self.db.db, db.db)
275
276
277class TestDBClass(unittest.TestCase):
278    """Test the methods of the DB class wrapped pg connection."""
279
280    @classmethod
281    def setUpClass(cls):
282        db = DB()
283        db.query("drop table if exists test cascade")
284        db.query("create table test ("
285            "i2 smallint, i4 integer, i8 bigint,"
286            " d numeric, f4 real, f8 double precision, m money,"
287            " v4 varchar(4), c4 char(4), t text)")
288        db.query("create or replace view test_view as"
289            " select i4, v4 from test")
290        db.close()
291
292    @classmethod
293    def tearDownClass(cls):
294        db = DB()
295        db.query("drop table test cascade")
296        db.close()
297
298    def setUp(self):
299        self.db = DB()
300        query = self.db.query
301        query('set client_encoding=utf8')
302        query('set standard_conforming_strings=on')
303        query("set lc_monetary='C'")
304        query("set datestyle='ISO,YMD'")
305        try:
306            query('set bytea_output=hex')
307        except pg.ProgrammingError:  # PostgreSQL < 9.0
308            pass
309
310    def tearDown(self):
311        self.db.close()
312
313    def testEscapeLiteral(self):
314        if self.db.server_version < 90000:  # PostgreSQL < 9.0
315            self.skipTest('Escaping functions not supported')
316        f = self.db.escape_literal
317        self.assertEqual(f("plain"), "'plain'")
318        self.assertEqual(f("that's k\xe4se"), "'that''s k\xe4se'")
319        self.assertEqual(f(r"It's fine to have a \ inside."),
320            r" E'It''s fine to have a \\ inside.'")
321        self.assertEqual(f('No "quotes" must be escaped.'),
322            "'No \"quotes\" must be escaped.'")
323
324    def testEscapeIdentifier(self):
325        if self.db.server_version < 90000:  # PostgreSQL < 9.0
326            self.skipTest('Escaping functions not supported')
327        f = self.db.escape_identifier
328        self.assertEqual(f("plain"), '"plain"')
329        self.assertEqual(f("that's k\xe4se"), '"that\'s k\xe4se"')
330        self.assertEqual(f(r"It's fine to have a \ inside."),
331            '"It\'s fine to have a \\ inside."')
332        self.assertEqual(f('All "quotes" must be escaped.'),
333            '"All ""quotes"" must be escaped."')
334
335    def testEscapeString(self):
336        f = self.db.escape_string
337        self.assertEqual(f("plain"), "plain")
338        self.assertEqual(f("that's k\xe4se"), "that''s k\xe4se")
339        self.assertEqual(f(r"It's fine to have a \ inside."),
340            r"It''s fine to have a \ inside.")
341
342    def testEscapeBytea(self):
343        f = self.db.escape_bytea
344        # note that escape_byte always returns hex output since PostgreSQL 9.0,
345        # regardless of the bytea_output setting
346        if self.db.server_version < 90000:
347            self.assertEqual(f("plain"), r"plain")
348            self.assertEqual(f("that's k\xe4se"), r"that''s k\344se")
349            self.assertEqual(f('O\x00ps\xff!'), r"O\000ps\377!")
350        else:
351            self.assertEqual(f("plain"), r"\x706c61696e")
352            self.assertEqual(f("that's k\xe4se"), r"\x746861742773206be47365")
353            self.assertEqual(f('O\x00ps\xff!'), r"\x4f007073ff21")
354
355    def testUnescapeBytea(self):
356        f = self.db.unescape_bytea
357        self.assertEqual(f("plain"), "plain")
358        self.assertEqual(f("that's k\\344se"), "that's k\xe4se")
359        self.assertEqual(f(r'O\000ps\377!'), 'O\x00ps\xff!')
360        self.assertEqual(f(r"\\x706c61696e"), r"\x706c61696e")
361        self.assertEqual(f(r"\\x746861742773206be47365"),
362            r"\x746861742773206be47365")
363        self.assertEqual(f(r"\\x4f007073ff21"), r"\x4f007073ff21")
364
365    def testQuote(self):
366        f = self.db._quote
367        self.assertEqual(f(None, None), 'NULL')
368        self.assertEqual(f(None, 'int'), 'NULL')
369        self.assertEqual(f(None, 'float'), 'NULL')
370        self.assertEqual(f(None, 'num'), 'NULL')
371        self.assertEqual(f(None, 'money'), 'NULL')
372        self.assertEqual(f(None, 'bool'), 'NULL')
373        self.assertEqual(f(None, 'date'), 'NULL')
374        self.assertEqual(f('', 'int'), 'NULL')
375        self.assertEqual(f('', 'float'), 'NULL')
376        self.assertEqual(f('', 'num'), 'NULL')
377        self.assertEqual(f('', 'money'), 'NULL')
378        self.assertEqual(f('', 'bool'), 'NULL')
379        self.assertEqual(f('', 'date'), 'NULL')
380        self.assertEqual(f('', 'text'), "''")
381        self.assertEqual(f(0, 'int'), '0')
382        self.assertEqual(f(0, 'num'), '0')
383        self.assertEqual(f(1, 'int'), '1')
384        self.assertEqual(f(1, 'num'), '1')
385        self.assertEqual(f(-1, 'int'), '-1')
386        self.assertEqual(f(-1, 'num'), '-1')
387        self.assertEqual(f(123456789, 'int'), '123456789')
388        self.assertEqual(f(123456987, 'num'), '123456987')
389        self.assertEqual(f(1.23654789, 'num'), '1.23654789')
390        self.assertEqual(f(12365478.9, 'num'), '12365478.9')
391        self.assertEqual(f('123456789', 'num'), '123456789')
392        self.assertEqual(f('1.23456789', 'num'), '1.23456789')
393        self.assertEqual(f('12345678.9', 'num'), '12345678.9')
394        self.assertEqual(f(123, 'money'), '123')
395        self.assertEqual(f('123', 'money'), '123')
396        self.assertEqual(f(123.45, 'money'), '123.45')
397        self.assertEqual(f('123.45', 'money'), '123.45')
398        self.assertEqual(f(123.454, 'money'), '123.454')
399        self.assertEqual(f('123.454', 'money'), '123.454')
400        self.assertEqual(f(123.456, 'money'), '123.456')
401        self.assertEqual(f('123.456', 'money'), '123.456')
402        self.assertEqual(f('f', 'bool'), "'f'")
403        self.assertEqual(f('F', 'bool'), "'f'")
404        self.assertEqual(f('false', 'bool'), "'f'")
405        self.assertEqual(f('False', 'bool'), "'f'")
406        self.assertEqual(f('FALSE', 'bool'), "'f'")
407        self.assertEqual(f(0, 'bool'), "'f'")
408        self.assertEqual(f('0', 'bool'), "'f'")
409        self.assertEqual(f('-', 'bool'), "'f'")
410        self.assertEqual(f('n', 'bool'), "'f'")
411        self.assertEqual(f('N', 'bool'), "'f'")
412        self.assertEqual(f('no', 'bool'), "'f'")
413        self.assertEqual(f('off', 'bool'), "'f'")
414        self.assertEqual(f('t', 'bool'), "'t'")
415        self.assertEqual(f('T', 'bool'), "'t'")
416        self.assertEqual(f('true', 'bool'), "'t'")
417        self.assertEqual(f('True', 'bool'), "'t'")
418        self.assertEqual(f('TRUE', 'bool'), "'t'")
419        self.assertEqual(f(1, 'bool'), "'t'")
420        self.assertEqual(f(2, 'bool'), "'t'")
421        self.assertEqual(f(-1, 'bool'), "'t'")
422        self.assertEqual(f(0.5, 'bool'), "'t'")
423        self.assertEqual(f('1', 'bool'), "'t'")
424        self.assertEqual(f('y', 'bool'), "'t'")
425        self.assertEqual(f('Y', 'bool'), "'t'")
426        self.assertEqual(f('yes', 'bool'), "'t'")
427        self.assertEqual(f('on', 'bool'), "'t'")
428        self.assertEqual(f('01.01.2000', 'date'), "'01.01.2000'")
429        self.assertEqual(f(123, 'text'), "'123'")
430        self.assertEqual(f(1.23, 'text'), "'1.23'")
431        self.assertEqual(f('abc', 'text'), "'abc'")
432        self.assertEqual(f("ab'c", 'text'), "'ab''c'")
433        self.assertEqual(f('ab\\c', 'text'), "'ab\\c'")
434        self.assertEqual(f("a\\b'c", 'text'), "'a\\b''c'")
435        self.db.query('set standard_conforming_strings=off')
436        self.assertEqual(f('ab\\c', 'text'), "'ab\\\\c'")
437        self.assertEqual(f("a\\b'c", 'text'), "'a\\\\b''c'")
438
439    def testGetParameter(self):
440        f = self.db.get_parameter
441        self.assertRaises(TypeError, f)
442        self.assertRaises(TypeError, f, None)
443        self.assertRaises(TypeError, f, 42)
444        self.assertRaises(TypeError, f, '')
445        self.assertRaises(TypeError, f, [])
446        self.assertRaises(TypeError, f, [''])
447        self.assertRaises(pg.ProgrammingError, f, 'this_does_not_exist')
448        r = f('standard_conforming_strings')
449        self.assertEqual(r, 'on')
450        r = f('lc_monetary')
451        self.assertEqual(r, 'C')
452        r = f('datestyle')
453        self.assertEqual(r, 'ISO, YMD')
454        r = f('bytea_output')
455        self.assertEqual(r, 'hex')
456        r = f(['bytea_output', 'lc_monetary'])
457        self.assertIsInstance(r, list)
458        self.assertEqual(r, ['hex', 'C'])
459        r = f(('standard_conforming_strings', 'datestyle', 'bytea_output'))
460        self.assertEqual(r, ['on', 'ISO, YMD', 'hex'])
461        r = f(set(['bytea_output', 'lc_monetary']))
462        self.assertIsInstance(r, dict)
463        self.assertEqual(r, {'bytea_output': 'hex', 'lc_monetary': 'C'})
464        r = f(set(['Bytea_Output', ' LC_Monetary ']))
465        self.assertIsInstance(r, dict)
466        self.assertEqual(r, {'Bytea_Output': 'hex', ' LC_Monetary ': 'C'})
467        s = dict.fromkeys(('bytea_output', 'lc_monetary'))
468        r = f(s)
469        self.assertIs(r, s)
470        self.assertEqual(r, {'bytea_output': 'hex', 'lc_monetary': 'C'})
471        s = dict.fromkeys(('Bytea_Output', ' LC_Monetary '))
472        r = f(s)
473        self.assertIs(r, s)
474        self.assertEqual(r, {'Bytea_Output': 'hex', ' LC_Monetary ': 'C'})
475
476    def testGetParameterServerVersion(self):
477        r = self.db.get_parameter('server_version_num')
478        self.assertIsInstance(r, str)
479        s = self.db.server_version
480        self.assertIsInstance(s, int)
481        self.assertEqual(r, str(s))
482
483    def testGetParameterAll(self):
484        f = self.db.get_parameter
485        r = f('all')
486        self.assertIsInstance(r, dict)
487        self.assertEqual(r['standard_conforming_strings'], 'on')
488        self.assertEqual(r['lc_monetary'], 'C')
489        self.assertEqual(r['DateStyle'], 'ISO, YMD')
490        self.assertEqual(r['bytea_output'], 'hex')
491
492    def testSetParameter(self):
493        f = self.db.set_parameter
494        g = self.db.get_parameter
495        self.assertRaises(TypeError, f)
496        self.assertRaises(TypeError, f, None)
497        self.assertRaises(TypeError, f, 42)
498        self.assertRaises(TypeError, f, '')
499        self.assertRaises(TypeError, f, [])
500        self.assertRaises(TypeError, f, [''])
501        self.assertRaises(ValueError, f, 'all', 'invalid')
502        self.assertRaises(ValueError, f, {
503            'invalid1': 'value1', 'invalid2': 'value2'}, 'value')
504        self.assertRaises(pg.ProgrammingError, f, 'this_does_not_exist')
505        f('standard_conforming_strings', 'off')
506        self.assertEqual(g('standard_conforming_strings'), 'off')
507        f('datestyle', 'ISO, DMY')
508        self.assertEqual(g('datestyle'), 'ISO, DMY')
509        f(['standard_conforming_strings', 'datestyle'], ['on', 'ISO, DMY'])
510        self.assertEqual(g('standard_conforming_strings'), 'on')
511        self.assertEqual(g('datestyle'), 'ISO, DMY')
512        f(['default_with_oids', 'standard_conforming_strings'], 'off')
513        self.assertEqual(g('default_with_oids'), 'off')
514        self.assertEqual(g('standard_conforming_strings'), 'off')
515        f(('standard_conforming_strings', 'datestyle'), ('on', 'ISO, YMD'))
516        self.assertEqual(g('standard_conforming_strings'), 'on')
517        self.assertEqual(g('datestyle'), 'ISO, YMD')
518        f(('default_with_oids', 'standard_conforming_strings'), 'off')
519        self.assertEqual(g('default_with_oids'), 'off')
520        self.assertEqual(g('standard_conforming_strings'), 'off')
521        f(set(['default_with_oids', 'standard_conforming_strings']), 'on')
522        self.assertEqual(g('default_with_oids'), 'on')
523        self.assertEqual(g('standard_conforming_strings'), 'on')
524        self.assertRaises(ValueError, f, set([ 'default_with_oids',
525            'standard_conforming_strings']), ['off', 'on'])
526        f(set(['default_with_oids', 'standard_conforming_strings']),
527            ['off', 'off'])
528        self.assertEqual(g('default_with_oids'), 'off')
529        self.assertEqual(g('standard_conforming_strings'), 'off')
530        f({'standard_conforming_strings': 'on', 'datestyle': 'ISO, YMD'})
531        self.assertEqual(g('standard_conforming_strings'), 'on')
532        self.assertEqual(g('datestyle'), 'ISO, YMD')
533
534    def testResetParameter(self):
535        db = DB()
536        f = db.set_parameter
537        g = db.get_parameter
538        r = g('default_with_oids')
539        self.assertIn(r, ('on', 'off'))
540        dwi, not_dwi = r, r == 'on' and 'off' or 'on'
541        r = g('standard_conforming_strings')
542        self.assertIn(r, ('on', 'off'))
543        scs, not_scs = r, r == 'on' and 'off' or 'on'
544        f('default_with_oids', not_dwi)
545        f('standard_conforming_strings', not_scs)
546        self.assertEqual(g('default_with_oids'), not_dwi)
547        self.assertEqual(g('standard_conforming_strings'), not_scs)
548        f('default_with_oids')
549        f('standard_conforming_strings', None)
550        self.assertEqual(g('default_with_oids'), dwi)
551        self.assertEqual(g('standard_conforming_strings'), scs)
552        f('default_with_oids', not_dwi)
553        f('standard_conforming_strings', not_scs)
554        self.assertEqual(g('default_with_oids'), not_dwi)
555        self.assertEqual(g('standard_conforming_strings'), not_scs)
556        f(['default_with_oids', 'standard_conforming_strings'], None)
557        self.assertEqual(g('default_with_oids'), dwi)
558        self.assertEqual(g('standard_conforming_strings'), scs)
559        f('default_with_oids', not_dwi)
560        f('standard_conforming_strings', not_scs)
561        self.assertEqual(g('default_with_oids'), not_dwi)
562        self.assertEqual(g('standard_conforming_strings'), not_scs)
563        f(('default_with_oids', 'standard_conforming_strings'))
564        self.assertEqual(g('default_with_oids'), dwi)
565        self.assertEqual(g('standard_conforming_strings'), scs)
566        f('default_with_oids', not_dwi)
567        f('standard_conforming_strings', not_scs)
568        self.assertEqual(g('default_with_oids'), not_dwi)
569        self.assertEqual(g('standard_conforming_strings'), not_scs)
570        f(set(['default_with_oids', 'standard_conforming_strings']), None)
571        self.assertEqual(g('default_with_oids'), dwi)
572        self.assertEqual(g('standard_conforming_strings'), scs)
573
574    def testResetParameterAll(self):
575        db = DB()
576        f = db.set_parameter
577        self.assertRaises(ValueError, f, 'all', 0)
578        self.assertRaises(ValueError, f, 'all', 'off')
579        g = db.get_parameter
580        r = g('default_with_oids')
581        self.assertIn(r, ('on', 'off'))
582        dwi, not_dwi = r, r == 'on' and 'off' or 'on'
583        r = g('standard_conforming_strings')
584        self.assertIn(r, ('on', 'off'))
585        scs, not_scs = r, r == 'on' and 'off' or 'on'
586        f('default_with_oids', not_dwi)
587        f('standard_conforming_strings', not_scs)
588        self.assertEqual(g('default_with_oids'), not_dwi)
589        self.assertEqual(g('standard_conforming_strings'), not_scs)
590        f('all')
591        self.assertEqual(g('default_with_oids'), dwi)
592        self.assertEqual(g('standard_conforming_strings'), scs)
593
594    def testSetParameterLocal(self):
595        f = self.db.set_parameter
596        g = self.db.get_parameter
597        self.assertEqual(g('standard_conforming_strings'), 'on')
598        self.db.begin()
599        f('standard_conforming_strings', 'off', local=True)
600        self.assertEqual(g('standard_conforming_strings'), 'off')
601        self.db.end()
602        self.assertEqual(g('standard_conforming_strings'), 'on')
603
604    def testSetParameterSession(self):
605        f = self.db.set_parameter
606        g = self.db.get_parameter
607        self.assertEqual(g('standard_conforming_strings'), 'on')
608        self.db.begin()
609        f('standard_conforming_strings', 'off', local=False)
610        self.assertEqual(g('standard_conforming_strings'), 'off')
611        self.db.end()
612        self.assertEqual(g('standard_conforming_strings'), 'off')
613
614    def testQuery(self):
615        query = self.db.query
616        query("drop table if exists test_table")
617        q = "create table test_table (n integer) with oids"
618        r = query(q)
619        self.assertIsNone(r)
620        q = "insert into test_table values (1)"
621        r = query(q)
622        self.assertIsInstance(r, int)
623        q = "insert into test_table select 2"
624        r = query(q)
625        self.assertIsInstance(r, int)
626        oid = r
627        q = "select oid from test_table where n=2"
628        r = query(q).getresult()
629        self.assertEqual(len(r), 1)
630        r = r[0]
631        self.assertEqual(len(r), 1)
632        r = r[0]
633        self.assertEqual(r, oid)
634        q = "insert into test_table select 3 union select 4 union select 5"
635        r = query(q)
636        self.assertIsInstance(r, str)
637        self.assertEqual(r, '3')
638        q = "update test_table set n=4 where n<5"
639        r = query(q)
640        self.assertIsInstance(r, str)
641        self.assertEqual(r, '4')
642        q = "delete from test_table"
643        r = query(q)
644        self.assertIsInstance(r, str)
645        self.assertEqual(r, '5')
646        query("drop table test_table")
647
648    def testMultipleQueries(self):
649        self.assertEqual(self.db.query(
650            "create temporary table test_multi (n integer);"
651            "insert into test_multi values (4711);"
652            "select n from test_multi").getresult()[0][0], 4711)
653
654    def testQueryWithParams(self):
655        query = self.db.query
656        query("drop table if exists test_table")
657        q = "create table test_table (n1 integer, n2 integer) with oids"
658        query(q)
659        q = "insert into test_table values ($1, $2)"
660        r = query(q, (1, 2))
661        self.assertIsInstance(r, int)
662        r = query(q, [3, 4])
663        self.assertIsInstance(r, int)
664        r = query(q, [5, 6])
665        self.assertIsInstance(r, int)
666        q = "select * from test_table order by 1, 2"
667        self.assertEqual(query(q).getresult(),
668            [(1, 2), (3, 4), (5, 6)])
669        q = "select * from test_table where n1=$1 and n2=$2"
670        self.assertEqual(query(q, 3, 4).getresult(), [(3, 4)])
671        q = "update test_table set n2=$2 where n1=$1"
672        r = query(q, 3, 7)
673        self.assertEqual(r, '1')
674        q = "select * from test_table order by 1, 2"
675        self.assertEqual(query(q).getresult(),
676            [(1, 2), (3, 7), (5, 6)])
677        q = "delete from test_table where n2!=$1"
678        r = query(q, 4)
679        self.assertEqual(r, '3')
680        query("drop table test_table")
681
682    def testEmptyQuery(self):
683        self.assertRaises(ValueError, self.db.query, '')
684
685    def testQueryProgrammingError(self):
686        try:
687            self.db.query("select 1/0")
688        except pg.ProgrammingError, error:
689            self.assertEqual(error.sqlstate, '22012')
690
691    def testPkey(self):
692        query = self.db.query
693        for n in range(4):
694            query("drop table if exists pkeytest%d" % n)
695        query("create table pkeytest0 ("
696            "a smallint)")
697        query("create table pkeytest1 ("
698            "b smallint primary key)")
699        query("create table pkeytest2 ("
700            "c smallint, d smallint primary key)")
701        query("create table pkeytest3 ("
702            "e smallint, f smallint, g smallint,"
703            " h smallint, i smallint,"
704            " primary key (f,h))")
705        pkey = self.db.pkey
706        self.assertRaises(KeyError, pkey, 'pkeytest0')
707        self.assertEqual(pkey('pkeytest1'), 'b')
708        self.assertEqual(pkey('pkeytest2'), 'd')
709        self.assertEqual(pkey('pkeytest3'), frozenset('fh'))
710        self.assertEqual(pkey('pkeytest0', 'none'), 'none')
711        self.assertEqual(pkey('pkeytest0'), 'none')
712        pkey(None, {'t': 'a', 'n.t': 'b'})
713        self.assertEqual(pkey('t'), 'a')
714        self.assertEqual(pkey('n.t'), 'b')
715        self.assertRaises(KeyError, pkey, 'pkeytest0')
716        for n in range(4):
717            query("drop table pkeytest%d" % n)
718
719    def testGetDatabases(self):
720        databases = self.db.get_databases()
721        self.assertIn('template0', databases)
722        self.assertIn('template1', databases)
723        self.assertNotIn('not existing database', databases)
724        self.assertIn('postgres', databases)
725        self.assertIn(dbname, databases)
726
727    def testGetTables(self):
728        get_tables = self.db.get_tables
729        result1 = get_tables()
730        self.assertIsInstance(result1, list)
731        for t in result1:
732            t = t.split('.', 1)
733            self.assertGreaterEqual(len(t), 2)
734            if len(t) > 2:
735                self.assertTrue(t[1].startswith('"'))
736            t = t[0]
737            self.assertNotEqual(t, 'information_schema')
738            self.assertFalse(t.startswith('pg_'))
739        tables = ('"A very Special Name"',
740            '"A_MiXeD_quoted_NaMe"', 'a1', 'a2',
741            'A_MiXeD_NaMe', '"another special name"',
742            'averyveryveryveryveryveryverylongtablename',
743            'b0', 'b3', 'x', 'xx', 'xXx', 'y', 'z')
744        for t in tables:
745            self.db.query('drop table if exists %s' % t)
746            self.db.query("create table %s"
747                " as select 0" % t)
748        result3 = get_tables()
749        result2 = []
750        for t in result3:
751            if t not in result1:
752                result2.append(t)
753        result3 = []
754        for t in tables:
755            if not t.startswith('"'):
756                t = t.lower()
757            result3.append('public.' + t)
758        self.assertEqual(result2, result3)
759        for t in result2:
760            self.db.query('drop table %s' % t)
761        result2 = get_tables()
762        self.assertEqual(result2, result1)
763
764    def testGetRelations(self):
765        get_relations = self.db.get_relations
766        result = get_relations()
767        self.assertIn('public.test', result)
768        self.assertIn('public.test_view', result)
769        result = get_relations('rv')
770        self.assertIn('public.test', result)
771        self.assertIn('public.test_view', result)
772        result = get_relations('r')
773        self.assertIn('public.test', result)
774        self.assertNotIn('public.test_view', result)
775        result = get_relations('v')
776        self.assertNotIn('public.test', result)
777        self.assertIn('public.test_view', result)
778        result = get_relations('cisSt')
779        self.assertNotIn('public.test', result)
780        self.assertNotIn('public.test_view', result)
781
782    def testGetAttnames(self):
783        self.assertRaises(pg.ProgrammingError,
784            self.db.get_attnames, 'does_not_exist')
785        self.assertRaises(pg.ProgrammingError,
786            self.db.get_attnames, 'has.too.many.dots')
787        attributes = self.db.get_attnames('test')
788        self.assertIsInstance(attributes, dict)
789        self.assertEqual(attributes, dict(
790            i2='int', i4='int', i8='int', d='num',
791            f4='float', f8='float', m='money',
792            v4='text', c4='text', t='text'))
793        for table in ('attnames_test_table', 'test table for attnames'):
794            self.db.query('drop table if exists "%s"' % table)
795            self.db.query('create table "%s" ('
796                ' a smallint, b integer, c bigint,'
797                ' e numeric, f float, f2 double precision, m money,'
798                ' x smallint, y smallint, z smallint,'
799                ' Normal_NaMe smallint, "Special Name" smallint,'
800                ' t text, u char(2), v varchar(2),'
801                ' primary key (y, u)) with oids' % table)
802            attributes = self.db.get_attnames(table)
803            result = {'a': 'int', 'c': 'int', 'b': 'int',
804                'e': 'num', 'f': 'float', 'f2': 'float', 'm': 'money',
805                'normal_name': 'int', 'Special Name': 'int',
806                'u': 'text', 't': 'text', 'v': 'text',
807                'y': 'int', 'x': 'int', 'z': 'int', 'oid': 'int'}
808            self.assertEqual(attributes, result)
809            self.db.query('drop table "%s"' % table)
810
811    def testHasTablePrivilege(self):
812        can = self.db.has_table_privilege
813        self.assertEqual(can('test'), True)
814        self.assertEqual(can('test', 'select'), True)
815        self.assertEqual(can('test', 'SeLeCt'), True)
816        self.assertEqual(can('test', 'SELECT'), True)
817        self.assertEqual(can('test', 'insert'), True)
818        self.assertEqual(can('test', 'update'), True)
819        self.assertEqual(can('test', 'delete'), True)
820        self.assertRaises(pg.ProgrammingError, can, 'test', 'foobar')
821        self.assertRaises(pg.ProgrammingError, can, 'table_does_not_exist')
822
823    def testGet(self):
824        get = self.db.get
825        query = self.db.query
826        for table in ('get_test_table', 'test table for get'):
827            query('drop table if exists "%s"' % table)
828            query('create table "%s" ('
829                "n integer, t text) with oids" % table)
830            for n, t in enumerate('xyz'):
831                query('insert into "%s" values('"%d, '%s')"
832                    % (table, n + 1, t))
833            self.assertRaises(pg.ProgrammingError, get, table, 2)
834            r = get(table, 2, 'n')
835            oid_table = table
836            if ' ' in table:
837                oid_table = '"%s"' % oid_table
838            oid_table = 'oid(public.%s)' % oid_table
839            self.assertIn(oid_table, r)
840            oid = r[oid_table]
841            self.assertIsInstance(oid, int)
842            result = {'t': 'y', 'n': 2, oid_table: oid}
843            self.assertEqual(r, result)
844            self.assertEqual(get(table + ' *', 2, 'n'), r)
845            self.assertEqual(get(table, oid, 'oid')['t'], 'y')
846            self.assertEqual(get(table, 1, 'n')['t'], 'x')
847            self.assertEqual(get(table, 3, 'n')['t'], 'z')
848            self.assertEqual(get(table, 2, 'n')['t'], 'y')
849            self.assertRaises(pg.DatabaseError, get, table, 4, 'n')
850            r['n'] = 3
851            self.assertEqual(get(table, r, 'n')['t'], 'z')
852            self.assertEqual(get(table, 1, 'n')['t'], 'x')
853            query('alter table "%s" alter n set not null' % table)
854            query('alter table "%s" add primary key (n)' % table)
855            self.assertEqual(get(table, 3)['t'], 'z')
856            self.assertEqual(get(table, 1)['t'], 'x')
857            self.assertEqual(get(table, 2)['t'], 'y')
858            r['n'] = 1
859            self.assertEqual(get(table, r)['t'], 'x')
860            r['n'] = 3
861            self.assertEqual(get(table, r)['t'], 'z')
862            r['n'] = 2
863            self.assertEqual(get(table, r)['t'], 'y')
864            query('drop table "%s"' % table)
865
866    def testGetWithCompositeKey(self):
867        get = self.db.get
868        query = self.db.query
869        table = 'get_test_table_1'
870        query("drop table if exists %s" % table)
871        query("create table %s ("
872            "n integer, t text, primary key (n))" % table)
873        for n, t in enumerate('abc'):
874            query("insert into %s values("
875                "%d, '%s')" % (table, n + 1, t))
876        self.assertEqual(get(table, 2)['t'], 'b')
877        query("drop table %s" % table)
878        table = 'get_test_table_2'
879        query("drop table if exists %s" % table)
880        query("create table %s ("
881            "n integer, m integer, t text, primary key (n, m))" % table)
882        for n in range(3):
883            for m in range(2):
884                t = chr(ord('a') + 2 * n + m)
885                query("insert into %s values("
886                    "%d, %d, '%s')" % (table, n + 1, m + 1, t))
887        self.assertRaises(pg.ProgrammingError, get, table, 2)
888        self.assertEqual(get(table, dict(n=2, m=2))['t'], 'd')
889        self.assertEqual(get(table, dict(n=1, m=2),
890                             ('n', 'm'))['t'], 'b')
891        self.assertEqual(get(table, dict(n=3, m=2),
892                             frozenset(['n', 'm']))['t'], 'f')
893        query("drop table %s" % table)
894
895    def testGetFromView(self):
896        self.db.query('delete from test where i4=14')
897        self.db.query('insert into test (i4, v4) values('
898            "14, 'abc4')")
899        r = self.db.get('test_view', 14, 'i4')
900        self.assertIn('v4', r)
901        self.assertEqual(r['v4'], 'abc4')
902
903    def testGetLittleBobbyTables(self):
904        get = self.db.get
905        query = self.db.query
906        query("drop table if exists test_students")
907        query("create table test_students (firstname varchar primary key,"
908            " nickname varchar, grade char(2))")
909        query("insert into test_students values ("
910              "'D''Arcy', 'Darcey', 'A+')")
911        query("insert into test_students values ("
912              "'Sheldon', 'Moonpie', 'A+')")
913        query("insert into test_students values ("
914              "'Robert', 'Little Bobby Tables', 'D-')")
915        r = get('test_students', 'Sheldon')
916        self.assertEqual(r, dict(
917            firstname="Sheldon", nickname='Moonpie', grade='A+'))
918        r = get('test_students', 'Robert')
919        self.assertEqual(r, dict(
920            firstname="Robert", nickname='Little Bobby Tables', grade='D-'))
921        r = get('test_students', "D'Arcy")
922        self.assertEqual(r, dict(
923            firstname="D'Arcy", nickname='Darcey', grade='A+'))
924        try:
925            get('test_students', "D' Arcy")
926        except pg.DatabaseError, error:
927            self.assertEqual(str(error),
928                'No such record in public.test_students where firstname = '
929                "'D'' Arcy'")
930        try:
931            get('test_students', "Robert'); TRUNCATE TABLE test_students;--")
932        except pg.DatabaseError, error:
933            self.assertEqual(str(error),
934                'No such record in public.test_students where firstname = '
935                "'Robert''); TRUNCATE TABLE test_students;--'")
936        q = "select * from test_students order by 1 limit 4"
937        r = query(q).getresult()
938        self.assertEqual(len(r), 3)
939        self.assertEqual(r[1][2], 'D-')
940        query('drop table test_students')
941
942    def testInsert(self):
943        insert = self.db.insert
944        query = self.db.query
945        server_version = self.db.server_version
946        for table in ('insert_test_table', 'test table for insert'):
947            query('drop table if exists "%s"' % table)
948            query('create table "%s" ('
949                "i2 smallint, i4 integer, i8 bigint,"
950                " d numeric, f4 real, f8 double precision, m money,"
951                " v4 varchar(4), c4 char(4), t text,"
952                " b boolean, ts timestamp) with oids" % table)
953            oid_table = table
954            if ' ' in table:
955                oid_table = '"%s"' % oid_table
956            oid_table = 'oid(public.%s)' % oid_table
957            tests = [dict(i2=None, i4=None, i8=None),
958                (dict(i2='', i4='', i8=''), dict(i2=None, i4=None, i8=None)),
959                (dict(i2=0, i4=0, i8=0), dict(i2=0, i4=0, i8=0)),
960                dict(i2=42, i4=123456, i8=9876543210),
961                dict(i2=2 ** 15 - 1,
962                    i4=int(2 ** 31 - 1), i8=long(2 ** 63 - 1)),
963                dict(d=None), (dict(d=''), dict(d=None)),
964                dict(d=Decimal(0)), (dict(d=0), dict(d=Decimal(0))),
965                dict(f4=None, f8=None), dict(f4=0, f8=0),
966                (dict(f4='', f8=''), dict(f4=None, f8=None)),
967                (dict(d=1234.5, f4=1234.5, f8=1234.5),
968                      dict(d=Decimal('1234.5'))),
969                dict(d=Decimal('123.456789'), f4=12.375, f8=123.4921875),
970                dict(d=Decimal('123456789.9876543212345678987654321')),
971                dict(m=None), (dict(m=''), dict(m=None)),
972                dict(m=Decimal('-1234.56')),
973                (dict(m='-1234.56'), dict(m=Decimal('-1234.56'))),
974                dict(m=Decimal('1234.56')), dict(m=Decimal('123456')),
975                (dict(m='1234.56'), dict(m=Decimal('1234.56'))),
976                (dict(m=1234.5), dict(m=Decimal('1234.5'))),
977                (dict(m=-1234.5), dict(m=Decimal('-1234.5'))),
978                (dict(m=123456), dict(m=Decimal('123456'))),
979                (dict(m='1234567.89'), dict(m=Decimal('1234567.89'))),
980                dict(b=None), (dict(b=''), dict(b=None)),
981                dict(b='f'), dict(b='t'),
982                (dict(b=0), dict(b='f')), (dict(b=1), dict(b='t')),
983                (dict(b=False), dict(b='f')), (dict(b=True), dict(b='t')),
984                (dict(b='0'), dict(b='f')), (dict(b='1'), dict(b='t')),
985                (dict(b='n'), dict(b='f')), (dict(b='y'), dict(b='t')),
986                (dict(b='no'), dict(b='f')), (dict(b='yes'), dict(b='t')),
987                (dict(b='off'), dict(b='f')), (dict(b='on'), dict(b='t')),
988                dict(v4=None, c4=None, t=None),
989                (dict(v4='', c4='', t=''), dict(c4=' ' * 4)),
990                dict(v4='1234', c4='1234', t='1234' * 10),
991                dict(v4='abcd', c4='abcd', t='abcdefg'),
992                (dict(v4='abc', c4='abc', t='abc'), dict(c4='abc ')),
993                dict(ts=None), (dict(ts=''), dict(ts=None)),
994                (dict(ts=0), dict(ts=None)), (dict(ts=False), dict(ts=None)),
995                dict(ts='2012-12-21 00:00:00'),
996                (dict(ts='2012-12-21'), dict(ts='2012-12-21 00:00:00')),
997                dict(ts='2012-12-21 12:21:12'),
998                dict(ts='2013-01-05 12:13:14'),
999                dict(ts='current_timestamp')]
1000            for test in tests:
1001                if isinstance(test, dict):
1002                    data = test
1003                    change = {}
1004                else:
1005                    data, change = test
1006                expect = data.copy()
1007                expect.update(change)
1008                if data.get('m') and server_version < 90100:
1009                    # PostgreSQL < 9.1 cannot directly convert numbers to money
1010                    data['m'] = "'%s'::money" % data['m']
1011                self.assertEqual(insert(table, data), data)
1012                self.assertIn(oid_table, data)
1013                oid = data[oid_table]
1014                self.assertIsInstance(oid, int)
1015                data = dict(item for item in data.iteritems()
1016                    if item[0] in expect)
1017                ts = expect.get('ts')
1018                if ts == 'current_timestamp':
1019                    ts = expect['ts'] = data['ts']
1020                    if len(ts) > 19:
1021                        self.assertEqual(ts[19], '.')
1022                        ts = ts[:19]
1023                    else:
1024                        self.assertEqual(len(ts), 19)
1025                    self.assertTrue(ts[:4].isdigit())
1026                    self.assertEqual(ts[4], '-')
1027                    self.assertEqual(ts[10], ' ')
1028                    self.assertTrue(ts[11:13].isdigit())
1029                    self.assertEqual(ts[13], ':')
1030                self.assertEqual(data, expect)
1031                data = query(
1032                    'select oid,* from "%s"' % table).dictresult()[0]
1033                self.assertEqual(data['oid'], oid)
1034                data = dict(item for item in data.iteritems()
1035                    if item[0] in expect)
1036                self.assertEqual(data, expect)
1037                query('delete from "%s"' % table)
1038            query('drop table "%s"' % table)
1039
1040    def testUpdate(self):
1041        update = self.db.update
1042        query = self.db.query
1043        for table in ('update_test_table', 'test table for update'):
1044            query('drop table if exists "%s"' % table)
1045            query('create table "%s" ('
1046                "n integer, t text) with oids" % table)
1047            for n, t in enumerate('xyz'):
1048                query('insert into "%s" values('
1049                    "%d, '%s')" % (table, n + 1, t))
1050            self.assertRaises(pg.ProgrammingError, self.db.get, table, 2)
1051            r = self.db.get(table, 2, 'n')
1052            r['t'] = 'u'
1053            s = update(table, r)
1054            self.assertEqual(s, r)
1055            r = query('select t from "%s" where n=2' % table
1056                      ).getresult()[0][0]
1057            self.assertEqual(r, 'u')
1058            query('drop table "%s"' % table)
1059
1060    def testUpdateWithCompositeKey(self):
1061        update = self.db.update
1062        query = self.db.query
1063        table = 'update_test_table_1'
1064        query("drop table if exists %s" % table)
1065        query("create table %s ("
1066            "n integer, t text, primary key (n))" % table)
1067        for n, t in enumerate('abc'):
1068            query("insert into %s values("
1069                "%d, '%s')" % (table, n + 1, t))
1070        self.assertRaises(pg.ProgrammingError, update,
1071                          table, dict(t='b'))
1072        self.assertEqual(update(table, dict(n=2, t='d'))['t'], 'd')
1073        r = query('select t from "%s" where n=2' % table
1074                  ).getresult()[0][0]
1075        self.assertEqual(r, 'd')
1076        query("drop table %s" % table)
1077        table = 'update_test_table_2'
1078        query("drop table if exists %s" % table)
1079        query("create table %s ("
1080            "n integer, m integer, t text, primary key (n, m))" % table)
1081        for n in range(3):
1082            for m in range(2):
1083                t = chr(ord('a') + 2 * n + m)
1084                query("insert into %s values("
1085                    "%d, %d, '%s')" % (table, n + 1, m + 1, t))
1086        self.assertRaises(pg.ProgrammingError, update,
1087                          table, dict(n=2, t='b'))
1088        self.assertEqual(update(table,
1089                                dict(n=2, m=2, t='x'))['t'], 'x')
1090        r = [r[0] for r in query('select t from "%s" where n=2'
1091            ' order by m' % table).getresult()]
1092        self.assertEqual(r, ['c', 'x'])
1093        query("drop table %s" % table)
1094
1095    def testClear(self):
1096        clear = self.db.clear
1097        query = self.db.query
1098        for table in ('clear_test_table', 'test table for clear'):
1099            query('drop table if exists "%s"' % table)
1100            query('create table "%s" ('
1101                "n integer, b boolean, d date, t text)" % table)
1102            r = clear(table)
1103            result = {'n': 0, 'b': 'f', 'd': '', 't': ''}
1104            self.assertEqual(r, result)
1105            r['a'] = r['n'] = 1
1106            r['d'] = r['t'] = 'x'
1107            r['b'] = 't'
1108            r['oid'] = 1L
1109            r = clear(table, r)
1110            result = {'a': 1, 'n': 0, 'b': 'f', 'd': '', 't': '', 'oid': 1L}
1111            self.assertEqual(r, result)
1112            query('drop table "%s"' % table)
1113
1114    def testDelete(self):
1115        delete = self.db.delete
1116        query = self.db.query
1117        for table in ('delete_test_table', 'test table for delete'):
1118            query('drop table if exists "%s"' % table)
1119            query('create table "%s" ('
1120                "n integer, t text) with oids" % table)
1121            for n, t in enumerate('xyz'):
1122                query('insert into "%s" values('
1123                    "%d, '%s')" % (table, n + 1, t))
1124            self.assertRaises(pg.ProgrammingError, self.db.get, table, 2)
1125            r = self.db.get(table, 1, 'n')
1126            s = delete(table, r)
1127            self.assertEqual(s, 1)
1128            r = self.db.get(table, 3, 'n')
1129            s = delete(table, r)
1130            self.assertEqual(s, 1)
1131            s = delete(table, r)
1132            self.assertEqual(s, 0)
1133            r = query('select * from "%s"' % table).dictresult()
1134            self.assertEqual(len(r), 1)
1135            r = r[0]
1136            result = {'n': 2, 't': 'y'}
1137            self.assertEqual(r, result)
1138            r = self.db.get(table, 2, 'n')
1139            s = delete(table, r)
1140            self.assertEqual(s, 1)
1141            s = delete(table, r)
1142            self.assertEqual(s, 0)
1143            self.assertRaises(pg.DatabaseError, self.db.get, table, 2, 'n')
1144            query('drop table "%s"' % table)
1145
1146    def testDeleteWithCompositeKey(self):
1147        query = self.db.query
1148        table = 'delete_test_table_1'
1149        query("drop table if exists %s" % table)
1150        query("create table %s ("
1151            "n integer, t text, primary key (n))" % table)
1152        for n, t in enumerate('abc'):
1153            query("insert into %s values("
1154                "%d, '%s')" % (table, n + 1, t))
1155        self.assertRaises(pg.ProgrammingError, self.db.delete,
1156            table, dict(t='b'))
1157        self.assertEqual(self.db.delete(table, dict(n=2)), 1)
1158        r = query('select t from "%s" where n=2' % table
1159                  ).getresult()
1160        self.assertEqual(r, [])
1161        self.assertEqual(self.db.delete(table, dict(n=2)), 0)
1162        r = query('select t from "%s" where n=3' % table
1163                  ).getresult()[0][0]
1164        self.assertEqual(r, 'c')
1165        query("drop table %s" % table)
1166        table = 'delete_test_table_2'
1167        query("drop table if exists %s" % table)
1168        query("create table %s ("
1169            "n integer, m integer, t text, primary key (n, m))" % table)
1170        for n in range(3):
1171            for m in range(2):
1172                t = chr(ord('a') + 2 * n + m)
1173                query("insert into %s values("
1174                    "%d, %d, '%s')" % (table, n + 1, m + 1, t))
1175        self.assertRaises(pg.ProgrammingError, self.db.delete,
1176            table, dict(n=2, t='b'))
1177        self.assertEqual(self.db.delete(table, dict(n=2, m=2)), 1)
1178        r = [r[0] for r in query('select t from "%s" where n=2'
1179            ' order by m' % table).getresult()]
1180        self.assertEqual(r, ['c'])
1181        self.assertEqual(self.db.delete(table, dict(n=2, m=2)), 0)
1182        r = [r[0] for r in query('select t from "%s" where n=3'
1183            ' order by m' % table).getresult()]
1184        self.assertEqual(r, ['e', 'f'])
1185        self.assertEqual(self.db.delete(table, dict(n=3, m=1)), 1)
1186        r = [r[0] for r in query('select t from "%s" where n=3'
1187            ' order by m' % table).getresult()]
1188        self.assertEqual(r, ['f'])
1189        query("drop table %s" % table)
1190
1191    def testTruncate(self):
1192        truncate = self.db.truncate
1193        self.assertRaises(TypeError, truncate, None)
1194        self.assertRaises(TypeError, truncate, 42)
1195        self.assertRaises(TypeError, truncate, dict(test_table=None))
1196        query = self.db.query
1197        query("drop table if exists test_table")
1198        query("create table test_table (n smallint)")
1199        for i in range(3):
1200            query("insert into test_table values (1)")
1201        q = "select count(*) from test_table"
1202        r = query(q).getresult()[0][0]
1203        self.assertEqual(r, 3)
1204        truncate('test_table')
1205        r = query(q).getresult()[0][0]
1206        self.assertEqual(r, 0)
1207        for i in range(3):
1208            query("insert into test_table values (1)")
1209        r = query(q).getresult()[0][0]
1210        self.assertEqual(r, 3)
1211        truncate('public.test_table')
1212        r = query(q).getresult()[0][0]
1213        self.assertEqual(r, 0)
1214        query("drop table if exists test_table_2")
1215        query('create table test_table_2 (n smallint)')
1216        for t in (list, tuple, set):
1217            for i in range(3):
1218                query("insert into test_table values (1)")
1219                query("insert into test_table_2 values (2)")
1220            q = ("select (select count(*) from test_table),"
1221                " (select count(*) from test_table_2)")
1222            r = query(q).getresult()[0]
1223            self.assertEqual(r, (3, 3))
1224            truncate(t(['test_table', 'test_table_2']))
1225            r = query(q).getresult()[0]
1226            self.assertEqual(r, (0, 0))
1227        query("drop table test_table_2")
1228        query("drop table test_table")
1229
1230    def testTruncateRestart(self):
1231        truncate = self.db.truncate
1232        self.assertRaises(TypeError, truncate, 'test_table', restart='invalid')
1233        query = self.db.query
1234        query("drop table if exists test_table")
1235        query("create table test_table (n serial, t text)")
1236        for n in range(3):
1237            query("insert into test_table (t) values ('test')")
1238        q = "select count(n), min(n), max(n) from test_table"
1239        r = query(q).getresult()[0]
1240        self.assertEqual(r, (3, 1, 3))
1241        truncate('test_table')
1242        r = query(q).getresult()[0]
1243        self.assertEqual(r, (0, None, None))
1244        for n in range(3):
1245            query("insert into test_table (t) values ('test')")
1246        r = query(q).getresult()[0]
1247        self.assertEqual(r, (3, 4, 6))
1248        truncate('test_table', restart=True)
1249        r = query(q).getresult()[0]
1250        self.assertEqual(r, (0, None, None))
1251        for n in range(3):
1252            query("insert into test_table (t) values ('test')")
1253        r = query(q).getresult()[0]
1254        self.assertEqual(r, (3, 1, 3))
1255        query("drop table test_table")
1256
1257    def testTruncateCascade(self):
1258        truncate = self.db.truncate
1259        self.assertRaises(TypeError, truncate, 'test_table', cascade='invalid')
1260        query = self.db.query
1261        query("drop table if exists test_child")
1262        query("drop table if exists test_parent")
1263        query("create table test_parent (n smallint primary key)")
1264        query("create table test_child ("
1265            " n smallint primary key references test_parent (n))")
1266        for n in range(3):
1267            query("insert into test_parent (n) values (%d)" % n)
1268            query("insert into test_child (n) values (%d)" % n)
1269        q = ("select (select count(*) from test_parent),"
1270            " (select count(*) from test_child)")
1271        r = query(q).getresult()[0]
1272        self.assertEqual(r, (3, 3))
1273        self.assertRaises(pg.ProgrammingError, truncate, 'test_parent')
1274        truncate(['test_parent', 'test_child'])
1275        r = query(q).getresult()[0]
1276        self.assertEqual(r, (0, 0))
1277        for n in range(3):
1278            query("insert into test_parent (n) values (%d)" % n)
1279            query("insert into test_child (n) values (%d)" % n)
1280        r = query(q).getresult()[0]
1281        self.assertEqual(r, (3, 3))
1282        truncate('test_parent', cascade=True)
1283        r = query(q).getresult()[0]
1284        self.assertEqual(r, (0, 0))
1285        for n in range(3):
1286            query("insert into test_parent (n) values (%d)" % n)
1287            query("insert into test_child (n) values (%d)" % n)
1288        r = query(q).getresult()[0]
1289        self.assertEqual(r, (3, 3))
1290        truncate('test_child')
1291        r = query(q).getresult()[0]
1292        self.assertEqual(r, (3, 0))
1293        self.assertRaises(pg.ProgrammingError, truncate, 'test_parent')
1294        truncate('test_parent', cascade=True)
1295        r = query(q).getresult()[0]
1296        self.assertEqual(r, (0, 0))
1297        query("drop table test_child")
1298        query("drop table test_parent")
1299
1300    def testTruncateOnly(self):
1301        truncate = self.db.truncate
1302        self.assertRaises(TypeError, truncate, 'test_table', only='invalid')
1303        query = self.db.query
1304        query("drop table if exists test_child")
1305        query("drop table if exists test_parent")
1306        query("create table test_parent (n smallint)")
1307        query("create table test_child ("
1308            " m smallint) inherits (test_parent)")
1309        for n in range(3):
1310            query("insert into test_parent (n) values (1)")
1311            query("insert into test_child (n, m) values (2, 3)")
1312        q = ("select (select count(*) from test_parent),"
1313            " (select count(*) from test_child)")
1314        r = query(q).getresult()[0]
1315        self.assertEqual(r, (6, 3))
1316        truncate('test_parent')
1317        r = query(q).getresult()[0]
1318        self.assertEqual(r, (0, 0))
1319        for n in range(3):
1320            query("insert into test_parent (n) values (1)")
1321            query("insert into test_child (n, m) values (2, 3)")
1322        r = query(q).getresult()[0]
1323        self.assertEqual(r, (6, 3))
1324        truncate('test_parent*')
1325        r = query(q).getresult()[0]
1326        self.assertEqual(r, (0, 0))
1327        for n in range(3):
1328            query("insert into test_parent (n) values (1)")
1329            query("insert into test_child (n, m) values (2, 3)")
1330        r = query(q).getresult()[0]
1331        self.assertEqual(r, (6, 3))
1332        truncate('test_parent', only=True)
1333        r = query(q).getresult()[0]
1334        self.assertEqual(r, (3, 3))
1335        truncate('test_parent', only=False)
1336        r = query(q).getresult()[0]
1337        self.assertEqual(r, (0, 0))
1338        self.assertRaises(ValueError, truncate, 'test_parent*', only=True)
1339        truncate('test_parent*', only=False)
1340        query("drop table if exists test_parent_2")
1341        query("create table test_parent_2 (n smallint)")
1342        query("drop table if exists test_child_2")
1343        query("create table test_child_2 ("
1344            " m smallint) inherits (test_parent_2)")
1345        for n in range(3):
1346            query("insert into test_parent (n) values (1)")
1347            query("insert into test_child (n, m) values (2, 3)")
1348            query("insert into test_parent_2 (n) values (1)")
1349            query("insert into test_child_2 (n, m) values (2, 3)")
1350        q = ("select (select count(*) from test_parent),"
1351            " (select count(*) from test_child),"
1352            " (select count(*) from test_parent_2),"
1353            " (select count(*) from test_child_2)")
1354        r = query(q).getresult()[0]
1355        self.assertEqual(r, (6, 3, 6, 3))
1356        truncate(['test_parent', 'test_parent_2'], only=[False, True])
1357        r = query(q).getresult()[0]
1358        self.assertEqual(r, (0, 0, 3, 3))
1359        truncate(['test_parent', 'test_parent_2'], only=False)
1360        r = query(q).getresult()[0]
1361        self.assertEqual(r, (0, 0, 0, 0))
1362        self.assertRaises(ValueError, truncate,
1363            ['test_parent*', 'test_child'], only=[True, False])
1364        truncate(['test_parent*', 'test_child'], only=[False, True])
1365        query("drop table test_child_2")
1366        query("drop table test_parent_2")
1367        query("drop table test_child")
1368        query("drop table test_parent")
1369
1370    def testTruncateQuoted(self):
1371        truncate = self.db.truncate
1372        query = self.db.query
1373        table = "test table for truncate()"
1374        query('drop table if exists "%s"' % table)
1375        query('create table "%s" (n smallint)' % table)
1376        for i in range(3):
1377            query('insert into "%s" values (1)' % table)
1378        q = 'select count(*) from "%s"' % table
1379        r = query(q).getresult()[0][0]
1380        self.assertEqual(r, 3)
1381        truncate(table)
1382        r = query(q).getresult()[0][0]
1383        self.assertEqual(r, 0)
1384        for i in range(3):
1385            query('insert into "%s" values (1)' % table)
1386        r = query(q).getresult()[0][0]
1387        self.assertEqual(r, 3)
1388        truncate('public."%s"' % table)
1389        r = query(q).getresult()[0][0]
1390        self.assertEqual(r, 0)
1391        query('drop table "%s"' % table)
1392
1393    def testTransaction(self):
1394        query = self.db.query
1395        query("drop table if exists test_table")
1396        query("create table test_table (n integer)")
1397        self.db.begin()
1398        query("insert into test_table values (1)")
1399        query("insert into test_table values (2)")
1400        self.db.commit()
1401        self.db.begin()
1402        query("insert into test_table values (3)")
1403        query("insert into test_table values (4)")
1404        self.db.rollback()
1405        self.db.begin()
1406        query("insert into test_table values (5)")
1407        self.db.savepoint('before6')
1408        query("insert into test_table values (6)")
1409        self.db.rollback('before6')
1410        query("insert into test_table values (7)")
1411        self.db.commit()
1412        self.db.begin()
1413        self.db.savepoint('before8')
1414        query("insert into test_table values (8)")
1415        self.db.release('before8')
1416        self.assertRaises(pg.ProgrammingError, self.db.rollback, 'before8')
1417        self.db.commit()
1418        self.db.start()
1419        query("insert into test_table values (9)")
1420        self.db.end()
1421        r = [r[0] for r in query(
1422            "select * from test_table order by 1").getresult()]
1423        self.assertEqual(r, [1, 2, 5, 7, 9])
1424        query("drop table test_table")
1425
1426    @unittest.skipIf(no_with, 'context managers not supported')
1427    def testContextManager(self):
1428        query = self.db.query
1429        query("drop table if exists test_table")
1430        query("create table test_table (n integer check(n>0))")
1431        # wrap "with" statements to avoid SyntaxError in Python < 2.5
1432        exec """from __future__ import with_statement\nif True:
1433        with self.db:
1434            query("insert into test_table values (1)")
1435            query("insert into test_table values (2)")
1436        try:
1437            with self.db:
1438                query("insert into test_table values (3)")
1439                query("insert into test_table values (4)")
1440                raise ValueError('test transaction should rollback')
1441        except ValueError, error:
1442            self.assertEqual(str(error), 'test transaction should rollback')
1443        with self.db:
1444            query("insert into test_table values (5)")
1445        try:
1446            with self.db:
1447                query("insert into test_table values (6)")
1448                query("insert into test_table values (-1)")
1449        except pg.ProgrammingError, error:
1450            self.assertTrue('check' in str(error))
1451        with self.db:
1452            query("insert into test_table values (7)")\n"""
1453        r = [r[0] for r in query(
1454            "select * from test_table order by 1").getresult()]
1455        self.assertEqual(r, [1, 2, 5, 7])
1456        query("drop table test_table")
1457
1458    def testBytea(self):
1459        query = self.db.query
1460        query('drop table if exists bytea_test')
1461        query('create table bytea_test ('
1462            'data bytea)')
1463        s = "It's all \\ kinds \x00 of\r nasty \xff stuff!\n"
1464        r = self.db.escape_bytea(s)
1465        query('insert into bytea_test values('
1466            "'%s')" % r)
1467        r = query('select * from bytea_test').getresult()
1468        self.assertTrue(len(r) == 1)
1469        r = r[0]
1470        self.assertTrue(len(r) == 1)
1471        r = r[0]
1472        r = self.db.unescape_bytea(r)
1473        self.assertEqual(r, s)
1474        query('drop table bytea_test')
1475
1476    def testNotificationHandler(self):
1477        # the notification handler itself is tested separately
1478        f = self.db.notification_handler
1479        callback = lambda arg_dict: None
1480        handler = f('test', callback)
1481        self.assertIsInstance(handler, pg.NotificationHandler)
1482        self.assertIs(handler.db, self.db)
1483        self.assertEqual(handler.event, 'test')
1484        self.assertEqual(handler.stop_event, 'stop_test')
1485        self.assertIs(handler.callback, callback)
1486        self.assertIsInstance(handler.arg_dict, dict)
1487        self.assertEqual(handler.arg_dict, {})
1488        self.assertIsNone(handler.timeout)
1489        self.assertFalse(handler.listening)
1490        handler.close()
1491        self.assertIsNone(handler.db)
1492        self.db.reopen()
1493        self.assertIsNone(handler.db)
1494        handler = f('test2', callback, timeout=2)
1495        self.assertIsInstance(handler, pg.NotificationHandler)
1496        self.assertIs(handler.db, self.db)
1497        self.assertEqual(handler.event, 'test2')
1498        self.assertEqual(handler.stop_event, 'stop_test2')
1499        self.assertIs(handler.callback, callback)
1500        self.assertIsInstance(handler.arg_dict, dict)
1501        self.assertEqual(handler.arg_dict, {})
1502        self.assertEqual(handler.timeout, 2)
1503        self.assertFalse(handler.listening)
1504        handler.close()
1505        self.assertIsNone(handler.db)
1506        self.db.reopen()
1507        self.assertIsNone(handler.db)
1508        arg_dict = {'testing': 3}
1509        handler = f('test3', callback, arg_dict=arg_dict)
1510        self.assertIsInstance(handler, pg.NotificationHandler)
1511        self.assertIs(handler.db, self.db)
1512        self.assertEqual(handler.event, 'test3')
1513        self.assertEqual(handler.stop_event, 'stop_test3')
1514        self.assertIs(handler.callback, callback)
1515        self.assertIs(handler.arg_dict, arg_dict)
1516        self.assertEqual(arg_dict['testing'], 3)
1517        self.assertIsNone(handler.timeout)
1518        self.assertFalse(handler.listening)
1519        handler.close()
1520        self.assertIsNone(handler.db)
1521        self.db.reopen()
1522        self.assertIsNone(handler.db)
1523        handler = f('test4', callback, stop_event='stop4')
1524        self.assertIsInstance(handler, pg.NotificationHandler)
1525        self.assertIs(handler.db, self.db)
1526        self.assertEqual(handler.event, 'test4')
1527        self.assertEqual(handler.stop_event, 'stop4')
1528        self.assertIs(handler.callback, callback)
1529        self.assertIsInstance(handler.arg_dict, dict)
1530        self.assertEqual(handler.arg_dict, {})
1531        self.assertIsNone(handler.timeout)
1532        self.assertFalse(handler.listening)
1533        handler.close()
1534        self.assertIsNone(handler.db)
1535        self.db.reopen()
1536        self.assertIsNone(handler.db)
1537        arg_dict = {'testing': 5}
1538        handler = f('test5', callback, arg_dict, 1.5, 'stop5')
1539        self.assertIsInstance(handler, pg.NotificationHandler)
1540        self.assertIs(handler.db, self.db)
1541        self.assertEqual(handler.event, 'test5')
1542        self.assertEqual(handler.stop_event, 'stop5')
1543        self.assertIs(handler.callback, callback)
1544        self.assertIs(handler.arg_dict, arg_dict)
1545        self.assertEqual(arg_dict['testing'], 5)
1546        self.assertEqual(handler.timeout, 1.5)
1547        self.assertFalse(handler.listening)
1548        handler.close()
1549        self.assertIsNone(handler.db)
1550        self.db.reopen()
1551        self.assertIsNone(handler.db)
1552
1553    def testDebugWithCallable(self):
1554        if debug:
1555            self.assertEqual(self.db.debug, debug)
1556        else:
1557            self.assertIsNone(self.db.debug)
1558        s = []
1559        self.db.debug = s.append
1560        try:
1561            self.db.query("select 1")
1562            self.db.query("select 2")
1563            self.assertEqual(s, ["select 1", "select 2"])
1564        finally:
1565            self.db.debug = debug
1566
1567
1568class TestSchemas(unittest.TestCase):
1569    """Test correct handling of schemas (namespaces)."""
1570
1571    @classmethod
1572    def setUpClass(cls):
1573        db = DB()
1574        query = db.query
1575        for num_schema in range(5):
1576            if num_schema:
1577                schema = "s%d" % num_schema
1578                query("drop schema if exists %s cascade" % (schema,))
1579                try:
1580                    query("create schema %s" % (schema,))
1581                except pg.ProgrammingError:
1582                    raise RuntimeError("The test user cannot create schemas.\n"
1583                        "Grant create on database %s to the user"
1584                        " for running these tests." % dbname)
1585            else:
1586                schema = "public"
1587                query("drop table if exists %s.t" % (schema,))
1588                query("drop table if exists %s.t%d" % (schema, num_schema))
1589            query("create table %s.t with oids as select 1 as n, %d as d"
1590                  % (schema, num_schema))
1591            query("create table %s.t%d with oids as select 1 as n, %d as d"
1592                  % (schema, num_schema, num_schema))
1593        db.close()
1594
1595    @classmethod
1596    def tearDownClass(cls):
1597        db = DB()
1598        query = db.query
1599        for num_schema in range(5):
1600            if num_schema:
1601                schema = "s%d" % num_schema
1602                query("drop schema %s cascade" % (schema,))
1603            else:
1604                schema = "public"
1605                query("drop table %s.t" % (schema,))
1606                query("drop table %s.t%d" % (schema, num_schema))
1607        db.close()
1608
1609    def setUp(self):
1610        self.db = DB()
1611
1612    def tearDown(self):
1613        self.db.close()
1614
1615    def testGetTables(self):
1616        tables = self.db.get_tables()
1617        for num_schema in range(5):
1618            if num_schema:
1619                schema = "s" + str(num_schema)
1620            else:
1621                schema = "public"
1622            for t in (schema + ".t",
1623                    schema + ".t" + str(num_schema)):
1624                self.assertIn(t, tables)
1625
1626    def testGetAttnames(self):
1627        get_attnames = self.db.get_attnames
1628        query = self.db.query
1629        result = {'oid': 'int', 'd': 'int', 'n': 'int'}
1630        r = get_attnames("t")
1631        self.assertEqual(r, result)
1632        r = get_attnames("s4.t4")
1633        self.assertEqual(r, result)
1634        query("drop table if exists s3.t3m")
1635        query("create table s3.t3m with oids as select 1 as m")
1636        result_m = {'oid': 'int', 'm': 'int'}
1637        r = get_attnames("s3.t3m")
1638        self.assertEqual(r, result_m)
1639        query("set search_path to s1,s3")
1640        r = get_attnames("t3")
1641        self.assertEqual(r, result)
1642        r = get_attnames("t3m")
1643        self.assertEqual(r, result_m)
1644        query("drop table s3.t3m")
1645
1646    def testGet(self):
1647        get = self.db.get
1648        query = self.db.query
1649        PrgError = pg.ProgrammingError
1650        self.assertEqual(get("t", 1, 'n')['d'], 0)
1651        self.assertEqual(get("t0", 1, 'n')['d'], 0)
1652        self.assertEqual(get("public.t", 1, 'n')['d'], 0)
1653        self.assertEqual(get("public.t0", 1, 'n')['d'], 0)
1654        self.assertRaises(PrgError, get, "public.t1", 1, 'n')
1655        self.assertEqual(get("s1.t1", 1, 'n')['d'], 1)
1656        self.assertEqual(get("s3.t", 1, 'n')['d'], 3)
1657        query("set search_path to s2,s4")
1658        self.assertRaises(PrgError, get, "t1", 1, 'n')
1659        self.assertEqual(get("t4", 1, 'n')['d'], 4)
1660        self.assertRaises(PrgError, get, "t3", 1, 'n')
1661        self.assertEqual(get("t", 1, 'n')['d'], 2)
1662        self.assertEqual(get("s3.t3", 1, 'n')['d'], 3)
1663        query("set search_path to s1,s3")
1664        self.assertRaises(PrgError, get, "t2", 1, 'n')
1665        self.assertEqual(get("t3", 1, 'n')['d'], 3)
1666        self.assertRaises(PrgError, get, "t4", 1, 'n')
1667        self.assertEqual(get("t", 1, 'n')['d'], 1)
1668        self.assertEqual(get("s4.t4", 1, 'n')['d'], 4)
1669
1670    def testMangling(self):
1671        get = self.db.get
1672        query = self.db.query
1673        r = get("t", 1, 'n')
1674        self.assertIn('oid(public.t)', r)
1675        query("set search_path to s2")
1676        r = get("t2", 1, 'n')
1677        self.assertIn('oid(s2.t2)', r)
1678        query("set search_path to s3")
1679        r = get("t", 1, 'n')
1680        self.assertIn('oid(s3.t)', r)
1681
1682
1683if __name__ == '__main__':
1684    unittest.main()
Note: See TracBrowser for help on using the repository browser.