source: branches/4.x/module/tests/test_classic_connection.py @ 690

Last change on this file since 690 was 690, checked in by cito, 4 years ago

Amend tests so that they can run with PostgreSQL < 9.0

Note that we do not need to make these amendments in the trunk,
because we assume PostgreSQL >= 9.0 for PyGreSQL version 5.0.

  • Property svn:executable set to *
  • Property svn:keywords set to Id
File size: 43.4 KB
Line 
1#! /usr/bin/python
2# -*- coding: utf-8 -*-
3
4"""Test the classic PyGreSQL interface.
5
6Sub-tests for the low-level connection object.
7
8Contributed by Christoph Zwerschke.
9
10These tests need a database to test against.
11
12"""
13
14try:
15    import unittest2 as unittest  # for Python < 2.7
16except ImportError:
17    import unittest
18import sys
19import tempfile
20import threading
21import time
22
23import pg  # the module under test
24
25from decimal import Decimal
26try:
27    from collections import namedtuple
28except ImportError:  # Python < 2.6
29    namedtuple = None
30
31# We need a database to test against.  If LOCAL_PyGreSQL.py exists we will
32# get our information from that.  Otherwise we use the defaults.
33dbname = 'unittest'
34dbhost = None
35dbport = 5432
36
37try:
38    from LOCAL_PyGreSQL import *
39except ImportError:
40    pass
41
42
43def connect():
44    """Create a basic pg connection to the test database."""
45    connection = pg.connect(dbname, dbhost, dbport)
46    connection.query("set client_min_messages=warning")
47    return connection
48
49
50class TestCanConnect(unittest.TestCase):
51    """Test whether a basic connection to PostgreSQL is possible."""
52
53    def testCanConnect(self):
54        try:
55            connection = connect()
56        except pg.Error, error:
57            self.fail('Cannot connect to database %s:\n%s' % (dbname, error))
58        try:
59            connection.close()
60        except pg.Error:
61            self.fail('Cannot close the database connection')
62
63
64class TestConnectObject(unittest.TestCase):
65    """Test existence of basic pg connection methods."""
66
67    def setUp(self):
68        self.connection = connect()
69
70    def tearDown(self):
71        try:
72            self.connection.close()
73        except pg.InternalError:
74            pass
75
76    def testAllConnectAttributes(self):
77        attributes = '''db error host options port
78            protocol_version server_version status tty user'''.split()
79        connection_attributes = [a for a in dir(self.connection)
80            if not callable(eval("self.connection." + a))]
81        self.assertEqual(attributes, connection_attributes)
82
83    def testAllConnectMethods(self):
84        methods = '''cancel close endcopy
85            escape_bytea escape_identifier escape_literal escape_string
86            fileno get_notice_receiver getline getlo getnotify
87            inserttable locreate loimport parameter putline query reset
88            set_notice_receiver source transaction'''.split()
89        if self.connection.server_version < 90000:  # PostgreSQL < 9.0
90            methods.remove('escape_identifier')
91            methods.remove('escape_literal')
92        connection_methods = [a for a in dir(self.connection)
93            if callable(eval("self.connection." + a))]
94        self.assertEqual(methods, connection_methods)
95
96    def testAttributeDb(self):
97        self.assertEqual(self.connection.db, dbname)
98
99    def testAttributeError(self):
100        error = self.connection.error
101        self.assertTrue(not error or 'krb5_' in error)
102
103    def testAttributeHost(self):
104        def_host = 'localhost'
105        self.assertIsInstance(self.connection.host, str)
106        self.assertEqual(self.connection.host, dbhost or def_host)
107
108    def testAttributeOptions(self):
109        no_options = ''
110        self.assertEqual(self.connection.options, no_options)
111
112    def testAttributePort(self):
113        def_port = 5432
114        self.assertIsInstance(self.connection.port, int)
115        self.assertEqual(self.connection.port, dbport or def_port)
116
117    def testAttributeProtocolVersion(self):
118        protocol_version = self.connection.protocol_version
119        self.assertIsInstance(protocol_version, int)
120        self.assertTrue(2 <= protocol_version < 4)
121
122    def testAttributeServerVersion(self):
123        server_version = self.connection.server_version
124        self.assertIsInstance(server_version, int)
125        self.assertTrue(70400 <= server_version < 100000)
126
127    def testAttributeStatus(self):
128        status_ok = 1
129        self.assertIsInstance(self.connection.status, int)
130        self.assertEqual(self.connection.status, status_ok)
131
132    def testAttributeTty(self):
133        def_tty = ''
134        self.assertIsInstance(self.connection.tty, str)
135        self.assertEqual(self.connection.tty, def_tty)
136
137    def testAttributeUser(self):
138        no_user = 'Deprecated facility'
139        user = self.connection.user
140        self.assertTrue(user)
141        self.assertIsInstance(user, str)
142        self.assertNotEqual(user, no_user)
143
144    def testMethodQuery(self):
145        query = self.connection.query
146        query("select 1+1")
147        query("select 1+$1", (1,))
148        query("select 1+$1+$2", (2, 3))
149        query("select 1+$1+$2", [2, 3])
150
151    def testMethodQueryEmpty(self):
152        self.assertRaises(ValueError, self.connection.query, '')
153
154    def testMethodEndcopy(self):
155        try:
156            self.connection.endcopy()
157        except IOError:
158            pass
159
160    def testMethodClose(self):
161        self.connection.close()
162        try:
163            self.connection.reset()
164        except (pg.Error, TypeError):
165            pass
166        else:
167            self.fail('Reset should give an error for a closed connection')
168        self.assertRaises(pg.InternalError, self.connection.close)
169        try:
170            self.connection.query('select 1')
171        except (pg.Error, TypeError):
172            pass
173        else:
174            self.fail('Query should give an error for a closed connection')
175        self.connection = connect()
176
177    def testMethodReset(self):
178        query = self.connection.query
179        # check that client encoding gets reset
180        encoding = query('show client_encoding').getresult()[0][0].upper()
181        changed_encoding = encoding == 'UTF8' and 'LATIN1' or 'UTF8'
182        self.assertNotEqual(encoding, changed_encoding)
183        self.connection.query("set client_encoding=%s" % changed_encoding)
184        new_encoding = query('show client_encoding').getresult()[0][0].upper()
185        self.assertEqual(new_encoding, changed_encoding)
186        self.connection.reset()
187        new_encoding = query('show client_encoding').getresult()[0][0].upper()
188        self.assertNotEqual(new_encoding, changed_encoding)
189        self.assertEqual(new_encoding, encoding)
190
191    def testMethodCancel(self):
192        r = self.connection.cancel()
193        self.assertIsInstance(r, int)
194        self.assertEqual(r, 1)
195
196    def testCancelLongRunningThread(self):
197        errors = []
198
199        def sleep():
200            try:
201                self.connection.query('select pg_sleep(5)').getresult()
202            except pg.ProgrammingError, error:
203                errors.append(str(error))
204
205        thread = threading.Thread(target=sleep)
206        t1 = time.time()
207        thread.start()  # run the query
208        while 1:  # make sure the query is really running
209            time.sleep(0.1)
210            if thread.isAlive() or time.time() - t1 > 5:
211                break
212        r = self.connection.cancel()  # cancel the running query
213        thread.join()  # wait for the thread to end
214        t2 = time.time()
215
216        self.assertIsInstance(r, int)
217        self.assertEqual(r, 1)  # return code should be 1
218        self.assertLessEqual(t2 - t1, 3)  # time should be under 3 seconds
219        self.assertTrue(errors)
220
221    def testMethodFileNo(self):
222        r = self.connection.fileno()
223        self.assertIsInstance(r, int)
224        self.assertGreaterEqual(r, 0)
225
226
227class TestSimpleQueries(unittest.TestCase):
228    """Test simple queries via a basic pg connection."""
229
230    def setUp(self):
231        self.c = connect()
232
233    def tearDown(self):
234        self.c.close()
235
236    def testSelect0(self):
237        q = "select 0"
238        self.c.query(q)
239
240    def testSelect0Semicolon(self):
241        q = "select 0;"
242        self.c.query(q)
243
244    def testSelectDotSemicolon(self):
245        q = "select .;"
246        self.assertRaises(pg.ProgrammingError, self.c.query, q)
247
248    def testGetresult(self):
249        q = "select 0"
250        result = [(0,)]
251        r = self.c.query(q).getresult()
252        self.assertIsInstance(r, list)
253        v = r[0]
254        self.assertIsInstance(v, tuple)
255        self.assertIsInstance(v[0], int)
256        self.assertEqual(r, result)
257
258    def testGetresultLong(self):
259        q = "select 9876543210"
260        result = 9876543210L
261        v = self.c.query(q).getresult()[0][0]
262        self.assertIsInstance(v, long)
263        self.assertEqual(v, result)
264
265    def testGetresultDecimal(self):
266        q = "select 98765432109876543210"
267        result = Decimal(98765432109876543210L)
268        v = self.c.query(q).getresult()[0][0]
269        self.assertIsInstance(v, Decimal)
270        self.assertEqual(v, result)
271
272    def testGetresultString(self):
273        result = 'Hello, world!'
274        q = "select '%s'" % result
275        v = self.c.query(q).getresult()[0][0]
276        self.assertIsInstance(v, str)
277        self.assertEqual(v, result)
278
279    def testDictresult(self):
280        q = "select 0 as alias0"
281        result = [{'alias0': 0}]
282        r = self.c.query(q).dictresult()
283        self.assertIsInstance(r, list)
284        v = r[0]
285        self.assertIsInstance(v, dict)
286        self.assertIsInstance(v['alias0'], int)
287        self.assertEqual(r, result)
288
289    def testDictresultLong(self):
290        q = "select 9876543210 as longjohnsilver"
291        result = 9876543210L
292        v = self.c.query(q).dictresult()[0]['longjohnsilver']
293        self.assertIsInstance(v, long)
294        self.assertEqual(v, result)
295
296    def testDictresultDecimal(self):
297        q = "select 98765432109876543210 as longjohnsilver"
298        result = Decimal(98765432109876543210L)
299        v = self.c.query(q).dictresult()[0]['longjohnsilver']
300        self.assertIsInstance(v, Decimal)
301        self.assertEqual(v, result)
302
303    def testDictresultString(self):
304        result = 'Hello, world!'
305        q = "select '%s' as greeting" % result
306        v = self.c.query(q).dictresult()[0]['greeting']
307        self.assertIsInstance(v, str)
308        self.assertEqual(v, result)
309
310    @unittest.skipUnless(namedtuple, 'Named tuples not available')
311    def testNamedresult(self):
312        q = "select 0 as alias0"
313        result = [(0,)]
314        r = self.c.query(q).namedresult()
315        self.assertEqual(r, result)
316        v = r[0]
317        self.assertEqual(v._fields, ('alias0',))
318        self.assertEqual(v.alias0, 0)
319
320    def testGet3Cols(self):
321        q = "select 1,2,3"
322        result = [(1, 2, 3)]
323        r = self.c.query(q).getresult()
324        self.assertEqual(r, result)
325
326    def testGet3DictCols(self):
327        q = "select 1 as a,2 as b,3 as c"
328        result = [dict(a=1, b=2, c=3)]
329        r = self.c.query(q).dictresult()
330        self.assertEqual(r, result)
331
332    @unittest.skipUnless(namedtuple, 'Named tuples not available')
333    def testGet3NamedCols(self):
334        q = "select 1 as a,2 as b,3 as c"
335        result = [(1, 2, 3)]
336        r = self.c.query(q).namedresult()
337        self.assertEqual(r, result)
338        v = r[0]
339        self.assertEqual(v._fields, ('a', 'b', 'c'))
340        self.assertEqual(v.b, 2)
341
342    def testGet3Rows(self):
343        q = "select 3 union select 1 union select 2 order by 1"
344        result = [(1,), (2,), (3,)]
345        r = self.c.query(q).getresult()
346        self.assertEqual(r, result)
347
348    def testGet3DictRows(self):
349        q = ("select 3 as alias3"
350            " union select 1 union select 2 order by 1")
351        result = [{'alias3': 1}, {'alias3': 2}, {'alias3': 3}]
352        r = self.c.query(q).dictresult()
353        self.assertEqual(r, result)
354
355    @unittest.skipUnless(namedtuple, 'Named tuples not available')
356    def testGet3NamedRows(self):
357        q = ("select 3 as alias3"
358            " union select 1 union select 2 order by 1")
359        result = [(1,), (2,), (3,)]
360        r = self.c.query(q).namedresult()
361        self.assertEqual(r, result)
362        for v in r:
363            self.assertEqual(v._fields, ('alias3',))
364
365    def testDictresultNames(self):
366        q = "select 'MixedCase' as MixedCaseAlias"
367        result = [{'mixedcasealias': 'MixedCase'}]
368        r = self.c.query(q).dictresult()
369        self.assertEqual(r, result)
370        q = "select 'MixedCase' as \"MixedCaseAlias\""
371        result = [{'MixedCaseAlias': 'MixedCase'}]
372        r = self.c.query(q).dictresult()
373        self.assertEqual(r, result)
374
375    @unittest.skipUnless(namedtuple, 'Named tuples not available')
376    def testNamedresultNames(self):
377        q = "select 'MixedCase' as MixedCaseAlias"
378        result = [('MixedCase',)]
379        r = self.c.query(q).namedresult()
380        self.assertEqual(r, result)
381        v = r[0]
382        self.assertEqual(v._fields, ('mixedcasealias',))
383        self.assertEqual(v.mixedcasealias, 'MixedCase')
384        q = "select 'MixedCase' as \"MixedCaseAlias\""
385        r = self.c.query(q).namedresult()
386        self.assertEqual(r, result)
387        v = r[0]
388        self.assertEqual(v._fields, ('MixedCaseAlias',))
389        self.assertEqual(v.MixedCaseAlias, 'MixedCase')
390
391    def testBigGetresult(self):
392        num_cols = 100
393        num_rows = 100
394        q = "select " + ','.join(map(str, xrange(num_cols)))
395        q = ' union all '.join((q,) * num_rows)
396        r = self.c.query(q).getresult()
397        result = [tuple(range(num_cols))] * num_rows
398        self.assertEqual(r, result)
399
400    def testListfields(self):
401        q = ('select 0 as a, 0 as b, 0 as c,'
402            ' 0 as c, 0 as b, 0 as a,'
403            ' 0 as lowercase, 0 as UPPERCASE,'
404            ' 0 as MixedCase, 0 as "MixedCase",'
405            ' 0 as a_long_name_with_underscores,'
406            ' 0 as "A long name with Blanks"')
407        r = self.c.query(q).listfields()
408        result = ('a', 'b', 'c', 'c', 'b', 'a',
409            'lowercase', 'uppercase', 'mixedcase', 'MixedCase',
410            'a_long_name_with_underscores',
411            'A long name with Blanks')
412        self.assertEqual(r, result)
413
414    def testFieldname(self):
415        q = "select 0 as z, 0 as a, 0 as x, 0 as y"
416        r = self.c.query(q).fieldname(2)
417        self.assertEqual(r, 'x')
418        r = self.c.query(q).fieldname(3)
419        self.assertEqual(r, 'y')
420
421    def testFieldnum(self):
422        q = "select 1 as x"
423        self.assertRaises(ValueError, self.c.query(q).fieldnum, 'y')
424        q = "select 1 as x"
425        r = self.c.query(q).fieldnum('x')
426        self.assertIsInstance(r, int)
427        self.assertEqual(r, 0)
428        q = "select 0 as z, 0 as a, 0 as x, 0 as y"
429        r = self.c.query(q).fieldnum('x')
430        self.assertIsInstance(r, int)
431        self.assertEqual(r, 2)
432        r = self.c.query(q).fieldnum('y')
433        self.assertIsInstance(r, int)
434        self.assertEqual(r, 3)
435
436    def testNtuples(self):
437        q = "select 1 where false"
438        r = self.c.query(q).ntuples()
439        self.assertIsInstance(r, int)
440        self.assertEqual(r, 0)
441        q = ("select 1 as a, 2 as b, 3 as c, 4 as d"
442            " union select 5 as a, 6 as b, 7 as c, 8 as d")
443        r = self.c.query(q).ntuples()
444        self.assertIsInstance(r, int)
445        self.assertEqual(r, 2)
446        q = ("select 1 union select 2 union select 3"
447            " union select 4 union select 5 union select 6")
448        r = self.c.query(q).ntuples()
449        self.assertIsInstance(r, int)
450        self.assertEqual(r, 6)
451
452    def testQuery(self):
453        query = self.c.query
454        query("drop table if exists test_table")
455        q = "create table test_table (n integer) with oids"
456        r = query(q)
457        self.assertIsNone(r)
458        q = "insert into test_table values (1)"
459        r = query(q)
460        self.assertIsInstance(r, int)
461        q = "insert into test_table select 2"
462        r = query(q)
463        self.assertIsInstance(r, int)
464        oid = r
465        q = "select oid from test_table where n=2"
466        r = query(q).getresult()
467        self.assertEqual(len(r), 1)
468        r = r[0]
469        self.assertEqual(len(r), 1)
470        r = r[0]
471        self.assertIsInstance(r, int)
472        self.assertEqual(r, oid)
473        q = "insert into test_table select 3 union select 4 union select 5"
474        r = query(q)
475        self.assertIsInstance(r, str)
476        self.assertEqual(r, '3')
477        q = "update test_table set n=4 where n<5"
478        r = query(q)
479        self.assertIsInstance(r, str)
480        self.assertEqual(r, '4')
481        q = "delete from test_table"
482        r = query(q)
483        self.assertIsInstance(r, str)
484        self.assertEqual(r, '5')
485        query("drop table test_table")
486
487    def testPrint(self):
488        q = ("select 1 as a, 'hello' as h, 'w' as world"
489            " union select 2, 'xyz', 'uvw'")
490        r = self.c.query(q)
491        f = tempfile.TemporaryFile()
492        stdout, sys.stdout = sys.stdout, f
493        try:
494            print r
495        except Exception:
496            pass
497        sys.stdout = stdout
498        f.seek(0)
499        r = f.read()
500        f.close()
501        self.assertEqual(r,
502            'a|  h  |world\n'
503            '-+-----+-----\n'
504            '1|hello|w    \n'
505            '2|xyz  |uvw  \n'
506            '(2 rows)\n')
507
508
509class TestParamQueries(unittest.TestCase):
510    """Test queries with parameters via a basic pg connection."""
511
512    def setUp(self):
513        self.c = connect()
514
515    def tearDown(self):
516        self.c.close()
517
518    def testQueryWithNoneParam(self):
519        self.assertEqual(self.c.query("select $1::integer", (None,)
520            ).getresult(), [(None,)])
521        self.assertEqual(self.c.query("select $1::text", [None]
522            ).getresult(), [(None,)])
523
524    def testQueryWithBoolParams(self, use_bool=None):
525        query = self.c.query
526        if use_bool is not None:
527            use_bool_default = pg.get_bool()
528            pg.set_bool(use_bool)
529        try:
530            v_false, v_true = use_bool and (False, True) or 'ft'
531            r_false, r_true = [(v_false,)], [(v_true,)]
532            self.assertEqual(query("select false").getresult(), r_false)
533            self.assertEqual(query("select true").getresult(), r_true)
534            q = "select $1::bool"
535            self.assertEqual(query(q, (None,)).getresult(), [(None,)])
536            self.assertEqual(query(q, ('f',)).getresult(), r_false)
537            self.assertEqual(query(q, ('t',)).getresult(), r_true)
538            self.assertEqual(query(q, ('false',)).getresult(), r_false)
539            self.assertEqual(query(q, ('true',)).getresult(), r_true)
540            self.assertEqual(query(q, ('n',)).getresult(), r_false)
541            self.assertEqual(query(q, ('y',)).getresult(), r_true)
542            self.assertEqual(query(q, (0,)).getresult(), r_false)
543            self.assertEqual(query(q, (1,)).getresult(), r_true)
544            self.assertEqual(query(q, (False,)).getresult(), r_false)
545            self.assertEqual(query(q, (True,)).getresult(), r_true)
546        finally:
547            if use_bool is not None:
548                pg.set_bool(use_bool_default)
549
550    def testQueryWithBoolParamsAndUseBool(self):
551        self.testQueryWithBoolParams(use_bool=True)
552
553    def testQueryWithIntParams(self):
554        query = self.c.query
555        self.assertEqual(query("select 1+1").getresult(), [(2,)])
556        self.assertEqual(query("select 1+$1", (1,)).getresult(), [(2,)])
557        self.assertEqual(query("select 1+$1", [1]).getresult(), [(2,)])
558        self.assertEqual(query("select $1::integer", (2,)).getresult(), [(2,)])
559        self.assertEqual(query("select $1::text", (2,)).getresult(), [('2',)])
560        self.assertEqual(query("select 1+$1::numeric", [1]).getresult(),
561            [(Decimal('2'),)])
562        self.assertEqual(query("select 1, $1::integer", (2,)
563            ).getresult(), [(1, 2)])
564        self.assertEqual(query("select 1 union select $1::integer", (2,)
565            ).getresult(), [(1,), (2,)])
566        self.assertEqual(query("select $1::integer+$2", (1, 2)
567            ).getresult(), [(3,)])
568        self.assertEqual(query("select $1::integer+$2", [1, 2]
569            ).getresult(), [(3,)])
570        self.assertEqual(query("select 0+$1+$2+$3+$4+$5+$6", range(6)
571            ).getresult(), [(15,)])
572
573    def testQueryWithStrParams(self):
574        query = self.c.query
575        self.assertEqual(query("select $1||', world!'", ('Hello',)
576            ).getresult(), [('Hello, world!',)])
577        self.assertEqual(query("select $1||', world!'", ['Hello']
578            ).getresult(), [('Hello, world!',)])
579        self.assertEqual(query("select $1||', '||$2||'!'", ('Hello', 'world'),
580            ).getresult(), [('Hello, world!',)])
581        self.assertEqual(query("select $1::text", ('Hello, world!',)
582            ).getresult(), [('Hello, world!',)])
583        self.assertEqual(query("select $1::text,$2::text", ('Hello', 'world')
584            ).getresult(), [('Hello', 'world')])
585        self.assertEqual(query("select $1::text,$2::text", ['Hello', 'world']
586            ).getresult(), [('Hello', 'world')])
587        self.assertEqual(query("select $1::text union select $2::text",
588            ('Hello', 'world')).getresult(), [('Hello',), ('world',)])
589        self.assertEqual(query("select $1||', '||$2||'!'", ('Hello',
590            'w\xc3\xb6rld')).getresult(), [('Hello, w\xc3\xb6rld!',)])
591
592    def testQueryWithUnicodeParams(self):
593        query = self.c.query
594        query('set client_encoding = utf8')
595        self.assertEqual(query("select $1||', '||$2||'!'",
596            ('Hello', u'w\xf6rld')).getresult(), [('Hello, w\xc3\xb6rld!',)])
597        self.assertEqual(query("select $1||', '||$2||'!'",
598            ('Hello', u'\u043c\u0438\u0440')).getresult(),
599            [('Hello, \xd0\xbc\xd0\xb8\xd1\x80!',)])
600        query('set client_encoding = latin1')
601        self.assertEqual(query("select $1||', '||$2||'!'",
602            ('Hello', u'w\xf6rld')).getresult(), [('Hello, w\xf6rld!',)])
603        self.assertRaises(UnicodeError, query, "select $1||', '||$2||'!'",
604            ('Hello', u'\u043c\u0438\u0440'))
605        query('set client_encoding = iso_8859_1')
606        self.assertEqual(query("select $1||', '||$2||'!'",
607            ('Hello', u'w\xf6rld')).getresult(), [('Hello, w\xf6rld!',)])
608        self.assertRaises(UnicodeError, query, "select $1||', '||$2||'!'",
609            ('Hello', u'\u043c\u0438\u0440'))
610        query('set client_encoding = iso_8859_5')
611        self.assertRaises(UnicodeError, query, "select $1||', '||$2||'!'",
612            ('Hello', u'w\xf6rld'))
613        self.assertEqual(query("select $1||', '||$2||'!'",
614            ('Hello', u'\u043c\u0438\u0440')).getresult(),
615            [('Hello, \xdc\xd8\xe0!',)])
616        query('set client_encoding = sql_ascii')
617        self.assertRaises(UnicodeError, query, "select $1||', '||$2||'!'",
618            ('Hello', u'w\xf6rld'))
619
620    def testQueryWithMixedParams(self):
621        self.assertEqual(self.c.query("select $1+2,$2||', world!'",
622            (1, 'Hello'),).getresult(), [(3, 'Hello, world!')])
623        self.assertEqual(self.c.query("select $1::integer,$2::date,$3::text",
624            (4711, None, 'Hello!'),).getresult(), [(4711, None, 'Hello!')])
625
626    def testQueryWithDuplicateParams(self):
627        self.assertRaises(pg.ProgrammingError,
628            self.c.query, "select $1+$1", (1,))
629        self.assertRaises(pg.ProgrammingError,
630            self.c.query, "select $1+$1", (1, 2))
631
632    def testQueryWithZeroParams(self):
633        self.assertEqual(self.c.query("select 1+1", []
634            ).getresult(), [(2,)])
635
636    def testQueryWithGarbage(self):
637        garbage = r"'\{}+()-#[]oo324"
638        self.assertEqual(self.c.query("select $1::text AS garbage", (garbage,)
639            ).dictresult(), [{'garbage': garbage}])
640
641    def testUnicodeQuery(self):
642        query = self.c.query
643        self.assertEqual(query(u"select 1+1").getresult(), [(2,)])
644        self.assertRaises(TypeError, query, u"select 'Hello, w\xf6rld!'")
645
646
647class TestInserttable(unittest.TestCase):
648    """Test inserttable method."""
649
650    @classmethod
651    def setUpClass(cls):
652        c = connect()
653        c.query("drop table if exists test cascade")
654        c.query("create table test ("
655            "i2 smallint, i4 integer, i8 bigint, b boolean, dt date, ti time,"
656            "d numeric, f4 real, f8 double precision, m money,"
657            "c char(1), v4 varchar(4), c4 char(4), t text)")
658        c.close()
659
660    @classmethod
661    def tearDownClass(cls):
662        c = connect()
663        c.query("drop table test cascade")
664        c.close()
665
666    def setUp(self):
667        self.c = connect()
668        self.c.query("set lc_monetary='C'")
669        self.c.query("set datestyle='ISO,YMD'")
670
671    def tearDown(self):
672        self.c.query("truncate table test")
673        self.c.close()
674
675    data = [
676        (-1, -1, -1L, True, '1492-10-12', '08:30:00',
677            -1.2345, -1.75, -1.875, '-1.25', '-', 'r?', '!u', 'xyz'),
678        (0, 0, 0L, False, '1607-04-14', '09:00:00',
679            0.0, 0.0, 0.0, '0.0', ' ', '0123', '4567', '890'),
680        (1, 1, 1L, True, '1801-03-04', '03:45:00',
681            1.23456, 1.75, 1.875, '1.25', 'x', 'bc', 'cdef', 'g'),
682        (2, 2, 2L, False, '1903-12-17', '11:22:00',
683            2.345678, 2.25, 2.125, '2.75', 'y', 'q', 'ijk', 'mnop\nstux!')]
684
685    def get_back(self):
686        """Convert boolean and decimal values back."""
687        data = []
688        for row in self.c.query("select * from test order by 1").getresult():
689            self.assertIsInstance(row, tuple)
690            row = list(row)
691            if row[0] is not None:  # smallint
692                self.assertIsInstance(row[0], int)
693            if row[1] is not None:  # integer
694                self.assertIsInstance(row[1], int)
695            if row[2] is not None:  # bigint
696                self.assertIsInstance(row[2], long)
697            if row[3] is not None:  # boolean
698                self.assertIsInstance(row[3], str)
699                row[3] = {'f': False, 't': True}.get(row[3])
700            if row[4] is not None:  # date
701                self.assertIsInstance(row[4], str)
702                self.assertTrue(row[4].replace('-', '').isdigit())
703            if row[5] is not None:  # time
704                self.assertIsInstance(row[5], str)
705                self.assertTrue(row[5].replace(':', '').isdigit())
706            if row[6] is not None:  # numeric
707                self.assertIsInstance(row[6], Decimal)
708                row[6] = float(row[6])
709            if row[7] is not None:  # real
710                self.assertIsInstance(row[7], float)
711            if row[8] is not None:  # double precision
712                self.assertIsInstance(row[8], float)
713                row[8] = float(row[8])
714            if row[9] is not None:  # money
715                self.assertIsInstance(row[9], Decimal)
716                row[9] = str(float(row[9]))
717            if row[10] is not None:  # char(1)
718                self.assertIsInstance(row[10], str)
719                self.assertEqual(len(row[10]), 1)
720            if row[11] is not None:  # varchar(4)
721                self.assertIsInstance(row[11], str)
722                self.assertLessEqual(len(row[11]), 4)
723            if row[12] is not None:  # char(4)
724                self.assertIsInstance(row[12], str)
725                self.assertEqual(len(row[12]), 4)
726                row[12] = row[12].rstrip()
727            if row[13] is not None:  # text
728                self.assertIsInstance(row[13], str)
729            row = tuple(row)
730            data.append(row)
731        return data
732
733    def testInserttable1Row(self):
734        data = self.data[2:3]
735        self.c.inserttable("test", data)
736        self.assertEqual(self.get_back(), data)
737
738    def testInserttable4Rows(self):
739        data = self.data
740        self.c.inserttable("test", data)
741        self.assertEqual(self.get_back(), data)
742
743    def testInserttableMultipleRows(self):
744        num_rows = 100
745        data = self.data[2:3] * num_rows
746        self.c.inserttable("test", data)
747        r = self.c.query("select count(*) from test").getresult()[0][0]
748        self.assertEqual(r, num_rows)
749
750    def testInserttableMultipleCalls(self):
751        num_rows = 10
752        data = self.data[2:3]
753        for _i in range(num_rows):
754            self.c.inserttable("test", data)
755        r = self.c.query("select count(*) from test").getresult()[0][0]
756        self.assertEqual(r, num_rows)
757
758    def testInserttableNullValues(self):
759        data = [(None,) * 14] * 100
760        self.c.inserttable("test", data)
761        self.assertEqual(self.get_back(), data)
762
763    def testInserttableMaxValues(self):
764        data = [(2 ** 15 - 1, int(2 ** 31 - 1), long(2 ** 31 - 1),
765            True, '2999-12-31', '11:59:59', 1e99,
766            1.0 + 1.0 / 32, 1.0 + 1.0 / 32, None,
767            "1", "1234", "1234", "1234" * 100)]
768        self.c.inserttable("test", data)
769        self.assertEqual(self.get_back(), data)
770
771
772class TestDirectSocketAccess(unittest.TestCase):
773    """Test copy command with direct socket access."""
774
775    @classmethod
776    def setUpClass(cls):
777        c = connect()
778        c.query("drop table if exists test cascade")
779        c.query("create table test (i int, v varchar(16))")
780        c.close()
781
782    @classmethod
783    def tearDownClass(cls):
784        c = connect()
785        c.query("drop table test cascade")
786        c.close()
787
788    def setUp(self):
789        self.c = connect()
790        self.c.query("set datestyle='ISO,YMD'")
791
792    def tearDown(self):
793        self.c.query("truncate table test")
794        self.c.close()
795
796    def testPutline(self):
797        putline = self.c.putline
798        query = self.c.query
799        data = list(enumerate("apple pear plum cherry banana".split()))
800        query("copy test from stdin")
801        try:
802            for i, v in data:
803                putline("%d\t%s\n" % (i, v))
804            putline("\\.\n")
805        finally:
806            self.c.endcopy()
807        r = query("select * from test").getresult()
808        self.assertEqual(r, data)
809
810    def testGetline(self):
811        getline = self.c.getline
812        query = self.c.query
813        data = list(enumerate("apple banana pear plum strawberry".split()))
814        n = len(data)
815        self.c.inserttable('test', data)
816        query("copy test to stdout")
817        try:
818            for i in range(n + 2):
819                v = getline()
820                if i < n:
821                    self.assertEqual(v, '%d\t%s' % data[i])
822                elif i == n or self.c.server_version < 90000:
823                    self.assertEqual(v, '\\.')
824                else:
825                    self.assertIsNone(v)
826        finally:
827            try:
828                self.c.endcopy()
829            except IOError:
830                pass
831
832    def testParameterChecks(self):
833        self.assertRaises(TypeError, self.c.putline)
834        self.assertRaises(TypeError, self.c.getline, 'invalid')
835        self.assertRaises(TypeError, self.c.endcopy, 'invalid')
836
837
838class TestNotificatons(unittest.TestCase):
839    """Test notification support."""
840
841    def setUp(self):
842        self.c = connect()
843
844    def tearDown(self):
845        self.c.close()
846
847    def testGetNotify(self):
848        if self.c.server_version < 90000:  # PostgreSQL < 9.0
849            self.skipTest('Notify with payload not supported')
850        getnotify = self.c.getnotify
851        query = self.c.query
852        self.assertIsNone(getnotify())
853        query('listen test_notify')
854        try:
855            self.assertIsNone(self.c.getnotify())
856            query("notify test_notify")
857            r = getnotify()
858            self.assertIsInstance(r, tuple)
859            self.assertEqual(len(r), 3)
860            self.assertIsInstance(r[0], str)
861            self.assertIsInstance(r[1], int)
862            self.assertIsInstance(r[2], str)
863            self.assertEqual(r[0], 'test_notify')
864            self.assertEqual(r[2], '')
865            self.assertIsNone(self.c.getnotify())
866            query("notify test_notify, 'test_payload'")
867            r = getnotify()
868            self.assertTrue(isinstance(r, tuple))
869            self.assertEqual(len(r), 3)
870            self.assertIsInstance(r[0], str)
871            self.assertIsInstance(r[1], int)
872            self.assertIsInstance(r[2], str)
873            self.assertEqual(r[0], 'test_notify')
874            self.assertEqual(r[2], 'test_payload')
875            self.assertIsNone(getnotify())
876        finally:
877            query('unlisten test_notify')
878
879    def testGetNoticeReceiver(self):
880        self.assertIsNone(self.c.get_notice_receiver())
881
882    def testSetNoticeReceiver(self):
883        self.assertRaises(TypeError, self.c.set_notice_receiver, None)
884        self.assertRaises(TypeError, self.c.set_notice_receiver, 42)
885        self.assertIsNone(self.c.set_notice_receiver(lambda notice: None))
886
887    def testSetAndGetNoticeReceiver(self):
888        r = lambda notice: None
889        self.assertIsNone(self.c.set_notice_receiver(r))
890        self.assertIs(self.c.get_notice_receiver(), r)
891
892    def testNoticeReceiver(self):
893        self.c.query('''create function bilbo_notice() returns void AS $$
894            begin
895                raise warning 'Bilbo was here!';
896            end;
897            $$ language plpgsql''')
898        try:
899            received = {}
900
901            def notice_receiver(notice):
902                for attr in dir(notice):
903                    value = getattr(notice, attr)
904                    if isinstance(value, str):
905                        value = value.replace('WARNUNG', 'WARNING')
906                    received[attr] = value
907
908            self.c.set_notice_receiver(notice_receiver)
909            self.c.query('''select bilbo_notice()''')
910            self.assertEqual(received, dict(
911                pgcnx=self.c, message='WARNING:  Bilbo was here!\n',
912                severity='WARNING', primary='Bilbo was here!',
913                detail=None, hint=None))
914        finally:
915            self.c.query('''drop function bilbo_notice();''')
916
917
918class TestConfigFunctions(unittest.TestCase):
919    """Test the functions for changing default settings.
920
921    To test the effect of most of these functions, we need a database
922    connection.  That's why they are covered in this test module.
923
924    """
925
926    def setUp(self):
927        self.c = connect()
928
929    def tearDown(self):
930        self.c.close()
931
932    def testGetDecimalPoint(self):
933        point = pg.get_decimal_point()
934        # error if a parameter is passed
935        self.assertRaises(TypeError, pg.get_decimal_point, point)
936        self.assertIsInstance(point, str)
937        self.assertEqual(point, '.')  # the default setting
938        pg.set_decimal_point(',')
939        try:
940            r = pg.get_decimal_point()
941        finally:
942            pg.set_decimal_point(point)
943        self.assertIsInstance(r, str)
944        self.assertEqual(r, ',')
945        pg.set_decimal_point("'")
946        try:
947            r = pg.get_decimal_point()
948        finally:
949            pg.set_decimal_point(point)
950        self.assertIsInstance(r, str)
951        self.assertEqual(r, "'")
952        pg.set_decimal_point('')
953        try:
954            r = pg.get_decimal_point()
955        finally:
956            pg.set_decimal_point(point)
957        self.assertIsNone(r)
958        pg.set_decimal_point(None)
959        try:
960            r = pg.get_decimal_point()
961        finally:
962            pg.set_decimal_point(point)
963        self.assertIsNone(r)
964
965    def testSetDecimalPoint(self):
966        d = pg.Decimal
967        point = pg.get_decimal_point()
968        self.assertRaises(TypeError, pg.set_decimal_point)
969        # error if decimal point is not a string
970        self.assertRaises(TypeError, pg.set_decimal_point, 0)
971        # error if more than one decimal point passed
972        self.assertRaises(TypeError, pg.set_decimal_point, '.', ',')
973        self.assertRaises(TypeError, pg.set_decimal_point, '.,')
974        # error if decimal point is not a punctuation character
975        self.assertRaises(TypeError, pg.set_decimal_point, '0')
976        query = self.c.query
977        # check that money values are interpreted as decimal values
978        # only if decimal_point is set, and that the result is correct
979        # only if it is set suitable for the current lc_monetary setting
980        select_money = "select '34.25'::money"
981        proper_money = d('34.25')
982        bad_money = d('3425')
983        en_locales = 'en', 'en_US', 'en_US.utf8', 'en_US.UTF-8'
984        en_money = '$34.25', '$ 34.25', '34.25$', '34.25 $', '34.25 Dollar'
985        de_locales = 'de', 'de_DE', 'de_DE.utf8', 'de_DE.UTF-8'
986        de_money = ('34,25€', '34,25 €', '€34,25' '€ 34,25',
987            'EUR34,25', 'EUR 34,25', '34,25 EUR', '34,25 Euro', '34,25 DM')
988        # first try with English localization (using the point)
989        for lc in en_locales:
990            try:
991                query("set lc_monetary='%s'" % lc)
992            except pg.ProgrammingError:
993                pass
994            else:
995                break
996        else:
997            self.skipTest("cannot set English money locale")
998        try:
999            r = query(select_money)
1000        except pg.ProgrammingError:
1001            # this can happen if the currency signs cannot be
1002            # converted using the encoding of the test database
1003            self.skipTest("database does not support English money")
1004        pg.set_decimal_point(None)
1005        try:
1006            r = r.getresult()[0][0]
1007        finally:
1008            pg.set_decimal_point(point)
1009        self.assertIsInstance(r, str)
1010        self.assertIn(r, en_money)
1011        r = query(select_money)
1012        pg.set_decimal_point('')
1013        try:
1014            r = r.getresult()[0][0]
1015        finally:
1016            pg.set_decimal_point(point)
1017        self.assertIsInstance(r, str)
1018        self.assertIn(r, en_money)
1019        r = query(select_money)
1020        pg.set_decimal_point('.')
1021        try:
1022            r = r.getresult()[0][0]
1023        finally:
1024            pg.set_decimal_point(point)
1025        self.assertIsInstance(r, d)
1026        self.assertEqual(r, proper_money)
1027        r = query(select_money)
1028        pg.set_decimal_point(',')
1029        try:
1030            r = r.getresult()[0][0]
1031        finally:
1032            pg.set_decimal_point(point)
1033        self.assertIsInstance(r, d)
1034        self.assertEqual(r, bad_money)
1035        r = query(select_money)
1036        pg.set_decimal_point("'")
1037        try:
1038            r = r.getresult()[0][0]
1039        finally:
1040            pg.set_decimal_point(point)
1041        self.assertIsInstance(r, d)
1042        self.assertEqual(r, bad_money)
1043        # then try with German localization (using the comma)
1044        for lc in de_locales:
1045            try:
1046                query("set lc_monetary='%s'" % lc)
1047            except pg.ProgrammingError:
1048                pass
1049            else:
1050                break
1051        else:
1052            self.skipTest("cannot set German money locale")
1053        select_money = select_money.replace('.', ',')
1054        try:
1055            r = query(select_money)
1056        except pg.ProgrammingError:
1057            self.skipTest("database does not support English money")
1058        pg.set_decimal_point(None)
1059        try:
1060            r = r.getresult()[0][0]
1061        finally:
1062            pg.set_decimal_point(point)
1063        self.assertIsInstance(r, str)
1064        self.assertIn(r, de_money)
1065        r = query(select_money)
1066        pg.set_decimal_point('')
1067        try:
1068            r = r.getresult()[0][0]
1069        finally:
1070            pg.set_decimal_point(point)
1071        self.assertIsInstance(r, str)
1072        self.assertIn(r, de_money)
1073        r = query(select_money)
1074        pg.set_decimal_point(',')
1075        try:
1076            r = r.getresult()[0][0]
1077        finally:
1078            pg.set_decimal_point(point)
1079        self.assertIsInstance(r, d)
1080        self.assertEqual(r, proper_money)
1081        r = query(select_money)
1082        pg.set_decimal_point('.')
1083        try:
1084            r = r.getresult()[0][0]
1085        finally:
1086            pg.set_decimal_point(point)
1087        self.assertEqual(r, bad_money)
1088        r = query(select_money)
1089        pg.set_decimal_point("'")
1090        try:
1091            r = r.getresult()[0][0]
1092        finally:
1093            pg.set_decimal_point(point)
1094        self.assertEqual(r, bad_money)
1095
1096    def testGetDecimal(self):
1097        decimal_class = pg.get_decimal()
1098        # error if a parameter is passed
1099        self.assertRaises(TypeError, pg.get_decimal, decimal_class)
1100        self.assertIs(decimal_class, pg.Decimal)  # the default setting
1101        pg.set_decimal(int)
1102        try:
1103            r = pg.get_decimal()
1104        finally:
1105            pg.set_decimal(decimal_class)
1106        self.assertIs(r, int)
1107        r = pg.get_decimal()
1108        self.assertIs(r, decimal_class)
1109
1110    def testSetDecimal(self):
1111        decimal_class = pg.get_decimal()
1112        # error if no parameter is passed
1113        self.assertRaises(TypeError, pg.set_decimal)
1114        query = self.c.query
1115        try:
1116            r = query("select 3425::numeric")
1117        except pg.ProgrammingError:
1118            self.skipTest('database does not support numeric')
1119        r = r.getresult()[0][0]
1120        self.assertIsInstance(r, decimal_class)
1121        self.assertEqual(r, decimal_class('3425'))
1122        r = query("select 3425::numeric")
1123        pg.set_decimal(int)
1124        try:
1125            r = r.getresult()[0][0]
1126        finally:
1127            pg.set_decimal(decimal_class)
1128        self.assertNotIsInstance(r, decimal_class)
1129        self.assertIsInstance(r, int)
1130        self.assertEqual(r, int(3425))
1131
1132    def testGetBool(self):
1133        use_bool = pg.get_bool()
1134        # error if a parameter is passed
1135        self.assertRaises(TypeError, pg.get_bool, use_bool)
1136        self.assertIsInstance(use_bool, bool)
1137        self.assertIs(use_bool, False)  # the default setting
1138        pg.set_bool(True)
1139        try:
1140            r = pg.get_bool()
1141        finally:
1142            pg.set_bool(use_bool)
1143        self.assertIsInstance(r, bool)
1144        self.assertIs(r, True)
1145        pg.set_bool(False)
1146        try:
1147            r = pg.get_bool()
1148        finally:
1149            pg.set_bool(use_bool)
1150        self.assertIsInstance(r, bool)
1151        self.assertIs(r, False)
1152        pg.set_bool(1)
1153        try:
1154            r = pg.get_bool()
1155        finally:
1156            pg.set_bool(use_bool)
1157        self.assertIsInstance(r, bool)
1158        self.assertIs(r, True)
1159        pg.set_bool(0)
1160        try:
1161            r = pg.get_bool()
1162        finally:
1163            pg.set_bool(use_bool)
1164        self.assertIsInstance(r, bool)
1165        self.assertIs(r, False)
1166
1167    def testSetBool(self):
1168        use_bool = pg.get_bool()
1169        # error if no parameter is passed
1170        self.assertRaises(TypeError, pg.set_bool)
1171        query = self.c.query
1172        try:
1173            r = query("select true::bool")
1174        except pg.ProgrammingError:
1175            self.skipTest('database does not support bool')
1176        r = r.getresult()[0][0]
1177        self.assertIsInstance(r, str)
1178        self.assertEqual(r, 't')
1179        r = query("select true::bool")
1180        pg.set_bool(True)
1181        try:
1182            r = r.getresult()[0][0]
1183        finally:
1184            pg.set_bool(use_bool)
1185        self.assertIsInstance(r, bool)
1186        self.assertIs(r, True)
1187        r = query("select true::bool")
1188        pg.set_bool(False)
1189        try:
1190            r = r.getresult()[0][0]
1191        finally:
1192            pg.set_bool(use_bool)
1193        self.assertIsInstance(r, str)
1194        self.assertIs(r, 't')
1195
1196    @unittest.skipUnless(namedtuple, 'Named tuples not available')
1197    def testGetNamedresult(self):
1198        namedresult = pg.get_namedresult()
1199        # error if a parameter is passed
1200        self.assertRaises(TypeError, pg.get_namedresult, namedresult)
1201        self.assertIs(namedresult, pg._namedresult)  # the default setting
1202
1203    @unittest.skipUnless(namedtuple, 'Named tuples not available')
1204    def testSetNamedresult(self):
1205        namedresult = pg.get_namedresult()
1206        self.assertTrue(callable(namedresult))
1207
1208        query = self.c.query
1209
1210        r = query("select 1 as x, 2 as y").namedresult()[0]
1211        self.assertIsInstance(r, tuple)
1212        self.assertEqual(r, (1, 2))
1213        self.assertIsNot(type(r), tuple)
1214        self.assertEqual(r._fields, ('x', 'y'))
1215        self.assertEqual(r._asdict(), {'x': 1, 'y': 2})
1216        self.assertEqual(r.__class__.__name__, 'Row')
1217
1218        def listresult(q):
1219            return [list(row) for row in q.getresult()]
1220
1221        pg.set_namedresult(listresult)
1222        try:
1223            r = pg.get_namedresult()
1224            self.assertIs(r, listresult)
1225            r = query("select 1 as x, 2 as y").namedresult()[0]
1226            self.assertIsInstance(r, list)
1227            self.assertEqual(r, [1, 2])
1228            self.assertIsNot(type(r), tuple)
1229            self.assertFalse(hasattr(r, '_fields'))
1230            self.assertNotEqual(r.__class__.__name__, 'Row')
1231        finally:
1232            pg.set_namedresult(namedresult)
1233
1234        r = pg.get_namedresult()
1235        self.assertIs(r, namedresult)
1236
1237
1238if __name__ == '__main__':
1239    unittest.main()
Note: See TracBrowser for help on using the repository browser.