source: branches/4.x/tests/test_classic_connection.py @ 771

Last change on this file since 771 was 771, checked in by cito, 4 years ago

Back port some minor fixes from the trunk

This also gives better error message if test runner does not support unittest2.

  • Property svn:executable set to *
  • Property svn:keywords set to Id
File size: 44.2 KB
Line 
1#! /usr/bin/python
2# -*- coding: utf-8 -*-
3
4"""Test the classic PyGreSQL interface.
5
6Sub-tests for the low-level connection object.
7
8Contributed by Christoph Zwerschke.
9
10These tests need a database to test against.
11
12"""
13
14try:
15    import unittest2 as unittest  # for Python < 2.7
16except ImportError:
17    import unittest
18import sys
19import tempfile
20import threading
21import time
22import os
23
24import pg  # the module under test
25
26from decimal import Decimal
27try:
28    from collections import namedtuple
29except ImportError:  # Python < 2.6
30    namedtuple = None
31
32# We need a database to test against.  If LOCAL_PyGreSQL.py exists we will
33# get our information from that.  Otherwise we use the defaults.
34dbname = 'unittest'
35dbhost = None
36dbport = 5432
37
38try:
39    from LOCAL_PyGreSQL import *
40except ImportError:
41    pass
42
43windows = os.name == 'nt'
44
45# There is a known a bug in libpq under Windows which can cause
46# the interface to crash when calling PQhost():
47do_not_ask_for_host = windows
48do_not_ask_for_host_reason = 'libpq issue on Windows'
49
50
51def connect():
52    """Create a basic pg connection to the test database."""
53    connection = pg.connect(dbname, dbhost, dbport)
54    connection.query("set client_min_messages=warning")
55    return connection
56
57
58class TestCanConnect(unittest.TestCase):
59    """Test whether a basic connection to PostgreSQL is possible."""
60
61    def testCanConnect(self):
62        try:
63            connection = connect()
64        except pg.Error, error:
65            self.fail('Cannot connect to database %s:\n%s' % (dbname, error))
66        try:
67            connection.close()
68        except pg.Error:
69            self.fail('Cannot close the database connection')
70
71
72class TestConnectObject(unittest.TestCase):
73    """Test existence of basic pg connection methods."""
74
75    def setUp(self):
76        self.connection = connect()
77
78    def tearDown(self):
79        try:
80            self.connection.close()
81        except pg.InternalError:
82            pass
83
84    def is_method(self, attribute):
85        """Check if given attribute on the connection is a method."""
86        if do_not_ask_for_host and attribute == 'host':
87            return False
88        return callable(getattr(self.connection, attribute))
89
90    def testAllConnectAttributes(self):
91        attributes = '''db error host options port
92            protocol_version server_version status tty user'''.split()
93        connection_attributes = [a for a in dir(self.connection)
94            if not a.startswith('__') and not self.is_method(a)]
95        self.assertEqual(attributes, connection_attributes)
96
97    def testAllConnectMethods(self):
98        methods = '''cancel close endcopy
99            escape_bytea escape_identifier escape_literal escape_string
100            fileno get_notice_receiver getline getlo getnotify
101            inserttable locreate loimport parameter putline query reset
102            set_notice_receiver source transaction'''.split()
103        if self.connection.server_version < 90000:  # PostgreSQL < 9.0
104            methods.remove('escape_identifier')
105            methods.remove('escape_literal')
106        connection_methods = [a for a in dir(self.connection)
107            if not a.startswith('__') and self.is_method(a)]
108        self.assertEqual(methods, connection_methods)
109
110    def testAttributeDb(self):
111        self.assertEqual(self.connection.db, dbname)
112
113    def testAttributeError(self):
114        error = self.connection.error
115        self.assertTrue(not error or 'krb5_' in error)
116
117    @unittest.skipIf(do_not_ask_for_host, do_not_ask_for_host_reason)
118    def testAttributeHost(self):
119        def_host = 'localhost'
120        self.assertIsInstance(self.connection.host, str)
121        self.assertEqual(self.connection.host, dbhost or def_host)
122
123    def testAttributeOptions(self):
124        no_options = ''
125        self.assertEqual(self.connection.options, no_options)
126
127    def testAttributePort(self):
128        def_port = 5432
129        self.assertIsInstance(self.connection.port, int)
130        self.assertEqual(self.connection.port, dbport or def_port)
131
132    def testAttributeProtocolVersion(self):
133        protocol_version = self.connection.protocol_version
134        self.assertIsInstance(protocol_version, int)
135        self.assertTrue(2 <= protocol_version < 4)
136
137    def testAttributeServerVersion(self):
138        server_version = self.connection.server_version
139        self.assertIsInstance(server_version, int)
140        self.assertTrue(70400 <= server_version < 100000)
141
142    def testAttributeStatus(self):
143        status_ok = 1
144        self.assertIsInstance(self.connection.status, int)
145        self.assertEqual(self.connection.status, status_ok)
146
147    def testAttributeTty(self):
148        def_tty = ''
149        self.assertIsInstance(self.connection.tty, str)
150        self.assertEqual(self.connection.tty, def_tty)
151
152    def testAttributeUser(self):
153        no_user = 'Deprecated facility'
154        user = self.connection.user
155        self.assertTrue(user)
156        self.assertIsInstance(user, str)
157        self.assertNotEqual(user, no_user)
158
159    def testMethodQuery(self):
160        query = self.connection.query
161        query("select 1+1")
162        query("select 1+$1", (1,))
163        query("select 1+$1+$2", (2, 3))
164        query("select 1+$1+$2", [2, 3])
165
166    def testMethodQueryEmpty(self):
167        self.assertRaises(ValueError, self.connection.query, '')
168
169    def testMethodEndcopy(self):
170        try:
171            self.connection.endcopy()
172        except IOError:
173            pass
174
175    def testMethodClose(self):
176        self.connection.close()
177        try:
178            self.connection.reset()
179        except (pg.Error, TypeError):
180            pass
181        else:
182            self.fail('Reset should give an error for a closed connection')
183        self.assertRaises(pg.InternalError, self.connection.close)
184        try:
185            self.connection.query('select 1')
186        except (pg.Error, TypeError):
187            pass
188        else:
189            self.fail('Query should give an error for a closed connection')
190        self.connection = connect()
191
192    def testMethodReset(self):
193        query = self.connection.query
194        # check that client encoding gets reset
195        encoding = query('show client_encoding').getresult()[0][0].upper()
196        changed_encoding = encoding == 'UTF8' and 'LATIN1' or 'UTF8'
197        self.assertNotEqual(encoding, changed_encoding)
198        self.connection.query("set client_encoding=%s" % changed_encoding)
199        new_encoding = query('show client_encoding').getresult()[0][0].upper()
200        self.assertEqual(new_encoding, changed_encoding)
201        self.connection.reset()
202        new_encoding = query('show client_encoding').getresult()[0][0].upper()
203        self.assertNotEqual(new_encoding, changed_encoding)
204        self.assertEqual(new_encoding, encoding)
205
206    def testMethodCancel(self):
207        r = self.connection.cancel()
208        self.assertIsInstance(r, int)
209        self.assertEqual(r, 1)
210
211    def testCancelLongRunningThread(self):
212        errors = []
213
214        def sleep():
215            try:
216                self.connection.query('select pg_sleep(5)').getresult()
217            except pg.ProgrammingError, error:
218                errors.append(str(error))
219
220        thread = threading.Thread(target=sleep)
221        t1 = time.time()
222        thread.start()  # run the query
223        while 1:  # make sure the query is really running
224            time.sleep(0.1)
225            if thread.isAlive() or time.time() - t1 > 5:
226                break
227        r = self.connection.cancel()  # cancel the running query
228        thread.join()  # wait for the thread to end
229        t2 = time.time()
230
231        self.assertIsInstance(r, int)
232        self.assertEqual(r, 1)  # return code should be 1
233        self.assertLessEqual(t2 - t1, 3)  # time should be under 3 seconds
234        self.assertTrue(errors)
235
236    def testMethodFileNo(self):
237        r = self.connection.fileno()
238        self.assertIsInstance(r, int)
239        self.assertGreaterEqual(r, 0)
240
241
242class TestSimpleQueries(unittest.TestCase):
243    """Test simple queries via a basic pg connection."""
244
245    def setUp(self):
246        self.c = connect()
247
248    def tearDown(self):
249        self.c.close()
250
251    def testSelect0(self):
252        q = "select 0"
253        self.c.query(q)
254
255    def testSelect0Semicolon(self):
256        q = "select 0;"
257        self.c.query(q)
258
259    def testSelectDotSemicolon(self):
260        q = "select .;"
261        self.assertRaises(pg.ProgrammingError, self.c.query, q)
262
263    def testGetresult(self):
264        q = "select 0"
265        result = [(0,)]
266        r = self.c.query(q).getresult()
267        self.assertIsInstance(r, list)
268        v = r[0]
269        self.assertIsInstance(v, tuple)
270        self.assertIsInstance(v[0], int)
271        self.assertEqual(r, result)
272
273    def testGetresultLong(self):
274        q = "select 9876543210"
275        result = 9876543210L
276        v = self.c.query(q).getresult()[0][0]
277        self.assertIsInstance(v, long)
278        self.assertEqual(v, result)
279
280    def testGetresultDecimal(self):
281        q = "select 98765432109876543210"
282        result = Decimal(98765432109876543210L)
283        v = self.c.query(q).getresult()[0][0]
284        self.assertIsInstance(v, Decimal)
285        self.assertEqual(v, result)
286
287    def testGetresultString(self):
288        result = 'Hello, world!'
289        q = "select '%s'" % result
290        v = self.c.query(q).getresult()[0][0]
291        self.assertIsInstance(v, str)
292        self.assertEqual(v, result)
293
294    def testDictresult(self):
295        q = "select 0 as alias0"
296        result = [{'alias0': 0}]
297        r = self.c.query(q).dictresult()
298        self.assertIsInstance(r, list)
299        v = r[0]
300        self.assertIsInstance(v, dict)
301        self.assertIsInstance(v['alias0'], int)
302        self.assertEqual(r, result)
303
304    def testDictresultLong(self):
305        q = "select 9876543210 as longjohnsilver"
306        result = 9876543210L
307        v = self.c.query(q).dictresult()[0]['longjohnsilver']
308        self.assertIsInstance(v, long)
309        self.assertEqual(v, result)
310
311    def testDictresultDecimal(self):
312        q = "select 98765432109876543210 as longjohnsilver"
313        result = Decimal(98765432109876543210L)
314        v = self.c.query(q).dictresult()[0]['longjohnsilver']
315        self.assertIsInstance(v, Decimal)
316        self.assertEqual(v, result)
317
318    def testDictresultString(self):
319        result = 'Hello, world!'
320        q = "select '%s' as greeting" % result
321        v = self.c.query(q).dictresult()[0]['greeting']
322        self.assertIsInstance(v, str)
323        self.assertEqual(v, result)
324
325    @unittest.skipUnless(namedtuple, 'Named tuples not available')
326    def testNamedresult(self):
327        q = "select 0 as alias0"
328        result = [(0,)]
329        r = self.c.query(q).namedresult()
330        self.assertEqual(r, result)
331        v = r[0]
332        self.assertEqual(v._fields, ('alias0',))
333        self.assertEqual(v.alias0, 0)
334
335    def testGet3Cols(self):
336        q = "select 1,2,3"
337        result = [(1, 2, 3)]
338        r = self.c.query(q).getresult()
339        self.assertEqual(r, result)
340
341    def testGet3DictCols(self):
342        q = "select 1 as a,2 as b,3 as c"
343        result = [dict(a=1, b=2, c=3)]
344        r = self.c.query(q).dictresult()
345        self.assertEqual(r, result)
346
347    @unittest.skipUnless(namedtuple, 'Named tuples not available')
348    def testGet3NamedCols(self):
349        q = "select 1 as a,2 as b,3 as c"
350        result = [(1, 2, 3)]
351        r = self.c.query(q).namedresult()
352        self.assertEqual(r, result)
353        v = r[0]
354        self.assertEqual(v._fields, ('a', 'b', 'c'))
355        self.assertEqual(v.b, 2)
356
357    def testGet3Rows(self):
358        q = "select 3 union select 1 union select 2 order by 1"
359        result = [(1,), (2,), (3,)]
360        r = self.c.query(q).getresult()
361        self.assertEqual(r, result)
362
363    def testGet3DictRows(self):
364        q = ("select 3 as alias3"
365            " union select 1 union select 2 order by 1")
366        result = [{'alias3': 1}, {'alias3': 2}, {'alias3': 3}]
367        r = self.c.query(q).dictresult()
368        self.assertEqual(r, result)
369
370    @unittest.skipUnless(namedtuple, 'Named tuples not available')
371    def testGet3NamedRows(self):
372        q = ("select 3 as alias3"
373            " union select 1 union select 2 order by 1")
374        result = [(1,), (2,), (3,)]
375        r = self.c.query(q).namedresult()
376        self.assertEqual(r, result)
377        for v in r:
378            self.assertEqual(v._fields, ('alias3',))
379
380    def testDictresultNames(self):
381        q = "select 'MixedCase' as MixedCaseAlias"
382        result = [{'mixedcasealias': 'MixedCase'}]
383        r = self.c.query(q).dictresult()
384        self.assertEqual(r, result)
385        q = "select 'MixedCase' as \"MixedCaseAlias\""
386        result = [{'MixedCaseAlias': 'MixedCase'}]
387        r = self.c.query(q).dictresult()
388        self.assertEqual(r, result)
389
390    @unittest.skipUnless(namedtuple, 'Named tuples not available')
391    def testNamedresultNames(self):
392        q = "select 'MixedCase' as MixedCaseAlias"
393        result = [('MixedCase',)]
394        r = self.c.query(q).namedresult()
395        self.assertEqual(r, result)
396        v = r[0]
397        self.assertEqual(v._fields, ('mixedcasealias',))
398        self.assertEqual(v.mixedcasealias, 'MixedCase')
399        q = "select 'MixedCase' as \"MixedCaseAlias\""
400        r = self.c.query(q).namedresult()
401        self.assertEqual(r, result)
402        v = r[0]
403        self.assertEqual(v._fields, ('MixedCaseAlias',))
404        self.assertEqual(v.MixedCaseAlias, 'MixedCase')
405
406    def testBigGetresult(self):
407        num_cols = 100
408        num_rows = 100
409        q = "select " + ','.join(map(str, xrange(num_cols)))
410        q = ' union all '.join((q,) * num_rows)
411        r = self.c.query(q).getresult()
412        result = [tuple(range(num_cols))] * num_rows
413        self.assertEqual(r, result)
414
415    def testListfields(self):
416        q = ('select 0 as a, 0 as b, 0 as c,'
417            ' 0 as c, 0 as b, 0 as a,'
418            ' 0 as lowercase, 0 as UPPERCASE,'
419            ' 0 as MixedCase, 0 as "MixedCase",'
420            ' 0 as a_long_name_with_underscores,'
421            ' 0 as "A long name with Blanks"')
422        r = self.c.query(q).listfields()
423        result = ('a', 'b', 'c', 'c', 'b', 'a',
424            'lowercase', 'uppercase', 'mixedcase', 'MixedCase',
425            'a_long_name_with_underscores',
426            'A long name with Blanks')
427        self.assertEqual(r, result)
428
429    def testFieldname(self):
430        q = "select 0 as z, 0 as a, 0 as x, 0 as y"
431        r = self.c.query(q).fieldname(2)
432        self.assertEqual(r, 'x')
433        r = self.c.query(q).fieldname(3)
434        self.assertEqual(r, 'y')
435
436    def testFieldnum(self):
437        q = "select 1 as x"
438        self.assertRaises(ValueError, self.c.query(q).fieldnum, 'y')
439        q = "select 1 as x"
440        r = self.c.query(q).fieldnum('x')
441        self.assertIsInstance(r, int)
442        self.assertEqual(r, 0)
443        q = "select 0 as z, 0 as a, 0 as x, 0 as y"
444        r = self.c.query(q).fieldnum('x')
445        self.assertIsInstance(r, int)
446        self.assertEqual(r, 2)
447        r = self.c.query(q).fieldnum('y')
448        self.assertIsInstance(r, int)
449        self.assertEqual(r, 3)
450
451    def testNtuples(self):
452        q = "select 1 where false"
453        r = self.c.query(q).ntuples()
454        self.assertIsInstance(r, int)
455        self.assertEqual(r, 0)
456        q = ("select 1 as a, 2 as b, 3 as c, 4 as d"
457            " union select 5 as a, 6 as b, 7 as c, 8 as d")
458        r = self.c.query(q).ntuples()
459        self.assertIsInstance(r, int)
460        self.assertEqual(r, 2)
461        q = ("select 1 union select 2 union select 3"
462            " union select 4 union select 5 union select 6")
463        r = self.c.query(q).ntuples()
464        self.assertIsInstance(r, int)
465        self.assertEqual(r, 6)
466
467    def testQuery(self):
468        query = self.c.query
469        query("drop table if exists test_table")
470        q = "create table test_table (n integer) with oids"
471        r = query(q)
472        self.assertIsNone(r)
473        q = "insert into test_table values (1)"
474        r = query(q)
475        self.assertIsInstance(r, int)
476        q = "insert into test_table select 2"
477        r = query(q)
478        self.assertIsInstance(r, int)
479        oid = r
480        q = "select oid from test_table where n=2"
481        r = query(q).getresult()
482        self.assertEqual(len(r), 1)
483        r = r[0]
484        self.assertEqual(len(r), 1)
485        r = r[0]
486        self.assertIsInstance(r, int)
487        self.assertEqual(r, oid)
488        q = "insert into test_table select 3 union select 4 union select 5"
489        r = query(q)
490        self.assertIsInstance(r, str)
491        self.assertEqual(r, '3')
492        q = "update test_table set n=4 where n<5"
493        r = query(q)
494        self.assertIsInstance(r, str)
495        self.assertEqual(r, '4')
496        q = "delete from test_table"
497        r = query(q)
498        self.assertIsInstance(r, str)
499        self.assertEqual(r, '5')
500        query("drop table test_table")
501
502    def testPrint(self):
503        q = ("select 1 as a, 'hello' as h, 'w' as world"
504            " union select 2, 'xyz', 'uvw'")
505        r = self.c.query(q)
506        f = tempfile.TemporaryFile()
507        stdout, sys.stdout = sys.stdout, f
508        try:
509            print r
510        except Exception:
511            pass
512        sys.stdout = stdout
513        f.seek(0)
514        r = f.read()
515        f.close()
516        self.assertEqual(r,
517            'a|  h  |world\n'
518            '-+-----+-----\n'
519            '1|hello|w    \n'
520            '2|xyz  |uvw  \n'
521            '(2 rows)\n')
522
523
524class TestParamQueries(unittest.TestCase):
525    """Test queries with parameters via a basic pg connection."""
526
527    def setUp(self):
528        self.c = connect()
529
530    def tearDown(self):
531        self.c.close()
532
533    def testQueryWithNoneParam(self):
534        self.assertEqual(self.c.query("select $1::integer", (None,)
535            ).getresult(), [(None,)])
536        self.assertEqual(self.c.query("select $1::text", [None]
537            ).getresult(), [(None,)])
538
539    def testQueryWithBoolParams(self, use_bool=None):
540        query = self.c.query
541        if use_bool is not None:
542            use_bool_default = pg.get_bool()
543            pg.set_bool(use_bool)
544        try:
545            v_false, v_true = use_bool and (False, True) or 'ft'
546            r_false, r_true = [(v_false,)], [(v_true,)]
547            self.assertEqual(query("select false").getresult(), r_false)
548            self.assertEqual(query("select true").getresult(), r_true)
549            q = "select $1::bool"
550            self.assertEqual(query(q, (None,)).getresult(), [(None,)])
551            self.assertEqual(query(q, ('f',)).getresult(), r_false)
552            self.assertEqual(query(q, ('t',)).getresult(), r_true)
553            self.assertEqual(query(q, ('false',)).getresult(), r_false)
554            self.assertEqual(query(q, ('true',)).getresult(), r_true)
555            self.assertEqual(query(q, ('n',)).getresult(), r_false)
556            self.assertEqual(query(q, ('y',)).getresult(), r_true)
557            self.assertEqual(query(q, (0,)).getresult(), r_false)
558            self.assertEqual(query(q, (1,)).getresult(), r_true)
559            self.assertEqual(query(q, (False,)).getresult(), r_false)
560            self.assertEqual(query(q, (True,)).getresult(), r_true)
561        finally:
562            if use_bool is not None:
563                pg.set_bool(use_bool_default)
564
565    def testQueryWithBoolParamsAndUseBool(self):
566        self.testQueryWithBoolParams(use_bool=True)
567
568    def testQueryWithIntParams(self):
569        query = self.c.query
570        self.assertEqual(query("select 1+1").getresult(), [(2,)])
571        self.assertEqual(query("select 1+$1", (1,)).getresult(), [(2,)])
572        self.assertEqual(query("select 1+$1", [1]).getresult(), [(2,)])
573        self.assertEqual(query("select $1::integer", (2,)).getresult(), [(2,)])
574        self.assertEqual(query("select $1::text", (2,)).getresult(), [('2',)])
575        self.assertEqual(query("select 1+$1::numeric", [1]).getresult(),
576            [(Decimal('2'),)])
577        self.assertEqual(query("select 1, $1::integer", (2,)
578            ).getresult(), [(1, 2)])
579        self.assertEqual(query("select 1 union select $1::integer", (2,)
580            ).getresult(), [(1,), (2,)])
581        self.assertEqual(query("select $1::integer+$2", (1, 2)
582            ).getresult(), [(3,)])
583        self.assertEqual(query("select $1::integer+$2", [1, 2]
584            ).getresult(), [(3,)])
585        self.assertEqual(query("select 0+$1+$2+$3+$4+$5+$6", range(6)
586            ).getresult(), [(15,)])
587
588    def testQueryWithStrParams(self):
589        query = self.c.query
590        self.assertEqual(query("select $1||', world!'", ('Hello',)
591            ).getresult(), [('Hello, world!',)])
592        self.assertEqual(query("select $1||', world!'", ['Hello']
593            ).getresult(), [('Hello, world!',)])
594        self.assertEqual(query("select $1||', '||$2||'!'", ('Hello', 'world'),
595            ).getresult(), [('Hello, world!',)])
596        self.assertEqual(query("select $1::text", ('Hello, world!',)
597            ).getresult(), [('Hello, world!',)])
598        self.assertEqual(query("select $1::text,$2::text", ('Hello', 'world')
599            ).getresult(), [('Hello', 'world')])
600        self.assertEqual(query("select $1::text,$2::text", ['Hello', 'world']
601            ).getresult(), [('Hello', 'world')])
602        self.assertEqual(query("select $1::text union select $2::text",
603            ('Hello', 'world')).getresult(), [('Hello',), ('world',)])
604        self.assertEqual(query("select $1||', '||$2||'!'", ('Hello',
605            'w\xc3\xb6rld')).getresult(), [('Hello, w\xc3\xb6rld!',)])
606
607    def testQueryWithUnicodeParams(self):
608        query = self.c.query
609        query('set client_encoding = utf8')
610        self.assertEqual(query("select $1||', '||$2||'!'",
611            ('Hello', u'w\xf6rld')).getresult(), [('Hello, w\xc3\xb6rld!',)])
612        self.assertEqual(query("select $1||', '||$2||'!'",
613            ('Hello', u'\u043c\u0438\u0440')).getresult(),
614            [('Hello, \xd0\xbc\xd0\xb8\xd1\x80!',)])
615        query('set client_encoding = latin1')
616        self.assertEqual(query("select $1||', '||$2||'!'",
617            ('Hello', u'w\xf6rld')).getresult(), [('Hello, w\xf6rld!',)])
618        self.assertRaises(UnicodeError, query, "select $1||', '||$2||'!'",
619            ('Hello', u'\u043c\u0438\u0440'))
620        query('set client_encoding = iso_8859_1')
621        self.assertEqual(query("select $1||', '||$2||'!'",
622            ('Hello', u'w\xf6rld')).getresult(), [('Hello, w\xf6rld!',)])
623        self.assertRaises(UnicodeError, query, "select $1||', '||$2||'!'",
624            ('Hello', u'\u043c\u0438\u0440'))
625        query('set client_encoding = iso_8859_5')
626        self.assertRaises(UnicodeError, query, "select $1||', '||$2||'!'",
627            ('Hello', u'w\xf6rld'))
628        self.assertEqual(query("select $1||', '||$2||'!'",
629            ('Hello', u'\u043c\u0438\u0440')).getresult(),
630            [('Hello, \xdc\xd8\xe0!',)])
631        query('set client_encoding = sql_ascii')
632        self.assertRaises(UnicodeError, query, "select $1||', '||$2||'!'",
633            ('Hello', u'w\xf6rld'))
634
635    def testQueryWithMixedParams(self):
636        self.assertEqual(self.c.query("select $1+2,$2||', world!'",
637            (1, 'Hello'),).getresult(), [(3, 'Hello, world!')])
638        self.assertEqual(self.c.query("select $1::integer,$2::date,$3::text",
639            (4711, None, 'Hello!'),).getresult(), [(4711, None, 'Hello!')])
640
641    def testQueryWithDuplicateParams(self):
642        self.assertRaises(pg.ProgrammingError,
643            self.c.query, "select $1+$1", (1,))
644        self.assertRaises(pg.ProgrammingError,
645            self.c.query, "select $1+$1", (1, 2))
646
647    def testQueryWithZeroParams(self):
648        self.assertEqual(self.c.query("select 1+1", []
649            ).getresult(), [(2,)])
650
651    def testQueryWithGarbage(self):
652        garbage = r"'\{}+()-#[]oo324"
653        self.assertEqual(self.c.query("select $1::text AS garbage", (garbage,)
654            ).dictresult(), [{'garbage': garbage}])
655
656    def testUnicodeQuery(self):
657        query = self.c.query
658        self.assertEqual(query(u"select 1+1").getresult(), [(2,)])
659        self.assertRaises(TypeError, query, u"select 'Hello, w\xf6rld!'")
660
661
662class TestInserttable(unittest.TestCase):
663    """Test inserttable method."""
664
665    cls_set_up = False
666
667    @classmethod
668    def setUpClass(cls):
669        c = connect()
670        c.query("drop table if exists test cascade")
671        c.query("create table test ("
672            "i2 smallint, i4 integer, i8 bigint, b boolean, dt date, ti time,"
673            "d numeric, f4 real, f8 double precision, m money,"
674            "c char(1), v4 varchar(4), c4 char(4), t text)")
675        c.close()
676        cls.cls_set_up = True
677
678    @classmethod
679    def tearDownClass(cls):
680        c = connect()
681        c.query("drop table test cascade")
682        c.close()
683
684    def setUp(self):
685        self.assertTrue(self.cls_set_up)
686        self.c = connect()
687        self.c.query("set lc_monetary='C'")
688        self.c.query("set datestyle='ISO,YMD'")
689
690    def tearDown(self):
691        self.c.query("truncate table test")
692        self.c.close()
693
694    data = [
695        (-1, -1, -1L, True, '1492-10-12', '08:30:00',
696            -1.2345, -1.75, -1.875, '-1.25', '-', 'r?', '!u', 'xyz'),
697        (0, 0, 0L, False, '1607-04-14', '09:00:00',
698            0.0, 0.0, 0.0, '0.0', ' ', '0123', '4567', '890'),
699        (1, 1, 1L, True, '1801-03-04', '03:45:00',
700            1.23456, 1.75, 1.875, '1.25', 'x', 'bc', 'cdef', 'g'),
701        (2, 2, 2L, False, '1903-12-17', '11:22:00',
702            2.345678, 2.25, 2.125, '2.75', 'y', 'q', 'ijk', 'mnop\nstux!')]
703
704    def get_back(self):
705        """Convert boolean and decimal values back."""
706        data = []
707        for row in self.c.query("select * from test order by 1").getresult():
708            self.assertIsInstance(row, tuple)
709            row = list(row)
710            if row[0] is not None:  # smallint
711                self.assertIsInstance(row[0], int)
712            if row[1] is not None:  # integer
713                self.assertIsInstance(row[1], int)
714            if row[2] is not None:  # bigint
715                self.assertIsInstance(row[2], long)
716            if row[3] is not None:  # boolean
717                self.assertIsInstance(row[3], str)
718                row[3] = {'f': False, 't': True}.get(row[3])
719            if row[4] is not None:  # date
720                self.assertIsInstance(row[4], str)
721                self.assertTrue(row[4].replace('-', '').isdigit())
722            if row[5] is not None:  # time
723                self.assertIsInstance(row[5], str)
724                self.assertTrue(row[5].replace(':', '').isdigit())
725            if row[6] is not None:  # numeric
726                self.assertIsInstance(row[6], Decimal)
727                row[6] = float(row[6])
728            if row[7] is not None:  # real
729                self.assertIsInstance(row[7], float)
730            if row[8] is not None:  # double precision
731                self.assertIsInstance(row[8], float)
732                row[8] = float(row[8])
733            if row[9] is not None:  # money
734                self.assertIsInstance(row[9], Decimal)
735                row[9] = str(float(row[9]))
736            if row[10] is not None:  # char(1)
737                self.assertIsInstance(row[10], str)
738                self.assertEqual(len(row[10]), 1)
739            if row[11] is not None:  # varchar(4)
740                self.assertIsInstance(row[11], str)
741                self.assertLessEqual(len(row[11]), 4)
742            if row[12] is not None:  # char(4)
743                self.assertIsInstance(row[12], str)
744                self.assertEqual(len(row[12]), 4)
745                row[12] = row[12].rstrip()
746            if row[13] is not None:  # text
747                self.assertIsInstance(row[13], str)
748            row = tuple(row)
749            data.append(row)
750        return data
751
752    def testInserttable1Row(self):
753        data = self.data[2:3]
754        self.c.inserttable("test", data)
755        self.assertEqual(self.get_back(), data)
756
757    def testInserttable4Rows(self):
758        data = self.data
759        self.c.inserttable("test", data)
760        self.assertEqual(self.get_back(), data)
761
762    def testInserttableMultipleRows(self):
763        num_rows = 100
764        data = self.data[2:3] * num_rows
765        self.c.inserttable("test", data)
766        r = self.c.query("select count(*) from test").getresult()[0][0]
767        self.assertEqual(r, num_rows)
768
769    def testInserttableMultipleCalls(self):
770        num_rows = 10
771        data = self.data[2:3]
772        for _i in range(num_rows):
773            self.c.inserttable("test", data)
774        r = self.c.query("select count(*) from test").getresult()[0][0]
775        self.assertEqual(r, num_rows)
776
777    def testInserttableNullValues(self):
778        data = [(None,) * 14] * 100
779        self.c.inserttable("test", data)
780        self.assertEqual(self.get_back(), data)
781
782    def testInserttableMaxValues(self):
783        data = [(2 ** 15 - 1, int(2 ** 31 - 1), long(2 ** 31 - 1),
784            True, '2999-12-31', '11:59:59', 1e99,
785            1.0 + 1.0 / 32, 1.0 + 1.0 / 32, None,
786            "1", "1234", "1234", "1234" * 100)]
787        self.c.inserttable("test", data)
788        self.assertEqual(self.get_back(), data)
789
790
791class TestDirectSocketAccess(unittest.TestCase):
792    """Test copy command with direct socket access."""
793
794    cls_set_up = False
795
796    @classmethod
797    def setUpClass(cls):
798        c = connect()
799        c.query("drop table if exists test cascade")
800        c.query("create table test (i int, v varchar(16))")
801        c.close()
802        cls.cls_set_up = True
803
804    @classmethod
805    def tearDownClass(cls):
806        c = connect()
807        c.query("drop table test cascade")
808        c.close()
809
810    def setUp(self):
811        self.assertTrue(self.cls_set_up)
812        self.c = connect()
813        self.c.query("set datestyle='ISO,YMD'")
814
815    def tearDown(self):
816        self.c.query("truncate table test")
817        self.c.close()
818
819    def testPutline(self):
820        putline = self.c.putline
821        query = self.c.query
822        data = list(enumerate("apple pear plum cherry banana".split()))
823        query("copy test from stdin")
824        try:
825            for i, v in data:
826                putline("%d\t%s\n" % (i, v))
827            putline("\\.\n")
828        finally:
829            self.c.endcopy()
830        r = query("select * from test").getresult()
831        self.assertEqual(r, data)
832
833    def testGetline(self):
834        getline = self.c.getline
835        query = self.c.query
836        data = list(enumerate("apple banana pear plum strawberry".split()))
837        n = len(data)
838        self.c.inserttable('test', data)
839        query("copy test to stdout")
840        try:
841            for i in range(n + 2):
842                v = getline()
843                if i < n:
844                    self.assertEqual(v, '%d\t%s' % data[i])
845                elif i == n or self.c.server_version < 90000:
846                    self.assertEqual(v, '\\.')
847                else:
848                    self.assertIsNone(v)
849        finally:
850            try:
851                self.c.endcopy()
852            except IOError:
853                pass
854
855    def testParameterChecks(self):
856        self.assertRaises(TypeError, self.c.putline)
857        self.assertRaises(TypeError, self.c.getline, 'invalid')
858        self.assertRaises(TypeError, self.c.endcopy, 'invalid')
859
860
861class TestNotificatons(unittest.TestCase):
862    """Test notification support."""
863
864    def setUp(self):
865        self.c = connect()
866
867    def tearDown(self):
868        self.c.close()
869
870    def testGetNotify(self):
871        if self.c.server_version < 90000:  # PostgreSQL < 9.0
872            self.skipTest('Notify with payload not supported')
873        getnotify = self.c.getnotify
874        query = self.c.query
875        self.assertIsNone(getnotify())
876        query('listen test_notify')
877        try:
878            self.assertIsNone(self.c.getnotify())
879            query("notify test_notify")
880            r = getnotify()
881            self.assertIsInstance(r, tuple)
882            self.assertEqual(len(r), 3)
883            self.assertIsInstance(r[0], str)
884            self.assertIsInstance(r[1], int)
885            self.assertIsInstance(r[2], str)
886            self.assertEqual(r[0], 'test_notify')
887            self.assertEqual(r[2], '')
888            self.assertIsNone(self.c.getnotify())
889            query("notify test_notify, 'test_payload'")
890            r = getnotify()
891            self.assertTrue(isinstance(r, tuple))
892            self.assertEqual(len(r), 3)
893            self.assertIsInstance(r[0], str)
894            self.assertIsInstance(r[1], int)
895            self.assertIsInstance(r[2], str)
896            self.assertEqual(r[0], 'test_notify')
897            self.assertEqual(r[2], 'test_payload')
898            self.assertIsNone(getnotify())
899        finally:
900            query('unlisten test_notify')
901
902    def testGetNoticeReceiver(self):
903        self.assertIsNone(self.c.get_notice_receiver())
904
905    def testSetNoticeReceiver(self):
906        self.assertRaises(TypeError, self.c.set_notice_receiver, None)
907        self.assertRaises(TypeError, self.c.set_notice_receiver, 42)
908        self.assertIsNone(self.c.set_notice_receiver(lambda notice: None))
909
910    def testSetAndGetNoticeReceiver(self):
911        r = lambda notice: None
912        self.assertIsNone(self.c.set_notice_receiver(r))
913        self.assertIs(self.c.get_notice_receiver(), r)
914
915    def testNoticeReceiver(self):
916        self.c.query('''create function bilbo_notice() returns void AS $$
917            begin
918                raise warning 'Bilbo was here!';
919            end;
920            $$ language plpgsql''')
921        try:
922            received = {}
923
924            def notice_receiver(notice):
925                for attr in dir(notice):
926                    value = getattr(notice, attr)
927                    if isinstance(value, str):
928                        value = value.replace('WARNUNG', 'WARNING')
929                    received[attr] = value
930
931            self.c.set_notice_receiver(notice_receiver)
932            self.c.query('''select bilbo_notice()''')
933            self.assertEqual(received, dict(
934                pgcnx=self.c, message='WARNING:  Bilbo was here!\n',
935                severity='WARNING', primary='Bilbo was here!',
936                detail=None, hint=None))
937        finally:
938            self.c.query('''drop function bilbo_notice();''')
939
940
941class TestConfigFunctions(unittest.TestCase):
942    """Test the functions for changing default settings.
943
944    To test the effect of most of these functions, we need a database
945    connection.  That's why they are covered in this test module.
946
947    """
948
949    def setUp(self):
950        self.c = connect()
951
952    def tearDown(self):
953        self.c.close()
954
955    def testGetDecimalPoint(self):
956        point = pg.get_decimal_point()
957        # error if a parameter is passed
958        self.assertRaises(TypeError, pg.get_decimal_point, point)
959        self.assertIsInstance(point, str)
960        self.assertEqual(point, '.')  # the default setting
961        pg.set_decimal_point(',')
962        try:
963            r = pg.get_decimal_point()
964        finally:
965            pg.set_decimal_point(point)
966        self.assertIsInstance(r, str)
967        self.assertEqual(r, ',')
968        pg.set_decimal_point("'")
969        try:
970            r = pg.get_decimal_point()
971        finally:
972            pg.set_decimal_point(point)
973        self.assertIsInstance(r, str)
974        self.assertEqual(r, "'")
975        pg.set_decimal_point('')
976        try:
977            r = pg.get_decimal_point()
978        finally:
979            pg.set_decimal_point(point)
980        self.assertIsNone(r)
981        pg.set_decimal_point(None)
982        try:
983            r = pg.get_decimal_point()
984        finally:
985            pg.set_decimal_point(point)
986        self.assertIsNone(r)
987
988    def testSetDecimalPoint(self):
989        d = pg.Decimal
990        point = pg.get_decimal_point()
991        self.assertRaises(TypeError, pg.set_decimal_point)
992        # error if decimal point is not a string
993        self.assertRaises(TypeError, pg.set_decimal_point, 0)
994        # error if more than one decimal point passed
995        self.assertRaises(TypeError, pg.set_decimal_point, '.', ',')
996        self.assertRaises(TypeError, pg.set_decimal_point, '.,')
997        # error if decimal point is not a punctuation character
998        self.assertRaises(TypeError, pg.set_decimal_point, '0')
999        query = self.c.query
1000        # check that money values are interpreted as decimal values
1001        # only if decimal_point is set, and that the result is correct
1002        # only if it is set suitable for the current lc_monetary setting
1003        select_money = "select '34.25'::money"
1004        proper_money = d('34.25')
1005        bad_money = d('3425')
1006        en_locales = 'en', 'en_US', 'en_US.utf8', 'en_US.UTF-8'
1007        en_money = '$34.25', '$ 34.25', '34.25$', '34.25 $', '34.25 Dollar'
1008        de_locales = 'de', 'de_DE', 'de_DE.utf8', 'de_DE.UTF-8'
1009        de_money = ('34,25€', '34,25 €', '€34,25', '€ 34,25',
1010            'EUR34,25', 'EUR 34,25', '34,25 EUR', '34,25 Euro', '34,25 DM')
1011        # first try with English localization (using the point)
1012        for lc in en_locales:
1013            try:
1014                query("set lc_monetary='%s'" % lc)
1015            except pg.ProgrammingError:
1016                pass
1017            else:
1018                break
1019        else:
1020            self.skipTest("cannot set English money locale")
1021        try:
1022            r = query(select_money)
1023        except pg.ProgrammingError:
1024            # this can happen if the currency signs cannot be
1025            # converted using the encoding of the test database
1026            self.skipTest("database does not support English money")
1027        pg.set_decimal_point(None)
1028        try:
1029            r = r.getresult()[0][0]
1030        finally:
1031            pg.set_decimal_point(point)
1032        self.assertIsInstance(r, str)
1033        self.assertIn(r, en_money)
1034        r = query(select_money)
1035        pg.set_decimal_point('')
1036        try:
1037            r = r.getresult()[0][0]
1038        finally:
1039            pg.set_decimal_point(point)
1040        self.assertIsInstance(r, str)
1041        self.assertIn(r, en_money)
1042        r = query(select_money)
1043        pg.set_decimal_point('.')
1044        try:
1045            r = r.getresult()[0][0]
1046        finally:
1047            pg.set_decimal_point(point)
1048        self.assertIsInstance(r, d)
1049        self.assertEqual(r, proper_money)
1050        r = query(select_money)
1051        pg.set_decimal_point(',')
1052        try:
1053            r = r.getresult()[0][0]
1054        finally:
1055            pg.set_decimal_point(point)
1056        self.assertIsInstance(r, d)
1057        self.assertEqual(r, bad_money)
1058        r = query(select_money)
1059        pg.set_decimal_point("'")
1060        try:
1061            r = r.getresult()[0][0]
1062        finally:
1063            pg.set_decimal_point(point)
1064        self.assertIsInstance(r, d)
1065        self.assertEqual(r, bad_money)
1066        # then try with German localization (using the comma)
1067        for lc in de_locales:
1068            try:
1069                query("set lc_monetary='%s'" % lc)
1070            except pg.ProgrammingError:
1071                pass
1072            else:
1073                break
1074        else:
1075            self.skipTest("cannot set German money locale")
1076        select_money = select_money.replace('.', ',')
1077        try:
1078            r = query(select_money)
1079        except pg.ProgrammingError:
1080            self.skipTest("database does not support English money")
1081        pg.set_decimal_point(None)
1082        try:
1083            r = r.getresult()[0][0]
1084        finally:
1085            pg.set_decimal_point(point)
1086        self.assertIsInstance(r, str)
1087        self.assertIn(r, de_money)
1088        r = query(select_money)
1089        pg.set_decimal_point('')
1090        try:
1091            r = r.getresult()[0][0]
1092        finally:
1093            pg.set_decimal_point(point)
1094        self.assertIsInstance(r, str)
1095        self.assertIn(r, de_money)
1096        r = query(select_money)
1097        pg.set_decimal_point(',')
1098        try:
1099            r = r.getresult()[0][0]
1100        finally:
1101            pg.set_decimal_point(point)
1102        self.assertIsInstance(r, d)
1103        self.assertEqual(r, proper_money)
1104        r = query(select_money)
1105        pg.set_decimal_point('.')
1106        try:
1107            r = r.getresult()[0][0]
1108        finally:
1109            pg.set_decimal_point(point)
1110        self.assertEqual(r, bad_money)
1111        r = query(select_money)
1112        pg.set_decimal_point("'")
1113        try:
1114            r = r.getresult()[0][0]
1115        finally:
1116            pg.set_decimal_point(point)
1117        self.assertEqual(r, bad_money)
1118
1119    def testGetDecimal(self):
1120        decimal_class = pg.get_decimal()
1121        # error if a parameter is passed
1122        self.assertRaises(TypeError, pg.get_decimal, decimal_class)
1123        self.assertIs(decimal_class, pg.Decimal)  # the default setting
1124        pg.set_decimal(int)
1125        try:
1126            r = pg.get_decimal()
1127        finally:
1128            pg.set_decimal(decimal_class)
1129        self.assertIs(r, int)
1130        r = pg.get_decimal()
1131        self.assertIs(r, decimal_class)
1132
1133    def testSetDecimal(self):
1134        decimal_class = pg.get_decimal()
1135        # error if no parameter is passed
1136        self.assertRaises(TypeError, pg.set_decimal)
1137        query = self.c.query
1138        try:
1139            r = query("select 3425::numeric")
1140        except pg.ProgrammingError:
1141            self.skipTest('database does not support numeric')
1142        r = r.getresult()[0][0]
1143        self.assertIsInstance(r, decimal_class)
1144        self.assertEqual(r, decimal_class('3425'))
1145        r = query("select 3425::numeric")
1146        pg.set_decimal(int)
1147        try:
1148            r = r.getresult()[0][0]
1149        finally:
1150            pg.set_decimal(decimal_class)
1151        self.assertNotIsInstance(r, decimal_class)
1152        self.assertIsInstance(r, int)
1153        self.assertEqual(r, int(3425))
1154
1155    def testGetBool(self):
1156        use_bool = pg.get_bool()
1157        # error if a parameter is passed
1158        self.assertRaises(TypeError, pg.get_bool, use_bool)
1159        self.assertIsInstance(use_bool, bool)
1160        self.assertIs(use_bool, False)  # the default setting
1161        pg.set_bool(True)
1162        try:
1163            r = pg.get_bool()
1164        finally:
1165            pg.set_bool(use_bool)
1166        self.assertIsInstance(r, bool)
1167        self.assertIs(r, True)
1168        pg.set_bool(False)
1169        try:
1170            r = pg.get_bool()
1171        finally:
1172            pg.set_bool(use_bool)
1173        self.assertIsInstance(r, bool)
1174        self.assertIs(r, False)
1175        pg.set_bool(1)
1176        try:
1177            r = pg.get_bool()
1178        finally:
1179            pg.set_bool(use_bool)
1180        self.assertIsInstance(r, bool)
1181        self.assertIs(r, True)
1182        pg.set_bool(0)
1183        try:
1184            r = pg.get_bool()
1185        finally:
1186            pg.set_bool(use_bool)
1187        self.assertIsInstance(r, bool)
1188        self.assertIs(r, False)
1189
1190    def testSetBool(self):
1191        use_bool = pg.get_bool()
1192        # error if no parameter is passed
1193        self.assertRaises(TypeError, pg.set_bool)
1194        query = self.c.query
1195        try:
1196            r = query("select true::bool")
1197        except pg.ProgrammingError:
1198            self.skipTest('database does not support bool')
1199        r = r.getresult()[0][0]
1200        self.assertIsInstance(r, str)
1201        self.assertEqual(r, 't')
1202        r = query("select true::bool")
1203        pg.set_bool(True)
1204        try:
1205            r = r.getresult()[0][0]
1206        finally:
1207            pg.set_bool(use_bool)
1208        self.assertIsInstance(r, bool)
1209        self.assertIs(r, True)
1210        r = query("select true::bool")
1211        pg.set_bool(False)
1212        try:
1213            r = r.getresult()[0][0]
1214        finally:
1215            pg.set_bool(use_bool)
1216        self.assertIsInstance(r, str)
1217        self.assertIs(r, 't')
1218
1219    @unittest.skipUnless(namedtuple, 'Named tuples not available')
1220    def testGetNamedresult(self):
1221        namedresult = pg.get_namedresult()
1222        # error if a parameter is passed
1223        self.assertRaises(TypeError, pg.get_namedresult, namedresult)
1224        self.assertIs(namedresult, pg._namedresult)  # the default setting
1225
1226    @unittest.skipUnless(namedtuple, 'Named tuples not available')
1227    def testSetNamedresult(self):
1228        namedresult = pg.get_namedresult()
1229        self.assertTrue(callable(namedresult))
1230
1231        query = self.c.query
1232
1233        r = query("select 1 as x, 2 as y").namedresult()[0]
1234        self.assertIsInstance(r, tuple)
1235        self.assertEqual(r, (1, 2))
1236        self.assertIsNot(type(r), tuple)
1237        self.assertEqual(r._fields, ('x', 'y'))
1238        self.assertEqual(r._asdict(), {'x': 1, 'y': 2})
1239        self.assertEqual(r.__class__.__name__, 'Row')
1240
1241        def listresult(q):
1242            return [list(row) for row in q.getresult()]
1243
1244        pg.set_namedresult(listresult)
1245        try:
1246            r = pg.get_namedresult()
1247            self.assertIs(r, listresult)
1248            r = query("select 1 as x, 2 as y").namedresult()[0]
1249            self.assertIsInstance(r, list)
1250            self.assertEqual(r, [1, 2])
1251            self.assertIsNot(type(r), tuple)
1252            self.assertFalse(hasattr(r, '_fields'))
1253            self.assertNotEqual(r.__class__.__name__, 'Row')
1254        finally:
1255            pg.set_namedresult(namedresult)
1256
1257        r = pg.get_namedresult()
1258        self.assertIs(r, namedresult)
1259
1260
1261if __name__ == '__main__':
1262    unittest.main()
Note: See TracBrowser for help on using the repository browser.