source: trunk/tests/test_classic_dbwrapper.py @ 735

Last change on this file since 735 was 735, checked in by cito, 4 years ago

Implement "upsert" method for PostgreSQL 9.5

A new method upsert() has been added to the DB wrapper class that
nicely complements the existing get/insert/update/delete() methods.

  • Property svn:executable set to *
  • Property svn:keywords set to Id
File size: 57.6 KB
Line 
1#! /usr/bin/python
2# -*- coding: utf-8 -*-
3
4"""Test the classic PyGreSQL interface.
5
6Sub-tests for the DB wrapper object.
7
8Contributed by Christoph Zwerschke.
9
10These tests need a database to test against.
11
12"""
13
14try:
15    import unittest2 as unittest  # for Python < 2.7
16except ImportError:
17    import unittest
18import os
19
20import pg  # the module under test
21
22from decimal import Decimal
23
24# We need a database to test against.  If LOCAL_PyGreSQL.py exists we will
25# get our information from that.  Otherwise we use the defaults.
26# The current user must have create schema privilege on the database.
27dbname = 'unittest'
28dbhost = None
29dbport = 5432
30
31debug = False  # let DB wrapper print debugging output
32
33try:
34    from .LOCAL_PyGreSQL import *
35except (ImportError, ValueError):
36    try:
37        from LOCAL_PyGreSQL import *
38    except ImportError:
39        pass
40
41try:
42    long
43except NameError:  # Python >= 3.0
44    long = int
45
46try:
47    unicode
48except NameError:  # Python >= 3.0
49    unicode = str
50
51windows = os.name == 'nt'
52
53# There is a known a bug in libpq under Windows which can cause
54# the interface to crash when calling PQhost():
55do_not_ask_for_host = windows
56do_not_ask_for_host_reason = 'libpq issue on Windows'
57
58
59def DB():
60    """Create a DB wrapper object connecting to the test database."""
61    db = pg.DB(dbname, dbhost, dbport)
62    if debug:
63        db.debug = debug
64    db.query("set client_min_messages=warning")
65    return db
66
67
68class TestDBClassBasic(unittest.TestCase):
69    """Test existence of the DB class wrapped pg connection methods."""
70
71    def setUp(self):
72        self.db = DB()
73
74    def tearDown(self):
75        try:
76            self.db.close()
77        except pg.InternalError:
78            pass
79
80    def testAllDBAttributes(self):
81        attributes = [
82            'begin',
83            'cancel',
84            'clear',
85            'close',
86            'commit',
87            'db',
88            'dbname',
89            'debug',
90            'delete',
91            'end',
92            'endcopy',
93            'error',
94            'escape_bytea',
95            'escape_identifier',
96            'escape_literal',
97            'escape_string',
98            'fileno',
99            'get',
100            'get_attnames',
101            'get_databases',
102            'get_notice_receiver',
103            'get_relations',
104            'get_tables',
105            'getline',
106            'getlo',
107            'getnotify',
108            'has_table_privilege',
109            'host',
110            'insert',
111            'inserttable',
112            'locreate',
113            'loimport',
114            'notification_handler',
115            'options',
116            'parameter',
117            'pkey',
118            'port',
119            'protocol_version',
120            'putline',
121            'query',
122            'release',
123            'reopen',
124            'reset',
125            'rollback',
126            'savepoint',
127            'server_version',
128            'set_notice_receiver',
129            'source',
130            'start',
131            'status',
132            'transaction',
133            'unescape_bytea',
134            'update',
135            'upsert',
136            'use_regtypes',
137            'user',
138        ]
139        db_attributes = [a for a in dir(self.db)
140            if not a.startswith('_')]
141        self.assertEqual(attributes, db_attributes)
142
143    def testAttributeDb(self):
144        self.assertEqual(self.db.db.db, dbname)
145
146    def testAttributeDbname(self):
147        self.assertEqual(self.db.dbname, dbname)
148
149    def testAttributeError(self):
150        error = self.db.error
151        self.assertTrue(not error or 'krb5_' in error)
152        self.assertEqual(self.db.error, self.db.db.error)
153
154    @unittest.skipIf(do_not_ask_for_host, do_not_ask_for_host_reason)
155    def testAttributeHost(self):
156        def_host = 'localhost'
157        host = self.db.host
158        self.assertIsInstance(host, str)
159        self.assertEqual(host, dbhost or def_host)
160        self.assertEqual(host, self.db.db.host)
161
162    def testAttributeOptions(self):
163        no_options = ''
164        options = self.db.options
165        self.assertEqual(options, no_options)
166        self.assertEqual(options, self.db.db.options)
167
168    def testAttributePort(self):
169        def_port = 5432
170        port = self.db.port
171        self.assertIsInstance(port, int)
172        self.assertEqual(port, dbport or def_port)
173        self.assertEqual(port, self.db.db.port)
174
175    def testAttributeProtocolVersion(self):
176        protocol_version = self.db.protocol_version
177        self.assertIsInstance(protocol_version, int)
178        self.assertTrue(2 <= protocol_version < 4)
179        self.assertEqual(protocol_version, self.db.db.protocol_version)
180
181    def testAttributeServerVersion(self):
182        server_version = self.db.server_version
183        self.assertIsInstance(server_version, int)
184        self.assertTrue(70400 <= server_version < 100000)
185        self.assertEqual(server_version, self.db.db.server_version)
186
187    def testAttributeStatus(self):
188        status_ok = 1
189        status = self.db.status
190        self.assertIsInstance(status, int)
191        self.assertEqual(status, status_ok)
192        self.assertEqual(status, self.db.db.status)
193
194    def testAttributeUser(self):
195        no_user = 'Deprecated facility'
196        user = self.db.user
197        self.assertTrue(user)
198        self.assertIsInstance(user, str)
199        self.assertNotEqual(user, no_user)
200        self.assertEqual(user, self.db.db.user)
201
202    def testMethodEscapeLiteral(self):
203        self.assertEqual(self.db.escape_literal(''), "''")
204
205    def testMethodEscapeIdentifier(self):
206        self.assertEqual(self.db.escape_identifier(''), '""')
207
208    def testMethodEscapeString(self):
209        self.assertEqual(self.db.escape_string(''), '')
210
211    def testMethodEscapeBytea(self):
212        self.assertEqual(self.db.escape_bytea('').replace(
213            '\\x', '').replace('\\', ''), '')
214
215    def testMethodUnescapeBytea(self):
216        self.assertEqual(self.db.unescape_bytea(''), b'')
217
218    def testMethodQuery(self):
219        query = self.db.query
220        query("select 1+1")
221        query("select 1+$1+$2", 2, 3)
222        query("select 1+$1+$2", (2, 3))
223        query("select 1+$1+$2", [2, 3])
224        query("select 1+$1", 1)
225
226    def testMethodQueryEmpty(self):
227        self.assertRaises(ValueError, self.db.query, '')
228
229    def testMethodQueryProgrammingError(self):
230        try:
231            self.db.query("select 1/0")
232        except pg.ProgrammingError as error:
233            self.assertEqual(error.sqlstate, '22012')
234
235    def testMethodEndcopy(self):
236        try:
237            self.db.endcopy()
238        except IOError:
239            pass
240
241    def testMethodClose(self):
242        self.db.close()
243        try:
244            self.db.reset()
245        except pg.Error:
246            pass
247        else:
248            self.fail('Reset should give an error for a closed connection')
249        self.assertRaises(pg.InternalError, self.db.close)
250        self.assertRaises(pg.InternalError, self.db.query, 'select 1')
251
252    def testExistingConnection(self):
253        db = pg.DB(self.db.db)
254        self.assertEqual(self.db.db, db.db)
255        self.assertTrue(db.db)
256        db.close()
257        self.assertTrue(db.db)
258        db.reopen()
259        self.assertTrue(db.db)
260        db.close()
261        self.assertTrue(db.db)
262        db = pg.DB(self.db)
263        self.assertEqual(self.db.db, db.db)
264        db = pg.DB(db=self.db.db)
265        self.assertEqual(self.db.db, db.db)
266
267        class DB2:
268            pass
269
270        db2 = DB2()
271        db2._cnx = self.db.db
272        db = pg.DB(db2)
273        self.assertEqual(self.db.db, db.db)
274
275
276class TestDBClass(unittest.TestCase):
277    """Test the methods of the DB class wrapped pg connection."""
278
279    @classmethod
280    def setUpClass(cls):
281        db = DB()
282        db.query("drop table if exists test cascade")
283        db.query("create table test ("
284            "i2 smallint, i4 integer, i8 bigint,"
285            "d numeric, f4 real, f8 double precision, m money, "
286            "v4 varchar(4), c4 char(4), t text)")
287        db.query("create or replace view test_view as"
288            " select i4, v4 from test")
289        db.close()
290
291    @classmethod
292    def tearDownClass(cls):
293        db = DB()
294        db.query("drop table test cascade")
295        db.close()
296
297    def setUp(self):
298        self.db = DB()
299        query = self.db.query
300        query('set client_encoding=utf8')
301        query('set standard_conforming_strings=on')
302        query("set lc_monetary='C'")
303        query("set datestyle='ISO,YMD'")
304        query('set bytea_output=hex')
305
306    def tearDown(self):
307        self.db.close()
308
309    def testClassName(self):
310        self.assertEqual(self.db.__class__.__name__, 'DB')
311
312    def testModuleName(self):
313        self.assertEqual(self.db.__module__, 'pg')
314        self.assertEqual(self.db.__class__.__module__, 'pg')
315
316    def testEscapeLiteral(self):
317        f = self.db.escape_literal
318        r = f(b"plain")
319        self.assertIsInstance(r, bytes)
320        self.assertEqual(r, b"'plain'")
321        r = f(u"plain")
322        self.assertIsInstance(r, unicode)
323        self.assertEqual(r, u"'plain'")
324        r = f(u"that's kÀse".encode('utf-8'))
325        self.assertIsInstance(r, bytes)
326        self.assertEqual(r, u"'that''s kÀse'".encode('utf-8'))
327        r = f(u"that's kÀse")
328        self.assertIsInstance(r, unicode)
329        self.assertEqual(r, u"'that''s kÀse'")
330        self.assertEqual(f(r"It's fine to have a \ inside."),
331            r" E'It''s fine to have a \\ inside.'")
332        self.assertEqual(f('No "quotes" must be escaped.'),
333            "'No \"quotes\" must be escaped.'")
334
335    def testEscapeIdentifier(self):
336        f = self.db.escape_identifier
337        r = f(b"plain")
338        self.assertIsInstance(r, bytes)
339        self.assertEqual(r, b'"plain"')
340        r = f(u"plain")
341        self.assertIsInstance(r, unicode)
342        self.assertEqual(r, u'"plain"')
343        r = f(u"that's kÀse".encode('utf-8'))
344        self.assertIsInstance(r, bytes)
345        self.assertEqual(r, u'"that\'s kÀse"'.encode('utf-8'))
346        r = f(u"that's kÀse")
347        self.assertIsInstance(r, unicode)
348        self.assertEqual(r, u'"that\'s kÀse"')
349        self.assertEqual(f(r"It's fine to have a \ inside."),
350            '"It\'s fine to have a \\ inside."')
351        self.assertEqual(f('All "quotes" must be escaped.'),
352            '"All ""quotes"" must be escaped."')
353
354    def testEscapeString(self):
355        f = self.db.escape_string
356        r = f(b"plain")
357        self.assertIsInstance(r, bytes)
358        self.assertEqual(r, b"plain")
359        r = f(u"plain")
360        self.assertIsInstance(r, unicode)
361        self.assertEqual(r, u"plain")
362        r = f(u"that's kÀse".encode('utf-8'))
363        self.assertIsInstance(r, bytes)
364        self.assertEqual(r, u"that''s kÀse".encode('utf-8'))
365        r = f(u"that's kÀse")
366        self.assertIsInstance(r, unicode)
367        self.assertEqual(r, u"that''s kÀse")
368        self.assertEqual(f(r"It's fine to have a \ inside."),
369            r"It''s fine to have a \ inside.")
370
371    def testEscapeBytea(self):
372        f = self.db.escape_bytea
373        # note that escape_byte always returns hex output since Pg 9.0,
374        # regardless of the bytea_output setting
375        r = f(b'plain')
376        self.assertIsInstance(r, bytes)
377        self.assertEqual(r, b'\\x706c61696e')
378        r = f(u'plain')
379        self.assertIsInstance(r, unicode)
380        self.assertEqual(r, u'\\x706c61696e')
381        r = f(u"das is' kÀse".encode('utf-8'))
382        self.assertIsInstance(r, bytes)
383        self.assertEqual(r, b'\\x64617320697327206bc3a47365')
384        r = f(u"das is' kÀse")
385        self.assertIsInstance(r, unicode)
386        self.assertEqual(r, u'\\x64617320697327206bc3a47365')
387        self.assertEqual(f(b'O\x00ps\xff!'), b'\\x4f007073ff21')
388
389    def testUnescapeBytea(self):
390        f = self.db.unescape_bytea
391        r = f(b'plain')
392        self.assertIsInstance(r, bytes)
393        self.assertEqual(r, b'plain')
394        r = f(u'plain')
395        self.assertIsInstance(r, bytes)
396        self.assertEqual(r, b'plain')
397        r = f(b"das is' k\\303\\244se")
398        self.assertIsInstance(r, bytes)
399        self.assertEqual(r, u"das is' kÀse".encode('utf8'))
400        r = f(u"das is' k\\303\\244se")
401        self.assertIsInstance(r, bytes)
402        self.assertEqual(r, u"das is' kÀse".encode('utf8'))
403        self.assertEqual(f(r'O\\000ps\\377!'), b'O\\000ps\\377!')
404        self.assertEqual(f(r'\\x706c61696e'), b'\\x706c61696e')
405        self.assertEqual(f(r'\\x746861742773206be47365'),
406            b'\\x746861742773206be47365')
407        self.assertEqual(f(r'\\x4f007073ff21'), b'\\x4f007073ff21')
408
409    def testQuery(self):
410        query = self.db.query
411        query("drop table if exists test_table")
412        q = "create table test_table (n integer) with oids"
413        r = query(q)
414        self.assertIsNone(r)
415        q = "insert into test_table values (1)"
416        r = query(q)
417        self.assertIsInstance(r, int)
418        q = "insert into test_table select 2"
419        r = query(q)
420        self.assertIsInstance(r, int)
421        oid = r
422        q = "select oid from test_table where n=2"
423        r = query(q).getresult()
424        self.assertEqual(len(r), 1)
425        r = r[0]
426        self.assertEqual(len(r), 1)
427        r = r[0]
428        self.assertEqual(r, oid)
429        q = "insert into test_table select 3 union select 4 union select 5"
430        r = query(q)
431        self.assertIsInstance(r, str)
432        self.assertEqual(r, '3')
433        q = "update test_table set n=4 where n<5"
434        r = query(q)
435        self.assertIsInstance(r, str)
436        self.assertEqual(r, '4')
437        q = "delete from test_table"
438        r = query(q)
439        self.assertIsInstance(r, str)
440        self.assertEqual(r, '5')
441        query("drop table test_table")
442
443    def testMultipleQueries(self):
444        self.assertEqual(self.db.query(
445            "create temporary table test_multi (n integer);"
446            "insert into test_multi values (4711);"
447            "select n from test_multi").getresult()[0][0], 4711)
448
449    def testQueryWithParams(self):
450        query = self.db.query
451        query("drop table if exists test_table")
452        q = "create table test_table (n1 integer, n2 integer) with oids"
453        query(q)
454        q = "insert into test_table values ($1, $2)"
455        r = query(q, (1, 2))
456        self.assertIsInstance(r, int)
457        r = query(q, [3, 4])
458        self.assertIsInstance(r, int)
459        r = query(q, [5, 6])
460        self.assertIsInstance(r, int)
461        q = "select * from test_table order by 1, 2"
462        self.assertEqual(query(q).getresult(),
463            [(1, 2), (3, 4), (5, 6)])
464        q = "select * from test_table where n1=$1 and n2=$2"
465        self.assertEqual(query(q, 3, 4).getresult(), [(3, 4)])
466        q = "update test_table set n2=$2 where n1=$1"
467        r = query(q, 3, 7)
468        self.assertEqual(r, '1')
469        q = "select * from test_table order by 1, 2"
470        self.assertEqual(query(q).getresult(),
471            [(1, 2), (3, 7), (5, 6)])
472        q = "delete from test_table where n2!=$1"
473        r = query(q, 4)
474        self.assertEqual(r, '3')
475        query("drop table test_table")
476
477    def testEmptyQuery(self):
478        self.assertRaises(ValueError, self.db.query, '')
479
480    def testQueryProgrammingError(self):
481        try:
482            self.db.query("select 1/0")
483        except pg.ProgrammingError as error:
484            self.assertEqual(error.sqlstate, '22012')
485
486    def testPkey(self):
487        query = self.db.query
488        pkey = self.db.pkey
489        for t in ('pkeytest', 'primary key test'):
490            for n in range(7):
491                query('drop table if exists "%s%d"' % (t, n))
492            query('create table "%s0" ('
493                "a smallint)" % t)
494            query('create table "%s1" ('
495                "b smallint primary key)" % t)
496            query('create table "%s2" ('
497                "c smallint, d smallint primary key)" % t)
498            query('create table "%s3" ('
499                "e smallint, f smallint, g smallint, "
500                "h smallint, i smallint, "
501                "primary key (f, h))" % t)
502            query('create table "%s4" ('
503                "more_than_one_letter varchar primary key)" % t)
504            query('create table "%s5" ('
505                '"with space" date primary key)' % t)
506            query('create table "%s6" ('
507                'a_very_long_column_name varchar, '
508                '"with space" date, '
509                '"42" int, '
510                "primary key (a_very_long_column_name, "
511                '"with space", "42"))' % t)
512            self.assertRaises(KeyError, pkey, '%s0' % t)
513            self.assertEqual(pkey('%s1' % t), 'b')
514            self.assertEqual(pkey('%s2' % t), 'd')
515            r = pkey('%s3' % t)
516            self.assertIsInstance(r, frozenset)
517            self.assertEqual(r, frozenset('fh'))
518            self.assertEqual(pkey('%s4' % t), 'more_than_one_letter')
519            self.assertEqual(pkey('%s5' % t), 'with space')
520            r = pkey('%s6' % t)
521            self.assertIsInstance(r, frozenset)
522            self.assertEqual(r, frozenset([
523                'a_very_long_column_name', 'with space', '42']))
524            # a newly added primary key will be detected
525            query('alter table "%s0" add primary key (a)' % t)
526            self.assertEqual(pkey('%s0' % t), 'a')
527            # a changed primary key will not be detected,
528            # indicating that the internal cache is operating
529            query('alter table "%s1" rename column b to x' % t)
530            self.assertEqual(pkey('%s1' % t), 'b')
531            # we get the changed primary key when the cache is flushed
532            self.assertEqual(pkey('%s1' % t, flush=True), 'x')
533            for n in range(7):
534                query('drop table "%s%d"' % (t, n))
535
536    def testGetDatabases(self):
537        databases = self.db.get_databases()
538        self.assertIn('template0', databases)
539        self.assertIn('template1', databases)
540        self.assertNotIn('not existing database', databases)
541        self.assertIn('postgres', databases)
542        self.assertIn(dbname, databases)
543
544    def testGetTables(self):
545        get_tables = self.db.get_tables
546        result1 = get_tables()
547        self.assertIsInstance(result1, list)
548        for t in result1:
549            t = t.split('.', 1)
550            self.assertGreaterEqual(len(t), 2)
551            if len(t) > 2:
552                self.assertTrue(t[1].startswith('"'))
553            t = t[0]
554            self.assertNotEqual(t, 'information_schema')
555            self.assertFalse(t.startswith('pg_'))
556        tables = ('"A very Special Name"',
557            '"A_MiXeD_quoted_NaMe"', 'a1', 'a2',
558            'A_MiXeD_NaMe', '"another special name"',
559            'averyveryveryveryveryveryverylongtablename',
560            'b0', 'b3', 'x', 'xx', 'xXx', 'y', 'z')
561        for t in tables:
562            self.db.query('drop table if exists %s' % t)
563            self.db.query("create table %s"
564                " as select 0" % t)
565        result3 = get_tables()
566        result2 = []
567        for t in result3:
568            if t not in result1:
569                result2.append(t)
570        result3 = []
571        for t in tables:
572            if not t.startswith('"'):
573                t = t.lower()
574            result3.append('public.' + t)
575        self.assertEqual(result2, result3)
576        for t in result2:
577            self.db.query('drop table %s' % t)
578        result2 = get_tables()
579        self.assertEqual(result2, result1)
580
581    def testGetRelations(self):
582        get_relations = self.db.get_relations
583        result = get_relations()
584        self.assertIn('public.test', result)
585        self.assertIn('public.test_view', result)
586        result = get_relations('rv')
587        self.assertIn('public.test', result)
588        self.assertIn('public.test_view', result)
589        result = get_relations('r')
590        self.assertIn('public.test', result)
591        self.assertNotIn('public.test_view', result)
592        result = get_relations('v')
593        self.assertNotIn('public.test', result)
594        self.assertIn('public.test_view', result)
595        result = get_relations('cisSt')
596        self.assertNotIn('public.test', result)
597        self.assertNotIn('public.test_view', result)
598
599    def testAttnames(self):
600        self.assertRaises(pg.ProgrammingError,
601            self.db.get_attnames, 'does_not_exist')
602        self.assertRaises(pg.ProgrammingError,
603            self.db.get_attnames, 'has.too.many.dots')
604        for table in ('attnames_test_table', 'test table for attnames'):
605            self.db.query('drop table if exists "%s"' % table)
606            self.db.query('create table "%s" ('
607                'a smallint, b integer, c bigint, '
608                'e numeric, f float, f2 double precision, m money, '
609                'x smallint, y smallint, z smallint, '
610                'Normal_NaMe smallint, "Special Name" smallint, '
611                't text, u char(2), v varchar(2), '
612                'primary key (y, u)) with oids' % table)
613            attributes = self.db.get_attnames(table)
614            result = {'a': 'int', 'c': 'int', 'b': 'int',
615                'e': 'num', 'f': 'float', 'f2': 'float', 'm': 'money',
616                'normal_name': 'int', 'Special Name': 'int',
617                'u': 'text', 't': 'text', 'v': 'text',
618                'y': 'int', 'x': 'int', 'z': 'int', 'oid': 'int'}
619            self.assertEqual(attributes, result)
620            self.db.query('drop table "%s"' % table)
621
622    def testHasTablePrivilege(self):
623        can = self.db.has_table_privilege
624        self.assertEqual(can('test'), True)
625        self.assertEqual(can('test', 'select'), True)
626        self.assertEqual(can('test', 'SeLeCt'), True)
627        self.assertEqual(can('test', 'SELECT'), True)
628        self.assertEqual(can('test', 'insert'), True)
629        self.assertEqual(can('test', 'update'), True)
630        self.assertEqual(can('test', 'delete'), True)
631        self.assertEqual(can('pg_views', 'select'), True)
632        self.assertEqual(can('pg_views', 'delete'), False)
633        self.assertRaises(pg.ProgrammingError, can, 'test', 'foobar')
634        self.assertRaises(pg.ProgrammingError, can, 'table_does_not_exist')
635
636    def testGet(self):
637        get = self.db.get
638        query = self.db.query
639        table = 'get_test_table'
640        query('drop table if exists "%s"' % table)
641        query('create table "%s" ('
642            "n integer, t text) with oids" % table)
643        for n, t in enumerate('xyz'):
644            query('insert into "%s" values('"%d, '%s')"
645                % (table, n + 1, t))
646        self.assertRaises(pg.ProgrammingError, get, table, 2)
647        r = get(table, 2, 'n')
648        oid_table = 'oid(%s)' % table
649        self.assertIn(oid_table, r)
650        oid = r[oid_table]
651        self.assertIsInstance(oid, int)
652        result = {'t': 'y', 'n': 2, oid_table: oid}
653        self.assertEqual(r, result)
654        self.assertEqual(get(table + ' *', 2, 'n'), r)
655        self.assertEqual(get(table, oid, 'oid')['t'], 'y')
656        self.assertEqual(get(table, 1, 'n')['t'], 'x')
657        self.assertEqual(get(table, 3, 'n')['t'], 'z')
658        self.assertEqual(get(table, 2, 'n')['t'], 'y')
659        self.assertRaises(pg.DatabaseError, get, table, 4, 'n')
660        r['n'] = 3
661        self.assertEqual(get(table, r, 'n')['t'], 'z')
662        self.assertEqual(get(table, 1, 'n')['t'], 'x')
663        query('alter table "%s" alter n set not null' % table)
664        query('alter table "%s" add primary key (n)' % table)
665        self.assertEqual(get(table, 3)['t'], 'z')
666        self.assertEqual(get(table, 1)['t'], 'x')
667        self.assertEqual(get(table, 2)['t'], 'y')
668        r['n'] = 1
669        self.assertEqual(get(table, r)['t'], 'x')
670        r['n'] = 3
671        self.assertEqual(get(table, r)['t'], 'z')
672        r['n'] = 2
673        self.assertEqual(get(table, r)['t'], 'y')
674        query('drop table "%s"' % table)
675
676    def testGetWithCompositeKey(self):
677        get = self.db.get
678        query = self.db.query
679        table = 'get_test_table_1'
680        query('drop table if exists "%s"' % table)
681        query('create table "%s" ('
682            "n integer, t text, primary key (n))" % table)
683        for n, t in enumerate('abc'):
684            query('insert into "%s" values('
685                "%d, '%s')" % (table, n + 1, t))
686        self.assertEqual(get(table, 2)['t'], 'b')
687        query('drop table "%s"' % table)
688        table = 'get_test_table_2'
689        query('drop table if exists "%s"' % table)
690        query('create table "%s" ('
691            "n integer, m integer, t text, primary key (n, m))" % table)
692        for n in range(3):
693            for m in range(2):
694                t = chr(ord('a') + 2 * n + m)
695                query('insert into "%s" values('
696                    "%d, %d, '%s')" % (table, n + 1, m + 1, t))
697        self.assertRaises(pg.ProgrammingError, get, table, 2)
698        self.assertEqual(get(table, dict(n=2, m=2))['t'], 'd')
699        self.assertEqual(get(table, dict(n=1, m=2),
700                             ('n', 'm'))['t'], 'b')
701        self.assertEqual(get(table, dict(n=3, m=2),
702                             frozenset(['n', 'm']))['t'], 'f')
703        query('drop table "%s"' % table)
704
705    def testGetWithQuotedNames(self):
706        get = self.db.get
707        query = self.db.query
708        table = 'test table for get()'
709        query('drop table if exists "%s"' % table)
710        query('create table "%s" ('
711            '"Prime!" smallint primary key,'
712            '"much space" integer, "Questions?" text)' % table)
713        query('insert into "%s"'
714              " values(17, 1001, 'No!')" % table)
715        r = get(table, 17)
716        self.assertIsInstance(r, dict)
717        self.assertEqual(r['Prime!'], 17)
718        self.assertEqual(r['much space'], 1001)
719        self.assertEqual(r['Questions?'], 'No!')
720        query('drop table "%s"' % table)
721
722    def testGetFromView(self):
723        self.db.query('delete from test where i4=14')
724        self.db.query('insert into test (i4, v4) values('
725            "14, 'abc4')")
726        r = self.db.get('test_view', 14, 'i4')
727        self.assertIn('v4', r)
728        self.assertEqual(r['v4'], 'abc4')
729
730    def testInsert(self):
731        insert = self.db.insert
732        query = self.db.query
733        bool_on = pg.get_bool()
734        decimal = pg.get_decimal()
735        table = 'insert_test_table'
736        query('drop table if exists "%s"' % table)
737        query('create table "%s" ('
738            "i2 smallint, i4 integer, i8 bigint,"
739            " d numeric, f4 real, f8 double precision, m money,"
740            " v4 varchar(4), c4 char(4), t text,"
741            " b boolean, ts timestamp) with oids" % table)
742        oid_table = 'oid(%s)' % table
743        tests = [dict(i2=None, i4=None, i8=None),
744            (dict(i2='', i4='', i8=''), dict(i2=None, i4=None, i8=None)),
745            (dict(i2=0, i4=0, i8=0), dict(i2=0, i4=0, i8=0)),
746            dict(i2=42, i4=123456, i8=9876543210),
747            dict(i2=2 ** 15 - 1,
748                i4=int(2 ** 31 - 1), i8=long(2 ** 63 - 1)),
749            dict(d=None), (dict(d=''), dict(d=None)),
750            dict(d=Decimal(0)), (dict(d=0), dict(d=Decimal(0))),
751            dict(f4=None, f8=None), dict(f4=0, f8=0),
752            (dict(f4='', f8=''), dict(f4=None, f8=None)),
753            (dict(d=1234.5, f4=1234.5, f8=1234.5),
754                  dict(d=Decimal('1234.5'))),
755            dict(d=Decimal('123.456789'), f4=12.375, f8=123.4921875),
756            dict(d=Decimal('123456789.9876543212345678987654321')),
757            dict(m=None), (dict(m=''), dict(m=None)),
758            dict(m=Decimal('-1234.56')),
759            (dict(m=('-1234.56')), dict(m=Decimal('-1234.56'))),
760            dict(m=Decimal('1234.56')), dict(m=Decimal('123456')),
761            (dict(m='1234.56'), dict(m=Decimal('1234.56'))),
762            (dict(m=1234.5), dict(m=Decimal('1234.5'))),
763            (dict(m=-1234.5), dict(m=Decimal('-1234.5'))),
764            (dict(m=123456), dict(m=Decimal('123456'))),
765            (dict(m='1234567.89'), dict(m=Decimal('1234567.89'))),
766            dict(b=None), (dict(b=''), dict(b=None)),
767            dict(b='f'), dict(b='t'),
768            (dict(b=0), dict(b='f')), (dict(b=1), dict(b='t')),
769            (dict(b=False), dict(b='f')), (dict(b=True), dict(b='t')),
770            (dict(b='0'), dict(b='f')), (dict(b='1'), dict(b='t')),
771            (dict(b='n'), dict(b='f')), (dict(b='y'), dict(b='t')),
772            (dict(b='no'), dict(b='f')), (dict(b='yes'), dict(b='t')),
773            (dict(b='off'), dict(b='f')), (dict(b='on'), dict(b='t')),
774            dict(v4=None, c4=None, t=None),
775            (dict(v4='', c4='', t=''), dict(c4=' ' * 4)),
776            dict(v4='1234', c4='1234', t='1234' * 10),
777            dict(v4='abcd', c4='abcd', t='abcdefg'),
778            (dict(v4='abc', c4='abc', t='abc'), dict(c4='abc ')),
779            dict(ts=None), (dict(ts=''), dict(ts=None)),
780            (dict(ts=0), dict(ts=None)), (dict(ts=False), dict(ts=None)),
781            dict(ts='2012-12-21 00:00:00'),
782            (dict(ts='2012-12-21'), dict(ts='2012-12-21 00:00:00')),
783            dict(ts='2012-12-21 12:21:12'),
784            dict(ts='2013-01-05 12:13:14'),
785            dict(ts='current_timestamp')]
786        for test in tests:
787            if isinstance(test, dict):
788                data = test
789                change = {}
790            else:
791                data, change = test
792            expect = data.copy()
793            expect.update(change)
794            if bool_on:
795                b = expect.get('b')
796                if b is not None:
797                    expect['b'] = b == 't'
798            if decimal is not Decimal:
799                d = expect.get('d')
800                if d is not None:
801                    expect['d'] = decimal(d)
802                m = expect.get('m')
803                if m is not None:
804                    expect['m'] = decimal(m)
805            self.assertEqual(insert(table, data), data)
806            self.assertIn(oid_table, data)
807            oid = data[oid_table]
808            self.assertIsInstance(oid, int)
809            data = dict(item for item in data.items()
810                if item[0] in expect)
811            ts = expect.get('ts')
812            if ts == 'current_timestamp':
813                ts = expect['ts'] = data['ts']
814                if len(ts) > 19:
815                    self.assertEqual(ts[19], '.')
816                    ts = ts[:19]
817                else:
818                    self.assertEqual(len(ts), 19)
819                self.assertTrue(ts[:4].isdigit())
820                self.assertEqual(ts[4], '-')
821                self.assertEqual(ts[10], ' ')
822                self.assertTrue(ts[11:13].isdigit())
823                self.assertEqual(ts[13], ':')
824            self.assertEqual(data, expect)
825            data = query(
826                'select oid,* from "%s"' % table).dictresult()[0]
827            self.assertEqual(data['oid'], oid)
828            data = dict(item for item in data.items()
829                if item[0] in expect)
830            self.assertEqual(data, expect)
831            query('delete from "%s"' % table)
832        query('drop table "%s"' % table)
833
834    def testInsertWithQuotedNames(self):
835        insert = self.db.insert
836        query = self.db.query
837        table = 'test table for insert()'
838        query('drop table if exists "%s"' % table)
839        query('create table "%s" ('
840            '"Prime!" smallint primary key,'
841            '"much space" integer, "Questions?" text)' % table)
842        r = {'Prime!': 11, 'much space': 2002, 'Questions?': 'What?'}
843        r = insert(table, r)
844        self.assertIsInstance(r, dict)
845        self.assertEqual(r['Prime!'], 11)
846        self.assertEqual(r['much space'], 2002)
847        self.assertEqual(r['Questions?'], 'What?')
848        r = query('select * from "%s" limit 2' % table).dictresult()
849        self.assertEqual(len(r), 1)
850        r = r[0]
851        self.assertEqual(r['Prime!'], 11)
852        self.assertEqual(r['much space'], 2002)
853        self.assertEqual(r['Questions?'], 'What?')
854        query('drop table "%s"' % table)
855
856    def testUpdate(self):
857        update = self.db.update
858        query = self.db.query
859        table = 'update_test_table'
860        query('drop table if exists "%s"' % table)
861        query('create table "%s" ('
862            "n integer, t text) with oids" % table)
863        for n, t in enumerate('xyz'):
864            query('insert into "%s" values('
865                "%d, '%s')" % (table, n + 1, t))
866        self.assertRaises(pg.ProgrammingError, self.db.get, table, 2)
867        r = self.db.get(table, 2, 'n')
868        r['t'] = 'u'
869        s = update(table, r)
870        self.assertEqual(s, r)
871        q = 'select t from "%s" where n=2' % table
872        r = query(q).getresult()[0][0]
873        self.assertEqual(r, 'u')
874        query('drop table "%s"' % table)
875
876    def testUpdateWithCompositeKey(self):
877        update = self.db.update
878        query = self.db.query
879        table = 'update_test_table_1'
880        query('drop table if exists "%s"' % table)
881        query('create table "%s" ('
882            "n integer, t text, primary key (n))" % table)
883        for n, t in enumerate('abc'):
884            query('insert into "%s" values('
885                "%d, '%s')" % (table, n + 1, t))
886        self.assertRaises(pg.ProgrammingError, update,
887                          table, dict(t='b'))
888        s = dict(n=2, t='d')
889        r = update(table, s)
890        self.assertIs(r, s)
891        self.assertEqual(r['n'], 2)
892        self.assertEqual(r['t'], 'd')
893        q = 'select t from "%s" where n=2' % table
894        r = query(q).getresult()[0][0]
895        self.assertEqual(r, 'd')
896        s.update(dict(n=4, t='e'))
897        r = update(table, s)
898        self.assertEqual(r['n'], 4)
899        self.assertEqual(r['t'], 'e')
900        q = 'select t from "%s" where n=2' % table
901        r = query(q).getresult()[0][0]
902        self.assertEqual(r, 'd')
903        q = 'select t from "%s" where n=4' % table
904        r = query(q).getresult()
905        self.assertEqual(len(r), 0)
906        query('drop table "%s"' % table)
907        table = 'update_test_table_2'
908        query('drop table if exists "%s"' % table)
909        query('create table "%s" ('
910            "n integer, m integer, t text, primary key (n, m))" % table)
911        for n in range(3):
912            for m in range(2):
913                t = chr(ord('a') + 2 * n + m)
914                query('insert into "%s" values('
915                    "%d, %d, '%s')" % (table, n + 1, m + 1, t))
916        self.assertRaises(pg.ProgrammingError, update,
917                          table, dict(n=2, t='b'))
918        self.assertEqual(update(table,
919                                dict(n=2, m=2, t='x'))['t'], 'x')
920        q = 'select t from "%s" where n=2 order by m' % table
921        r = [r[0] for r in query(q).getresult()]
922        self.assertEqual(r, ['c', 'x'])
923        query('drop table "%s"' % table)
924
925    def testUpdateWithQuotedNames(self):
926        update = self.db.update
927        query = self.db.query
928        table = 'test table for update()'
929        query('drop table if exists "%s"' % table)
930        query('create table "%s" ('
931            '"Prime!" smallint primary key,'
932            '"much space" integer, "Questions?" text)' % table)
933        query('insert into "%s"'
934              " values(13, 3003, 'Why!')" % table)
935        r = {'Prime!': 13, 'much space': 7007, 'Questions?': 'When?'}
936        r = update(table, r)
937        self.assertIsInstance(r, dict)
938        self.assertEqual(r['Prime!'], 13)
939        self.assertEqual(r['much space'], 7007)
940        self.assertEqual(r['Questions?'], 'When?')
941        r = query('select * from "%s" limit 2' % table).dictresult()
942        self.assertEqual(len(r), 1)
943        r = r[0]
944        self.assertEqual(r['Prime!'], 13)
945        self.assertEqual(r['much space'], 7007)
946        self.assertEqual(r['Questions?'], 'When?')
947        query('drop table "%s"' % table)
948
949    def testUpsert(self):
950        upsert = self.db.upsert
951        query = self.db.query
952        table = 'upsert_test_table'
953        query('drop table if exists "%s"' % table)
954        query('create table "%s" ('
955            "n integer primary key, t text) with oids" % table)
956        s = dict(n=1, t='x')
957        try:
958            r = upsert(table, s)
959        except pg.ProgrammingError as error:
960            if self.db.server_version < 90500:
961                self.skipTest('database does not support upsert')
962            self.fail(str(error))
963        self.assertIs(r, s)
964        self.assertEqual(r['n'], 1)
965        self.assertEqual(r['t'], 'x')
966        s.update(n=2, t='y')
967        r = upsert(table, s, **dict.fromkeys(s))
968        self.assertIs(r, s)
969        self.assertEqual(r['n'], 2)
970        self.assertEqual(r['t'], 'y')
971        q = 'select n, t from "%s" order by n limit 3' % table
972        r = query(q).getresult()
973        self.assertEqual(r, [(1, 'x'), (2, 'y')])
974        s.update(t='z')
975        r = upsert(table, s)
976        self.assertIs(r, s)
977        self.assertEqual(r['n'], 2)
978        self.assertEqual(r['t'], 'z')
979        r = query(q).getresult()
980        self.assertEqual(r, [(1, 'x'), (2, 'z')])
981        s.update(t='n')
982        r = upsert(table, s, t=False)
983        self.assertIs(r, s)
984        self.assertEqual(r['n'], 2)
985        self.assertEqual(r['t'], 'z')
986        r = query(q).getresult()
987        self.assertEqual(r, [(1, 'x'), (2, 'z')])
988        s.update(t='y')
989        r = upsert(table, s, t=True)
990        self.assertIs(r, s)
991        self.assertEqual(r['n'], 2)
992        self.assertEqual(r['t'], 'y')
993        r = query(q).getresult()
994        self.assertEqual(r, [(1, 'x'), (2, 'y')])
995        s.update(t='n')
996        r = upsert(table, s, t="included.t || '2'")
997        self.assertIs(r, s)
998        self.assertEqual(r['n'], 2)
999        self.assertEqual(r['t'], 'y2')
1000        r = query(q).getresult()
1001        self.assertEqual(r, [(1, 'x'), (2, 'y2')])
1002        s.update(t='y')
1003        r = upsert(table, s, t="excluded.t || '3'")
1004        self.assertIs(r, s)
1005        self.assertEqual(r['n'], 2)
1006        self.assertEqual(r['t'], 'y3')
1007        r = query(q).getresult()
1008        self.assertEqual(r, [(1, 'x'), (2, 'y3')])
1009        s.update(n=1, t='2')
1010        r = upsert(table, s, t="included.t || excluded.t")
1011        self.assertIs(r, s)
1012        self.assertEqual(r['n'], 1)
1013        self.assertEqual(r['t'], 'x2')
1014        r = query(q).getresult()
1015        self.assertEqual(r, [(1, 'x2'), (2, 'y3')])
1016        query('drop table "%s"' % table)
1017
1018    def testUpsertWithCompositeKey(self):
1019        upsert = self.db.upsert
1020        query = self.db.query
1021        table = 'upsert_test_table_2'
1022        query('drop table if exists "%s"' % table)
1023        query('create table "%s" ('
1024            "n integer, m integer, t text, primary key (n, m))" % table)
1025        s = dict(n=1, m=2, t='x')
1026        try:
1027            r = upsert(table, s)
1028        except pg.ProgrammingError as error:
1029            if self.db.server_version < 90500:
1030                self.skipTest('database does not support upsert')
1031            self.fail(str(error))
1032        self.assertIs(r, s)
1033        self.assertEqual(r['n'], 1)
1034        self.assertEqual(r['m'], 2)
1035        self.assertEqual(r['t'], 'x')
1036        s.update(m=3, t='y')
1037        r = upsert(table, s, **dict.fromkeys(s))
1038        self.assertIs(r, s)
1039        self.assertEqual(r['n'], 1)
1040        self.assertEqual(r['m'], 3)
1041        self.assertEqual(r['t'], 'y')
1042        q = 'select n, m, t from "%s" order by n, m limit 3' % table
1043        r = query(q).getresult()
1044        self.assertEqual(r, [(1, 2, 'x'), (1, 3, 'y')])
1045        s.update(t='z')
1046        r = upsert(table, s)
1047        self.assertIs(r, s)
1048        self.assertEqual(r['n'], 1)
1049        self.assertEqual(r['m'], 3)
1050        self.assertEqual(r['t'], 'z')
1051        r = query(q).getresult()
1052        self.assertEqual(r, [(1, 2, 'x'), (1, 3, 'z')])
1053        s.update(t='n')
1054        r = upsert(table, s, t=False)
1055        self.assertIs(r, s)
1056        self.assertEqual(r['n'], 1)
1057        self.assertEqual(r['m'], 3)
1058        self.assertEqual(r['t'], 'z')
1059        r = query(q).getresult()
1060        self.assertEqual(r, [(1, 2, 'x'), (1, 3, 'z')])
1061        s.update(t='n')
1062        r = upsert(table, s, t=True)
1063        self.assertIs(r, s)
1064        self.assertEqual(r['n'], 1)
1065        self.assertEqual(r['m'], 3)
1066        self.assertEqual(r['t'], 'n')
1067        r = query(q).getresult()
1068        self.assertEqual(r, [(1, 2, 'x'), (1, 3, 'n')])
1069        s.update(n=2, t='y')
1070        r = upsert(table, s, t="'z'")
1071        self.assertIs(r, s)
1072        self.assertEqual(r['n'], 2)
1073        self.assertEqual(r['m'], 3)
1074        self.assertEqual(r['t'], 'y')
1075        r = query(q).getresult()
1076        self.assertEqual(r, [(1, 2, 'x'), (1, 3, 'n'), (2, 3, 'y')])
1077        s.update(n=1, t='m')
1078        r = upsert(table, s, t='included.t || excluded.t')
1079        self.assertIs(r, s)
1080        self.assertEqual(r['n'], 1)
1081        self.assertEqual(r['m'], 3)
1082        self.assertEqual(r['t'], 'nm')
1083        r = query(q).getresult()
1084        self.assertEqual(r, [(1, 2, 'x'), (1, 3, 'nm'), (2, 3, 'y')])
1085        query('drop table "%s"' % table)
1086
1087    def testUpsertWithQuotedNames(self):
1088        upsert = self.db.upsert
1089        query = self.db.query
1090        table = 'test table for upsert()'
1091        query('drop table if exists "%s"' % table)
1092        query('create table "%s" ('
1093            '"Prime!" smallint primary key,'
1094            '"much space" integer, "Questions?" text)' % table)
1095        s = {'Prime!': 31, 'much space': 9009, 'Questions?': 'Yes.'}
1096        try:
1097            r = upsert(table, s)
1098        except pg.ProgrammingError as error:
1099            if self.db.server_version < 90500:
1100                self.skipTest('database does not support upsert')
1101            self.fail(str(error))
1102        self.assertIs(r, s)
1103        self.assertEqual(r['Prime!'], 31)
1104        self.assertEqual(r['much space'], 9009)
1105        self.assertEqual(r['Questions?'], 'Yes.')
1106        q = 'select * from "%s" limit 2' % table
1107        r = query(q).getresult()
1108        self.assertEqual(r, [(31, 9009, 'Yes.')])
1109        s.update({'Questions?': 'No.'})
1110        r = upsert(table, s)
1111        self.assertIs(r, s)
1112        self.assertEqual(r['Prime!'], 31)
1113        self.assertEqual(r['much space'], 9009)
1114        self.assertEqual(r['Questions?'], 'No.')
1115        r = query(q).getresult()
1116        self.assertEqual(r, [(31, 9009, 'No.')])
1117
1118    def testClear(self):
1119        clear = self.db.clear
1120        query = self.db.query
1121        f = False if pg.get_bool() else 'f'
1122        table = 'clear_test_table'
1123        query('drop table if exists "%s"' % table)
1124        query('create table "%s" ('
1125            "n integer, b boolean, d date, t text)" % table)
1126        r = clear(table)
1127        result = {'n': 0, 'b': f, 'd': '', 't': ''}
1128        self.assertEqual(r, result)
1129        r['a'] = r['n'] = 1
1130        r['d'] = r['t'] = 'x'
1131        r['b'] = 't'
1132        r['oid'] = long(1)
1133        r = clear(table, r)
1134        result = {'a': 1, 'n': 0, 'b': f, 'd': '', 't': '',
1135            'oid': long(1)}
1136        self.assertEqual(r, result)
1137        query('drop table "%s"' % table)
1138
1139    def testClearWithQuotedNames(self):
1140        clear = self.db.clear
1141        query = self.db.query
1142        table = 'test table for clear()'
1143        query('drop table if exists "%s"' % table)
1144        query('create table "%s" ('
1145            '"Prime!" smallint primary key,'
1146            '"much space" integer, "Questions?" text)' % table)
1147        r = clear(table)
1148        self.assertIsInstance(r, dict)
1149        self.assertEqual(r['Prime!'], 0)
1150        self.assertEqual(r['much space'], 0)
1151        self.assertEqual(r['Questions?'], '')
1152        query('drop table "%s"' % table)
1153
1154    def testDelete(self):
1155        delete = self.db.delete
1156        query = self.db.query
1157        table = 'delete_test_table'
1158        query('drop table if exists "%s"' % table)
1159        query('create table "%s" ('
1160            "n integer, t text) with oids" % table)
1161        for n, t in enumerate('xyz'):
1162            query('insert into "%s" values('
1163                "%d, '%s')" % (table, n + 1, t))
1164        self.assertRaises(pg.ProgrammingError, self.db.get, table, 2)
1165        r = self.db.get(table, 1, 'n')
1166        s = delete(table, r)
1167        self.assertEqual(s, 1)
1168        r = self.db.get(table, 3, 'n')
1169        s = delete(table, r)
1170        self.assertEqual(s, 1)
1171        s = delete(table, r)
1172        self.assertEqual(s, 0)
1173        r = query('select * from "%s"' % table).dictresult()
1174        self.assertEqual(len(r), 1)
1175        r = r[0]
1176        result = {'n': 2, 't': 'y'}
1177        self.assertEqual(r, result)
1178        r = self.db.get(table, 2, 'n')
1179        s = delete(table, r)
1180        self.assertEqual(s, 1)
1181        s = delete(table, r)
1182        self.assertEqual(s, 0)
1183        self.assertRaises(pg.DatabaseError, self.db.get, table, 2, 'n')
1184        query('drop table "%s"' % table)
1185
1186    def testDeleteWithCompositeKey(self):
1187        query = self.db.query
1188        table = 'delete_test_table_1'
1189        query('drop table if exists "%s"' % table)
1190        query('create table "%s" ('
1191            "n integer, t text, primary key (n))" % table)
1192        for n, t in enumerate('abc'):
1193            query("insert into %s values("
1194                "%d, '%s')" % (table, n + 1, t))
1195        self.assertRaises(pg.ProgrammingError, self.db.delete,
1196            table, dict(t='b'))
1197        self.assertEqual(self.db.delete(table, dict(n=2)), 1)
1198        r = query('select t from "%s" where n=2' % table
1199                  ).getresult()
1200        self.assertEqual(r, [])
1201        self.assertEqual(self.db.delete(table, dict(n=2)), 0)
1202        r = query('select t from "%s" where n=3' % table
1203                  ).getresult()[0][0]
1204        self.assertEqual(r, 'c')
1205        query('drop table "%s"' % table)
1206        table = 'delete_test_table_2'
1207        query('drop table if exists "%s"' % table)
1208        query('create table "%s" ('
1209            "n integer, m integer, t text, primary key (n, m))" % table)
1210        for n in range(3):
1211            for m in range(2):
1212                t = chr(ord('a') + 2 * n + m)
1213                query('insert into "%s" values('
1214                    "%d, %d, '%s')" % (table, n + 1, m + 1, t))
1215        self.assertRaises(pg.ProgrammingError, self.db.delete,
1216            table, dict(n=2, t='b'))
1217        self.assertEqual(self.db.delete(table, dict(n=2, m=2)), 1)
1218        r = [r[0] for r in query('select t from "%s" where n=2'
1219            ' order by m' % table).getresult()]
1220        self.assertEqual(r, ['c'])
1221        self.assertEqual(self.db.delete(table, dict(n=2, m=2)), 0)
1222        r = [r[0] for r in query('select t from "%s" where n=3'
1223            ' order by m' % table).getresult()]
1224        self.assertEqual(r, ['e', 'f'])
1225        self.assertEqual(self.db.delete(table, dict(n=3, m=1)), 1)
1226        r = [r[0] for r in query('select t from "%s" where n=3'
1227            ' order by m' % table).getresult()]
1228        self.assertEqual(r, ['f'])
1229        query('drop table "%s"' % table)
1230
1231    def testDeleteWithQuotedNames(self):
1232        delete = self.db.delete
1233        query = self.db.query
1234        table = 'test table for delete()'
1235        query('drop table if exists "%s"' % table)
1236        query('create table "%s" ('
1237            '"Prime!" smallint primary key,'
1238            '"much space" integer, "Questions?" text)' % table)
1239        query('insert into "%s"'
1240              " values(19, 5005, 'Yes!')" % table)
1241        r = {'Prime!': 17}
1242        r = delete(table, r)
1243        self.assertEqual(r, 0)
1244        r = query('select count(*) from "%s"' % table).getresult()
1245        self.assertEqual(r[0][0], 1)
1246        r = {'Prime!': 19}
1247        r = delete(table, r)
1248        self.assertEqual(r, 1)
1249        r = query('select count(*) from "%s"' % table).getresult()
1250        self.assertEqual(r[0][0], 0)
1251        query('drop table "%s"' % table)
1252
1253    def testTransaction(self):
1254        query = self.db.query
1255        query("drop table if exists test_table")
1256        query("create table test_table (n integer)")
1257        self.db.begin()
1258        query("insert into test_table values (1)")
1259        query("insert into test_table values (2)")
1260        self.db.commit()
1261        self.db.begin()
1262        query("insert into test_table values (3)")
1263        query("insert into test_table values (4)")
1264        self.db.rollback()
1265        self.db.begin()
1266        query("insert into test_table values (5)")
1267        self.db.savepoint('before6')
1268        query("insert into test_table values (6)")
1269        self.db.rollback('before6')
1270        query("insert into test_table values (7)")
1271        self.db.commit()
1272        self.db.begin()
1273        self.db.savepoint('before8')
1274        query("insert into test_table values (8)")
1275        self.db.release('before8')
1276        self.assertRaises(pg.ProgrammingError, self.db.rollback, 'before8')
1277        self.db.commit()
1278        self.db.start()
1279        query("insert into test_table values (9)")
1280        self.db.end()
1281        r = [r[0] for r in query(
1282            "select * from test_table order by 1").getresult()]
1283        self.assertEqual(r, [1, 2, 5, 7, 9])
1284        query("drop table test_table")
1285
1286    def testContextManager(self):
1287        query = self.db.query
1288        query("drop table if exists test_table")
1289        query("create table test_table (n integer check(n>0))")
1290        with self.db:
1291            query("insert into test_table values (1)")
1292            query("insert into test_table values (2)")
1293        try:
1294            with self.db:
1295                query("insert into test_table values (3)")
1296                query("insert into test_table values (4)")
1297                raise ValueError('test transaction should rollback')
1298        except ValueError as error:
1299            self.assertEqual(str(error), 'test transaction should rollback')
1300        with self.db:
1301            query("insert into test_table values (5)")
1302        try:
1303            with self.db:
1304                query("insert into test_table values (6)")
1305                query("insert into test_table values (-1)")
1306        except pg.ProgrammingError as error:
1307            self.assertTrue('check' in str(error))
1308        with self.db:
1309            query("insert into test_table values (7)")
1310        r = [r[0] for r in query(
1311            "select * from test_table order by 1").getresult()]
1312        self.assertEqual(r, [1, 2, 5, 7])
1313        query("drop table test_table")
1314
1315    def testBytea(self):
1316        query = self.db.query
1317        query('drop table if exists bytea_test')
1318        query('create table bytea_test (n smallint primary key, data bytea)')
1319        s = b"It's all \\ kinds \x00 of\r nasty \xff stuff!\n"
1320        r = self.db.escape_bytea(s)
1321        query('insert into bytea_test values(3,$1)', (r,))
1322        r = query('select * from bytea_test where n=3').getresult()
1323        self.assertEqual(len(r), 1)
1324        r = r[0]
1325        self.assertEqual(len(r), 2)
1326        self.assertEqual(r[0], 3)
1327        r = r[1]
1328        self.assertIsInstance(r, str)
1329        r = self.db.unescape_bytea(r)
1330        self.assertIsInstance(r, bytes)
1331        self.assertEqual(r, s)
1332        query('drop table bytea_test')
1333
1334    def testInsertUpdateGetBytea(self):
1335        query = self.db.query
1336        query('drop table if exists bytea_test')
1337        query('create table bytea_test (n smallint primary key, data bytea)')
1338        # insert null value
1339        r = self.db.insert('bytea_test', n=0, data=None)
1340        self.assertIsInstance(r, dict)
1341        self.assertIn('n', r)
1342        self.assertEqual(r['n'], 0)
1343        self.assertIn('data', r)
1344        self.assertIsNone(r['data'])
1345        s = b'None'
1346        r = self.db.update('bytea_test', n=0, data=s)
1347        self.assertIsInstance(r, dict)
1348        self.assertIn('n', r)
1349        self.assertEqual(r['n'], 0)
1350        self.assertIn('data', r)
1351        r = r['data']
1352        self.assertIsInstance(r, bytes)
1353        self.assertEqual(r, s)
1354        r = self.db.update('bytea_test', n=0, data=None)
1355        self.assertIsNone(r['data'])
1356        # insert as bytes
1357        s = b"It's all \\ kinds \x00 of\r nasty \xff stuff!\n"
1358        r = self.db.insert('bytea_test', n=5, data=s)
1359        self.assertIsInstance(r, dict)
1360        self.assertIn('n', r)
1361        self.assertEqual(r['n'], 5)
1362        self.assertIn('data', r)
1363        r = r['data']
1364        self.assertIsInstance(r, bytes)
1365        self.assertEqual(r, s)
1366        # update as bytes
1367        s += b"and now even more \x00 nasty \t stuff!\f"
1368        r = self.db.update('bytea_test', n=5, data=s)
1369        self.assertIsInstance(r, dict)
1370        self.assertIn('n', r)
1371        self.assertEqual(r['n'], 5)
1372        self.assertIn('data', r)
1373        r = r['data']
1374        self.assertIsInstance(r, bytes)
1375        self.assertEqual(r, s)
1376        r = query('select * from bytea_test where n=5').getresult()
1377        self.assertEqual(len(r), 1)
1378        r = r[0]
1379        self.assertEqual(len(r), 2)
1380        self.assertEqual(r[0], 5)
1381        r = r[1]
1382        self.assertIsInstance(r, str)
1383        r = self.db.unescape_bytea(r)
1384        self.assertIsInstance(r, bytes)
1385        self.assertEqual(r, s)
1386        r = self.db.get('bytea_test', dict(n=5))
1387        self.assertIsInstance(r, dict)
1388        self.assertIn('n', r)
1389        self.assertEqual(r['n'], 5)
1390        self.assertIn('data', r)
1391        r = r['data']
1392        self.assertIsInstance(r, bytes)
1393        self.assertEqual(r, s)
1394        query('drop table bytea_test')
1395
1396    def testDebugWithCallable(self):
1397        if debug:
1398            self.assertEqual(self.db.debug, debug)
1399        else:
1400            self.assertIsNone(self.db.debug)
1401        s = []
1402        self.db.debug = s.append
1403        try:
1404            self.db.query("select 1")
1405            self.db.query("select 2")
1406            self.assertEqual(s, ["select 1", "select 2"])
1407        finally:
1408            self.db.debug = debug
1409
1410
1411class TestDBClassNonStdOpts(TestDBClass):
1412    """Test the methods of the DB class with non-standard global options."""
1413
1414    @classmethod
1415    def setUpClass(cls):
1416        cls.saved_options = {}
1417        cls.set_option('decimal', float)
1418        not_bool = not pg.get_bool()
1419        cls.set_option('bool', not_bool)
1420        unnamed_result = lambda q: q.getresult()
1421        cls.set_option('namedresult', unnamed_result)
1422        super(TestDBClassNonStdOpts, cls).setUpClass()
1423
1424    @classmethod
1425    def tearDownClass(cls):
1426        super(TestDBClassNonStdOpts, cls).tearDownClass()
1427        cls.reset_option('namedresult')
1428        cls.reset_option('bool')
1429        cls.reset_option('decimal')
1430
1431    @classmethod
1432    def set_option(cls, option, value):
1433        cls.saved_options[option] = getattr(pg, 'get_' + option)()
1434        return getattr(pg, 'set_' + option)(value)
1435
1436    @classmethod
1437    def reset_option(cls, option):
1438        return getattr(pg, 'set_' + option)(cls.saved_options[option])
1439
1440
1441class TestSchemas(unittest.TestCase):
1442    """Test correct handling of schemas (namespaces)."""
1443
1444    @classmethod
1445    def setUpClass(cls):
1446        db = DB()
1447        query = db.query
1448        query("set client_min_messages=warning")
1449        for num_schema in range(5):
1450            if num_schema:
1451                schema = "s%d" % num_schema
1452                query("drop schema if exists %s cascade" % (schema,))
1453                try:
1454                    query("create schema %s" % (schema,))
1455                except pg.ProgrammingError:
1456                    raise RuntimeError("The test user cannot create schemas.\n"
1457                        "Grant create on database %s to the user"
1458                        " for running these tests." % dbname)
1459            else:
1460                schema = "public"
1461                query("drop table if exists %s.t" % (schema,))
1462                query("drop table if exists %s.t%d" % (schema, num_schema))
1463            query("create table %s.t with oids as select 1 as n, %d as d"
1464                  % (schema, num_schema))
1465            query("create table %s.t%d with oids as select 1 as n, %d as d"
1466                  % (schema, num_schema, num_schema))
1467        db.close()
1468
1469    @classmethod
1470    def tearDownClass(cls):
1471        db = DB()
1472        query = db.query
1473        query("set client_min_messages=warning")
1474        for num_schema in range(5):
1475            if num_schema:
1476                schema = "s%d" % num_schema
1477                query("drop schema %s cascade" % (schema,))
1478            else:
1479                schema = "public"
1480                query("drop table %s.t" % (schema,))
1481                query("drop table %s.t%d" % (schema, num_schema))
1482        db.close()
1483
1484    def setUp(self):
1485        self.db = DB()
1486        self.db.query("set client_min_messages=warning")
1487
1488    def tearDown(self):
1489        self.db.close()
1490
1491    def testGetTables(self):
1492        tables = self.db.get_tables()
1493        for num_schema in range(5):
1494            if num_schema:
1495                schema = "s" + str(num_schema)
1496            else:
1497                schema = "public"
1498            for t in (schema + ".t",
1499                    schema + ".t" + str(num_schema)):
1500                self.assertIn(t, tables)
1501
1502    def testGetAttnames(self):
1503        get_attnames = self.db.get_attnames
1504        query = self.db.query
1505        result = {'oid': 'int', 'd': 'int', 'n': 'int'}
1506        r = get_attnames("t")
1507        self.assertEqual(r, result)
1508        r = get_attnames("s4.t4")
1509        self.assertEqual(r, result)
1510        query("drop table if exists s3.t3m")
1511        query("create table s3.t3m with oids as select 1 as m")
1512        result_m = {'oid': 'int', 'm': 'int'}
1513        r = get_attnames("s3.t3m")
1514        self.assertEqual(r, result_m)
1515        query("set search_path to s1,s3")
1516        r = get_attnames("t3")
1517        self.assertEqual(r, result)
1518        r = get_attnames("t3m")
1519        self.assertEqual(r, result_m)
1520        query("drop table s3.t3m")
1521
1522    def testGet(self):
1523        get = self.db.get
1524        query = self.db.query
1525        PrgError = pg.ProgrammingError
1526        self.assertEqual(get("t", 1, 'n')['d'], 0)
1527        self.assertEqual(get("t0", 1, 'n')['d'], 0)
1528        self.assertEqual(get("public.t", 1, 'n')['d'], 0)
1529        self.assertEqual(get("public.t0", 1, 'n')['d'], 0)
1530        self.assertRaises(PrgError, get, "public.t1", 1, 'n')
1531        self.assertEqual(get("s1.t1", 1, 'n')['d'], 1)
1532        self.assertEqual(get("s3.t", 1, 'n')['d'], 3)
1533        query("set search_path to s2,s4")
1534        self.assertRaises(PrgError, get, "t1", 1, 'n')
1535        self.assertEqual(get("t4", 1, 'n')['d'], 4)
1536        self.assertRaises(PrgError, get, "t3", 1, 'n')
1537        self.assertEqual(get("t", 1, 'n')['d'], 2)
1538        self.assertEqual(get("s3.t3", 1, 'n')['d'], 3)
1539        query("set search_path to s1,s3")
1540        self.assertRaises(PrgError, get, "t2", 1, 'n')
1541        self.assertEqual(get("t3", 1, 'n')['d'], 3)
1542        self.assertRaises(PrgError, get, "t4", 1, 'n')
1543        self.assertEqual(get("t", 1, 'n')['d'], 1)
1544        self.assertEqual(get("s4.t4", 1, 'n')['d'], 4)
1545
1546    def testMunging(self):
1547        get = self.db.get
1548        query = self.db.query
1549        r = get("t", 1, 'n')
1550        self.assertIn('oid(t)', r)
1551        query("set search_path to s2")
1552        r = get("t2", 1, 'n')
1553        self.assertIn('oid(t2)', r)
1554        query("set search_path to s3")
1555        r = get("t", 1, 'n')
1556        self.assertIn('oid(t)', r)
1557
1558
1559if __name__ == '__main__':
1560    unittest.main()
Note: See TracBrowser for help on using the repository browser.