source: trunk/tests/test_classic_dbwrapper.py @ 739

Last change on this file since 739 was 739, checked in by cito, 4 years ago

Return ordered dict for attributes is possible

Sometimes it's important to know the order of the columns in a table.
By returning an OrderedDict? instead of a dict in get_attnames, we can
deliver that information en passant, while staying backward compatible.

  • Property svn:executable set to *
  • Property svn:keywords set to Id
File size: 61.0 KB
Line 
1#! /usr/bin/python
2# -*- coding: utf-8 -*-
3
4"""Test the classic PyGreSQL interface.
5
6Sub-tests for the DB wrapper object.
7
8Contributed by Christoph Zwerschke.
9
10These tests need a database to test against.
11
12"""
13
14try:
15    import unittest2 as unittest  # for Python < 2.7
16except ImportError:
17    import unittest
18import os
19
20import pg  # the module under test
21
22from decimal import Decimal
23
24# We need a database to test against.  If LOCAL_PyGreSQL.py exists we will
25# get our information from that.  Otherwise we use the defaults.
26# The current user must have create schema privilege on the database.
27dbname = 'unittest'
28dbhost = None
29dbport = 5432
30
31debug = False  # let DB wrapper print debugging output
32
33try:
34    from .LOCAL_PyGreSQL import *
35except (ImportError, ValueError):
36    try:
37        from LOCAL_PyGreSQL import *
38    except ImportError:
39        pass
40
41try:
42    long
43except NameError:  # Python >= 3.0
44    long = int
45
46try:
47    unicode
48except NameError:  # Python >= 3.0
49    unicode = str
50
51try:
52    from collections import OrderedDict
53except ImportError:  # Python 2.6 or 3.0
54    OrderedDict = dict
55
56windows = os.name == 'nt'
57
58# There is a known a bug in libpq under Windows which can cause
59# the interface to crash when calling PQhost():
60do_not_ask_for_host = windows
61do_not_ask_for_host_reason = 'libpq issue on Windows'
62
63
64def DB():
65    """Create a DB wrapper object connecting to the test database."""
66    db = pg.DB(dbname, dbhost, dbport)
67    if debug:
68        db.debug = debug
69    db.query("set client_min_messages=warning")
70    return db
71
72
73class TestDBClassBasic(unittest.TestCase):
74    """Test existence of the DB class wrapped pg connection methods."""
75
76    def setUp(self):
77        self.db = DB()
78
79    def tearDown(self):
80        try:
81            self.db.close()
82        except pg.InternalError:
83            pass
84
85    def testAllDBAttributes(self):
86        attributes = [
87            'begin',
88            'cancel',
89            'clear',
90            'close',
91            'commit',
92            'db',
93            'dbname',
94            'debug',
95            'delete',
96            'end',
97            'endcopy',
98            'error',
99            'escape_bytea',
100            'escape_identifier',
101            'escape_literal',
102            'escape_string',
103            'fileno',
104            'get',
105            'get_attnames',
106            'get_databases',
107            'get_notice_receiver',
108            'get_relations',
109            'get_tables',
110            'getline',
111            'getlo',
112            'getnotify',
113            'has_table_privilege',
114            'host',
115            'insert',
116            'inserttable',
117            'locreate',
118            'loimport',
119            'notification_handler',
120            'options',
121            'parameter',
122            'pkey',
123            'port',
124            'protocol_version',
125            'putline',
126            'query',
127            'release',
128            'reopen',
129            'reset',
130            'rollback',
131            'savepoint',
132            'server_version',
133            'set_notice_receiver',
134            'source',
135            'start',
136            'status',
137            'transaction',
138            'unescape_bytea',
139            'update',
140            'upsert',
141            'use_regtypes',
142            'user',
143        ]
144        db_attributes = [a for a in dir(self.db)
145            if not a.startswith('_')]
146        self.assertEqual(attributes, db_attributes)
147
148    def testAttributeDb(self):
149        self.assertEqual(self.db.db.db, dbname)
150
151    def testAttributeDbname(self):
152        self.assertEqual(self.db.dbname, dbname)
153
154    def testAttributeError(self):
155        error = self.db.error
156        self.assertTrue(not error or 'krb5_' in error)
157        self.assertEqual(self.db.error, self.db.db.error)
158
159    @unittest.skipIf(do_not_ask_for_host, do_not_ask_for_host_reason)
160    def testAttributeHost(self):
161        def_host = 'localhost'
162        host = self.db.host
163        self.assertIsInstance(host, str)
164        self.assertEqual(host, dbhost or def_host)
165        self.assertEqual(host, self.db.db.host)
166
167    def testAttributeOptions(self):
168        no_options = ''
169        options = self.db.options
170        self.assertEqual(options, no_options)
171        self.assertEqual(options, self.db.db.options)
172
173    def testAttributePort(self):
174        def_port = 5432
175        port = self.db.port
176        self.assertIsInstance(port, int)
177        self.assertEqual(port, dbport or def_port)
178        self.assertEqual(port, self.db.db.port)
179
180    def testAttributeProtocolVersion(self):
181        protocol_version = self.db.protocol_version
182        self.assertIsInstance(protocol_version, int)
183        self.assertTrue(2 <= protocol_version < 4)
184        self.assertEqual(protocol_version, self.db.db.protocol_version)
185
186    def testAttributeServerVersion(self):
187        server_version = self.db.server_version
188        self.assertIsInstance(server_version, int)
189        self.assertTrue(70400 <= server_version < 100000)
190        self.assertEqual(server_version, self.db.db.server_version)
191
192    def testAttributeStatus(self):
193        status_ok = 1
194        status = self.db.status
195        self.assertIsInstance(status, int)
196        self.assertEqual(status, status_ok)
197        self.assertEqual(status, self.db.db.status)
198
199    def testAttributeUser(self):
200        no_user = 'Deprecated facility'
201        user = self.db.user
202        self.assertTrue(user)
203        self.assertIsInstance(user, str)
204        self.assertNotEqual(user, no_user)
205        self.assertEqual(user, self.db.db.user)
206
207    def testMethodEscapeLiteral(self):
208        self.assertEqual(self.db.escape_literal(''), "''")
209
210    def testMethodEscapeIdentifier(self):
211        self.assertEqual(self.db.escape_identifier(''), '""')
212
213    def testMethodEscapeString(self):
214        self.assertEqual(self.db.escape_string(''), '')
215
216    def testMethodEscapeBytea(self):
217        self.assertEqual(self.db.escape_bytea('').replace(
218            '\\x', '').replace('\\', ''), '')
219
220    def testMethodUnescapeBytea(self):
221        self.assertEqual(self.db.unescape_bytea(''), b'')
222
223    def testMethodQuery(self):
224        query = self.db.query
225        query("select 1+1")
226        query("select 1+$1+$2", 2, 3)
227        query("select 1+$1+$2", (2, 3))
228        query("select 1+$1+$2", [2, 3])
229        query("select 1+$1", 1)
230
231    def testMethodQueryEmpty(self):
232        self.assertRaises(ValueError, self.db.query, '')
233
234    def testMethodQueryProgrammingError(self):
235        try:
236            self.db.query("select 1/0")
237        except pg.ProgrammingError as error:
238            self.assertEqual(error.sqlstate, '22012')
239
240    def testMethodEndcopy(self):
241        try:
242            self.db.endcopy()
243        except IOError:
244            pass
245
246    def testMethodClose(self):
247        self.db.close()
248        try:
249            self.db.reset()
250        except pg.Error:
251            pass
252        else:
253            self.fail('Reset should give an error for a closed connection')
254        self.assertRaises(pg.InternalError, self.db.close)
255        self.assertRaises(pg.InternalError, self.db.query, 'select 1')
256
257    def testExistingConnection(self):
258        db = pg.DB(self.db.db)
259        self.assertEqual(self.db.db, db.db)
260        self.assertTrue(db.db)
261        db.close()
262        self.assertTrue(db.db)
263        db.reopen()
264        self.assertTrue(db.db)
265        db.close()
266        self.assertTrue(db.db)
267        db = pg.DB(self.db)
268        self.assertEqual(self.db.db, db.db)
269        db = pg.DB(db=self.db.db)
270        self.assertEqual(self.db.db, db.db)
271
272        class DB2:
273            pass
274
275        db2 = DB2()
276        db2._cnx = self.db.db
277        db = pg.DB(db2)
278        self.assertEqual(self.db.db, db.db)
279
280
281class TestDBClass(unittest.TestCase):
282    """Test the methods of the DB class wrapped pg connection."""
283
284    @classmethod
285    def setUpClass(cls):
286        db = DB()
287        db.query("drop table if exists test cascade")
288        db.query("create table test ("
289            "i2 smallint, i4 integer, i8 bigint,"
290            "d numeric, f4 real, f8 double precision, m money, "
291            "v4 varchar(4), c4 char(4), t text)")
292        db.query("create or replace view test_view as"
293            " select i4, v4 from test")
294        db.close()
295
296    @classmethod
297    def tearDownClass(cls):
298        db = DB()
299        db.query("drop table test cascade")
300        db.close()
301
302    def setUp(self):
303        self.db = DB()
304        query = self.db.query
305        query('set client_encoding=utf8')
306        query('set standard_conforming_strings=on')
307        query("set lc_monetary='C'")
308        query("set datestyle='ISO,YMD'")
309        query('set bytea_output=hex')
310
311    def tearDown(self):
312        self.db.close()
313
314    def testClassName(self):
315        self.assertEqual(self.db.__class__.__name__, 'DB')
316
317    def testModuleName(self):
318        self.assertEqual(self.db.__module__, 'pg')
319        self.assertEqual(self.db.__class__.__module__, 'pg')
320
321    def testEscapeLiteral(self):
322        f = self.db.escape_literal
323        r = f(b"plain")
324        self.assertIsInstance(r, bytes)
325        self.assertEqual(r, b"'plain'")
326        r = f(u"plain")
327        self.assertIsInstance(r, unicode)
328        self.assertEqual(r, u"'plain'")
329        r = f(u"that's kÀse".encode('utf-8'))
330        self.assertIsInstance(r, bytes)
331        self.assertEqual(r, u"'that''s kÀse'".encode('utf-8'))
332        r = f(u"that's kÀse")
333        self.assertIsInstance(r, unicode)
334        self.assertEqual(r, u"'that''s kÀse'")
335        self.assertEqual(f(r"It's fine to have a \ inside."),
336            r" E'It''s fine to have a \\ inside.'")
337        self.assertEqual(f('No "quotes" must be escaped.'),
338            "'No \"quotes\" must be escaped.'")
339
340    def testEscapeIdentifier(self):
341        f = self.db.escape_identifier
342        r = f(b"plain")
343        self.assertIsInstance(r, bytes)
344        self.assertEqual(r, b'"plain"')
345        r = f(u"plain")
346        self.assertIsInstance(r, unicode)
347        self.assertEqual(r, u'"plain"')
348        r = f(u"that's kÀse".encode('utf-8'))
349        self.assertIsInstance(r, bytes)
350        self.assertEqual(r, u'"that\'s kÀse"'.encode('utf-8'))
351        r = f(u"that's kÀse")
352        self.assertIsInstance(r, unicode)
353        self.assertEqual(r, u'"that\'s kÀse"')
354        self.assertEqual(f(r"It's fine to have a \ inside."),
355            '"It\'s fine to have a \\ inside."')
356        self.assertEqual(f('All "quotes" must be escaped.'),
357            '"All ""quotes"" must be escaped."')
358
359    def testEscapeString(self):
360        f = self.db.escape_string
361        r = f(b"plain")
362        self.assertIsInstance(r, bytes)
363        self.assertEqual(r, b"plain")
364        r = f(u"plain")
365        self.assertIsInstance(r, unicode)
366        self.assertEqual(r, u"plain")
367        r = f(u"that's kÀse".encode('utf-8'))
368        self.assertIsInstance(r, bytes)
369        self.assertEqual(r, u"that''s kÀse".encode('utf-8'))
370        r = f(u"that's kÀse")
371        self.assertIsInstance(r, unicode)
372        self.assertEqual(r, u"that''s kÀse")
373        self.assertEqual(f(r"It's fine to have a \ inside."),
374            r"It''s fine to have a \ inside.")
375
376    def testEscapeBytea(self):
377        f = self.db.escape_bytea
378        # note that escape_byte always returns hex output since Pg 9.0,
379        # regardless of the bytea_output setting
380        r = f(b'plain')
381        self.assertIsInstance(r, bytes)
382        self.assertEqual(r, b'\\x706c61696e')
383        r = f(u'plain')
384        self.assertIsInstance(r, unicode)
385        self.assertEqual(r, u'\\x706c61696e')
386        r = f(u"das is' kÀse".encode('utf-8'))
387        self.assertIsInstance(r, bytes)
388        self.assertEqual(r, b'\\x64617320697327206bc3a47365')
389        r = f(u"das is' kÀse")
390        self.assertIsInstance(r, unicode)
391        self.assertEqual(r, u'\\x64617320697327206bc3a47365')
392        self.assertEqual(f(b'O\x00ps\xff!'), b'\\x4f007073ff21')
393
394    def testUnescapeBytea(self):
395        f = self.db.unescape_bytea
396        r = f(b'plain')
397        self.assertIsInstance(r, bytes)
398        self.assertEqual(r, b'plain')
399        r = f(u'plain')
400        self.assertIsInstance(r, bytes)
401        self.assertEqual(r, b'plain')
402        r = f(b"das is' k\\303\\244se")
403        self.assertIsInstance(r, bytes)
404        self.assertEqual(r, u"das is' kÀse".encode('utf8'))
405        r = f(u"das is' k\\303\\244se")
406        self.assertIsInstance(r, bytes)
407        self.assertEqual(r, u"das is' kÀse".encode('utf8'))
408        self.assertEqual(f(r'O\\000ps\\377!'), b'O\\000ps\\377!')
409        self.assertEqual(f(r'\\x706c61696e'), b'\\x706c61696e')
410        self.assertEqual(f(r'\\x746861742773206be47365'),
411            b'\\x746861742773206be47365')
412        self.assertEqual(f(r'\\x4f007073ff21'), b'\\x4f007073ff21')
413
414    def testGetAttnames(self):
415        get_attnames = self.db.get_attnames
416        query = self.db.query
417        query("drop table if exists test_table")
418        query("create table test_table("
419            " n int, alpha smallint, beta bool,"
420            " gamma char(5), tau text, v varchar(3))")
421        r = get_attnames("test_table")
422        self.assertIsInstance(r, dict)
423        self.assertEquals(r, dict(
424            n='int', alpha='int', beta='bool',
425            gamma='text', tau='text', v='text'))
426        query("drop table test_table")
427
428    def testGetAttnamesWithQuotes(self):
429        get_attnames = self.db.get_attnames
430        query = self.db.query
431        table = 'test table for get_attnames()'
432        query('drop table if exists "%s"' % table)
433        query('create table "%s"('
434            '"Prime!" smallint,'
435            '"much space" integer, "Questions?" text)' % table)
436        r = get_attnames(table)
437        self.assertIsInstance(r, dict)
438        self.assertEquals(r, {
439            'Prime!': 'int', 'much space': 'int', 'Questions?': 'text'})
440        query('drop table "%s"' % table)
441
442    def testGetAttnamesWithRegtypes(self):
443        get_attnames = self.db.get_attnames
444        query = self.db.query
445        query("drop table if exists test_table")
446        query("create table test_table("
447            " n int, alpha smallint, beta bool,"
448            " gamma char(5), tau text, v varchar(3))")
449        self.db.use_regtypes(True)
450        try:
451            r = get_attnames("test_table")
452            self.assertIsInstance(r, dict)
453        finally:
454            self.db.use_regtypes(False)
455        self.assertEquals(r, dict(
456            n='integer', alpha='smallint', beta='boolean',
457            gamma='character', tau='text', v='character varying'))
458        query("drop table test_table")
459
460    def testGetAttnamesIsCached(self):
461        get_attnames = self.db.get_attnames
462        query = self.db.query
463        query("drop table if exists test_table")
464        query("create table test_table(col int)")
465        r = get_attnames("test_table")
466        self.assertIsInstance(r, dict)
467        self.assertEquals(r, dict(col='int'))
468        query("drop table test_table")
469        query("create table test_table(col text)")
470        r = get_attnames("test_table")
471        self.assertEquals(r, dict(col='int'))
472        r = get_attnames("test_table", flush=True)
473        self.assertEquals(r, dict(col='text'))
474        query("drop table test_table")
475        r = get_attnames("test_table")
476        self.assertEquals(r, dict(col='text'))
477        self.assertRaises(pg.ProgrammingError,
478            get_attnames, "test_table", flush=True)
479
480    def testGetAttnamesIsOrdered(self):
481        get_attnames = self.db.get_attnames
482        query = self.db.query
483        query("drop table if exists test_table")
484        query("create table test_table("
485            " n int, alpha smallint, v varchar(3),"
486            " gamma char(5), tau text, beta bool)")
487        r = get_attnames("test_table")
488        self.assertIsInstance(r, OrderedDict)
489        self.assertEquals(r, OrderedDict([
490            ('n', 'int'), ('alpha', 'int'), ('v', 'text'),
491            ('gamma', 'text'), ('tau', 'text'), ('beta', 'bool')]))
492        query("drop table test_table")
493        if OrderedDict is dict:
494            self.skipTest('OrderedDict is not supported')
495        r = ' '.join(list(r.keys()))
496        self.assertEquals(r, 'n alpha v gamma tau beta')
497
498    def testQuery(self):
499        query = self.db.query
500        query("drop table if exists test_table")
501        q = "create table test_table (n integer) with oids"
502        r = query(q)
503        self.assertIsNone(r)
504        q = "insert into test_table values (1)"
505        r = query(q)
506        self.assertIsInstance(r, int)
507        q = "insert into test_table select 2"
508        r = query(q)
509        self.assertIsInstance(r, int)
510        oid = r
511        q = "select oid from test_table where n=2"
512        r = query(q).getresult()
513        self.assertEqual(len(r), 1)
514        r = r[0]
515        self.assertEqual(len(r), 1)
516        r = r[0]
517        self.assertEqual(r, oid)
518        q = "insert into test_table select 3 union select 4 union select 5"
519        r = query(q)
520        self.assertIsInstance(r, str)
521        self.assertEqual(r, '3')
522        q = "update test_table set n=4 where n<5"
523        r = query(q)
524        self.assertIsInstance(r, str)
525        self.assertEqual(r, '4')
526        q = "delete from test_table"
527        r = query(q)
528        self.assertIsInstance(r, str)
529        self.assertEqual(r, '5')
530        query("drop table test_table")
531
532    def testMultipleQueries(self):
533        self.assertEqual(self.db.query(
534            "create temporary table test_multi (n integer);"
535            "insert into test_multi values (4711);"
536            "select n from test_multi").getresult()[0][0], 4711)
537
538    def testQueryWithParams(self):
539        query = self.db.query
540        query("drop table if exists test_table")
541        q = "create table test_table (n1 integer, n2 integer) with oids"
542        query(q)
543        q = "insert into test_table values ($1, $2)"
544        r = query(q, (1, 2))
545        self.assertIsInstance(r, int)
546        r = query(q, [3, 4])
547        self.assertIsInstance(r, int)
548        r = query(q, [5, 6])
549        self.assertIsInstance(r, int)
550        q = "select * from test_table order by 1, 2"
551        self.assertEqual(query(q).getresult(),
552            [(1, 2), (3, 4), (5, 6)])
553        q = "select * from test_table where n1=$1 and n2=$2"
554        self.assertEqual(query(q, 3, 4).getresult(), [(3, 4)])
555        q = "update test_table set n2=$2 where n1=$1"
556        r = query(q, 3, 7)
557        self.assertEqual(r, '1')
558        q = "select * from test_table order by 1, 2"
559        self.assertEqual(query(q).getresult(),
560            [(1, 2), (3, 7), (5, 6)])
561        q = "delete from test_table where n2!=$1"
562        r = query(q, 4)
563        self.assertEqual(r, '3')
564        query("drop table test_table")
565
566    def testEmptyQuery(self):
567        self.assertRaises(ValueError, self.db.query, '')
568
569    def testQueryProgrammingError(self):
570        try:
571            self.db.query("select 1/0")
572        except pg.ProgrammingError as error:
573            self.assertEqual(error.sqlstate, '22012')
574
575    def testPkey(self):
576        query = self.db.query
577        pkey = self.db.pkey
578        for t in ('pkeytest', 'primary key test'):
579            for n in range(7):
580                query('drop table if exists "%s%d"' % (t, n))
581            query('create table "%s0" ('
582                "a smallint)" % t)
583            query('create table "%s1" ('
584                "b smallint primary key)" % t)
585            query('create table "%s2" ('
586                "c smallint, d smallint primary key)" % t)
587            query('create table "%s3" ('
588                "e smallint, f smallint, g smallint, "
589                "h smallint, i smallint, "
590                "primary key (f, h))" % t)
591            query('create table "%s4" ('
592                "more_than_one_letter varchar primary key)" % t)
593            query('create table "%s5" ('
594                '"with space" date primary key)' % t)
595            query('create table "%s6" ('
596                'a_very_long_column_name varchar, '
597                '"with space" date, '
598                '"42" int, '
599                "primary key (a_very_long_column_name, "
600                '"with space", "42"))' % t)
601            self.assertRaises(KeyError, pkey, '%s0' % t)
602            self.assertEqual(pkey('%s1' % t), 'b')
603            self.assertEqual(pkey('%s2' % t), 'd')
604            r = pkey('%s3' % t)
605            self.assertIsInstance(r, frozenset)
606            self.assertEqual(r, frozenset('fh'))
607            self.assertEqual(pkey('%s4' % t), 'more_than_one_letter')
608            self.assertEqual(pkey('%s5' % t), 'with space')
609            r = pkey('%s6' % t)
610            self.assertIsInstance(r, frozenset)
611            self.assertEqual(r, frozenset([
612                'a_very_long_column_name', 'with space', '42']))
613            # a newly added primary key will be detected
614            query('alter table "%s0" add primary key (a)' % t)
615            self.assertEqual(pkey('%s0' % t), 'a')
616            # a changed primary key will not be detected,
617            # indicating that the internal cache is operating
618            query('alter table "%s1" rename column b to x' % t)
619            self.assertEqual(pkey('%s1' % t), 'b')
620            # we get the changed primary key when the cache is flushed
621            self.assertEqual(pkey('%s1' % t, flush=True), 'x')
622            for n in range(7):
623                query('drop table "%s%d"' % (t, n))
624
625    def testGetDatabases(self):
626        databases = self.db.get_databases()
627        self.assertIn('template0', databases)
628        self.assertIn('template1', databases)
629        self.assertNotIn('not existing database', databases)
630        self.assertIn('postgres', databases)
631        self.assertIn(dbname, databases)
632
633    def testGetTables(self):
634        get_tables = self.db.get_tables
635        result1 = get_tables()
636        self.assertIsInstance(result1, list)
637        for t in result1:
638            t = t.split('.', 1)
639            self.assertGreaterEqual(len(t), 2)
640            if len(t) > 2:
641                self.assertTrue(t[1].startswith('"'))
642            t = t[0]
643            self.assertNotEqual(t, 'information_schema')
644            self.assertFalse(t.startswith('pg_'))
645        tables = ('"A very Special Name"',
646            '"A_MiXeD_quoted_NaMe"', 'a1', 'a2',
647            'A_MiXeD_NaMe', '"another special name"',
648            'averyveryveryveryveryveryverylongtablename',
649            'b0', 'b3', 'x', 'xx', 'xXx', 'y', 'z')
650        for t in tables:
651            self.db.query('drop table if exists %s' % t)
652            self.db.query("create table %s"
653                " as select 0" % t)
654        result3 = get_tables()
655        result2 = []
656        for t in result3:
657            if t not in result1:
658                result2.append(t)
659        result3 = []
660        for t in tables:
661            if not t.startswith('"'):
662                t = t.lower()
663            result3.append('public.' + t)
664        self.assertEqual(result2, result3)
665        for t in result2:
666            self.db.query('drop table %s' % t)
667        result2 = get_tables()
668        self.assertEqual(result2, result1)
669
670    def testGetRelations(self):
671        get_relations = self.db.get_relations
672        result = get_relations()
673        self.assertIn('public.test', result)
674        self.assertIn('public.test_view', result)
675        result = get_relations('rv')
676        self.assertIn('public.test', result)
677        self.assertIn('public.test_view', result)
678        result = get_relations('r')
679        self.assertIn('public.test', result)
680        self.assertNotIn('public.test_view', result)
681        result = get_relations('v')
682        self.assertNotIn('public.test', result)
683        self.assertIn('public.test_view', result)
684        result = get_relations('cisSt')
685        self.assertNotIn('public.test', result)
686        self.assertNotIn('public.test_view', result)
687
688    def testAttnames(self):
689        self.assertRaises(pg.ProgrammingError,
690            self.db.get_attnames, 'does_not_exist')
691        self.assertRaises(pg.ProgrammingError,
692            self.db.get_attnames, 'has.too.many.dots')
693        for table in ('attnames_test_table', 'test table for attnames'):
694            self.db.query('drop table if exists "%s"' % table)
695            self.db.query('create table "%s" ('
696                'a smallint, b integer, c bigint, '
697                'e numeric, f float, f2 double precision, m money, '
698                'x smallint, y smallint, z smallint, '
699                'Normal_NaMe smallint, "Special Name" smallint, '
700                't text, u char(2), v varchar(2), '
701                'primary key (y, u)) with oids' % table)
702            attributes = self.db.get_attnames(table)
703            result = {'a': 'int', 'c': 'int', 'b': 'int',
704                'e': 'num', 'f': 'float', 'f2': 'float', 'm': 'money',
705                'normal_name': 'int', 'Special Name': 'int',
706                'u': 'text', 't': 'text', 'v': 'text',
707                'y': 'int', 'x': 'int', 'z': 'int', 'oid': 'int'}
708            self.assertEqual(attributes, result)
709            self.db.query('drop table "%s"' % table)
710
711    def testHasTablePrivilege(self):
712        can = self.db.has_table_privilege
713        self.assertEqual(can('test'), True)
714        self.assertEqual(can('test', 'select'), True)
715        self.assertEqual(can('test', 'SeLeCt'), True)
716        self.assertEqual(can('test', 'SELECT'), True)
717        self.assertEqual(can('test', 'insert'), True)
718        self.assertEqual(can('test', 'update'), True)
719        self.assertEqual(can('test', 'delete'), True)
720        self.assertEqual(can('pg_views', 'select'), True)
721        self.assertEqual(can('pg_views', 'delete'), False)
722        self.assertRaises(pg.ProgrammingError, can, 'test', 'foobar')
723        self.assertRaises(pg.ProgrammingError, can, 'table_does_not_exist')
724
725    def testGet(self):
726        get = self.db.get
727        query = self.db.query
728        table = 'get_test_table'
729        query('drop table if exists "%s"' % table)
730        query('create table "%s" ('
731            "n integer, t text) with oids" % table)
732        for n, t in enumerate('xyz'):
733            query('insert into "%s" values('"%d, '%s')"
734                % (table, n + 1, t))
735        self.assertRaises(pg.ProgrammingError, get, table, 2)
736        r = get(table, 2, 'n')
737        oid_table = 'oid(%s)' % table
738        self.assertIn(oid_table, r)
739        oid = r[oid_table]
740        self.assertIsInstance(oid, int)
741        result = {'t': 'y', 'n': 2, oid_table: oid}
742        self.assertEqual(r, result)
743        self.assertEqual(get(table + ' *', 2, 'n'), r)
744        self.assertEqual(get(table, oid, 'oid')['t'], 'y')
745        self.assertEqual(get(table, 1, 'n')['t'], 'x')
746        self.assertEqual(get(table, 3, 'n')['t'], 'z')
747        self.assertEqual(get(table, 2, 'n')['t'], 'y')
748        self.assertRaises(pg.DatabaseError, get, table, 4, 'n')
749        r['n'] = 3
750        self.assertEqual(get(table, r, 'n')['t'], 'z')
751        self.assertEqual(get(table, 1, 'n')['t'], 'x')
752        query('alter table "%s" alter n set not null' % table)
753        query('alter table "%s" add primary key (n)' % table)
754        self.assertEqual(get(table, 3)['t'], 'z')
755        self.assertEqual(get(table, 1)['t'], 'x')
756        self.assertEqual(get(table, 2)['t'], 'y')
757        r['n'] = 1
758        self.assertEqual(get(table, r)['t'], 'x')
759        r['n'] = 3
760        self.assertEqual(get(table, r)['t'], 'z')
761        r['n'] = 2
762        self.assertEqual(get(table, r)['t'], 'y')
763        query('drop table "%s"' % table)
764
765    def testGetWithCompositeKey(self):
766        get = self.db.get
767        query = self.db.query
768        table = 'get_test_table_1'
769        query('drop table if exists "%s"' % table)
770        query('create table "%s" ('
771            "n integer, t text, primary key (n))" % table)
772        for n, t in enumerate('abc'):
773            query('insert into "%s" values('
774                "%d, '%s')" % (table, n + 1, t))
775        self.assertEqual(get(table, 2)['t'], 'b')
776        query('drop table "%s"' % table)
777        table = 'get_test_table_2'
778        query('drop table if exists "%s"' % table)
779        query('create table "%s" ('
780            "n integer, m integer, t text, primary key (n, m))" % table)
781        for n in range(3):
782            for m in range(2):
783                t = chr(ord('a') + 2 * n + m)
784                query('insert into "%s" values('
785                    "%d, %d, '%s')" % (table, n + 1, m + 1, t))
786        self.assertRaises(pg.ProgrammingError, get, table, 2)
787        self.assertEqual(get(table, dict(n=2, m=2))['t'], 'd')
788        r = get(table, dict(n=1, m=2), ('n', 'm'))
789        self.assertEqual(r['t'], 'b')
790        r = get(table, dict(n=3, m=2), frozenset(['n', 'm']))
791        self.assertEqual(r['t'], 'f')
792        query('drop table "%s"' % table)
793
794    def testGetWithQuotedNames(self):
795        get = self.db.get
796        query = self.db.query
797        table = 'test table for get()'
798        query('drop table if exists "%s"' % table)
799        query('create table "%s" ('
800            '"Prime!" smallint primary key,'
801            '"much space" integer, "Questions?" text)' % table)
802        query('insert into "%s"'
803              " values(17, 1001, 'No!')" % table)
804        r = get(table, 17)
805        self.assertIsInstance(r, dict)
806        self.assertEqual(r['Prime!'], 17)
807        self.assertEqual(r['much space'], 1001)
808        self.assertEqual(r['Questions?'], 'No!')
809        query('drop table "%s"' % table)
810
811    def testGetFromView(self):
812        self.db.query('delete from test where i4=14')
813        self.db.query('insert into test (i4, v4) values('
814            "14, 'abc4')")
815        r = self.db.get('test_view', 14, 'i4')
816        self.assertIn('v4', r)
817        self.assertEqual(r['v4'], 'abc4')
818
819    def testInsert(self):
820        insert = self.db.insert
821        query = self.db.query
822        bool_on = pg.get_bool()
823        decimal = pg.get_decimal()
824        table = 'insert_test_table'
825        query('drop table if exists "%s"' % table)
826        query('create table "%s" ('
827            "i2 smallint, i4 integer, i8 bigint,"
828            " d numeric, f4 real, f8 double precision, m money,"
829            " v4 varchar(4), c4 char(4), t text,"
830            " b boolean, ts timestamp) with oids" % table)
831        oid_table = 'oid(%s)' % table
832        tests = [dict(i2=None, i4=None, i8=None),
833            (dict(i2='', i4='', i8=''), dict(i2=None, i4=None, i8=None)),
834            (dict(i2=0, i4=0, i8=0), dict(i2=0, i4=0, i8=0)),
835            dict(i2=42, i4=123456, i8=9876543210),
836            dict(i2=2 ** 15 - 1,
837                i4=int(2 ** 31 - 1), i8=long(2 ** 63 - 1)),
838            dict(d=None), (dict(d=''), dict(d=None)),
839            dict(d=Decimal(0)), (dict(d=0), dict(d=Decimal(0))),
840            dict(f4=None, f8=None), dict(f4=0, f8=0),
841            (dict(f4='', f8=''), dict(f4=None, f8=None)),
842            (dict(d=1234.5, f4=1234.5, f8=1234.5),
843                  dict(d=Decimal('1234.5'))),
844            dict(d=Decimal('123.456789'), f4=12.375, f8=123.4921875),
845            dict(d=Decimal('123456789.9876543212345678987654321')),
846            dict(m=None), (dict(m=''), dict(m=None)),
847            dict(m=Decimal('-1234.56')),
848            (dict(m=('-1234.56')), dict(m=Decimal('-1234.56'))),
849            dict(m=Decimal('1234.56')), dict(m=Decimal('123456')),
850            (dict(m='1234.56'), dict(m=Decimal('1234.56'))),
851            (dict(m=1234.5), dict(m=Decimal('1234.5'))),
852            (dict(m=-1234.5), dict(m=Decimal('-1234.5'))),
853            (dict(m=123456), dict(m=Decimal('123456'))),
854            (dict(m='1234567.89'), dict(m=Decimal('1234567.89'))),
855            dict(b=None), (dict(b=''), dict(b=None)),
856            dict(b='f'), dict(b='t'),
857            (dict(b=0), dict(b='f')), (dict(b=1), dict(b='t')),
858            (dict(b=False), dict(b='f')), (dict(b=True), dict(b='t')),
859            (dict(b='0'), dict(b='f')), (dict(b='1'), dict(b='t')),
860            (dict(b='n'), dict(b='f')), (dict(b='y'), dict(b='t')),
861            (dict(b='no'), dict(b='f')), (dict(b='yes'), dict(b='t')),
862            (dict(b='off'), dict(b='f')), (dict(b='on'), dict(b='t')),
863            dict(v4=None, c4=None, t=None),
864            (dict(v4='', c4='', t=''), dict(c4=' ' * 4)),
865            dict(v4='1234', c4='1234', t='1234' * 10),
866            dict(v4='abcd', c4='abcd', t='abcdefg'),
867            (dict(v4='abc', c4='abc', t='abc'), dict(c4='abc ')),
868            dict(ts=None), (dict(ts=''), dict(ts=None)),
869            (dict(ts=0), dict(ts=None)), (dict(ts=False), dict(ts=None)),
870            dict(ts='2012-12-21 00:00:00'),
871            (dict(ts='2012-12-21'), dict(ts='2012-12-21 00:00:00')),
872            dict(ts='2012-12-21 12:21:12'),
873            dict(ts='2013-01-05 12:13:14'),
874            dict(ts='current_timestamp')]
875        for test in tests:
876            if isinstance(test, dict):
877                data = test
878                change = {}
879            else:
880                data, change = test
881            expect = data.copy()
882            expect.update(change)
883            if bool_on:
884                b = expect.get('b')
885                if b is not None:
886                    expect['b'] = b == 't'
887            if decimal is not Decimal:
888                d = expect.get('d')
889                if d is not None:
890                    expect['d'] = decimal(d)
891                m = expect.get('m')
892                if m is not None:
893                    expect['m'] = decimal(m)
894            self.assertEqual(insert(table, data), data)
895            self.assertIn(oid_table, data)
896            oid = data[oid_table]
897            self.assertIsInstance(oid, int)
898            data = dict(item for item in data.items()
899                if item[0] in expect)
900            ts = expect.get('ts')
901            if ts == 'current_timestamp':
902                ts = expect['ts'] = data['ts']
903                if len(ts) > 19:
904                    self.assertEqual(ts[19], '.')
905                    ts = ts[:19]
906                else:
907                    self.assertEqual(len(ts), 19)
908                self.assertTrue(ts[:4].isdigit())
909                self.assertEqual(ts[4], '-')
910                self.assertEqual(ts[10], ' ')
911                self.assertTrue(ts[11:13].isdigit())
912                self.assertEqual(ts[13], ':')
913            self.assertEqual(data, expect)
914            data = query(
915                'select oid,* from "%s"' % table).dictresult()[0]
916            self.assertEqual(data['oid'], oid)
917            data = dict(item for item in data.items()
918                if item[0] in expect)
919            self.assertEqual(data, expect)
920            query('delete from "%s"' % table)
921        query('drop table "%s"' % table)
922
923    def testInsertWithQuotedNames(self):
924        insert = self.db.insert
925        query = self.db.query
926        table = 'test table for insert()'
927        query('drop table if exists "%s"' % table)
928        query('create table "%s" ('
929            '"Prime!" smallint primary key,'
930            '"much space" integer, "Questions?" text)' % table)
931        r = {'Prime!': 11, 'much space': 2002, 'Questions?': 'What?'}
932        r = insert(table, r)
933        self.assertIsInstance(r, dict)
934        self.assertEqual(r['Prime!'], 11)
935        self.assertEqual(r['much space'], 2002)
936        self.assertEqual(r['Questions?'], 'What?')
937        r = query('select * from "%s" limit 2' % table).dictresult()
938        self.assertEqual(len(r), 1)
939        r = r[0]
940        self.assertEqual(r['Prime!'], 11)
941        self.assertEqual(r['much space'], 2002)
942        self.assertEqual(r['Questions?'], 'What?')
943        query('drop table "%s"' % table)
944
945    def testUpdate(self):
946        update = self.db.update
947        query = self.db.query
948        table = 'update_test_table'
949        query('drop table if exists "%s"' % table)
950        query('create table "%s" ('
951            "n integer, t text) with oids" % table)
952        for n, t in enumerate('xyz'):
953            query('insert into "%s" values('
954                "%d, '%s')" % (table, n + 1, t))
955        self.assertRaises(pg.ProgrammingError, self.db.get, table, 2)
956        r = self.db.get(table, 2, 'n')
957        r['t'] = 'u'
958        s = update(table, r)
959        self.assertEqual(s, r)
960        q = 'select t from "%s" where n=2' % table
961        r = query(q).getresult()[0][0]
962        self.assertEqual(r, 'u')
963        query('drop table "%s"' % table)
964
965    def testUpdateWithCompositeKey(self):
966        update = self.db.update
967        query = self.db.query
968        table = 'update_test_table_1'
969        query('drop table if exists "%s"' % table)
970        query('create table "%s" ('
971            "n integer, t text, primary key (n))" % table)
972        for n, t in enumerate('abc'):
973            query('insert into "%s" values('
974                "%d, '%s')" % (table, n + 1, t))
975        self.assertRaises(pg.ProgrammingError, update,
976                          table, dict(t='b'))
977        s = dict(n=2, t='d')
978        r = update(table, s)
979        self.assertIs(r, s)
980        self.assertEqual(r['n'], 2)
981        self.assertEqual(r['t'], 'd')
982        q = 'select t from "%s" where n=2' % table
983        r = query(q).getresult()[0][0]
984        self.assertEqual(r, 'd')
985        s.update(dict(n=4, t='e'))
986        r = update(table, s)
987        self.assertEqual(r['n'], 4)
988        self.assertEqual(r['t'], 'e')
989        q = 'select t from "%s" where n=2' % table
990        r = query(q).getresult()[0][0]
991        self.assertEqual(r, 'd')
992        q = 'select t from "%s" where n=4' % table
993        r = query(q).getresult()
994        self.assertEqual(len(r), 0)
995        query('drop table "%s"' % table)
996        table = 'update_test_table_2'
997        query('drop table if exists "%s"' % table)
998        query('create table "%s" ('
999            "n integer, m integer, t text, primary key (n, m))" % table)
1000        for n in range(3):
1001            for m in range(2):
1002                t = chr(ord('a') + 2 * n + m)
1003                query('insert into "%s" values('
1004                    "%d, %d, '%s')" % (table, n + 1, m + 1, t))
1005        self.assertRaises(pg.ProgrammingError, update,
1006                          table, dict(n=2, t='b'))
1007        self.assertEqual(update(table,
1008                                dict(n=2, m=2, t='x'))['t'], 'x')
1009        q = 'select t from "%s" where n=2 order by m' % table
1010        r = [r[0] for r in query(q).getresult()]
1011        self.assertEqual(r, ['c', 'x'])
1012        query('drop table "%s"' % table)
1013
1014    def testUpdateWithQuotedNames(self):
1015        update = self.db.update
1016        query = self.db.query
1017        table = 'test table for update()'
1018        query('drop table if exists "%s"' % table)
1019        query('create table "%s" ('
1020            '"Prime!" smallint primary key,'
1021            '"much space" integer, "Questions?" text)' % table)
1022        query('insert into "%s"'
1023              " values(13, 3003, 'Why!')" % table)
1024        r = {'Prime!': 13, 'much space': 7007, 'Questions?': 'When?'}
1025        r = update(table, r)
1026        self.assertIsInstance(r, dict)
1027        self.assertEqual(r['Prime!'], 13)
1028        self.assertEqual(r['much space'], 7007)
1029        self.assertEqual(r['Questions?'], 'When?')
1030        r = query('select * from "%s" limit 2' % table).dictresult()
1031        self.assertEqual(len(r), 1)
1032        r = r[0]
1033        self.assertEqual(r['Prime!'], 13)
1034        self.assertEqual(r['much space'], 7007)
1035        self.assertEqual(r['Questions?'], 'When?')
1036        query('drop table "%s"' % table)
1037
1038    def testUpsert(self):
1039        upsert = self.db.upsert
1040        query = self.db.query
1041        table = 'upsert_test_table'
1042        query('drop table if exists "%s"' % table)
1043        query('create table "%s" ('
1044            "n integer primary key, t text) with oids" % table)
1045        s = dict(n=1, t='x')
1046        try:
1047            r = upsert(table, s)
1048        except pg.ProgrammingError as error:
1049            if self.db.server_version < 90500:
1050                self.skipTest('database does not support upsert')
1051            self.fail(str(error))
1052        self.assertIs(r, s)
1053        self.assertEqual(r['n'], 1)
1054        self.assertEqual(r['t'], 'x')
1055        s.update(n=2, t='y')
1056        r = upsert(table, s, **dict.fromkeys(s))
1057        self.assertIs(r, s)
1058        self.assertEqual(r['n'], 2)
1059        self.assertEqual(r['t'], 'y')
1060        q = 'select n, t from "%s" order by n limit 3' % table
1061        r = query(q).getresult()
1062        self.assertEqual(r, [(1, 'x'), (2, 'y')])
1063        s.update(t='z')
1064        r = upsert(table, s)
1065        self.assertIs(r, s)
1066        self.assertEqual(r['n'], 2)
1067        self.assertEqual(r['t'], 'z')
1068        r = query(q).getresult()
1069        self.assertEqual(r, [(1, 'x'), (2, 'z')])
1070        s.update(t='n')
1071        r = upsert(table, s, t=False)
1072        self.assertIs(r, s)
1073        self.assertEqual(r['n'], 2)
1074        self.assertEqual(r['t'], 'z')
1075        r = query(q).getresult()
1076        self.assertEqual(r, [(1, 'x'), (2, 'z')])
1077        s.update(t='y')
1078        r = upsert(table, s, t=True)
1079        self.assertIs(r, s)
1080        self.assertEqual(r['n'], 2)
1081        self.assertEqual(r['t'], 'y')
1082        r = query(q).getresult()
1083        self.assertEqual(r, [(1, 'x'), (2, 'y')])
1084        s.update(t='n')
1085        r = upsert(table, s, t="included.t || '2'")
1086        self.assertIs(r, s)
1087        self.assertEqual(r['n'], 2)
1088        self.assertEqual(r['t'], 'y2')
1089        r = query(q).getresult()
1090        self.assertEqual(r, [(1, 'x'), (2, 'y2')])
1091        s.update(t='y')
1092        r = upsert(table, s, t="excluded.t || '3'")
1093        self.assertIs(r, s)
1094        self.assertEqual(r['n'], 2)
1095        self.assertEqual(r['t'], 'y3')
1096        r = query(q).getresult()
1097        self.assertEqual(r, [(1, 'x'), (2, 'y3')])
1098        s.update(n=1, t='2')
1099        r = upsert(table, s, t="included.t || excluded.t")
1100        self.assertIs(r, s)
1101        self.assertEqual(r['n'], 1)
1102        self.assertEqual(r['t'], 'x2')
1103        r = query(q).getresult()
1104        self.assertEqual(r, [(1, 'x2'), (2, 'y3')])
1105        query('drop table "%s"' % table)
1106
1107    def testUpsertWithCompositeKey(self):
1108        upsert = self.db.upsert
1109        query = self.db.query
1110        table = 'upsert_test_table_2'
1111        query('drop table if exists "%s"' % table)
1112        query('create table "%s" ('
1113            "n integer, m integer, t text, primary key (n, m))" % table)
1114        s = dict(n=1, m=2, t='x')
1115        try:
1116            r = upsert(table, s)
1117        except pg.ProgrammingError as error:
1118            if self.db.server_version < 90500:
1119                self.skipTest('database does not support upsert')
1120            self.fail(str(error))
1121        self.assertIs(r, s)
1122        self.assertEqual(r['n'], 1)
1123        self.assertEqual(r['m'], 2)
1124        self.assertEqual(r['t'], 'x')
1125        s.update(m=3, t='y')
1126        r = upsert(table, s, **dict.fromkeys(s))
1127        self.assertIs(r, s)
1128        self.assertEqual(r['n'], 1)
1129        self.assertEqual(r['m'], 3)
1130        self.assertEqual(r['t'], 'y')
1131        q = 'select n, m, t from "%s" order by n, m limit 3' % table
1132        r = query(q).getresult()
1133        self.assertEqual(r, [(1, 2, 'x'), (1, 3, 'y')])
1134        s.update(t='z')
1135        r = upsert(table, s)
1136        self.assertIs(r, s)
1137        self.assertEqual(r['n'], 1)
1138        self.assertEqual(r['m'], 3)
1139        self.assertEqual(r['t'], 'z')
1140        r = query(q).getresult()
1141        self.assertEqual(r, [(1, 2, 'x'), (1, 3, 'z')])
1142        s.update(t='n')
1143        r = upsert(table, s, t=False)
1144        self.assertIs(r, s)
1145        self.assertEqual(r['n'], 1)
1146        self.assertEqual(r['m'], 3)
1147        self.assertEqual(r['t'], 'z')
1148        r = query(q).getresult()
1149        self.assertEqual(r, [(1, 2, 'x'), (1, 3, 'z')])
1150        s.update(t='n')
1151        r = upsert(table, s, t=True)
1152        self.assertIs(r, s)
1153        self.assertEqual(r['n'], 1)
1154        self.assertEqual(r['m'], 3)
1155        self.assertEqual(r['t'], 'n')
1156        r = query(q).getresult()
1157        self.assertEqual(r, [(1, 2, 'x'), (1, 3, 'n')])
1158        s.update(n=2, t='y')
1159        r = upsert(table, s, t="'z'")
1160        self.assertIs(r, s)
1161        self.assertEqual(r['n'], 2)
1162        self.assertEqual(r['m'], 3)
1163        self.assertEqual(r['t'], 'y')
1164        r = query(q).getresult()
1165        self.assertEqual(r, [(1, 2, 'x'), (1, 3, 'n'), (2, 3, 'y')])
1166        s.update(n=1, t='m')
1167        r = upsert(table, s, t='included.t || excluded.t')
1168        self.assertIs(r, s)
1169        self.assertEqual(r['n'], 1)
1170        self.assertEqual(r['m'], 3)
1171        self.assertEqual(r['t'], 'nm')
1172        r = query(q).getresult()
1173        self.assertEqual(r, [(1, 2, 'x'), (1, 3, 'nm'), (2, 3, 'y')])
1174        query('drop table "%s"' % table)
1175
1176    def testUpsertWithQuotedNames(self):
1177        upsert = self.db.upsert
1178        query = self.db.query
1179        table = 'test table for upsert()'
1180        query('drop table if exists "%s"' % table)
1181        query('create table "%s" ('
1182            '"Prime!" smallint primary key,'
1183            '"much space" integer, "Questions?" text)' % table)
1184        s = {'Prime!': 31, 'much space': 9009, 'Questions?': 'Yes.'}
1185        try:
1186            r = upsert(table, s)
1187        except pg.ProgrammingError as error:
1188            if self.db.server_version < 90500:
1189                self.skipTest('database does not support upsert')
1190            self.fail(str(error))
1191        self.assertIs(r, s)
1192        self.assertEqual(r['Prime!'], 31)
1193        self.assertEqual(r['much space'], 9009)
1194        self.assertEqual(r['Questions?'], 'Yes.')
1195        q = 'select * from "%s" limit 2' % table
1196        r = query(q).getresult()
1197        self.assertEqual(r, [(31, 9009, 'Yes.')])
1198        s.update({'Questions?': 'No.'})
1199        r = upsert(table, s)
1200        self.assertIs(r, s)
1201        self.assertEqual(r['Prime!'], 31)
1202        self.assertEqual(r['much space'], 9009)
1203        self.assertEqual(r['Questions?'], 'No.')
1204        r = query(q).getresult()
1205        self.assertEqual(r, [(31, 9009, 'No.')])
1206
1207    def testClear(self):
1208        clear = self.db.clear
1209        query = self.db.query
1210        f = False if pg.get_bool() else 'f'
1211        table = 'clear_test_table'
1212        query('drop table if exists "%s"' % table)
1213        query('create table "%s" ('
1214            "n integer, b boolean, d date, t text)" % table)
1215        r = clear(table)
1216        result = {'n': 0, 'b': f, 'd': '', 't': ''}
1217        self.assertEqual(r, result)
1218        r['a'] = r['n'] = 1
1219        r['d'] = r['t'] = 'x'
1220        r['b'] = 't'
1221        r['oid'] = long(1)
1222        r = clear(table, r)
1223        result = {'a': 1, 'n': 0, 'b': f, 'd': '', 't': '',
1224            'oid': long(1)}
1225        self.assertEqual(r, result)
1226        query('drop table "%s"' % table)
1227
1228    def testClearWithQuotedNames(self):
1229        clear = self.db.clear
1230        query = self.db.query
1231        table = 'test table for clear()'
1232        query('drop table if exists "%s"' % table)
1233        query('create table "%s" ('
1234            '"Prime!" smallint primary key,'
1235            '"much space" integer, "Questions?" text)' % table)
1236        r = clear(table)
1237        self.assertIsInstance(r, dict)
1238        self.assertEqual(r['Prime!'], 0)
1239        self.assertEqual(r['much space'], 0)
1240        self.assertEqual(r['Questions?'], '')
1241        query('drop table "%s"' % table)
1242
1243    def testDelete(self):
1244        delete = self.db.delete
1245        query = self.db.query
1246        table = 'delete_test_table'
1247        query('drop table if exists "%s"' % table)
1248        query('create table "%s" ('
1249            "n integer, t text) with oids" % table)
1250        for n, t in enumerate('xyz'):
1251            query('insert into "%s" values('
1252                "%d, '%s')" % (table, n + 1, t))
1253        self.assertRaises(pg.ProgrammingError, self.db.get, table, 2)
1254        r = self.db.get(table, 1, 'n')
1255        s = delete(table, r)
1256        self.assertEqual(s, 1)
1257        r = self.db.get(table, 3, 'n')
1258        s = delete(table, r)
1259        self.assertEqual(s, 1)
1260        s = delete(table, r)
1261        self.assertEqual(s, 0)
1262        r = query('select * from "%s"' % table).dictresult()
1263        self.assertEqual(len(r), 1)
1264        r = r[0]
1265        result = {'n': 2, 't': 'y'}
1266        self.assertEqual(r, result)
1267        r = self.db.get(table, 2, 'n')
1268        s = delete(table, r)
1269        self.assertEqual(s, 1)
1270        s = delete(table, r)
1271        self.assertEqual(s, 0)
1272        self.assertRaises(pg.DatabaseError, self.db.get, table, 2, 'n')
1273        query('drop table "%s"' % table)
1274
1275    def testDeleteWithCompositeKey(self):
1276        query = self.db.query
1277        table = 'delete_test_table_1'
1278        query('drop table if exists "%s"' % table)
1279        query('create table "%s" ('
1280            "n integer, t text, primary key (n))" % table)
1281        for n, t in enumerate('abc'):
1282            query("insert into %s values("
1283                "%d, '%s')" % (table, n + 1, t))
1284        self.assertRaises(pg.ProgrammingError, self.db.delete,
1285            table, dict(t='b'))
1286        self.assertEqual(self.db.delete(table, dict(n=2)), 1)
1287        r = query('select t from "%s" where n=2' % table
1288                  ).getresult()
1289        self.assertEqual(r, [])
1290        self.assertEqual(self.db.delete(table, dict(n=2)), 0)
1291        r = query('select t from "%s" where n=3' % table
1292                  ).getresult()[0][0]
1293        self.assertEqual(r, 'c')
1294        query('drop table "%s"' % table)
1295        table = 'delete_test_table_2'
1296        query('drop table if exists "%s"' % table)
1297        query('create table "%s" ('
1298            "n integer, m integer, t text, primary key (n, m))" % table)
1299        for n in range(3):
1300            for m in range(2):
1301                t = chr(ord('a') + 2 * n + m)
1302                query('insert into "%s" values('
1303                    "%d, %d, '%s')" % (table, n + 1, m + 1, t))
1304        self.assertRaises(pg.ProgrammingError, self.db.delete,
1305            table, dict(n=2, t='b'))
1306        self.assertEqual(self.db.delete(table, dict(n=2, m=2)), 1)
1307        r = [r[0] for r in query('select t from "%s" where n=2'
1308            ' order by m' % table).getresult()]
1309        self.assertEqual(r, ['c'])
1310        self.assertEqual(self.db.delete(table, dict(n=2, m=2)), 0)
1311        r = [r[0] for r in query('select t from "%s" where n=3'
1312            ' order by m' % table).getresult()]
1313        self.assertEqual(r, ['e', 'f'])
1314        self.assertEqual(self.db.delete(table, dict(n=3, m=1)), 1)
1315        r = [r[0] for r in query('select t from "%s" where n=3'
1316            ' order by m' % table).getresult()]
1317        self.assertEqual(r, ['f'])
1318        query('drop table "%s"' % table)
1319
1320    def testDeleteWithQuotedNames(self):
1321        delete = self.db.delete
1322        query = self.db.query
1323        table = 'test table for delete()'
1324        query('drop table if exists "%s"' % table)
1325        query('create table "%s" ('
1326            '"Prime!" smallint primary key,'
1327            '"much space" integer, "Questions?" text)' % table)
1328        query('insert into "%s"'
1329              " values(19, 5005, 'Yes!')" % table)
1330        r = {'Prime!': 17}
1331        r = delete(table, r)
1332        self.assertEqual(r, 0)
1333        r = query('select count(*) from "%s"' % table).getresult()
1334        self.assertEqual(r[0][0], 1)
1335        r = {'Prime!': 19}
1336        r = delete(table, r)
1337        self.assertEqual(r, 1)
1338        r = query('select count(*) from "%s"' % table).getresult()
1339        self.assertEqual(r[0][0], 0)
1340        query('drop table "%s"' % table)
1341
1342    def testTransaction(self):
1343        query = self.db.query
1344        query("drop table if exists test_table")
1345        query("create table test_table (n integer)")
1346        self.db.begin()
1347        query("insert into test_table values (1)")
1348        query("insert into test_table values (2)")
1349        self.db.commit()
1350        self.db.begin()
1351        query("insert into test_table values (3)")
1352        query("insert into test_table values (4)")
1353        self.db.rollback()
1354        self.db.begin()
1355        query("insert into test_table values (5)")
1356        self.db.savepoint('before6')
1357        query("insert into test_table values (6)")
1358        self.db.rollback('before6')
1359        query("insert into test_table values (7)")
1360        self.db.commit()
1361        self.db.begin()
1362        self.db.savepoint('before8')
1363        query("insert into test_table values (8)")
1364        self.db.release('before8')
1365        self.assertRaises(pg.ProgrammingError, self.db.rollback, 'before8')
1366        self.db.commit()
1367        self.db.start()
1368        query("insert into test_table values (9)")
1369        self.db.end()
1370        r = [r[0] for r in query(
1371            "select * from test_table order by 1").getresult()]
1372        self.assertEqual(r, [1, 2, 5, 7, 9])
1373        query("drop table test_table")
1374
1375    def testContextManager(self):
1376        query = self.db.query
1377        query("drop table if exists test_table")
1378        query("create table test_table (n integer check(n>0))")
1379        with self.db:
1380            query("insert into test_table values (1)")
1381            query("insert into test_table values (2)")
1382        try:
1383            with self.db:
1384                query("insert into test_table values (3)")
1385                query("insert into test_table values (4)")
1386                raise ValueError('test transaction should rollback')
1387        except ValueError as error:
1388            self.assertEqual(str(error), 'test transaction should rollback')
1389        with self.db:
1390            query("insert into test_table values (5)")
1391        try:
1392            with self.db:
1393                query("insert into test_table values (6)")
1394                query("insert into test_table values (-1)")
1395        except pg.ProgrammingError as error:
1396            self.assertTrue('check' in str(error))
1397        with self.db:
1398            query("insert into test_table values (7)")
1399        r = [r[0] for r in query(
1400            "select * from test_table order by 1").getresult()]
1401        self.assertEqual(r, [1, 2, 5, 7])
1402        query("drop table test_table")
1403
1404    def testBytea(self):
1405        query = self.db.query
1406        query('drop table if exists bytea_test')
1407        query('create table bytea_test (n smallint primary key, data bytea)')
1408        s = b"It's all \\ kinds \x00 of\r nasty \xff stuff!\n"
1409        r = self.db.escape_bytea(s)
1410        query('insert into bytea_test values(3,$1)', (r,))
1411        r = query('select * from bytea_test where n=3').getresult()
1412        self.assertEqual(len(r), 1)
1413        r = r[0]
1414        self.assertEqual(len(r), 2)
1415        self.assertEqual(r[0], 3)
1416        r = r[1]
1417        self.assertIsInstance(r, str)
1418        r = self.db.unescape_bytea(r)
1419        self.assertIsInstance(r, bytes)
1420        self.assertEqual(r, s)
1421        query('drop table bytea_test')
1422
1423    def testInsertUpdateGetBytea(self):
1424        query = self.db.query
1425        query('drop table if exists bytea_test')
1426        query('create table bytea_test (n smallint primary key, data bytea)')
1427        # insert null value
1428        r = self.db.insert('bytea_test', n=0, data=None)
1429        self.assertIsInstance(r, dict)
1430        self.assertIn('n', r)
1431        self.assertEqual(r['n'], 0)
1432        self.assertIn('data', r)
1433        self.assertIsNone(r['data'])
1434        s = b'None'
1435        r = self.db.update('bytea_test', n=0, data=s)
1436        self.assertIsInstance(r, dict)
1437        self.assertIn('n', r)
1438        self.assertEqual(r['n'], 0)
1439        self.assertIn('data', r)
1440        r = r['data']
1441        self.assertIsInstance(r, bytes)
1442        self.assertEqual(r, s)
1443        r = self.db.update('bytea_test', n=0, data=None)
1444        self.assertIsNone(r['data'])
1445        # insert as bytes
1446        s = b"It's all \\ kinds \x00 of\r nasty \xff stuff!\n"
1447        r = self.db.insert('bytea_test', n=5, data=s)
1448        self.assertIsInstance(r, dict)
1449        self.assertIn('n', r)
1450        self.assertEqual(r['n'], 5)
1451        self.assertIn('data', r)
1452        r = r['data']
1453        self.assertIsInstance(r, bytes)
1454        self.assertEqual(r, s)
1455        # update as bytes
1456        s += b"and now even more \x00 nasty \t stuff!\f"
1457        r = self.db.update('bytea_test', n=5, data=s)
1458        self.assertIsInstance(r, dict)
1459        self.assertIn('n', r)
1460        self.assertEqual(r['n'], 5)
1461        self.assertIn('data', r)
1462        r = r['data']
1463        self.assertIsInstance(r, bytes)
1464        self.assertEqual(r, s)
1465        r = query('select * from bytea_test where n=5').getresult()
1466        self.assertEqual(len(r), 1)
1467        r = r[0]
1468        self.assertEqual(len(r), 2)
1469        self.assertEqual(r[0], 5)
1470        r = r[1]
1471        self.assertIsInstance(r, str)
1472        r = self.db.unescape_bytea(r)
1473        self.assertIsInstance(r, bytes)
1474        self.assertEqual(r, s)
1475        r = self.db.get('bytea_test', dict(n=5))
1476        self.assertIsInstance(r, dict)
1477        self.assertIn('n', r)
1478        self.assertEqual(r['n'], 5)
1479        self.assertIn('data', r)
1480        r = r['data']
1481        self.assertIsInstance(r, bytes)
1482        self.assertEqual(r, s)
1483        query('drop table bytea_test')
1484
1485    def testDebugWithCallable(self):
1486        if debug:
1487            self.assertEqual(self.db.debug, debug)
1488        else:
1489            self.assertIsNone(self.db.debug)
1490        s = []
1491        self.db.debug = s.append
1492        try:
1493            self.db.query("select 1")
1494            self.db.query("select 2")
1495            self.assertEqual(s, ["select 1", "select 2"])
1496        finally:
1497            self.db.debug = debug
1498
1499
1500class TestDBClassNonStdOpts(TestDBClass):
1501    """Test the methods of the DB class with non-standard global options."""
1502
1503    @classmethod
1504    def setUpClass(cls):
1505        cls.saved_options = {}
1506        cls.set_option('decimal', float)
1507        not_bool = not pg.get_bool()
1508        cls.set_option('bool', not_bool)
1509        unnamed_result = lambda q: q.getresult()
1510        cls.set_option('namedresult', unnamed_result)
1511        super(TestDBClassNonStdOpts, cls).setUpClass()
1512
1513    @classmethod
1514    def tearDownClass(cls):
1515        super(TestDBClassNonStdOpts, cls).tearDownClass()
1516        cls.reset_option('namedresult')
1517        cls.reset_option('bool')
1518        cls.reset_option('decimal')
1519
1520    @classmethod
1521    def set_option(cls, option, value):
1522        cls.saved_options[option] = getattr(pg, 'get_' + option)()
1523        return getattr(pg, 'set_' + option)(value)
1524
1525    @classmethod
1526    def reset_option(cls, option):
1527        return getattr(pg, 'set_' + option)(cls.saved_options[option])
1528
1529
1530class TestSchemas(unittest.TestCase):
1531    """Test correct handling of schemas (namespaces)."""
1532
1533    @classmethod
1534    def setUpClass(cls):
1535        db = DB()
1536        query = db.query
1537        query("set client_min_messages=warning")
1538        for num_schema in range(5):
1539            if num_schema:
1540                schema = "s%d" % num_schema
1541                query("drop schema if exists %s cascade" % (schema,))
1542                try:
1543                    query("create schema %s" % (schema,))
1544                except pg.ProgrammingError:
1545                    raise RuntimeError("The test user cannot create schemas.\n"
1546                        "Grant create on database %s to the user"
1547                        " for running these tests." % dbname)
1548            else:
1549                schema = "public"
1550                query("drop table if exists %s.t" % (schema,))
1551                query("drop table if exists %s.t%d" % (schema, num_schema))
1552            query("create table %s.t with oids as select 1 as n, %d as d"
1553                  % (schema, num_schema))
1554            query("create table %s.t%d with oids as select 1 as n, %d as d"
1555                  % (schema, num_schema, num_schema))
1556        db.close()
1557
1558    @classmethod
1559    def tearDownClass(cls):
1560        db = DB()
1561        query = db.query
1562        query("set client_min_messages=warning")
1563        for num_schema in range(5):
1564            if num_schema:
1565                schema = "s%d" % num_schema
1566                query("drop schema %s cascade" % (schema,))
1567            else:
1568                schema = "public"
1569                query("drop table %s.t" % (schema,))
1570                query("drop table %s.t%d" % (schema, num_schema))
1571        db.close()
1572
1573    def setUp(self):
1574        self.db = DB()
1575        self.db.query("set client_min_messages=warning")
1576
1577    def tearDown(self):
1578        self.db.close()
1579
1580    def testGetTables(self):
1581        tables = self.db.get_tables()
1582        for num_schema in range(5):
1583            if num_schema:
1584                schema = "s" + str(num_schema)
1585            else:
1586                schema = "public"
1587            for t in (schema + ".t",
1588                    schema + ".t" + str(num_schema)):
1589                self.assertIn(t, tables)
1590
1591    def testGetAttnames(self):
1592        get_attnames = self.db.get_attnames
1593        query = self.db.query
1594        result = {'oid': 'int', 'd': 'int', 'n': 'int'}
1595        r = get_attnames("t")
1596        self.assertEqual(r, result)
1597        r = get_attnames("s4.t4")
1598        self.assertEqual(r, result)
1599        query("drop table if exists s3.t3m")
1600        query("create table s3.t3m with oids as select 1 as m")
1601        result_m = {'oid': 'int', 'm': 'int'}
1602        r = get_attnames("s3.t3m")
1603        self.assertEqual(r, result_m)
1604        query("set search_path to s1,s3")
1605        r = get_attnames("t3")
1606        self.assertEqual(r, result)
1607        r = get_attnames("t3m")
1608        self.assertEqual(r, result_m)
1609        query("drop table s3.t3m")
1610
1611    def testGet(self):
1612        get = self.db.get
1613        query = self.db.query
1614        PrgError = pg.ProgrammingError
1615        self.assertEqual(get("t", 1, 'n')['d'], 0)
1616        self.assertEqual(get("t0", 1, 'n')['d'], 0)
1617        self.assertEqual(get("public.t", 1, 'n')['d'], 0)
1618        self.assertEqual(get("public.t0", 1, 'n')['d'], 0)
1619        self.assertRaises(PrgError, get, "public.t1", 1, 'n')
1620        self.assertEqual(get("s1.t1", 1, 'n')['d'], 1)
1621        self.assertEqual(get("s3.t", 1, 'n')['d'], 3)
1622        query("set search_path to s2,s4")
1623        self.assertRaises(PrgError, get, "t1", 1, 'n')
1624        self.assertEqual(get("t4", 1, 'n')['d'], 4)
1625        self.assertRaises(PrgError, get, "t3", 1, 'n')
1626        self.assertEqual(get("t", 1, 'n')['d'], 2)
1627        self.assertEqual(get("s3.t3", 1, 'n')['d'], 3)
1628        query("set search_path to s1,s3")
1629        self.assertRaises(PrgError, get, "t2", 1, 'n')
1630        self.assertEqual(get("t3", 1, 'n')['d'], 3)
1631        self.assertRaises(PrgError, get, "t4", 1, 'n')
1632        self.assertEqual(get("t", 1, 'n')['d'], 1)
1633        self.assertEqual(get("s4.t4", 1, 'n')['d'], 4)
1634
1635    def testMunging(self):
1636        get = self.db.get
1637        query = self.db.query
1638        r = get("t", 1, 'n')
1639        self.assertIn('oid(t)', r)
1640        query("set search_path to s2")
1641        r = get("t2", 1, 'n')
1642        self.assertIn('oid(t2)', r)
1643        query("set search_path to s3")
1644        r = get("t", 1, 'n')
1645        self.assertIn('oid(t)', r)
1646
1647
1648if __name__ == '__main__':
1649    unittest.main()
Note: See TracBrowser for help on using the repository browser.