source: trunk/tests/test_classic_dbwrapper.py @ 744

Last change on this file since 744 was 744, checked in by cito, 4 years ago

Use cleanup feature in unittests

  • Property svn:executable set to *
  • Property svn:keywords set to Id
File size: 63.4 KB
Line 
1#! /usr/bin/python
2# -*- coding: utf-8 -*-
3
4"""Test the classic PyGreSQL interface.
5
6Sub-tests for the DB wrapper object.
7
8Contributed by Christoph Zwerschke.
9
10These tests need a database to test against.
11
12"""
13
14try:
15    import unittest2 as unittest  # for Python < 2.7
16except ImportError:
17    import unittest
18import os
19
20import pg  # the module under test
21
22from decimal import Decimal
23
24# We need a database to test against.  If LOCAL_PyGreSQL.py exists we will
25# get our information from that.  Otherwise we use the defaults.
26# The current user must have create schema privilege on the database.
27dbname = 'unittest'
28dbhost = None
29dbport = 5432
30
31debug = False  # let DB wrapper print debugging output
32
33try:
34    from .LOCAL_PyGreSQL import *
35except (ImportError, ValueError):
36    try:
37        from LOCAL_PyGreSQL import *
38    except ImportError:
39        pass
40
41try:
42    long
43except NameError:  # Python >= 3.0
44    long = int
45
46try:
47    unicode
48except NameError:  # Python >= 3.0
49    unicode = str
50
51try:
52    from collections import OrderedDict
53except ImportError:  # Python 2.6 or 3.0
54    OrderedDict = dict
55
56windows = os.name == 'nt'
57
58# There is a known a bug in libpq under Windows which can cause
59# the interface to crash when calling PQhost():
60do_not_ask_for_host = windows
61do_not_ask_for_host_reason = 'libpq issue on Windows'
62
63
64def DB():
65    """Create a DB wrapper object connecting to the test database."""
66    db = pg.DB(dbname, dbhost, dbport)
67    if debug:
68        db.debug = debug
69    db.query("set client_min_messages=warning")
70    return db
71
72
73class TestDBClassBasic(unittest.TestCase):
74    """Test existence of the DB class wrapped pg connection methods."""
75
76    def setUp(self):
77        self.db = DB()
78
79    def tearDown(self):
80        try:
81            self.db.close()
82        except pg.InternalError:
83            pass
84
85    def testAllDBAttributes(self):
86        attributes = [
87            'begin',
88            'cancel',
89            'clear',
90            'close',
91            'commit',
92            'db',
93            'dbname',
94            'debug',
95            'delete',
96            'end',
97            'endcopy',
98            'error',
99            'escape_bytea',
100            'escape_identifier',
101            'escape_literal',
102            'escape_string',
103            'fileno',
104            'get',
105            'get_attnames',
106            'get_databases',
107            'get_notice_receiver',
108            'get_relations',
109            'get_tables',
110            'getline',
111            'getlo',
112            'getnotify',
113            'has_table_privilege',
114            'host',
115            'insert',
116            'inserttable',
117            'locreate',
118            'loimport',
119            'notification_handler',
120            'options',
121            'parameter',
122            'pkey',
123            'port',
124            'protocol_version',
125            'putline',
126            'query',
127            'release',
128            'reopen',
129            'reset',
130            'rollback',
131            'savepoint',
132            'server_version',
133            'set_notice_receiver',
134            'source',
135            'start',
136            'status',
137            'transaction',
138            'unescape_bytea',
139            'update',
140            'upsert',
141            'use_regtypes',
142            'user',
143        ]
144        db_attributes = [a for a in dir(self.db)
145            if not a.startswith('_')]
146        self.assertEqual(attributes, db_attributes)
147
148    def testAttributeDb(self):
149        self.assertEqual(self.db.db.db, dbname)
150
151    def testAttributeDbname(self):
152        self.assertEqual(self.db.dbname, dbname)
153
154    def testAttributeError(self):
155        error = self.db.error
156        self.assertTrue(not error or 'krb5_' in error)
157        self.assertEqual(self.db.error, self.db.db.error)
158
159    @unittest.skipIf(do_not_ask_for_host, do_not_ask_for_host_reason)
160    def testAttributeHost(self):
161        def_host = 'localhost'
162        host = self.db.host
163        self.assertIsInstance(host, str)
164        self.assertEqual(host, dbhost or def_host)
165        self.assertEqual(host, self.db.db.host)
166
167    def testAttributeOptions(self):
168        no_options = ''
169        options = self.db.options
170        self.assertEqual(options, no_options)
171        self.assertEqual(options, self.db.db.options)
172
173    def testAttributePort(self):
174        def_port = 5432
175        port = self.db.port
176        self.assertIsInstance(port, int)
177        self.assertEqual(port, dbport or def_port)
178        self.assertEqual(port, self.db.db.port)
179
180    def testAttributeProtocolVersion(self):
181        protocol_version = self.db.protocol_version
182        self.assertIsInstance(protocol_version, int)
183        self.assertTrue(2 <= protocol_version < 4)
184        self.assertEqual(protocol_version, self.db.db.protocol_version)
185
186    def testAttributeServerVersion(self):
187        server_version = self.db.server_version
188        self.assertIsInstance(server_version, int)
189        self.assertTrue(70400 <= server_version < 100000)
190        self.assertEqual(server_version, self.db.db.server_version)
191
192    def testAttributeStatus(self):
193        status_ok = 1
194        status = self.db.status
195        self.assertIsInstance(status, int)
196        self.assertEqual(status, status_ok)
197        self.assertEqual(status, self.db.db.status)
198
199    def testAttributeUser(self):
200        no_user = 'Deprecated facility'
201        user = self.db.user
202        self.assertTrue(user)
203        self.assertIsInstance(user, str)
204        self.assertNotEqual(user, no_user)
205        self.assertEqual(user, self.db.db.user)
206
207    def testMethodEscapeLiteral(self):
208        self.assertEqual(self.db.escape_literal(''), "''")
209
210    def testMethodEscapeIdentifier(self):
211        self.assertEqual(self.db.escape_identifier(''), '""')
212
213    def testMethodEscapeString(self):
214        self.assertEqual(self.db.escape_string(''), '')
215
216    def testMethodEscapeBytea(self):
217        self.assertEqual(self.db.escape_bytea('').replace(
218            '\\x', '').replace('\\', ''), '')
219
220    def testMethodUnescapeBytea(self):
221        self.assertEqual(self.db.unescape_bytea(''), b'')
222
223    def testMethodQuery(self):
224        query = self.db.query
225        query("select 1+1")
226        query("select 1+$1+$2", 2, 3)
227        query("select 1+$1+$2", (2, 3))
228        query("select 1+$1+$2", [2, 3])
229        query("select 1+$1", 1)
230
231    def testMethodQueryEmpty(self):
232        self.assertRaises(ValueError, self.db.query, '')
233
234    def testMethodQueryProgrammingError(self):
235        try:
236            self.db.query("select 1/0")
237        except pg.ProgrammingError as error:
238            self.assertEqual(error.sqlstate, '22012')
239
240    def testMethodEndcopy(self):
241        try:
242            self.db.endcopy()
243        except IOError:
244            pass
245
246    def testMethodClose(self):
247        self.db.close()
248        try:
249            self.db.reset()
250        except pg.Error:
251            pass
252        else:
253            self.fail('Reset should give an error for a closed connection')
254        self.assertRaises(pg.InternalError, self.db.close)
255        self.assertRaises(pg.InternalError, self.db.query, 'select 1')
256
257    def testExistingConnection(self):
258        db = pg.DB(self.db.db)
259        self.assertEqual(self.db.db, db.db)
260        self.assertTrue(db.db)
261        db.close()
262        self.assertTrue(db.db)
263        db.reopen()
264        self.assertTrue(db.db)
265        db.close()
266        self.assertTrue(db.db)
267        db = pg.DB(self.db)
268        self.assertEqual(self.db.db, db.db)
269        db = pg.DB(db=self.db.db)
270        self.assertEqual(self.db.db, db.db)
271
272        class DB2:
273            pass
274
275        db2 = DB2()
276        db2._cnx = self.db.db
277        db = pg.DB(db2)
278        self.assertEqual(self.db.db, db.db)
279
280
281class TestDBClass(unittest.TestCase):
282    """Test the methods of the DB class wrapped pg connection."""
283
284    @classmethod
285    def setUpClass(cls):
286        db = DB()
287        db.query("drop table if exists test cascade")
288        db.query("create table test ("
289            "i2 smallint, i4 integer, i8 bigint,"
290            "d numeric, f4 real, f8 double precision, m money, "
291            "v4 varchar(4), c4 char(4), t text)")
292        db.query("create or replace view test_view as"
293            " select i4, v4 from test")
294        db.close()
295
296    @classmethod
297    def tearDownClass(cls):
298        db = DB()
299        db.query("drop table test cascade")
300        db.close()
301
302    def setUp(self):
303        self.db = DB()
304        query = self.db.query
305        query('set client_encoding=utf8')
306        query('set standard_conforming_strings=on')
307        query("set lc_monetary='C'")
308        query("set datestyle='ISO,YMD'")
309        query('set bytea_output=hex')
310
311    def tearDown(self):
312        self.doCleanups()
313        self.db.close()
314
315    def testClassName(self):
316        self.assertEqual(self.db.__class__.__name__, 'DB')
317
318    def testModuleName(self):
319        self.assertEqual(self.db.__module__, 'pg')
320        self.assertEqual(self.db.__class__.__module__, 'pg')
321
322    def testEscapeLiteral(self):
323        f = self.db.escape_literal
324        r = f(b"plain")
325        self.assertIsInstance(r, bytes)
326        self.assertEqual(r, b"'plain'")
327        r = f(u"plain")
328        self.assertIsInstance(r, unicode)
329        self.assertEqual(r, u"'plain'")
330        r = f(u"that's kÀse".encode('utf-8'))
331        self.assertIsInstance(r, bytes)
332        self.assertEqual(r, u"'that''s kÀse'".encode('utf-8'))
333        r = f(u"that's kÀse")
334        self.assertIsInstance(r, unicode)
335        self.assertEqual(r, u"'that''s kÀse'")
336        self.assertEqual(f(r"It's fine to have a \ inside."),
337            r" E'It''s fine to have a \\ inside.'")
338        self.assertEqual(f('No "quotes" must be escaped.'),
339            "'No \"quotes\" must be escaped.'")
340
341    def testEscapeIdentifier(self):
342        f = self.db.escape_identifier
343        r = f(b"plain")
344        self.assertIsInstance(r, bytes)
345        self.assertEqual(r, b'"plain"')
346        r = f(u"plain")
347        self.assertIsInstance(r, unicode)
348        self.assertEqual(r, u'"plain"')
349        r = f(u"that's kÀse".encode('utf-8'))
350        self.assertIsInstance(r, bytes)
351        self.assertEqual(r, u'"that\'s kÀse"'.encode('utf-8'))
352        r = f(u"that's kÀse")
353        self.assertIsInstance(r, unicode)
354        self.assertEqual(r, u'"that\'s kÀse"')
355        self.assertEqual(f(r"It's fine to have a \ inside."),
356            '"It\'s fine to have a \\ inside."')
357        self.assertEqual(f('All "quotes" must be escaped.'),
358            '"All ""quotes"" must be escaped."')
359
360    def testEscapeString(self):
361        f = self.db.escape_string
362        r = f(b"plain")
363        self.assertIsInstance(r, bytes)
364        self.assertEqual(r, b"plain")
365        r = f(u"plain")
366        self.assertIsInstance(r, unicode)
367        self.assertEqual(r, u"plain")
368        r = f(u"that's kÀse".encode('utf-8'))
369        self.assertIsInstance(r, bytes)
370        self.assertEqual(r, u"that''s kÀse".encode('utf-8'))
371        r = f(u"that's kÀse")
372        self.assertIsInstance(r, unicode)
373        self.assertEqual(r, u"that''s kÀse")
374        self.assertEqual(f(r"It's fine to have a \ inside."),
375            r"It''s fine to have a \ inside.")
376
377    def testEscapeBytea(self):
378        f = self.db.escape_bytea
379        # note that escape_byte always returns hex output since Pg 9.0,
380        # regardless of the bytea_output setting
381        r = f(b'plain')
382        self.assertIsInstance(r, bytes)
383        self.assertEqual(r, b'\\x706c61696e')
384        r = f(u'plain')
385        self.assertIsInstance(r, unicode)
386        self.assertEqual(r, u'\\x706c61696e')
387        r = f(u"das is' kÀse".encode('utf-8'))
388        self.assertIsInstance(r, bytes)
389        self.assertEqual(r, b'\\x64617320697327206bc3a47365')
390        r = f(u"das is' kÀse")
391        self.assertIsInstance(r, unicode)
392        self.assertEqual(r, u'\\x64617320697327206bc3a47365')
393        self.assertEqual(f(b'O\x00ps\xff!'), b'\\x4f007073ff21')
394
395    def testUnescapeBytea(self):
396        f = self.db.unescape_bytea
397        r = f(b'plain')
398        self.assertIsInstance(r, bytes)
399        self.assertEqual(r, b'plain')
400        r = f(u'plain')
401        self.assertIsInstance(r, bytes)
402        self.assertEqual(r, b'plain')
403        r = f(b"das is' k\\303\\244se")
404        self.assertIsInstance(r, bytes)
405        self.assertEqual(r, u"das is' kÀse".encode('utf8'))
406        r = f(u"das is' k\\303\\244se")
407        self.assertIsInstance(r, bytes)
408        self.assertEqual(r, u"das is' kÀse".encode('utf8'))
409        self.assertEqual(f(r'O\\000ps\\377!'), b'O\\000ps\\377!')
410        self.assertEqual(f(r'\\x706c61696e'), b'\\x706c61696e')
411        self.assertEqual(f(r'\\x746861742773206be47365'),
412            b'\\x746861742773206be47365')
413        self.assertEqual(f(r'\\x4f007073ff21'), b'\\x4f007073ff21')
414
415    def testGetAttnames(self):
416        get_attnames = self.db.get_attnames
417        query = self.db.query
418        query("drop table if exists test_table")
419        self.addCleanup(query, "drop table test_table")
420        query("create table test_table("
421            " n int, alpha smallint, beta bool,"
422            " gamma char(5), tau text, v varchar(3))")
423        r = get_attnames("test_table")
424        self.assertIsInstance(r, dict)
425        self.assertEquals(r, dict(
426            n='int', alpha='int', beta='bool',
427            gamma='text', tau='text', v='text'))
428
429    def testGetAttnamesWithQuotes(self):
430        get_attnames = self.db.get_attnames
431        query = self.db.query
432        table = 'test table for get_attnames()'
433        query('drop table if exists "%s"' % table)
434        self.addCleanup(query, 'drop table "%s"' % table)
435        query('create table "%s"('
436            '"Prime!" smallint,'
437            '"much space" integer, "Questions?" text)' % table)
438        r = get_attnames(table)
439        self.assertIsInstance(r, dict)
440        self.assertEquals(r, {
441            'Prime!': 'int', 'much space': 'int', 'Questions?': 'text'})
442
443    def testGetAttnamesWithRegtypes(self):
444        get_attnames = self.db.get_attnames
445        query = self.db.query
446        query("drop table if exists test_table")
447        self.addCleanup(query, "drop table test_table")
448        query("create table test_table("
449            " n int, alpha smallint, beta bool,"
450            " gamma char(5), tau text, v varchar(3))")
451        self.db.use_regtypes(True)
452        try:
453            r = get_attnames("test_table")
454            self.assertIsInstance(r, dict)
455        finally:
456            self.db.use_regtypes(False)
457        self.assertEquals(r, dict(
458            n='integer', alpha='smallint', beta='boolean',
459            gamma='character', tau='text', v='character varying'))
460
461    def testGetAttnamesIsCached(self):
462        get_attnames = self.db.get_attnames
463        query = self.db.query
464        query("drop table if exists test_table")
465        self.addCleanup(query, "drop table if exists test_table")
466        query("create table test_table(col int)")
467        r = get_attnames("test_table")
468        self.assertIsInstance(r, dict)
469        self.assertEquals(r, dict(col='int'))
470        query("drop table test_table")
471        query("create table test_table(col text)")
472        r = get_attnames("test_table")
473        self.assertEquals(r, dict(col='int'))
474        r = get_attnames("test_table", flush=True)
475        self.assertEquals(r, dict(col='text'))
476        query("drop table test_table")
477        r = get_attnames("test_table")
478        self.assertEquals(r, dict(col='text'))
479        self.assertRaises(pg.ProgrammingError,
480            get_attnames, "test_table", flush=True)
481
482    def testGetAttnamesIsOrdered(self):
483        get_attnames = self.db.get_attnames
484        query = self.db.query
485        query("drop table if exists test_table")
486        self.addCleanup(query, "drop table test_table")
487        query("create table test_table("
488            " n int, alpha smallint, v varchar(3),"
489            " gamma char(5), tau text, beta bool)")
490        r = get_attnames("test_table")
491        self.assertIsInstance(r, OrderedDict)
492        self.assertEquals(r, OrderedDict([
493            ('n', 'int'), ('alpha', 'int'), ('v', 'text'),
494            ('gamma', 'text'), ('tau', 'text'), ('beta', 'bool')]))
495        if OrderedDict is dict:
496            self.skipTest('OrderedDict is not supported')
497        r = ' '.join(list(r.keys()))
498        self.assertEquals(r, 'n alpha v gamma tau beta')
499
500    def testQuery(self):
501        query = self.db.query
502        query("drop table if exists test_table")
503        self.addCleanup(query, "drop table test_table")
504        q = "create table test_table (n integer) with oids"
505        r = query(q)
506        self.assertIsNone(r)
507        q = "insert into test_table values (1)"
508        r = query(q)
509        self.assertIsInstance(r, int)
510        q = "insert into test_table select 2"
511        r = query(q)
512        self.assertIsInstance(r, int)
513        oid = r
514        q = "select oid from test_table where n=2"
515        r = query(q).getresult()
516        self.assertEqual(len(r), 1)
517        r = r[0]
518        self.assertEqual(len(r), 1)
519        r = r[0]
520        self.assertEqual(r, oid)
521        q = "insert into test_table select 3 union select 4 union select 5"
522        r = query(q)
523        self.assertIsInstance(r, str)
524        self.assertEqual(r, '3')
525        q = "update test_table set n=4 where n<5"
526        r = query(q)
527        self.assertIsInstance(r, str)
528        self.assertEqual(r, '4')
529        q = "delete from test_table"
530        r = query(q)
531        self.assertIsInstance(r, str)
532        self.assertEqual(r, '5')
533
534    def testMultipleQueries(self):
535        self.assertEqual(self.db.query(
536            "create temporary table test_multi (n integer);"
537            "insert into test_multi values (4711);"
538            "select n from test_multi").getresult()[0][0], 4711)
539
540    def testQueryWithParams(self):
541        query = self.db.query
542        query("drop table if exists test_table")
543        self.addCleanup(query, "drop table test_table")
544        q = "create table test_table (n1 integer, n2 integer) with oids"
545        query(q)
546        q = "insert into test_table values ($1, $2)"
547        r = query(q, (1, 2))
548        self.assertIsInstance(r, int)
549        r = query(q, [3, 4])
550        self.assertIsInstance(r, int)
551        r = query(q, [5, 6])
552        self.assertIsInstance(r, int)
553        q = "select * from test_table order by 1, 2"
554        self.assertEqual(query(q).getresult(),
555            [(1, 2), (3, 4), (5, 6)])
556        q = "select * from test_table where n1=$1 and n2=$2"
557        self.assertEqual(query(q, 3, 4).getresult(), [(3, 4)])
558        q = "update test_table set n2=$2 where n1=$1"
559        r = query(q, 3, 7)
560        self.assertEqual(r, '1')
561        q = "select * from test_table order by 1, 2"
562        self.assertEqual(query(q).getresult(),
563            [(1, 2), (3, 7), (5, 6)])
564        q = "delete from test_table where n2!=$1"
565        r = query(q, 4)
566        self.assertEqual(r, '3')
567
568    def testEmptyQuery(self):
569        self.assertRaises(ValueError, self.db.query, '')
570
571    def testQueryProgrammingError(self):
572        try:
573            self.db.query("select 1/0")
574        except pg.ProgrammingError as error:
575            self.assertEqual(error.sqlstate, '22012')
576
577    def testPkey(self):
578        query = self.db.query
579        pkey = self.db.pkey
580        for t in ('pkeytest', 'primary key test'):
581            for n in range(7):
582                query('drop table if exists "%s%d"' % (t, n))
583                self.addCleanup(query, 'drop table "%s%d"' % (t, n))
584            query('create table "%s0" ('
585                "a smallint)" % t)
586            query('create table "%s1" ('
587                "b smallint primary key)" % t)
588            query('create table "%s2" ('
589                "c smallint, d smallint primary key)" % t)
590            query('create table "%s3" ('
591                "e smallint, f smallint, g smallint, "
592                "h smallint, i smallint, "
593                "primary key (f, h))" % t)
594            query('create table "%s4" ('
595                "more_than_one_letter varchar primary key)" % t)
596            query('create table "%s5" ('
597                '"with space" date primary key)' % t)
598            query('create table "%s6" ('
599                'a_very_long_column_name varchar, '
600                '"with space" date, '
601                '"42" int, '
602                "primary key (a_very_long_column_name, "
603                '"with space", "42"))' % t)
604            self.assertRaises(KeyError, pkey, '%s0' % t)
605            self.assertEqual(pkey('%s1' % t), 'b')
606            self.assertEqual(pkey('%s2' % t), 'd')
607            r = pkey('%s3' % t)
608            self.assertIsInstance(r, frozenset)
609            self.assertEqual(r, frozenset('fh'))
610            self.assertEqual(pkey('%s4' % t), 'more_than_one_letter')
611            self.assertEqual(pkey('%s5' % t), 'with space')
612            r = pkey('%s6' % t)
613            self.assertIsInstance(r, frozenset)
614            self.assertEqual(r, frozenset([
615                'a_very_long_column_name', 'with space', '42']))
616            # a newly added primary key will be detected
617            query('alter table "%s0" add primary key (a)' % t)
618            self.assertEqual(pkey('%s0' % t), 'a')
619            # a changed primary key will not be detected,
620            # indicating that the internal cache is operating
621            query('alter table "%s1" rename column b to x' % t)
622            self.assertEqual(pkey('%s1' % t), 'b')
623            # we get the changed primary key when the cache is flushed
624            self.assertEqual(pkey('%s1' % t, flush=True), 'x')
625
626    def testGetDatabases(self):
627        databases = self.db.get_databases()
628        self.assertIn('template0', databases)
629        self.assertIn('template1', databases)
630        self.assertNotIn('not existing database', databases)
631        self.assertIn('postgres', databases)
632        self.assertIn(dbname, databases)
633
634    def testGetTables(self):
635        get_tables = self.db.get_tables
636        result1 = get_tables()
637        self.assertIsInstance(result1, list)
638        for t in result1:
639            t = t.split('.', 1)
640            self.assertGreaterEqual(len(t), 2)
641            if len(t) > 2:
642                self.assertTrue(t[1].startswith('"'))
643            t = t[0]
644            self.assertNotEqual(t, 'information_schema')
645            self.assertFalse(t.startswith('pg_'))
646        tables = ('"A very Special Name"',
647            '"A_MiXeD_quoted_NaMe"', 'a1', 'a2',
648            'A_MiXeD_NaMe', '"another special name"',
649            'averyveryveryveryveryveryverylongtablename',
650            'b0', 'b3', 'x', 'xx', 'xXx', 'y', 'z')
651        for t in tables:
652            self.db.query('drop table if exists %s' % t)
653            self.db.query("create table %s"
654                " as select 0" % t)
655        result3 = get_tables()
656        result2 = []
657        for t in result3:
658            if t not in result1:
659                result2.append(t)
660        result3 = []
661        for t in tables:
662            if not t.startswith('"'):
663                t = t.lower()
664            result3.append('public.' + t)
665        self.assertEqual(result2, result3)
666        for t in result2:
667            self.db.query('drop table %s' % t)
668        result2 = get_tables()
669        self.assertEqual(result2, result1)
670
671    def testGetRelations(self):
672        get_relations = self.db.get_relations
673        result = get_relations()
674        self.assertIn('public.test', result)
675        self.assertIn('public.test_view', result)
676        result = get_relations('rv')
677        self.assertIn('public.test', result)
678        self.assertIn('public.test_view', result)
679        result = get_relations('r')
680        self.assertIn('public.test', result)
681        self.assertNotIn('public.test_view', result)
682        result = get_relations('v')
683        self.assertNotIn('public.test', result)
684        self.assertIn('public.test_view', result)
685        result = get_relations('cisSt')
686        self.assertNotIn('public.test', result)
687        self.assertNotIn('public.test_view', result)
688
689    def testAttnames(self):
690        self.assertRaises(pg.ProgrammingError,
691            self.db.get_attnames, 'does_not_exist')
692        self.assertRaises(pg.ProgrammingError,
693            self.db.get_attnames, 'has.too.many.dots')
694        for table in ('attnames_test_table', 'test table for attnames'):
695            self.db.query('drop table if exists "%s"' % table)
696            self.addCleanup(self.db.query, 'drop table "%s"' % table)
697            self.db.query('create table "%s" ('
698                'a smallint, b integer, c bigint, '
699                'e numeric, f float, f2 double precision, m money, '
700                'x smallint, y smallint, z smallint, '
701                'Normal_NaMe smallint, "Special Name" smallint, '
702                't text, u char(2), v varchar(2), '
703                'primary key (y, u)) with oids' % table)
704            attributes = self.db.get_attnames(table)
705            result = {'a': 'int', 'c': 'int', 'b': 'int',
706                'e': 'num', 'f': 'float', 'f2': 'float', 'm': 'money',
707                'normal_name': 'int', 'Special Name': 'int',
708                'u': 'text', 't': 'text', 'v': 'text',
709                'y': 'int', 'x': 'int', 'z': 'int', 'oid': 'int'}
710            self.assertEqual(attributes, result)
711
712    def testHasTablePrivilege(self):
713        can = self.db.has_table_privilege
714        self.assertEqual(can('test'), True)
715        self.assertEqual(can('test', 'select'), True)
716        self.assertEqual(can('test', 'SeLeCt'), True)
717        self.assertEqual(can('test', 'SELECT'), True)
718        self.assertEqual(can('test', 'insert'), True)
719        self.assertEqual(can('test', 'update'), True)
720        self.assertEqual(can('test', 'delete'), True)
721        self.assertEqual(can('pg_views', 'select'), True)
722        self.assertEqual(can('pg_views', 'delete'), False)
723        self.assertRaises(pg.ProgrammingError, can, 'test', 'foobar')
724        self.assertRaises(pg.ProgrammingError, can, 'table_does_not_exist')
725
726    def testGet(self):
727        get = self.db.get
728        query = self.db.query
729        table = 'get_test_table'
730        query('drop table if exists "%s"' % table)
731        self.addCleanup(query, 'drop table "%s"' % table)
732        query('create table "%s" ('
733            "n integer, t text) with oids" % table)
734        for n, t in enumerate('xyz'):
735            query('insert into "%s" values('"%d, '%s')"
736                % (table, n + 1, t))
737        self.assertRaises(pg.ProgrammingError, get, table, 2)
738        r = get(table, 2, 'n')
739        oid_table = 'oid(%s)' % table
740        self.assertIn(oid_table, r)
741        oid = r[oid_table]
742        self.assertIsInstance(oid, int)
743        result = {'t': 'y', 'n': 2, oid_table: oid}
744        self.assertEqual(r, result)
745        self.assertEqual(get(table + ' *', 2, 'n'), r)
746        self.assertEqual(get(table, oid, 'oid')['t'], 'y')
747        self.assertEqual(get(table, 1, 'n')['t'], 'x')
748        self.assertEqual(get(table, 3, 'n')['t'], 'z')
749        self.assertEqual(get(table, 2, 'n')['t'], 'y')
750        self.assertRaises(pg.DatabaseError, get, table, 4, 'n')
751        r['n'] = 3
752        self.assertEqual(get(table, r, 'n')['t'], 'z')
753        self.assertEqual(get(table, 1, 'n')['t'], 'x')
754        query('alter table "%s" alter n set not null' % table)
755        query('alter table "%s" add primary key (n)' % table)
756        self.assertEqual(get(table, 3)['t'], 'z')
757        self.assertEqual(get(table, 1)['t'], 'x')
758        self.assertEqual(get(table, 2)['t'], 'y')
759        r['n'] = 1
760        self.assertEqual(get(table, r)['t'], 'x')
761        r['n'] = 3
762        self.assertEqual(get(table, r)['t'], 'z')
763        r['n'] = 2
764        self.assertEqual(get(table, r)['t'], 'y')
765
766    def testGetWithCompositeKey(self):
767        get = self.db.get
768        query = self.db.query
769        table = 'get_test_table_1'
770        query('drop table if exists "%s"' % table)
771        self.addCleanup(query, 'drop table "%s"' % table)
772        query('create table "%s" ('
773            "n integer, t text, primary key (n))" % table)
774        for n, t in enumerate('abc'):
775            query('insert into "%s" values('
776                "%d, '%s')" % (table, n + 1, t))
777        self.assertEqual(get(table, 2)['t'], 'b')
778        table = 'get_test_table_2'
779        query('drop table if exists "%s"' % table)
780        self.addCleanup(query, 'drop table "%s"' % table)
781        query('create table "%s" ('
782            "n integer, m integer, t text, primary key (n, m))" % table)
783        for n in range(3):
784            for m in range(2):
785                t = chr(ord('a') + 2 * n + m)
786                query('insert into "%s" values('
787                    "%d, %d, '%s')" % (table, n + 1, m + 1, t))
788        self.assertRaises(pg.ProgrammingError, get, table, 2)
789        self.assertEqual(get(table, dict(n=2, m=2))['t'], 'd')
790        r = get(table, dict(n=1, m=2), ('n', 'm'))
791        self.assertEqual(r['t'], 'b')
792        r = get(table, dict(n=3, m=2), frozenset(['n', 'm']))
793        self.assertEqual(r['t'], 'f')
794
795    def testGetWithQuotedNames(self):
796        get = self.db.get
797        query = self.db.query
798        table = 'test table for get()'
799        query('drop table if exists "%s"' % table)
800        self.addCleanup(query, 'drop table "%s"' % table)
801        query('create table "%s" ('
802            '"Prime!" smallint primary key,'
803            '"much space" integer, "Questions?" text)' % table)
804        query('insert into "%s"'
805              " values(17, 1001, 'No!')" % table)
806        r = get(table, 17)
807        self.assertIsInstance(r, dict)
808        self.assertEqual(r['Prime!'], 17)
809        self.assertEqual(r['much space'], 1001)
810        self.assertEqual(r['Questions?'], 'No!')
811
812    def testGetFromView(self):
813        self.db.query('delete from test where i4=14')
814        self.db.query('insert into test (i4, v4) values('
815            "14, 'abc4')")
816        r = self.db.get('test_view', 14, 'i4')
817        self.assertIn('v4', r)
818        self.assertEqual(r['v4'], 'abc4')
819
820    def testGetLittleBobbyTables(self):
821        get = self.db.get
822        query = self.db.query
823        query("drop table if exists test_students")
824        self.addCleanup(query, "drop table test_students")
825        query("create table test_students (firstname varchar primary key,"
826            " nickname varchar, grade char(2))")
827        query("insert into test_students values ("
828              "'D''Arcy', 'Darcey', 'A+')")
829        query("insert into test_students values ("
830              "'Sheldon', 'Moonpie', 'A+')")
831        query("insert into test_students values ("
832              "'Robert', 'Little Bobby Tables', 'D-')")
833        r = get('test_students', 'Sheldon')
834        self.assertEqual(r, dict(
835            firstname="Sheldon", nickname='Moonpie', grade='A+'))
836        r = get('test_students', 'Robert')
837        self.assertEqual(r, dict(
838            firstname="Robert", nickname='Little Bobby Tables', grade='D-'))
839        r = get('test_students', "D'Arcy")
840        self.assertEqual(r, dict(
841            firstname="D'Arcy", nickname='Darcey', grade='A+'))
842        try:
843            get('test_students', "D' Arcy")
844        except pg.DatabaseError as error:
845            self.assertEqual(str(error),
846                'No such record in test_students\nwhere "firstname" = $1\n'
847                'with $1="D\' Arcy"')
848        try:
849            get('test_students', "Robert'); TRUNCATE TABLE test_students;--")
850        except pg.DatabaseError as error:
851            self.assertEqual(str(error),
852                'No such record in test_students\nwhere "firstname" = $1\n'
853                'with $1="Robert\'); TRUNCATE TABLE test_students;--"')
854        q = "select * from test_students order by 1 limit 4"
855        r = query(q).getresult()
856        self.assertEqual(len(r), 3)
857        self.assertEqual(r[1][2], 'D-')
858
859    def testInsert(self):
860        insert = self.db.insert
861        query = self.db.query
862        bool_on = pg.get_bool()
863        decimal = pg.get_decimal()
864        table = 'insert_test_table'
865        query('drop table if exists "%s"' % table)
866        self.addCleanup(query, 'drop table "%s"' % table)
867        query('create table "%s" ('
868            "i2 smallint, i4 integer, i8 bigint,"
869            " d numeric, f4 real, f8 double precision, m money,"
870            " v4 varchar(4), c4 char(4), t text,"
871            " b boolean, ts timestamp) with oids" % table)
872        oid_table = 'oid(%s)' % table
873        tests = [dict(i2=None, i4=None, i8=None),
874            (dict(i2='', i4='', i8=''), dict(i2=None, i4=None, i8=None)),
875            (dict(i2=0, i4=0, i8=0), dict(i2=0, i4=0, i8=0)),
876            dict(i2=42, i4=123456, i8=9876543210),
877            dict(i2=2 ** 15 - 1,
878                i4=int(2 ** 31 - 1), i8=long(2 ** 63 - 1)),
879            dict(d=None), (dict(d=''), dict(d=None)),
880            dict(d=Decimal(0)), (dict(d=0), dict(d=Decimal(0))),
881            dict(f4=None, f8=None), dict(f4=0, f8=0),
882            (dict(f4='', f8=''), dict(f4=None, f8=None)),
883            (dict(d=1234.5, f4=1234.5, f8=1234.5),
884                  dict(d=Decimal('1234.5'))),
885            dict(d=Decimal('123.456789'), f4=12.375, f8=123.4921875),
886            dict(d=Decimal('123456789.9876543212345678987654321')),
887            dict(m=None), (dict(m=''), dict(m=None)),
888            dict(m=Decimal('-1234.56')),
889            (dict(m=('-1234.56')), dict(m=Decimal('-1234.56'))),
890            dict(m=Decimal('1234.56')), dict(m=Decimal('123456')),
891            (dict(m='1234.56'), dict(m=Decimal('1234.56'))),
892            (dict(m=1234.5), dict(m=Decimal('1234.5'))),
893            (dict(m=-1234.5), dict(m=Decimal('-1234.5'))),
894            (dict(m=123456), dict(m=Decimal('123456'))),
895            (dict(m='1234567.89'), dict(m=Decimal('1234567.89'))),
896            dict(b=None), (dict(b=''), dict(b=None)),
897            dict(b='f'), dict(b='t'),
898            (dict(b=0), dict(b='f')), (dict(b=1), dict(b='t')),
899            (dict(b=False), dict(b='f')), (dict(b=True), dict(b='t')),
900            (dict(b='0'), dict(b='f')), (dict(b='1'), dict(b='t')),
901            (dict(b='n'), dict(b='f')), (dict(b='y'), dict(b='t')),
902            (dict(b='no'), dict(b='f')), (dict(b='yes'), dict(b='t')),
903            (dict(b='off'), dict(b='f')), (dict(b='on'), dict(b='t')),
904            dict(v4=None, c4=None, t=None),
905            (dict(v4='', c4='', t=''), dict(c4=' ' * 4)),
906            dict(v4='1234', c4='1234', t='1234' * 10),
907            dict(v4='abcd', c4='abcd', t='abcdefg'),
908            (dict(v4='abc', c4='abc', t='abc'), dict(c4='abc ')),
909            dict(ts=None), (dict(ts=''), dict(ts=None)),
910            (dict(ts=0), dict(ts=None)), (dict(ts=False), dict(ts=None)),
911            dict(ts='2012-12-21 00:00:00'),
912            (dict(ts='2012-12-21'), dict(ts='2012-12-21 00:00:00')),
913            dict(ts='2012-12-21 12:21:12'),
914            dict(ts='2013-01-05 12:13:14'),
915            dict(ts='current_timestamp')]
916        for test in tests:
917            if isinstance(test, dict):
918                data = test
919                change = {}
920            else:
921                data, change = test
922            expect = data.copy()
923            expect.update(change)
924            if bool_on:
925                b = expect.get('b')
926                if b is not None:
927                    expect['b'] = b == 't'
928            if decimal is not Decimal:
929                d = expect.get('d')
930                if d is not None:
931                    expect['d'] = decimal(d)
932                m = expect.get('m')
933                if m is not None:
934                    expect['m'] = decimal(m)
935            self.assertEqual(insert(table, data), data)
936            self.assertIn(oid_table, data)
937            oid = data[oid_table]
938            self.assertIsInstance(oid, int)
939            data = dict(item for item in data.items()
940                if item[0] in expect)
941            ts = expect.get('ts')
942            if ts == 'current_timestamp':
943                ts = expect['ts'] = data['ts']
944                if len(ts) > 19:
945                    self.assertEqual(ts[19], '.')
946                    ts = ts[:19]
947                else:
948                    self.assertEqual(len(ts), 19)
949                self.assertTrue(ts[:4].isdigit())
950                self.assertEqual(ts[4], '-')
951                self.assertEqual(ts[10], ' ')
952                self.assertTrue(ts[11:13].isdigit())
953                self.assertEqual(ts[13], ':')
954            self.assertEqual(data, expect)
955            data = query(
956                'select oid,* from "%s"' % table).dictresult()[0]
957            self.assertEqual(data['oid'], oid)
958            data = dict(item for item in data.items()
959                if item[0] in expect)
960            self.assertEqual(data, expect)
961            query('delete from "%s"' % table)
962
963    def testInsertWithQuotedNames(self):
964        insert = self.db.insert
965        query = self.db.query
966        table = 'test table for insert()'
967        query('drop table if exists "%s"' % table)
968        self.addCleanup(query, 'drop table "%s"' % table)
969        query('create table "%s" ('
970            '"Prime!" smallint primary key,'
971            '"much space" integer, "Questions?" text)' % table)
972        r = {'Prime!': 11, 'much space': 2002, 'Questions?': 'What?'}
973        r = insert(table, r)
974        self.assertIsInstance(r, dict)
975        self.assertEqual(r['Prime!'], 11)
976        self.assertEqual(r['much space'], 2002)
977        self.assertEqual(r['Questions?'], 'What?')
978        r = query('select * from "%s" limit 2' % table).dictresult()
979        self.assertEqual(len(r), 1)
980        r = r[0]
981        self.assertEqual(r['Prime!'], 11)
982        self.assertEqual(r['much space'], 2002)
983        self.assertEqual(r['Questions?'], 'What?')
984
985    def testUpdate(self):
986        update = self.db.update
987        query = self.db.query
988        table = 'update_test_table'
989        query('drop table if exists "%s"' % table)
990        self.addCleanup(query, 'drop table "%s"' % table)
991        query('create table "%s" ('
992            "n integer, t text) with oids" % table)
993        for n, t in enumerate('xyz'):
994            query('insert into "%s" values('
995                "%d, '%s')" % (table, n + 1, t))
996        self.assertRaises(pg.ProgrammingError, self.db.get, table, 2)
997        r = self.db.get(table, 2, 'n')
998        r['t'] = 'u'
999        s = update(table, r)
1000        self.assertEqual(s, r)
1001        q = 'select t from "%s" where n=2' % table
1002        r = query(q).getresult()[0][0]
1003        self.assertEqual(r, 'u')
1004
1005    def testUpdateWithCompositeKey(self):
1006        update = self.db.update
1007        query = self.db.query
1008        table = 'update_test_table_1'
1009        query('drop table if exists "%s"' % table)
1010        self.addCleanup(query, 'drop table if exists "%s"' % table)
1011        query('create table "%s" ('
1012            "n integer, t text, primary key (n))" % table)
1013        for n, t in enumerate('abc'):
1014            query('insert into "%s" values('
1015                "%d, '%s')" % (table, n + 1, t))
1016        self.assertRaises(pg.ProgrammingError, update,
1017                          table, dict(t='b'))
1018        s = dict(n=2, t='d')
1019        r = update(table, s)
1020        self.assertIs(r, s)
1021        self.assertEqual(r['n'], 2)
1022        self.assertEqual(r['t'], 'd')
1023        q = 'select t from "%s" where n=2' % table
1024        r = query(q).getresult()[0][0]
1025        self.assertEqual(r, 'd')
1026        s.update(dict(n=4, t='e'))
1027        r = update(table, s)
1028        self.assertEqual(r['n'], 4)
1029        self.assertEqual(r['t'], 'e')
1030        q = 'select t from "%s" where n=2' % table
1031        r = query(q).getresult()[0][0]
1032        self.assertEqual(r, 'd')
1033        q = 'select t from "%s" where n=4' % table
1034        r = query(q).getresult()
1035        self.assertEqual(len(r), 0)
1036        query('drop table "%s"' % table)
1037        table = 'update_test_table_2'
1038        query('drop table if exists "%s"' % table)
1039        query('create table "%s" ('
1040            "n integer, m integer, t text, primary key (n, m))" % table)
1041        for n in range(3):
1042            for m in range(2):
1043                t = chr(ord('a') + 2 * n + m)
1044                query('insert into "%s" values('
1045                    "%d, %d, '%s')" % (table, n + 1, m + 1, t))
1046        self.assertRaises(pg.ProgrammingError, update,
1047                          table, dict(n=2, t='b'))
1048        self.assertEqual(update(table,
1049                                dict(n=2, m=2, t='x'))['t'], 'x')
1050        q = 'select t from "%s" where n=2 order by m' % table
1051        r = [r[0] for r in query(q).getresult()]
1052        self.assertEqual(r, ['c', 'x'])
1053
1054    def testUpdateWithQuotedNames(self):
1055        update = self.db.update
1056        query = self.db.query
1057        table = 'test table for update()'
1058        query('drop table if exists "%s"' % table)
1059        self.addCleanup(query, 'drop table "%s"' % table)
1060        query('create table "%s" ('
1061            '"Prime!" smallint primary key,'
1062            '"much space" integer, "Questions?" text)' % table)
1063        query('insert into "%s"'
1064              " values(13, 3003, 'Why!')" % table)
1065        r = {'Prime!': 13, 'much space': 7007, 'Questions?': 'When?'}
1066        r = update(table, r)
1067        self.assertIsInstance(r, dict)
1068        self.assertEqual(r['Prime!'], 13)
1069        self.assertEqual(r['much space'], 7007)
1070        self.assertEqual(r['Questions?'], 'When?')
1071        r = query('select * from "%s" limit 2' % table).dictresult()
1072        self.assertEqual(len(r), 1)
1073        r = r[0]
1074        self.assertEqual(r['Prime!'], 13)
1075        self.assertEqual(r['much space'], 7007)
1076        self.assertEqual(r['Questions?'], 'When?')
1077
1078    def testUpsert(self):
1079        upsert = self.db.upsert
1080        query = self.db.query
1081        table = 'upsert_test_table'
1082        query('drop table if exists "%s"' % table)
1083        self.addCleanup(query, 'drop table "%s"' % table)
1084        query('create table "%s" ('
1085            "n integer primary key, t text) with oids" % table)
1086        s = dict(n=1, t='x')
1087        try:
1088            r = upsert(table, s)
1089        except pg.ProgrammingError as error:
1090            if self.db.server_version < 90500:
1091                self.skipTest('database does not support upsert')
1092            self.fail(str(error))
1093        self.assertIs(r, s)
1094        self.assertEqual(r['n'], 1)
1095        self.assertEqual(r['t'], 'x')
1096        s.update(n=2, t='y')
1097        r = upsert(table, s, **dict.fromkeys(s))
1098        self.assertIs(r, s)
1099        self.assertEqual(r['n'], 2)
1100        self.assertEqual(r['t'], 'y')
1101        q = 'select n, t from "%s" order by n limit 3' % table
1102        r = query(q).getresult()
1103        self.assertEqual(r, [(1, 'x'), (2, 'y')])
1104        s.update(t='z')
1105        r = upsert(table, s)
1106        self.assertIs(r, s)
1107        self.assertEqual(r['n'], 2)
1108        self.assertEqual(r['t'], 'z')
1109        r = query(q).getresult()
1110        self.assertEqual(r, [(1, 'x'), (2, 'z')])
1111        s.update(t='n')
1112        r = upsert(table, s, t=False)
1113        self.assertIs(r, s)
1114        self.assertEqual(r['n'], 2)
1115        self.assertEqual(r['t'], 'z')
1116        r = query(q).getresult()
1117        self.assertEqual(r, [(1, 'x'), (2, 'z')])
1118        s.update(t='y')
1119        r = upsert(table, s, t=True)
1120        self.assertIs(r, s)
1121        self.assertEqual(r['n'], 2)
1122        self.assertEqual(r['t'], 'y')
1123        r = query(q).getresult()
1124        self.assertEqual(r, [(1, 'x'), (2, 'y')])
1125        s.update(t='n')
1126        r = upsert(table, s, t="included.t || '2'")
1127        self.assertIs(r, s)
1128        self.assertEqual(r['n'], 2)
1129        self.assertEqual(r['t'], 'y2')
1130        r = query(q).getresult()
1131        self.assertEqual(r, [(1, 'x'), (2, 'y2')])
1132        s.update(t='y')
1133        r = upsert(table, s, t="excluded.t || '3'")
1134        self.assertIs(r, s)
1135        self.assertEqual(r['n'], 2)
1136        self.assertEqual(r['t'], 'y3')
1137        r = query(q).getresult()
1138        self.assertEqual(r, [(1, 'x'), (2, 'y3')])
1139        s.update(n=1, t='2')
1140        r = upsert(table, s, t="included.t || excluded.t")
1141        self.assertIs(r, s)
1142        self.assertEqual(r['n'], 1)
1143        self.assertEqual(r['t'], 'x2')
1144        r = query(q).getresult()
1145        self.assertEqual(r, [(1, 'x2'), (2, 'y3')])
1146
1147    def testUpsertWithCompositeKey(self):
1148        upsert = self.db.upsert
1149        query = self.db.query
1150        table = 'upsert_test_table_2'
1151        query('drop table if exists "%s"' % table)
1152        self.addCleanup(query, 'drop table "%s"' % table)
1153        query('create table "%s" ('
1154            "n integer, m integer, t text, primary key (n, m))" % table)
1155        s = dict(n=1, m=2, t='x')
1156        try:
1157            r = upsert(table, s)
1158        except pg.ProgrammingError as error:
1159            if self.db.server_version < 90500:
1160                self.skipTest('database does not support upsert')
1161            self.fail(str(error))
1162        self.assertIs(r, s)
1163        self.assertEqual(r['n'], 1)
1164        self.assertEqual(r['m'], 2)
1165        self.assertEqual(r['t'], 'x')
1166        s.update(m=3, t='y')
1167        r = upsert(table, s, **dict.fromkeys(s))
1168        self.assertIs(r, s)
1169        self.assertEqual(r['n'], 1)
1170        self.assertEqual(r['m'], 3)
1171        self.assertEqual(r['t'], 'y')
1172        q = 'select n, m, t from "%s" order by n, m limit 3' % table
1173        r = query(q).getresult()
1174        self.assertEqual(r, [(1, 2, 'x'), (1, 3, 'y')])
1175        s.update(t='z')
1176        r = upsert(table, s)
1177        self.assertIs(r, s)
1178        self.assertEqual(r['n'], 1)
1179        self.assertEqual(r['m'], 3)
1180        self.assertEqual(r['t'], 'z')
1181        r = query(q).getresult()
1182        self.assertEqual(r, [(1, 2, 'x'), (1, 3, 'z')])
1183        s.update(t='n')
1184        r = upsert(table, s, t=False)
1185        self.assertIs(r, s)
1186        self.assertEqual(r['n'], 1)
1187        self.assertEqual(r['m'], 3)
1188        self.assertEqual(r['t'], 'z')
1189        r = query(q).getresult()
1190        self.assertEqual(r, [(1, 2, 'x'), (1, 3, 'z')])
1191        s.update(t='n')
1192        r = upsert(table, s, t=True)
1193        self.assertIs(r, s)
1194        self.assertEqual(r['n'], 1)
1195        self.assertEqual(r['m'], 3)
1196        self.assertEqual(r['t'], 'n')
1197        r = query(q).getresult()
1198        self.assertEqual(r, [(1, 2, 'x'), (1, 3, 'n')])
1199        s.update(n=2, t='y')
1200        r = upsert(table, s, t="'z'")
1201        self.assertIs(r, s)
1202        self.assertEqual(r['n'], 2)
1203        self.assertEqual(r['m'], 3)
1204        self.assertEqual(r['t'], 'y')
1205        r = query(q).getresult()
1206        self.assertEqual(r, [(1, 2, 'x'), (1, 3, 'n'), (2, 3, 'y')])
1207        s.update(n=1, t='m')
1208        r = upsert(table, s, t='included.t || excluded.t')
1209        self.assertIs(r, s)
1210        self.assertEqual(r['n'], 1)
1211        self.assertEqual(r['m'], 3)
1212        self.assertEqual(r['t'], 'nm')
1213        r = query(q).getresult()
1214        self.assertEqual(r, [(1, 2, 'x'), (1, 3, 'nm'), (2, 3, 'y')])
1215
1216    def testUpsertWithQuotedNames(self):
1217        upsert = self.db.upsert
1218        query = self.db.query
1219        table = 'test table for upsert()'
1220        query('drop table if exists "%s"' % table)
1221        self.addCleanup(query, 'drop table "%s"' % table)
1222        query('create table "%s" ('
1223            '"Prime!" smallint primary key,'
1224            '"much space" integer, "Questions?" text)' % table)
1225        s = {'Prime!': 31, 'much space': 9009, 'Questions?': 'Yes.'}
1226        try:
1227            r = upsert(table, s)
1228        except pg.ProgrammingError as error:
1229            if self.db.server_version < 90500:
1230                self.skipTest('database does not support upsert')
1231            self.fail(str(error))
1232        self.assertIs(r, s)
1233        self.assertEqual(r['Prime!'], 31)
1234        self.assertEqual(r['much space'], 9009)
1235        self.assertEqual(r['Questions?'], 'Yes.')
1236        q = 'select * from "%s" limit 2' % table
1237        r = query(q).getresult()
1238        self.assertEqual(r, [(31, 9009, 'Yes.')])
1239        s.update({'Questions?': 'No.'})
1240        r = upsert(table, s)
1241        self.assertIs(r, s)
1242        self.assertEqual(r['Prime!'], 31)
1243        self.assertEqual(r['much space'], 9009)
1244        self.assertEqual(r['Questions?'], 'No.')
1245        r = query(q).getresult()
1246        self.assertEqual(r, [(31, 9009, 'No.')])
1247
1248    def testClear(self):
1249        clear = self.db.clear
1250        query = self.db.query
1251        f = False if pg.get_bool() else 'f'
1252        table = 'clear_test_table'
1253        query('drop table if exists "%s"' % table)
1254        self.addCleanup(query, 'drop table "%s"' % table)
1255        query('create table "%s" ('
1256            "n integer, b boolean, d date, t text)" % table)
1257        r = clear(table)
1258        result = {'n': 0, 'b': f, 'd': '', 't': ''}
1259        self.assertEqual(r, result)
1260        r['a'] = r['n'] = 1
1261        r['d'] = r['t'] = 'x'
1262        r['b'] = 't'
1263        r['oid'] = long(1)
1264        r = clear(table, r)
1265        result = {'a': 1, 'n': 0, 'b': f, 'd': '', 't': '',
1266            'oid': long(1)}
1267        self.assertEqual(r, result)
1268
1269    def testClearWithQuotedNames(self):
1270        clear = self.db.clear
1271        query = self.db.query
1272        table = 'test table for clear()'
1273        query('drop table if exists "%s"' % table)
1274        self.addCleanup(query, 'drop table "%s"' % table)
1275        query('create table "%s" ('
1276            '"Prime!" smallint primary key,'
1277            '"much space" integer, "Questions?" text)' % table)
1278        r = clear(table)
1279        self.assertIsInstance(r, dict)
1280        self.assertEqual(r['Prime!'], 0)
1281        self.assertEqual(r['much space'], 0)
1282        self.assertEqual(r['Questions?'], '')
1283
1284    def testDelete(self):
1285        delete = self.db.delete
1286        query = self.db.query
1287        table = 'delete_test_table'
1288        query('drop table if exists "%s"' % table)
1289        self.addCleanup(query, 'drop table "%s"' % table)
1290        query('create table "%s" ('
1291            "n integer, t text) with oids" % table)
1292        for n, t in enumerate('xyz'):
1293            query('insert into "%s" values('
1294                "%d, '%s')" % (table, n + 1, t))
1295        self.assertRaises(pg.ProgrammingError, self.db.get, table, 2)
1296        r = self.db.get(table, 1, 'n')
1297        s = delete(table, r)
1298        self.assertEqual(s, 1)
1299        r = self.db.get(table, 3, 'n')
1300        s = delete(table, r)
1301        self.assertEqual(s, 1)
1302        s = delete(table, r)
1303        self.assertEqual(s, 0)
1304        r = query('select * from "%s"' % table).dictresult()
1305        self.assertEqual(len(r), 1)
1306        r = r[0]
1307        result = {'n': 2, 't': 'y'}
1308        self.assertEqual(r, result)
1309        r = self.db.get(table, 2, 'n')
1310        s = delete(table, r)
1311        self.assertEqual(s, 1)
1312        s = delete(table, r)
1313        self.assertEqual(s, 0)
1314        self.assertRaises(pg.DatabaseError, self.db.get, table, 2, 'n')
1315
1316    def testDeleteWithCompositeKey(self):
1317        query = self.db.query
1318        table = 'delete_test_table_1'
1319        query('drop table if exists "%s"' % table)
1320        self.addCleanup(query, 'drop table "%s"' % table)
1321        query('create table "%s" ('
1322            "n integer, t text, primary key (n))" % table)
1323        for n, t in enumerate('abc'):
1324            query("insert into %s values("
1325                "%d, '%s')" % (table, n + 1, t))
1326        self.assertRaises(pg.ProgrammingError, self.db.delete,
1327            table, dict(t='b'))
1328        self.assertEqual(self.db.delete(table, dict(n=2)), 1)
1329        r = query('select t from "%s" where n=2' % table
1330                  ).getresult()
1331        self.assertEqual(r, [])
1332        self.assertEqual(self.db.delete(table, dict(n=2)), 0)
1333        r = query('select t from "%s" where n=3' % table
1334                  ).getresult()[0][0]
1335        self.assertEqual(r, 'c')
1336        table = 'delete_test_table_2'
1337        query('drop table if exists "%s"' % table)
1338        self.addCleanup(query, 'drop table "%s"' % table)
1339        query('create table "%s" ('
1340            "n integer, m integer, t text, primary key (n, m))" % table)
1341        for n in range(3):
1342            for m in range(2):
1343                t = chr(ord('a') + 2 * n + m)
1344                query('insert into "%s" values('
1345                    "%d, %d, '%s')" % (table, n + 1, m + 1, t))
1346        self.assertRaises(pg.ProgrammingError, self.db.delete,
1347            table, dict(n=2, t='b'))
1348        self.assertEqual(self.db.delete(table, dict(n=2, m=2)), 1)
1349        r = [r[0] for r in query('select t from "%s" where n=2'
1350            ' order by m' % table).getresult()]
1351        self.assertEqual(r, ['c'])
1352        self.assertEqual(self.db.delete(table, dict(n=2, m=2)), 0)
1353        r = [r[0] for r in query('select t from "%s" where n=3'
1354            ' order by m' % table).getresult()]
1355        self.assertEqual(r, ['e', 'f'])
1356        self.assertEqual(self.db.delete(table, dict(n=3, m=1)), 1)
1357        r = [r[0] for r in query('select t from "%s" where n=3'
1358            ' order by m' % table).getresult()]
1359        self.assertEqual(r, ['f'])
1360
1361    def testDeleteWithQuotedNames(self):
1362        delete = self.db.delete
1363        query = self.db.query
1364        table = 'test table for delete()'
1365        query('drop table if exists "%s"' % table)
1366        self.addCleanup(query, 'drop table "%s"' % table)
1367        query('create table "%s" ('
1368            '"Prime!" smallint primary key,'
1369            '"much space" integer, "Questions?" text)' % table)
1370        query('insert into "%s"'
1371              " values(19, 5005, 'Yes!')" % table)
1372        r = {'Prime!': 17}
1373        r = delete(table, r)
1374        self.assertEqual(r, 0)
1375        r = query('select count(*) from "%s"' % table).getresult()
1376        self.assertEqual(r[0][0], 1)
1377        r = {'Prime!': 19}
1378        r = delete(table, r)
1379        self.assertEqual(r, 1)
1380        r = query('select count(*) from "%s"' % table).getresult()
1381        self.assertEqual(r[0][0], 0)
1382
1383    def testTransaction(self):
1384        query = self.db.query
1385        query("drop table if exists test_table")
1386        self.addCleanup(query, "drop table test_table")
1387        query("create table test_table (n integer)")
1388        self.db.begin()
1389        query("insert into test_table values (1)")
1390        query("insert into test_table values (2)")
1391        self.db.commit()
1392        self.db.begin()
1393        query("insert into test_table values (3)")
1394        query("insert into test_table values (4)")
1395        self.db.rollback()
1396        self.db.begin()
1397        query("insert into test_table values (5)")
1398        self.db.savepoint('before6')
1399        query("insert into test_table values (6)")
1400        self.db.rollback('before6')
1401        query("insert into test_table values (7)")
1402        self.db.commit()
1403        self.db.begin()
1404        self.db.savepoint('before8')
1405        query("insert into test_table values (8)")
1406        self.db.release('before8')
1407        self.assertRaises(pg.ProgrammingError, self.db.rollback, 'before8')
1408        self.db.commit()
1409        self.db.start()
1410        query("insert into test_table values (9)")
1411        self.db.end()
1412        r = [r[0] for r in query(
1413            "select * from test_table order by 1").getresult()]
1414        self.assertEqual(r, [1, 2, 5, 7, 9])
1415
1416    def testContextManager(self):
1417        query = self.db.query
1418        query("drop table if exists test_table")
1419        self.addCleanup(query, "drop table test_table")
1420        query("create table test_table (n integer check(n>0))")
1421        with self.db:
1422            query("insert into test_table values (1)")
1423            query("insert into test_table values (2)")
1424        try:
1425            with self.db:
1426                query("insert into test_table values (3)")
1427                query("insert into test_table values (4)")
1428                raise ValueError('test transaction should rollback')
1429        except ValueError as error:
1430            self.assertEqual(str(error), 'test transaction should rollback')
1431        with self.db:
1432            query("insert into test_table values (5)")
1433        try:
1434            with self.db:
1435                query("insert into test_table values (6)")
1436                query("insert into test_table values (-1)")
1437        except pg.ProgrammingError as error:
1438            self.assertTrue('check' in str(error))
1439        with self.db:
1440            query("insert into test_table values (7)")
1441        r = [r[0] for r in query(
1442            "select * from test_table order by 1").getresult()]
1443        self.assertEqual(r, [1, 2, 5, 7])
1444
1445    def testBytea(self):
1446        query = self.db.query
1447        query('drop table if exists bytea_test')
1448        self.addCleanup(query, 'drop table bytea_test')
1449        query('create table bytea_test (n smallint primary key, data bytea)')
1450        s = b"It's all \\ kinds \x00 of\r nasty \xff stuff!\n"
1451        r = self.db.escape_bytea(s)
1452        query('insert into bytea_test values(3,$1)', (r,))
1453        r = query('select * from bytea_test where n=3').getresult()
1454        self.assertEqual(len(r), 1)
1455        r = r[0]
1456        self.assertEqual(len(r), 2)
1457        self.assertEqual(r[0], 3)
1458        r = r[1]
1459        self.assertIsInstance(r, str)
1460        r = self.db.unescape_bytea(r)
1461        self.assertIsInstance(r, bytes)
1462        self.assertEqual(r, s)
1463
1464    def testInsertUpdateGetBytea(self):
1465        query = self.db.query
1466        query('drop table if exists bytea_test')
1467        self.addCleanup(query, 'drop table bytea_test')
1468        query('create table bytea_test (n smallint primary key, data bytea)')
1469        # insert null value
1470        r = self.db.insert('bytea_test', n=0, data=None)
1471        self.assertIsInstance(r, dict)
1472        self.assertIn('n', r)
1473        self.assertEqual(r['n'], 0)
1474        self.assertIn('data', r)
1475        self.assertIsNone(r['data'])
1476        s = b'None'
1477        r = self.db.update('bytea_test', n=0, data=s)
1478        self.assertIsInstance(r, dict)
1479        self.assertIn('n', r)
1480        self.assertEqual(r['n'], 0)
1481        self.assertIn('data', r)
1482        r = r['data']
1483        self.assertIsInstance(r, bytes)
1484        self.assertEqual(r, s)
1485        r = self.db.update('bytea_test', n=0, data=None)
1486        self.assertIsNone(r['data'])
1487        # insert as bytes
1488        s = b"It's all \\ kinds \x00 of\r nasty \xff stuff!\n"
1489        r = self.db.insert('bytea_test', n=5, data=s)
1490        self.assertIsInstance(r, dict)
1491        self.assertIn('n', r)
1492        self.assertEqual(r['n'], 5)
1493        self.assertIn('data', r)
1494        r = r['data']
1495        self.assertIsInstance(r, bytes)
1496        self.assertEqual(r, s)
1497        # update as bytes
1498        s += b"and now even more \x00 nasty \t stuff!\f"
1499        r = self.db.update('bytea_test', n=5, data=s)
1500        self.assertIsInstance(r, dict)
1501        self.assertIn('n', r)
1502        self.assertEqual(r['n'], 5)
1503        self.assertIn('data', r)
1504        r = r['data']
1505        self.assertIsInstance(r, bytes)
1506        self.assertEqual(r, s)
1507        r = query('select * from bytea_test where n=5').getresult()
1508        self.assertEqual(len(r), 1)
1509        r = r[0]
1510        self.assertEqual(len(r), 2)
1511        self.assertEqual(r[0], 5)
1512        r = r[1]
1513        self.assertIsInstance(r, str)
1514        r = self.db.unescape_bytea(r)
1515        self.assertIsInstance(r, bytes)
1516        self.assertEqual(r, s)
1517        r = self.db.get('bytea_test', dict(n=5))
1518        self.assertIsInstance(r, dict)
1519        self.assertIn('n', r)
1520        self.assertEqual(r['n'], 5)
1521        self.assertIn('data', r)
1522        r = r['data']
1523        self.assertIsInstance(r, bytes)
1524        self.assertEqual(r, s)
1525
1526    def testDebugWithCallable(self):
1527        if debug:
1528            self.assertEqual(self.db.debug, debug)
1529        else:
1530            self.assertIsNone(self.db.debug)
1531        s = []
1532        self.db.debug = s.append
1533        try:
1534            self.db.query("select 1")
1535            self.db.query("select 2")
1536            self.assertEqual(s, ["select 1", "select 2"])
1537        finally:
1538            self.db.debug = debug
1539
1540
1541class TestDBClassNonStdOpts(TestDBClass):
1542    """Test the methods of the DB class with non-standard global options."""
1543
1544    @classmethod
1545    def setUpClass(cls):
1546        cls.saved_options = {}
1547        cls.set_option('decimal', float)
1548        not_bool = not pg.get_bool()
1549        cls.set_option('bool', not_bool)
1550        unnamed_result = lambda q: q.getresult()
1551        cls.set_option('namedresult', unnamed_result)
1552        super(TestDBClassNonStdOpts, cls).setUpClass()
1553
1554    @classmethod
1555    def tearDownClass(cls):
1556        super(TestDBClassNonStdOpts, cls).tearDownClass()
1557        cls.reset_option('namedresult')
1558        cls.reset_option('bool')
1559        cls.reset_option('decimal')
1560
1561    @classmethod
1562    def set_option(cls, option, value):
1563        cls.saved_options[option] = getattr(pg, 'get_' + option)()
1564        return getattr(pg, 'set_' + option)(value)
1565
1566    @classmethod
1567    def reset_option(cls, option):
1568        return getattr(pg, 'set_' + option)(cls.saved_options[option])
1569
1570
1571class TestSchemas(unittest.TestCase):
1572    """Test correct handling of schemas (namespaces)."""
1573
1574    @classmethod
1575    def setUpClass(cls):
1576        db = DB()
1577        query = db.query
1578        query("set client_min_messages=warning")
1579        for num_schema in range(5):
1580            if num_schema:
1581                schema = "s%d" % num_schema
1582                query("drop schema if exists %s cascade" % (schema,))
1583                try:
1584                    query("create schema %s" % (schema,))
1585                except pg.ProgrammingError:
1586                    raise RuntimeError("The test user cannot create schemas.\n"
1587                        "Grant create on database %s to the user"
1588                        " for running these tests." % dbname)
1589            else:
1590                schema = "public"
1591                query("drop table if exists %s.t" % (schema,))
1592                query("drop table if exists %s.t%d" % (schema, num_schema))
1593            query("create table %s.t with oids as select 1 as n, %d as d"
1594                  % (schema, num_schema))
1595            query("create table %s.t%d with oids as select 1 as n, %d as d"
1596                  % (schema, num_schema, num_schema))
1597        db.close()
1598
1599    @classmethod
1600    def tearDownClass(cls):
1601        db = DB()
1602        query = db.query
1603        query("set client_min_messages=warning")
1604        for num_schema in range(5):
1605            if num_schema:
1606                schema = "s%d" % num_schema
1607                query("drop schema %s cascade" % (schema,))
1608            else:
1609                schema = "public"
1610                query("drop table %s.t" % (schema,))
1611                query("drop table %s.t%d" % (schema, num_schema))
1612        db.close()
1613
1614    def setUp(self):
1615        self.db = DB()
1616        self.db.query("set client_min_messages=warning")
1617
1618    def tearDown(self):
1619        self.doCleanups()
1620        self.db.close()
1621
1622    def testGetTables(self):
1623        tables = self.db.get_tables()
1624        for num_schema in range(5):
1625            if num_schema:
1626                schema = "s" + str(num_schema)
1627            else:
1628                schema = "public"
1629            for t in (schema + ".t",
1630                    schema + ".t" + str(num_schema)):
1631                self.assertIn(t, tables)
1632
1633    def testGetAttnames(self):
1634        get_attnames = self.db.get_attnames
1635        query = self.db.query
1636        result = {'oid': 'int', 'd': 'int', 'n': 'int'}
1637        r = get_attnames("t")
1638        self.assertEqual(r, result)
1639        r = get_attnames("s4.t4")
1640        self.assertEqual(r, result)
1641        query("drop table if exists s3.t3m")
1642        self.addCleanup(query, "drop table s3.t3m")
1643        query("create table s3.t3m with oids as select 1 as m")
1644        result_m = {'oid': 'int', 'm': 'int'}
1645        r = get_attnames("s3.t3m")
1646        self.assertEqual(r, result_m)
1647        query("set search_path to s1,s3")
1648        r = get_attnames("t3")
1649        self.assertEqual(r, result)
1650        r = get_attnames("t3m")
1651        self.assertEqual(r, result_m)
1652
1653    def testGet(self):
1654        get = self.db.get
1655        query = self.db.query
1656        PrgError = pg.ProgrammingError
1657        self.assertEqual(get("t", 1, 'n')['d'], 0)
1658        self.assertEqual(get("t0", 1, 'n')['d'], 0)
1659        self.assertEqual(get("public.t", 1, 'n')['d'], 0)
1660        self.assertEqual(get("public.t0", 1, 'n')['d'], 0)
1661        self.assertRaises(PrgError, get, "public.t1", 1, 'n')
1662        self.assertEqual(get("s1.t1", 1, 'n')['d'], 1)
1663        self.assertEqual(get("s3.t", 1, 'n')['d'], 3)
1664        query("set search_path to s2,s4")
1665        self.assertRaises(PrgError, get, "t1", 1, 'n')
1666        self.assertEqual(get("t4", 1, 'n')['d'], 4)
1667        self.assertRaises(PrgError, get, "t3", 1, 'n')
1668        self.assertEqual(get("t", 1, 'n')['d'], 2)
1669        self.assertEqual(get("s3.t3", 1, 'n')['d'], 3)
1670        query("set search_path to s1,s3")
1671        self.assertRaises(PrgError, get, "t2", 1, 'n')
1672        self.assertEqual(get("t3", 1, 'n')['d'], 3)
1673        self.assertRaises(PrgError, get, "t4", 1, 'n')
1674        self.assertEqual(get("t", 1, 'n')['d'], 1)
1675        self.assertEqual(get("s4.t4", 1, 'n')['d'], 4)
1676
1677    def testMunging(self):
1678        get = self.db.get
1679        query = self.db.query
1680        r = get("t", 1, 'n')
1681        self.assertIn('oid(t)', r)
1682        query("set search_path to s2")
1683        r = get("t2", 1, 'n')
1684        self.assertIn('oid(t2)', r)
1685        query("set search_path to s3")
1686        r = get("t", 1, 'n')
1687        self.assertIn('oid(t)', r)
1688
1689
1690if __name__ == '__main__':
1691    unittest.main()
Note: See TracBrowser for help on using the repository browser.