source: trunk/tests/test_classic_dbwrapper.py @ 748

Last change on this file since 748 was 748, checked in by cito, 4 years ago

Add method truncate() to DB wrapper class

This methods can be used to quickly truncate tables.

Since this is pretty useful and will not break anything, I have
also back ported this addition to the 4.x branch.

Everything is well documented and tested, of course.

  • Property svn:executable set to *
  • Property svn:keywords set to Id
File size: 79.3 KB
Line 
1#! /usr/bin/python
2# -*- coding: utf-8 -*-
3
4"""Test the classic PyGreSQL interface.
5
6Sub-tests for the DB wrapper object.
7
8Contributed by Christoph Zwerschke.
9
10These tests need a database to test against.
11
12"""
13
14try:
15    import unittest2 as unittest  # for Python < 2.7
16except ImportError:
17    import unittest
18import os
19
20import pg  # the module under test
21
22from decimal import Decimal
23
24# We need a database to test against.  If LOCAL_PyGreSQL.py exists we will
25# get our information from that.  Otherwise we use the defaults.
26# The current user must have create schema privilege on the database.
27dbname = 'unittest'
28dbhost = None
29dbport = 5432
30
31debug = False  # let DB wrapper print debugging output
32
33try:
34    from .LOCAL_PyGreSQL import *
35except (ImportError, ValueError):
36    try:
37        from LOCAL_PyGreSQL import *
38    except ImportError:
39        pass
40
41try:
42    long
43except NameError:  # Python >= 3.0
44    long = int
45
46try:
47    unicode
48except NameError:  # Python >= 3.0
49    unicode = str
50
51try:
52    from collections import OrderedDict
53except ImportError:  # Python 2.6 or 3.0
54    OrderedDict = dict
55
56windows = os.name == 'nt'
57
58# There is a known a bug in libpq under Windows which can cause
59# the interface to crash when calling PQhost():
60do_not_ask_for_host = windows
61do_not_ask_for_host_reason = 'libpq issue on Windows'
62
63
64def DB():
65    """Create a DB wrapper object connecting to the test database."""
66    db = pg.DB(dbname, dbhost, dbport)
67    if debug:
68        db.debug = debug
69    db.query("set client_min_messages=warning")
70    return db
71
72
73class TestDBClassBasic(unittest.TestCase):
74    """Test existence of the DB class wrapped pg connection methods."""
75
76    def setUp(self):
77        self.db = DB()
78
79    def tearDown(self):
80        try:
81            self.db.close()
82        except pg.InternalError:
83            pass
84
85    def testAllDBAttributes(self):
86        attributes = [
87            'begin',
88            'cancel', 'clear', 'close', 'commit',
89            'db', 'dbname', 'debug', 'delete',
90            'end', 'endcopy', 'error',
91            'escape_bytea', 'escape_identifier',
92            'escape_literal', 'escape_string',
93            'fileno',
94            'get', 'get_attnames', 'get_databases',
95            'get_notice_receiver', 'get_parameter',
96            'get_relations', 'get_tables',
97            'getline', 'getlo', 'getnotify',
98            'has_table_privilege', 'host',
99            'insert', 'inserttable',
100            'locreate', 'loimport',
101            'notification_handler',
102            'options',
103            'parameter', 'pkey', 'port',
104            'protocol_version', 'putline',
105            'query',
106            'release', 'reopen', 'reset', 'rollback',
107            'savepoint', 'server_version',
108            'set_notice_receiver', 'set_parameter',
109            'source', 'start', 'status',
110            'transaction', 'truncate',
111            'unescape_bytea', 'update', 'upsert',
112            'use_regtypes', 'user',
113        ]
114        db_attributes = [a for a in dir(self.db)
115            if not a.startswith('_')]
116        self.assertEqual(attributes, db_attributes)
117
118    def testAttributeDb(self):
119        self.assertEqual(self.db.db.db, dbname)
120
121    def testAttributeDbname(self):
122        self.assertEqual(self.db.dbname, dbname)
123
124    def testAttributeError(self):
125        error = self.db.error
126        self.assertTrue(not error or 'krb5_' in error)
127        self.assertEqual(self.db.error, self.db.db.error)
128
129    @unittest.skipIf(do_not_ask_for_host, do_not_ask_for_host_reason)
130    def testAttributeHost(self):
131        def_host = 'localhost'
132        host = self.db.host
133        self.assertIsInstance(host, str)
134        self.assertEqual(host, dbhost or def_host)
135        self.assertEqual(host, self.db.db.host)
136
137    def testAttributeOptions(self):
138        no_options = ''
139        options = self.db.options
140        self.assertEqual(options, no_options)
141        self.assertEqual(options, self.db.db.options)
142
143    def testAttributePort(self):
144        def_port = 5432
145        port = self.db.port
146        self.assertIsInstance(port, int)
147        self.assertEqual(port, dbport or def_port)
148        self.assertEqual(port, self.db.db.port)
149
150    def testAttributeProtocolVersion(self):
151        protocol_version = self.db.protocol_version
152        self.assertIsInstance(protocol_version, int)
153        self.assertTrue(2 <= protocol_version < 4)
154        self.assertEqual(protocol_version, self.db.db.protocol_version)
155
156    def testAttributeServerVersion(self):
157        server_version = self.db.server_version
158        self.assertIsInstance(server_version, int)
159        self.assertTrue(70400 <= server_version < 100000)
160        self.assertEqual(server_version, self.db.db.server_version)
161
162    def testAttributeStatus(self):
163        status_ok = 1
164        status = self.db.status
165        self.assertIsInstance(status, int)
166        self.assertEqual(status, status_ok)
167        self.assertEqual(status, self.db.db.status)
168
169    def testAttributeUser(self):
170        no_user = 'Deprecated facility'
171        user = self.db.user
172        self.assertTrue(user)
173        self.assertIsInstance(user, str)
174        self.assertNotEqual(user, no_user)
175        self.assertEqual(user, self.db.db.user)
176
177    def testMethodEscapeLiteral(self):
178        self.assertEqual(self.db.escape_literal(''), "''")
179
180    def testMethodEscapeIdentifier(self):
181        self.assertEqual(self.db.escape_identifier(''), '""')
182
183    def testMethodEscapeString(self):
184        self.assertEqual(self.db.escape_string(''), '')
185
186    def testMethodEscapeBytea(self):
187        self.assertEqual(self.db.escape_bytea('').replace(
188            '\\x', '').replace('\\', ''), '')
189
190    def testMethodUnescapeBytea(self):
191        self.assertEqual(self.db.unescape_bytea(''), b'')
192
193    def testMethodQuery(self):
194        query = self.db.query
195        query("select 1+1")
196        query("select 1+$1+$2", 2, 3)
197        query("select 1+$1+$2", (2, 3))
198        query("select 1+$1+$2", [2, 3])
199        query("select 1+$1", 1)
200
201    def testMethodQueryEmpty(self):
202        self.assertRaises(ValueError, self.db.query, '')
203
204    def testMethodQueryProgrammingError(self):
205        try:
206            self.db.query("select 1/0")
207        except pg.ProgrammingError as error:
208            self.assertEqual(error.sqlstate, '22012')
209
210    def testMethodEndcopy(self):
211        try:
212            self.db.endcopy()
213        except IOError:
214            pass
215
216    def testMethodClose(self):
217        self.db.close()
218        try:
219            self.db.reset()
220        except pg.Error:
221            pass
222        else:
223            self.fail('Reset should give an error for a closed connection')
224        self.assertRaises(pg.InternalError, self.db.close)
225        self.assertRaises(pg.InternalError, self.db.query, 'select 1')
226
227    def testExistingConnection(self):
228        db = pg.DB(self.db.db)
229        self.assertEqual(self.db.db, db.db)
230        self.assertTrue(db.db)
231        db.close()
232        self.assertTrue(db.db)
233        db.reopen()
234        self.assertTrue(db.db)
235        db.close()
236        self.assertTrue(db.db)
237        db = pg.DB(self.db)
238        self.assertEqual(self.db.db, db.db)
239        db = pg.DB(db=self.db.db)
240        self.assertEqual(self.db.db, db.db)
241
242        class DB2:
243            pass
244
245        db2 = DB2()
246        db2._cnx = self.db.db
247        db = pg.DB(db2)
248        self.assertEqual(self.db.db, db.db)
249
250
251class TestDBClass(unittest.TestCase):
252    """Test the methods of the DB class wrapped pg connection."""
253
254    @classmethod
255    def setUpClass(cls):
256        db = DB()
257        db.query("drop table if exists test cascade")
258        db.query("create table test ("
259            "i2 smallint, i4 integer, i8 bigint,"
260            " d numeric, f4 real, f8 double precision, m money,"
261            " v4 varchar(4), c4 char(4), t text)")
262        db.query("create or replace view test_view as"
263            " select i4, v4 from test")
264        db.close()
265
266    @classmethod
267    def tearDownClass(cls):
268        db = DB()
269        db.query("drop table test cascade")
270        db.close()
271
272    def setUp(self):
273        self.db = DB()
274        query = self.db.query
275        query('set client_encoding=utf8')
276        query('set standard_conforming_strings=on')
277        query("set lc_monetary='C'")
278        query("set datestyle='ISO,YMD'")
279        query('set bytea_output=hex')
280
281    def tearDown(self):
282        self.doCleanups()
283        self.db.close()
284
285    def testClassName(self):
286        self.assertEqual(self.db.__class__.__name__, 'DB')
287
288    def testModuleName(self):
289        self.assertEqual(self.db.__module__, 'pg')
290        self.assertEqual(self.db.__class__.__module__, 'pg')
291
292    def testEscapeLiteral(self):
293        f = self.db.escape_literal
294        r = f(b"plain")
295        self.assertIsInstance(r, bytes)
296        self.assertEqual(r, b"'plain'")
297        r = f(u"plain")
298        self.assertIsInstance(r, unicode)
299        self.assertEqual(r, u"'plain'")
300        r = f(u"that's kÀse".encode('utf-8'))
301        self.assertIsInstance(r, bytes)
302        self.assertEqual(r, u"'that''s kÀse'".encode('utf-8'))
303        r = f(u"that's kÀse")
304        self.assertIsInstance(r, unicode)
305        self.assertEqual(r, u"'that''s kÀse'")
306        self.assertEqual(f(r"It's fine to have a \ inside."),
307            r" E'It''s fine to have a \\ inside.'")
308        self.assertEqual(f('No "quotes" must be escaped.'),
309            "'No \"quotes\" must be escaped.'")
310
311    def testEscapeIdentifier(self):
312        f = self.db.escape_identifier
313        r = f(b"plain")
314        self.assertIsInstance(r, bytes)
315        self.assertEqual(r, b'"plain"')
316        r = f(u"plain")
317        self.assertIsInstance(r, unicode)
318        self.assertEqual(r, u'"plain"')
319        r = f(u"that's kÀse".encode('utf-8'))
320        self.assertIsInstance(r, bytes)
321        self.assertEqual(r, u'"that\'s kÀse"'.encode('utf-8'))
322        r = f(u"that's kÀse")
323        self.assertIsInstance(r, unicode)
324        self.assertEqual(r, u'"that\'s kÀse"')
325        self.assertEqual(f(r"It's fine to have a \ inside."),
326            '"It\'s fine to have a \\ inside."')
327        self.assertEqual(f('All "quotes" must be escaped.'),
328            '"All ""quotes"" must be escaped."')
329
330    def testEscapeString(self):
331        f = self.db.escape_string
332        r = f(b"plain")
333        self.assertIsInstance(r, bytes)
334        self.assertEqual(r, b"plain")
335        r = f(u"plain")
336        self.assertIsInstance(r, unicode)
337        self.assertEqual(r, u"plain")
338        r = f(u"that's kÀse".encode('utf-8'))
339        self.assertIsInstance(r, bytes)
340        self.assertEqual(r, u"that''s kÀse".encode('utf-8'))
341        r = f(u"that's kÀse")
342        self.assertIsInstance(r, unicode)
343        self.assertEqual(r, u"that''s kÀse")
344        self.assertEqual(f(r"It's fine to have a \ inside."),
345            r"It''s fine to have a \ inside.")
346
347    def testEscapeBytea(self):
348        f = self.db.escape_bytea
349        # note that escape_byte always returns hex output since Pg 9.0,
350        # regardless of the bytea_output setting
351        r = f(b'plain')
352        self.assertIsInstance(r, bytes)
353        self.assertEqual(r, b'\\x706c61696e')
354        r = f(u'plain')
355        self.assertIsInstance(r, unicode)
356        self.assertEqual(r, u'\\x706c61696e')
357        r = f(u"das is' kÀse".encode('utf-8'))
358        self.assertIsInstance(r, bytes)
359        self.assertEqual(r, b'\\x64617320697327206bc3a47365')
360        r = f(u"das is' kÀse")
361        self.assertIsInstance(r, unicode)
362        self.assertEqual(r, u'\\x64617320697327206bc3a47365')
363        self.assertEqual(f(b'O\x00ps\xff!'), b'\\x4f007073ff21')
364
365    def testUnescapeBytea(self):
366        f = self.db.unescape_bytea
367        r = f(b'plain')
368        self.assertIsInstance(r, bytes)
369        self.assertEqual(r, b'plain')
370        r = f(u'plain')
371        self.assertIsInstance(r, bytes)
372        self.assertEqual(r, b'plain')
373        r = f(b"das is' k\\303\\244se")
374        self.assertIsInstance(r, bytes)
375        self.assertEqual(r, u"das is' kÀse".encode('utf8'))
376        r = f(u"das is' k\\303\\244se")
377        self.assertIsInstance(r, bytes)
378        self.assertEqual(r, u"das is' kÀse".encode('utf8'))
379        self.assertEqual(f(r'O\\000ps\\377!'), b'O\\000ps\\377!')
380        self.assertEqual(f(r'\\x706c61696e'), b'\\x706c61696e')
381        self.assertEqual(f(r'\\x746861742773206be47365'),
382            b'\\x746861742773206be47365')
383        self.assertEqual(f(r'\\x4f007073ff21'), b'\\x4f007073ff21')
384
385    def testGetParameter(self):
386        f = self.db.get_parameter
387        self.assertRaises(TypeError, f)
388        self.assertRaises(TypeError, f, None)
389        self.assertRaises(TypeError, f, 42)
390        self.assertRaises(pg.ProgrammingError, f, 'this_does_not_exist')
391        r = f('standard_conforming_strings')
392        self.assertEqual(r, 'on')
393        r = f('lc_monetary')
394        self.assertEqual(r, 'C')
395        r = f('datestyle')
396        self.assertEqual(r, 'ISO, YMD')
397        r = f('bytea_output')
398        self.assertEqual(r, 'hex')
399        r = f(['bytea_output', 'lc_monetary'])
400        self.assertIsInstance(r, list)
401        self.assertEqual(r, ['hex', 'C'])
402        r = f(('standard_conforming_strings', 'datestyle', 'bytea_output'))
403        self.assertEqual(r, ['on', 'ISO, YMD', 'hex'])
404        r = f(set(['bytea_output', 'lc_monetary']))
405        self.assertIsInstance(r, dict)
406        self.assertEqual(r, {'bytea_output': 'hex', 'lc_monetary': 'C'})
407        r = f(set(['Bytea_Output', ' LC_Monetary ']))
408        self.assertIsInstance(r, dict)
409        self.assertEqual(r, {'Bytea_Output': 'hex', ' LC_Monetary ': 'C'})
410        s = dict.fromkeys(('bytea_output', 'lc_monetary'))
411        r = f(s)
412        self.assertIs(r, s)
413        self.assertEqual(r, {'bytea_output': 'hex', 'lc_monetary': 'C'})
414        s = dict.fromkeys(('Bytea_Output', ' LC_Monetary '))
415        r = f(s)
416        self.assertIs(r, s)
417        self.assertEqual(r, {'Bytea_Output': 'hex', ' LC_Monetary ': 'C'})
418
419    def testGetParameterServerVersion(self):
420        r = self.db.get_parameter('server_version_num')
421        self.assertIsInstance(r, str)
422        s = self.db.server_version
423        self.assertIsInstance(s, int)
424        self.assertEqual(r, str(s))
425
426    def testGetParameterAll(self):
427        f = self.db.get_parameter
428        r = f('all')
429        self.assertIsInstance(r, dict)
430        self.assertEqual(r['standard_conforming_strings'], 'on')
431        self.assertEqual(r['lc_monetary'], 'C')
432        self.assertEqual(r['DateStyle'], 'ISO, YMD')
433        self.assertEqual(r['bytea_output'], 'hex')
434
435    def testSetParameter(self):
436        f = self.db.set_parameter
437        g = self.db.get_parameter
438        self.assertRaises(TypeError, f)
439        self.assertRaises(TypeError, f, None)
440        self.assertRaises(TypeError, f, 42)
441        self.assertRaises(pg.ProgrammingError, f, 'this_does_not_exist')
442        f('standard_conforming_strings', 'off')
443        self.assertEqual(g('standard_conforming_strings'), 'off')
444        f('datestyle', 'ISO, DMY')
445        self.assertEqual(g('datestyle'), 'ISO, DMY')
446        f(['standard_conforming_strings', 'datestyle'], ['on', 'ISO, DMY'])
447        self.assertEqual(g('standard_conforming_strings'), 'on')
448        self.assertEqual(g('datestyle'), 'ISO, DMY')
449        f(['default_with_oids', 'standard_conforming_strings'], 'off')
450        self.assertEqual(g('default_with_oids'), 'off')
451        self.assertEqual(g('standard_conforming_strings'), 'off')
452        f(('standard_conforming_strings', 'datestyle'), ('on', 'ISO, YMD'))
453        self.assertEqual(g('standard_conforming_strings'), 'on')
454        self.assertEqual(g('datestyle'), 'ISO, YMD')
455        f(('default_with_oids', 'standard_conforming_strings'), 'off')
456        self.assertEqual(g('default_with_oids'), 'off')
457        self.assertEqual(g('standard_conforming_strings'), 'off')
458        f(set(['default_with_oids', 'standard_conforming_strings']), 'on')
459        self.assertEqual(g('default_with_oids'), 'on')
460        self.assertEqual(g('standard_conforming_strings'), 'on')
461        self.assertRaises(ValueError, f, set([ 'default_with_oids',
462            'standard_conforming_strings']), ['off', 'on'])
463        f(set(['default_with_oids', 'standard_conforming_strings']),
464            ['off', 'off'])
465        self.assertEqual(g('default_with_oids'), 'off')
466        self.assertEqual(g('standard_conforming_strings'), 'off')
467        f({'standard_conforming_strings': 'on', 'datestyle': 'ISO, YMD'})
468        self.assertEqual(g('standard_conforming_strings'), 'on')
469        self.assertEqual(g('datestyle'), 'ISO, YMD')
470
471    def testResetParameter(self):
472        db = DB()
473        f = db.set_parameter
474        g = db.get_parameter
475        r = g('default_with_oids')
476        self.assertIn(r, ('on', 'off'))
477        dwi, not_dwi = r, 'off' if r == 'on' else 'on'
478        r = g('standard_conforming_strings')
479        self.assertIn(r, ('on', 'off'))
480        scs, not_scs = r, 'off' if r == 'on' else 'on'
481        f('default_with_oids', not_dwi)
482        f('standard_conforming_strings', not_scs)
483        self.assertEqual(g('default_with_oids'), not_dwi)
484        self.assertEqual(g('standard_conforming_strings'), not_scs)
485        f('default_with_oids')
486        f('standard_conforming_strings', None)
487        self.assertEqual(g('default_with_oids'), dwi)
488        self.assertEqual(g('standard_conforming_strings'), scs)
489        f('default_with_oids', not_dwi)
490        f('standard_conforming_strings', not_scs)
491        self.assertEqual(g('default_with_oids'), not_dwi)
492        self.assertEqual(g('standard_conforming_strings'), not_scs)
493        f(['default_with_oids', 'standard_conforming_strings'], None)
494        self.assertEqual(g('default_with_oids'), dwi)
495        self.assertEqual(g('standard_conforming_strings'), scs)
496        f('default_with_oids', not_dwi)
497        f('standard_conforming_strings', not_scs)
498        self.assertEqual(g('default_with_oids'), not_dwi)
499        self.assertEqual(g('standard_conforming_strings'), not_scs)
500        f(('default_with_oids', 'standard_conforming_strings'))
501        self.assertEqual(g('default_with_oids'), dwi)
502        self.assertEqual(g('standard_conforming_strings'), scs)
503        f('default_with_oids', not_dwi)
504        f('standard_conforming_strings', not_scs)
505        self.assertEqual(g('default_with_oids'), not_dwi)
506        self.assertEqual(g('standard_conforming_strings'), not_scs)
507        f(set(['default_with_oids', 'standard_conforming_strings']))
508        self.assertEqual(g('default_with_oids'), dwi)
509        self.assertEqual(g('standard_conforming_strings'), scs)
510
511    def testResetParameterAll(self):
512        db = DB()
513        f = db.set_parameter
514        self.assertRaises(ValueError, f, 'all', 0)
515        self.assertRaises(ValueError, f, 'all', 'off')
516        g = db.get_parameter
517        r = g('default_with_oids')
518        self.assertIn(r, ('on', 'off'))
519        dwi, not_dwi = r, 'off' if r == 'on' else 'on'
520        r = g('standard_conforming_strings')
521        self.assertIn(r, ('on', 'off'))
522        scs, not_scs = r, 'off' if r == 'on' else 'on'
523        f('default_with_oids', not_dwi)
524        f('standard_conforming_strings', not_scs)
525        self.assertEqual(g('default_with_oids'), not_dwi)
526        self.assertEqual(g('standard_conforming_strings'), not_scs)
527        f('all')
528        self.assertEqual(g('default_with_oids'), dwi)
529        self.assertEqual(g('standard_conforming_strings'), scs)
530
531    def testSetParameterLocal(self):
532        f = self.db.set_parameter
533        g = self.db.get_parameter
534        self.assertEqual(g('standard_conforming_strings'), 'on')
535        self.db.begin()
536        f('standard_conforming_strings', 'off', local=True)
537        self.assertEqual(g('standard_conforming_strings'), 'off')
538        self.db.end()
539        self.assertEqual(g('standard_conforming_strings'), 'on')
540
541    def testSetParameterSession(self):
542        f = self.db.set_parameter
543        g = self.db.get_parameter
544        self.assertEqual(g('standard_conforming_strings'), 'on')
545        self.db.begin()
546        f('standard_conforming_strings', 'off', local=False)
547        self.assertEqual(g('standard_conforming_strings'), 'off')
548        self.db.end()
549        self.assertEqual(g('standard_conforming_strings'), 'off')
550
551    def testQuery(self):
552        query = self.db.query
553        query("drop table if exists test_table")
554        self.addCleanup(query, "drop table test_table")
555        q = "create table test_table (n integer) with oids"
556        r = query(q)
557        self.assertIsNone(r)
558        q = "insert into test_table values (1)"
559        r = query(q)
560        self.assertIsInstance(r, int)
561        q = "insert into test_table select 2"
562        r = query(q)
563        self.assertIsInstance(r, int)
564        oid = r
565        q = "select oid from test_table where n=2"
566        r = query(q).getresult()
567        self.assertEqual(len(r), 1)
568        r = r[0]
569        self.assertEqual(len(r), 1)
570        r = r[0]
571        self.assertEqual(r, oid)
572        q = "insert into test_table select 3 union select 4 union select 5"
573        r = query(q)
574        self.assertIsInstance(r, str)
575        self.assertEqual(r, '3')
576        q = "update test_table set n=4 where n<5"
577        r = query(q)
578        self.assertIsInstance(r, str)
579        self.assertEqual(r, '4')
580        q = "delete from test_table"
581        r = query(q)
582        self.assertIsInstance(r, str)
583        self.assertEqual(r, '5')
584
585    def testMultipleQueries(self):
586        self.assertEqual(self.db.query(
587            "create temporary table test_multi (n integer);"
588            "insert into test_multi values (4711);"
589            "select n from test_multi").getresult()[0][0], 4711)
590
591    def testQueryWithParams(self):
592        query = self.db.query
593        query("drop table if exists test_table")
594        self.addCleanup(query, "drop table test_table")
595        q = "create table test_table (n1 integer, n2 integer) with oids"
596        query(q)
597        q = "insert into test_table values ($1, $2)"
598        r = query(q, (1, 2))
599        self.assertIsInstance(r, int)
600        r = query(q, [3, 4])
601        self.assertIsInstance(r, int)
602        r = query(q, [5, 6])
603        self.assertIsInstance(r, int)
604        q = "select * from test_table order by 1, 2"
605        self.assertEqual(query(q).getresult(),
606            [(1, 2), (3, 4), (5, 6)])
607        q = "select * from test_table where n1=$1 and n2=$2"
608        self.assertEqual(query(q, 3, 4).getresult(), [(3, 4)])
609        q = "update test_table set n2=$2 where n1=$1"
610        r = query(q, 3, 7)
611        self.assertEqual(r, '1')
612        q = "select * from test_table order by 1, 2"
613        self.assertEqual(query(q).getresult(),
614            [(1, 2), (3, 7), (5, 6)])
615        q = "delete from test_table where n2!=$1"
616        r = query(q, 4)
617        self.assertEqual(r, '3')
618
619    def testEmptyQuery(self):
620        self.assertRaises(ValueError, self.db.query, '')
621
622    def testQueryProgrammingError(self):
623        try:
624            self.db.query("select 1/0")
625        except pg.ProgrammingError as error:
626            self.assertEqual(error.sqlstate, '22012')
627
628    def testPkey(self):
629        query = self.db.query
630        pkey = self.db.pkey
631        for t in ('pkeytest', 'primary key test'):
632            for n in range(7):
633                query('drop table if exists "%s%d"' % (t, n))
634                self.addCleanup(query, 'drop table "%s%d"' % (t, n))
635            query('create table "%s0" ('
636                "a smallint)" % t)
637            query('create table "%s1" ('
638                "b smallint primary key)" % t)
639            query('create table "%s2" ('
640                "c smallint, d smallint primary key)" % t)
641            query('create table "%s3" ('
642                "e smallint, f smallint, g smallint,"
643                " h smallint, i smallint,"
644                " primary key (f, h))" % t)
645            query('create table "%s4" ('
646                "more_than_one_letter varchar primary key)" % t)
647            query('create table "%s5" ('
648                '"with space" date primary key)' % t)
649            query('create table "%s6" ('
650                'a_very_long_column_name varchar,'
651                ' "with space" date,'
652                ' "42" int,'
653                " primary key (a_very_long_column_name,"
654                ' "with space", "42"))' % t)
655            self.assertRaises(KeyError, pkey, '%s0' % t)
656            self.assertEqual(pkey('%s1' % t), 'b')
657            self.assertEqual(pkey('%s2' % t), 'd')
658            r = pkey('%s3' % t)
659            self.assertIsInstance(r, frozenset)
660            self.assertEqual(r, frozenset('fh'))
661            self.assertEqual(pkey('%s4' % t), 'more_than_one_letter')
662            self.assertEqual(pkey('%s5' % t), 'with space')
663            r = pkey('%s6' % t)
664            self.assertIsInstance(r, frozenset)
665            self.assertEqual(r, frozenset([
666                'a_very_long_column_name', 'with space', '42']))
667            # a newly added primary key will be detected
668            query('alter table "%s0" add primary key (a)' % t)
669            self.assertEqual(pkey('%s0' % t), 'a')
670            # a changed primary key will not be detected,
671            # indicating that the internal cache is operating
672            query('alter table "%s1" rename column b to x' % t)
673            self.assertEqual(pkey('%s1' % t), 'b')
674            # we get the changed primary key when the cache is flushed
675            self.assertEqual(pkey('%s1' % t, flush=True), 'x')
676
677    def testGetDatabases(self):
678        databases = self.db.get_databases()
679        self.assertIn('template0', databases)
680        self.assertIn('template1', databases)
681        self.assertNotIn('not existing database', databases)
682        self.assertIn('postgres', databases)
683        self.assertIn(dbname, databases)
684
685    def testGetTables(self):
686        get_tables = self.db.get_tables
687        result1 = get_tables()
688        self.assertIsInstance(result1, list)
689        for t in result1:
690            t = t.split('.', 1)
691            self.assertGreaterEqual(len(t), 2)
692            if len(t) > 2:
693                self.assertTrue(t[1].startswith('"'))
694            t = t[0]
695            self.assertNotEqual(t, 'information_schema')
696            self.assertFalse(t.startswith('pg_'))
697        tables = ('"A very Special Name"',
698            '"A_MiXeD_quoted_NaMe"', 'a1', 'a2',
699            'A_MiXeD_NaMe', '"another special name"',
700            'averyveryveryveryveryveryverylongtablename',
701            'b0', 'b3', 'x', 'xx', 'xXx', 'y', 'z')
702        for t in tables:
703            self.db.query('drop table if exists %s' % t)
704            self.db.query("create table %s"
705                " as select 0" % t)
706        result3 = get_tables()
707        result2 = []
708        for t in result3:
709            if t not in result1:
710                result2.append(t)
711        result3 = []
712        for t in tables:
713            if not t.startswith('"'):
714                t = t.lower()
715            result3.append('public.' + t)
716        self.assertEqual(result2, result3)
717        for t in result2:
718            self.db.query('drop table %s' % t)
719        result2 = get_tables()
720        self.assertEqual(result2, result1)
721
722    def testGetRelations(self):
723        get_relations = self.db.get_relations
724        result = get_relations()
725        self.assertIn('public.test', result)
726        self.assertIn('public.test_view', result)
727        result = get_relations('rv')
728        self.assertIn('public.test', result)
729        self.assertIn('public.test_view', result)
730        result = get_relations('r')
731        self.assertIn('public.test', result)
732        self.assertNotIn('public.test_view', result)
733        result = get_relations('v')
734        self.assertNotIn('public.test', result)
735        self.assertIn('public.test_view', result)
736        result = get_relations('cisSt')
737        self.assertNotIn('public.test', result)
738        self.assertNotIn('public.test_view', result)
739
740    def testGetAttnames(self):
741        get_attnames = self.db.get_attnames
742        self.assertRaises(pg.ProgrammingError,
743            self.db.get_attnames, 'does_not_exist')
744        self.assertRaises(pg.ProgrammingError,
745            self.db.get_attnames, 'has.too.many.dots')
746        query = self.db.query
747        query("drop table if exists test_table")
748        self.addCleanup(query, "drop table test_table")
749        query("create table test_table("
750            " n int, alpha smallint, beta bool,"
751            " gamma char(5), tau text, v varchar(3))")
752        r = get_attnames('test_table')
753        self.assertIsInstance(r, dict)
754        self.assertEqual(r, dict(
755            n='int', alpha='int', beta='bool',
756            gamma='text', tau='text', v='text'))
757
758    def testGetAttnamesWithQuotes(self):
759        get_attnames = self.db.get_attnames
760        query = self.db.query
761        table = 'test table for get_attnames()'
762        query('drop table if exists "%s"' % table)
763        self.addCleanup(query, 'drop table "%s"' % table)
764        query('create table "%s"('
765            '"Prime!" smallint,'
766            ' "much space" integer, "Questions?" text)' % table)
767        r = get_attnames(table)
768        self.assertIsInstance(r, dict)
769        self.assertEqual(r, {
770            'Prime!': 'int', 'much space': 'int', 'Questions?': 'text'})
771        table = 'yet another test table for get_attnames()'
772        query('drop table if exists "%s"' % table)
773        self.addCleanup(query, 'drop table "%s"' % table)
774        self.db.query('create table "%s" ('
775            'a smallint, b integer, c bigint,'
776            ' e numeric, f float, f2 double precision, m money,'
777            ' x smallint, y smallint, z smallint,'
778            ' Normal_NaMe smallint, "Special Name" smallint,'
779            ' t text, u char(2), v varchar(2),'
780            ' primary key (y, u)) with oids' % table)
781        r = get_attnames(table)
782        self.assertIsInstance(r, dict)
783        self.assertEqual(r, {'a': 'int', 'c': 'int', 'b': 'int',
784            'e': 'num', 'f': 'float', 'f2': 'float', 'm': 'money',
785            'normal_name': 'int', 'Special Name': 'int',
786            'u': 'text', 't': 'text', 'v': 'text',
787            'y': 'int', 'x': 'int', 'z': 'int', 'oid': 'int'})
788
789    def testGetAttnamesWithRegtypes(self):
790        get_attnames = self.db.get_attnames
791        query = self.db.query
792        query("drop table if exists test_table")
793        self.addCleanup(query, "drop table test_table")
794        query("create table test_table("
795            " n int, alpha smallint, beta bool,"
796            " gamma char(5), tau text, v varchar(3))")
797        self.db.use_regtypes(True)
798        try:
799            r = get_attnames("test_table")
800            self.assertIsInstance(r, dict)
801        finally:
802            self.db.use_regtypes(False)
803        self.assertEqual(r, dict(
804            n='integer', alpha='smallint', beta='boolean',
805            gamma='character', tau='text', v='character varying'))
806
807    def testGetAttnamesIsCached(self):
808        get_attnames = self.db.get_attnames
809        query = self.db.query
810        query("drop table if exists test_table")
811        self.addCleanup(query, "drop table if exists test_table")
812        query("create table test_table(col int)")
813        r = get_attnames("test_table")
814        self.assertIsInstance(r, dict)
815        self.assertEqual(r, dict(col='int'))
816        query("drop table test_table")
817        query("create table test_table(col text)")
818        r = get_attnames("test_table")
819        self.assertEqual(r, dict(col='int'))
820        r = get_attnames("test_table", flush=True)
821        self.assertEqual(r, dict(col='text'))
822        query("drop table test_table")
823        r = get_attnames("test_table")
824        self.assertEqual(r, dict(col='text'))
825        self.assertRaises(pg.ProgrammingError,
826            get_attnames, "test_table", flush=True)
827
828    def testGetAttnamesIsOrdered(self):
829        get_attnames = self.db.get_attnames
830        query = self.db.query
831        query("drop table if exists test_table")
832        self.addCleanup(query, "drop table test_table")
833        query("create table test_table("
834            " n int, alpha smallint, v varchar(3),"
835            " gamma char(5), tau text, beta bool)")
836        r = get_attnames("test_table")
837        self.assertIsInstance(r, OrderedDict)
838        self.assertEqual(r, OrderedDict([
839            ('n', 'int'), ('alpha', 'int'), ('v', 'text'),
840            ('gamma', 'text'), ('tau', 'text'), ('beta', 'bool')]))
841        if OrderedDict is dict:
842            self.skipTest('OrderedDict is not supported')
843        r = ' '.join(list(r.keys()))
844        self.assertEqual(r, 'n alpha v gamma tau beta')
845
846    def testHasTablePrivilege(self):
847        can = self.db.has_table_privilege
848        self.assertEqual(can('test'), True)
849        self.assertEqual(can('test', 'select'), True)
850        self.assertEqual(can('test', 'SeLeCt'), True)
851        self.assertEqual(can('test', 'SELECT'), True)
852        self.assertEqual(can('test', 'insert'), True)
853        self.assertEqual(can('test', 'update'), True)
854        self.assertEqual(can('test', 'delete'), True)
855        self.assertEqual(can('pg_views', 'select'), True)
856        self.assertEqual(can('pg_views', 'delete'), False)
857        self.assertRaises(pg.ProgrammingError, can, 'test', 'foobar')
858        self.assertRaises(pg.ProgrammingError, can, 'table_does_not_exist')
859
860    def testGet(self):
861        get = self.db.get
862        query = self.db.query
863        table = 'get_test_table'
864        query('drop table if exists "%s"' % table)
865        self.addCleanup(query, 'drop table "%s"' % table)
866        query('create table "%s" ('
867            "n integer, t text) with oids" % table)
868        for n, t in enumerate('xyz'):
869            query('insert into "%s" values('"%d, '%s')"
870                % (table, n + 1, t))
871        self.assertRaises(pg.ProgrammingError, get, table, 2)
872        r = get(table, 2, 'n')
873        oid_table = 'oid(%s)' % table
874        self.assertIn(oid_table, r)
875        oid = r[oid_table]
876        self.assertIsInstance(oid, int)
877        result = {'t': 'y', 'n': 2, oid_table: oid}
878        self.assertEqual(r, result)
879        self.assertEqual(get(table + ' *', 2, 'n'), r)
880        self.assertEqual(get(table, oid, 'oid')['t'], 'y')
881        self.assertEqual(get(table, 1, 'n')['t'], 'x')
882        self.assertEqual(get(table, 3, 'n')['t'], 'z')
883        self.assertEqual(get(table, 2, 'n')['t'], 'y')
884        self.assertRaises(pg.DatabaseError, get, table, 4, 'n')
885        r['n'] = 3
886        self.assertEqual(get(table, r, 'n')['t'], 'z')
887        self.assertEqual(get(table, 1, 'n')['t'], 'x')
888        query('alter table "%s" alter n set not null' % table)
889        query('alter table "%s" add primary key (n)' % table)
890        self.assertEqual(get(table, 3)['t'], 'z')
891        self.assertEqual(get(table, 1)['t'], 'x')
892        self.assertEqual(get(table, 2)['t'], 'y')
893        r['n'] = 1
894        self.assertEqual(get(table, r)['t'], 'x')
895        r['n'] = 3
896        self.assertEqual(get(table, r)['t'], 'z')
897        r['n'] = 2
898        self.assertEqual(get(table, r)['t'], 'y')
899
900    def testGetWithCompositeKey(self):
901        get = self.db.get
902        query = self.db.query
903        table = 'get_test_table_1'
904        query('drop table if exists "%s"' % table)
905        self.addCleanup(query, 'drop table "%s"' % table)
906        query('create table "%s" ('
907            "n integer, t text, primary key (n))" % table)
908        for n, t in enumerate('abc'):
909            query('insert into "%s" values('
910                "%d, '%s')" % (table, n + 1, t))
911        self.assertEqual(get(table, 2)['t'], 'b')
912        table = 'get_test_table_2'
913        query('drop table if exists "%s"' % table)
914        self.addCleanup(query, 'drop table "%s"' % table)
915        query('create table "%s" ('
916            "n integer, m integer, t text, primary key (n, m))" % table)
917        for n in range(3):
918            for m in range(2):
919                t = chr(ord('a') + 2 * n + m)
920                query('insert into "%s" values('
921                    "%d, %d, '%s')" % (table, n + 1, m + 1, t))
922        self.assertRaises(pg.ProgrammingError, get, table, 2)
923        self.assertEqual(get(table, dict(n=2, m=2))['t'], 'd')
924        r = get(table, dict(n=1, m=2), ('n', 'm'))
925        self.assertEqual(r['t'], 'b')
926        r = get(table, dict(n=3, m=2), frozenset(['n', 'm']))
927        self.assertEqual(r['t'], 'f')
928
929    def testGetWithQuotedNames(self):
930        get = self.db.get
931        query = self.db.query
932        table = 'test table for get()'
933        query('drop table if exists "%s"' % table)
934        self.addCleanup(query, 'drop table "%s"' % table)
935        query('create table "%s" ('
936            '"Prime!" smallint primary key,'
937            ' "much space" integer, "Questions?" text)' % table)
938        query('insert into "%s"'
939              " values(17, 1001, 'No!')" % table)
940        r = get(table, 17)
941        self.assertIsInstance(r, dict)
942        self.assertEqual(r['Prime!'], 17)
943        self.assertEqual(r['much space'], 1001)
944        self.assertEqual(r['Questions?'], 'No!')
945
946    def testGetFromView(self):
947        self.db.query('delete from test where i4=14')
948        self.db.query('insert into test (i4, v4) values('
949            "14, 'abc4')")
950        r = self.db.get('test_view', 14, 'i4')
951        self.assertIn('v4', r)
952        self.assertEqual(r['v4'], 'abc4')
953
954    def testGetLittleBobbyTables(self):
955        get = self.db.get
956        query = self.db.query
957        query("drop table if exists test_students")
958        self.addCleanup(query, "drop table test_students")
959        query("create table test_students (firstname varchar primary key,"
960            " nickname varchar, grade char(2))")
961        query("insert into test_students values ("
962              "'D''Arcy', 'Darcey', 'A+')")
963        query("insert into test_students values ("
964              "'Sheldon', 'Moonpie', 'A+')")
965        query("insert into test_students values ("
966              "'Robert', 'Little Bobby Tables', 'D-')")
967        r = get('test_students', 'Sheldon')
968        self.assertEqual(r, dict(
969            firstname="Sheldon", nickname='Moonpie', grade='A+'))
970        r = get('test_students', 'Robert')
971        self.assertEqual(r, dict(
972            firstname="Robert", nickname='Little Bobby Tables', grade='D-'))
973        r = get('test_students', "D'Arcy")
974        self.assertEqual(r, dict(
975            firstname="D'Arcy", nickname='Darcey', grade='A+'))
976        try:
977            get('test_students', "D' Arcy")
978        except pg.DatabaseError as error:
979            self.assertEqual(str(error),
980                'No such record in test_students\nwhere "firstname" = $1\n'
981                'with $1="D\' Arcy"')
982        try:
983            get('test_students', "Robert'); TRUNCATE TABLE test_students;--")
984        except pg.DatabaseError as error:
985            self.assertEqual(str(error),
986                'No such record in test_students\nwhere "firstname" = $1\n'
987                'with $1="Robert\'); TRUNCATE TABLE test_students;--"')
988        q = "select * from test_students order by 1 limit 4"
989        r = query(q).getresult()
990        self.assertEqual(len(r), 3)
991        self.assertEqual(r[1][2], 'D-')
992
993    def testInsert(self):
994        insert = self.db.insert
995        query = self.db.query
996        bool_on = pg.get_bool()
997        decimal = pg.get_decimal()
998        table = 'insert_test_table'
999        query('drop table if exists "%s"' % table)
1000        self.addCleanup(query, 'drop table "%s"' % table)
1001        query('create table "%s" ('
1002            "i2 smallint, i4 integer, i8 bigint,"
1003            " d numeric, f4 real, f8 double precision, m money,"
1004            " v4 varchar(4), c4 char(4), t text,"
1005            " b boolean, ts timestamp) with oids" % table)
1006        oid_table = 'oid(%s)' % table
1007        tests = [dict(i2=None, i4=None, i8=None),
1008            (dict(i2='', i4='', i8=''), dict(i2=None, i4=None, i8=None)),
1009            (dict(i2=0, i4=0, i8=0), dict(i2=0, i4=0, i8=0)),
1010            dict(i2=42, i4=123456, i8=9876543210),
1011            dict(i2=2 ** 15 - 1,
1012                i4=int(2 ** 31 - 1), i8=long(2 ** 63 - 1)),
1013            dict(d=None), (dict(d=''), dict(d=None)),
1014            dict(d=Decimal(0)), (dict(d=0), dict(d=Decimal(0))),
1015            dict(f4=None, f8=None), dict(f4=0, f8=0),
1016            (dict(f4='', f8=''), dict(f4=None, f8=None)),
1017            (dict(d=1234.5, f4=1234.5, f8=1234.5),
1018                  dict(d=Decimal('1234.5'))),
1019            dict(d=Decimal('123.456789'), f4=12.375, f8=123.4921875),
1020            dict(d=Decimal('123456789.9876543212345678987654321')),
1021            dict(m=None), (dict(m=''), dict(m=None)),
1022            dict(m=Decimal('-1234.56')),
1023            (dict(m=('-1234.56')), dict(m=Decimal('-1234.56'))),
1024            dict(m=Decimal('1234.56')), dict(m=Decimal('123456')),
1025            (dict(m='1234.56'), dict(m=Decimal('1234.56'))),
1026            (dict(m=1234.5), dict(m=Decimal('1234.5'))),
1027            (dict(m=-1234.5), dict(m=Decimal('-1234.5'))),
1028            (dict(m=123456), dict(m=Decimal('123456'))),
1029            (dict(m='1234567.89'), dict(m=Decimal('1234567.89'))),
1030            dict(b=None), (dict(b=''), dict(b=None)),
1031            dict(b='f'), dict(b='t'),
1032            (dict(b=0), dict(b='f')), (dict(b=1), dict(b='t')),
1033            (dict(b=False), dict(b='f')), (dict(b=True), dict(b='t')),
1034            (dict(b='0'), dict(b='f')), (dict(b='1'), dict(b='t')),
1035            (dict(b='n'), dict(b='f')), (dict(b='y'), dict(b='t')),
1036            (dict(b='no'), dict(b='f')), (dict(b='yes'), dict(b='t')),
1037            (dict(b='off'), dict(b='f')), (dict(b='on'), dict(b='t')),
1038            dict(v4=None, c4=None, t=None),
1039            (dict(v4='', c4='', t=''), dict(c4=' ' * 4)),
1040            dict(v4='1234', c4='1234', t='1234' * 10),
1041            dict(v4='abcd', c4='abcd', t='abcdefg'),
1042            (dict(v4='abc', c4='abc', t='abc'), dict(c4='abc ')),
1043            dict(ts=None), (dict(ts=''), dict(ts=None)),
1044            (dict(ts=0), dict(ts=None)), (dict(ts=False), dict(ts=None)),
1045            dict(ts='2012-12-21 00:00:00'),
1046            (dict(ts='2012-12-21'), dict(ts='2012-12-21 00:00:00')),
1047            dict(ts='2012-12-21 12:21:12'),
1048            dict(ts='2013-01-05 12:13:14'),
1049            dict(ts='current_timestamp')]
1050        for test in tests:
1051            if isinstance(test, dict):
1052                data = test
1053                change = {}
1054            else:
1055                data, change = test
1056            expect = data.copy()
1057            expect.update(change)
1058            if bool_on:
1059                b = expect.get('b')
1060                if b is not None:
1061                    expect['b'] = b == 't'
1062            if decimal is not Decimal:
1063                d = expect.get('d')
1064                if d is not None:
1065                    expect['d'] = decimal(d)
1066                m = expect.get('m')
1067                if m is not None:
1068                    expect['m'] = decimal(m)
1069            self.assertEqual(insert(table, data), data)
1070            self.assertIn(oid_table, data)
1071            oid = data[oid_table]
1072            self.assertIsInstance(oid, int)
1073            data = dict(item for item in data.items()
1074                if item[0] in expect)
1075            ts = expect.get('ts')
1076            if ts == 'current_timestamp':
1077                ts = expect['ts'] = data['ts']
1078                if len(ts) > 19:
1079                    self.assertEqual(ts[19], '.')
1080                    ts = ts[:19]
1081                else:
1082                    self.assertEqual(len(ts), 19)
1083                self.assertTrue(ts[:4].isdigit())
1084                self.assertEqual(ts[4], '-')
1085                self.assertEqual(ts[10], ' ')
1086                self.assertTrue(ts[11:13].isdigit())
1087                self.assertEqual(ts[13], ':')
1088            self.assertEqual(data, expect)
1089            data = query(
1090                'select oid,* from "%s"' % table).dictresult()[0]
1091            self.assertEqual(data['oid'], oid)
1092            data = dict(item for item in data.items()
1093                if item[0] in expect)
1094            self.assertEqual(data, expect)
1095            query('delete from "%s"' % table)
1096
1097    def testInsertWithQuotedNames(self):
1098        insert = self.db.insert
1099        query = self.db.query
1100        table = 'test table for insert()'
1101        query('drop table if exists "%s"' % table)
1102        self.addCleanup(query, 'drop table "%s"' % table)
1103        query('create table "%s" ('
1104            '"Prime!" smallint primary key,'
1105            ' "much space" integer, "Questions?" text)' % table)
1106        r = {'Prime!': 11, 'much space': 2002, 'Questions?': 'What?'}
1107        r = insert(table, r)
1108        self.assertIsInstance(r, dict)
1109        self.assertEqual(r['Prime!'], 11)
1110        self.assertEqual(r['much space'], 2002)
1111        self.assertEqual(r['Questions?'], 'What?')
1112        r = query('select * from "%s" limit 2' % table).dictresult()
1113        self.assertEqual(len(r), 1)
1114        r = r[0]
1115        self.assertEqual(r['Prime!'], 11)
1116        self.assertEqual(r['much space'], 2002)
1117        self.assertEqual(r['Questions?'], 'What?')
1118
1119    def testUpdate(self):
1120        update = self.db.update
1121        query = self.db.query
1122        table = 'update_test_table'
1123        query('drop table if exists "%s"' % table)
1124        self.addCleanup(query, 'drop table "%s"' % table)
1125        query('create table "%s" ('
1126            "n integer, t text) with oids" % table)
1127        for n, t in enumerate('xyz'):
1128            query('insert into "%s" values('
1129                "%d, '%s')" % (table, n + 1, t))
1130        self.assertRaises(pg.ProgrammingError, self.db.get, table, 2)
1131        r = self.db.get(table, 2, 'n')
1132        r['t'] = 'u'
1133        s = update(table, r)
1134        self.assertEqual(s, r)
1135        q = 'select t from "%s" where n=2' % table
1136        r = query(q).getresult()[0][0]
1137        self.assertEqual(r, 'u')
1138
1139    def testUpdateWithCompositeKey(self):
1140        update = self.db.update
1141        query = self.db.query
1142        table = 'update_test_table_1'
1143        query('drop table if exists "%s"' % table)
1144        self.addCleanup(query, 'drop table if exists "%s"' % table)
1145        query('create table "%s" ('
1146            "n integer, t text, primary key (n))" % table)
1147        for n, t in enumerate('abc'):
1148            query('insert into "%s" values('
1149                "%d, '%s')" % (table, n + 1, t))
1150        self.assertRaises(pg.ProgrammingError, update,
1151                          table, dict(t='b'))
1152        s = dict(n=2, t='d')
1153        r = update(table, s)
1154        self.assertIs(r, s)
1155        self.assertEqual(r['n'], 2)
1156        self.assertEqual(r['t'], 'd')
1157        q = 'select t from "%s" where n=2' % table
1158        r = query(q).getresult()[0][0]
1159        self.assertEqual(r, 'd')
1160        s.update(dict(n=4, t='e'))
1161        r = update(table, s)
1162        self.assertEqual(r['n'], 4)
1163        self.assertEqual(r['t'], 'e')
1164        q = 'select t from "%s" where n=2' % table
1165        r = query(q).getresult()[0][0]
1166        self.assertEqual(r, 'd')
1167        q = 'select t from "%s" where n=4' % table
1168        r = query(q).getresult()
1169        self.assertEqual(len(r), 0)
1170        query('drop table "%s"' % table)
1171        table = 'update_test_table_2'
1172        query('drop table if exists "%s"' % table)
1173        query('create table "%s" ('
1174            "n integer, m integer, t text, primary key (n, m))" % table)
1175        for n in range(3):
1176            for m in range(2):
1177                t = chr(ord('a') + 2 * n + m)
1178                query('insert into "%s" values('
1179                    "%d, %d, '%s')" % (table, n + 1, m + 1, t))
1180        self.assertRaises(pg.ProgrammingError, update,
1181                          table, dict(n=2, t='b'))
1182        self.assertEqual(update(table,
1183                                dict(n=2, m=2, t='x'))['t'], 'x')
1184        q = 'select t from "%s" where n=2 order by m' % table
1185        r = [r[0] for r in query(q).getresult()]
1186        self.assertEqual(r, ['c', 'x'])
1187
1188    def testUpdateWithQuotedNames(self):
1189        update = self.db.update
1190        query = self.db.query
1191        table = 'test table for update()'
1192        query('drop table if exists "%s"' % table)
1193        self.addCleanup(query, 'drop table "%s"' % table)
1194        query('create table "%s" ('
1195            '"Prime!" smallint primary key,'
1196            ' "much space" integer, "Questions?" text)' % table)
1197        query('insert into "%s"'
1198              " values(13, 3003, 'Why!')" % table)
1199        r = {'Prime!': 13, 'much space': 7007, 'Questions?': 'When?'}
1200        r = update(table, r)
1201        self.assertIsInstance(r, dict)
1202        self.assertEqual(r['Prime!'], 13)
1203        self.assertEqual(r['much space'], 7007)
1204        self.assertEqual(r['Questions?'], 'When?')
1205        r = query('select * from "%s" limit 2' % table).dictresult()
1206        self.assertEqual(len(r), 1)
1207        r = r[0]
1208        self.assertEqual(r['Prime!'], 13)
1209        self.assertEqual(r['much space'], 7007)
1210        self.assertEqual(r['Questions?'], 'When?')
1211
1212    def testUpsert(self):
1213        upsert = self.db.upsert
1214        query = self.db.query
1215        table = 'upsert_test_table'
1216        query('drop table if exists "%s"' % table)
1217        self.addCleanup(query, 'drop table "%s"' % table)
1218        query('create table "%s" ('
1219            "n integer primary key, t text) with oids" % table)
1220        s = dict(n=1, t='x')
1221        try:
1222            r = upsert(table, s)
1223        except pg.ProgrammingError as error:
1224            if self.db.server_version < 90500:
1225                self.skipTest('database does not support upsert')
1226            self.fail(str(error))
1227        self.assertIs(r, s)
1228        self.assertEqual(r['n'], 1)
1229        self.assertEqual(r['t'], 'x')
1230        s.update(n=2, t='y')
1231        r = upsert(table, s, **dict.fromkeys(s))
1232        self.assertIs(r, s)
1233        self.assertEqual(r['n'], 2)
1234        self.assertEqual(r['t'], 'y')
1235        q = 'select n, t from "%s" order by n limit 3' % table
1236        r = query(q).getresult()
1237        self.assertEqual(r, [(1, 'x'), (2, 'y')])
1238        s.update(t='z')
1239        r = upsert(table, s)
1240        self.assertIs(r, s)
1241        self.assertEqual(r['n'], 2)
1242        self.assertEqual(r['t'], 'z')
1243        r = query(q).getresult()
1244        self.assertEqual(r, [(1, 'x'), (2, 'z')])
1245        s.update(t='n')
1246        r = upsert(table, s, t=False)
1247        self.assertIs(r, s)
1248        self.assertEqual(r['n'], 2)
1249        self.assertEqual(r['t'], 'z')
1250        r = query(q).getresult()
1251        self.assertEqual(r, [(1, 'x'), (2, 'z')])
1252        s.update(t='y')
1253        r = upsert(table, s, t=True)
1254        self.assertIs(r, s)
1255        self.assertEqual(r['n'], 2)
1256        self.assertEqual(r['t'], 'y')
1257        r = query(q).getresult()
1258        self.assertEqual(r, [(1, 'x'), (2, 'y')])
1259        s.update(t='n')
1260        r = upsert(table, s, t="included.t || '2'")
1261        self.assertIs(r, s)
1262        self.assertEqual(r['n'], 2)
1263        self.assertEqual(r['t'], 'y2')
1264        r = query(q).getresult()
1265        self.assertEqual(r, [(1, 'x'), (2, 'y2')])
1266        s.update(t='y')
1267        r = upsert(table, s, t="excluded.t || '3'")
1268        self.assertIs(r, s)
1269        self.assertEqual(r['n'], 2)
1270        self.assertEqual(r['t'], 'y3')
1271        r = query(q).getresult()
1272        self.assertEqual(r, [(1, 'x'), (2, 'y3')])
1273        s.update(n=1, t='2')
1274        r = upsert(table, s, t="included.t || excluded.t")
1275        self.assertIs(r, s)
1276        self.assertEqual(r['n'], 1)
1277        self.assertEqual(r['t'], 'x2')
1278        r = query(q).getresult()
1279        self.assertEqual(r, [(1, 'x2'), (2, 'y3')])
1280
1281    def testUpsertWithCompositeKey(self):
1282        upsert = self.db.upsert
1283        query = self.db.query
1284        table = 'upsert_test_table_2'
1285        query('drop table if exists "%s"' % table)
1286        self.addCleanup(query, 'drop table "%s"' % table)
1287        query('create table "%s" ('
1288            "n integer, m integer, t text, primary key (n, m))" % table)
1289        s = dict(n=1, m=2, t='x')
1290        try:
1291            r = upsert(table, s)
1292        except pg.ProgrammingError as error:
1293            if self.db.server_version < 90500:
1294                self.skipTest('database does not support upsert')
1295            self.fail(str(error))
1296        self.assertIs(r, s)
1297        self.assertEqual(r['n'], 1)
1298        self.assertEqual(r['m'], 2)
1299        self.assertEqual(r['t'], 'x')
1300        s.update(m=3, t='y')
1301        r = upsert(table, s, **dict.fromkeys(s))
1302        self.assertIs(r, s)
1303        self.assertEqual(r['n'], 1)
1304        self.assertEqual(r['m'], 3)
1305        self.assertEqual(r['t'], 'y')
1306        q = 'select n, m, t from "%s" order by n, m limit 3' % table
1307        r = query(q).getresult()
1308        self.assertEqual(r, [(1, 2, 'x'), (1, 3, 'y')])
1309        s.update(t='z')
1310        r = upsert(table, s)
1311        self.assertIs(r, s)
1312        self.assertEqual(r['n'], 1)
1313        self.assertEqual(r['m'], 3)
1314        self.assertEqual(r['t'], 'z')
1315        r = query(q).getresult()
1316        self.assertEqual(r, [(1, 2, 'x'), (1, 3, 'z')])
1317        s.update(t='n')
1318        r = upsert(table, s, t=False)
1319        self.assertIs(r, s)
1320        self.assertEqual(r['n'], 1)
1321        self.assertEqual(r['m'], 3)
1322        self.assertEqual(r['t'], 'z')
1323        r = query(q).getresult()
1324        self.assertEqual(r, [(1, 2, 'x'), (1, 3, 'z')])
1325        s.update(t='n')
1326        r = upsert(table, s, t=True)
1327        self.assertIs(r, s)
1328        self.assertEqual(r['n'], 1)
1329        self.assertEqual(r['m'], 3)
1330        self.assertEqual(r['t'], 'n')
1331        r = query(q).getresult()
1332        self.assertEqual(r, [(1, 2, 'x'), (1, 3, 'n')])
1333        s.update(n=2, t='y')
1334        r = upsert(table, s, t="'z'")
1335        self.assertIs(r, s)
1336        self.assertEqual(r['n'], 2)
1337        self.assertEqual(r['m'], 3)
1338        self.assertEqual(r['t'], 'y')
1339        r = query(q).getresult()
1340        self.assertEqual(r, [(1, 2, 'x'), (1, 3, 'n'), (2, 3, 'y')])
1341        s.update(n=1, t='m')
1342        r = upsert(table, s, t='included.t || excluded.t')
1343        self.assertIs(r, s)
1344        self.assertEqual(r['n'], 1)
1345        self.assertEqual(r['m'], 3)
1346        self.assertEqual(r['t'], 'nm')
1347        r = query(q).getresult()
1348        self.assertEqual(r, [(1, 2, 'x'), (1, 3, 'nm'), (2, 3, 'y')])
1349
1350    def testUpsertWithQuotedNames(self):
1351        upsert = self.db.upsert
1352        query = self.db.query
1353        table = 'test table for upsert()'
1354        query('drop table if exists "%s"' % table)
1355        self.addCleanup(query, 'drop table "%s"' % table)
1356        query('create table "%s" ('
1357            '"Prime!" smallint primary key,'
1358            ' "much space" integer, "Questions?" text)' % table)
1359        s = {'Prime!': 31, 'much space': 9009, 'Questions?': 'Yes.'}
1360        try:
1361            r = upsert(table, s)
1362        except pg.ProgrammingError as error:
1363            if self.db.server_version < 90500:
1364                self.skipTest('database does not support upsert')
1365            self.fail(str(error))
1366        self.assertIs(r, s)
1367        self.assertEqual(r['Prime!'], 31)
1368        self.assertEqual(r['much space'], 9009)
1369        self.assertEqual(r['Questions?'], 'Yes.')
1370        q = 'select * from "%s" limit 2' % table
1371        r = query(q).getresult()
1372        self.assertEqual(r, [(31, 9009, 'Yes.')])
1373        s.update({'Questions?': 'No.'})
1374        r = upsert(table, s)
1375        self.assertIs(r, s)
1376        self.assertEqual(r['Prime!'], 31)
1377        self.assertEqual(r['much space'], 9009)
1378        self.assertEqual(r['Questions?'], 'No.')
1379        r = query(q).getresult()
1380        self.assertEqual(r, [(31, 9009, 'No.')])
1381
1382    def testClear(self):
1383        clear = self.db.clear
1384        query = self.db.query
1385        f = False if pg.get_bool() else 'f'
1386        table = 'clear_test_table'
1387        query('drop table if exists "%s"' % table)
1388        self.addCleanup(query, 'drop table "%s"' % table)
1389        query('create table "%s" ('
1390            "n integer, b boolean, d date, t text)" % table)
1391        r = clear(table)
1392        result = {'n': 0, 'b': f, 'd': '', 't': ''}
1393        self.assertEqual(r, result)
1394        r['a'] = r['n'] = 1
1395        r['d'] = r['t'] = 'x'
1396        r['b'] = 't'
1397        r['oid'] = long(1)
1398        r = clear(table, r)
1399        result = {'a': 1, 'n': 0, 'b': f, 'd': '', 't': '',
1400            'oid': long(1)}
1401        self.assertEqual(r, result)
1402
1403    def testClearWithQuotedNames(self):
1404        clear = self.db.clear
1405        query = self.db.query
1406        table = 'test table for clear()'
1407        query('drop table if exists "%s"' % table)
1408        self.addCleanup(query, 'drop table "%s"' % table)
1409        query('create table "%s" ('
1410            '"Prime!" smallint primary key,'
1411            ' "much space" integer, "Questions?" text)' % table)
1412        r = clear(table)
1413        self.assertIsInstance(r, dict)
1414        self.assertEqual(r['Prime!'], 0)
1415        self.assertEqual(r['much space'], 0)
1416        self.assertEqual(r['Questions?'], '')
1417
1418    def testDelete(self):
1419        delete = self.db.delete
1420        query = self.db.query
1421        table = 'delete_test_table'
1422        query('drop table if exists "%s"' % table)
1423        self.addCleanup(query, 'drop table "%s"' % table)
1424        query('create table "%s" ('
1425            "n integer, t text) with oids" % table)
1426        for n, t in enumerate('xyz'):
1427            query('insert into "%s" values('
1428                "%d, '%s')" % (table, n + 1, t))
1429        self.assertRaises(pg.ProgrammingError, self.db.get, table, 2)
1430        r = self.db.get(table, 1, 'n')
1431        s = delete(table, r)
1432        self.assertEqual(s, 1)
1433        r = self.db.get(table, 3, 'n')
1434        s = delete(table, r)
1435        self.assertEqual(s, 1)
1436        s = delete(table, r)
1437        self.assertEqual(s, 0)
1438        r = query('select * from "%s"' % table).dictresult()
1439        self.assertEqual(len(r), 1)
1440        r = r[0]
1441        result = {'n': 2, 't': 'y'}
1442        self.assertEqual(r, result)
1443        r = self.db.get(table, 2, 'n')
1444        s = delete(table, r)
1445        self.assertEqual(s, 1)
1446        s = delete(table, r)
1447        self.assertEqual(s, 0)
1448        self.assertRaises(pg.DatabaseError, self.db.get, table, 2, 'n')
1449
1450    def testDeleteWithCompositeKey(self):
1451        query = self.db.query
1452        table = 'delete_test_table_1'
1453        query('drop table if exists "%s"' % table)
1454        self.addCleanup(query, 'drop table "%s"' % table)
1455        query('create table "%s" ('
1456            "n integer, t text, primary key (n))" % table)
1457        for n, t in enumerate('abc'):
1458            query("insert into %s values("
1459                "%d, '%s')" % (table, n + 1, t))
1460        self.assertRaises(pg.ProgrammingError, self.db.delete,
1461            table, dict(t='b'))
1462        self.assertEqual(self.db.delete(table, dict(n=2)), 1)
1463        r = query('select t from "%s" where n=2' % table
1464                  ).getresult()
1465        self.assertEqual(r, [])
1466        self.assertEqual(self.db.delete(table, dict(n=2)), 0)
1467        r = query('select t from "%s" where n=3' % table
1468                  ).getresult()[0][0]
1469        self.assertEqual(r, 'c')
1470        table = 'delete_test_table_2'
1471        query('drop table if exists "%s"' % table)
1472        self.addCleanup(query, 'drop table "%s"' % table)
1473        query('create table "%s" ('
1474            "n integer, m integer, t text, primary key (n, m))" % table)
1475        for n in range(3):
1476            for m in range(2):
1477                t = chr(ord('a') + 2 * n + m)
1478                query('insert into "%s" values('
1479                    "%d, %d, '%s')" % (table, n + 1, m + 1, t))
1480        self.assertRaises(pg.ProgrammingError, self.db.delete,
1481            table, dict(n=2, t='b'))
1482        self.assertEqual(self.db.delete(table, dict(n=2, m=2)), 1)
1483        r = [r[0] for r in query('select t from "%s" where n=2'
1484            ' order by m' % table).getresult()]
1485        self.assertEqual(r, ['c'])
1486        self.assertEqual(self.db.delete(table, dict(n=2, m=2)), 0)
1487        r = [r[0] for r in query('select t from "%s" where n=3'
1488            ' order by m' % table).getresult()]
1489        self.assertEqual(r, ['e', 'f'])
1490        self.assertEqual(self.db.delete(table, dict(n=3, m=1)), 1)
1491        r = [r[0] for r in query('select t from "%s" where n=3'
1492            ' order by m' % table).getresult()]
1493        self.assertEqual(r, ['f'])
1494
1495    def testDeleteWithQuotedNames(self):
1496        delete = self.db.delete
1497        query = self.db.query
1498        table = 'test table for delete()'
1499        query('drop table if exists "%s"' % table)
1500        self.addCleanup(query, 'drop table "%s"' % table)
1501        query('create table "%s" ('
1502            '"Prime!" smallint primary key,'
1503            ' "much space" integer, "Questions?" text)' % table)
1504        query('insert into "%s"'
1505              " values(19, 5005, 'Yes!')" % table)
1506        r = {'Prime!': 17}
1507        r = delete(table, r)
1508        self.assertEqual(r, 0)
1509        r = query('select count(*) from "%s"' % table).getresult()
1510        self.assertEqual(r[0][0], 1)
1511        r = {'Prime!': 19}
1512        r = delete(table, r)
1513        self.assertEqual(r, 1)
1514        r = query('select count(*) from "%s"' % table).getresult()
1515        self.assertEqual(r[0][0], 0)
1516
1517    def testTruncate(self):
1518        truncate = self.db.truncate
1519        self.assertRaises(TypeError, truncate, None)
1520        self.assertRaises(TypeError, truncate, 42)
1521        self.assertRaises(TypeError, truncate, dict(test_table=None))
1522        query = self.db.query
1523        query("drop table if exists test_table")
1524        self.addCleanup(query, "drop table test_table")
1525        query("create table test_table (n smallint)")
1526        for i in range(3):
1527            query("insert into test_table values (1)")
1528        q = "select count(*) from test_table"
1529        r = query(q).getresult()[0][0]
1530        self.assertEqual(r, 3)
1531        truncate('test_table')
1532        r = query(q).getresult()[0][0]
1533        self.assertEqual(r, 0)
1534        for i in range(3):
1535            query("insert into test_table values (1)")
1536        r = query(q).getresult()[0][0]
1537        self.assertEqual(r, 3)
1538        truncate('public.test_table')
1539        r = query(q).getresult()[0][0]
1540        self.assertEqual(r, 0)
1541        query("drop table if exists test_table_2")
1542        self.addCleanup(query, "drop table test_table_2")
1543        query('create table test_table_2 (n smallint)')
1544        for t in (list, tuple, set):
1545            for i in range(3):
1546                query("insert into test_table values (1)")
1547                query("insert into test_table_2 values (2)")
1548            q = ("select (select count(*) from test_table),"
1549                " (select count(*) from test_table_2)")
1550            r = query(q).getresult()[0]
1551            self.assertEqual(r, (3, 3))
1552            truncate(t(['test_table', 'test_table_2']))
1553            r = query(q).getresult()[0]
1554            self.assertEqual(r, (0, 0))
1555
1556    def testTruncateRestart(self):
1557        truncate = self.db.truncate
1558        self.assertRaises(TypeError, truncate, 'test_table', restart='invalid')
1559        query = self.db.query
1560        query("drop table if exists test_table")
1561        self.addCleanup(query, "drop table test_table")
1562        query("create table test_table (n serial, t text)")
1563        for n in range(3):
1564            query("insert into test_table (t) values ('test')")
1565        q = "select count(n), min(n), max(n) from test_table"
1566        r = query(q).getresult()[0]
1567        self.assertEqual(r, (3, 1, 3))
1568        truncate('test_table')
1569        r = query(q).getresult()[0]
1570        self.assertEqual(r, (0, None, None))
1571        for n in range(3):
1572            query("insert into test_table (t) values ('test')")
1573        r = query(q).getresult()[0]
1574        self.assertEqual(r, (3, 4, 6))
1575        truncate('test_table', restart=True)
1576        r = query(q).getresult()[0]
1577        self.assertEqual(r, (0, None, None))
1578        for n in range(3):
1579            query("insert into test_table (t) values ('test')")
1580        r = query(q).getresult()[0]
1581        self.assertEqual(r, (3, 1, 3))
1582
1583    def testTruncateCascade(self):
1584        truncate = self.db.truncate
1585        self.assertRaises(TypeError, truncate, 'test_table', cascade='invalid')
1586        query = self.db.query
1587        query("drop table if exists test_child")
1588        query("drop table if exists test_parent")
1589        self.addCleanup(query, "drop table test_parent")
1590        query("create table test_parent (n smallint primary key)")
1591        self.addCleanup(query, "drop table test_child")
1592        query("create table test_child ("
1593            " n smallint primary key references test_parent (n))")
1594        for n in range(3):
1595            query("insert into test_parent (n) values (%d)" % n)
1596            query("insert into test_child (n) values (%d)" % n)
1597        q = ("select (select count(*) from test_parent),"
1598            " (select count(*) from test_child)")
1599        r = query(q).getresult()[0]
1600        self.assertEqual(r, (3, 3))
1601        self.assertRaises(pg.ProgrammingError, truncate, 'test_parent')
1602        truncate(['test_parent', 'test_child'])
1603        r = query(q).getresult()[0]
1604        self.assertEqual(r, (0, 0))
1605        for n in range(3):
1606            query("insert into test_parent (n) values (%d)" % n)
1607            query("insert into test_child (n) values (%d)" % n)
1608        r = query(q).getresult()[0]
1609        self.assertEqual(r, (3, 3))
1610        truncate('test_parent', cascade=True)
1611        r = query(q).getresult()[0]
1612        self.assertEqual(r, (0, 0))
1613        for n in range(3):
1614            query("insert into test_parent (n) values (%d)" % n)
1615            query("insert into test_child (n) values (%d)" % n)
1616        r = query(q).getresult()[0]
1617        self.assertEqual(r, (3, 3))
1618        truncate('test_child')
1619        r = query(q).getresult()[0]
1620        self.assertEqual(r, (3, 0))
1621        self.assertRaises(pg.ProgrammingError, truncate, 'test_parent')
1622        truncate('test_parent', cascade=True)
1623        r = query(q).getresult()[0]
1624        self.assertEqual(r, (0, 0))
1625
1626    def testTruncateOnly(self):
1627        truncate = self.db.truncate
1628        self.assertRaises(TypeError, truncate, 'test_table', only='invalid')
1629        query = self.db.query
1630        query("drop table if exists test_child")
1631        query("drop table if exists test_parent")
1632        self.addCleanup(query, "drop table test_parent")
1633        query("create table test_parent (n smallint)")
1634        self.addCleanup(query, "drop table test_child")
1635        query("create table test_child ("
1636            " m smallint) inherits (test_parent)")
1637        for n in range(3):
1638            query("insert into test_parent (n) values (1)")
1639            query("insert into test_child (n, m) values (2, 3)")
1640        q = ("select (select count(*) from test_parent),"
1641            " (select count(*) from test_child)")
1642        r = query(q).getresult()[0]
1643        self.assertEqual(r, (6, 3))
1644        truncate('test_parent')
1645        r = query(q).getresult()[0]
1646        self.assertEqual(r, (0, 0))
1647        for n in range(3):
1648            query("insert into test_parent (n) values (1)")
1649            query("insert into test_child (n, m) values (2, 3)")
1650        r = query(q).getresult()[0]
1651        self.assertEqual(r, (6, 3))
1652        truncate('test_parent*')
1653        r = query(q).getresult()[0]
1654        self.assertEqual(r, (0, 0))
1655        for n in range(3):
1656            query("insert into test_parent (n) values (1)")
1657            query("insert into test_child (n, m) values (2, 3)")
1658        r = query(q).getresult()[0]
1659        self.assertEqual(r, (6, 3))
1660        truncate('test_parent', only=True)
1661        r = query(q).getresult()[0]
1662        self.assertEqual(r, (3, 3))
1663        truncate('test_parent', only=False)
1664        r = query(q).getresult()[0]
1665        self.assertEqual(r, (0, 0))
1666        self.assertRaises(ValueError, truncate, 'test_parent*', only=True)
1667        truncate('test_parent*', only=False)
1668        query("drop table if exists test_parent_2")
1669        self.addCleanup(query, "drop table test_parent_2")
1670        query("create table test_parent_2 (n smallint)")
1671        query("drop table if exists test_child_2")
1672        self.addCleanup(query, "drop table test_child_2")
1673        query("create table test_child_2 ("
1674            " m smallint) inherits (test_parent_2)")
1675        for n in range(3):
1676            query("insert into test_parent (n) values (1)")
1677            query("insert into test_child (n, m) values (2, 3)")
1678            query("insert into test_parent_2 (n) values (1)")
1679            query("insert into test_child_2 (n, m) values (2, 3)")
1680        q = ("select (select count(*) from test_parent),"
1681            " (select count(*) from test_child),"
1682            " (select count(*) from test_parent_2),"
1683            " (select count(*) from test_child_2)")
1684        r = query(q).getresult()[0]
1685        self.assertEqual(r, (6, 3, 6, 3))
1686        truncate(['test_parent', 'test_parent_2'], only=[False, True])
1687        r = query(q).getresult()[0]
1688        self.assertEqual(r, (0, 0, 3, 3))
1689        truncate(['test_parent', 'test_parent_2'], only=False)
1690        r = query(q).getresult()[0]
1691        self.assertEqual(r, (0, 0, 0, 0))
1692        self.assertRaises(ValueError, truncate,
1693            ['test_parent*', 'test_child'], only=[True, False])
1694        truncate(['test_parent*', 'test_child'], only=[False, True])
1695
1696    def testTruncateQuoted(self):
1697        truncate = self.db.truncate
1698        query = self.db.query
1699        table = "test table for truncate()"
1700        query('drop table if exists "%s"' % table)
1701        self.addCleanup(query, 'drop table "%s"' % table)
1702        query('create table "%s" (n smallint)' % table)
1703        for i in range(3):
1704            query('insert into "%s" values (1)' % table)
1705        q = 'select count(*) from "%s"' % table
1706        r = query(q).getresult()[0][0]
1707        self.assertEqual(r, 3)
1708        truncate(table)
1709        r = query(q).getresult()[0][0]
1710        self.assertEqual(r, 0)
1711        for i in range(3):
1712            query('insert into "%s" values (1)' % table)
1713        r = query(q).getresult()[0][0]
1714        self.assertEqual(r, 3)
1715        truncate('public."%s"' % table)
1716        r = query(q).getresult()[0][0]
1717        self.assertEqual(r, 0)
1718
1719    def testTransaction(self):
1720        query = self.db.query
1721        query("drop table if exists test_table")
1722        self.addCleanup(query, "drop table test_table")
1723        query("create table test_table (n integer)")
1724        self.db.begin()
1725        query("insert into test_table values (1)")
1726        query("insert into test_table values (2)")
1727        self.db.commit()
1728        self.db.begin()
1729        query("insert into test_table values (3)")
1730        query("insert into test_table values (4)")
1731        self.db.rollback()
1732        self.db.begin()
1733        query("insert into test_table values (5)")
1734        self.db.savepoint('before6')
1735        query("insert into test_table values (6)")
1736        self.db.rollback('before6')
1737        query("insert into test_table values (7)")
1738        self.db.commit()
1739        self.db.begin()
1740        self.db.savepoint('before8')
1741        query("insert into test_table values (8)")
1742        self.db.release('before8')
1743        self.assertRaises(pg.ProgrammingError, self.db.rollback, 'before8')
1744        self.db.commit()
1745        self.db.start()
1746        query("insert into test_table values (9)")
1747        self.db.end()
1748        r = [r[0] for r in query(
1749            "select * from test_table order by 1").getresult()]
1750        self.assertEqual(r, [1, 2, 5, 7, 9])
1751
1752    def testContextManager(self):
1753        query = self.db.query
1754        query("drop table if exists test_table")
1755        self.addCleanup(query, "drop table test_table")
1756        query("create table test_table (n integer check(n>0))")
1757        with self.db:
1758            query("insert into test_table values (1)")
1759            query("insert into test_table values (2)")
1760        try:
1761            with self.db:
1762                query("insert into test_table values (3)")
1763                query("insert into test_table values (4)")
1764                raise ValueError('test transaction should rollback')
1765        except ValueError as error:
1766            self.assertEqual(str(error), 'test transaction should rollback')
1767        with self.db:
1768            query("insert into test_table values (5)")
1769        try:
1770            with self.db:
1771                query("insert into test_table values (6)")
1772                query("insert into test_table values (-1)")
1773        except pg.ProgrammingError as error:
1774            self.assertTrue('check' in str(error))
1775        with self.db:
1776            query("insert into test_table values (7)")
1777        r = [r[0] for r in query(
1778            "select * from test_table order by 1").getresult()]
1779        self.assertEqual(r, [1, 2, 5, 7])
1780
1781    def testBytea(self):
1782        query = self.db.query
1783        query('drop table if exists bytea_test')
1784        self.addCleanup(query, 'drop table bytea_test')
1785        query('create table bytea_test (n smallint primary key, data bytea)')
1786        s = b"It's all \\ kinds \x00 of\r nasty \xff stuff!\n"
1787        r = self.db.escape_bytea(s)
1788        query('insert into bytea_test values(3,$1)', (r,))
1789        r = query('select * from bytea_test where n=3').getresult()
1790        self.assertEqual(len(r), 1)
1791        r = r[0]
1792        self.assertEqual(len(r), 2)
1793        self.assertEqual(r[0], 3)
1794        r = r[1]
1795        self.assertIsInstance(r, str)
1796        r = self.db.unescape_bytea(r)
1797        self.assertIsInstance(r, bytes)
1798        self.assertEqual(r, s)
1799
1800    def testInsertUpdateGetBytea(self):
1801        query = self.db.query
1802        query('drop table if exists bytea_test')
1803        self.addCleanup(query, 'drop table bytea_test')
1804        query('create table bytea_test (n smallint primary key, data bytea)')
1805        # insert null value
1806        r = self.db.insert('bytea_test', n=0, data=None)
1807        self.assertIsInstance(r, dict)
1808        self.assertIn('n', r)
1809        self.assertEqual(r['n'], 0)
1810        self.assertIn('data', r)
1811        self.assertIsNone(r['data'])
1812        s = b'None'
1813        r = self.db.update('bytea_test', n=0, data=s)
1814        self.assertIsInstance(r, dict)
1815        self.assertIn('n', r)
1816        self.assertEqual(r['n'], 0)
1817        self.assertIn('data', r)
1818        r = r['data']
1819        self.assertIsInstance(r, bytes)
1820        self.assertEqual(r, s)
1821        r = self.db.update('bytea_test', n=0, data=None)
1822        self.assertIsNone(r['data'])
1823        # insert as bytes
1824        s = b"It's all \\ kinds \x00 of\r nasty \xff stuff!\n"
1825        r = self.db.insert('bytea_test', n=5, data=s)
1826        self.assertIsInstance(r, dict)
1827        self.assertIn('n', r)
1828        self.assertEqual(r['n'], 5)
1829        self.assertIn('data', r)
1830        r = r['data']
1831        self.assertIsInstance(r, bytes)
1832        self.assertEqual(r, s)
1833        # update as bytes
1834        s += b"and now even more \x00 nasty \t stuff!\f"
1835        r = self.db.update('bytea_test', n=5, data=s)
1836        self.assertIsInstance(r, dict)
1837        self.assertIn('n', r)
1838        self.assertEqual(r['n'], 5)
1839        self.assertIn('data', r)
1840        r = r['data']
1841        self.assertIsInstance(r, bytes)
1842        self.assertEqual(r, s)
1843        r = query('select * from bytea_test where n=5').getresult()
1844        self.assertEqual(len(r), 1)
1845        r = r[0]
1846        self.assertEqual(len(r), 2)
1847        self.assertEqual(r[0], 5)
1848        r = r[1]
1849        self.assertIsInstance(r, str)
1850        r = self.db.unescape_bytea(r)
1851        self.assertIsInstance(r, bytes)
1852        self.assertEqual(r, s)
1853        r = self.db.get('bytea_test', dict(n=5))
1854        self.assertIsInstance(r, dict)
1855        self.assertIn('n', r)
1856        self.assertEqual(r['n'], 5)
1857        self.assertIn('data', r)
1858        r = r['data']
1859        self.assertIsInstance(r, bytes)
1860        self.assertEqual(r, s)
1861
1862    def testDebugWithCallable(self):
1863        if debug:
1864            self.assertEqual(self.db.debug, debug)
1865        else:
1866            self.assertIsNone(self.db.debug)
1867        s = []
1868        self.db.debug = s.append
1869        try:
1870            self.db.query("select 1")
1871            self.db.query("select 2")
1872            self.assertEqual(s, ["select 1", "select 2"])
1873        finally:
1874            self.db.debug = debug
1875
1876
1877class TestDBClassNonStdOpts(TestDBClass):
1878    """Test the methods of the DB class with non-standard global options."""
1879
1880    @classmethod
1881    def setUpClass(cls):
1882        cls.saved_options = {}
1883        cls.set_option('decimal', float)
1884        not_bool = not pg.get_bool()
1885        cls.set_option('bool', not_bool)
1886        unnamed_result = lambda q: q.getresult()
1887        cls.set_option('namedresult', unnamed_result)
1888        super(TestDBClassNonStdOpts, cls).setUpClass()
1889
1890    @classmethod
1891    def tearDownClass(cls):
1892        super(TestDBClassNonStdOpts, cls).tearDownClass()
1893        cls.reset_option('namedresult')
1894        cls.reset_option('bool')
1895        cls.reset_option('decimal')
1896
1897    @classmethod
1898    def set_option(cls, option, value):
1899        cls.saved_options[option] = getattr(pg, 'get_' + option)()
1900        return getattr(pg, 'set_' + option)(value)
1901
1902    @classmethod
1903    def reset_option(cls, option):
1904        return getattr(pg, 'set_' + option)(cls.saved_options[option])
1905
1906
1907class TestSchemas(unittest.TestCase):
1908    """Test correct handling of schemas (namespaces)."""
1909
1910    @classmethod
1911    def setUpClass(cls):
1912        db = DB()
1913        query = db.query
1914        query("set client_min_messages=warning")
1915        for num_schema in range(5):
1916            if num_schema:
1917                schema = "s%d" % num_schema
1918                query("drop schema if exists %s cascade" % (schema,))
1919                try:
1920                    query("create schema %s" % (schema,))
1921                except pg.ProgrammingError:
1922                    raise RuntimeError("The test user cannot create schemas.\n"
1923                        "Grant create on database %s to the user"
1924                        " for running these tests." % dbname)
1925            else:
1926                schema = "public"
1927                query("drop table if exists %s.t" % (schema,))
1928                query("drop table if exists %s.t%d" % (schema, num_schema))
1929            query("create table %s.t with oids as select 1 as n, %d as d"
1930                  % (schema, num_schema))
1931            query("create table %s.t%d with oids as select 1 as n, %d as d"
1932                  % (schema, num_schema, num_schema))
1933        db.close()
1934
1935    @classmethod
1936    def tearDownClass(cls):
1937        db = DB()
1938        query = db.query
1939        query("set client_min_messages=warning")
1940        for num_schema in range(5):
1941            if num_schema:
1942                schema = "s%d" % num_schema
1943                query("drop schema %s cascade" % (schema,))
1944            else:
1945                schema = "public"
1946                query("drop table %s.t" % (schema,))
1947                query("drop table %s.t%d" % (schema, num_schema))
1948        db.close()
1949
1950    def setUp(self):
1951        self.db = DB()
1952        self.db.query("set client_min_messages=warning")
1953
1954    def tearDown(self):
1955        self.doCleanups()
1956        self.db.close()
1957
1958    def testGetTables(self):
1959        tables = self.db.get_tables()
1960        for num_schema in range(5):
1961            if num_schema:
1962                schema = "s" + str(num_schema)
1963            else:
1964                schema = "public"
1965            for t in (schema + ".t",
1966                    schema + ".t" + str(num_schema)):
1967                self.assertIn(t, tables)
1968
1969    def testGetAttnames(self):
1970        get_attnames = self.db.get_attnames
1971        query = self.db.query
1972        result = {'oid': 'int', 'd': 'int', 'n': 'int'}
1973        r = get_attnames("t")
1974        self.assertEqual(r, result)
1975        r = get_attnames("s4.t4")
1976        self.assertEqual(r, result)
1977        query("drop table if exists s3.t3m")
1978        self.addCleanup(query, "drop table s3.t3m")
1979        query("create table s3.t3m with oids as select 1 as m")
1980        result_m = {'oid': 'int', 'm': 'int'}
1981        r = get_attnames("s3.t3m")
1982        self.assertEqual(r, result_m)
1983        query("set search_path to s1,s3")
1984        r = get_attnames("t3")
1985        self.assertEqual(r, result)
1986        r = get_attnames("t3m")
1987        self.assertEqual(r, result_m)
1988
1989    def testGet(self):
1990        get = self.db.get
1991        query = self.db.query
1992        PrgError = pg.ProgrammingError
1993        self.assertEqual(get("t", 1, 'n')['d'], 0)
1994        self.assertEqual(get("t0", 1, 'n')['d'], 0)
1995        self.assertEqual(get("public.t", 1, 'n')['d'], 0)
1996        self.assertEqual(get("public.t0", 1, 'n')['d'], 0)
1997        self.assertRaises(PrgError, get, "public.t1", 1, 'n')
1998        self.assertEqual(get("s1.t1", 1, 'n')['d'], 1)
1999        self.assertEqual(get("s3.t", 1, 'n')['d'], 3)
2000        query("set search_path to s2,s4")
2001        self.assertRaises(PrgError, get, "t1", 1, 'n')
2002        self.assertEqual(get("t4", 1, 'n')['d'], 4)
2003        self.assertRaises(PrgError, get, "t3", 1, 'n')
2004        self.assertEqual(get("t", 1, 'n')['d'], 2)
2005        self.assertEqual(get("s3.t3", 1, 'n')['d'], 3)
2006        query("set search_path to s1,s3")
2007        self.assertRaises(PrgError, get, "t2", 1, 'n')
2008        self.assertEqual(get("t3", 1, 'n')['d'], 3)
2009        self.assertRaises(PrgError, get, "t4", 1, 'n')
2010        self.assertEqual(get("t", 1, 'n')['d'], 1)
2011        self.assertEqual(get("s4.t4", 1, 'n')['d'], 4)
2012
2013    def testMunging(self):
2014        get = self.db.get
2015        query = self.db.query
2016        r = get("t", 1, 'n')
2017        self.assertIn('oid(t)', r)
2018        query("set search_path to s2")
2019        r = get("t2", 1, 'n')
2020        self.assertIn('oid(t2)', r)
2021        query("set search_path to s3")
2022        r = get("t", 1, 'n')
2023        self.assertIn('oid(t)', r)
2024
2025
2026if __name__ == '__main__':
2027    unittest.main()
Note: See TracBrowser for help on using the repository browser.