source: trunk/tests/test_classic_dbwrapper.py @ 743

Last change on this file since 743 was 743, checked in by cito, 4 years ago

Test error messages and security of the get() method

The get() method should be immune against SQL hacking with apostrophes in
values, and give a proper and helpful error message if a row is not found.

  • Property svn:executable set to *
  • Property svn:keywords set to Id
File size: 62.8 KB
Line 
1#! /usr/bin/python
2# -*- coding: utf-8 -*-
3
4"""Test the classic PyGreSQL interface.
5
6Sub-tests for the DB wrapper object.
7
8Contributed by Christoph Zwerschke.
9
10These tests need a database to test against.
11
12"""
13
14try:
15    import unittest2 as unittest  # for Python < 2.7
16except ImportError:
17    import unittest
18import os
19
20import pg  # the module under test
21
22from decimal import Decimal
23
24# We need a database to test against.  If LOCAL_PyGreSQL.py exists we will
25# get our information from that.  Otherwise we use the defaults.
26# The current user must have create schema privilege on the database.
27dbname = 'unittest'
28dbhost = None
29dbport = 5432
30
31debug = False  # let DB wrapper print debugging output
32
33try:
34    from .LOCAL_PyGreSQL import *
35except (ImportError, ValueError):
36    try:
37        from LOCAL_PyGreSQL import *
38    except ImportError:
39        pass
40
41try:
42    long
43except NameError:  # Python >= 3.0
44    long = int
45
46try:
47    unicode
48except NameError:  # Python >= 3.0
49    unicode = str
50
51try:
52    from collections import OrderedDict
53except ImportError:  # Python 2.6 or 3.0
54    OrderedDict = dict
55
56windows = os.name == 'nt'
57
58# There is a known a bug in libpq under Windows which can cause
59# the interface to crash when calling PQhost():
60do_not_ask_for_host = windows
61do_not_ask_for_host_reason = 'libpq issue on Windows'
62
63
64def DB():
65    """Create a DB wrapper object connecting to the test database."""
66    db = pg.DB(dbname, dbhost, dbport)
67    if debug:
68        db.debug = debug
69    db.query("set client_min_messages=warning")
70    return db
71
72
73class TestDBClassBasic(unittest.TestCase):
74    """Test existence of the DB class wrapped pg connection methods."""
75
76    def setUp(self):
77        self.db = DB()
78
79    def tearDown(self):
80        try:
81            self.db.close()
82        except pg.InternalError:
83            pass
84
85    def testAllDBAttributes(self):
86        attributes = [
87            'begin',
88            'cancel',
89            'clear',
90            'close',
91            'commit',
92            'db',
93            'dbname',
94            'debug',
95            'delete',
96            'end',
97            'endcopy',
98            'error',
99            'escape_bytea',
100            'escape_identifier',
101            'escape_literal',
102            'escape_string',
103            'fileno',
104            'get',
105            'get_attnames',
106            'get_databases',
107            'get_notice_receiver',
108            'get_relations',
109            'get_tables',
110            'getline',
111            'getlo',
112            'getnotify',
113            'has_table_privilege',
114            'host',
115            'insert',
116            'inserttable',
117            'locreate',
118            'loimport',
119            'notification_handler',
120            'options',
121            'parameter',
122            'pkey',
123            'port',
124            'protocol_version',
125            'putline',
126            'query',
127            'release',
128            'reopen',
129            'reset',
130            'rollback',
131            'savepoint',
132            'server_version',
133            'set_notice_receiver',
134            'source',
135            'start',
136            'status',
137            'transaction',
138            'unescape_bytea',
139            'update',
140            'upsert',
141            'use_regtypes',
142            'user',
143        ]
144        db_attributes = [a for a in dir(self.db)
145            if not a.startswith('_')]
146        self.assertEqual(attributes, db_attributes)
147
148    def testAttributeDb(self):
149        self.assertEqual(self.db.db.db, dbname)
150
151    def testAttributeDbname(self):
152        self.assertEqual(self.db.dbname, dbname)
153
154    def testAttributeError(self):
155        error = self.db.error
156        self.assertTrue(not error or 'krb5_' in error)
157        self.assertEqual(self.db.error, self.db.db.error)
158
159    @unittest.skipIf(do_not_ask_for_host, do_not_ask_for_host_reason)
160    def testAttributeHost(self):
161        def_host = 'localhost'
162        host = self.db.host
163        self.assertIsInstance(host, str)
164        self.assertEqual(host, dbhost or def_host)
165        self.assertEqual(host, self.db.db.host)
166
167    def testAttributeOptions(self):
168        no_options = ''
169        options = self.db.options
170        self.assertEqual(options, no_options)
171        self.assertEqual(options, self.db.db.options)
172
173    def testAttributePort(self):
174        def_port = 5432
175        port = self.db.port
176        self.assertIsInstance(port, int)
177        self.assertEqual(port, dbport or def_port)
178        self.assertEqual(port, self.db.db.port)
179
180    def testAttributeProtocolVersion(self):
181        protocol_version = self.db.protocol_version
182        self.assertIsInstance(protocol_version, int)
183        self.assertTrue(2 <= protocol_version < 4)
184        self.assertEqual(protocol_version, self.db.db.protocol_version)
185
186    def testAttributeServerVersion(self):
187        server_version = self.db.server_version
188        self.assertIsInstance(server_version, int)
189        self.assertTrue(70400 <= server_version < 100000)
190        self.assertEqual(server_version, self.db.db.server_version)
191
192    def testAttributeStatus(self):
193        status_ok = 1
194        status = self.db.status
195        self.assertIsInstance(status, int)
196        self.assertEqual(status, status_ok)
197        self.assertEqual(status, self.db.db.status)
198
199    def testAttributeUser(self):
200        no_user = 'Deprecated facility'
201        user = self.db.user
202        self.assertTrue(user)
203        self.assertIsInstance(user, str)
204        self.assertNotEqual(user, no_user)
205        self.assertEqual(user, self.db.db.user)
206
207    def testMethodEscapeLiteral(self):
208        self.assertEqual(self.db.escape_literal(''), "''")
209
210    def testMethodEscapeIdentifier(self):
211        self.assertEqual(self.db.escape_identifier(''), '""')
212
213    def testMethodEscapeString(self):
214        self.assertEqual(self.db.escape_string(''), '')
215
216    def testMethodEscapeBytea(self):
217        self.assertEqual(self.db.escape_bytea('').replace(
218            '\\x', '').replace('\\', ''), '')
219
220    def testMethodUnescapeBytea(self):
221        self.assertEqual(self.db.unescape_bytea(''), b'')
222
223    def testMethodQuery(self):
224        query = self.db.query
225        query("select 1+1")
226        query("select 1+$1+$2", 2, 3)
227        query("select 1+$1+$2", (2, 3))
228        query("select 1+$1+$2", [2, 3])
229        query("select 1+$1", 1)
230
231    def testMethodQueryEmpty(self):
232        self.assertRaises(ValueError, self.db.query, '')
233
234    def testMethodQueryProgrammingError(self):
235        try:
236            self.db.query("select 1/0")
237        except pg.ProgrammingError as error:
238            self.assertEqual(error.sqlstate, '22012')
239
240    def testMethodEndcopy(self):
241        try:
242            self.db.endcopy()
243        except IOError:
244            pass
245
246    def testMethodClose(self):
247        self.db.close()
248        try:
249            self.db.reset()
250        except pg.Error:
251            pass
252        else:
253            self.fail('Reset should give an error for a closed connection')
254        self.assertRaises(pg.InternalError, self.db.close)
255        self.assertRaises(pg.InternalError, self.db.query, 'select 1')
256
257    def testExistingConnection(self):
258        db = pg.DB(self.db.db)
259        self.assertEqual(self.db.db, db.db)
260        self.assertTrue(db.db)
261        db.close()
262        self.assertTrue(db.db)
263        db.reopen()
264        self.assertTrue(db.db)
265        db.close()
266        self.assertTrue(db.db)
267        db = pg.DB(self.db)
268        self.assertEqual(self.db.db, db.db)
269        db = pg.DB(db=self.db.db)
270        self.assertEqual(self.db.db, db.db)
271
272        class DB2:
273            pass
274
275        db2 = DB2()
276        db2._cnx = self.db.db
277        db = pg.DB(db2)
278        self.assertEqual(self.db.db, db.db)
279
280
281class TestDBClass(unittest.TestCase):
282    """Test the methods of the DB class wrapped pg connection."""
283
284    @classmethod
285    def setUpClass(cls):
286        db = DB()
287        db.query("drop table if exists test cascade")
288        db.query("create table test ("
289            "i2 smallint, i4 integer, i8 bigint,"
290            "d numeric, f4 real, f8 double precision, m money, "
291            "v4 varchar(4), c4 char(4), t text)")
292        db.query("create or replace view test_view as"
293            " select i4, v4 from test")
294        db.close()
295
296    @classmethod
297    def tearDownClass(cls):
298        db = DB()
299        db.query("drop table test cascade")
300        db.close()
301
302    def setUp(self):
303        self.db = DB()
304        query = self.db.query
305        query('set client_encoding=utf8')
306        query('set standard_conforming_strings=on')
307        query("set lc_monetary='C'")
308        query("set datestyle='ISO,YMD'")
309        query('set bytea_output=hex')
310
311    def tearDown(self):
312        self.db.close()
313
314    def testClassName(self):
315        self.assertEqual(self.db.__class__.__name__, 'DB')
316
317    def testModuleName(self):
318        self.assertEqual(self.db.__module__, 'pg')
319        self.assertEqual(self.db.__class__.__module__, 'pg')
320
321    def testEscapeLiteral(self):
322        f = self.db.escape_literal
323        r = f(b"plain")
324        self.assertIsInstance(r, bytes)
325        self.assertEqual(r, b"'plain'")
326        r = f(u"plain")
327        self.assertIsInstance(r, unicode)
328        self.assertEqual(r, u"'plain'")
329        r = f(u"that's kÀse".encode('utf-8'))
330        self.assertIsInstance(r, bytes)
331        self.assertEqual(r, u"'that''s kÀse'".encode('utf-8'))
332        r = f(u"that's kÀse")
333        self.assertIsInstance(r, unicode)
334        self.assertEqual(r, u"'that''s kÀse'")
335        self.assertEqual(f(r"It's fine to have a \ inside."),
336            r" E'It''s fine to have a \\ inside.'")
337        self.assertEqual(f('No "quotes" must be escaped.'),
338            "'No \"quotes\" must be escaped.'")
339
340    def testEscapeIdentifier(self):
341        f = self.db.escape_identifier
342        r = f(b"plain")
343        self.assertIsInstance(r, bytes)
344        self.assertEqual(r, b'"plain"')
345        r = f(u"plain")
346        self.assertIsInstance(r, unicode)
347        self.assertEqual(r, u'"plain"')
348        r = f(u"that's kÀse".encode('utf-8'))
349        self.assertIsInstance(r, bytes)
350        self.assertEqual(r, u'"that\'s kÀse"'.encode('utf-8'))
351        r = f(u"that's kÀse")
352        self.assertIsInstance(r, unicode)
353        self.assertEqual(r, u'"that\'s kÀse"')
354        self.assertEqual(f(r"It's fine to have a \ inside."),
355            '"It\'s fine to have a \\ inside."')
356        self.assertEqual(f('All "quotes" must be escaped.'),
357            '"All ""quotes"" must be escaped."')
358
359    def testEscapeString(self):
360        f = self.db.escape_string
361        r = f(b"plain")
362        self.assertIsInstance(r, bytes)
363        self.assertEqual(r, b"plain")
364        r = f(u"plain")
365        self.assertIsInstance(r, unicode)
366        self.assertEqual(r, u"plain")
367        r = f(u"that's kÀse".encode('utf-8'))
368        self.assertIsInstance(r, bytes)
369        self.assertEqual(r, u"that''s kÀse".encode('utf-8'))
370        r = f(u"that's kÀse")
371        self.assertIsInstance(r, unicode)
372        self.assertEqual(r, u"that''s kÀse")
373        self.assertEqual(f(r"It's fine to have a \ inside."),
374            r"It''s fine to have a \ inside.")
375
376    def testEscapeBytea(self):
377        f = self.db.escape_bytea
378        # note that escape_byte always returns hex output since Pg 9.0,
379        # regardless of the bytea_output setting
380        r = f(b'plain')
381        self.assertIsInstance(r, bytes)
382        self.assertEqual(r, b'\\x706c61696e')
383        r = f(u'plain')
384        self.assertIsInstance(r, unicode)
385        self.assertEqual(r, u'\\x706c61696e')
386        r = f(u"das is' kÀse".encode('utf-8'))
387        self.assertIsInstance(r, bytes)
388        self.assertEqual(r, b'\\x64617320697327206bc3a47365')
389        r = f(u"das is' kÀse")
390        self.assertIsInstance(r, unicode)
391        self.assertEqual(r, u'\\x64617320697327206bc3a47365')
392        self.assertEqual(f(b'O\x00ps\xff!'), b'\\x4f007073ff21')
393
394    def testUnescapeBytea(self):
395        f = self.db.unescape_bytea
396        r = f(b'plain')
397        self.assertIsInstance(r, bytes)
398        self.assertEqual(r, b'plain')
399        r = f(u'plain')
400        self.assertIsInstance(r, bytes)
401        self.assertEqual(r, b'plain')
402        r = f(b"das is' k\\303\\244se")
403        self.assertIsInstance(r, bytes)
404        self.assertEqual(r, u"das is' kÀse".encode('utf8'))
405        r = f(u"das is' k\\303\\244se")
406        self.assertIsInstance(r, bytes)
407        self.assertEqual(r, u"das is' kÀse".encode('utf8'))
408        self.assertEqual(f(r'O\\000ps\\377!'), b'O\\000ps\\377!')
409        self.assertEqual(f(r'\\x706c61696e'), b'\\x706c61696e')
410        self.assertEqual(f(r'\\x746861742773206be47365'),
411            b'\\x746861742773206be47365')
412        self.assertEqual(f(r'\\x4f007073ff21'), b'\\x4f007073ff21')
413
414    def testGetAttnames(self):
415        get_attnames = self.db.get_attnames
416        query = self.db.query
417        query("drop table if exists test_table")
418        query("create table test_table("
419            " n int, alpha smallint, beta bool,"
420            " gamma char(5), tau text, v varchar(3))")
421        r = get_attnames("test_table")
422        self.assertIsInstance(r, dict)
423        self.assertEquals(r, dict(
424            n='int', alpha='int', beta='bool',
425            gamma='text', tau='text', v='text'))
426        query("drop table test_table")
427
428    def testGetAttnamesWithQuotes(self):
429        get_attnames = self.db.get_attnames
430        query = self.db.query
431        table = 'test table for get_attnames()'
432        query('drop table if exists "%s"' % table)
433        query('create table "%s"('
434            '"Prime!" smallint,'
435            '"much space" integer, "Questions?" text)' % table)
436        r = get_attnames(table)
437        self.assertIsInstance(r, dict)
438        self.assertEquals(r, {
439            'Prime!': 'int', 'much space': 'int', 'Questions?': 'text'})
440        query('drop table "%s"' % table)
441
442    def testGetAttnamesWithRegtypes(self):
443        get_attnames = self.db.get_attnames
444        query = self.db.query
445        query("drop table if exists test_table")
446        query("create table test_table("
447            " n int, alpha smallint, beta bool,"
448            " gamma char(5), tau text, v varchar(3))")
449        self.db.use_regtypes(True)
450        try:
451            r = get_attnames("test_table")
452            self.assertIsInstance(r, dict)
453        finally:
454            self.db.use_regtypes(False)
455        self.assertEquals(r, dict(
456            n='integer', alpha='smallint', beta='boolean',
457            gamma='character', tau='text', v='character varying'))
458        query("drop table test_table")
459
460    def testGetAttnamesIsCached(self):
461        get_attnames = self.db.get_attnames
462        query = self.db.query
463        query("drop table if exists test_table")
464        query("create table test_table(col int)")
465        r = get_attnames("test_table")
466        self.assertIsInstance(r, dict)
467        self.assertEquals(r, dict(col='int'))
468        query("drop table test_table")
469        query("create table test_table(col text)")
470        r = get_attnames("test_table")
471        self.assertEquals(r, dict(col='int'))
472        r = get_attnames("test_table", flush=True)
473        self.assertEquals(r, dict(col='text'))
474        query("drop table test_table")
475        r = get_attnames("test_table")
476        self.assertEquals(r, dict(col='text'))
477        self.assertRaises(pg.ProgrammingError,
478            get_attnames, "test_table", flush=True)
479
480    def testGetAttnamesIsOrdered(self):
481        get_attnames = self.db.get_attnames
482        query = self.db.query
483        query("drop table if exists test_table")
484        query("create table test_table("
485            " n int, alpha smallint, v varchar(3),"
486            " gamma char(5), tau text, beta bool)")
487        r = get_attnames("test_table")
488        self.assertIsInstance(r, OrderedDict)
489        self.assertEquals(r, OrderedDict([
490            ('n', 'int'), ('alpha', 'int'), ('v', 'text'),
491            ('gamma', 'text'), ('tau', 'text'), ('beta', 'bool')]))
492        query("drop table test_table")
493        if OrderedDict is dict:
494            self.skipTest('OrderedDict is not supported')
495        r = ' '.join(list(r.keys()))
496        self.assertEquals(r, 'n alpha v gamma tau beta')
497
498    def testQuery(self):
499        query = self.db.query
500        query("drop table if exists test_table")
501        q = "create table test_table (n integer) with oids"
502        r = query(q)
503        self.assertIsNone(r)
504        q = "insert into test_table values (1)"
505        r = query(q)
506        self.assertIsInstance(r, int)
507        q = "insert into test_table select 2"
508        r = query(q)
509        self.assertIsInstance(r, int)
510        oid = r
511        q = "select oid from test_table where n=2"
512        r = query(q).getresult()
513        self.assertEqual(len(r), 1)
514        r = r[0]
515        self.assertEqual(len(r), 1)
516        r = r[0]
517        self.assertEqual(r, oid)
518        q = "insert into test_table select 3 union select 4 union select 5"
519        r = query(q)
520        self.assertIsInstance(r, str)
521        self.assertEqual(r, '3')
522        q = "update test_table set n=4 where n<5"
523        r = query(q)
524        self.assertIsInstance(r, str)
525        self.assertEqual(r, '4')
526        q = "delete from test_table"
527        r = query(q)
528        self.assertIsInstance(r, str)
529        self.assertEqual(r, '5')
530        query("drop table test_table")
531
532    def testMultipleQueries(self):
533        self.assertEqual(self.db.query(
534            "create temporary table test_multi (n integer);"
535            "insert into test_multi values (4711);"
536            "select n from test_multi").getresult()[0][0], 4711)
537
538    def testQueryWithParams(self):
539        query = self.db.query
540        query("drop table if exists test_table")
541        q = "create table test_table (n1 integer, n2 integer) with oids"
542        query(q)
543        q = "insert into test_table values ($1, $2)"
544        r = query(q, (1, 2))
545        self.assertIsInstance(r, int)
546        r = query(q, [3, 4])
547        self.assertIsInstance(r, int)
548        r = query(q, [5, 6])
549        self.assertIsInstance(r, int)
550        q = "select * from test_table order by 1, 2"
551        self.assertEqual(query(q).getresult(),
552            [(1, 2), (3, 4), (5, 6)])
553        q = "select * from test_table where n1=$1 and n2=$2"
554        self.assertEqual(query(q, 3, 4).getresult(), [(3, 4)])
555        q = "update test_table set n2=$2 where n1=$1"
556        r = query(q, 3, 7)
557        self.assertEqual(r, '1')
558        q = "select * from test_table order by 1, 2"
559        self.assertEqual(query(q).getresult(),
560            [(1, 2), (3, 7), (5, 6)])
561        q = "delete from test_table where n2!=$1"
562        r = query(q, 4)
563        self.assertEqual(r, '3')
564        query("drop table test_table")
565
566    def testEmptyQuery(self):
567        self.assertRaises(ValueError, self.db.query, '')
568
569    def testQueryProgrammingError(self):
570        try:
571            self.db.query("select 1/0")
572        except pg.ProgrammingError as error:
573            self.assertEqual(error.sqlstate, '22012')
574
575    def testPkey(self):
576        query = self.db.query
577        pkey = self.db.pkey
578        for t in ('pkeytest', 'primary key test'):
579            for n in range(7):
580                query('drop table if exists "%s%d"' % (t, n))
581            query('create table "%s0" ('
582                "a smallint)" % t)
583            query('create table "%s1" ('
584                "b smallint primary key)" % t)
585            query('create table "%s2" ('
586                "c smallint, d smallint primary key)" % t)
587            query('create table "%s3" ('
588                "e smallint, f smallint, g smallint, "
589                "h smallint, i smallint, "
590                "primary key (f, h))" % t)
591            query('create table "%s4" ('
592                "more_than_one_letter varchar primary key)" % t)
593            query('create table "%s5" ('
594                '"with space" date primary key)' % t)
595            query('create table "%s6" ('
596                'a_very_long_column_name varchar, '
597                '"with space" date, '
598                '"42" int, '
599                "primary key (a_very_long_column_name, "
600                '"with space", "42"))' % t)
601            self.assertRaises(KeyError, pkey, '%s0' % t)
602            self.assertEqual(pkey('%s1' % t), 'b')
603            self.assertEqual(pkey('%s2' % t), 'd')
604            r = pkey('%s3' % t)
605            self.assertIsInstance(r, frozenset)
606            self.assertEqual(r, frozenset('fh'))
607            self.assertEqual(pkey('%s4' % t), 'more_than_one_letter')
608            self.assertEqual(pkey('%s5' % t), 'with space')
609            r = pkey('%s6' % t)
610            self.assertIsInstance(r, frozenset)
611            self.assertEqual(r, frozenset([
612                'a_very_long_column_name', 'with space', '42']))
613            # a newly added primary key will be detected
614            query('alter table "%s0" add primary key (a)' % t)
615            self.assertEqual(pkey('%s0' % t), 'a')
616            # a changed primary key will not be detected,
617            # indicating that the internal cache is operating
618            query('alter table "%s1" rename column b to x' % t)
619            self.assertEqual(pkey('%s1' % t), 'b')
620            # we get the changed primary key when the cache is flushed
621            self.assertEqual(pkey('%s1' % t, flush=True), 'x')
622            for n in range(7):
623                query('drop table "%s%d"' % (t, n))
624
625    def testGetDatabases(self):
626        databases = self.db.get_databases()
627        self.assertIn('template0', databases)
628        self.assertIn('template1', databases)
629        self.assertNotIn('not existing database', databases)
630        self.assertIn('postgres', databases)
631        self.assertIn(dbname, databases)
632
633    def testGetTables(self):
634        get_tables = self.db.get_tables
635        result1 = get_tables()
636        self.assertIsInstance(result1, list)
637        for t in result1:
638            t = t.split('.', 1)
639            self.assertGreaterEqual(len(t), 2)
640            if len(t) > 2:
641                self.assertTrue(t[1].startswith('"'))
642            t = t[0]
643            self.assertNotEqual(t, 'information_schema')
644            self.assertFalse(t.startswith('pg_'))
645        tables = ('"A very Special Name"',
646            '"A_MiXeD_quoted_NaMe"', 'a1', 'a2',
647            'A_MiXeD_NaMe', '"another special name"',
648            'averyveryveryveryveryveryverylongtablename',
649            'b0', 'b3', 'x', 'xx', 'xXx', 'y', 'z')
650        for t in tables:
651            self.db.query('drop table if exists %s' % t)
652            self.db.query("create table %s"
653                " as select 0" % t)
654        result3 = get_tables()
655        result2 = []
656        for t in result3:
657            if t not in result1:
658                result2.append(t)
659        result3 = []
660        for t in tables:
661            if not t.startswith('"'):
662                t = t.lower()
663            result3.append('public.' + t)
664        self.assertEqual(result2, result3)
665        for t in result2:
666            self.db.query('drop table %s' % t)
667        result2 = get_tables()
668        self.assertEqual(result2, result1)
669
670    def testGetRelations(self):
671        get_relations = self.db.get_relations
672        result = get_relations()
673        self.assertIn('public.test', result)
674        self.assertIn('public.test_view', result)
675        result = get_relations('rv')
676        self.assertIn('public.test', result)
677        self.assertIn('public.test_view', result)
678        result = get_relations('r')
679        self.assertIn('public.test', result)
680        self.assertNotIn('public.test_view', result)
681        result = get_relations('v')
682        self.assertNotIn('public.test', result)
683        self.assertIn('public.test_view', result)
684        result = get_relations('cisSt')
685        self.assertNotIn('public.test', result)
686        self.assertNotIn('public.test_view', result)
687
688    def testAttnames(self):
689        self.assertRaises(pg.ProgrammingError,
690            self.db.get_attnames, 'does_not_exist')
691        self.assertRaises(pg.ProgrammingError,
692            self.db.get_attnames, 'has.too.many.dots')
693        for table in ('attnames_test_table', 'test table for attnames'):
694            self.db.query('drop table if exists "%s"' % table)
695            self.db.query('create table "%s" ('
696                'a smallint, b integer, c bigint, '
697                'e numeric, f float, f2 double precision, m money, '
698                'x smallint, y smallint, z smallint, '
699                'Normal_NaMe smallint, "Special Name" smallint, '
700                't text, u char(2), v varchar(2), '
701                'primary key (y, u)) with oids' % table)
702            attributes = self.db.get_attnames(table)
703            result = {'a': 'int', 'c': 'int', 'b': 'int',
704                'e': 'num', 'f': 'float', 'f2': 'float', 'm': 'money',
705                'normal_name': 'int', 'Special Name': 'int',
706                'u': 'text', 't': 'text', 'v': 'text',
707                'y': 'int', 'x': 'int', 'z': 'int', 'oid': 'int'}
708            self.assertEqual(attributes, result)
709            self.db.query('drop table "%s"' % table)
710
711    def testHasTablePrivilege(self):
712        can = self.db.has_table_privilege
713        self.assertEqual(can('test'), True)
714        self.assertEqual(can('test', 'select'), True)
715        self.assertEqual(can('test', 'SeLeCt'), True)
716        self.assertEqual(can('test', 'SELECT'), True)
717        self.assertEqual(can('test', 'insert'), True)
718        self.assertEqual(can('test', 'update'), True)
719        self.assertEqual(can('test', 'delete'), True)
720        self.assertEqual(can('pg_views', 'select'), True)
721        self.assertEqual(can('pg_views', 'delete'), False)
722        self.assertRaises(pg.ProgrammingError, can, 'test', 'foobar')
723        self.assertRaises(pg.ProgrammingError, can, 'table_does_not_exist')
724
725    def testGet(self):
726        get = self.db.get
727        query = self.db.query
728        table = 'get_test_table'
729        query('drop table if exists "%s"' % table)
730        query('create table "%s" ('
731            "n integer, t text) with oids" % table)
732        for n, t in enumerate('xyz'):
733            query('insert into "%s" values('"%d, '%s')"
734                % (table, n + 1, t))
735        self.assertRaises(pg.ProgrammingError, get, table, 2)
736        r = get(table, 2, 'n')
737        oid_table = 'oid(%s)' % table
738        self.assertIn(oid_table, r)
739        oid = r[oid_table]
740        self.assertIsInstance(oid, int)
741        result = {'t': 'y', 'n': 2, oid_table: oid}
742        self.assertEqual(r, result)
743        self.assertEqual(get(table + ' *', 2, 'n'), r)
744        self.assertEqual(get(table, oid, 'oid')['t'], 'y')
745        self.assertEqual(get(table, 1, 'n')['t'], 'x')
746        self.assertEqual(get(table, 3, 'n')['t'], 'z')
747        self.assertEqual(get(table, 2, 'n')['t'], 'y')
748        self.assertRaises(pg.DatabaseError, get, table, 4, 'n')
749        r['n'] = 3
750        self.assertEqual(get(table, r, 'n')['t'], 'z')
751        self.assertEqual(get(table, 1, 'n')['t'], 'x')
752        query('alter table "%s" alter n set not null' % table)
753        query('alter table "%s" add primary key (n)' % table)
754        self.assertEqual(get(table, 3)['t'], 'z')
755        self.assertEqual(get(table, 1)['t'], 'x')
756        self.assertEqual(get(table, 2)['t'], 'y')
757        r['n'] = 1
758        self.assertEqual(get(table, r)['t'], 'x')
759        r['n'] = 3
760        self.assertEqual(get(table, r)['t'], 'z')
761        r['n'] = 2
762        self.assertEqual(get(table, r)['t'], 'y')
763        query('drop table "%s"' % table)
764
765    def testGetWithCompositeKey(self):
766        get = self.db.get
767        query = self.db.query
768        table = 'get_test_table_1'
769        query('drop table if exists "%s"' % table)
770        query('create table "%s" ('
771            "n integer, t text, primary key (n))" % table)
772        for n, t in enumerate('abc'):
773            query('insert into "%s" values('
774                "%d, '%s')" % (table, n + 1, t))
775        self.assertEqual(get(table, 2)['t'], 'b')
776        query('drop table "%s"' % table)
777        table = 'get_test_table_2'
778        query('drop table if exists "%s"' % table)
779        query('create table "%s" ('
780            "n integer, m integer, t text, primary key (n, m))" % table)
781        for n in range(3):
782            for m in range(2):
783                t = chr(ord('a') + 2 * n + m)
784                query('insert into "%s" values('
785                    "%d, %d, '%s')" % (table, n + 1, m + 1, t))
786        self.assertRaises(pg.ProgrammingError, get, table, 2)
787        self.assertEqual(get(table, dict(n=2, m=2))['t'], 'd')
788        r = get(table, dict(n=1, m=2), ('n', 'm'))
789        self.assertEqual(r['t'], 'b')
790        r = get(table, dict(n=3, m=2), frozenset(['n', 'm']))
791        self.assertEqual(r['t'], 'f')
792        query('drop table "%s"' % table)
793
794    def testGetWithQuotedNames(self):
795        get = self.db.get
796        query = self.db.query
797        table = 'test table for get()'
798        query('drop table if exists "%s"' % table)
799        query('create table "%s" ('
800            '"Prime!" smallint primary key,'
801            '"much space" integer, "Questions?" text)' % table)
802        query('insert into "%s"'
803              " values(17, 1001, 'No!')" % table)
804        r = get(table, 17)
805        self.assertIsInstance(r, dict)
806        self.assertEqual(r['Prime!'], 17)
807        self.assertEqual(r['much space'], 1001)
808        self.assertEqual(r['Questions?'], 'No!')
809        query('drop table "%s"' % table)
810
811    def testGetFromView(self):
812        self.db.query('delete from test where i4=14')
813        self.db.query('insert into test (i4, v4) values('
814            "14, 'abc4')")
815        r = self.db.get('test_view', 14, 'i4')
816        self.assertIn('v4', r)
817        self.assertEqual(r['v4'], 'abc4')
818
819    def testGetLittleBobbyTables(self):
820        get = self.db.get
821        query = self.db.query
822        query("drop table if exists test_students")
823        query("create table test_students (firstname varchar primary key,"
824            " nickname varchar, grade char(2))")
825        query("insert into test_students values ("
826              "'D''Arcy', 'Darcey', 'A+')")
827        query("insert into test_students values ("
828              "'Sheldon', 'Moonpie', 'A+')")
829        query("insert into test_students values ("
830              "'Robert', 'Little Bobby Tables', 'D-')")
831        r = get('test_students', 'Sheldon')
832        self.assertEqual(r, dict(
833            firstname="Sheldon", nickname='Moonpie', grade='A+'))
834        r = get('test_students', 'Robert')
835        self.assertEqual(r, dict(
836            firstname="Robert", nickname='Little Bobby Tables', grade='D-'))
837        r = get('test_students', "D'Arcy")
838        self.assertEqual(r, dict(
839            firstname="D'Arcy", nickname='Darcey', grade='A+'))
840        try:
841            get('test_students', "D' Arcy")
842        except pg.DatabaseError as error:
843            self.assertEqual(str(error),
844                'No such record in test_students\nwhere "firstname" = $1\n'
845                'with $1="D\' Arcy"')
846        try:
847            get('test_students', "Robert'); TRUNCATE TABLE test_students;--")
848        except pg.DatabaseError as error:
849            self.assertEqual(str(error),
850                'No such record in test_students\nwhere "firstname" = $1\n'
851                'with $1="Robert\'); TRUNCATE TABLE test_students;--"')
852        q = "select * from test_students order by 1 limit 4"
853        r = query(q).getresult()
854        self.assertEqual(len(r), 3)
855        self.assertEqual(r[1][2], 'D-')
856        query('drop table test_students')
857
858    def testInsert(self):
859        insert = self.db.insert
860        query = self.db.query
861        bool_on = pg.get_bool()
862        decimal = pg.get_decimal()
863        table = 'insert_test_table'
864        query('drop table if exists "%s"' % table)
865        query('create table "%s" ('
866            "i2 smallint, i4 integer, i8 bigint,"
867            " d numeric, f4 real, f8 double precision, m money,"
868            " v4 varchar(4), c4 char(4), t text,"
869            " b boolean, ts timestamp) with oids" % table)
870        oid_table = 'oid(%s)' % table
871        tests = [dict(i2=None, i4=None, i8=None),
872            (dict(i2='', i4='', i8=''), dict(i2=None, i4=None, i8=None)),
873            (dict(i2=0, i4=0, i8=0), dict(i2=0, i4=0, i8=0)),
874            dict(i2=42, i4=123456, i8=9876543210),
875            dict(i2=2 ** 15 - 1,
876                i4=int(2 ** 31 - 1), i8=long(2 ** 63 - 1)),
877            dict(d=None), (dict(d=''), dict(d=None)),
878            dict(d=Decimal(0)), (dict(d=0), dict(d=Decimal(0))),
879            dict(f4=None, f8=None), dict(f4=0, f8=0),
880            (dict(f4='', f8=''), dict(f4=None, f8=None)),
881            (dict(d=1234.5, f4=1234.5, f8=1234.5),
882                  dict(d=Decimal('1234.5'))),
883            dict(d=Decimal('123.456789'), f4=12.375, f8=123.4921875),
884            dict(d=Decimal('123456789.9876543212345678987654321')),
885            dict(m=None), (dict(m=''), dict(m=None)),
886            dict(m=Decimal('-1234.56')),
887            (dict(m=('-1234.56')), dict(m=Decimal('-1234.56'))),
888            dict(m=Decimal('1234.56')), dict(m=Decimal('123456')),
889            (dict(m='1234.56'), dict(m=Decimal('1234.56'))),
890            (dict(m=1234.5), dict(m=Decimal('1234.5'))),
891            (dict(m=-1234.5), dict(m=Decimal('-1234.5'))),
892            (dict(m=123456), dict(m=Decimal('123456'))),
893            (dict(m='1234567.89'), dict(m=Decimal('1234567.89'))),
894            dict(b=None), (dict(b=''), dict(b=None)),
895            dict(b='f'), dict(b='t'),
896            (dict(b=0), dict(b='f')), (dict(b=1), dict(b='t')),
897            (dict(b=False), dict(b='f')), (dict(b=True), dict(b='t')),
898            (dict(b='0'), dict(b='f')), (dict(b='1'), dict(b='t')),
899            (dict(b='n'), dict(b='f')), (dict(b='y'), dict(b='t')),
900            (dict(b='no'), dict(b='f')), (dict(b='yes'), dict(b='t')),
901            (dict(b='off'), dict(b='f')), (dict(b='on'), dict(b='t')),
902            dict(v4=None, c4=None, t=None),
903            (dict(v4='', c4='', t=''), dict(c4=' ' * 4)),
904            dict(v4='1234', c4='1234', t='1234' * 10),
905            dict(v4='abcd', c4='abcd', t='abcdefg'),
906            (dict(v4='abc', c4='abc', t='abc'), dict(c4='abc ')),
907            dict(ts=None), (dict(ts=''), dict(ts=None)),
908            (dict(ts=0), dict(ts=None)), (dict(ts=False), dict(ts=None)),
909            dict(ts='2012-12-21 00:00:00'),
910            (dict(ts='2012-12-21'), dict(ts='2012-12-21 00:00:00')),
911            dict(ts='2012-12-21 12:21:12'),
912            dict(ts='2013-01-05 12:13:14'),
913            dict(ts='current_timestamp')]
914        for test in tests:
915            if isinstance(test, dict):
916                data = test
917                change = {}
918            else:
919                data, change = test
920            expect = data.copy()
921            expect.update(change)
922            if bool_on:
923                b = expect.get('b')
924                if b is not None:
925                    expect['b'] = b == 't'
926            if decimal is not Decimal:
927                d = expect.get('d')
928                if d is not None:
929                    expect['d'] = decimal(d)
930                m = expect.get('m')
931                if m is not None:
932                    expect['m'] = decimal(m)
933            self.assertEqual(insert(table, data), data)
934            self.assertIn(oid_table, data)
935            oid = data[oid_table]
936            self.assertIsInstance(oid, int)
937            data = dict(item for item in data.items()
938                if item[0] in expect)
939            ts = expect.get('ts')
940            if ts == 'current_timestamp':
941                ts = expect['ts'] = data['ts']
942                if len(ts) > 19:
943                    self.assertEqual(ts[19], '.')
944                    ts = ts[:19]
945                else:
946                    self.assertEqual(len(ts), 19)
947                self.assertTrue(ts[:4].isdigit())
948                self.assertEqual(ts[4], '-')
949                self.assertEqual(ts[10], ' ')
950                self.assertTrue(ts[11:13].isdigit())
951                self.assertEqual(ts[13], ':')
952            self.assertEqual(data, expect)
953            data = query(
954                'select oid,* from "%s"' % table).dictresult()[0]
955            self.assertEqual(data['oid'], oid)
956            data = dict(item for item in data.items()
957                if item[0] in expect)
958            self.assertEqual(data, expect)
959            query('delete from "%s"' % table)
960        query('drop table "%s"' % table)
961
962    def testInsertWithQuotedNames(self):
963        insert = self.db.insert
964        query = self.db.query
965        table = 'test table for insert()'
966        query('drop table if exists "%s"' % table)
967        query('create table "%s" ('
968            '"Prime!" smallint primary key,'
969            '"much space" integer, "Questions?" text)' % table)
970        r = {'Prime!': 11, 'much space': 2002, 'Questions?': 'What?'}
971        r = insert(table, r)
972        self.assertIsInstance(r, dict)
973        self.assertEqual(r['Prime!'], 11)
974        self.assertEqual(r['much space'], 2002)
975        self.assertEqual(r['Questions?'], 'What?')
976        r = query('select * from "%s" limit 2' % table).dictresult()
977        self.assertEqual(len(r), 1)
978        r = r[0]
979        self.assertEqual(r['Prime!'], 11)
980        self.assertEqual(r['much space'], 2002)
981        self.assertEqual(r['Questions?'], 'What?')
982        query('drop table "%s"' % table)
983
984    def testUpdate(self):
985        update = self.db.update
986        query = self.db.query
987        table = 'update_test_table'
988        query('drop table if exists "%s"' % table)
989        query('create table "%s" ('
990            "n integer, t text) with oids" % table)
991        for n, t in enumerate('xyz'):
992            query('insert into "%s" values('
993                "%d, '%s')" % (table, n + 1, t))
994        self.assertRaises(pg.ProgrammingError, self.db.get, table, 2)
995        r = self.db.get(table, 2, 'n')
996        r['t'] = 'u'
997        s = update(table, r)
998        self.assertEqual(s, r)
999        q = 'select t from "%s" where n=2' % table
1000        r = query(q).getresult()[0][0]
1001        self.assertEqual(r, 'u')
1002        query('drop table "%s"' % table)
1003
1004    def testUpdateWithCompositeKey(self):
1005        update = self.db.update
1006        query = self.db.query
1007        table = 'update_test_table_1'
1008        query('drop table if exists "%s"' % table)
1009        query('create table "%s" ('
1010            "n integer, t text, primary key (n))" % table)
1011        for n, t in enumerate('abc'):
1012            query('insert into "%s" values('
1013                "%d, '%s')" % (table, n + 1, t))
1014        self.assertRaises(pg.ProgrammingError, update,
1015                          table, dict(t='b'))
1016        s = dict(n=2, t='d')
1017        r = update(table, s)
1018        self.assertIs(r, s)
1019        self.assertEqual(r['n'], 2)
1020        self.assertEqual(r['t'], 'd')
1021        q = 'select t from "%s" where n=2' % table
1022        r = query(q).getresult()[0][0]
1023        self.assertEqual(r, 'd')
1024        s.update(dict(n=4, t='e'))
1025        r = update(table, s)
1026        self.assertEqual(r['n'], 4)
1027        self.assertEqual(r['t'], 'e')
1028        q = 'select t from "%s" where n=2' % table
1029        r = query(q).getresult()[0][0]
1030        self.assertEqual(r, 'd')
1031        q = 'select t from "%s" where n=4' % table
1032        r = query(q).getresult()
1033        self.assertEqual(len(r), 0)
1034        query('drop table "%s"' % table)
1035        table = 'update_test_table_2'
1036        query('drop table if exists "%s"' % table)
1037        query('create table "%s" ('
1038            "n integer, m integer, t text, primary key (n, m))" % table)
1039        for n in range(3):
1040            for m in range(2):
1041                t = chr(ord('a') + 2 * n + m)
1042                query('insert into "%s" values('
1043                    "%d, %d, '%s')" % (table, n + 1, m + 1, t))
1044        self.assertRaises(pg.ProgrammingError, update,
1045                          table, dict(n=2, t='b'))
1046        self.assertEqual(update(table,
1047                                dict(n=2, m=2, t='x'))['t'], 'x')
1048        q = 'select t from "%s" where n=2 order by m' % table
1049        r = [r[0] for r in query(q).getresult()]
1050        self.assertEqual(r, ['c', 'x'])
1051        query('drop table "%s"' % table)
1052
1053    def testUpdateWithQuotedNames(self):
1054        update = self.db.update
1055        query = self.db.query
1056        table = 'test table for update()'
1057        query('drop table if exists "%s"' % table)
1058        query('create table "%s" ('
1059            '"Prime!" smallint primary key,'
1060            '"much space" integer, "Questions?" text)' % table)
1061        query('insert into "%s"'
1062              " values(13, 3003, 'Why!')" % table)
1063        r = {'Prime!': 13, 'much space': 7007, 'Questions?': 'When?'}
1064        r = update(table, r)
1065        self.assertIsInstance(r, dict)
1066        self.assertEqual(r['Prime!'], 13)
1067        self.assertEqual(r['much space'], 7007)
1068        self.assertEqual(r['Questions?'], 'When?')
1069        r = query('select * from "%s" limit 2' % table).dictresult()
1070        self.assertEqual(len(r), 1)
1071        r = r[0]
1072        self.assertEqual(r['Prime!'], 13)
1073        self.assertEqual(r['much space'], 7007)
1074        self.assertEqual(r['Questions?'], 'When?')
1075        query('drop table "%s"' % table)
1076
1077    def testUpsert(self):
1078        upsert = self.db.upsert
1079        query = self.db.query
1080        table = 'upsert_test_table'
1081        query('drop table if exists "%s"' % table)
1082        query('create table "%s" ('
1083            "n integer primary key, t text) with oids" % table)
1084        s = dict(n=1, t='x')
1085        try:
1086            r = upsert(table, s)
1087        except pg.ProgrammingError as error:
1088            if self.db.server_version < 90500:
1089                self.skipTest('database does not support upsert')
1090            self.fail(str(error))
1091        self.assertIs(r, s)
1092        self.assertEqual(r['n'], 1)
1093        self.assertEqual(r['t'], 'x')
1094        s.update(n=2, t='y')
1095        r = upsert(table, s, **dict.fromkeys(s))
1096        self.assertIs(r, s)
1097        self.assertEqual(r['n'], 2)
1098        self.assertEqual(r['t'], 'y')
1099        q = 'select n, t from "%s" order by n limit 3' % table
1100        r = query(q).getresult()
1101        self.assertEqual(r, [(1, 'x'), (2, 'y')])
1102        s.update(t='z')
1103        r = upsert(table, s)
1104        self.assertIs(r, s)
1105        self.assertEqual(r['n'], 2)
1106        self.assertEqual(r['t'], 'z')
1107        r = query(q).getresult()
1108        self.assertEqual(r, [(1, 'x'), (2, 'z')])
1109        s.update(t='n')
1110        r = upsert(table, s, t=False)
1111        self.assertIs(r, s)
1112        self.assertEqual(r['n'], 2)
1113        self.assertEqual(r['t'], 'z')
1114        r = query(q).getresult()
1115        self.assertEqual(r, [(1, 'x'), (2, 'z')])
1116        s.update(t='y')
1117        r = upsert(table, s, t=True)
1118        self.assertIs(r, s)
1119        self.assertEqual(r['n'], 2)
1120        self.assertEqual(r['t'], 'y')
1121        r = query(q).getresult()
1122        self.assertEqual(r, [(1, 'x'), (2, 'y')])
1123        s.update(t='n')
1124        r = upsert(table, s, t="included.t || '2'")
1125        self.assertIs(r, s)
1126        self.assertEqual(r['n'], 2)
1127        self.assertEqual(r['t'], 'y2')
1128        r = query(q).getresult()
1129        self.assertEqual(r, [(1, 'x'), (2, 'y2')])
1130        s.update(t='y')
1131        r = upsert(table, s, t="excluded.t || '3'")
1132        self.assertIs(r, s)
1133        self.assertEqual(r['n'], 2)
1134        self.assertEqual(r['t'], 'y3')
1135        r = query(q).getresult()
1136        self.assertEqual(r, [(1, 'x'), (2, 'y3')])
1137        s.update(n=1, t='2')
1138        r = upsert(table, s, t="included.t || excluded.t")
1139        self.assertIs(r, s)
1140        self.assertEqual(r['n'], 1)
1141        self.assertEqual(r['t'], 'x2')
1142        r = query(q).getresult()
1143        self.assertEqual(r, [(1, 'x2'), (2, 'y3')])
1144        query('drop table "%s"' % table)
1145
1146    def testUpsertWithCompositeKey(self):
1147        upsert = self.db.upsert
1148        query = self.db.query
1149        table = 'upsert_test_table_2'
1150        query('drop table if exists "%s"' % table)
1151        query('create table "%s" ('
1152            "n integer, m integer, t text, primary key (n, m))" % table)
1153        s = dict(n=1, m=2, t='x')
1154        try:
1155            r = upsert(table, s)
1156        except pg.ProgrammingError as error:
1157            if self.db.server_version < 90500:
1158                self.skipTest('database does not support upsert')
1159            self.fail(str(error))
1160        self.assertIs(r, s)
1161        self.assertEqual(r['n'], 1)
1162        self.assertEqual(r['m'], 2)
1163        self.assertEqual(r['t'], 'x')
1164        s.update(m=3, t='y')
1165        r = upsert(table, s, **dict.fromkeys(s))
1166        self.assertIs(r, s)
1167        self.assertEqual(r['n'], 1)
1168        self.assertEqual(r['m'], 3)
1169        self.assertEqual(r['t'], 'y')
1170        q = 'select n, m, t from "%s" order by n, m limit 3' % table
1171        r = query(q).getresult()
1172        self.assertEqual(r, [(1, 2, 'x'), (1, 3, 'y')])
1173        s.update(t='z')
1174        r = upsert(table, s)
1175        self.assertIs(r, s)
1176        self.assertEqual(r['n'], 1)
1177        self.assertEqual(r['m'], 3)
1178        self.assertEqual(r['t'], 'z')
1179        r = query(q).getresult()
1180        self.assertEqual(r, [(1, 2, 'x'), (1, 3, 'z')])
1181        s.update(t='n')
1182        r = upsert(table, s, t=False)
1183        self.assertIs(r, s)
1184        self.assertEqual(r['n'], 1)
1185        self.assertEqual(r['m'], 3)
1186        self.assertEqual(r['t'], 'z')
1187        r = query(q).getresult()
1188        self.assertEqual(r, [(1, 2, 'x'), (1, 3, 'z')])
1189        s.update(t='n')
1190        r = upsert(table, s, t=True)
1191        self.assertIs(r, s)
1192        self.assertEqual(r['n'], 1)
1193        self.assertEqual(r['m'], 3)
1194        self.assertEqual(r['t'], 'n')
1195        r = query(q).getresult()
1196        self.assertEqual(r, [(1, 2, 'x'), (1, 3, 'n')])
1197        s.update(n=2, t='y')
1198        r = upsert(table, s, t="'z'")
1199        self.assertIs(r, s)
1200        self.assertEqual(r['n'], 2)
1201        self.assertEqual(r['m'], 3)
1202        self.assertEqual(r['t'], 'y')
1203        r = query(q).getresult()
1204        self.assertEqual(r, [(1, 2, 'x'), (1, 3, 'n'), (2, 3, 'y')])
1205        s.update(n=1, t='m')
1206        r = upsert(table, s, t='included.t || excluded.t')
1207        self.assertIs(r, s)
1208        self.assertEqual(r['n'], 1)
1209        self.assertEqual(r['m'], 3)
1210        self.assertEqual(r['t'], 'nm')
1211        r = query(q).getresult()
1212        self.assertEqual(r, [(1, 2, 'x'), (1, 3, 'nm'), (2, 3, 'y')])
1213        query('drop table "%s"' % table)
1214
1215    def testUpsertWithQuotedNames(self):
1216        upsert = self.db.upsert
1217        query = self.db.query
1218        table = 'test table for upsert()'
1219        query('drop table if exists "%s"' % table)
1220        query('create table "%s" ('
1221            '"Prime!" smallint primary key,'
1222            '"much space" integer, "Questions?" text)' % table)
1223        s = {'Prime!': 31, 'much space': 9009, 'Questions?': 'Yes.'}
1224        try:
1225            r = upsert(table, s)
1226        except pg.ProgrammingError as error:
1227            if self.db.server_version < 90500:
1228                self.skipTest('database does not support upsert')
1229            self.fail(str(error))
1230        self.assertIs(r, s)
1231        self.assertEqual(r['Prime!'], 31)
1232        self.assertEqual(r['much space'], 9009)
1233        self.assertEqual(r['Questions?'], 'Yes.')
1234        q = 'select * from "%s" limit 2' % table
1235        r = query(q).getresult()
1236        self.assertEqual(r, [(31, 9009, 'Yes.')])
1237        s.update({'Questions?': 'No.'})
1238        r = upsert(table, s)
1239        self.assertIs(r, s)
1240        self.assertEqual(r['Prime!'], 31)
1241        self.assertEqual(r['much space'], 9009)
1242        self.assertEqual(r['Questions?'], 'No.')
1243        r = query(q).getresult()
1244        self.assertEqual(r, [(31, 9009, 'No.')])
1245
1246    def testClear(self):
1247        clear = self.db.clear
1248        query = self.db.query
1249        f = False if pg.get_bool() else 'f'
1250        table = 'clear_test_table'
1251        query('drop table if exists "%s"' % table)
1252        query('create table "%s" ('
1253            "n integer, b boolean, d date, t text)" % table)
1254        r = clear(table)
1255        result = {'n': 0, 'b': f, 'd': '', 't': ''}
1256        self.assertEqual(r, result)
1257        r['a'] = r['n'] = 1
1258        r['d'] = r['t'] = 'x'
1259        r['b'] = 't'
1260        r['oid'] = long(1)
1261        r = clear(table, r)
1262        result = {'a': 1, 'n': 0, 'b': f, 'd': '', 't': '',
1263            'oid': long(1)}
1264        self.assertEqual(r, result)
1265        query('drop table "%s"' % table)
1266
1267    def testClearWithQuotedNames(self):
1268        clear = self.db.clear
1269        query = self.db.query
1270        table = 'test table for clear()'
1271        query('drop table if exists "%s"' % table)
1272        query('create table "%s" ('
1273            '"Prime!" smallint primary key,'
1274            '"much space" integer, "Questions?" text)' % table)
1275        r = clear(table)
1276        self.assertIsInstance(r, dict)
1277        self.assertEqual(r['Prime!'], 0)
1278        self.assertEqual(r['much space'], 0)
1279        self.assertEqual(r['Questions?'], '')
1280        query('drop table "%s"' % table)
1281
1282    def testDelete(self):
1283        delete = self.db.delete
1284        query = self.db.query
1285        table = 'delete_test_table'
1286        query('drop table if exists "%s"' % table)
1287        query('create table "%s" ('
1288            "n integer, t text) with oids" % table)
1289        for n, t in enumerate('xyz'):
1290            query('insert into "%s" values('
1291                "%d, '%s')" % (table, n + 1, t))
1292        self.assertRaises(pg.ProgrammingError, self.db.get, table, 2)
1293        r = self.db.get(table, 1, 'n')
1294        s = delete(table, r)
1295        self.assertEqual(s, 1)
1296        r = self.db.get(table, 3, 'n')
1297        s = delete(table, r)
1298        self.assertEqual(s, 1)
1299        s = delete(table, r)
1300        self.assertEqual(s, 0)
1301        r = query('select * from "%s"' % table).dictresult()
1302        self.assertEqual(len(r), 1)
1303        r = r[0]
1304        result = {'n': 2, 't': 'y'}
1305        self.assertEqual(r, result)
1306        r = self.db.get(table, 2, 'n')
1307        s = delete(table, r)
1308        self.assertEqual(s, 1)
1309        s = delete(table, r)
1310        self.assertEqual(s, 0)
1311        self.assertRaises(pg.DatabaseError, self.db.get, table, 2, 'n')
1312        query('drop table "%s"' % table)
1313
1314    def testDeleteWithCompositeKey(self):
1315        query = self.db.query
1316        table = 'delete_test_table_1'
1317        query('drop table if exists "%s"' % table)
1318        query('create table "%s" ('
1319            "n integer, t text, primary key (n))" % table)
1320        for n, t in enumerate('abc'):
1321            query("insert into %s values("
1322                "%d, '%s')" % (table, n + 1, t))
1323        self.assertRaises(pg.ProgrammingError, self.db.delete,
1324            table, dict(t='b'))
1325        self.assertEqual(self.db.delete(table, dict(n=2)), 1)
1326        r = query('select t from "%s" where n=2' % table
1327                  ).getresult()
1328        self.assertEqual(r, [])
1329        self.assertEqual(self.db.delete(table, dict(n=2)), 0)
1330        r = query('select t from "%s" where n=3' % table
1331                  ).getresult()[0][0]
1332        self.assertEqual(r, 'c')
1333        query('drop table "%s"' % table)
1334        table = 'delete_test_table_2'
1335        query('drop table if exists "%s"' % table)
1336        query('create table "%s" ('
1337            "n integer, m integer, t text, primary key (n, m))" % table)
1338        for n in range(3):
1339            for m in range(2):
1340                t = chr(ord('a') + 2 * n + m)
1341                query('insert into "%s" values('
1342                    "%d, %d, '%s')" % (table, n + 1, m + 1, t))
1343        self.assertRaises(pg.ProgrammingError, self.db.delete,
1344            table, dict(n=2, t='b'))
1345        self.assertEqual(self.db.delete(table, dict(n=2, m=2)), 1)
1346        r = [r[0] for r in query('select t from "%s" where n=2'
1347            ' order by m' % table).getresult()]
1348        self.assertEqual(r, ['c'])
1349        self.assertEqual(self.db.delete(table, dict(n=2, m=2)), 0)
1350        r = [r[0] for r in query('select t from "%s" where n=3'
1351            ' order by m' % table).getresult()]
1352        self.assertEqual(r, ['e', 'f'])
1353        self.assertEqual(self.db.delete(table, dict(n=3, m=1)), 1)
1354        r = [r[0] for r in query('select t from "%s" where n=3'
1355            ' order by m' % table).getresult()]
1356        self.assertEqual(r, ['f'])
1357        query('drop table "%s"' % table)
1358
1359    def testDeleteWithQuotedNames(self):
1360        delete = self.db.delete
1361        query = self.db.query
1362        table = 'test table for delete()'
1363        query('drop table if exists "%s"' % table)
1364        query('create table "%s" ('
1365            '"Prime!" smallint primary key,'
1366            '"much space" integer, "Questions?" text)' % table)
1367        query('insert into "%s"'
1368              " values(19, 5005, 'Yes!')" % table)
1369        r = {'Prime!': 17}
1370        r = delete(table, r)
1371        self.assertEqual(r, 0)
1372        r = query('select count(*) from "%s"' % table).getresult()
1373        self.assertEqual(r[0][0], 1)
1374        r = {'Prime!': 19}
1375        r = delete(table, r)
1376        self.assertEqual(r, 1)
1377        r = query('select count(*) from "%s"' % table).getresult()
1378        self.assertEqual(r[0][0], 0)
1379        query('drop table "%s"' % table)
1380
1381    def testTransaction(self):
1382        query = self.db.query
1383        query("drop table if exists test_table")
1384        query("create table test_table (n integer)")
1385        self.db.begin()
1386        query("insert into test_table values (1)")
1387        query("insert into test_table values (2)")
1388        self.db.commit()
1389        self.db.begin()
1390        query("insert into test_table values (3)")
1391        query("insert into test_table values (4)")
1392        self.db.rollback()
1393        self.db.begin()
1394        query("insert into test_table values (5)")
1395        self.db.savepoint('before6')
1396        query("insert into test_table values (6)")
1397        self.db.rollback('before6')
1398        query("insert into test_table values (7)")
1399        self.db.commit()
1400        self.db.begin()
1401        self.db.savepoint('before8')
1402        query("insert into test_table values (8)")
1403        self.db.release('before8')
1404        self.assertRaises(pg.ProgrammingError, self.db.rollback, 'before8')
1405        self.db.commit()
1406        self.db.start()
1407        query("insert into test_table values (9)")
1408        self.db.end()
1409        r = [r[0] for r in query(
1410            "select * from test_table order by 1").getresult()]
1411        self.assertEqual(r, [1, 2, 5, 7, 9])
1412        query("drop table test_table")
1413
1414    def testContextManager(self):
1415        query = self.db.query
1416        query("drop table if exists test_table")
1417        query("create table test_table (n integer check(n>0))")
1418        with self.db:
1419            query("insert into test_table values (1)")
1420            query("insert into test_table values (2)")
1421        try:
1422            with self.db:
1423                query("insert into test_table values (3)")
1424                query("insert into test_table values (4)")
1425                raise ValueError('test transaction should rollback')
1426        except ValueError as error:
1427            self.assertEqual(str(error), 'test transaction should rollback')
1428        with self.db:
1429            query("insert into test_table values (5)")
1430        try:
1431            with self.db:
1432                query("insert into test_table values (6)")
1433                query("insert into test_table values (-1)")
1434        except pg.ProgrammingError as error:
1435            self.assertTrue('check' in str(error))
1436        with self.db:
1437            query("insert into test_table values (7)")
1438        r = [r[0] for r in query(
1439            "select * from test_table order by 1").getresult()]
1440        self.assertEqual(r, [1, 2, 5, 7])
1441        query("drop table test_table")
1442
1443    def testBytea(self):
1444        query = self.db.query
1445        query('drop table if exists bytea_test')
1446        query('create table bytea_test (n smallint primary key, data bytea)')
1447        s = b"It's all \\ kinds \x00 of\r nasty \xff stuff!\n"
1448        r = self.db.escape_bytea(s)
1449        query('insert into bytea_test values(3,$1)', (r,))
1450        r = query('select * from bytea_test where n=3').getresult()
1451        self.assertEqual(len(r), 1)
1452        r = r[0]
1453        self.assertEqual(len(r), 2)
1454        self.assertEqual(r[0], 3)
1455        r = r[1]
1456        self.assertIsInstance(r, str)
1457        r = self.db.unescape_bytea(r)
1458        self.assertIsInstance(r, bytes)
1459        self.assertEqual(r, s)
1460        query('drop table bytea_test')
1461
1462    def testInsertUpdateGetBytea(self):
1463        query = self.db.query
1464        query('drop table if exists bytea_test')
1465        query('create table bytea_test (n smallint primary key, data bytea)')
1466        # insert null value
1467        r = self.db.insert('bytea_test', n=0, data=None)
1468        self.assertIsInstance(r, dict)
1469        self.assertIn('n', r)
1470        self.assertEqual(r['n'], 0)
1471        self.assertIn('data', r)
1472        self.assertIsNone(r['data'])
1473        s = b'None'
1474        r = self.db.update('bytea_test', n=0, data=s)
1475        self.assertIsInstance(r, dict)
1476        self.assertIn('n', r)
1477        self.assertEqual(r['n'], 0)
1478        self.assertIn('data', r)
1479        r = r['data']
1480        self.assertIsInstance(r, bytes)
1481        self.assertEqual(r, s)
1482        r = self.db.update('bytea_test', n=0, data=None)
1483        self.assertIsNone(r['data'])
1484        # insert as bytes
1485        s = b"It's all \\ kinds \x00 of\r nasty \xff stuff!\n"
1486        r = self.db.insert('bytea_test', n=5, data=s)
1487        self.assertIsInstance(r, dict)
1488        self.assertIn('n', r)
1489        self.assertEqual(r['n'], 5)
1490        self.assertIn('data', r)
1491        r = r['data']
1492        self.assertIsInstance(r, bytes)
1493        self.assertEqual(r, s)
1494        # update as bytes
1495        s += b"and now even more \x00 nasty \t stuff!\f"
1496        r = self.db.update('bytea_test', n=5, data=s)
1497        self.assertIsInstance(r, dict)
1498        self.assertIn('n', r)
1499        self.assertEqual(r['n'], 5)
1500        self.assertIn('data', r)
1501        r = r['data']
1502        self.assertIsInstance(r, bytes)
1503        self.assertEqual(r, s)
1504        r = query('select * from bytea_test where n=5').getresult()
1505        self.assertEqual(len(r), 1)
1506        r = r[0]
1507        self.assertEqual(len(r), 2)
1508        self.assertEqual(r[0], 5)
1509        r = r[1]
1510        self.assertIsInstance(r, str)
1511        r = self.db.unescape_bytea(r)
1512        self.assertIsInstance(r, bytes)
1513        self.assertEqual(r, s)
1514        r = self.db.get('bytea_test', dict(n=5))
1515        self.assertIsInstance(r, dict)
1516        self.assertIn('n', r)
1517        self.assertEqual(r['n'], 5)
1518        self.assertIn('data', r)
1519        r = r['data']
1520        self.assertIsInstance(r, bytes)
1521        self.assertEqual(r, s)
1522        query('drop table bytea_test')
1523
1524    def testDebugWithCallable(self):
1525        if debug:
1526            self.assertEqual(self.db.debug, debug)
1527        else:
1528            self.assertIsNone(self.db.debug)
1529        s = []
1530        self.db.debug = s.append
1531        try:
1532            self.db.query("select 1")
1533            self.db.query("select 2")
1534            self.assertEqual(s, ["select 1", "select 2"])
1535        finally:
1536            self.db.debug = debug
1537
1538
1539class TestDBClassNonStdOpts(TestDBClass):
1540    """Test the methods of the DB class with non-standard global options."""
1541
1542    @classmethod
1543    def setUpClass(cls):
1544        cls.saved_options = {}
1545        cls.set_option('decimal', float)
1546        not_bool = not pg.get_bool()
1547        cls.set_option('bool', not_bool)
1548        unnamed_result = lambda q: q.getresult()
1549        cls.set_option('namedresult', unnamed_result)
1550        super(TestDBClassNonStdOpts, cls).setUpClass()
1551
1552    @classmethod
1553    def tearDownClass(cls):
1554        super(TestDBClassNonStdOpts, cls).tearDownClass()
1555        cls.reset_option('namedresult')
1556        cls.reset_option('bool')
1557        cls.reset_option('decimal')
1558
1559    @classmethod
1560    def set_option(cls, option, value):
1561        cls.saved_options[option] = getattr(pg, 'get_' + option)()
1562        return getattr(pg, 'set_' + option)(value)
1563
1564    @classmethod
1565    def reset_option(cls, option):
1566        return getattr(pg, 'set_' + option)(cls.saved_options[option])
1567
1568
1569class TestSchemas(unittest.TestCase):
1570    """Test correct handling of schemas (namespaces)."""
1571
1572    @classmethod
1573    def setUpClass(cls):
1574        db = DB()
1575        query = db.query
1576        query("set client_min_messages=warning")
1577        for num_schema in range(5):
1578            if num_schema:
1579                schema = "s%d" % num_schema
1580                query("drop schema if exists %s cascade" % (schema,))
1581                try:
1582                    query("create schema %s" % (schema,))
1583                except pg.ProgrammingError:
1584                    raise RuntimeError("The test user cannot create schemas.\n"
1585                        "Grant create on database %s to the user"
1586                        " for running these tests." % dbname)
1587            else:
1588                schema = "public"
1589                query("drop table if exists %s.t" % (schema,))
1590                query("drop table if exists %s.t%d" % (schema, num_schema))
1591            query("create table %s.t with oids as select 1 as n, %d as d"
1592                  % (schema, num_schema))
1593            query("create table %s.t%d with oids as select 1 as n, %d as d"
1594                  % (schema, num_schema, num_schema))
1595        db.close()
1596
1597    @classmethod
1598    def tearDownClass(cls):
1599        db = DB()
1600        query = db.query
1601        query("set client_min_messages=warning")
1602        for num_schema in range(5):
1603            if num_schema:
1604                schema = "s%d" % num_schema
1605                query("drop schema %s cascade" % (schema,))
1606            else:
1607                schema = "public"
1608                query("drop table %s.t" % (schema,))
1609                query("drop table %s.t%d" % (schema, num_schema))
1610        db.close()
1611
1612    def setUp(self):
1613        self.db = DB()
1614        self.db.query("set client_min_messages=warning")
1615
1616    def tearDown(self):
1617        self.db.close()
1618
1619    def testGetTables(self):
1620        tables = self.db.get_tables()
1621        for num_schema in range(5):
1622            if num_schema:
1623                schema = "s" + str(num_schema)
1624            else:
1625                schema = "public"
1626            for t in (schema + ".t",
1627                    schema + ".t" + str(num_schema)):
1628                self.assertIn(t, tables)
1629
1630    def testGetAttnames(self):
1631        get_attnames = self.db.get_attnames
1632        query = self.db.query
1633        result = {'oid': 'int', 'd': 'int', 'n': 'int'}
1634        r = get_attnames("t")
1635        self.assertEqual(r, result)
1636        r = get_attnames("s4.t4")
1637        self.assertEqual(r, result)
1638        query("drop table if exists s3.t3m")
1639        query("create table s3.t3m with oids as select 1 as m")
1640        result_m = {'oid': 'int', 'm': 'int'}
1641        r = get_attnames("s3.t3m")
1642        self.assertEqual(r, result_m)
1643        query("set search_path to s1,s3")
1644        r = get_attnames("t3")
1645        self.assertEqual(r, result)
1646        r = get_attnames("t3m")
1647        self.assertEqual(r, result_m)
1648        query("drop table s3.t3m")
1649
1650    def testGet(self):
1651        get = self.db.get
1652        query = self.db.query
1653        PrgError = pg.ProgrammingError
1654        self.assertEqual(get("t", 1, 'n')['d'], 0)
1655        self.assertEqual(get("t0", 1, 'n')['d'], 0)
1656        self.assertEqual(get("public.t", 1, 'n')['d'], 0)
1657        self.assertEqual(get("public.t0", 1, 'n')['d'], 0)
1658        self.assertRaises(PrgError, get, "public.t1", 1, 'n')
1659        self.assertEqual(get("s1.t1", 1, 'n')['d'], 1)
1660        self.assertEqual(get("s3.t", 1, 'n')['d'], 3)
1661        query("set search_path to s2,s4")
1662        self.assertRaises(PrgError, get, "t1", 1, 'n')
1663        self.assertEqual(get("t4", 1, 'n')['d'], 4)
1664        self.assertRaises(PrgError, get, "t3", 1, 'n')
1665        self.assertEqual(get("t", 1, 'n')['d'], 2)
1666        self.assertEqual(get("s3.t3", 1, 'n')['d'], 3)
1667        query("set search_path to s1,s3")
1668        self.assertRaises(PrgError, get, "t2", 1, 'n')
1669        self.assertEqual(get("t3", 1, 'n')['d'], 3)
1670        self.assertRaises(PrgError, get, "t4", 1, 'n')
1671        self.assertEqual(get("t", 1, 'n')['d'], 1)
1672        self.assertEqual(get("s4.t4", 1, 'n')['d'], 4)
1673
1674    def testMunging(self):
1675        get = self.db.get
1676        query = self.db.query
1677        r = get("t", 1, 'n')
1678        self.assertIn('oid(t)', r)
1679        query("set search_path to s2")
1680        r = get("t2", 1, 'n')
1681        self.assertIn('oid(t2)', r)
1682        query("set search_path to s3")
1683        r = get("t", 1, 'n')
1684        self.assertIn('oid(t)', r)
1685
1686
1687if __name__ == '__main__':
1688    unittest.main()
Note: See TracBrowser for help on using the repository browser.