source: branches/4.x/tests/test_classic_dbwrapper.py @ 771

Last change on this file since 771 was 771, checked in by cito, 4 years ago

Back port some minor fixes from the trunk

This also gives better error message if test runner does not support unittest2.

  • Property svn:executable set to *
  • Property svn:keywords set to Id
File size: 68.4 KB
Line 
1#! /usr/bin/python
2# -*- coding: utf-8 -*-
3
4"""Test the classic PyGreSQL interface.
5
6Sub-tests for the DB wrapper object.
7
8Contributed by Christoph Zwerschke.
9
10These tests need a database to test against.
11
12"""
13
14try:
15    import unittest2 as unittest  # for Python < 2.7
16except ImportError:
17    import unittest
18import os
19
20import sys
21
22import pg  # the module under test
23
24from decimal import Decimal
25
26# check whether the "with" statement is supported
27no_with = sys.version_info[:2] < (2, 5)
28
29# We need a database to test against.  If LOCAL_PyGreSQL.py exists we will
30# get our information from that.  Otherwise we use the defaults.
31# The current user must have create schema privilege on the database.
32dbname = 'unittest'
33dbhost = None
34dbport = 5432
35
36debug = False  # let DB wrapper print debugging output
37
38try:
39    from LOCAL_PyGreSQL import *
40except ImportError:
41    pass
42
43windows = os.name == 'nt'
44
45# There is a known a bug in libpq under Windows which can cause
46# the interface to crash when calling PQhost():
47do_not_ask_for_host = windows
48do_not_ask_for_host_reason = 'libpq issue on Windows'
49
50
51def DB():
52    """Create a DB wrapper object connecting to the test database."""
53    db = pg.DB(dbname, dbhost, dbport)
54    if debug:
55        db.debug = debug
56    db.query("set client_min_messages=warning")
57    return db
58
59
60class TestDBClassBasic(unittest.TestCase):
61    """Test existence of the DB class wrapped pg connection methods."""
62
63    def setUp(self):
64        self.db = DB()
65
66    def tearDown(self):
67        try:
68            self.db.close()
69        except pg.InternalError:
70            pass
71
72    def testAllDBAttributes(self):
73        attributes = [
74            'abort',
75            'begin',
76            'cancel', 'clear', 'close', 'commit',
77            'db', 'dbname', 'debug', 'delete',
78            'end', 'endcopy', 'error',
79            'escape_bytea', 'escape_identifier',
80            'escape_literal', 'escape_string',
81            'fileno',
82            'get', 'get_attnames', 'get_databases',
83            'get_notice_receiver', 'get_parameter',
84            'get_relations', 'get_tables',
85            'getline', 'getlo', 'getnotify',
86            'has_table_privilege', 'host',
87            'insert', 'inserttable',
88            'locreate', 'loimport',
89            'notification_handler',
90            'options',
91            'parameter', 'pkey', 'port',
92            'protocol_version', 'putline',
93            'query',
94            'release', 'reopen', 'reset', 'rollback',
95            'savepoint', 'server_version',
96            'set_notice_receiver', 'set_parameter',
97            'source', 'start', 'status',
98            'transaction', 'truncate', 'tty',
99            'unescape_bytea', 'update',
100            'use_regtypes', 'user',
101        ]
102        if self.db.server_version < 90000:  # PostgreSQL < 9.0
103            attributes.remove('escape_identifier')
104            attributes.remove('escape_literal')
105        db_attributes = [a for a in dir(self.db)
106            if not a.startswith('_')]
107        self.assertEqual(attributes, db_attributes)
108
109    def testAttributeDb(self):
110        self.assertEqual(self.db.db.db, dbname)
111
112    def testAttributeDbname(self):
113        self.assertEqual(self.db.dbname, dbname)
114
115    def testAttributeError(self):
116        error = self.db.error
117        self.assertTrue(not error or 'krb5_' in error)
118        self.assertEqual(self.db.error, self.db.db.error)
119
120    @unittest.skipIf(do_not_ask_for_host, do_not_ask_for_host_reason)
121    def testAttributeHost(self):
122        def_host = 'localhost'
123        host = self.db.host
124        self.assertIsInstance(host, str)
125        self.assertEqual(host, dbhost or def_host)
126        self.assertEqual(host, self.db.db.host)
127
128    def testAttributeOptions(self):
129        no_options = ''
130        options = self.db.options
131        self.assertEqual(options, no_options)
132        self.assertEqual(options, self.db.db.options)
133
134    def testAttributePort(self):
135        def_port = 5432
136        port = self.db.port
137        self.assertIsInstance(port, int)
138        self.assertEqual(port, dbport or def_port)
139        self.assertEqual(port, self.db.db.port)
140
141    def testAttributeProtocolVersion(self):
142        protocol_version = self.db.protocol_version
143        self.assertIsInstance(protocol_version, int)
144        self.assertTrue(2 <= protocol_version < 4)
145        self.assertEqual(protocol_version, self.db.db.protocol_version)
146
147    def testAttributeServerVersion(self):
148        server_version = self.db.server_version
149        self.assertIsInstance(server_version, int)
150        self.assertTrue(70400 <= server_version < 100000)
151        self.assertEqual(server_version, self.db.db.server_version)
152
153    def testAttributeStatus(self):
154        status_ok = 1
155        status = self.db.status
156        self.assertIsInstance(status, int)
157        self.assertEqual(status, status_ok)
158        self.assertEqual(status, self.db.db.status)
159
160    def testAttributeTty(self):
161        def_tty = ''
162        tty = self.db.tty
163        self.assertIsInstance(tty, str)
164        self.assertEqual(tty, def_tty)
165        self.assertEqual(tty, self.db.db.tty)
166
167    def testAttributeUser(self):
168        no_user = 'Deprecated facility'
169        user = self.db.user
170        self.assertTrue(user)
171        self.assertIsInstance(user, str)
172        self.assertNotEqual(user, no_user)
173        self.assertEqual(user, self.db.db.user)
174
175    def testMethodEscapeLiteral(self):
176        if self.db.server_version < 90000:  # PostgreSQL < 9.0
177            self.skipTest('Escaping functions not supported')
178        self.assertEqual(self.db.escape_literal(''), "''")
179
180    def testMethodEscapeIdentifier(self):
181        if self.db.server_version < 90000:  # PostgreSQL < 9.0
182            self.skipTest('Escaping functions not supported')
183        self.assertEqual(self.db.escape_identifier(''), '""')
184
185    def testMethodEscapeString(self):
186        self.assertEqual(self.db.escape_string(''), '')
187
188    def testMethodEscapeBytea(self):
189        self.assertEqual(self.db.escape_bytea('').replace(
190            '\\x', '').replace('\\', ''), '')
191
192    def testMethodUnescapeBytea(self):
193        self.assertEqual(self.db.unescape_bytea(''), '')
194
195    def testMethodQuery(self):
196        query = self.db.query
197        query("select 1+1")
198        query("select 1+$1+$2", 2, 3)
199        query("select 1+$1+$2", (2, 3))
200        query("select 1+$1+$2", [2, 3])
201        query("select 1+$1", 1)
202
203    def testMethodQueryEmpty(self):
204        self.assertRaises(ValueError, self.db.query, '')
205
206    def testMethodQueryProgrammingError(self):
207        try:
208            self.db.query("select 1/0")
209        except pg.ProgrammingError, error:
210            self.assertEqual(error.sqlstate, '22012')
211
212    def testMethodEndcopy(self):
213        try:
214            self.db.endcopy()
215        except IOError:
216            pass
217
218    def testMethodClose(self):
219        self.db.close()
220        try:
221            self.db.reset()
222        except pg.Error:
223            pass
224        else:
225            self.fail('Reset should give an error for a closed connection')
226        self.assertIsNone(self.db.db)
227        self.assertRaises(pg.InternalError, self.db.close)
228        self.assertRaises(pg.InternalError, self.db.query, 'select 1')
229        self.assertRaises(pg.InternalError, getattr, self.db, 'status')
230        self.assertRaises(pg.InternalError, getattr, self.db, 'error')
231        self.assertRaises(pg.InternalError, getattr, self.db, 'absent')
232
233    def testMethodReset(self):
234        con = self.db.db
235        self.db.reset()
236        self.assertIs(self.db.db, con)
237        self.db.query("select 1+1")
238        self.db.close()
239        self.assertRaises(pg.InternalError, self.db.reset)
240
241    def testMethodReopen(self):
242        con = self.db.db
243        self.db.reopen()
244        self.assertIsNot(self.db.db, con)
245        con = self.db.db
246        self.db.query("select 1+1")
247        self.db.close()
248        self.db.reopen()
249        self.assertIsNot(self.db.db, con)
250        self.db.query("select 1+1")
251        self.db.close()
252
253    def testExistingConnection(self):
254        db = pg.DB(self.db.db)
255        self.assertEqual(self.db.db, db.db)
256        self.assertTrue(db.db)
257        db.close()
258        self.assertTrue(db.db)
259        db.reopen()
260        self.assertTrue(db.db)
261        db.close()
262        self.assertTrue(db.db)
263        db = pg.DB(self.db)
264        self.assertEqual(self.db.db, db.db)
265        db = pg.DB(db=self.db.db)
266        self.assertEqual(self.db.db, db.db)
267
268        class DB2:
269            pass
270
271        db2 = DB2()
272        db2._cnx = self.db.db
273        db = pg.DB(db2)
274        self.assertEqual(self.db.db, db.db)
275
276
277class TestDBClass(unittest.TestCase):
278    """Test the methods of the DB class wrapped pg connection."""
279
280    cls_set_up = False
281
282    @classmethod
283    def setUpClass(cls):
284        db = DB()
285        db.query("drop table if exists test cascade")
286        db.query("create table test ("
287            "i2 smallint, i4 integer, i8 bigint,"
288            " d numeric, f4 real, f8 double precision, m money,"
289            " v4 varchar(4), c4 char(4), t text)")
290        db.query("create or replace view test_view as"
291            " select i4, v4 from test")
292        db.close()
293        cls.cls_set_up = True
294
295    @classmethod
296    def tearDownClass(cls):
297        db = DB()
298        db.query("drop table test cascade")
299        db.close()
300
301    def setUp(self):
302        self.assertTrue(self.cls_set_up)
303        self.db = DB()
304        query = self.db.query
305        query('set client_encoding=utf8')
306        query('set standard_conforming_strings=on')
307        query("set lc_monetary='C'")
308        query("set datestyle='ISO,YMD'")
309        try:
310            query('set bytea_output=hex')
311        except pg.ProgrammingError:  # PostgreSQL < 9.0
312            pass
313
314    def tearDown(self):
315        self.db.close()
316
317    def testEscapeLiteral(self):
318        if self.db.server_version < 90000:  # PostgreSQL < 9.0
319            self.skipTest('Escaping functions not supported')
320        f = self.db.escape_literal
321        self.assertEqual(f("plain"), "'plain'")
322        self.assertEqual(f("that's k\xe4se"), "'that''s k\xe4se'")
323        self.assertEqual(f(r"It's fine to have a \ inside."),
324            r" E'It''s fine to have a \\ inside.'")
325        self.assertEqual(f('No "quotes" must be escaped.'),
326            "'No \"quotes\" must be escaped.'")
327
328    def testEscapeIdentifier(self):
329        if self.db.server_version < 90000:  # PostgreSQL < 9.0
330            self.skipTest('Escaping functions not supported')
331        f = self.db.escape_identifier
332        self.assertEqual(f("plain"), '"plain"')
333        self.assertEqual(f("that's k\xe4se"), '"that\'s k\xe4se"')
334        self.assertEqual(f(r"It's fine to have a \ inside."),
335            '"It\'s fine to have a \\ inside."')
336        self.assertEqual(f('All "quotes" must be escaped.'),
337            '"All ""quotes"" must be escaped."')
338
339    def testEscapeString(self):
340        f = self.db.escape_string
341        self.assertEqual(f("plain"), "plain")
342        self.assertEqual(f("that's k\xe4se"), "that''s k\xe4se")
343        self.assertEqual(f(r"It's fine to have a \ inside."),
344            r"It''s fine to have a \ inside.")
345
346    def testEscapeBytea(self):
347        f = self.db.escape_bytea
348        # note that escape_byte always returns hex output since PostgreSQL 9.0,
349        # regardless of the bytea_output setting
350        if self.db.server_version < 90000:
351            self.assertEqual(f("plain"), r"plain")
352            self.assertEqual(f("that's k\xe4se"), r"that''s k\344se")
353            self.assertEqual(f('O\x00ps\xff!'), r"O\000ps\377!")
354        else:
355            self.assertEqual(f("plain"), r"\x706c61696e")
356            self.assertEqual(f("that's k\xe4se"), r"\x746861742773206be47365")
357            self.assertEqual(f('O\x00ps\xff!'), r"\x4f007073ff21")
358
359    def testUnescapeBytea(self):
360        f = self.db.unescape_bytea
361        self.assertEqual(f("plain"), "plain")
362        self.assertEqual(f("that's k\\344se"), "that's k\xe4se")
363        self.assertEqual(f(r'O\000ps\377!'), 'O\x00ps\xff!')
364        self.assertEqual(f(r"\\x706c61696e"), r"\x706c61696e")
365        self.assertEqual(f(r"\\x746861742773206be47365"),
366            r"\x746861742773206be47365")
367        self.assertEqual(f(r"\\x4f007073ff21"), r"\x4f007073ff21")
368
369    def testQuote(self):
370        f = self.db._quote
371        self.assertEqual(f(None, None), 'NULL')
372        self.assertEqual(f(None, 'int'), 'NULL')
373        self.assertEqual(f(None, 'float'), 'NULL')
374        self.assertEqual(f(None, 'num'), 'NULL')
375        self.assertEqual(f(None, 'money'), 'NULL')
376        self.assertEqual(f(None, 'bool'), 'NULL')
377        self.assertEqual(f(None, 'date'), 'NULL')
378        self.assertEqual(f('', 'int'), 'NULL')
379        self.assertEqual(f('', 'float'), 'NULL')
380        self.assertEqual(f('', 'num'), 'NULL')
381        self.assertEqual(f('', 'money'), 'NULL')
382        self.assertEqual(f('', 'bool'), 'NULL')
383        self.assertEqual(f('', 'date'), 'NULL')
384        self.assertEqual(f('', 'text'), "''")
385        self.assertEqual(f(0, 'int'), '0')
386        self.assertEqual(f(0, 'num'), '0')
387        self.assertEqual(f(1, 'int'), '1')
388        self.assertEqual(f(1, 'num'), '1')
389        self.assertEqual(f(-1, 'int'), '-1')
390        self.assertEqual(f(-1, 'num'), '-1')
391        self.assertEqual(f(123456789, 'int'), '123456789')
392        self.assertEqual(f(123456987, 'num'), '123456987')
393        self.assertEqual(f(1.23654789, 'num'), '1.23654789')
394        self.assertEqual(f(12365478.9, 'num'), '12365478.9')
395        self.assertEqual(f('123456789', 'num'), '123456789')
396        self.assertEqual(f('1.23456789', 'num'), '1.23456789')
397        self.assertEqual(f('12345678.9', 'num'), '12345678.9')
398        self.assertEqual(f(123, 'money'), '123')
399        self.assertEqual(f('123', 'money'), '123')
400        self.assertEqual(f(123.45, 'money'), '123.45')
401        self.assertEqual(f('123.45', 'money'), '123.45')
402        self.assertEqual(f(123.454, 'money'), '123.454')
403        self.assertEqual(f('123.454', 'money'), '123.454')
404        self.assertEqual(f(123.456, 'money'), '123.456')
405        self.assertEqual(f('123.456', 'money'), '123.456')
406        self.assertEqual(f('f', 'bool'), "'f'")
407        self.assertEqual(f('F', 'bool'), "'f'")
408        self.assertEqual(f('false', 'bool'), "'f'")
409        self.assertEqual(f('False', 'bool'), "'f'")
410        self.assertEqual(f('FALSE', 'bool'), "'f'")
411        self.assertEqual(f(0, 'bool'), "'f'")
412        self.assertEqual(f('0', 'bool'), "'f'")
413        self.assertEqual(f('-', 'bool'), "'f'")
414        self.assertEqual(f('n', 'bool'), "'f'")
415        self.assertEqual(f('N', 'bool'), "'f'")
416        self.assertEqual(f('no', 'bool'), "'f'")
417        self.assertEqual(f('off', 'bool'), "'f'")
418        self.assertEqual(f('t', 'bool'), "'t'")
419        self.assertEqual(f('T', 'bool'), "'t'")
420        self.assertEqual(f('true', 'bool'), "'t'")
421        self.assertEqual(f('True', 'bool'), "'t'")
422        self.assertEqual(f('TRUE', 'bool'), "'t'")
423        self.assertEqual(f(1, 'bool'), "'t'")
424        self.assertEqual(f(2, 'bool'), "'t'")
425        self.assertEqual(f(-1, 'bool'), "'t'")
426        self.assertEqual(f(0.5, 'bool'), "'t'")
427        self.assertEqual(f('1', 'bool'), "'t'")
428        self.assertEqual(f('y', 'bool'), "'t'")
429        self.assertEqual(f('Y', 'bool'), "'t'")
430        self.assertEqual(f('yes', 'bool'), "'t'")
431        self.assertEqual(f('on', 'bool'), "'t'")
432        self.assertEqual(f('01.01.2000', 'date'), "'01.01.2000'")
433        self.assertEqual(f(123, 'text'), "'123'")
434        self.assertEqual(f(1.23, 'text'), "'1.23'")
435        self.assertEqual(f('abc', 'text'), "'abc'")
436        self.assertEqual(f("ab'c", 'text'), "'ab''c'")
437        self.assertEqual(f('ab\\c', 'text'), "'ab\\c'")
438        self.assertEqual(f("a\\b'c", 'text'), "'a\\b''c'")
439        self.db.query('set standard_conforming_strings=off')
440        self.assertEqual(f('ab\\c', 'text'), "'ab\\\\c'")
441        self.assertEqual(f("a\\b'c", 'text'), "'a\\\\b''c'")
442
443    def testGetParameter(self):
444        f = self.db.get_parameter
445        self.assertRaises(TypeError, f)
446        self.assertRaises(TypeError, f, None)
447        self.assertRaises(TypeError, f, 42)
448        self.assertRaises(TypeError, f, '')
449        self.assertRaises(TypeError, f, [])
450        self.assertRaises(TypeError, f, [''])
451        self.assertRaises(pg.ProgrammingError, f, 'this_does_not_exist')
452        r = f('standard_conforming_strings')
453        self.assertEqual(r, 'on')
454        r = f('lc_monetary')
455        self.assertEqual(r, 'C')
456        r = f('datestyle')
457        self.assertEqual(r, 'ISO, YMD')
458        r = f('bytea_output')
459        self.assertEqual(r, 'hex')
460        r = f(['bytea_output', 'lc_monetary'])
461        self.assertIsInstance(r, list)
462        self.assertEqual(r, ['hex', 'C'])
463        r = f(('standard_conforming_strings', 'datestyle', 'bytea_output'))
464        self.assertEqual(r, ['on', 'ISO, YMD', 'hex'])
465        r = f(set(['bytea_output', 'lc_monetary']))
466        self.assertIsInstance(r, dict)
467        self.assertEqual(r, {'bytea_output': 'hex', 'lc_monetary': 'C'})
468        r = f(set(['Bytea_Output', ' LC_Monetary ']))
469        self.assertIsInstance(r, dict)
470        self.assertEqual(r, {'Bytea_Output': 'hex', ' LC_Monetary ': 'C'})
471        s = dict.fromkeys(('bytea_output', 'lc_monetary'))
472        r = f(s)
473        self.assertIs(r, s)
474        self.assertEqual(r, {'bytea_output': 'hex', 'lc_monetary': 'C'})
475        s = dict.fromkeys(('Bytea_Output', ' LC_Monetary '))
476        r = f(s)
477        self.assertIs(r, s)
478        self.assertEqual(r, {'Bytea_Output': 'hex', ' LC_Monetary ': 'C'})
479
480    def testGetParameterServerVersion(self):
481        r = self.db.get_parameter('server_version_num')
482        self.assertIsInstance(r, str)
483        s = self.db.server_version
484        self.assertIsInstance(s, int)
485        self.assertEqual(r, str(s))
486
487    def testGetParameterAll(self):
488        f = self.db.get_parameter
489        r = f('all')
490        self.assertIsInstance(r, dict)
491        self.assertEqual(r['standard_conforming_strings'], 'on')
492        self.assertEqual(r['lc_monetary'], 'C')
493        self.assertEqual(r['DateStyle'], 'ISO, YMD')
494        self.assertEqual(r['bytea_output'], 'hex')
495
496    def testSetParameter(self):
497        f = self.db.set_parameter
498        g = self.db.get_parameter
499        self.assertRaises(TypeError, f)
500        self.assertRaises(TypeError, f, None)
501        self.assertRaises(TypeError, f, 42)
502        self.assertRaises(TypeError, f, '')
503        self.assertRaises(TypeError, f, [])
504        self.assertRaises(TypeError, f, [''])
505        self.assertRaises(ValueError, f, 'all', 'invalid')
506        self.assertRaises(ValueError, f, {
507            'invalid1': 'value1', 'invalid2': 'value2'}, 'value')
508        self.assertRaises(pg.ProgrammingError, f, 'this_does_not_exist')
509        f('standard_conforming_strings', 'off')
510        self.assertEqual(g('standard_conforming_strings'), 'off')
511        f('datestyle', 'ISO, DMY')
512        self.assertEqual(g('datestyle'), 'ISO, DMY')
513        f(['standard_conforming_strings', 'datestyle'], ['on', 'ISO, DMY'])
514        self.assertEqual(g('standard_conforming_strings'), 'on')
515        self.assertEqual(g('datestyle'), 'ISO, DMY')
516        f(['default_with_oids', 'standard_conforming_strings'], 'off')
517        self.assertEqual(g('default_with_oids'), 'off')
518        self.assertEqual(g('standard_conforming_strings'), 'off')
519        f(('standard_conforming_strings', 'datestyle'), ('on', 'ISO, YMD'))
520        self.assertEqual(g('standard_conforming_strings'), 'on')
521        self.assertEqual(g('datestyle'), 'ISO, YMD')
522        f(('default_with_oids', 'standard_conforming_strings'), 'off')
523        self.assertEqual(g('default_with_oids'), 'off')
524        self.assertEqual(g('standard_conforming_strings'), 'off')
525        f(set(['default_with_oids', 'standard_conforming_strings']), 'on')
526        self.assertEqual(g('default_with_oids'), 'on')
527        self.assertEqual(g('standard_conforming_strings'), 'on')
528        self.assertRaises(ValueError, f, set([ 'default_with_oids',
529            'standard_conforming_strings']), ['off', 'on'])
530        f(set(['default_with_oids', 'standard_conforming_strings']),
531            ['off', 'off'])
532        self.assertEqual(g('default_with_oids'), 'off')
533        self.assertEqual(g('standard_conforming_strings'), 'off')
534        f({'standard_conforming_strings': 'on', 'datestyle': 'ISO, YMD'})
535        self.assertEqual(g('standard_conforming_strings'), 'on')
536        self.assertEqual(g('datestyle'), 'ISO, YMD')
537
538    def testResetParameter(self):
539        db = DB()
540        f = db.set_parameter
541        g = db.get_parameter
542        r = g('default_with_oids')
543        self.assertIn(r, ('on', 'off'))
544        dwi, not_dwi = r, r == 'on' and 'off' or 'on'
545        r = g('standard_conforming_strings')
546        self.assertIn(r, ('on', 'off'))
547        scs, not_scs = r, r == 'on' and 'off' or 'on'
548        f('default_with_oids', not_dwi)
549        f('standard_conforming_strings', not_scs)
550        self.assertEqual(g('default_with_oids'), not_dwi)
551        self.assertEqual(g('standard_conforming_strings'), not_scs)
552        f('default_with_oids')
553        f('standard_conforming_strings', None)
554        self.assertEqual(g('default_with_oids'), dwi)
555        self.assertEqual(g('standard_conforming_strings'), scs)
556        f('default_with_oids', not_dwi)
557        f('standard_conforming_strings', not_scs)
558        self.assertEqual(g('default_with_oids'), not_dwi)
559        self.assertEqual(g('standard_conforming_strings'), not_scs)
560        f(['default_with_oids', 'standard_conforming_strings'], None)
561        self.assertEqual(g('default_with_oids'), dwi)
562        self.assertEqual(g('standard_conforming_strings'), scs)
563        f('default_with_oids', not_dwi)
564        f('standard_conforming_strings', not_scs)
565        self.assertEqual(g('default_with_oids'), not_dwi)
566        self.assertEqual(g('standard_conforming_strings'), not_scs)
567        f(('default_with_oids', 'standard_conforming_strings'))
568        self.assertEqual(g('default_with_oids'), dwi)
569        self.assertEqual(g('standard_conforming_strings'), scs)
570        f('default_with_oids', not_dwi)
571        f('standard_conforming_strings', not_scs)
572        self.assertEqual(g('default_with_oids'), not_dwi)
573        self.assertEqual(g('standard_conforming_strings'), not_scs)
574        f(set(['default_with_oids', 'standard_conforming_strings']), None)
575        self.assertEqual(g('default_with_oids'), dwi)
576        self.assertEqual(g('standard_conforming_strings'), scs)
577
578    def testResetParameterAll(self):
579        db = DB()
580        f = db.set_parameter
581        self.assertRaises(ValueError, f, 'all', 0)
582        self.assertRaises(ValueError, f, 'all', 'off')
583        g = db.get_parameter
584        r = g('default_with_oids')
585        self.assertIn(r, ('on', 'off'))
586        dwi, not_dwi = r, r == 'on' and 'off' or 'on'
587        r = g('standard_conforming_strings')
588        self.assertIn(r, ('on', 'off'))
589        scs, not_scs = r, r == 'on' and 'off' or 'on'
590        f('default_with_oids', not_dwi)
591        f('standard_conforming_strings', not_scs)
592        self.assertEqual(g('default_with_oids'), not_dwi)
593        self.assertEqual(g('standard_conforming_strings'), not_scs)
594        f('all')
595        self.assertEqual(g('default_with_oids'), dwi)
596        self.assertEqual(g('standard_conforming_strings'), scs)
597
598    def testSetParameterLocal(self):
599        f = self.db.set_parameter
600        g = self.db.get_parameter
601        self.assertEqual(g('standard_conforming_strings'), 'on')
602        self.db.begin()
603        f('standard_conforming_strings', 'off', local=True)
604        self.assertEqual(g('standard_conforming_strings'), 'off')
605        self.db.end()
606        self.assertEqual(g('standard_conforming_strings'), 'on')
607
608    def testSetParameterSession(self):
609        f = self.db.set_parameter
610        g = self.db.get_parameter
611        self.assertEqual(g('standard_conforming_strings'), 'on')
612        self.db.begin()
613        f('standard_conforming_strings', 'off', local=False)
614        self.assertEqual(g('standard_conforming_strings'), 'off')
615        self.db.end()
616        self.assertEqual(g('standard_conforming_strings'), 'off')
617
618    def testQuery(self):
619        query = self.db.query
620        query("drop table if exists test_table")
621        q = "create table test_table (n integer) with oids"
622        r = query(q)
623        self.assertIsNone(r)
624        q = "insert into test_table values (1)"
625        r = query(q)
626        self.assertIsInstance(r, int)
627        q = "insert into test_table select 2"
628        r = query(q)
629        self.assertIsInstance(r, int)
630        oid = r
631        q = "select oid from test_table where n=2"
632        r = query(q).getresult()
633        self.assertEqual(len(r), 1)
634        r = r[0]
635        self.assertEqual(len(r), 1)
636        r = r[0]
637        self.assertEqual(r, oid)
638        q = "insert into test_table select 3 union select 4 union select 5"
639        r = query(q)
640        self.assertIsInstance(r, str)
641        self.assertEqual(r, '3')
642        q = "update test_table set n=4 where n<5"
643        r = query(q)
644        self.assertIsInstance(r, str)
645        self.assertEqual(r, '4')
646        q = "delete from test_table"
647        r = query(q)
648        self.assertIsInstance(r, str)
649        self.assertEqual(r, '5')
650        query("drop table test_table")
651
652    def testMultipleQueries(self):
653        self.assertEqual(self.db.query(
654            "create temporary table test_multi (n integer);"
655            "insert into test_multi values (4711);"
656            "select n from test_multi").getresult()[0][0], 4711)
657
658    def testQueryWithParams(self):
659        query = self.db.query
660        query("drop table if exists test_table")
661        q = "create table test_table (n1 integer, n2 integer) with oids"
662        query(q)
663        q = "insert into test_table values ($1, $2)"
664        r = query(q, (1, 2))
665        self.assertIsInstance(r, int)
666        r = query(q, [3, 4])
667        self.assertIsInstance(r, int)
668        r = query(q, [5, 6])
669        self.assertIsInstance(r, int)
670        q = "select * from test_table order by 1, 2"
671        self.assertEqual(query(q).getresult(),
672            [(1, 2), (3, 4), (5, 6)])
673        q = "select * from test_table where n1=$1 and n2=$2"
674        self.assertEqual(query(q, 3, 4).getresult(), [(3, 4)])
675        q = "update test_table set n2=$2 where n1=$1"
676        r = query(q, 3, 7)
677        self.assertEqual(r, '1')
678        q = "select * from test_table order by 1, 2"
679        self.assertEqual(query(q).getresult(),
680            [(1, 2), (3, 7), (5, 6)])
681        q = "delete from test_table where n2!=$1"
682        r = query(q, 4)
683        self.assertEqual(r, '3')
684        query("drop table test_table")
685
686    def testEmptyQuery(self):
687        self.assertRaises(ValueError, self.db.query, '')
688
689    def testQueryProgrammingError(self):
690        try:
691            self.db.query("select 1/0")
692        except pg.ProgrammingError, error:
693            self.assertEqual(error.sqlstate, '22012')
694
695    def testPkey(self):
696        query = self.db.query
697        for n in range(4):
698            query("drop table if exists pkeytest%d" % n)
699        query("create table pkeytest0 ("
700            "a smallint)")
701        query("create table pkeytest1 ("
702            "b smallint primary key)")
703        query("create table pkeytest2 ("
704            "c smallint, d smallint primary key)")
705        query("create table pkeytest3 ("
706            "e smallint, f smallint, g smallint,"
707            " h smallint, i smallint,"
708            " primary key (f,h))")
709        pkey = self.db.pkey
710        self.assertRaises(KeyError, pkey, 'pkeytest0')
711        self.assertEqual(pkey('pkeytest1'), 'b')
712        self.assertEqual(pkey('pkeytest2'), 'd')
713        self.assertEqual(pkey('pkeytest3'), frozenset('fh'))
714        self.assertEqual(pkey('pkeytest0', 'none'), 'none')
715        self.assertEqual(pkey('pkeytest0'), 'none')
716        pkey(None, {'t': 'a', 'n.t': 'b'})
717        self.assertEqual(pkey('t'), 'a')
718        self.assertEqual(pkey('n.t'), 'b')
719        self.assertRaises(KeyError, pkey, 'pkeytest0')
720        for n in range(4):
721            query("drop table pkeytest%d" % n)
722
723    def testGetDatabases(self):
724        databases = self.db.get_databases()
725        self.assertIn('template0', databases)
726        self.assertIn('template1', databases)
727        self.assertNotIn('not existing database', databases)
728        self.assertIn('postgres', databases)
729        self.assertIn(dbname, databases)
730
731    def testGetTables(self):
732        get_tables = self.db.get_tables
733        result1 = get_tables()
734        self.assertIsInstance(result1, list)
735        for t in result1:
736            t = t.split('.', 1)
737            self.assertGreaterEqual(len(t), 2)
738            if len(t) > 2:
739                self.assertTrue(t[1].startswith('"'))
740            t = t[0]
741            self.assertNotEqual(t, 'information_schema')
742            self.assertFalse(t.startswith('pg_'))
743        tables = ('"A very Special Name"',
744            '"A_MiXeD_quoted_NaMe"', 'a1', 'a2',
745            'A_MiXeD_NaMe', '"another special name"',
746            'averyveryveryveryveryveryverylongtablename',
747            'b0', 'b3', 'x', 'xx', 'xXx', 'y', 'z')
748        for t in tables:
749            self.db.query('drop table if exists %s' % t)
750            self.db.query("create table %s"
751                " as select 0" % t)
752        result3 = get_tables()
753        result2 = []
754        for t in result3:
755            if t not in result1:
756                result2.append(t)
757        result3 = []
758        for t in tables:
759            if not t.startswith('"'):
760                t = t.lower()
761            result3.append('public.' + t)
762        self.assertEqual(result2, result3)
763        for t in result2:
764            self.db.query('drop table %s' % t)
765        result2 = get_tables()
766        self.assertEqual(result2, result1)
767
768    def testGetRelations(self):
769        get_relations = self.db.get_relations
770        result = get_relations()
771        self.assertIn('public.test', result)
772        self.assertIn('public.test_view', result)
773        result = get_relations('rv')
774        self.assertIn('public.test', result)
775        self.assertIn('public.test_view', result)
776        result = get_relations('r')
777        self.assertIn('public.test', result)
778        self.assertNotIn('public.test_view', result)
779        result = get_relations('v')
780        self.assertNotIn('public.test', result)
781        self.assertIn('public.test_view', result)
782        result = get_relations('cisSt')
783        self.assertNotIn('public.test', result)
784        self.assertNotIn('public.test_view', result)
785
786    def testGetAttnames(self):
787        self.assertRaises(pg.ProgrammingError,
788            self.db.get_attnames, 'does_not_exist')
789        self.assertRaises(pg.ProgrammingError,
790            self.db.get_attnames, 'has.too.many.dots')
791        attributes = self.db.get_attnames('test')
792        self.assertIsInstance(attributes, dict)
793        self.assertEqual(attributes, dict(
794            i2='int', i4='int', i8='int', d='num',
795            f4='float', f8='float', m='money',
796            v4='text', c4='text', t='text'))
797        for table in ('attnames_test_table', 'test table for attnames'):
798            self.db.query('drop table if exists "%s"' % table)
799            self.db.query('create table "%s" ('
800                ' a smallint, b integer, c bigint,'
801                ' e numeric, f float, f2 double precision, m money,'
802                ' x smallint, y smallint, z smallint,'
803                ' Normal_NaMe smallint, "Special Name" smallint,'
804                ' t text, u char(2), v varchar(2),'
805                ' primary key (y, u)) with oids' % table)
806            attributes = self.db.get_attnames(table)
807            result = {'a': 'int', 'c': 'int', 'b': 'int',
808                'e': 'num', 'f': 'float', 'f2': 'float', 'm': 'money',
809                'normal_name': 'int', 'Special Name': 'int',
810                'u': 'text', 't': 'text', 'v': 'text',
811                'y': 'int', 'x': 'int', 'z': 'int', 'oid': 'int'}
812            self.assertEqual(attributes, result)
813            self.db.query('drop table "%s"' % table)
814
815    def testHasTablePrivilege(self):
816        can = self.db.has_table_privilege
817        self.assertEqual(can('test'), True)
818        self.assertEqual(can('test', 'select'), True)
819        self.assertEqual(can('test', 'SeLeCt'), True)
820        self.assertEqual(can('test', 'SELECT'), True)
821        self.assertEqual(can('test', 'insert'), True)
822        self.assertEqual(can('test', 'update'), True)
823        self.assertEqual(can('test', 'delete'), True)
824        self.assertRaises(pg.ProgrammingError, can, 'test', 'foobar')
825        self.assertRaises(pg.ProgrammingError, can, 'table_does_not_exist')
826
827    def testGet(self):
828        get = self.db.get
829        query = self.db.query
830        for table in ('get_test_table', 'test table for get'):
831            query('drop table if exists "%s"' % table)
832            query('create table "%s" ('
833                "n integer, t text) with oids" % table)
834            for n, t in enumerate('xyz'):
835                query('insert into "%s" values('"%d, '%s')"
836                    % (table, n + 1, t))
837            self.assertRaises(pg.ProgrammingError, get, table, 2)
838            r = get(table, 2, 'n')
839            oid_table = table
840            if ' ' in table:
841                oid_table = '"%s"' % oid_table
842            oid_table = 'oid(public.%s)' % oid_table
843            self.assertIn(oid_table, r)
844            oid = r[oid_table]
845            self.assertIsInstance(oid, int)
846            result = {'t': 'y', 'n': 2, oid_table: oid}
847            self.assertEqual(r, result)
848            self.assertEqual(get(table + ' *', 2, 'n'), r)
849            self.assertEqual(get(table, oid, 'oid')['t'], 'y')
850            self.assertEqual(get(table, 1, 'n')['t'], 'x')
851            self.assertEqual(get(table, 3, 'n')['t'], 'z')
852            self.assertEqual(get(table, 2, 'n')['t'], 'y')
853            self.assertRaises(pg.DatabaseError, get, table, 4, 'n')
854            r['n'] = 3
855            self.assertEqual(get(table, r, 'n')['t'], 'z')
856            self.assertEqual(get(table, 1, 'n')['t'], 'x')
857            query('alter table "%s" alter n set not null' % table)
858            query('alter table "%s" add primary key (n)' % table)
859            self.assertEqual(get(table, 3)['t'], 'z')
860            self.assertEqual(get(table, 1)['t'], 'x')
861            self.assertEqual(get(table, 2)['t'], 'y')
862            r['n'] = 1
863            self.assertEqual(get(table, r)['t'], 'x')
864            r['n'] = 3
865            self.assertEqual(get(table, r)['t'], 'z')
866            r['n'] = 2
867            self.assertEqual(get(table, r)['t'], 'y')
868            query('drop table "%s"' % table)
869
870    def testGetWithCompositeKey(self):
871        get = self.db.get
872        query = self.db.query
873        table = 'get_test_table_1'
874        query("drop table if exists %s" % table)
875        query("create table %s ("
876            "n integer, t text, primary key (n))" % table)
877        for n, t in enumerate('abc'):
878            query("insert into %s values("
879                "%d, '%s')" % (table, n + 1, t))
880        self.assertEqual(get(table, 2)['t'], 'b')
881        query("drop table %s" % table)
882        table = 'get_test_table_2'
883        query("drop table if exists %s" % table)
884        query("create table %s ("
885            "n integer, m integer, t text, primary key (n, m))" % table)
886        for n in range(3):
887            for m in range(2):
888                t = chr(ord('a') + 2 * n + m)
889                query("insert into %s values("
890                    "%d, %d, '%s')" % (table, n + 1, m + 1, t))
891        self.assertRaises(pg.ProgrammingError, get, table, 2)
892        self.assertEqual(get(table, dict(n=2, m=2))['t'], 'd')
893        self.assertEqual(get(table, dict(n=1, m=2),
894                             ('n', 'm'))['t'], 'b')
895        self.assertEqual(get(table, dict(n=3, m=2),
896                             frozenset(['n', 'm']))['t'], 'f')
897        query("drop table %s" % table)
898
899    def testGetFromView(self):
900        self.db.query('delete from test where i4=14')
901        self.db.query('insert into test (i4, v4) values('
902            "14, 'abc4')")
903        r = self.db.get('test_view', 14, 'i4')
904        self.assertIn('v4', r)
905        self.assertEqual(r['v4'], 'abc4')
906
907    def testGetLittleBobbyTables(self):
908        get = self.db.get
909        query = self.db.query
910        query("drop table if exists test_students")
911        query("create table test_students (firstname varchar primary key,"
912            " nickname varchar, grade char(2))")
913        query("insert into test_students values ("
914              "'D''Arcy', 'Darcey', 'A+')")
915        query("insert into test_students values ("
916              "'Sheldon', 'Moonpie', 'A+')")
917        query("insert into test_students values ("
918              "'Robert', 'Little Bobby Tables', 'D-')")
919        r = get('test_students', 'Sheldon')
920        self.assertEqual(r, dict(
921            firstname="Sheldon", nickname='Moonpie', grade='A+'))
922        r = get('test_students', 'Robert')
923        self.assertEqual(r, dict(
924            firstname="Robert", nickname='Little Bobby Tables', grade='D-'))
925        r = get('test_students', "D'Arcy")
926        self.assertEqual(r, dict(
927            firstname="D'Arcy", nickname='Darcey', grade='A+'))
928        try:
929            get('test_students', "D' Arcy")
930        except pg.DatabaseError, error:
931            self.assertEqual(str(error),
932                'No such record in public.test_students where firstname = '
933                "'D'' Arcy'")
934        try:
935            get('test_students', "Robert'); TRUNCATE TABLE test_students;--")
936        except pg.DatabaseError, error:
937            self.assertEqual(str(error),
938                'No such record in public.test_students where firstname = '
939                "'Robert''); TRUNCATE TABLE test_students;--'")
940        q = "select * from test_students order by 1 limit 4"
941        r = query(q).getresult()
942        self.assertEqual(len(r), 3)
943        self.assertEqual(r[1][2], 'D-')
944        query('drop table test_students')
945
946    def testInsert(self):
947        insert = self.db.insert
948        query = self.db.query
949        server_version = self.db.server_version
950        for table in ('insert_test_table', 'test table for insert'):
951            query('drop table if exists "%s"' % table)
952            query('create table "%s" ('
953                "i2 smallint, i4 integer, i8 bigint,"
954                " d numeric, f4 real, f8 double precision, m money,"
955                " v4 varchar(4), c4 char(4), t text,"
956                " b boolean, ts timestamp) with oids" % table)
957            oid_table = table
958            if ' ' in table:
959                oid_table = '"%s"' % oid_table
960            oid_table = 'oid(public.%s)' % oid_table
961            tests = [dict(i2=None, i4=None, i8=None),
962                (dict(i2='', i4='', i8=''), dict(i2=None, i4=None, i8=None)),
963                (dict(i2=0, i4=0, i8=0), dict(i2=0, i4=0, i8=0)),
964                dict(i2=42, i4=123456, i8=9876543210),
965                dict(i2=2 ** 15 - 1,
966                    i4=int(2 ** 31 - 1), i8=long(2 ** 63 - 1)),
967                dict(d=None), (dict(d=''), dict(d=None)),
968                dict(d=Decimal(0)), (dict(d=0), dict(d=Decimal(0))),
969                dict(f4=None, f8=None), dict(f4=0, f8=0),
970                (dict(f4='', f8=''), dict(f4=None, f8=None)),
971                (dict(d=1234.5, f4=1234.5, f8=1234.5),
972                      dict(d=Decimal('1234.5'))),
973                dict(d=Decimal('123.456789'), f4=12.375, f8=123.4921875),
974                dict(d=Decimal('123456789.9876543212345678987654321')),
975                dict(m=None), (dict(m=''), dict(m=None)),
976                dict(m=Decimal('-1234.56')),
977                (dict(m='-1234.56'), dict(m=Decimal('-1234.56'))),
978                dict(m=Decimal('1234.56')), dict(m=Decimal('123456')),
979                (dict(m='1234.56'), dict(m=Decimal('1234.56'))),
980                (dict(m=1234.5), dict(m=Decimal('1234.5'))),
981                (dict(m=-1234.5), dict(m=Decimal('-1234.5'))),
982                (dict(m=123456), dict(m=Decimal('123456'))),
983                (dict(m='1234567.89'), dict(m=Decimal('1234567.89'))),
984                dict(b=None), (dict(b=''), dict(b=None)),
985                dict(b='f'), dict(b='t'),
986                (dict(b=0), dict(b='f')), (dict(b=1), dict(b='t')),
987                (dict(b=False), dict(b='f')), (dict(b=True), dict(b='t')),
988                (dict(b='0'), dict(b='f')), (dict(b='1'), dict(b='t')),
989                (dict(b='n'), dict(b='f')), (dict(b='y'), dict(b='t')),
990                (dict(b='no'), dict(b='f')), (dict(b='yes'), dict(b='t')),
991                (dict(b='off'), dict(b='f')), (dict(b='on'), dict(b='t')),
992                dict(v4=None, c4=None, t=None),
993                (dict(v4='', c4='', t=''), dict(c4=' ' * 4)),
994                dict(v4='1234', c4='1234', t='1234' * 10),
995                dict(v4='abcd', c4='abcd', t='abcdefg'),
996                (dict(v4='abc', c4='abc', t='abc'), dict(c4='abc ')),
997                dict(ts=None), (dict(ts=''), dict(ts=None)),
998                (dict(ts=0), dict(ts=None)), (dict(ts=False), dict(ts=None)),
999                dict(ts='2012-12-21 00:00:00'),
1000                (dict(ts='2012-12-21'), dict(ts='2012-12-21 00:00:00')),
1001                dict(ts='2012-12-21 12:21:12'),
1002                dict(ts='2013-01-05 12:13:14'),
1003                dict(ts='current_timestamp')]
1004            for test in tests:
1005                if isinstance(test, dict):
1006                    data = test
1007                    change = {}
1008                else:
1009                    data, change = test
1010                expect = data.copy()
1011                expect.update(change)
1012                if data.get('m') and server_version < 90100:
1013                    # PostgreSQL < 9.1 cannot directly convert numbers to money
1014                    data['m'] = "'%s'::money" % data['m']
1015                self.assertEqual(insert(table, data), data)
1016                self.assertIn(oid_table, data)
1017                oid = data[oid_table]
1018                self.assertIsInstance(oid, int)
1019                data = dict(item for item in data.iteritems()
1020                    if item[0] in expect)
1021                ts = expect.get('ts')
1022                if ts == 'current_timestamp':
1023                    ts = expect['ts'] = data['ts']
1024                    if len(ts) > 19:
1025                        self.assertEqual(ts[19], '.')
1026                        ts = ts[:19]
1027                    else:
1028                        self.assertEqual(len(ts), 19)
1029                    self.assertTrue(ts[:4].isdigit())
1030                    self.assertEqual(ts[4], '-')
1031                    self.assertEqual(ts[10], ' ')
1032                    self.assertTrue(ts[11:13].isdigit())
1033                    self.assertEqual(ts[13], ':')
1034                self.assertEqual(data, expect)
1035                data = query(
1036                    'select oid,* from "%s"' % table).dictresult()[0]
1037                self.assertEqual(data['oid'], oid)
1038                data = dict(item for item in data.iteritems()
1039                    if item[0] in expect)
1040                self.assertEqual(data, expect)
1041                query('delete from "%s"' % table)
1042            query('drop table "%s"' % table)
1043
1044    def testUpdate(self):
1045        update = self.db.update
1046        query = self.db.query
1047        for table in ('update_test_table', 'test table for update'):
1048            query('drop table if exists "%s"' % table)
1049            query('create table "%s" ('
1050                "n integer, t text) with oids" % table)
1051            for n, t in enumerate('xyz'):
1052                query('insert into "%s" values('
1053                    "%d, '%s')" % (table, n + 1, t))
1054            self.assertRaises(pg.ProgrammingError, self.db.get, table, 2)
1055            r = self.db.get(table, 2, 'n')
1056            r['t'] = 'u'
1057            s = update(table, r)
1058            self.assertEqual(s, r)
1059            r = query('select t from "%s" where n=2' % table
1060                      ).getresult()[0][0]
1061            self.assertEqual(r, 'u')
1062            query('drop table "%s"' % table)
1063
1064    def testUpdateWithCompositeKey(self):
1065        update = self.db.update
1066        query = self.db.query
1067        table = 'update_test_table_1'
1068        query("drop table if exists %s" % table)
1069        query("create table %s ("
1070            "n integer, t text, primary key (n))" % table)
1071        for n, t in enumerate('abc'):
1072            query("insert into %s values("
1073                "%d, '%s')" % (table, n + 1, t))
1074        self.assertRaises(pg.ProgrammingError, update,
1075                          table, dict(t='b'))
1076        self.assertEqual(update(table, dict(n=2, t='d'))['t'], 'd')
1077        r = query('select t from "%s" where n=2' % table
1078                  ).getresult()[0][0]
1079        self.assertEqual(r, 'd')
1080        query("drop table %s" % table)
1081        table = 'update_test_table_2'
1082        query("drop table if exists %s" % table)
1083        query("create table %s ("
1084            "n integer, m integer, t text, primary key (n, m))" % table)
1085        for n in range(3):
1086            for m in range(2):
1087                t = chr(ord('a') + 2 * n + m)
1088                query("insert into %s values("
1089                    "%d, %d, '%s')" % (table, n + 1, m + 1, t))
1090        self.assertRaises(pg.ProgrammingError, update,
1091                          table, dict(n=2, t='b'))
1092        self.assertEqual(update(table,
1093                                dict(n=2, m=2, t='x'))['t'], 'x')
1094        r = [r[0] for r in query('select t from "%s" where n=2'
1095            ' order by m' % table).getresult()]
1096        self.assertEqual(r, ['c', 'x'])
1097        query("drop table %s" % table)
1098
1099    def testClear(self):
1100        clear = self.db.clear
1101        query = self.db.query
1102        for table in ('clear_test_table', 'test table for clear'):
1103            query('drop table if exists "%s"' % table)
1104            query('create table "%s" ('
1105                "n integer, b boolean, d date, t text)" % table)
1106            r = clear(table)
1107            result = {'n': 0, 'b': 'f', 'd': '', 't': ''}
1108            self.assertEqual(r, result)
1109            r['a'] = r['n'] = 1
1110            r['d'] = r['t'] = 'x'
1111            r['b'] = 't'
1112            r['oid'] = 1L
1113            r = clear(table, r)
1114            result = {'a': 1, 'n': 0, 'b': 'f', 'd': '', 't': '', 'oid': 1L}
1115            self.assertEqual(r, result)
1116            query('drop table "%s"' % table)
1117
1118    def testDelete(self):
1119        delete = self.db.delete
1120        query = self.db.query
1121        for table in ('delete_test_table', 'test table for delete'):
1122            query('drop table if exists "%s"' % table)
1123            query('create table "%s" ('
1124                "n integer, t text) with oids" % table)
1125            for n, t in enumerate('xyz'):
1126                query('insert into "%s" values('
1127                    "%d, '%s')" % (table, n + 1, t))
1128            self.assertRaises(pg.ProgrammingError, self.db.get, table, 2)
1129            r = self.db.get(table, 1, 'n')
1130            s = delete(table, r)
1131            self.assertEqual(s, 1)
1132            r = self.db.get(table, 3, 'n')
1133            s = delete(table, r)
1134            self.assertEqual(s, 1)
1135            s = delete(table, r)
1136            self.assertEqual(s, 0)
1137            r = query('select * from "%s"' % table).dictresult()
1138            self.assertEqual(len(r), 1)
1139            r = r[0]
1140            result = {'n': 2, 't': 'y'}
1141            self.assertEqual(r, result)
1142            r = self.db.get(table, 2, 'n')
1143            s = delete(table, r)
1144            self.assertEqual(s, 1)
1145            s = delete(table, r)
1146            self.assertEqual(s, 0)
1147            self.assertRaises(pg.DatabaseError, self.db.get, table, 2, 'n')
1148            query('drop table "%s"' % table)
1149
1150    def testDeleteWithCompositeKey(self):
1151        query = self.db.query
1152        table = 'delete_test_table_1'
1153        query("drop table if exists %s" % table)
1154        query("create table %s ("
1155            "n integer, t text, primary key (n))" % table)
1156        for n, t in enumerate('abc'):
1157            query("insert into %s values("
1158                "%d, '%s')" % (table, n + 1, t))
1159        self.assertRaises(pg.ProgrammingError, self.db.delete,
1160            table, dict(t='b'))
1161        self.assertEqual(self.db.delete(table, dict(n=2)), 1)
1162        r = query('select t from "%s" where n=2' % table
1163                  ).getresult()
1164        self.assertEqual(r, [])
1165        self.assertEqual(self.db.delete(table, dict(n=2)), 0)
1166        r = query('select t from "%s" where n=3' % table
1167                  ).getresult()[0][0]
1168        self.assertEqual(r, 'c')
1169        query("drop table %s" % table)
1170        table = 'delete_test_table_2'
1171        query("drop table if exists %s" % table)
1172        query("create table %s ("
1173            "n integer, m integer, t text, primary key (n, m))" % table)
1174        for n in range(3):
1175            for m in range(2):
1176                t = chr(ord('a') + 2 * n + m)
1177                query("insert into %s values("
1178                    "%d, %d, '%s')" % (table, n + 1, m + 1, t))
1179        self.assertRaises(pg.ProgrammingError, self.db.delete,
1180            table, dict(n=2, t='b'))
1181        self.assertEqual(self.db.delete(table, dict(n=2, m=2)), 1)
1182        r = [r[0] for r in query('select t from "%s" where n=2'
1183            ' order by m' % table).getresult()]
1184        self.assertEqual(r, ['c'])
1185        self.assertEqual(self.db.delete(table, dict(n=2, m=2)), 0)
1186        r = [r[0] for r in query('select t from "%s" where n=3'
1187            ' order by m' % table).getresult()]
1188        self.assertEqual(r, ['e', 'f'])
1189        self.assertEqual(self.db.delete(table, dict(n=3, m=1)), 1)
1190        r = [r[0] for r in query('select t from "%s" where n=3'
1191            ' order by m' % table).getresult()]
1192        self.assertEqual(r, ['f'])
1193        query("drop table %s" % table)
1194
1195    def testTruncate(self):
1196        truncate = self.db.truncate
1197        self.assertRaises(TypeError, truncate, None)
1198        self.assertRaises(TypeError, truncate, 42)
1199        self.assertRaises(TypeError, truncate, dict(test_table=None))
1200        query = self.db.query
1201        query("drop table if exists test_table")
1202        query("create table test_table (n smallint)")
1203        for i in range(3):
1204            query("insert into test_table values (1)")
1205        q = "select count(*) from test_table"
1206        r = query(q).getresult()[0][0]
1207        self.assertEqual(r, 3)
1208        truncate('test_table')
1209        r = query(q).getresult()[0][0]
1210        self.assertEqual(r, 0)
1211        for i in range(3):
1212            query("insert into test_table values (1)")
1213        r = query(q).getresult()[0][0]
1214        self.assertEqual(r, 3)
1215        truncate('public.test_table')
1216        r = query(q).getresult()[0][0]
1217        self.assertEqual(r, 0)
1218        query("drop table if exists test_table_2")
1219        query('create table test_table_2 (n smallint)')
1220        for t in (list, tuple, set):
1221            for i in range(3):
1222                query("insert into test_table values (1)")
1223                query("insert into test_table_2 values (2)")
1224            q = ("select (select count(*) from test_table),"
1225                " (select count(*) from test_table_2)")
1226            r = query(q).getresult()[0]
1227            self.assertEqual(r, (3, 3))
1228            truncate(t(['test_table', 'test_table_2']))
1229            r = query(q).getresult()[0]
1230            self.assertEqual(r, (0, 0))
1231        query("drop table test_table_2")
1232        query("drop table test_table")
1233
1234    def testTruncateRestart(self):
1235        truncate = self.db.truncate
1236        self.assertRaises(TypeError, truncate, 'test_table', restart='invalid')
1237        query = self.db.query
1238        query("drop table if exists test_table")
1239        query("create table test_table (n serial, t text)")
1240        for n in range(3):
1241            query("insert into test_table (t) values ('test')")
1242        q = "select count(n), min(n), max(n) from test_table"
1243        r = query(q).getresult()[0]
1244        self.assertEqual(r, (3, 1, 3))
1245        truncate('test_table')
1246        r = query(q).getresult()[0]
1247        self.assertEqual(r, (0, None, None))
1248        for n in range(3):
1249            query("insert into test_table (t) values ('test')")
1250        r = query(q).getresult()[0]
1251        self.assertEqual(r, (3, 4, 6))
1252        truncate('test_table', restart=True)
1253        r = query(q).getresult()[0]
1254        self.assertEqual(r, (0, None, None))
1255        for n in range(3):
1256            query("insert into test_table (t) values ('test')")
1257        r = query(q).getresult()[0]
1258        self.assertEqual(r, (3, 1, 3))
1259        query("drop table test_table")
1260
1261    def testTruncateCascade(self):
1262        truncate = self.db.truncate
1263        self.assertRaises(TypeError, truncate, 'test_table', cascade='invalid')
1264        query = self.db.query
1265        query("drop table if exists test_child")
1266        query("drop table if exists test_parent")
1267        query("create table test_parent (n smallint primary key)")
1268        query("create table test_child ("
1269            " n smallint primary key references test_parent (n))")
1270        for n in range(3):
1271            query("insert into test_parent (n) values (%d)" % n)
1272            query("insert into test_child (n) values (%d)" % n)
1273        q = ("select (select count(*) from test_parent),"
1274            " (select count(*) from test_child)")
1275        r = query(q).getresult()[0]
1276        self.assertEqual(r, (3, 3))
1277        self.assertRaises(pg.ProgrammingError, truncate, 'test_parent')
1278        truncate(['test_parent', 'test_child'])
1279        r = query(q).getresult()[0]
1280        self.assertEqual(r, (0, 0))
1281        for n in range(3):
1282            query("insert into test_parent (n) values (%d)" % n)
1283            query("insert into test_child (n) values (%d)" % n)
1284        r = query(q).getresult()[0]
1285        self.assertEqual(r, (3, 3))
1286        truncate('test_parent', cascade=True)
1287        r = query(q).getresult()[0]
1288        self.assertEqual(r, (0, 0))
1289        for n in range(3):
1290            query("insert into test_parent (n) values (%d)" % n)
1291            query("insert into test_child (n) values (%d)" % n)
1292        r = query(q).getresult()[0]
1293        self.assertEqual(r, (3, 3))
1294        truncate('test_child')
1295        r = query(q).getresult()[0]
1296        self.assertEqual(r, (3, 0))
1297        self.assertRaises(pg.ProgrammingError, truncate, 'test_parent')
1298        truncate('test_parent', cascade=True)
1299        r = query(q).getresult()[0]
1300        self.assertEqual(r, (0, 0))
1301        query("drop table test_child")
1302        query("drop table test_parent")
1303
1304    def testTruncateOnly(self):
1305        truncate = self.db.truncate
1306        self.assertRaises(TypeError, truncate, 'test_table', only='invalid')
1307        query = self.db.query
1308        query("drop table if exists test_child")
1309        query("drop table if exists test_parent")
1310        query("create table test_parent (n smallint)")
1311        query("create table test_child ("
1312            " m smallint) inherits (test_parent)")
1313        for n in range(3):
1314            query("insert into test_parent (n) values (1)")
1315            query("insert into test_child (n, m) values (2, 3)")
1316        q = ("select (select count(*) from test_parent),"
1317            " (select count(*) from test_child)")
1318        r = query(q).getresult()[0]
1319        self.assertEqual(r, (6, 3))
1320        truncate('test_parent')
1321        r = query(q).getresult()[0]
1322        self.assertEqual(r, (0, 0))
1323        for n in range(3):
1324            query("insert into test_parent (n) values (1)")
1325            query("insert into test_child (n, m) values (2, 3)")
1326        r = query(q).getresult()[0]
1327        self.assertEqual(r, (6, 3))
1328        truncate('test_parent*')
1329        r = query(q).getresult()[0]
1330        self.assertEqual(r, (0, 0))
1331        for n in range(3):
1332            query("insert into test_parent (n) values (1)")
1333            query("insert into test_child (n, m) values (2, 3)")
1334        r = query(q).getresult()[0]
1335        self.assertEqual(r, (6, 3))
1336        truncate('test_parent', only=True)
1337        r = query(q).getresult()[0]
1338        self.assertEqual(r, (3, 3))
1339        truncate('test_parent', only=False)
1340        r = query(q).getresult()[0]
1341        self.assertEqual(r, (0, 0))
1342        self.assertRaises(ValueError, truncate, 'test_parent*', only=True)
1343        truncate('test_parent*', only=False)
1344        query("drop table if exists test_parent_2")
1345        query("create table test_parent_2 (n smallint)")
1346        query("drop table if exists test_child_2")
1347        query("create table test_child_2 ("
1348            " m smallint) inherits (test_parent_2)")
1349        for n in range(3):
1350            query("insert into test_parent (n) values (1)")
1351            query("insert into test_child (n, m) values (2, 3)")
1352            query("insert into test_parent_2 (n) values (1)")
1353            query("insert into test_child_2 (n, m) values (2, 3)")
1354        q = ("select (select count(*) from test_parent),"
1355            " (select count(*) from test_child),"
1356            " (select count(*) from test_parent_2),"
1357            " (select count(*) from test_child_2)")
1358        r = query(q).getresult()[0]
1359        self.assertEqual(r, (6, 3, 6, 3))
1360        truncate(['test_parent', 'test_parent_2'], only=[False, True])
1361        r = query(q).getresult()[0]
1362        self.assertEqual(r, (0, 0, 3, 3))
1363        truncate(['test_parent', 'test_parent_2'], only=False)
1364        r = query(q).getresult()[0]
1365        self.assertEqual(r, (0, 0, 0, 0))
1366        self.assertRaises(ValueError, truncate,
1367            ['test_parent*', 'test_child'], only=[True, False])
1368        truncate(['test_parent*', 'test_child'], only=[False, True])
1369        query("drop table test_child_2")
1370        query("drop table test_parent_2")
1371        query("drop table test_child")
1372        query("drop table test_parent")
1373
1374    def testTruncateQuoted(self):
1375        truncate = self.db.truncate
1376        query = self.db.query
1377        table = "test table for truncate()"
1378        query('drop table if exists "%s"' % table)
1379        query('create table "%s" (n smallint)' % table)
1380        for i in range(3):
1381            query('insert into "%s" values (1)' % table)
1382        q = 'select count(*) from "%s"' % table
1383        r = query(q).getresult()[0][0]
1384        self.assertEqual(r, 3)
1385        truncate(table)
1386        r = query(q).getresult()[0][0]
1387        self.assertEqual(r, 0)
1388        for i in range(3):
1389            query('insert into "%s" values (1)' % table)
1390        r = query(q).getresult()[0][0]
1391        self.assertEqual(r, 3)
1392        truncate('public."%s"' % table)
1393        r = query(q).getresult()[0][0]
1394        self.assertEqual(r, 0)
1395        query('drop table "%s"' % table)
1396
1397    def testTransaction(self):
1398        query = self.db.query
1399        query("drop table if exists test_table")
1400        query("create table test_table (n integer)")
1401        self.db.begin()
1402        query("insert into test_table values (1)")
1403        query("insert into test_table values (2)")
1404        self.db.commit()
1405        self.db.begin()
1406        query("insert into test_table values (3)")
1407        query("insert into test_table values (4)")
1408        self.db.rollback()
1409        self.db.begin()
1410        query("insert into test_table values (5)")
1411        self.db.savepoint('before6')
1412        query("insert into test_table values (6)")
1413        self.db.rollback('before6')
1414        query("insert into test_table values (7)")
1415        self.db.commit()
1416        self.db.begin()
1417        self.db.savepoint('before8')
1418        query("insert into test_table values (8)")
1419        self.db.release('before8')
1420        self.assertRaises(pg.ProgrammingError, self.db.rollback, 'before8')
1421        self.db.commit()
1422        self.db.start()
1423        query("insert into test_table values (9)")
1424        self.db.end()
1425        r = [r[0] for r in query(
1426            "select * from test_table order by 1").getresult()]
1427        self.assertEqual(r, [1, 2, 5, 7, 9])
1428        query("drop table test_table")
1429
1430    @unittest.skipIf(no_with, 'context managers not supported')
1431    def testContextManager(self):
1432        query = self.db.query
1433        query("drop table if exists test_table")
1434        query("create table test_table (n integer check(n>0))")
1435        # wrap "with" statements to avoid SyntaxError in Python < 2.5
1436        exec """from __future__ import with_statement\nif True:
1437        with self.db:
1438            query("insert into test_table values (1)")
1439            query("insert into test_table values (2)")
1440        try:
1441            with self.db:
1442                query("insert into test_table values (3)")
1443                query("insert into test_table values (4)")
1444                raise ValueError('test transaction should rollback')
1445        except ValueError, error:
1446            self.assertEqual(str(error), 'test transaction should rollback')
1447        with self.db:
1448            query("insert into test_table values (5)")
1449        try:
1450            with self.db:
1451                query("insert into test_table values (6)")
1452                query("insert into test_table values (-1)")
1453        except pg.ProgrammingError, error:
1454            self.assertTrue('check' in str(error))
1455        with self.db:
1456            query("insert into test_table values (7)")\n"""
1457        r = [r[0] for r in query(
1458            "select * from test_table order by 1").getresult()]
1459        self.assertEqual(r, [1, 2, 5, 7])
1460        query("drop table test_table")
1461
1462    def testBytea(self):
1463        query = self.db.query
1464        query('drop table if exists bytea_test')
1465        query('create table bytea_test ('
1466            'data bytea)')
1467        s = "It's all \\ kinds \x00 of\r nasty \xff stuff!\n"
1468        r = self.db.escape_bytea(s)
1469        query('insert into bytea_test values('
1470            "'%s')" % r)
1471        r = query('select * from bytea_test').getresult()
1472        self.assertTrue(len(r) == 1)
1473        r = r[0]
1474        self.assertTrue(len(r) == 1)
1475        r = r[0]
1476        r = self.db.unescape_bytea(r)
1477        self.assertEqual(r, s)
1478        query('drop table bytea_test')
1479
1480    def testNotificationHandler(self):
1481        # the notification handler itself is tested separately
1482        f = self.db.notification_handler
1483        callback = lambda arg_dict: None
1484        handler = f('test', callback)
1485        self.assertIsInstance(handler, pg.NotificationHandler)
1486        self.assertIs(handler.db, self.db)
1487        self.assertEqual(handler.event, 'test')
1488        self.assertEqual(handler.stop_event, 'stop_test')
1489        self.assertIs(handler.callback, callback)
1490        self.assertIsInstance(handler.arg_dict, dict)
1491        self.assertEqual(handler.arg_dict, {})
1492        self.assertIsNone(handler.timeout)
1493        self.assertFalse(handler.listening)
1494        handler.close()
1495        self.assertIsNone(handler.db)
1496        self.db.reopen()
1497        self.assertIsNone(handler.db)
1498        handler = f('test2', callback, timeout=2)
1499        self.assertIsInstance(handler, pg.NotificationHandler)
1500        self.assertIs(handler.db, self.db)
1501        self.assertEqual(handler.event, 'test2')
1502        self.assertEqual(handler.stop_event, 'stop_test2')
1503        self.assertIs(handler.callback, callback)
1504        self.assertIsInstance(handler.arg_dict, dict)
1505        self.assertEqual(handler.arg_dict, {})
1506        self.assertEqual(handler.timeout, 2)
1507        self.assertFalse(handler.listening)
1508        handler.close()
1509        self.assertIsNone(handler.db)
1510        self.db.reopen()
1511        self.assertIsNone(handler.db)
1512        arg_dict = {'testing': 3}
1513        handler = f('test3', callback, arg_dict=arg_dict)
1514        self.assertIsInstance(handler, pg.NotificationHandler)
1515        self.assertIs(handler.db, self.db)
1516        self.assertEqual(handler.event, 'test3')
1517        self.assertEqual(handler.stop_event, 'stop_test3')
1518        self.assertIs(handler.callback, callback)
1519        self.assertIs(handler.arg_dict, arg_dict)
1520        self.assertEqual(arg_dict['testing'], 3)
1521        self.assertIsNone(handler.timeout)
1522        self.assertFalse(handler.listening)
1523        handler.close()
1524        self.assertIsNone(handler.db)
1525        self.db.reopen()
1526        self.assertIsNone(handler.db)
1527        handler = f('test4', callback, stop_event='stop4')
1528        self.assertIsInstance(handler, pg.NotificationHandler)
1529        self.assertIs(handler.db, self.db)
1530        self.assertEqual(handler.event, 'test4')
1531        self.assertEqual(handler.stop_event, 'stop4')
1532        self.assertIs(handler.callback, callback)
1533        self.assertIsInstance(handler.arg_dict, dict)
1534        self.assertEqual(handler.arg_dict, {})
1535        self.assertIsNone(handler.timeout)
1536        self.assertFalse(handler.listening)
1537        handler.close()
1538        self.assertIsNone(handler.db)
1539        self.db.reopen()
1540        self.assertIsNone(handler.db)
1541        arg_dict = {'testing': 5}
1542        handler = f('test5', callback, arg_dict, 1.5, 'stop5')
1543        self.assertIsInstance(handler, pg.NotificationHandler)
1544        self.assertIs(handler.db, self.db)
1545        self.assertEqual(handler.event, 'test5')
1546        self.assertEqual(handler.stop_event, 'stop5')
1547        self.assertIs(handler.callback, callback)
1548        self.assertIs(handler.arg_dict, arg_dict)
1549        self.assertEqual(arg_dict['testing'], 5)
1550        self.assertEqual(handler.timeout, 1.5)
1551        self.assertFalse(handler.listening)
1552        handler.close()
1553        self.assertIsNone(handler.db)
1554        self.db.reopen()
1555        self.assertIsNone(handler.db)
1556
1557    def testDebugWithCallable(self):
1558        if debug:
1559            self.assertEqual(self.db.debug, debug)
1560        else:
1561            self.assertIsNone(self.db.debug)
1562        s = []
1563        self.db.debug = s.append
1564        try:
1565            self.db.query("select 1")
1566            self.db.query("select 2")
1567            self.assertEqual(s, ["select 1", "select 2"])
1568        finally:
1569            self.db.debug = debug
1570
1571
1572class TestSchemas(unittest.TestCase):
1573    """Test correct handling of schemas (namespaces)."""
1574
1575    cls_set_up = False
1576
1577    @classmethod
1578    def setUpClass(cls):
1579        db = DB()
1580        query = db.query
1581        for num_schema in range(5):
1582            if num_schema:
1583                schema = "s%d" % num_schema
1584                query("drop schema if exists %s cascade" % (schema,))
1585                try:
1586                    query("create schema %s" % (schema,))
1587                except pg.ProgrammingError:
1588                    raise RuntimeError("The test user cannot create schemas.\n"
1589                        "Grant create on database %s to the user"
1590                        " for running these tests." % dbname)
1591            else:
1592                schema = "public"
1593                query("drop table if exists %s.t" % (schema,))
1594                query("drop table if exists %s.t%d" % (schema, num_schema))
1595            query("create table %s.t with oids as select 1 as n, %d as d"
1596                  % (schema, num_schema))
1597            query("create table %s.t%d with oids as select 1 as n, %d as d"
1598                  % (schema, num_schema, num_schema))
1599        db.close()
1600        cls.cls_set_up = True
1601
1602    @classmethod
1603    def tearDownClass(cls):
1604        db = DB()
1605        query = db.query
1606        for num_schema in range(5):
1607            if num_schema:
1608                schema = "s%d" % num_schema
1609                query("drop schema %s cascade" % (schema,))
1610            else:
1611                schema = "public"
1612                query("drop table %s.t" % (schema,))
1613                query("drop table %s.t%d" % (schema, num_schema))
1614        db.close()
1615
1616    def setUp(self):
1617        self.assertTrue(self.cls_set_up)
1618        self.db = DB()
1619
1620    def tearDown(self):
1621        self.db.close()
1622
1623    def testGetTables(self):
1624        tables = self.db.get_tables()
1625        for num_schema in range(5):
1626            if num_schema:
1627                schema = "s" + str(num_schema)
1628            else:
1629                schema = "public"
1630            for t in (schema + ".t",
1631                    schema + ".t" + str(num_schema)):
1632                self.assertIn(t, tables)
1633
1634    def testGetAttnames(self):
1635        get_attnames = self.db.get_attnames
1636        query = self.db.query
1637        result = {'oid': 'int', 'd': 'int', 'n': 'int'}
1638        r = get_attnames("t")
1639        self.assertEqual(r, result)
1640        r = get_attnames("s4.t4")
1641        self.assertEqual(r, result)
1642        query("drop table if exists s3.t3m")
1643        query("create table s3.t3m with oids as select 1 as m")
1644        result_m = {'oid': 'int', 'm': 'int'}
1645        r = get_attnames("s3.t3m")
1646        self.assertEqual(r, result_m)
1647        query("set search_path to s1,s3")
1648        r = get_attnames("t3")
1649        self.assertEqual(r, result)
1650        r = get_attnames("t3m")
1651        self.assertEqual(r, result_m)
1652        query("drop table s3.t3m")
1653
1654    def testGet(self):
1655        get = self.db.get
1656        query = self.db.query
1657        PrgError = pg.ProgrammingError
1658        self.assertEqual(get("t", 1, 'n')['d'], 0)
1659        self.assertEqual(get("t0", 1, 'n')['d'], 0)
1660        self.assertEqual(get("public.t", 1, 'n')['d'], 0)
1661        self.assertEqual(get("public.t0", 1, 'n')['d'], 0)
1662        self.assertRaises(PrgError, get, "public.t1", 1, 'n')
1663        self.assertEqual(get("s1.t1", 1, 'n')['d'], 1)
1664        self.assertEqual(get("s3.t", 1, 'n')['d'], 3)
1665        query("set search_path to s2,s4")
1666        self.assertRaises(PrgError, get, "t1", 1, 'n')
1667        self.assertEqual(get("t4", 1, 'n')['d'], 4)
1668        self.assertRaises(PrgError, get, "t3", 1, 'n')
1669        self.assertEqual(get("t", 1, 'n')['d'], 2)
1670        self.assertEqual(get("s3.t3", 1, 'n')['d'], 3)
1671        query("set search_path to s1,s3")
1672        self.assertRaises(PrgError, get, "t2", 1, 'n')
1673        self.assertEqual(get("t3", 1, 'n')['d'], 3)
1674        self.assertRaises(PrgError, get, "t4", 1, 'n')
1675        self.assertEqual(get("t", 1, 'n')['d'], 1)
1676        self.assertEqual(get("s4.t4", 1, 'n')['d'], 4)
1677
1678    def testMangling(self):
1679        get = self.db.get
1680        query = self.db.query
1681        r = get("t", 1, 'n')
1682        self.assertIn('oid(public.t)', r)
1683        query("set search_path to s2")
1684        r = get("t2", 1, 'n')
1685        self.assertIn('oid(s2.t2)', r)
1686        query("set search_path to s3")
1687        r = get("t", 1, 'n')
1688        self.assertIn('oid(s3.t)', r)
1689
1690
1691if __name__ == '__main__':
1692    unittest.main()
Note: See TracBrowser for help on using the repository browser.