source: trunk/tests/test_classic_dbwrapper.py @ 765

Last change on this file since 765 was 765, checked in by cito, 4 years ago

Improve support for access by primary key

Composite primary keys are now returned as tuples instead of frozensets,
where the ordering of the tuple reflects the primary key index.

Primary keys now takes precedence if both OID and primary key are available
(this was solved the other way around in 4.x). Use of OIDs is thus slightly
more discouraged, though it still works as before for tables with OIDs where
no primary key is available.

This changeset also clarifies some docstrings, makes the code a bit clearer,
handles and tests some more edge cases (pg module still has 100% coverage).

  • Property svn:executable set to *
  • Property svn:keywords set to Id
File size: 108.9 KB
Line 
1#! /usr/bin/python
2# -*- coding: utf-8 -*-
3
4"""Test the classic PyGreSQL interface.
5
6Sub-tests for the DB wrapper object.
7
8Contributed by Christoph Zwerschke.
9
10These tests need a database to test against.
11
12"""
13
14try:
15    import unittest2 as unittest  # for Python < 2.7
16except ImportError:
17    import unittest
18
19import os
20import sys
21import tempfile
22
23import pg  # the module under test
24
25from decimal import Decimal
26
27# We need a database to test against.  If LOCAL_PyGreSQL.py exists we will
28# get our information from that.  Otherwise we use the defaults.
29# The current user must have create schema privilege on the database.
30dbname = 'unittest'
31dbhost = None
32dbport = 5432
33
34debug = False  # let DB wrapper print debugging output
35
36try:
37    from .LOCAL_PyGreSQL import *
38except (ImportError, ValueError):
39    try:
40        from LOCAL_PyGreSQL import *
41    except ImportError:
42        pass
43
44try:
45    long
46except NameError:  # Python >= 3.0
47    long = int
48
49try:
50    unicode
51except NameError:  # Python >= 3.0
52    unicode = str
53
54try:
55    from collections import OrderedDict
56except ImportError:  # Python 2.6 or 3.0
57    OrderedDict = dict
58
59if str is bytes:
60    from StringIO import StringIO
61else:
62    from io import StringIO
63
64windows = os.name == 'nt'
65
66# There is a known a bug in libpq under Windows which can cause
67# the interface to crash when calling PQhost():
68do_not_ask_for_host = windows
69do_not_ask_for_host_reason = 'libpq issue on Windows'
70
71
72def DB():
73    """Create a DB wrapper object connecting to the test database."""
74    db = pg.DB(dbname, dbhost, dbport)
75    if debug:
76        db.debug = debug
77    db.query("set client_min_messages=warning")
78    return db
79
80
81class TestDBClassBasic(unittest.TestCase):
82    """Test existence of the DB class wrapped pg connection methods."""
83
84    def setUp(self):
85        self.db = DB()
86
87    def tearDown(self):
88        try:
89            self.db.close()
90        except pg.InternalError:
91            pass
92
93    def testAllDBAttributes(self):
94        attributes = [
95            'abort',
96            'begin',
97            'cancel', 'clear', 'close', 'commit',
98            'db', 'dbname', 'debug', 'delete',
99            'end', 'endcopy', 'error',
100            'escape_bytea', 'escape_identifier',
101            'escape_literal', 'escape_string',
102            'fileno',
103            'get', 'get_attnames', 'get_databases',
104            'get_notice_receiver', 'get_parameter',
105            'get_relations', 'get_tables',
106            'getline', 'getlo', 'getnotify',
107            'has_table_privilege', 'host',
108            'insert', 'inserttable',
109            'locreate', 'loimport',
110            'notification_handler',
111            'options',
112            'parameter', 'pkey', 'port',
113            'protocol_version', 'putline',
114            'query',
115            'release', 'reopen', 'reset', 'rollback',
116            'savepoint', 'server_version',
117            'set_notice_receiver', 'set_parameter',
118            'source', 'start', 'status',
119            'transaction', 'truncate',
120            'unescape_bytea', 'update', 'upsert',
121            'use_regtypes', 'user',
122        ]
123        db_attributes = [a for a in dir(self.db)
124            if not a.startswith('_')]
125        self.assertEqual(attributes, db_attributes)
126
127    def testAttributeDb(self):
128        self.assertEqual(self.db.db.db, dbname)
129
130    def testAttributeDbname(self):
131        self.assertEqual(self.db.dbname, dbname)
132
133    def testAttributeError(self):
134        error = self.db.error
135        self.assertTrue(not error or 'krb5_' in error)
136        self.assertEqual(self.db.error, self.db.db.error)
137
138    @unittest.skipIf(do_not_ask_for_host, do_not_ask_for_host_reason)
139    def testAttributeHost(self):
140        def_host = 'localhost'
141        host = self.db.host
142        self.assertIsInstance(host, str)
143        self.assertEqual(host, dbhost or def_host)
144        self.assertEqual(host, self.db.db.host)
145
146    def testAttributeOptions(self):
147        no_options = ''
148        options = self.db.options
149        self.assertEqual(options, no_options)
150        self.assertEqual(options, self.db.db.options)
151
152    def testAttributePort(self):
153        def_port = 5432
154        port = self.db.port
155        self.assertIsInstance(port, int)
156        self.assertEqual(port, dbport or def_port)
157        self.assertEqual(port, self.db.db.port)
158
159    def testAttributeProtocolVersion(self):
160        protocol_version = self.db.protocol_version
161        self.assertIsInstance(protocol_version, int)
162        self.assertTrue(2 <= protocol_version < 4)
163        self.assertEqual(protocol_version, self.db.db.protocol_version)
164
165    def testAttributeServerVersion(self):
166        server_version = self.db.server_version
167        self.assertIsInstance(server_version, int)
168        self.assertTrue(70400 <= server_version < 100000)
169        self.assertEqual(server_version, self.db.db.server_version)
170
171    def testAttributeStatus(self):
172        status_ok = 1
173        status = self.db.status
174        self.assertIsInstance(status, int)
175        self.assertEqual(status, status_ok)
176        self.assertEqual(status, self.db.db.status)
177
178    def testAttributeUser(self):
179        no_user = 'Deprecated facility'
180        user = self.db.user
181        self.assertTrue(user)
182        self.assertIsInstance(user, str)
183        self.assertNotEqual(user, no_user)
184        self.assertEqual(user, self.db.db.user)
185
186    def testMethodEscapeLiteral(self):
187        self.assertEqual(self.db.escape_literal(''), "''")
188
189    def testMethodEscapeIdentifier(self):
190        self.assertEqual(self.db.escape_identifier(''), '""')
191
192    def testMethodEscapeString(self):
193        self.assertEqual(self.db.escape_string(''), '')
194
195    def testMethodEscapeBytea(self):
196        self.assertEqual(self.db.escape_bytea('').replace(
197            '\\x', '').replace('\\', ''), '')
198
199    def testMethodUnescapeBytea(self):
200        self.assertEqual(self.db.unescape_bytea(''), b'')
201
202    def testMethodQuery(self):
203        query = self.db.query
204        query("select 1+1")
205        query("select 1+$1+$2", 2, 3)
206        query("select 1+$1+$2", (2, 3))
207        query("select 1+$1+$2", [2, 3])
208        query("select 1+$1", 1)
209
210    def testMethodQueryEmpty(self):
211        self.assertRaises(ValueError, self.db.query, '')
212
213    def testMethodQueryProgrammingError(self):
214        try:
215            self.db.query("select 1/0")
216        except pg.ProgrammingError as error:
217            self.assertEqual(error.sqlstate, '22012')
218
219    def testMethodEndcopy(self):
220        try:
221            self.db.endcopy()
222        except IOError:
223            pass
224
225    def testMethodClose(self):
226        self.db.close()
227        try:
228            self.db.reset()
229        except pg.Error:
230            pass
231        else:
232            self.fail('Reset should give an error for a closed connection')
233        self.assertIsNone(self.db.db)
234        self.assertRaises(pg.InternalError, self.db.close)
235        self.assertRaises(pg.InternalError, self.db.query, 'select 1')
236        self.assertRaises(pg.InternalError, getattr, self.db, 'status')
237        self.assertRaises(pg.InternalError, getattr, self.db, 'error')
238        self.assertRaises(pg.InternalError, getattr, self.db, 'absent')
239
240    def testMethodReset(self):
241        con = self.db.db
242        self.db.reset()
243        self.assertIs(self.db.db, con)
244        self.db.query("select 1+1")
245        self.db.close()
246        self.assertRaises(pg.InternalError, self.db.reset)
247
248    def testMethodReopen(self):
249        con = self.db.db
250        self.db.reopen()
251        self.assertIsNot(self.db.db, con)
252        con = self.db.db
253        self.db.query("select 1+1")
254        self.db.close()
255        self.db.reopen()
256        self.assertIsNot(self.db.db, con)
257        self.db.query("select 1+1")
258        self.db.close()
259
260    def testExistingConnection(self):
261        db = pg.DB(self.db.db)
262        self.assertEqual(self.db.db, db.db)
263        self.assertTrue(db.db)
264        db.close()
265        self.assertTrue(db.db)
266        db.reopen()
267        self.assertTrue(db.db)
268        db.close()
269        self.assertTrue(db.db)
270        db = pg.DB(self.db)
271        self.assertEqual(self.db.db, db.db)
272        db = pg.DB(db=self.db.db)
273        self.assertEqual(self.db.db, db.db)
274
275        class DB2:
276            pass
277
278        db2 = DB2()
279        db2._cnx = self.db.db
280        db = pg.DB(db2)
281        self.assertEqual(self.db.db, db.db)
282
283
284class TestDBClass(unittest.TestCase):
285    """Test the methods of the DB class wrapped pg connection."""
286
287    @classmethod
288    def setUpClass(cls):
289        db = DB()
290        db.query("drop table if exists test cascade")
291        db.query("create table test ("
292            "i2 smallint, i4 integer, i8 bigint,"
293            " d numeric, f4 real, f8 double precision, m money,"
294            " v4 varchar(4), c4 char(4), t text)")
295        db.query("create or replace view test_view as"
296            " select i4, v4 from test")
297        db.close()
298
299    @classmethod
300    def tearDownClass(cls):
301        db = DB()
302        db.query("drop table test cascade")
303        db.close()
304
305    def setUp(self):
306        self.db = DB()
307        query = self.db.query
308        query('set client_encoding=utf8')
309        query('set standard_conforming_strings=on')
310        query("set lc_monetary='C'")
311        query("set datestyle='ISO,YMD'")
312        query('set bytea_output=hex')
313
314    def tearDown(self):
315        self.doCleanups()
316        self.db.close()
317
318    def testClassName(self):
319        self.assertEqual(self.db.__class__.__name__, 'DB')
320
321    def testModuleName(self):
322        self.assertEqual(self.db.__module__, 'pg')
323        self.assertEqual(self.db.__class__.__module__, 'pg')
324
325    def testEscapeLiteral(self):
326        f = self.db.escape_literal
327        r = f(b"plain")
328        self.assertIsInstance(r, bytes)
329        self.assertEqual(r, b"'plain'")
330        r = f(u"plain")
331        self.assertIsInstance(r, unicode)
332        self.assertEqual(r, u"'plain'")
333        r = f(u"that's kÀse".encode('utf-8'))
334        self.assertIsInstance(r, bytes)
335        self.assertEqual(r, u"'that''s kÀse'".encode('utf-8'))
336        r = f(u"that's kÀse")
337        self.assertIsInstance(r, unicode)
338        self.assertEqual(r, u"'that''s kÀse'")
339        self.assertEqual(f(r"It's fine to have a \ inside."),
340            r" E'It''s fine to have a \\ inside.'")
341        self.assertEqual(f('No "quotes" must be escaped.'),
342            "'No \"quotes\" must be escaped.'")
343
344    def testEscapeIdentifier(self):
345        f = self.db.escape_identifier
346        r = f(b"plain")
347        self.assertIsInstance(r, bytes)
348        self.assertEqual(r, b'"plain"')
349        r = f(u"plain")
350        self.assertIsInstance(r, unicode)
351        self.assertEqual(r, u'"plain"')
352        r = f(u"that's kÀse".encode('utf-8'))
353        self.assertIsInstance(r, bytes)
354        self.assertEqual(r, u'"that\'s kÀse"'.encode('utf-8'))
355        r = f(u"that's kÀse")
356        self.assertIsInstance(r, unicode)
357        self.assertEqual(r, u'"that\'s kÀse"')
358        self.assertEqual(f(r"It's fine to have a \ inside."),
359            '"It\'s fine to have a \\ inside."')
360        self.assertEqual(f('All "quotes" must be escaped.'),
361            '"All ""quotes"" must be escaped."')
362
363    def testEscapeString(self):
364        f = self.db.escape_string
365        r = f(b"plain")
366        self.assertIsInstance(r, bytes)
367        self.assertEqual(r, b"plain")
368        r = f(u"plain")
369        self.assertIsInstance(r, unicode)
370        self.assertEqual(r, u"plain")
371        r = f(u"that's kÀse".encode('utf-8'))
372        self.assertIsInstance(r, bytes)
373        self.assertEqual(r, u"that''s kÀse".encode('utf-8'))
374        r = f(u"that's kÀse")
375        self.assertIsInstance(r, unicode)
376        self.assertEqual(r, u"that''s kÀse")
377        self.assertEqual(f(r"It's fine to have a \ inside."),
378            r"It''s fine to have a \ inside.")
379
380    def testEscapeBytea(self):
381        f = self.db.escape_bytea
382        # note that escape_byte always returns hex output since Pg 9.0,
383        # regardless of the bytea_output setting
384        r = f(b'plain')
385        self.assertIsInstance(r, bytes)
386        self.assertEqual(r, b'\\x706c61696e')
387        r = f(u'plain')
388        self.assertIsInstance(r, unicode)
389        self.assertEqual(r, u'\\x706c61696e')
390        r = f(u"das is' kÀse".encode('utf-8'))
391        self.assertIsInstance(r, bytes)
392        self.assertEqual(r, b'\\x64617320697327206bc3a47365')
393        r = f(u"das is' kÀse")
394        self.assertIsInstance(r, unicode)
395        self.assertEqual(r, u'\\x64617320697327206bc3a47365')
396        self.assertEqual(f(b'O\x00ps\xff!'), b'\\x4f007073ff21')
397
398    def testUnescapeBytea(self):
399        f = self.db.unescape_bytea
400        r = f(b'plain')
401        self.assertIsInstance(r, bytes)
402        self.assertEqual(r, b'plain')
403        r = f(u'plain')
404        self.assertIsInstance(r, bytes)
405        self.assertEqual(r, b'plain')
406        r = f(b"das is' k\\303\\244se")
407        self.assertIsInstance(r, bytes)
408        self.assertEqual(r, u"das is' kÀse".encode('utf8'))
409        r = f(u"das is' k\\303\\244se")
410        self.assertIsInstance(r, bytes)
411        self.assertEqual(r, u"das is' kÀse".encode('utf8'))
412        self.assertEqual(f(r'O\\000ps\\377!'), b'O\\000ps\\377!')
413        self.assertEqual(f(r'\\x706c61696e'), b'\\x706c61696e')
414        self.assertEqual(f(r'\\x746861742773206be47365'),
415            b'\\x746861742773206be47365')
416        self.assertEqual(f(r'\\x4f007073ff21'), b'\\x4f007073ff21')
417
418    def testGetParameter(self):
419        f = self.db.get_parameter
420        self.assertRaises(TypeError, f)
421        self.assertRaises(TypeError, f, None)
422        self.assertRaises(TypeError, f, 42)
423        self.assertRaises(TypeError, f, '')
424        self.assertRaises(TypeError, f, [])
425        self.assertRaises(TypeError, f, [''])
426        self.assertRaises(pg.ProgrammingError, f, 'this_does_not_exist')
427        r = f('standard_conforming_strings')
428        self.assertEqual(r, 'on')
429        r = f('lc_monetary')
430        self.assertEqual(r, 'C')
431        r = f('datestyle')
432        self.assertEqual(r, 'ISO, YMD')
433        r = f('bytea_output')
434        self.assertEqual(r, 'hex')
435        r = f(['bytea_output', 'lc_monetary'])
436        self.assertIsInstance(r, list)
437        self.assertEqual(r, ['hex', 'C'])
438        r = f(('standard_conforming_strings', 'datestyle', 'bytea_output'))
439        self.assertEqual(r, ['on', 'ISO, YMD', 'hex'])
440        r = f(set(['bytea_output', 'lc_monetary']))
441        self.assertIsInstance(r, dict)
442        self.assertEqual(r, {'bytea_output': 'hex', 'lc_monetary': 'C'})
443        r = f(set(['Bytea_Output', ' LC_Monetary ']))
444        self.assertIsInstance(r, dict)
445        self.assertEqual(r, {'Bytea_Output': 'hex', ' LC_Monetary ': 'C'})
446        s = dict.fromkeys(('bytea_output', 'lc_monetary'))
447        r = f(s)
448        self.assertIs(r, s)
449        self.assertEqual(r, {'bytea_output': 'hex', 'lc_monetary': 'C'})
450        s = dict.fromkeys(('Bytea_Output', ' LC_Monetary '))
451        r = f(s)
452        self.assertIs(r, s)
453        self.assertEqual(r, {'Bytea_Output': 'hex', ' LC_Monetary ': 'C'})
454
455    def testGetParameterServerVersion(self):
456        r = self.db.get_parameter('server_version_num')
457        self.assertIsInstance(r, str)
458        s = self.db.server_version
459        self.assertIsInstance(s, int)
460        self.assertEqual(r, str(s))
461
462    def testGetParameterAll(self):
463        f = self.db.get_parameter
464        r = f('all')
465        self.assertIsInstance(r, dict)
466        self.assertEqual(r['standard_conforming_strings'], 'on')
467        self.assertEqual(r['lc_monetary'], 'C')
468        self.assertEqual(r['DateStyle'], 'ISO, YMD')
469        self.assertEqual(r['bytea_output'], 'hex')
470
471    def testSetParameter(self):
472        f = self.db.set_parameter
473        g = self.db.get_parameter
474        self.assertRaises(TypeError, f)
475        self.assertRaises(TypeError, f, None)
476        self.assertRaises(TypeError, f, 42)
477        self.assertRaises(TypeError, f, '')
478        self.assertRaises(TypeError, f, [])
479        self.assertRaises(TypeError, f, [''])
480        self.assertRaises(ValueError, f, 'all', 'invalid')
481        self.assertRaises(ValueError, f, {
482            'invalid1': 'value1', 'invalid2': 'value2'}, 'value')
483        self.assertRaises(pg.ProgrammingError, f, 'this_does_not_exist')
484        f('standard_conforming_strings', 'off')
485        self.assertEqual(g('standard_conforming_strings'), 'off')
486        f('datestyle', 'ISO, DMY')
487        self.assertEqual(g('datestyle'), 'ISO, DMY')
488        f(['standard_conforming_strings', 'datestyle'], ['on', 'ISO, DMY'])
489        self.assertEqual(g('standard_conforming_strings'), 'on')
490        self.assertEqual(g('datestyle'), 'ISO, DMY')
491        f(['default_with_oids', 'standard_conforming_strings'], 'off')
492        self.assertEqual(g('default_with_oids'), 'off')
493        self.assertEqual(g('standard_conforming_strings'), 'off')
494        f(('standard_conforming_strings', 'datestyle'), ('on', 'ISO, YMD'))
495        self.assertEqual(g('standard_conforming_strings'), 'on')
496        self.assertEqual(g('datestyle'), 'ISO, YMD')
497        f(('default_with_oids', 'standard_conforming_strings'), 'off')
498        self.assertEqual(g('default_with_oids'), 'off')
499        self.assertEqual(g('standard_conforming_strings'), 'off')
500        f(set(['default_with_oids', 'standard_conforming_strings']), 'on')
501        self.assertEqual(g('default_with_oids'), 'on')
502        self.assertEqual(g('standard_conforming_strings'), 'on')
503        self.assertRaises(ValueError, f, set([ 'default_with_oids',
504            'standard_conforming_strings']), ['off', 'on'])
505        f(set(['default_with_oids', 'standard_conforming_strings']),
506            ['off', 'off'])
507        self.assertEqual(g('default_with_oids'), 'off')
508        self.assertEqual(g('standard_conforming_strings'), 'off')
509        f({'standard_conforming_strings': 'on', 'datestyle': 'ISO, YMD'})
510        self.assertEqual(g('standard_conforming_strings'), 'on')
511        self.assertEqual(g('datestyle'), 'ISO, YMD')
512
513    def testResetParameter(self):
514        db = DB()
515        f = db.set_parameter
516        g = db.get_parameter
517        r = g('default_with_oids')
518        self.assertIn(r, ('on', 'off'))
519        dwi, not_dwi = r, 'off' if r == 'on' else 'on'
520        r = g('standard_conforming_strings')
521        self.assertIn(r, ('on', 'off'))
522        scs, not_scs = r, 'off' if r == 'on' else 'on'
523        f('default_with_oids', not_dwi)
524        f('standard_conforming_strings', not_scs)
525        self.assertEqual(g('default_with_oids'), not_dwi)
526        self.assertEqual(g('standard_conforming_strings'), not_scs)
527        f('default_with_oids')
528        f('standard_conforming_strings', None)
529        self.assertEqual(g('default_with_oids'), dwi)
530        self.assertEqual(g('standard_conforming_strings'), scs)
531        f('default_with_oids', not_dwi)
532        f('standard_conforming_strings', not_scs)
533        self.assertEqual(g('default_with_oids'), not_dwi)
534        self.assertEqual(g('standard_conforming_strings'), not_scs)
535        f(['default_with_oids', 'standard_conforming_strings'], None)
536        self.assertEqual(g('default_with_oids'), dwi)
537        self.assertEqual(g('standard_conforming_strings'), scs)
538        f('default_with_oids', not_dwi)
539        f('standard_conforming_strings', not_scs)
540        self.assertEqual(g('default_with_oids'), not_dwi)
541        self.assertEqual(g('standard_conforming_strings'), not_scs)
542        f(('default_with_oids', 'standard_conforming_strings'))
543        self.assertEqual(g('default_with_oids'), dwi)
544        self.assertEqual(g('standard_conforming_strings'), scs)
545        f('default_with_oids', not_dwi)
546        f('standard_conforming_strings', not_scs)
547        self.assertEqual(g('default_with_oids'), not_dwi)
548        self.assertEqual(g('standard_conforming_strings'), not_scs)
549        f(set(['default_with_oids', 'standard_conforming_strings']))
550        self.assertEqual(g('default_with_oids'), dwi)
551        self.assertEqual(g('standard_conforming_strings'), scs)
552
553    def testResetParameterAll(self):
554        db = DB()
555        f = db.set_parameter
556        self.assertRaises(ValueError, f, 'all', 0)
557        self.assertRaises(ValueError, f, 'all', 'off')
558        g = db.get_parameter
559        r = g('default_with_oids')
560        self.assertIn(r, ('on', 'off'))
561        dwi, not_dwi = r, 'off' if r == 'on' else 'on'
562        r = g('standard_conforming_strings')
563        self.assertIn(r, ('on', 'off'))
564        scs, not_scs = r, 'off' if r == 'on' else 'on'
565        f('default_with_oids', not_dwi)
566        f('standard_conforming_strings', not_scs)
567        self.assertEqual(g('default_with_oids'), not_dwi)
568        self.assertEqual(g('standard_conforming_strings'), not_scs)
569        f('all')
570        self.assertEqual(g('default_with_oids'), dwi)
571        self.assertEqual(g('standard_conforming_strings'), scs)
572
573    def testSetParameterLocal(self):
574        f = self.db.set_parameter
575        g = self.db.get_parameter
576        self.assertEqual(g('standard_conforming_strings'), 'on')
577        self.db.begin()
578        f('standard_conforming_strings', 'off', local=True)
579        self.assertEqual(g('standard_conforming_strings'), 'off')
580        self.db.end()
581        self.assertEqual(g('standard_conforming_strings'), 'on')
582
583    def testSetParameterSession(self):
584        f = self.db.set_parameter
585        g = self.db.get_parameter
586        self.assertEqual(g('standard_conforming_strings'), 'on')
587        self.db.begin()
588        f('standard_conforming_strings', 'off', local=False)
589        self.assertEqual(g('standard_conforming_strings'), 'off')
590        self.db.end()
591        self.assertEqual(g('standard_conforming_strings'), 'off')
592
593    def testReset(self):
594        db = DB()
595        default_datestyle = db.get_parameter('datestyle')
596        changed_datestyle = 'ISO, DMY'
597        if changed_datestyle == default_datestyle:
598            changed_datestyle == 'ISO, YMD'
599        self.db.set_parameter('datestyle', changed_datestyle)
600        r = self.db.get_parameter('datestyle')
601        self.assertEqual(r, changed_datestyle)
602        con = self.db.db
603        q = con.query("show datestyle")
604        self.db.reset()
605        r = q.getresult()[0][0]
606        self.assertEqual(r, changed_datestyle)
607        q = con.query("show datestyle")
608        r = q.getresult()[0][0]
609        self.assertEqual(r, default_datestyle)
610        r = self.db.get_parameter('datestyle')
611        self.assertEqual(r, default_datestyle)
612
613    def testReopen(self):
614        db = DB()
615        default_datestyle = db.get_parameter('datestyle')
616        changed_datestyle = 'ISO, DMY'
617        if changed_datestyle == default_datestyle:
618            changed_datestyle == 'ISO, YMD'
619        self.db.set_parameter('datestyle', changed_datestyle)
620        r = self.db.get_parameter('datestyle')
621        self.assertEqual(r, changed_datestyle)
622        con = self.db.db
623        q = con.query("show datestyle")
624        self.db.reopen()
625        r = q.getresult()[0][0]
626        self.assertEqual(r, changed_datestyle)
627        self.assertRaises(TypeError, getattr, con, 'query')
628        r = self.db.get_parameter('datestyle')
629        self.assertEqual(r, default_datestyle)
630
631    def testQuery(self):
632        query = self.db.query
633        query("drop table if exists test_table")
634        self.addCleanup(query, "drop table test_table")
635        q = "create table test_table (n integer) with oids"
636        r = query(q)
637        self.assertIsNone(r)
638        q = "insert into test_table values (1)"
639        r = query(q)
640        self.assertIsInstance(r, int)
641        q = "insert into test_table select 2"
642        r = query(q)
643        self.assertIsInstance(r, int)
644        oid = r
645        q = "select oid from test_table where n=2"
646        r = query(q).getresult()
647        self.assertEqual(len(r), 1)
648        r = r[0]
649        self.assertEqual(len(r), 1)
650        r = r[0]
651        self.assertEqual(r, oid)
652        q = "insert into test_table select 3 union select 4 union select 5"
653        r = query(q)
654        self.assertIsInstance(r, str)
655        self.assertEqual(r, '3')
656        q = "update test_table set n=4 where n<5"
657        r = query(q)
658        self.assertIsInstance(r, str)
659        self.assertEqual(r, '4')
660        q = "delete from test_table"
661        r = query(q)
662        self.assertIsInstance(r, str)
663        self.assertEqual(r, '5')
664
665    def testMultipleQueries(self):
666        self.assertEqual(self.db.query(
667            "create temporary table test_multi (n integer);"
668            "insert into test_multi values (4711);"
669            "select n from test_multi").getresult()[0][0], 4711)
670
671    def testQueryWithParams(self):
672        query = self.db.query
673        query("drop table if exists test_table")
674        self.addCleanup(query, "drop table test_table")
675        q = "create table test_table (n1 integer, n2 integer) with oids"
676        query(q)
677        q = "insert into test_table values ($1, $2)"
678        r = query(q, (1, 2))
679        self.assertIsInstance(r, int)
680        r = query(q, [3, 4])
681        self.assertIsInstance(r, int)
682        r = query(q, [5, 6])
683        self.assertIsInstance(r, int)
684        q = "select * from test_table order by 1, 2"
685        self.assertEqual(query(q).getresult(),
686            [(1, 2), (3, 4), (5, 6)])
687        q = "select * from test_table where n1=$1 and n2=$2"
688        self.assertEqual(query(q, 3, 4).getresult(), [(3, 4)])
689        q = "update test_table set n2=$2 where n1=$1"
690        r = query(q, 3, 7)
691        self.assertEqual(r, '1')
692        q = "select * from test_table order by 1, 2"
693        self.assertEqual(query(q).getresult(),
694            [(1, 2), (3, 7), (5, 6)])
695        q = "delete from test_table where n2!=$1"
696        r = query(q, 4)
697        self.assertEqual(r, '3')
698
699    def testEmptyQuery(self):
700        self.assertRaises(ValueError, self.db.query, '')
701
702    def testQueryProgrammingError(self):
703        try:
704            self.db.query("select 1/0")
705        except pg.ProgrammingError as error:
706            self.assertEqual(error.sqlstate, '22012')
707
708    def testPkey(self):
709        query = self.db.query
710        pkey = self.db.pkey
711        self.assertRaises(KeyError, pkey, 'test')
712        for t in ('pkeytest', 'primary key test'):
713            for n in range(8):
714                query('drop table if exists "%s%d"' % (t, n))
715                self.addCleanup(query, 'drop table "%s%d"' % (t, n))
716            query('create table "%s0" ('
717                "a smallint)" % t)
718            query('create table "%s1" ('
719                "b smallint primary key)" % t)
720            query('create table "%s2" ('
721                "c smallint, d smallint primary key)" % t)
722            query('create table "%s3" ('
723                "e smallint, f smallint, g smallint,"
724                " h smallint, i smallint,"
725                " primary key (f, h))" % t)
726            query('create table "%s4" ('
727                "e smallint, f smallint, g smallint,"
728                " h smallint, i smallint,"
729                " primary key (h, f))" % t)
730            query('create table "%s5" ('
731                "more_than_one_letter varchar primary key)" % t)
732            query('create table "%s6" ('
733                '"with space" date primary key)' % t)
734            query('create table "%s7" ('
735                'a_very_long_column_name varchar,'
736                ' "with space" date,'
737                ' "42" int,'
738                " primary key (a_very_long_column_name,"
739                ' "with space", "42"))' % t)
740            self.assertRaises(KeyError, pkey, '%s0' % t)
741            self.assertEqual(pkey('%s1' % t), 'b')
742            self.assertEqual(pkey('%s1' % t, True), ('b',))
743            self.assertEqual(pkey('%s1' % t, composite=False), 'b')
744            self.assertEqual(pkey('%s1' % t, composite=True), ('b',))
745            self.assertEqual(pkey('%s2' % t), 'd')
746            self.assertEqual(pkey('%s2' % t, composite=True), ('d',))
747            r = pkey('%s3' % t)
748            self.assertIsInstance(r, tuple)
749            self.assertEqual(r, ('f', 'h'))
750            r = pkey('%s3' % t, composite=False)
751            self.assertIsInstance(r, tuple)
752            self.assertEqual(r, ('f', 'h'))
753            r = pkey('%s4' % t)
754            self.assertIsInstance(r, tuple)
755            self.assertEqual(r, ('h', 'f'))
756            self.assertEqual(pkey('%s5' % t), 'more_than_one_letter')
757            self.assertEqual(pkey('%s6' % t), 'with space')
758            r = pkey('%s7' % t)
759            self.assertIsInstance(r, tuple)
760            self.assertEqual(r, (
761                'a_very_long_column_name', 'with space', '42'))
762            # a newly added primary key will be detected
763            query('alter table "%s0" add primary key (a)' % t)
764            self.assertEqual(pkey('%s0' % t), 'a')
765            # a changed primary key will not be detected,
766            # indicating that the internal cache is operating
767            query('alter table "%s1" rename column b to x' % t)
768            self.assertEqual(pkey('%s1' % t), 'b')
769            # we get the changed primary key when the cache is flushed
770            self.assertEqual(pkey('%s1' % t, flush=True), 'x')
771
772    def testGetDatabases(self):
773        databases = self.db.get_databases()
774        self.assertIn('template0', databases)
775        self.assertIn('template1', databases)
776        self.assertNotIn('not existing database', databases)
777        self.assertIn('postgres', databases)
778        self.assertIn(dbname, databases)
779
780    def testGetTables(self):
781        get_tables = self.db.get_tables
782        result1 = get_tables()
783        self.assertIsInstance(result1, list)
784        for t in result1:
785            t = t.split('.', 1)
786            self.assertGreaterEqual(len(t), 2)
787            if len(t) > 2:
788                self.assertTrue(t[1].startswith('"'))
789            t = t[0]
790            self.assertNotEqual(t, 'information_schema')
791            self.assertFalse(t.startswith('pg_'))
792        tables = ('"A very Special Name"',
793            '"A_MiXeD_quoted_NaMe"', 'a1', 'a2',
794            'A_MiXeD_NaMe', '"another special name"',
795            'averyveryveryveryveryveryverylongtablename',
796            'b0', 'b3', 'x', 'xx', 'xXx', 'y', 'z')
797        for t in tables:
798            self.db.query('drop table if exists %s' % t)
799            self.db.query("create table %s"
800                " as select 0" % t)
801        result3 = get_tables()
802        result2 = []
803        for t in result3:
804            if t not in result1:
805                result2.append(t)
806        result3 = []
807        for t in tables:
808            if not t.startswith('"'):
809                t = t.lower()
810            result3.append('public.' + t)
811        self.assertEqual(result2, result3)
812        for t in result2:
813            self.db.query('drop table %s' % t)
814        result2 = get_tables()
815        self.assertEqual(result2, result1)
816
817    def testGetRelations(self):
818        get_relations = self.db.get_relations
819        result = get_relations()
820        self.assertIn('public.test', result)
821        self.assertIn('public.test_view', result)
822        result = get_relations('rv')
823        self.assertIn('public.test', result)
824        self.assertIn('public.test_view', result)
825        result = get_relations('r')
826        self.assertIn('public.test', result)
827        self.assertNotIn('public.test_view', result)
828        result = get_relations('v')
829        self.assertNotIn('public.test', result)
830        self.assertIn('public.test_view', result)
831        result = get_relations('cisSt')
832        self.assertNotIn('public.test', result)
833        self.assertNotIn('public.test_view', result)
834
835    def testGetAttnames(self):
836        get_attnames = self.db.get_attnames
837        self.assertRaises(pg.ProgrammingError,
838            self.db.get_attnames, 'does_not_exist')
839        self.assertRaises(pg.ProgrammingError,
840            self.db.get_attnames, 'has.too.many.dots')
841        r = get_attnames('test')
842        self.assertIsInstance(r, dict)
843        self.assertEqual(r, dict(
844            i2='int', i4='int', i8='int', d='num',
845            f4='float', f8='float', m='money',
846            v4='text', c4='text', t='text'))
847        query = self.db.query
848        query("drop table if exists test_table")
849        self.addCleanup(query, "drop table test_table")
850        query("create table test_table("
851            " n int, alpha smallint, beta bool,"
852            " gamma char(5), tau text, v varchar(3))")
853        r = get_attnames('test_table')
854        self.assertIsInstance(r, dict)
855        self.assertEqual(r, dict(
856            n='int', alpha='int', beta='bool',
857            gamma='text', tau='text', v='text'))
858
859    def testGetAttnamesWithQuotes(self):
860        get_attnames = self.db.get_attnames
861        query = self.db.query
862        table = 'test table for get_attnames()'
863        query('drop table if exists "%s"' % table)
864        self.addCleanup(query, 'drop table "%s"' % table)
865        query('create table "%s"('
866            '"Prime!" smallint,'
867            ' "much space" integer, "Questions?" text)' % table)
868        r = get_attnames(table)
869        self.assertIsInstance(r, dict)
870        self.assertEqual(r, {
871            'Prime!': 'int', 'much space': 'int', 'Questions?': 'text'})
872        table = 'yet another test table for get_attnames()'
873        query('drop table if exists "%s"' % table)
874        self.addCleanup(query, 'drop table "%s"' % table)
875        self.db.query('create table "%s" ('
876            'a smallint, b integer, c bigint,'
877            ' e numeric, f float, f2 double precision, m money,'
878            ' x smallint, y smallint, z smallint,'
879            ' Normal_NaMe smallint, "Special Name" smallint,'
880            ' t text, u char(2), v varchar(2),'
881            ' primary key (y, u)) with oids' % table)
882        r = get_attnames(table)
883        self.assertIsInstance(r, dict)
884        self.assertEqual(r, {'a': 'int', 'c': 'int', 'b': 'int',
885            'e': 'num', 'f': 'float', 'f2': 'float', 'm': 'money',
886            'normal_name': 'int', 'Special Name': 'int',
887            'u': 'text', 't': 'text', 'v': 'text',
888            'y': 'int', 'x': 'int', 'z': 'int', 'oid': 'int'})
889
890    def testGetAttnamesWithRegtypes(self):
891        get_attnames = self.db.get_attnames
892        query = self.db.query
893        query("drop table if exists test_table")
894        self.addCleanup(query, "drop table test_table")
895        query("create table test_table("
896            " n int, alpha smallint, beta bool,"
897            " gamma char(5), tau text, v varchar(3))")
898        use_regtypes = self.db.use_regtypes
899        regtypes = use_regtypes()
900        self.assertFalse(regtypes)
901        use_regtypes(True)
902        try:
903            r = get_attnames("test_table")
904            self.assertIsInstance(r, dict)
905        finally:
906            use_regtypes(regtypes)
907        self.assertEqual(r, dict(
908            n='integer', alpha='smallint', beta='boolean',
909            gamma='character', tau='text', v='character varying'))
910
911    def testGetAttnamesIsCached(self):
912        get_attnames = self.db.get_attnames
913        query = self.db.query
914        query("drop table if exists test_table")
915        self.addCleanup(query, "drop table test_table")
916        query("create table test_table(col int)")
917        r = get_attnames("test_table")
918        self.assertIsInstance(r, dict)
919        self.assertEqual(r, dict(col='int'))
920        query("alter table test_table alter column col type text")
921        query("alter table test_table add column col2 int")
922        r = get_attnames("test_table")
923        self.assertEqual(r, dict(col='int'))
924        r = get_attnames("test_table", flush=True)
925        self.assertEqual(r, dict(col='text', col2='int'))
926        query("alter table test_table drop column col2")
927        r = get_attnames("test_table")
928        self.assertEqual(r, dict(col='text', col2='int'))
929        r = get_attnames("test_table", flush=True)
930        self.assertEqual(r, dict(col='text'))
931        query("alter table test_table drop column col")
932        r = get_attnames("test_table")
933        self.assertEqual(r, dict(col='text'))
934        r = get_attnames("test_table", flush=True)
935        self.assertEqual(r, dict())
936
937    def testGetAttnamesIsOrdered(self):
938        get_attnames = self.db.get_attnames
939        query = self.db.query
940        query("drop table if exists test_table")
941        self.addCleanup(query, "drop table test_table")
942        query("create table test_table("
943            " n int, alpha smallint, v varchar(3),"
944            " gamma char(5), tau text, beta bool)")
945        r = get_attnames("test_table")
946        self.assertIsInstance(r, OrderedDict)
947        self.assertEqual(r, OrderedDict([
948            ('n', 'int'), ('alpha', 'int'), ('v', 'text'),
949            ('gamma', 'text'), ('tau', 'text'), ('beta', 'bool')]))
950        if OrderedDict is dict:
951            self.skipTest('OrderedDict is not supported')
952        r = ' '.join(list(r.keys()))
953        self.assertEqual(r, 'n alpha v gamma tau beta')
954
955    def testHasTablePrivilege(self):
956        can = self.db.has_table_privilege
957        self.assertEqual(can('test'), True)
958        self.assertEqual(can('test', 'select'), True)
959        self.assertEqual(can('test', 'SeLeCt'), True)
960        self.assertEqual(can('test', 'SELECT'), True)
961        self.assertEqual(can('test', 'insert'), True)
962        self.assertEqual(can('test', 'update'), True)
963        self.assertEqual(can('test', 'delete'), True)
964        self.assertEqual(can('pg_views', 'select'), True)
965        self.assertEqual(can('pg_views', 'delete'), False)
966        self.assertRaises(pg.ProgrammingError, can, 'test', 'foobar')
967        self.assertRaises(pg.ProgrammingError, can, 'table_does_not_exist')
968
969    def testGet(self):
970        get = self.db.get
971        query = self.db.query
972        table = 'get_test_table'
973        self.assertRaises(TypeError, get)
974        self.assertRaises(TypeError, get, table)
975        query('drop table if exists "%s"' % table)
976        self.addCleanup(query, 'drop table "%s"' % table)
977        query('create table "%s" ('
978            "n integer, t text) without oids" % table)
979        for n, t in enumerate('xyz'):
980            query('insert into "%s" values('"%d, '%s')"
981                % (table, n + 1, t))
982        self.assertRaises(pg.ProgrammingError, get, table, 2)
983        r = get(table, 2, 'n')
984        self.assertIsInstance(r, dict)
985        self.assertEqual(r, dict(n=2, t='y'))
986        r = get(table, 1, 'n')
987        self.assertEqual(r, dict(n=1, t='x'))
988        r = get(table, (3,), ('n',))
989        self.assertEqual(r, dict(n=3, t='z'))
990        r = get(table, 'y', 't')
991        self.assertEqual(r, dict(n=2, t='y'))
992        self.assertRaises(pg.DatabaseError, get, table, 4)
993        self.assertRaises(pg.DatabaseError, get, table, 4, 'n')
994        self.assertRaises(pg.DatabaseError, get, table, 'y')
995        self.assertRaises(pg.DatabaseError, get, table, 2, 't')
996        s = dict(n=3)
997        self.assertRaises(pg.ProgrammingError, get, table, s)
998        r = get(table, s, 'n')
999        self.assertIs(r, s)
1000        self.assertEqual(r, dict(n=3, t='z'))
1001        s.update(t='x')
1002        r = get(table, s, 't')
1003        self.assertIs(r, s)
1004        self.assertEqual(s, dict(n=1, t='x'))
1005        r = get(table, s, ('n', 't'))
1006        self.assertIs(r, s)
1007        self.assertEqual(r, dict(n=1, t='x'))
1008        query('alter table "%s" alter n set not null' % table)
1009        query('alter table "%s" add primary key (n)' % table)
1010        r = get(table, 2)
1011        self.assertIsInstance(r, dict)
1012        self.assertEqual(r, dict(n=2, t='y'))
1013        self.assertEqual(get(table, 1)['t'], 'x')
1014        self.assertEqual(get(table, 3)['t'], 'z')
1015        self.assertEqual(get(table + '*', 2)['t'], 'y')
1016        self.assertEqual(get(table + ' *', 2)['t'], 'y')
1017        self.assertRaises(KeyError, get, table, (2, 2))
1018        s = dict(n=3)
1019        r = get(table, s)
1020        self.assertIs(r, s)
1021        self.assertEqual(r, dict(n=3, t='z'))
1022        s.update(n=1)
1023        self.assertEqual(get(table, s)['t'], 'x')
1024        s.update(n=2)
1025        self.assertEqual(get(table, r)['t'], 'y')
1026        s.pop('n')
1027        self.assertRaises(KeyError, get, table, s)
1028
1029    def testGetWithOid(self):
1030        get = self.db.get
1031        query = self.db.query
1032        table = 'get_with_oid_test_table'
1033        query('drop table if exists "%s"' % table)
1034        self.addCleanup(query, 'drop table "%s"' % table)
1035        query('create table "%s" ('
1036            "n integer, t text) with oids" % table)
1037        for n, t in enumerate('xyz'):
1038            query('insert into "%s" values('"%d, '%s')"
1039                % (table, n + 1, t))
1040        self.assertRaises(pg.ProgrammingError, get, table, 2)
1041        self.assertRaises(KeyError, get, table, {}, 'oid')
1042        r = get(table, 2, 'n')
1043        qoid = 'oid(%s)' % table
1044        self.assertIn(qoid, r)
1045        oid = r[qoid]
1046        self.assertIsInstance(oid, int)
1047        result = {'t': 'y', 'n': 2, qoid: oid}
1048        self.assertEqual(r, result)
1049        r = get(table, oid, 'oid')
1050        self.assertEqual(r, result)
1051        r = get(table, dict(oid=oid))
1052        self.assertEqual(r, result)
1053        r = get(table, dict(oid=oid), 'oid')
1054        self.assertEqual(r, result)
1055        r = get(table, {qoid: oid})
1056        self.assertEqual(r, result)
1057        r = get(table, {qoid: oid}, 'oid')
1058        self.assertEqual(r, result)
1059        self.assertEqual(get(table + '*', 2, 'n'), r)
1060        self.assertEqual(get(table + ' *', 2, 'n'), r)
1061        self.assertEqual(get(table, oid, 'oid')['t'], 'y')
1062        self.assertEqual(get(table, 1, 'n')['t'], 'x')
1063        self.assertEqual(get(table, 3, 'n')['t'], 'z')
1064        self.assertEqual(get(table, 2, 'n')['t'], 'y')
1065        self.assertRaises(pg.DatabaseError, get, table, 4, 'n')
1066        r['n'] = 3
1067        self.assertEqual(get(table, r, 'n')['t'], 'z')
1068        self.assertEqual(get(table, 1, 'n')['t'], 'x')
1069        self.assertEqual(get(table, r, 'oid')['t'], 'z')
1070        query('alter table "%s" alter n set not null' % table)
1071        query('alter table "%s" add primary key (n)' % table)
1072        self.assertEqual(get(table, 3)['t'], 'z')
1073        self.assertEqual(get(table, 1)['t'], 'x')
1074        self.assertEqual(get(table, 2)['t'], 'y')
1075        r['n'] = 1
1076        self.assertEqual(get(table, r)['t'], 'x')
1077        r['n'] = 3
1078        self.assertEqual(get(table, r)['t'], 'z')
1079        r['n'] = 2
1080        self.assertEqual(get(table, r)['t'], 'y')
1081        r = get(table, oid, 'oid')
1082        self.assertEqual(r, result)
1083        r = get(table, dict(oid=oid))
1084        self.assertEqual(r, result)
1085        r = get(table, dict(oid=oid), 'oid')
1086        self.assertEqual(r, result)
1087        r = get(table, {qoid: oid})
1088        self.assertEqual(r, result)
1089        r = get(table, {qoid: oid}, 'oid')
1090        self.assertEqual(r, result)
1091        r = get(table, dict(oid=oid, n=1))
1092        self.assertEqual(r['n'], 1)
1093        self.assertNotEqual(r[qoid], oid)
1094        r = get(table, dict(oid=oid, t='z'), 't')
1095        self.assertEqual(r['n'], 3)
1096        self.assertNotEqual(r[qoid], oid)
1097
1098    def testGetWithCompositeKey(self):
1099        get = self.db.get
1100        query = self.db.query
1101        table = 'get_test_table_1'
1102        query('drop table if exists "%s"' % table)
1103        self.addCleanup(query, 'drop table "%s"' % table)
1104        query('create table "%s" ('
1105            "n integer, t text, primary key (n))" % table)
1106        for n, t in enumerate('abc'):
1107            query('insert into "%s" values('
1108                "%d, '%s')" % (table, n + 1, t))
1109        self.assertEqual(get(table, 2)['t'], 'b')
1110        self.assertEqual(get(table, 1, 'n')['t'], 'a')
1111        self.assertEqual(get(table, 2, ('n',))['t'], 'b')
1112        self.assertEqual(get(table, 3, ['n'])['t'], 'c')
1113        self.assertEqual(get(table, (2,), ('n',))['t'], 'b')
1114        self.assertEqual(get(table, 'b', 't')['n'], 2)
1115        self.assertEqual(get(table, ('a',), ('t',))['n'], 1)
1116        self.assertEqual(get(table, ['c'], ['t'])['n'], 3)
1117        table = 'get_test_table_2'
1118        query('drop table if exists "%s"' % table)
1119        self.addCleanup(query, 'drop table "%s"' % table)
1120        query('create table "%s" ('
1121            "n integer, m integer, t text, primary key (n, m))" % table)
1122        for n in range(3):
1123            for m in range(2):
1124                t = chr(ord('a') + 2 * n + m)
1125                query('insert into "%s" values('
1126                    "%d, %d, '%s')" % (table, n + 1, m + 1, t))
1127        self.assertRaises(KeyError, get, table, 2)
1128        self.assertEqual(get(table, (1, 1))['t'], 'a')
1129        self.assertEqual(get(table, (1, 2))['t'], 'b')
1130        self.assertEqual(get(table, (2, 1))['t'], 'c')
1131        self.assertEqual(get(table, (1, 2), ('n', 'm'))['t'], 'b')
1132        self.assertEqual(get(table, (1, 2), ('m', 'n'))['t'], 'c')
1133        self.assertEqual(get(table, (3, 1), ('n', 'm'))['t'], 'e')
1134        self.assertEqual(get(table, (1, 3), ('m', 'n'))['t'], 'e')
1135        self.assertEqual(get(table, dict(n=2, m=2))['t'], 'd')
1136        self.assertEqual(get(table, dict(n=1, m=2), ('n', 'm'))['t'], 'b')
1137        self.assertEqual(get(table, dict(n=2, m=1), ['n', 'm'])['t'], 'c')
1138        self.assertEqual(get(table, dict(n=3, m=2), ('m', 'n'))['t'], 'f')
1139
1140    def testGetWithQuotedNames(self):
1141        get = self.db.get
1142        query = self.db.query
1143        table = 'test table for get()'
1144        query('drop table if exists "%s"' % table)
1145        self.addCleanup(query, 'drop table "%s"' % table)
1146        query('create table "%s" ('
1147            '"Prime!" smallint primary key,'
1148            ' "much space" integer, "Questions?" text)' % table)
1149        query('insert into "%s"'
1150              " values(17, 1001, 'No!')" % table)
1151        r = get(table, 17)
1152        self.assertIsInstance(r, dict)
1153        self.assertEqual(r['Prime!'], 17)
1154        self.assertEqual(r['much space'], 1001)
1155        self.assertEqual(r['Questions?'], 'No!')
1156
1157    def testGetFromView(self):
1158        self.db.query('delete from test where i4=14')
1159        self.db.query('insert into test (i4, v4) values('
1160            "14, 'abc4')")
1161        r = self.db.get('test_view', 14, 'i4')
1162        self.assertIn('v4', r)
1163        self.assertEqual(r['v4'], 'abc4')
1164
1165    def testGetLittleBobbyTables(self):
1166        get = self.db.get
1167        query = self.db.query
1168        query("drop table if exists test_students")
1169        self.addCleanup(query, "drop table test_students")
1170        query("create table test_students (firstname varchar primary key,"
1171            " nickname varchar, grade char(2))")
1172        query("insert into test_students values ("
1173              "'D''Arcy', 'Darcey', 'A+')")
1174        query("insert into test_students values ("
1175              "'Sheldon', 'Moonpie', 'A+')")
1176        query("insert into test_students values ("
1177              "'Robert', 'Little Bobby Tables', 'D-')")
1178        r = get('test_students', 'Sheldon')
1179        self.assertEqual(r, dict(
1180            firstname="Sheldon", nickname='Moonpie', grade='A+'))
1181        r = get('test_students', 'Robert')
1182        self.assertEqual(r, dict(
1183            firstname="Robert", nickname='Little Bobby Tables', grade='D-'))
1184        r = get('test_students', "D'Arcy")
1185        self.assertEqual(r, dict(
1186            firstname="D'Arcy", nickname='Darcey', grade='A+'))
1187        try:
1188            get('test_students', "D' Arcy")
1189        except pg.DatabaseError as error:
1190            self.assertEqual(str(error),
1191                'No such record in test_students\nwhere "firstname" = $1\n'
1192                'with $1="D\' Arcy"')
1193        try:
1194            get('test_students', "Robert'); TRUNCATE TABLE test_students;--")
1195        except pg.DatabaseError as error:
1196            self.assertEqual(str(error),
1197                'No such record in test_students\nwhere "firstname" = $1\n'
1198                'with $1="Robert\'); TRUNCATE TABLE test_students;--"')
1199        q = "select * from test_students order by 1 limit 4"
1200        r = query(q).getresult()
1201        self.assertEqual(len(r), 3)
1202        self.assertEqual(r[1][2], 'D-')
1203
1204    def testInsert(self):
1205        insert = self.db.insert
1206        query = self.db.query
1207        bool_on = pg.get_bool()
1208        decimal = pg.get_decimal()
1209        table = 'insert_test_table'
1210        query('drop table if exists "%s"' % table)
1211        self.addCleanup(query, 'drop table "%s"' % table)
1212        query('create table "%s" ('
1213            "i2 smallint, i4 integer, i8 bigint,"
1214            " d numeric, f4 real, f8 double precision, m money,"
1215            " v4 varchar(4), c4 char(4), t text,"
1216            " b boolean, ts timestamp) with oids" % table)
1217        oid_table = 'oid(%s)' % table
1218        tests = [dict(i2=None, i4=None, i8=None),
1219            (dict(i2='', i4='', i8=''), dict(i2=None, i4=None, i8=None)),
1220            (dict(i2=0, i4=0, i8=0), dict(i2=0, i4=0, i8=0)),
1221            dict(i2=42, i4=123456, i8=9876543210),
1222            dict(i2=2 ** 15 - 1,
1223                i4=int(2 ** 31 - 1), i8=long(2 ** 63 - 1)),
1224            dict(d=None), (dict(d=''), dict(d=None)),
1225            dict(d=Decimal(0)), (dict(d=0), dict(d=Decimal(0))),
1226            dict(f4=None, f8=None), dict(f4=0, f8=0),
1227            (dict(f4='', f8=''), dict(f4=None, f8=None)),
1228            (dict(d=1234.5, f4=1234.5, f8=1234.5),
1229                  dict(d=Decimal('1234.5'))),
1230            dict(d=Decimal('123.456789'), f4=12.375, f8=123.4921875),
1231            dict(d=Decimal('123456789.9876543212345678987654321')),
1232            dict(m=None), (dict(m=''), dict(m=None)),
1233            dict(m=Decimal('-1234.56')),
1234            (dict(m=('-1234.56')), dict(m=Decimal('-1234.56'))),
1235            dict(m=Decimal('1234.56')), dict(m=Decimal('123456')),
1236            (dict(m='1234.56'), dict(m=Decimal('1234.56'))),
1237            (dict(m=1234.5), dict(m=Decimal('1234.5'))),
1238            (dict(m=-1234.5), dict(m=Decimal('-1234.5'))),
1239            (dict(m=123456), dict(m=Decimal('123456'))),
1240            (dict(m='1234567.89'), dict(m=Decimal('1234567.89'))),
1241            dict(b=None), (dict(b=''), dict(b=None)),
1242            dict(b='f'), dict(b='t'),
1243            (dict(b=0), dict(b='f')), (dict(b=1), dict(b='t')),
1244            (dict(b=False), dict(b='f')), (dict(b=True), dict(b='t')),
1245            (dict(b='0'), dict(b='f')), (dict(b='1'), dict(b='t')),
1246            (dict(b='n'), dict(b='f')), (dict(b='y'), dict(b='t')),
1247            (dict(b='no'), dict(b='f')), (dict(b='yes'), dict(b='t')),
1248            (dict(b='off'), dict(b='f')), (dict(b='on'), dict(b='t')),
1249            dict(v4=None, c4=None, t=None),
1250            (dict(v4='', c4='', t=''), dict(c4=' ' * 4)),
1251            dict(v4='1234', c4='1234', t='1234' * 10),
1252            dict(v4='abcd', c4='abcd', t='abcdefg'),
1253            (dict(v4='abc', c4='abc', t='abc'), dict(c4='abc ')),
1254            dict(ts=None), (dict(ts=''), dict(ts=None)),
1255            (dict(ts=0), dict(ts=None)), (dict(ts=False), dict(ts=None)),
1256            dict(ts='2012-12-21 00:00:00'),
1257            (dict(ts='2012-12-21'), dict(ts='2012-12-21 00:00:00')),
1258            dict(ts='2012-12-21 12:21:12'),
1259            dict(ts='2013-01-05 12:13:14'),
1260            dict(ts='current_timestamp')]
1261        for test in tests:
1262            if isinstance(test, dict):
1263                data = test
1264                change = {}
1265            else:
1266                data, change = test
1267            expect = data.copy()
1268            expect.update(change)
1269            if bool_on:
1270                b = expect.get('b')
1271                if b is not None:
1272                    expect['b'] = b == 't'
1273            if decimal is not Decimal:
1274                d = expect.get('d')
1275                if d is not None:
1276                    expect['d'] = decimal(d)
1277                m = expect.get('m')
1278                if m is not None:
1279                    expect['m'] = decimal(m)
1280            self.assertEqual(insert(table, data), data)
1281            self.assertIn(oid_table, data)
1282            oid = data[oid_table]
1283            self.assertIsInstance(oid, int)
1284            data = dict(item for item in data.items()
1285                if item[0] in expect)
1286            ts = expect.get('ts')
1287            if ts == 'current_timestamp':
1288                ts = expect['ts'] = data['ts']
1289                if len(ts) > 19:
1290                    self.assertEqual(ts[19], '.')
1291                    ts = ts[:19]
1292                else:
1293                    self.assertEqual(len(ts), 19)
1294                self.assertTrue(ts[:4].isdigit())
1295                self.assertEqual(ts[4], '-')
1296                self.assertEqual(ts[10], ' ')
1297                self.assertTrue(ts[11:13].isdigit())
1298                self.assertEqual(ts[13], ':')
1299            self.assertEqual(data, expect)
1300            data = query(
1301                'select oid,* from "%s"' % table).dictresult()[0]
1302            self.assertEqual(data['oid'], oid)
1303            data = dict(item for item in data.items()
1304                if item[0] in expect)
1305            self.assertEqual(data, expect)
1306            query('delete from "%s"' % table)
1307
1308    def testInsertWithOid(self):
1309        insert = self.db.insert
1310        query = self.db.query
1311        query("drop table if exists test_table")
1312        self.addCleanup(query, "drop table test_table")
1313        query("create table test_table (n int) with oids")
1314        r = insert('test_table', n=1)
1315        self.assertIsInstance(r, dict)
1316        self.assertEqual(r['n'], 1)
1317        self.assertNotIn('oid', r)
1318        qoid = 'oid(test_table)'
1319        self.assertIn(qoid, r)
1320        oid = r[qoid]
1321        self.assertEqual(sorted(r.keys()), ['n', qoid])
1322        r = insert('test_table', n=2, oid=oid)
1323        self.assertIsInstance(r, dict)
1324        self.assertEqual(r['n'], 2)
1325        self.assertIn(qoid, r)
1326        self.assertNotEqual(r[qoid], oid)
1327        self.assertNotIn('oid', r)
1328        r = insert('test_table', None, n=3)
1329        self.assertIsInstance(r, dict)
1330        self.assertEqual(r['n'], 3)
1331        s = r
1332        r = insert('test_table', r)
1333        self.assertIs(r, s)
1334        self.assertEqual(r['n'], 3)
1335        r = insert('test_table *', r)
1336        self.assertIs(r, s)
1337        self.assertEqual(r['n'], 3)
1338        r = insert('test_table', r, n=4)
1339        self.assertIs(r, s)
1340        self.assertEqual(r['n'], 4)
1341        self.assertNotIn('oid', r)
1342        self.assertIn(qoid, r)
1343        oid = r[qoid]
1344        r = insert('test_table', r, n=5, oid=oid)
1345        self.assertIs(r, s)
1346        self.assertEqual(r['n'], 5)
1347        self.assertIn(qoid, r)
1348        self.assertNotEqual(r[qoid], oid)
1349        self.assertNotIn('oid', r)
1350        r['oid'] = oid = r[qoid]
1351        r = insert('test_table', r, n=6)
1352        self.assertIs(r, s)
1353        self.assertEqual(r['n'], 6)
1354        self.assertIn(qoid, r)
1355        self.assertNotEqual(r[qoid], oid)
1356        self.assertNotIn('oid', r)
1357        q = 'select n from test_table order by 1 limit 9'
1358        r = ' '.join(str(row[0]) for row in query(q).getresult())
1359        self.assertEqual(r, '1 2 3 3 3 4 5 6')
1360        query("truncate test_table")
1361        query("alter table test_table add unique (n)")
1362        r = insert('test_table', dict(n=7))
1363        self.assertIsInstance(r, dict)
1364        self.assertEqual(r['n'], 7)
1365        self.assertRaises(pg.ProgrammingError, insert, 'test_table', r)
1366        r['n'] = 6
1367        self.assertRaises(pg.ProgrammingError, insert, 'test_table', r, n=7)
1368        self.assertIsInstance(r, dict)
1369        self.assertEqual(r['n'], 7)
1370        r['n'] = 6
1371        r = insert('test_table', r)
1372        self.assertIsInstance(r, dict)
1373        self.assertEqual(r['n'], 6)
1374        r = query(q).getresult()
1375        r = ' '.join(str(row[0]) for row in query(q).getresult())
1376        self.assertEqual(r, '6 7')
1377
1378    def testInsertWithQuotedNames(self):
1379        insert = self.db.insert
1380        query = self.db.query
1381        table = 'test table for insert()'
1382        query('drop table if exists "%s"' % table)
1383        self.addCleanup(query, 'drop table "%s"' % table)
1384        query('create table "%s" ('
1385            '"Prime!" smallint primary key,'
1386            ' "much space" integer, "Questions?" text)' % table)
1387        r = {'Prime!': 11, 'much space': 2002, 'Questions?': 'What?'}
1388        r = insert(table, r)
1389        self.assertIsInstance(r, dict)
1390        self.assertEqual(r['Prime!'], 11)
1391        self.assertEqual(r['much space'], 2002)
1392        self.assertEqual(r['Questions?'], 'What?')
1393        r = query('select * from "%s" limit 2' % table).dictresult()
1394        self.assertEqual(len(r), 1)
1395        r = r[0]
1396        self.assertEqual(r['Prime!'], 11)
1397        self.assertEqual(r['much space'], 2002)
1398        self.assertEqual(r['Questions?'], 'What?')
1399
1400    def testUpdate(self):
1401        update = self.db.update
1402        query = self.db.query
1403        self.assertRaises(pg.ProgrammingError, update,
1404            'test', i2=2, i4=4, i8=8)
1405        table = 'update_test_table'
1406        query('drop table if exists "%s"' % table)
1407        self.addCleanup(query, 'drop table "%s"' % table)
1408        query('create table "%s" ('
1409            "n integer, t text) with oids" % table)
1410        for n, t in enumerate('xyz'):
1411            query('insert into "%s" values('
1412                "%d, '%s')" % (table, n + 1, t))
1413        self.assertRaises(pg.ProgrammingError, self.db.get, table, 2)
1414        r = self.db.get(table, 2, 'n')
1415        r['t'] = 'u'
1416        s = update(table, r)
1417        self.assertEqual(s, r)
1418        q = 'select t from "%s" where n=2' % table
1419        r = query(q).getresult()[0][0]
1420        self.assertEqual(r, 'u')
1421
1422    def testUpdateWithOid(self):
1423        update = self.db.update
1424        get = self.db.get
1425        query = self.db.query
1426        query("drop table if exists test_table")
1427        self.addCleanup(query, "drop table test_table")
1428        query("create table test_table (n int) with oids")
1429        query("insert into test_table values (1)")
1430        s = get('test_table', 1, 'n')
1431        self.assertIsInstance(s, dict)
1432        self.assertEqual(s['n'], 1)
1433        s['n'] = 2
1434        r = update('test_table', s)
1435        self.assertIs(r, s)
1436        self.assertEqual(r['n'], 2)
1437        qoid = 'oid(test_table)'
1438        self.assertIn(qoid, r)
1439        self.assertNotIn('oid', r)
1440        self.assertEqual(sorted(r.keys()), ['n', qoid])
1441        r['n'] = 3
1442        oid = r.pop(qoid)
1443        r = update('test_table', r, oid=oid)
1444        self.assertIs(r, s)
1445        self.assertEqual(r['n'], 3)
1446        r.pop(qoid)
1447        self.assertRaises(pg.ProgrammingError, update, 'test_table', r)
1448        s = get('test_table', 3, 'n')
1449        self.assertIsInstance(s, dict)
1450        self.assertEqual(s['n'], 3)
1451        s.pop('n')
1452        r = update('test_table', s)
1453        oid = r.pop(qoid)
1454        self.assertEqual(r, {})
1455        q = "select n from test_table limit 2"
1456        r = query(q).getresult()
1457        self.assertEqual(r, [(3,)])
1458        query("insert into test_table values (1)")
1459        self.assertRaises(pg.ProgrammingError,
1460            update, 'test_table', dict(oid=oid, n=4))
1461        r = update('test_table', dict(n=4), oid=oid)
1462        self.assertEqual(r['n'], 4)
1463        r = update('test_table *', dict(n=5), oid=oid)
1464        self.assertEqual(r['n'], 5)
1465        query("alter table test_table add column m int")
1466        query("alter table test_table add primary key (n)")
1467        self.assertIn('m', self.db.get_attnames('test_table', flush=True))
1468        self.assertEqual('n', self.db.pkey('test_table', flush=True))
1469        s = dict(n=1, m=4)
1470        r = update('test_table', s)
1471        self.assertIs(r, s)
1472        self.assertEqual(r['n'], 1)
1473        self.assertEqual(r['m'], 4)
1474        s = dict(m=7)
1475        r = update('test_table', s, n=5)
1476        self.assertIs(r, s)
1477        self.assertEqual(r['n'], 5)
1478        self.assertEqual(r['m'], 7)
1479        q = "select n, m from test_table order by 1 limit 3"
1480        r = query(q).getresult()
1481        self.assertEqual(r, [(1, 4), (5, 7)])
1482        s = dict(m=9, oid=oid)
1483        self.assertRaises(KeyError, update, 'test_table', s)
1484        r = update('test_table', s, oid=oid)
1485        self.assertIs(r, s)
1486        self.assertEqual(r['n'], 5)
1487        self.assertEqual(r['m'], 9)
1488        s = dict(n=1, m=3, oid=oid)
1489        r = update('test_table', s)
1490        self.assertIs(r, s)
1491        self.assertEqual(r['n'], 1)
1492        self.assertEqual(r['m'], 3)
1493        r = query(q).getresult()
1494        self.assertEqual(r, [(1, 3), (5, 9)])
1495
1496    def testUpdateWithCompositeKey(self):
1497        update = self.db.update
1498        query = self.db.query
1499        table = 'update_test_table_1'
1500        query('drop table if exists "%s"' % table)
1501        self.addCleanup(query, 'drop table if exists "%s"' % table)
1502        query('create table "%s" ('
1503            "n integer, t text, primary key (n))" % table)
1504        for n, t in enumerate('abc'):
1505            query('insert into "%s" values('
1506                "%d, '%s')" % (table, n + 1, t))
1507        self.assertRaises(KeyError, update, table, dict(t='b'))
1508        s = dict(n=2, t='d')
1509        r = update(table, s)
1510        self.assertIs(r, s)
1511        self.assertEqual(r['n'], 2)
1512        self.assertEqual(r['t'], 'd')
1513        q = 'select t from "%s" where n=2' % table
1514        r = query(q).getresult()[0][0]
1515        self.assertEqual(r, 'd')
1516        s.update(dict(n=4, t='e'))
1517        r = update(table, s)
1518        self.assertEqual(r['n'], 4)
1519        self.assertEqual(r['t'], 'e')
1520        q = 'select t from "%s" where n=2' % table
1521        r = query(q).getresult()[0][0]
1522        self.assertEqual(r, 'd')
1523        q = 'select t from "%s" where n=4' % table
1524        r = query(q).getresult()
1525        self.assertEqual(len(r), 0)
1526        query('drop table "%s"' % table)
1527        table = 'update_test_table_2'
1528        query('drop table if exists "%s"' % table)
1529        query('create table "%s" ('
1530            "n integer, m integer, t text, primary key (n, m))" % table)
1531        for n in range(3):
1532            for m in range(2):
1533                t = chr(ord('a') + 2 * n + m)
1534                query('insert into "%s" values('
1535                    "%d, %d, '%s')" % (table, n + 1, m + 1, t))
1536        self.assertRaises(KeyError, update, table, dict(n=2, t='b'))
1537        self.assertEqual(update(table,
1538            dict(n=2, m=2, t='x'))['t'], 'x')
1539        q = 'select t from "%s" where n=2 order by m' % table
1540        r = [r[0] for r in query(q).getresult()]
1541        self.assertEqual(r, ['c', 'x'])
1542
1543    def testUpdateWithQuotedNames(self):
1544        update = self.db.update
1545        query = self.db.query
1546        table = 'test table for update()'
1547        query('drop table if exists "%s"' % table)
1548        self.addCleanup(query, 'drop table "%s"' % table)
1549        query('create table "%s" ('
1550            '"Prime!" smallint primary key,'
1551            ' "much space" integer, "Questions?" text)' % table)
1552        query('insert into "%s"'
1553              " values(13, 3003, 'Why!')" % table)
1554        r = {'Prime!': 13, 'much space': 7007, 'Questions?': 'When?'}
1555        r = update(table, r)
1556        self.assertIsInstance(r, dict)
1557        self.assertEqual(r['Prime!'], 13)
1558        self.assertEqual(r['much space'], 7007)
1559        self.assertEqual(r['Questions?'], 'When?')
1560        r = query('select * from "%s" limit 2' % table).dictresult()
1561        self.assertEqual(len(r), 1)
1562        r = r[0]
1563        self.assertEqual(r['Prime!'], 13)
1564        self.assertEqual(r['much space'], 7007)
1565        self.assertEqual(r['Questions?'], 'When?')
1566
1567    def testUpsert(self):
1568        upsert = self.db.upsert
1569        query = self.db.query
1570        self.assertRaises(pg.ProgrammingError, upsert,
1571            'test', i2=2, i4=4, i8=8)
1572        table = 'upsert_test_table'
1573        query('drop table if exists "%s"' % table)
1574        self.addCleanup(query, 'drop table "%s"' % table)
1575        query('create table "%s" ('
1576            "n integer primary key, t text) with oids" % table)
1577        s = dict(n=1, t='x')
1578        try:
1579            r = upsert(table, s)
1580        except pg.ProgrammingError as error:
1581            if self.db.server_version < 90500:
1582                self.skipTest('database does not support upsert')
1583            self.fail(str(error))
1584        self.assertIs(r, s)
1585        self.assertEqual(r['n'], 1)
1586        self.assertEqual(r['t'], 'x')
1587        s.update(n=2, t='y')
1588        r = upsert(table, s, **dict.fromkeys(s))
1589        self.assertIs(r, s)
1590        self.assertEqual(r['n'], 2)
1591        self.assertEqual(r['t'], 'y')
1592        q = 'select n, t from "%s" order by n limit 3' % table
1593        r = query(q).getresult()
1594        self.assertEqual(r, [(1, 'x'), (2, 'y')])
1595        s.update(t='z')
1596        r = upsert(table, s)
1597        self.assertIs(r, s)
1598        self.assertEqual(r['n'], 2)
1599        self.assertEqual(r['t'], 'z')
1600        r = query(q).getresult()
1601        self.assertEqual(r, [(1, 'x'), (2, 'z')])
1602        s.update(t='n')
1603        r = upsert(table, s, t=False)
1604        self.assertIs(r, s)
1605        self.assertEqual(r['n'], 2)
1606        self.assertEqual(r['t'], 'z')
1607        r = query(q).getresult()
1608        self.assertEqual(r, [(1, 'x'), (2, 'z')])
1609        s.update(t='y')
1610        r = upsert(table, s, t=True)
1611        self.assertIs(r, s)
1612        self.assertEqual(r['n'], 2)
1613        self.assertEqual(r['t'], 'y')
1614        r = query(q).getresult()
1615        self.assertEqual(r, [(1, 'x'), (2, 'y')])
1616        s.update(t='n')
1617        r = upsert(table, s, t="included.t || '2'")
1618        self.assertIs(r, s)
1619        self.assertEqual(r['n'], 2)
1620        self.assertEqual(r['t'], 'y2')
1621        r = query(q).getresult()
1622        self.assertEqual(r, [(1, 'x'), (2, 'y2')])
1623        s.update(t='y')
1624        r = upsert(table, s, t="excluded.t || '3'")
1625        self.assertIs(r, s)
1626        self.assertEqual(r['n'], 2)
1627        self.assertEqual(r['t'], 'y3')
1628        r = query(q).getresult()
1629        self.assertEqual(r, [(1, 'x'), (2, 'y3')])
1630        s.update(n=1, t='2')
1631        r = upsert(table, s, t="included.t || excluded.t")
1632        self.assertIs(r, s)
1633        self.assertEqual(r['n'], 1)
1634        self.assertEqual(r['t'], 'x2')
1635        r = query(q).getresult()
1636        self.assertEqual(r, [(1, 'x2'), (2, 'y3')])
1637        # not existing columns and oid parameter should be ignored
1638        s = dict(m=3, u='z')
1639        r = upsert(table, s, oid='invalid')
1640        self.assertIs(r, s)
1641
1642    def testUpsertWithOid(self):
1643        upsert = self.db.upsert
1644        get = self.db.get
1645        query = self.db.query
1646        query("drop table if exists test_table")
1647        self.addCleanup(query, "drop table test_table")
1648        query("create table test_table (n int) with oids")
1649        query("insert into test_table values (1)")
1650        self.assertRaises(pg.ProgrammingError,
1651            upsert, 'test_table', dict(n=2))
1652        r = get('test_table', 1, 'n')
1653        self.assertIsInstance(r, dict)
1654        self.assertEqual(r['n'], 1)
1655        qoid = 'oid(test_table)'
1656        self.assertIn(qoid, r)
1657        self.assertNotIn('oid', r)
1658        oid = r[qoid]
1659        self.assertRaises(pg.ProgrammingError,
1660            upsert, 'test_table', dict(n=2, oid=oid))
1661        query("alter table test_table add column m int")
1662        query("alter table test_table add primary key (n)")
1663        self.assertIn('m', self.db.get_attnames('test_table', flush=True))
1664        self.assertEqual('n', self.db.pkey('test_table', flush=True))
1665        s = dict(n=2)
1666        r = upsert('test_table', s)
1667        self.assertIs(r, s)
1668        self.assertEqual(r['n'], 2)
1669        self.assertIsNone(r['m'])
1670        q = query("select n, m from test_table order by n limit 3")
1671        self.assertEqual(q.getresult(), [(1, None), (2, None)])
1672        r['oid'] = oid
1673        r = upsert('test_table', r)
1674        self.assertIs(r, s)
1675        self.assertEqual(r['n'], 2)
1676        self.assertIsNone(r['m'])
1677        self.assertIn(qoid, r)
1678        self.assertNotIn('oid', r)
1679        self.assertNotEqual(r[qoid], oid)
1680        r['m'] = 7
1681        r = upsert('test_table', r)
1682        self.assertIs(r, s)
1683        self.assertEqual(r['n'], 2)
1684        self.assertEqual(r['m'], 7)
1685        r.update(n=1, m=3)
1686        r = upsert('test_table', r)
1687        self.assertIs(r, s)
1688        self.assertEqual(r['n'], 1)
1689        self.assertEqual(r['m'], 3)
1690        q = query("select n, m from test_table order by n limit 3")
1691        self.assertEqual(q.getresult(), [(1, 3), (2, 7)])
1692        r = upsert('test_table', r, oid='invalid')
1693        self.assertIs(r, s)
1694        self.assertEqual(r['n'], 1)
1695        self.assertEqual(r['m'], 3)
1696        r['m'] = 5
1697        r = upsert('test_table', r, m=False)
1698        self.assertIs(r, s)
1699        self.assertEqual(r['n'], 1)
1700        self.assertEqual(r['m'], 3)
1701        r['m'] = 5
1702        r = upsert('test_table', r, m=True)
1703        self.assertIs(r, s)
1704        self.assertEqual(r['n'], 1)
1705        self.assertEqual(r['m'], 5)
1706        r.update(n=2, m=1)
1707        r = upsert('test_table', r, m='included.m')
1708        self.assertIs(r, s)
1709        self.assertEqual(r['n'], 2)
1710        self.assertEqual(r['m'], 7)
1711        r['m'] = 9
1712        r = upsert('test_table', r, m='excluded.m')
1713        self.assertIs(r, s)
1714        self.assertEqual(r['n'], 2)
1715        self.assertEqual(r['m'], 9)
1716        r['m'] = 8
1717        r = upsert('test_table *', r, m='included.m + 1')
1718        self.assertIs(r, s)
1719        self.assertEqual(r['n'], 2)
1720        self.assertEqual(r['m'], 10)
1721        q = query("select n, m from test_table order by n limit 3")
1722        self.assertEqual(q.getresult(), [(1, 5), (2, 10)])
1723
1724    def testUpsertWithCompositeKey(self):
1725        upsert = self.db.upsert
1726        query = self.db.query
1727        table = 'upsert_test_table_2'
1728        query('drop table if exists "%s"' % table)
1729        self.addCleanup(query, 'drop table "%s"' % table)
1730        query('create table "%s" ('
1731            "n integer, m integer, t text, primary key (n, m))" % table)
1732        s = dict(n=1, m=2, t='x')
1733        try:
1734            r = upsert(table, s)
1735        except pg.ProgrammingError as error:
1736            if self.db.server_version < 90500:
1737                self.skipTest('database does not support upsert')
1738            self.fail(str(error))
1739        self.assertIs(r, s)
1740        self.assertEqual(r['n'], 1)
1741        self.assertEqual(r['m'], 2)
1742        self.assertEqual(r['t'], 'x')
1743        s.update(m=3, t='y')
1744        r = upsert(table, s, **dict.fromkeys(s))
1745        self.assertIs(r, s)
1746        self.assertEqual(r['n'], 1)
1747        self.assertEqual(r['m'], 3)
1748        self.assertEqual(r['t'], 'y')
1749        q = 'select n, m, t from "%s" order by n, m limit 3' % table
1750        r = query(q).getresult()
1751        self.assertEqual(r, [(1, 2, 'x'), (1, 3, 'y')])
1752        s.update(t='z')
1753        r = upsert(table, s)
1754        self.assertIs(r, s)
1755        self.assertEqual(r['n'], 1)
1756        self.assertEqual(r['m'], 3)
1757        self.assertEqual(r['t'], 'z')
1758        r = query(q).getresult()
1759        self.assertEqual(r, [(1, 2, 'x'), (1, 3, 'z')])
1760        s.update(t='n')
1761        r = upsert(table, s, t=False)
1762        self.assertIs(r, s)
1763        self.assertEqual(r['n'], 1)
1764        self.assertEqual(r['m'], 3)
1765        self.assertEqual(r['t'], 'z')
1766        r = query(q).getresult()
1767        self.assertEqual(r, [(1, 2, 'x'), (1, 3, 'z')])
1768        s.update(t='n')
1769        r = upsert(table, s, t=True)
1770        self.assertIs(r, s)
1771        self.assertEqual(r['n'], 1)
1772        self.assertEqual(r['m'], 3)
1773        self.assertEqual(r['t'], 'n')
1774        r = query(q).getresult()
1775        self.assertEqual(r, [(1, 2, 'x'), (1, 3, 'n')])
1776        s.update(n=2, t='y')
1777        r = upsert(table, s, t="'z'")
1778        self.assertIs(r, s)
1779        self.assertEqual(r['n'], 2)
1780        self.assertEqual(r['m'], 3)
1781        self.assertEqual(r['t'], 'y')
1782        r = query(q).getresult()
1783        self.assertEqual(r, [(1, 2, 'x'), (1, 3, 'n'), (2, 3, 'y')])
1784        s.update(n=1, t='m')
1785        r = upsert(table, s, t='included.t || excluded.t')
1786        self.assertIs(r, s)
1787        self.assertEqual(r['n'], 1)
1788        self.assertEqual(r['m'], 3)
1789        self.assertEqual(r['t'], 'nm')
1790        r = query(q).getresult()
1791        self.assertEqual(r, [(1, 2, 'x'), (1, 3, 'nm'), (2, 3, 'y')])
1792
1793    def testUpsertWithQuotedNames(self):
1794        upsert = self.db.upsert
1795        query = self.db.query
1796        table = 'test table for upsert()'
1797        query('drop table if exists "%s"' % table)
1798        self.addCleanup(query, 'drop table "%s"' % table)
1799        query('create table "%s" ('
1800            '"Prime!" smallint primary key,'
1801            ' "much space" integer, "Questions?" text)' % table)
1802        s = {'Prime!': 31, 'much space': 9009, 'Questions?': 'Yes.'}
1803        try:
1804            r = upsert(table, s)
1805        except pg.ProgrammingError as error:
1806            if self.db.server_version < 90500:
1807                self.skipTest('database does not support upsert')
1808            self.fail(str(error))
1809        self.assertIs(r, s)
1810        self.assertEqual(r['Prime!'], 31)
1811        self.assertEqual(r['much space'], 9009)
1812        self.assertEqual(r['Questions?'], 'Yes.')
1813        q = 'select * from "%s" limit 2' % table
1814        r = query(q).getresult()
1815        self.assertEqual(r, [(31, 9009, 'Yes.')])
1816        s.update({'Questions?': 'No.'})
1817        r = upsert(table, s)
1818        self.assertIs(r, s)
1819        self.assertEqual(r['Prime!'], 31)
1820        self.assertEqual(r['much space'], 9009)
1821        self.assertEqual(r['Questions?'], 'No.')
1822        r = query(q).getresult()
1823        self.assertEqual(r, [(31, 9009, 'No.')])
1824
1825    def testClear(self):
1826        clear = self.db.clear
1827        query = self.db.query
1828        f = False if pg.get_bool() else 'f'
1829        r = clear('test')
1830        result = dict(
1831            i2=0, i4=0, i8=0, d=0, f4=0, f8=0, m=0, v4='', c4='', t='')
1832        self.assertEqual(r, result)
1833        table = 'clear_test_table'
1834        query('drop table if exists "%s"' % table)
1835        self.addCleanup(query, 'drop table "%s"' % table)
1836        query('create table "%s" ('
1837            "n integer, b boolean, d date, t text) with oids" % table)
1838        r = clear(table)
1839        result = dict(n=0, b=f, d='', t='')
1840        self.assertEqual(r, result)
1841        r['a'] = r['n'] = 1
1842        r['d'] = r['t'] = 'x'
1843        r['b'] = 't'
1844        r['oid'] = long(1)
1845        r = clear(table, r)
1846        result = dict(a=1, n=0, b=f, d='', t='', oid=long(1))
1847        self.assertEqual(r, result)
1848
1849    def testClearWithQuotedNames(self):
1850        clear = self.db.clear
1851        query = self.db.query
1852        table = 'test table for clear()'
1853        query('drop table if exists "%s"' % table)
1854        self.addCleanup(query, 'drop table "%s"' % table)
1855        query('create table "%s" ('
1856            '"Prime!" smallint primary key,'
1857            ' "much space" integer, "Questions?" text)' % table)
1858        r = clear(table)
1859        self.assertIsInstance(r, dict)
1860        self.assertEqual(r['Prime!'], 0)
1861        self.assertEqual(r['much space'], 0)
1862        self.assertEqual(r['Questions?'], '')
1863
1864    def testDelete(self):
1865        delete = self.db.delete
1866        query = self.db.query
1867        self.assertRaises(pg.ProgrammingError, delete,
1868            'test', dict(i2=2, i4=4, i8=8))
1869        table = 'delete_test_table'
1870        query('drop table if exists "%s"' % table)
1871        self.addCleanup(query, 'drop table "%s"' % table)
1872        query('create table "%s" ('
1873            "n integer, t text) with oids" % table)
1874        for n, t in enumerate('xyz'):
1875            query('insert into "%s" values('
1876                "%d, '%s')" % (table, n + 1, t))
1877        self.assertRaises(pg.ProgrammingError, self.db.get, table, 2)
1878        r = self.db.get(table, 1, 'n')
1879        s = delete(table, r)
1880        self.assertEqual(s, 1)
1881        r = self.db.get(table, 3, 'n')
1882        s = delete(table, r)
1883        self.assertEqual(s, 1)
1884        s = delete(table, r)
1885        self.assertEqual(s, 0)
1886        r = query('select * from "%s"' % table).dictresult()
1887        self.assertEqual(len(r), 1)
1888        r = r[0]
1889        result = {'n': 2, 't': 'y'}
1890        self.assertEqual(r, result)
1891        r = self.db.get(table, 2, 'n')
1892        s = delete(table, r)
1893        self.assertEqual(s, 1)
1894        s = delete(table, r)
1895        self.assertEqual(s, 0)
1896        self.assertRaises(pg.DatabaseError, self.db.get, table, 2, 'n')
1897        # not existing columns and oid parameter should be ignored
1898        r.update(m=3, u='z', oid='invalid')
1899        s = delete(table, r)
1900        self.assertEqual(s, 0)
1901
1902    def testDeleteWithOid(self):
1903        delete = self.db.delete
1904        get = self.db.get
1905        query = self.db.query
1906        query("drop table if exists test_table")
1907        self.addCleanup(query, "drop table test_table")
1908        query("create table test_table (n int) with oids")
1909        for i in range(6):
1910            query("insert into test_table values (%d)" % (i + 1))
1911        r = dict(n=3)
1912        self.assertRaises(pg.ProgrammingError, delete, 'test_table', r)
1913        s = get('test_table', 1, 'n')
1914        qoid = 'oid(test_table)'
1915        self.assertIn(qoid, s)
1916        r = delete('test_table', s)
1917        self.assertEqual(r, 1)
1918        r = delete('test_table', s)
1919        self.assertEqual(r, 0)
1920        q = "select min(n),count(n) from test_table"
1921        self.assertEqual(query(q).getresult()[0], (2, 5))
1922        oid = get('test_table', 2, 'n')[qoid]
1923        s = dict(oid=oid, n=2)
1924        self.assertRaises(pg.ProgrammingError, delete, 'test_table', s)
1925        r = delete('test_table', None, oid=oid)
1926        self.assertEqual(r, 1)
1927        r = delete('test_table', None, oid=oid)
1928        self.assertEqual(r, 0)
1929        self.assertEqual(query(q).getresult()[0], (3, 4))
1930        s = dict(oid=oid, n=2)
1931        oid = get('test_table', 3, 'n')[qoid]
1932        self.assertRaises(pg.ProgrammingError, delete, 'test_table', s)
1933        r = delete('test_table', s, oid=oid)
1934        self.assertEqual(r, 1)
1935        r = delete('test_table', s, oid=oid)
1936        self.assertEqual(r, 0)
1937        self.assertEqual(query(q).getresult()[0], (4, 3))
1938        s = get('test_table', 4, 'n')
1939        r = delete('test_table *', s)
1940        self.assertEqual(r, 1)
1941        r = delete('test_table *', s)
1942        self.assertEqual(r, 0)
1943        self.assertEqual(query(q).getresult()[0], (5, 2))
1944        oid = get('test_table', 5, 'n')[qoid]
1945        s = {qoid: oid, 'm': 4}
1946        r = delete('test_table', s, m=6)
1947        self.assertEqual(r, 1)
1948        r = delete('test_table *', s)
1949        self.assertEqual(r, 0)
1950        self.assertEqual(query(q).getresult()[0], (6, 1))
1951        query("alter table test_table add column m int")
1952        query("alter table test_table add primary key (n)")
1953        self.assertIn('m', self.db.get_attnames('test_table', flush=True))
1954        self.assertEqual('n', self.db.pkey('test_table', flush=True))
1955        for i in range(5):
1956            query("insert into test_table values (%d, %d)" % (i + 1, i + 2))
1957        s = dict(m=2)
1958        self.assertRaises(KeyError, delete, 'test_table', s)
1959        s = dict(m=2, oid=oid)
1960        self.assertRaises(KeyError, delete, 'test_table', s)
1961        r = delete('test_table', dict(m=2), oid=oid)
1962        self.assertEqual(r, 0)
1963        oid = get('test_table', 1, 'n')[qoid]
1964        s = dict(oid=oid)
1965        self.assertRaises(KeyError, delete, 'test_table', s)
1966        r = delete('test_table', s, oid=oid)
1967        self.assertEqual(r, 1)
1968        r = delete('test_table', s, oid=oid)
1969        self.assertEqual(r, 0)
1970        self.assertEqual(query(q).getresult()[0], (2, 5))
1971        s = get('test_table', 2, 'n')
1972        del s['n']
1973        r = delete('test_table', s)
1974        self.assertEqual(r, 1)
1975        r = delete('test_table', s)
1976        self.assertEqual(r, 0)
1977        self.assertEqual(query(q).getresult()[0], (3, 4))
1978        r = delete('test_table', n=3)
1979        self.assertEqual(r, 1)
1980        r = delete('test_table', n=3)
1981        self.assertEqual(r, 0)
1982        self.assertEqual(query(q).getresult()[0], (4, 3))
1983        r = delete('test_table', None, n=4)
1984        self.assertEqual(r, 1)
1985        r = delete('test_table', None, n=4)
1986        self.assertEqual(r, 0)
1987        self.assertEqual(query(q).getresult()[0], (5, 2))
1988        s = dict(n=6)
1989        r = delete('test_table', s, n=5)
1990        self.assertEqual(r, 1)
1991        r = delete('test_table', s, n=5)
1992        self.assertEqual(r, 0)
1993        self.assertEqual(query(q).getresult()[0], (6, 1))
1994
1995    def testDeleteWithCompositeKey(self):
1996        query = self.db.query
1997        table = 'delete_test_table_1'
1998        query('drop table if exists "%s"' % table)
1999        self.addCleanup(query, 'drop table "%s"' % table)
2000        query('create table "%s" ('
2001            "n integer, t text, primary key (n))" % table)
2002        for n, t in enumerate('abc'):
2003            query("insert into %s values("
2004                "%d, '%s')" % (table, n + 1, t))
2005        self.assertRaises(KeyError, self.db.delete, table, dict(t='b'))
2006        self.assertEqual(self.db.delete(table, dict(n=2)), 1)
2007        r = query('select t from "%s" where n=2' % table).getresult()
2008        self.assertEqual(r, [])
2009        self.assertEqual(self.db.delete(table, dict(n=2)), 0)
2010        r = query('select t from "%s" where n=3' % table).getresult()[0][0]
2011        self.assertEqual(r, 'c')
2012        table = 'delete_test_table_2'
2013        query('drop table if exists "%s"' % table)
2014        self.addCleanup(query, 'drop table "%s"' % table)
2015        query('create table "%s" ('
2016            "n integer, m integer, t text, primary key (n, m))" % table)
2017        for n in range(3):
2018            for m in range(2):
2019                t = chr(ord('a') + 2 * n + m)
2020                query('insert into "%s" values('
2021                    "%d, %d, '%s')" % (table, n + 1, m + 1, t))
2022        self.assertRaises(KeyError, self.db.delete, table, dict(n=2, t='b'))
2023        self.assertEqual(self.db.delete(table, dict(n=2, m=2)), 1)
2024        r = [r[0] for r in query('select t from "%s" where n=2'
2025            ' order by m' % table).getresult()]
2026        self.assertEqual(r, ['c'])
2027        self.assertEqual(self.db.delete(table, dict(n=2, m=2)), 0)
2028        r = [r[0] for r in query('select t from "%s" where n=3'
2029            ' order by m' % table).getresult()]
2030        self.assertEqual(r, ['e', 'f'])
2031        self.assertEqual(self.db.delete(table, dict(n=3, m=1)), 1)
2032        r = [r[0] for r in query('select t from "%s" where n=3'
2033            ' order by m' % table).getresult()]
2034        self.assertEqual(r, ['f'])
2035
2036    def testDeleteWithQuotedNames(self):
2037        delete = self.db.delete
2038        query = self.db.query
2039        table = 'test table for delete()'
2040        query('drop table if exists "%s"' % table)
2041        self.addCleanup(query, 'drop table "%s"' % table)
2042        query('create table "%s" ('
2043            '"Prime!" smallint primary key,'
2044            ' "much space" integer, "Questions?" text)' % table)
2045        query('insert into "%s"'
2046              " values(19, 5005, 'Yes!')" % table)
2047        r = {'Prime!': 17}
2048        r = delete(table, r)
2049        self.assertEqual(r, 0)
2050        r = query('select count(*) from "%s"' % table).getresult()
2051        self.assertEqual(r[0][0], 1)
2052        r = {'Prime!': 19}
2053        r = delete(table, r)
2054        self.assertEqual(r, 1)
2055        r = query('select count(*) from "%s"' % table).getresult()
2056        self.assertEqual(r[0][0], 0)
2057
2058    def testDeleteReferenced(self):
2059        delete = self.db.delete
2060        query = self.db.query
2061        query("drop table if exists test_child")
2062        query("drop table if exists test_parent")
2063        self.addCleanup(query, "drop table test_parent")
2064        query("create table test_parent (n smallint primary key)")
2065        self.addCleanup(query, "drop table test_child")
2066        query("create table test_child ("
2067            " n smallint primary key references test_parent (n))")
2068        for n in range(3):
2069            query("insert into test_parent (n) values (%d)" % n)
2070            query("insert into test_child (n) values (%d)" % n)
2071        q = ("select (select count(*) from test_parent),"
2072            " (select count(*) from test_child)")
2073        self.assertEqual(query(q).getresult()[0], (3, 3))
2074        self.assertRaises(pg.ProgrammingError,
2075            delete, 'test_parent', None, n=2)
2076        self.assertRaises(pg.ProgrammingError,
2077            delete, 'test_parent *', None, n=2)
2078        r = delete('test_child', None, n=2)
2079        self.assertEqual(r, 1)
2080        self.assertEqual(query(q).getresult()[0], (3, 2))
2081        r = delete('test_parent', None, n=2)
2082        self.assertEqual(r, 1)
2083        self.assertEqual(query(q).getresult()[0], (2, 2))
2084        self.assertRaises(pg.ProgrammingError,
2085            delete, 'test_parent', dict(n=0))
2086        self.assertRaises(pg.ProgrammingError,
2087            delete, 'test_parent *', dict(n=0))
2088        r = delete('test_child', dict(n=0))
2089        self.assertEqual(r, 1)
2090        self.assertEqual(query(q).getresult()[0], (2, 1))
2091        r = delete('test_child', dict(n=0))
2092        self.assertEqual(r, 0)
2093        r = delete('test_parent', dict(n=0))
2094        self.assertEqual(r, 1)
2095        self.assertEqual(query(q).getresult()[0], (1, 1))
2096        r = delete('test_parent', None, n=0)
2097        self.assertEqual(r, 0)
2098        q = "select n from test_parent natural join test_child limit 2"
2099        self.assertEqual(query(q).getresult(), [(1,)])
2100
2101    def testTruncate(self):
2102        truncate = self.db.truncate
2103        self.assertRaises(TypeError, truncate, None)
2104        self.assertRaises(TypeError, truncate, 42)
2105        self.assertRaises(TypeError, truncate, dict(test_table=None))
2106        query = self.db.query
2107        query("drop table if exists test_table")
2108        self.addCleanup(query, "drop table test_table")
2109        query("create table test_table (n smallint)")
2110        for i in range(3):
2111            query("insert into test_table values (1)")
2112        q = "select count(*) from test_table"
2113        r = query(q).getresult()[0][0]
2114        self.assertEqual(r, 3)
2115        truncate('test_table')
2116        r = query(q).getresult()[0][0]
2117        self.assertEqual(r, 0)
2118        for i in range(3):
2119            query("insert into test_table values (1)")
2120        r = query(q).getresult()[0][0]
2121        self.assertEqual(r, 3)
2122        truncate('public.test_table')
2123        r = query(q).getresult()[0][0]
2124        self.assertEqual(r, 0)
2125        query("drop table if exists test_table_2")
2126        self.addCleanup(query, "drop table test_table_2")
2127        query('create table test_table_2 (n smallint)')
2128        for t in (list, tuple, set):
2129            for i in range(3):
2130                query("insert into test_table values (1)")
2131                query("insert into test_table_2 values (2)")
2132            q = ("select (select count(*) from test_table),"
2133                " (select count(*) from test_table_2)")
2134            r = query(q).getresult()[0]
2135            self.assertEqual(r, (3, 3))
2136            truncate(t(['test_table', 'test_table_2']))
2137            r = query(q).getresult()[0]
2138            self.assertEqual(r, (0, 0))
2139
2140    def testTruncateRestart(self):
2141        truncate = self.db.truncate
2142        self.assertRaises(TypeError, truncate, 'test_table', restart='invalid')
2143        query = self.db.query
2144        query("drop table if exists test_table")
2145        self.addCleanup(query, "drop table test_table")
2146        query("create table test_table (n serial, t text)")
2147        for n in range(3):
2148            query("insert into test_table (t) values ('test')")
2149        q = "select count(n), min(n), max(n) from test_table"
2150        r = query(q).getresult()[0]
2151        self.assertEqual(r, (3, 1, 3))
2152        truncate('test_table')
2153        r = query(q).getresult()[0]
2154        self.assertEqual(r, (0, None, None))
2155        for n in range(3):
2156            query("insert into test_table (t) values ('test')")
2157        r = query(q).getresult()[0]
2158        self.assertEqual(r, (3, 4, 6))
2159        truncate('test_table', restart=True)
2160        r = query(q).getresult()[0]
2161        self.assertEqual(r, (0, None, None))
2162        for n in range(3):
2163            query("insert into test_table (t) values ('test')")
2164        r = query(q).getresult()[0]
2165        self.assertEqual(r, (3, 1, 3))
2166
2167    def testTruncateCascade(self):
2168        truncate = self.db.truncate
2169        self.assertRaises(TypeError, truncate, 'test_table', cascade='invalid')
2170        query = self.db.query
2171        query("drop table if exists test_child")
2172        query("drop table if exists test_parent")
2173        self.addCleanup(query, "drop table test_parent")
2174        query("create table test_parent (n smallint primary key)")
2175        self.addCleanup(query, "drop table test_child")
2176        query("create table test_child ("
2177            " n smallint primary key references test_parent (n))")
2178        for n in range(3):
2179            query("insert into test_parent (n) values (%d)" % n)
2180            query("insert into test_child (n) values (%d)" % n)
2181        q = ("select (select count(*) from test_parent),"
2182            " (select count(*) from test_child)")
2183        r = query(q).getresult()[0]
2184        self.assertEqual(r, (3, 3))
2185        self.assertRaises(pg.ProgrammingError, truncate, 'test_parent')
2186        truncate(['test_parent', 'test_child'])
2187        r = query(q).getresult()[0]
2188        self.assertEqual(r, (0, 0))
2189        for n in range(3):
2190            query("insert into test_parent (n) values (%d)" % n)
2191            query("insert into test_child (n) values (%d)" % n)
2192        r = query(q).getresult()[0]
2193        self.assertEqual(r, (3, 3))
2194        truncate('test_parent', cascade=True)
2195        r = query(q).getresult()[0]
2196        self.assertEqual(r, (0, 0))
2197        for n in range(3):
2198            query("insert into test_parent (n) values (%d)" % n)
2199            query("insert into test_child (n) values (%d)" % n)
2200        r = query(q).getresult()[0]
2201        self.assertEqual(r, (3, 3))
2202        truncate('test_child')
2203        r = query(q).getresult()[0]
2204        self.assertEqual(r, (3, 0))
2205        self.assertRaises(pg.ProgrammingError, truncate, 'test_parent')
2206        truncate('test_parent', cascade=True)
2207        r = query(q).getresult()[0]
2208        self.assertEqual(r, (0, 0))
2209
2210    def testTruncateOnly(self):
2211        truncate = self.db.truncate
2212        self.assertRaises(TypeError, truncate, 'test_table', only='invalid')
2213        query = self.db.query
2214        query("drop table if exists test_child")
2215        query("drop table if exists test_parent")
2216        self.addCleanup(query, "drop table test_parent")
2217        query("create table test_parent (n smallint)")
2218        self.addCleanup(query, "drop table test_child")
2219        query("create table test_child ("
2220            " m smallint) inherits (test_parent)")
2221        for n in range(3):
2222            query("insert into test_parent (n) values (1)")
2223            query("insert into test_child (n, m) values (2, 3)")
2224        q = ("select (select count(*) from test_parent),"
2225            " (select count(*) from test_child)")
2226        r = query(q).getresult()[0]
2227        self.assertEqual(r, (6, 3))
2228        truncate('test_parent')
2229        r = query(q).getresult()[0]
2230        self.assertEqual(r, (0, 0))
2231        for n in range(3):
2232            query("insert into test_parent (n) values (1)")
2233            query("insert into test_child (n, m) values (2, 3)")
2234        r = query(q).getresult()[0]
2235        self.assertEqual(r, (6, 3))
2236        truncate('test_parent*')
2237        r = query(q).getresult()[0]
2238        self.assertEqual(r, (0, 0))
2239        for n in range(3):
2240            query("insert into test_parent (n) values (1)")
2241            query("insert into test_child (n, m) values (2, 3)")
2242        r = query(q).getresult()[0]
2243        self.assertEqual(r, (6, 3))
2244        truncate('test_parent', only=True)
2245        r = query(q).getresult()[0]
2246        self.assertEqual(r, (3, 3))
2247        truncate('test_parent', only=False)
2248        r = query(q).getresult()[0]
2249        self.assertEqual(r, (0, 0))
2250        self.assertRaises(ValueError, truncate, 'test_parent*', only=True)
2251        truncate('test_parent*', only=False)
2252        query("drop table if exists test_parent_2")
2253        self.addCleanup(query, "drop table test_parent_2")
2254        query("create table test_parent_2 (n smallint)")
2255        query("drop table if exists test_child_2")
2256        self.addCleanup(query, "drop table test_child_2")
2257        query("create table test_child_2 ("
2258            " m smallint) inherits (test_parent_2)")
2259        for n in range(3):
2260            query("insert into test_parent (n) values (1)")
2261            query("insert into test_child (n, m) values (2, 3)")
2262            query("insert into test_parent_2 (n) values (1)")
2263            query("insert into test_child_2 (n, m) values (2, 3)")
2264        q = ("select (select count(*) from test_parent),"
2265            " (select count(*) from test_child),"
2266            " (select count(*) from test_parent_2),"
2267            " (select count(*) from test_child_2)")
2268        r = query(q).getresult()[0]
2269        self.assertEqual(r, (6, 3, 6, 3))
2270        truncate(['test_parent', 'test_parent_2'], only=[False, True])
2271        r = query(q).getresult()[0]
2272        self.assertEqual(r, (0, 0, 3, 3))
2273        truncate(['test_parent', 'test_parent_2'], only=False)
2274        r = query(q).getresult()[0]
2275        self.assertEqual(r, (0, 0, 0, 0))
2276        self.assertRaises(ValueError, truncate,
2277            ['test_parent*', 'test_child'], only=[True, False])
2278        truncate(['test_parent*', 'test_child'], only=[False, True])
2279
2280    def testTruncateQuoted(self):
2281        truncate = self.db.truncate
2282        query = self.db.query
2283        table = "test table for truncate()"
2284        query('drop table if exists "%s"' % table)
2285        self.addCleanup(query, 'drop table "%s"' % table)
2286        query('create table "%s" (n smallint)' % table)
2287        for i in range(3):
2288            query('insert into "%s" values (1)' % table)
2289        q = 'select count(*) from "%s"' % table
2290        r = query(q).getresult()[0][0]
2291        self.assertEqual(r, 3)
2292        truncate(table)
2293        r = query(q).getresult()[0][0]
2294        self.assertEqual(r, 0)
2295        for i in range(3):
2296            query('insert into "%s" values (1)' % table)
2297        r = query(q).getresult()[0][0]
2298        self.assertEqual(r, 3)
2299        truncate('public."%s"' % table)
2300        r = query(q).getresult()[0][0]
2301        self.assertEqual(r, 0)
2302
2303    def testTransaction(self):
2304        query = self.db.query
2305        query("drop table if exists test_table")
2306        self.addCleanup(query, "drop table test_table")
2307        query("create table test_table (n integer)")
2308        self.db.begin()
2309        query("insert into test_table values (1)")
2310        query("insert into test_table values (2)")
2311        self.db.commit()
2312        self.db.begin()
2313        query("insert into test_table values (3)")
2314        query("insert into test_table values (4)")
2315        self.db.rollback()
2316        self.db.begin()
2317        query("insert into test_table values (5)")
2318        self.db.savepoint('before6')
2319        query("insert into test_table values (6)")
2320        self.db.rollback('before6')
2321        query("insert into test_table values (7)")
2322        self.db.commit()
2323        self.db.begin()
2324        self.db.savepoint('before8')
2325        query("insert into test_table values (8)")
2326        self.db.release('before8')
2327        self.assertRaises(pg.ProgrammingError, self.db.rollback, 'before8')
2328        self.db.commit()
2329        self.db.start()
2330        query("insert into test_table values (9)")
2331        self.db.end()
2332        r = [r[0] for r in query(
2333            "select * from test_table order by 1").getresult()]
2334        self.assertEqual(r, [1, 2, 5, 7, 9])
2335        self.db.begin(mode='read only')
2336        self.assertRaises(pg.ProgrammingError,
2337            query, "insert into test_table values (0)")
2338        self.db.rollback()
2339        self.db.start(mode='Read Only')
2340        self.assertRaises(pg.ProgrammingError,
2341            query, "insert into test_table values (0)")
2342        self.db.abort()
2343
2344    def testTransactionAliases(self):
2345        self.assertEqual(self.db.begin, self.db.start)
2346        self.assertEqual(self.db.commit, self.db.end)
2347        self.assertEqual(self.db.rollback, self.db.abort)
2348
2349    def testContextManager(self):
2350        query = self.db.query
2351        query("drop table if exists test_table")
2352        self.addCleanup(query, "drop table test_table")
2353        query("create table test_table (n integer check(n>0))")
2354        with self.db:
2355            query("insert into test_table values (1)")
2356            query("insert into test_table values (2)")
2357        try:
2358            with self.db:
2359                query("insert into test_table values (3)")
2360                query("insert into test_table values (4)")
2361                raise ValueError('test transaction should rollback')
2362        except ValueError as error:
2363            self.assertEqual(str(error), 'test transaction should rollback')
2364        with self.db:
2365            query("insert into test_table values (5)")
2366        try:
2367            with self.db:
2368                query("insert into test_table values (6)")
2369                query("insert into test_table values (-1)")
2370        except pg.ProgrammingError as error:
2371            self.assertTrue('check' in str(error))
2372        with self.db:
2373            query("insert into test_table values (7)")
2374        r = [r[0] for r in query(
2375            "select * from test_table order by 1").getresult()]
2376        self.assertEqual(r, [1, 2, 5, 7])
2377
2378    def testBytea(self):
2379        query = self.db.query
2380        query('drop table if exists bytea_test')
2381        self.addCleanup(query, 'drop table bytea_test')
2382        query('create table bytea_test (n smallint primary key, data bytea)')
2383        s = b"It's all \\ kinds \x00 of\r nasty \xff stuff!\n"
2384        r = self.db.escape_bytea(s)
2385        query('insert into bytea_test values(3,$1)', (r,))
2386        r = query('select * from bytea_test where n=3').getresult()
2387        self.assertEqual(len(r), 1)
2388        r = r[0]
2389        self.assertEqual(len(r), 2)
2390        self.assertEqual(r[0], 3)
2391        r = r[1]
2392        self.assertIsInstance(r, str)
2393        r = self.db.unescape_bytea(r)
2394        self.assertIsInstance(r, bytes)
2395        self.assertEqual(r, s)
2396
2397    def testInsertUpdateGetBytea(self):
2398        query = self.db.query
2399        query('drop table if exists bytea_test')
2400        self.addCleanup(query, 'drop table bytea_test')
2401        query('create table bytea_test (n smallint primary key, data bytea)')
2402        # insert null value
2403        r = self.db.insert('bytea_test', n=0, data=None)
2404        self.assertIsInstance(r, dict)
2405        self.assertIn('n', r)
2406        self.assertEqual(r['n'], 0)
2407        self.assertIn('data', r)
2408        self.assertIsNone(r['data'])
2409        s = b'None'
2410        r = self.db.update('bytea_test', n=0, data=s)
2411        self.assertIsInstance(r, dict)
2412        self.assertIn('n', r)
2413        self.assertEqual(r['n'], 0)
2414        self.assertIn('data', r)
2415        r = r['data']
2416        self.assertIsInstance(r, bytes)
2417        self.assertEqual(r, s)
2418        r = self.db.update('bytea_test', n=0, data=None)
2419        self.assertIsNone(r['data'])
2420        # insert as bytes
2421        s = b"It's all \\ kinds \x00 of\r nasty \xff stuff!\n"
2422        r = self.db.insert('bytea_test', n=5, data=s)
2423        self.assertIsInstance(r, dict)
2424        self.assertIn('n', r)
2425        self.assertEqual(r['n'], 5)
2426        self.assertIn('data', r)
2427        r = r['data']
2428        self.assertIsInstance(r, bytes)
2429        self.assertEqual(r, s)
2430        # update as bytes
2431        s += b"and now even more \x00 nasty \t stuff!\f"
2432        r = self.db.update('bytea_test', n=5, data=s)
2433        self.assertIsInstance(r, dict)
2434        self.assertIn('n', r)
2435        self.assertEqual(r['n'], 5)
2436        self.assertIn('data', r)
2437        r = r['data']
2438        self.assertIsInstance(r, bytes)
2439        self.assertEqual(r, s)
2440        r = query('select * from bytea_test where n=5').getresult()
2441        self.assertEqual(len(r), 1)
2442        r = r[0]
2443        self.assertEqual(len(r), 2)
2444        self.assertEqual(r[0], 5)
2445        r = r[1]
2446        self.assertIsInstance(r, str)
2447        r = self.db.unescape_bytea(r)
2448        self.assertIsInstance(r, bytes)
2449        self.assertEqual(r, s)
2450        r = self.db.get('bytea_test', dict(n=5))
2451        self.assertIsInstance(r, dict)
2452        self.assertIn('n', r)
2453        self.assertEqual(r['n'], 5)
2454        self.assertIn('data', r)
2455        r = r['data']
2456        self.assertIsInstance(r, bytes)
2457        self.assertEqual(r, s)
2458
2459    def testUpsertBytea(self):
2460        query = self.db.query
2461        query('drop table if exists bytea_test')
2462        self.addCleanup(query, 'drop table bytea_test')
2463        query('create table bytea_test (n smallint primary key, data bytea)')
2464        s = b"It's all \\ kinds \x00 of\r nasty \xff stuff!\n"
2465        r = dict(n=7, data=s)
2466        try:
2467            r = self.db.upsert('bytea_test', r)
2468        except pg.ProgrammingError as error:
2469            if self.db.server_version < 90500:
2470                self.skipTest('database does not support upsert')
2471            self.fail(str(error))
2472        self.assertIsInstance(r, dict)
2473        self.assertIn('n', r)
2474        self.assertEqual(r['n'], 7)
2475        self.assertIn('data', r)
2476        self.assertIsInstance(r['data'], bytes)
2477        self.assertEqual(r['data'], s)
2478        r['data'] = None
2479        r = self.db.upsert('bytea_test', r)
2480        self.assertIsInstance(r, dict)
2481        self.assertIn('n', r)
2482        self.assertEqual(r['n'], 7)
2483        self.assertIn('data', r)
2484        self.assertIsNone(r['data'], bytes)
2485
2486    def testNotificationHandler(self):
2487        # the notification handler itself is tested separately
2488        f = self.db.notification_handler
2489        callback = lambda arg_dict: None
2490        handler = f('test', callback)
2491        self.assertIsInstance(handler, pg.NotificationHandler)
2492        self.assertIs(handler.db, self.db)
2493        self.assertEqual(handler.event, 'test')
2494        self.assertEqual(handler.stop_event, 'stop_test')
2495        self.assertIs(handler.callback, callback)
2496        self.assertIsInstance(handler.arg_dict, dict)
2497        self.assertEqual(handler.arg_dict, {})
2498        self.assertIsNone(handler.timeout)
2499        self.assertFalse(handler.listening)
2500        handler.close()
2501        self.assertIsNone(handler.db)
2502        self.db.reopen()
2503        self.assertIsNone(handler.db)
2504        handler = f('test2', callback, timeout=2)
2505        self.assertIsInstance(handler, pg.NotificationHandler)
2506        self.assertIs(handler.db, self.db)
2507        self.assertEqual(handler.event, 'test2')
2508        self.assertEqual(handler.stop_event, 'stop_test2')
2509        self.assertIs(handler.callback, callback)
2510        self.assertIsInstance(handler.arg_dict, dict)
2511        self.assertEqual(handler.arg_dict, {})
2512        self.assertEqual(handler.timeout, 2)
2513        self.assertFalse(handler.listening)
2514        handler.close()
2515        self.assertIsNone(handler.db)
2516        self.db.reopen()
2517        self.assertIsNone(handler.db)
2518        arg_dict = {'testing': 3}
2519        handler = f('test3', callback, arg_dict=arg_dict)
2520        self.assertIsInstance(handler, pg.NotificationHandler)
2521        self.assertIs(handler.db, self.db)
2522        self.assertEqual(handler.event, 'test3')
2523        self.assertEqual(handler.stop_event, 'stop_test3')
2524        self.assertIs(handler.callback, callback)
2525        self.assertIs(handler.arg_dict, arg_dict)
2526        self.assertEqual(arg_dict['testing'], 3)
2527        self.assertIsNone(handler.timeout)
2528        self.assertFalse(handler.listening)
2529        handler.close()
2530        self.assertIsNone(handler.db)
2531        self.db.reopen()
2532        self.assertIsNone(handler.db)
2533        handler = f('test4', callback, stop_event='stop4')
2534        self.assertIsInstance(handler, pg.NotificationHandler)
2535        self.assertIs(handler.db, self.db)
2536        self.assertEqual(handler.event, 'test4')
2537        self.assertEqual(handler.stop_event, 'stop4')
2538        self.assertIs(handler.callback, callback)
2539        self.assertIsInstance(handler.arg_dict, dict)
2540        self.assertEqual(handler.arg_dict, {})
2541        self.assertIsNone(handler.timeout)
2542        self.assertFalse(handler.listening)
2543        handler.close()
2544        self.assertIsNone(handler.db)
2545        self.db.reopen()
2546        self.assertIsNone(handler.db)
2547        arg_dict = {'testing': 5}
2548        handler = f('test5', callback, arg_dict, 1.5, 'stop5')
2549        self.assertIsInstance(handler, pg.NotificationHandler)
2550        self.assertIs(handler.db, self.db)
2551        self.assertEqual(handler.event, 'test5')
2552        self.assertEqual(handler.stop_event, 'stop5')
2553        self.assertIs(handler.callback, callback)
2554        self.assertIs(handler.arg_dict, arg_dict)
2555        self.assertEqual(arg_dict['testing'], 5)
2556        self.assertEqual(handler.timeout, 1.5)
2557        self.assertFalse(handler.listening)
2558        handler.close()
2559        self.assertIsNone(handler.db)
2560        self.db.reopen()
2561        self.assertIsNone(handler.db)
2562
2563
2564class TestDBClassNonStdOpts(TestDBClass):
2565    """Test the methods of the DB class with non-standard global options."""
2566
2567    @classmethod
2568    def setUpClass(cls):
2569        cls.saved_options = {}
2570        cls.set_option('decimal', float)
2571        not_bool = not pg.get_bool()
2572        cls.set_option('bool', not_bool)
2573        unnamed_result = lambda q: q.getresult()
2574        cls.set_option('namedresult', unnamed_result)
2575        super(TestDBClassNonStdOpts, cls).setUpClass()
2576
2577    @classmethod
2578    def tearDownClass(cls):
2579        super(TestDBClassNonStdOpts, cls).tearDownClass()
2580        cls.reset_option('namedresult')
2581        cls.reset_option('bool')
2582        cls.reset_option('decimal')
2583
2584    @classmethod
2585    def set_option(cls, option, value):
2586        cls.saved_options[option] = getattr(pg, 'get_' + option)()
2587        return getattr(pg, 'set_' + option)(value)
2588
2589    @classmethod
2590    def reset_option(cls, option):
2591        return getattr(pg, 'set_' + option)(cls.saved_options[option])
2592
2593
2594class TestSchemas(unittest.TestCase):
2595    """Test correct handling of schemas (namespaces)."""
2596
2597    @classmethod
2598    def setUpClass(cls):
2599        db = DB()
2600        query = db.query
2601        for num_schema in range(5):
2602            if num_schema:
2603                schema = "s%d" % num_schema
2604                query("drop schema if exists %s cascade" % (schema,))
2605                try:
2606                    query("create schema %s" % (schema,))
2607                except pg.ProgrammingError:
2608                    raise RuntimeError("The test user cannot create schemas.\n"
2609                        "Grant create on database %s to the user"
2610                        " for running these tests." % dbname)
2611            else:
2612                schema = "public"
2613                query("drop table if exists %s.t" % (schema,))
2614                query("drop table if exists %s.t%d" % (schema, num_schema))
2615            query("create table %s.t with oids as select 1 as n, %d as d"
2616                  % (schema, num_schema))
2617            query("create table %s.t%d with oids as select 1 as n, %d as d"
2618                  % (schema, num_schema, num_schema))
2619        db.close()
2620
2621    @classmethod
2622    def tearDownClass(cls):
2623        db = DB()
2624        query = db.query
2625        for num_schema in range(5):
2626            if num_schema:
2627                schema = "s%d" % num_schema
2628                query("drop schema %s cascade" % (schema,))
2629            else:
2630                schema = "public"
2631                query("drop table %s.t" % (schema,))
2632                query("drop table %s.t%d" % (schema, num_schema))
2633        db.close()
2634
2635    def setUp(self):
2636        self.db = DB()
2637
2638    def tearDown(self):
2639        self.doCleanups()
2640        self.db.close()
2641
2642    def testGetTables(self):
2643        tables = self.db.get_tables()
2644        for num_schema in range(5):
2645            if num_schema:
2646                schema = "s" + str(num_schema)
2647            else:
2648                schema = "public"
2649            for t in (schema + ".t",
2650                    schema + ".t" + str(num_schema)):
2651                self.assertIn(t, tables)
2652
2653    def testGetAttnames(self):
2654        get_attnames = self.db.get_attnames
2655        query = self.db.query
2656        result = {'oid': 'int', 'd': 'int', 'n': 'int'}
2657        r = get_attnames("t")
2658        self.assertEqual(r, result)
2659        r = get_attnames("s4.t4")
2660        self.assertEqual(r, result)
2661        query("drop table if exists s3.t3m")
2662        self.addCleanup(query, "drop table s3.t3m")
2663        query("create table s3.t3m with oids as select 1 as m")
2664        result_m = {'oid': 'int', 'm': 'int'}
2665        r = get_attnames("s3.t3m")
2666        self.assertEqual(r, result_m)
2667        query("set search_path to s1,s3")
2668        r = get_attnames("t3")
2669        self.assertEqual(r, result)
2670        r = get_attnames("t3m")
2671        self.assertEqual(r, result_m)
2672
2673    def testGet(self):
2674        get = self.db.get
2675        query = self.db.query
2676        PrgError = pg.ProgrammingError
2677        self.assertEqual(get("t", 1, 'n')['d'], 0)
2678        self.assertEqual(get("t0", 1, 'n')['d'], 0)
2679        self.assertEqual(get("public.t", 1, 'n')['d'], 0)
2680        self.assertEqual(get("public.t0", 1, 'n')['d'], 0)
2681        self.assertRaises(PrgError, get, "public.t1", 1, 'n')
2682        self.assertEqual(get("s1.t1", 1, 'n')['d'], 1)
2683        self.assertEqual(get("s3.t", 1, 'n')['d'], 3)
2684        query("set search_path to s2,s4")
2685        self.assertRaises(PrgError, get, "t1", 1, 'n')
2686        self.assertEqual(get("t4", 1, 'n')['d'], 4)
2687        self.assertRaises(PrgError, get, "t3", 1, 'n')
2688        self.assertEqual(get("t", 1, 'n')['d'], 2)
2689        self.assertEqual(get("s3.t3", 1, 'n')['d'], 3)
2690        query("set search_path to s1,s3")
2691        self.assertRaises(PrgError, get, "t2", 1, 'n')
2692        self.assertEqual(get("t3", 1, 'n')['d'], 3)
2693        self.assertRaises(PrgError, get, "t4", 1, 'n')
2694        self.assertEqual(get("t", 1, 'n')['d'], 1)
2695        self.assertEqual(get("s4.t4", 1, 'n')['d'], 4)
2696
2697    def testMunging(self):
2698        get = self.db.get
2699        query = self.db.query
2700        r = get("t", 1, 'n')
2701        self.assertIn('oid(t)', r)
2702        query("set search_path to s2")
2703        r = get("t2", 1, 'n')
2704        self.assertIn('oid(t2)', r)
2705        query("set search_path to s3")
2706        r = get("t", 1, 'n')
2707        self.assertIn('oid(t)', r)
2708
2709
2710class TestDebug(unittest.TestCase):
2711    """Test the debug attribute of the DB class."""
2712
2713    def setUp(self):
2714        self.db = DB()
2715        self.query = self.db.query
2716        self.debug = self.db.debug
2717        self.output = StringIO()
2718        self.stdout, sys.stdout = sys.stdout, self.output
2719
2720    def tearDown(self):
2721        sys.stdout = self.stdout
2722        self.output.close()
2723        self.db.debug = debug
2724        self.db.close()
2725
2726    def get_output(self):
2727        return self.output.getvalue()
2728
2729    def send_queries(self):
2730        self.db.query("select 1")
2731        self.db.query("select 2")
2732
2733    def testDebugDefault(self):
2734        if debug:
2735            self.assertEqual(self.db.debug, debug)
2736        else:
2737            self.assertIsNone(self.db.debug)
2738
2739    def testDebugIsFalse(self):
2740        self.db.debug = False
2741        self.send_queries()
2742        self.assertEqual(self.get_output(), "")
2743
2744    def testDebugIsTrue(self):
2745        self.db.debug = True
2746        self.send_queries()
2747        self.assertEqual(self.get_output(), "select 1\nselect 2\n")
2748
2749    def testDebugIsString(self):
2750        self.db.debug = "Test with string: %s."
2751        self.send_queries()
2752        self.assertEqual(self.get_output(),
2753            "Test with string: select 1.\nTest with string: select 2.\n")
2754
2755    def testDebugIsFileLike(self):
2756        with tempfile.TemporaryFile('w+') as debug_file:
2757            self.db.debug = debug_file
2758            self.send_queries()
2759            debug_file.seek(0)
2760            output = debug_file.read()
2761            self.assertEqual(output, "select 1\nselect 2\n")
2762            self.assertEqual(self.get_output(), "")
2763
2764    def testDebugIsCallable(self):
2765        output = []
2766        self.db.debug = output.append
2767        self.db.query("select 1")
2768        self.db.query("select 2")
2769        self.assertEqual(output, ["select 1", "select 2"])
2770        self.assertEqual(self.get_output(), "")
2771
2772
2773if __name__ == '__main__':
2774    unittest.main()
Note: See TracBrowser for help on using the repository browser.