source: branches/4.x/module/tests/test_classic_connection.py @ 642

Last change on this file since 642 was 642, checked in by cito, 4 years ago

Skip test if German locale cannot be set

  • Property svn:executable set to *
  • Property svn:keywords set to Id
File size: 37.1 KB
Line 
1#! /usr/bin/python
2# -*- coding: utf-8 -*-
3
4"""Test the classic PyGreSQL interface.
5
6Sub-tests for the low-level connection object.
7
8Contributed by Christoph Zwerschke.
9
10These tests need a database to test against.
11
12"""
13
14try:
15    import unittest2 as unittest  # for Python < 2.7
16except ImportError:
17    import unittest
18import sys
19import tempfile
20import threading
21import time
22
23import pg  # the module under test
24
25from decimal import Decimal
26try:
27    from collections import namedtuple
28except ImportError:  # Python < 2.6
29    namedtuple = None
30
31# We need a database to test against.  If LOCAL_PyGreSQL.py exists we will
32# get our information from that.  Otherwise we use the defaults.
33dbname = 'unittest'
34dbhost = None
35dbport = 5432
36
37try:
38    from LOCAL_PyGreSQL import *
39except ImportError:
40    pass
41
42
43def connect():
44    """Create a basic pg connection to the test database."""
45    connection = pg.connect(dbname, dbhost, dbport)
46    connection.query("set client_min_messages=warning")
47    return connection
48
49
50class TestCanConnect(unittest.TestCase):
51    """Test whether a basic connection to PostgreSQL is possible."""
52
53    def testCanConnect(self):
54        try:
55            connection = connect()
56        except pg.Error, error:
57            self.fail('Cannot connect to database %s:\n%s' % (dbname, error))
58        try:
59            connection.close()
60        except pg.Error:
61            self.fail('Cannot close the database connection')
62
63
64class TestConnectObject(unittest.TestCase):
65    """Test existence of basic pg connection methods."""
66
67    def setUp(self):
68        self.connection = connect()
69
70    def tearDown(self):
71        try:
72            self.connection.close()
73        except pg.InternalError:
74            pass
75
76    def testAllConnectAttributes(self):
77        attributes = '''db error host options port
78            protocol_version server_version status tty user'''.split()
79        connection_attributes = [a for a in dir(self.connection)
80            if not callable(eval("self.connection." + a))]
81        self.assertEqual(attributes, connection_attributes)
82
83    def testAllConnectMethods(self):
84        methods = '''cancel close endcopy
85            escape_bytea escape_identifier escape_literal escape_string
86            fileno get_notice_receiver getline getlo getnotify
87            inserttable locreate loimport parameter putline query reset
88            set_notice_receiver source transaction'''.split()
89        connection_methods = [a for a in dir(self.connection)
90            if callable(eval("self.connection." + a))]
91        self.assertEqual(methods, connection_methods)
92
93    def testAttributeDb(self):
94        self.assertEqual(self.connection.db, dbname)
95
96    def testAttributeError(self):
97        error = self.connection.error
98        self.assertTrue(not error or 'krb5_' in error)
99
100    def testAttributeHost(self):
101        def_host = 'localhost'
102        self.assertIsInstance(self.connection.host, str)
103        self.assertEqual(self.connection.host, dbhost or def_host)
104
105    def testAttributeOptions(self):
106        no_options = ''
107        self.assertEqual(self.connection.options, no_options)
108
109    def testAttributePort(self):
110        def_port = 5432
111        self.assertIsInstance(self.connection.port, int)
112        self.assertEqual(self.connection.port, dbport or def_port)
113
114    def testAttributeProtocolVersion(self):
115        protocol_version = self.connection.protocol_version
116        self.assertIsInstance(protocol_version, int)
117        self.assertTrue(2 <= protocol_version < 4)
118
119    def testAttributeServerVersion(self):
120        server_version = self.connection.server_version
121        self.assertIsInstance(server_version, int)
122        self.assertTrue(70400 <= server_version < 100000)
123
124    def testAttributeStatus(self):
125        status_ok = 1
126        self.assertIsInstance(self.connection.status, int)
127        self.assertEqual(self.connection.status, status_ok)
128
129    def testAttributeTty(self):
130        def_tty = ''
131        self.assertIsInstance(self.connection.tty, str)
132        self.assertEqual(self.connection.tty, def_tty)
133
134    def testAttributeUser(self):
135        no_user = 'Deprecated facility'
136        user = self.connection.user
137        self.assertTrue(user)
138        self.assertIsInstance(user, str)
139        self.assertNotEqual(user, no_user)
140
141    def testMethodQuery(self):
142        query = self.connection.query
143        query("select 1+1")
144        query("select 1+$1", (1,))
145        query("select 1+$1+$2", (2, 3))
146        query("select 1+$1+$2", [2, 3])
147
148    def testMethodQueryEmpty(self):
149        self.assertRaises(ValueError, self.connection.query, '')
150
151    def testMethodEndcopy(self):
152        try:
153            self.connection.endcopy()
154        except IOError:
155            pass
156
157    def testMethodClose(self):
158        self.connection.close()
159        try:
160            self.connection.reset()
161        except (pg.Error, TypeError):
162            pass
163        else:
164            self.fail('Reset should give an error for a closed connection')
165        self.assertRaises(pg.InternalError, self.connection.close)
166        try:
167            self.connection.query('select 1')
168        except (pg.Error, TypeError):
169            pass
170        else:
171            self.fail('Query should give an error for a closed connection')
172        self.connection = connect()
173
174    def testMethodReset(self):
175        query = self.connection.query
176        # check that client encoding gets reset
177        encoding = query('show client_encoding').getresult()[0][0].upper()
178        changed_encoding = 'LATIN1' if encoding == 'UTF8' else 'UTF8'
179        self.assertNotEqual(encoding, changed_encoding)
180        self.connection.query("set client_encoding=%s" % changed_encoding)
181        new_encoding = query('show client_encoding').getresult()[0][0].upper()
182        self.assertEqual(new_encoding, changed_encoding)
183        self.connection.reset()
184        new_encoding = query('show client_encoding').getresult()[0][0].upper()
185        self.assertNotEqual(new_encoding, changed_encoding)
186        self.assertEqual(new_encoding, encoding)
187
188    def testMethodCancel(self):
189        r = self.connection.cancel()
190        self.assertIsInstance(r, int)
191        self.assertEqual(r, 1)
192
193    def testCancelLongRunningThread(self):
194        errors = []
195
196        def sleep():
197            try:
198                self.connection.query('select pg_sleep(5)').getresult()
199            except pg.ProgrammingError, error:
200                errors.append(str(error))
201
202        thread = threading.Thread(target=sleep)
203        t1 = time.time()
204        thread.start()  # run the query
205        while 1:  # make sure the query is really running
206            time.sleep(0.1)
207            if thread.is_alive() or time.time() - t1 > 5:
208                break
209        r = self.connection.cancel()  # cancel the running query
210        thread.join()  # wait for the thread to end
211        t2 = time.time()
212
213        self.assertIsInstance(r, int)
214        self.assertEqual(r, 1)  # return code should be 1
215        self.assertLessEqual(t2 - t1, 3)  # time should be under 3 seconds
216        self.assertTrue(errors)
217
218    def testMethodFileNo(self):
219        r = self.connection.fileno()
220        self.assertIsInstance(r, int)
221        self.assertGreaterEqual(r, 0)
222
223
224class TestSimpleQueries(unittest.TestCase):
225    """Test simple queries via a basic pg connection."""
226
227    def setUp(self):
228        self.c = connect()
229
230    def tearDown(self):
231        self.c.close()
232
233    def testSelect0(self):
234        q = "select 0"
235        self.c.query(q)
236
237    def testSelect0Semicolon(self):
238        q = "select 0;"
239        self.c.query(q)
240
241    def testSelectDotSemicolon(self):
242        q = "select .;"
243        self.assertRaises(pg.ProgrammingError, self.c.query, q)
244
245    def testGetresult(self):
246        q = "select 0"
247        result = [(0,)]
248        r = self.c.query(q).getresult()
249        self.assertIsInstance(r, list)
250        v = r[0]
251        self.assertIsInstance(v, tuple)
252        self.assertIsInstance(v[0], int)
253        self.assertEqual(r, result)
254
255    def testGetresultLong(self):
256        q = "select 9876543210"
257        result = 9876543210L
258        v = self.c.query(q).getresult()[0][0]
259        self.assertIsInstance(v, long)
260        self.assertEqual(v, result)
261
262    def testGetresultDecimal(self):
263        q = "select 98765432109876543210"
264        result = Decimal(98765432109876543210L)
265        v = self.c.query(q).getresult()[0][0]
266        self.assertIsInstance(v, Decimal)
267        self.assertEqual(v, result)
268
269    def testGetresultString(self):
270        result = 'Hello, world!'
271        q = "select '%s'" % result
272        v = self.c.query(q).getresult()[0][0]
273        self.assertIsInstance(v, str)
274        self.assertEqual(v, result)
275
276    def testDictresult(self):
277        q = "select 0 as alias0"
278        result = [{'alias0': 0}]
279        r = self.c.query(q).dictresult()
280        self.assertIsInstance(r, list)
281        v = r[0]
282        self.assertIsInstance(v, dict)
283        self.assertIsInstance(v['alias0'], int)
284        self.assertEqual(r, result)
285
286    def testDictresultLong(self):
287        q = "select 9876543210 as longjohnsilver"
288        result = 9876543210L
289        v = self.c.query(q).dictresult()[0]['longjohnsilver']
290        self.assertIsInstance(v, long)
291        self.assertEqual(v, result)
292
293    def testDictresultDecimal(self):
294        q = "select 98765432109876543210 as longjohnsilver"
295        result = Decimal(98765432109876543210L)
296        v = self.c.query(q).dictresult()[0]['longjohnsilver']
297        self.assertIsInstance(v, Decimal)
298        self.assertEqual(v, result)
299
300    def testDictresultString(self):
301        result = 'Hello, world!'
302        q = "select '%s' as greeting" % result
303        v = self.c.query(q).dictresult()[0]['greeting']
304        self.assertIsInstance(v, str)
305        self.assertEqual(v, result)
306
307    @unittest.skipUnless(namedtuple, 'Named tuples not available')
308    def testNamedresult(self):
309        q = "select 0 as alias0"
310        result = [(0,)]
311        r = self.c.query(q).namedresult()
312        self.assertEqual(r, result)
313        v = r[0]
314        self.assertEqual(v._fields, ('alias0',))
315        self.assertEqual(v.alias0, 0)
316
317    def testGet3Cols(self):
318        q = "select 1,2,3"
319        result = [(1, 2, 3)]
320        r = self.c.query(q).getresult()
321        self.assertEqual(r, result)
322
323    def testGet3DictCols(self):
324        q = "select 1 as a,2 as b,3 as c"
325        result = [dict(a=1, b=2, c=3)]
326        r = self.c.query(q).dictresult()
327        self.assertEqual(r, result)
328
329    @unittest.skipUnless(namedtuple, 'Named tuples not available')
330    def testGet3NamedCols(self):
331        q = "select 1 as a,2 as b,3 as c"
332        result = [(1, 2, 3)]
333        r = self.c.query(q).namedresult()
334        self.assertEqual(r, result)
335        v = r[0]
336        self.assertEqual(v._fields, ('a', 'b', 'c'))
337        self.assertEqual(v.b, 2)
338
339    def testGet3Rows(self):
340        q = "select 3 union select 1 union select 2 order by 1"
341        result = [(1,), (2,), (3,)]
342        r = self.c.query(q).getresult()
343        self.assertEqual(r, result)
344
345    def testGet3DictRows(self):
346        q = ("select 3 as alias3"
347            " union select 1 union select 2 order by 1")
348        result = [{'alias3': 1}, {'alias3': 2}, {'alias3': 3}]
349        r = self.c.query(q).dictresult()
350        self.assertEqual(r, result)
351
352    @unittest.skipUnless(namedtuple, 'Named tuples not available')
353    def testGet3NamedRows(self):
354        q = ("select 3 as alias3"
355            " union select 1 union select 2 order by 1")
356        result = [(1,), (2,), (3,)]
357        r = self.c.query(q).namedresult()
358        self.assertEqual(r, result)
359        for v in r:
360            self.assertEqual(v._fields, ('alias3',))
361
362    def testDictresultNames(self):
363        q = "select 'MixedCase' as MixedCaseAlias"
364        result = [{'mixedcasealias': 'MixedCase'}]
365        r = self.c.query(q).dictresult()
366        self.assertEqual(r, result)
367        q = "select 'MixedCase' as \"MixedCaseAlias\""
368        result = [{'MixedCaseAlias': 'MixedCase'}]
369        r = self.c.query(q).dictresult()
370        self.assertEqual(r, result)
371
372    @unittest.skipUnless(namedtuple, 'Named tuples not available')
373    def testNamedresultNames(self):
374        q = "select 'MixedCase' as MixedCaseAlias"
375        result = [('MixedCase',)]
376        r = self.c.query(q).namedresult()
377        self.assertEqual(r, result)
378        v = r[0]
379        self.assertEqual(v._fields, ('mixedcasealias',))
380        self.assertEqual(v.mixedcasealias, 'MixedCase')
381        q = "select 'MixedCase' as \"MixedCaseAlias\""
382        r = self.c.query(q).namedresult()
383        self.assertEqual(r, result)
384        v = r[0]
385        self.assertEqual(v._fields, ('MixedCaseAlias',))
386        self.assertEqual(v.MixedCaseAlias, 'MixedCase')
387
388    def testBigGetresult(self):
389        num_cols = 100
390        num_rows = 100
391        q = "select " + ','.join(map(str, xrange(num_cols)))
392        q = ' union all '.join((q,) * num_rows)
393        r = self.c.query(q).getresult()
394        result = [tuple(range(num_cols))] * num_rows
395        self.assertEqual(r, result)
396
397    def testListfields(self):
398        q = ('select 0 as a, 0 as b, 0 as c,'
399            ' 0 as c, 0 as b, 0 as a,'
400            ' 0 as lowercase, 0 as UPPERCASE,'
401            ' 0 as MixedCase, 0 as "MixedCase",'
402            ' 0 as a_long_name_with_underscores,'
403            ' 0 as "A long name with Blanks"')
404        r = self.c.query(q).listfields()
405        result = ('a', 'b', 'c', 'c', 'b', 'a',
406            'lowercase', 'uppercase', 'mixedcase', 'MixedCase',
407            'a_long_name_with_underscores',
408            'A long name with Blanks')
409        self.assertEqual(r, result)
410
411    def testFieldname(self):
412        q = "select 0 as z, 0 as a, 0 as x, 0 as y"
413        r = self.c.query(q).fieldname(2)
414        self.assertEqual(r, 'x')
415        r = self.c.query(q).fieldname(3)
416        self.assertEqual(r, 'y')
417
418    def testFieldnum(self):
419        q = "select 1 as x"
420        self.assertRaises(ValueError, self.c.query(q).fieldnum, 'y')
421        q = "select 1 as x"
422        r = self.c.query(q).fieldnum('x')
423        self.assertIsInstance(r, int)
424        self.assertEqual(r, 0)
425        q = "select 0 as z, 0 as a, 0 as x, 0 as y"
426        r = self.c.query(q).fieldnum('x')
427        self.assertIsInstance(r, int)
428        self.assertEqual(r, 2)
429        r = self.c.query(q).fieldnum('y')
430        self.assertIsInstance(r, int)
431        self.assertEqual(r, 3)
432
433    def testNtuples(self):
434        q = "select 1 where false"
435        r = self.c.query(q).ntuples()
436        self.assertIsInstance(r, int)
437        self.assertEqual(r, 0)
438        q = ("select 1 as a, 2 as b, 3 as c, 4 as d"
439            " union select 5 as a, 6 as b, 7 as c, 8 as d")
440        r = self.c.query(q).ntuples()
441        self.assertIsInstance(r, int)
442        self.assertEqual(r, 2)
443        q = ("select 1 union select 2 union select 3"
444            " union select 4 union select 5 union select 6")
445        r = self.c.query(q).ntuples()
446        self.assertIsInstance(r, int)
447        self.assertEqual(r, 6)
448
449    def testQuery(self):
450        query = self.c.query
451        query("drop table if exists test_table")
452        q = "create table test_table (n integer) with oids"
453        r = query(q)
454        self.assertIsNone(r)
455        q = "insert into test_table values (1)"
456        r = query(q)
457        self.assertIsInstance(r, int)
458        q = "insert into test_table select 2"
459        r = query(q)
460        self.assertIsInstance(r, int)
461        oid = r
462        q = "select oid from test_table where n=2"
463        r = query(q).getresult()
464        self.assertEqual(len(r), 1)
465        r = r[0]
466        self.assertEqual(len(r), 1)
467        r = r[0]
468        self.assertIsInstance(r, int)
469        self.assertEqual(r, oid)
470        q = "insert into test_table select 3 union select 4 union select 5"
471        r = query(q)
472        self.assertIsInstance(r, str)
473        self.assertEqual(r, '3')
474        q = "update test_table set n=4 where n<5"
475        r = query(q)
476        self.assertIsInstance(r, str)
477        self.assertEqual(r, '4')
478        q = "delete from test_table"
479        r = query(q)
480        self.assertIsInstance(r, str)
481        self.assertEqual(r, '5')
482        query("drop table test_table")
483
484    def testPrint(self):
485        q = ("select 1 as a, 'hello' as h, 'w' as world"
486            " union select 2, 'xyz', 'uvw'")
487        r = self.c.query(q)
488        f = tempfile.TemporaryFile()
489        stdout, sys.stdout = sys.stdout, f
490        try:
491            print r
492        except Exception:
493            pass
494        finally:
495            sys.stdout = stdout
496        f.seek(0)
497        r = f.read()
498        f.close()
499        self.assertEqual(r,
500            'a|  h  |world\n'
501            '-+-----+-----\n'
502            '1|hello|w    \n'
503            '2|xyz  |uvw  \n'
504            '(2 rows)\n')
505
506
507class TestParamQueries(unittest.TestCase):
508    """Test queries with parameters via a basic pg connection."""
509
510    def setUp(self):
511        self.c = connect()
512
513    def tearDown(self):
514        self.c.close()
515
516    def testQueryWithNoneParam(self):
517        self.assertEqual(self.c.query("select $1::integer", (None,)
518            ).getresult(), [(None,)])
519        self.assertEqual(self.c.query("select $1::text", [None]
520            ).getresult(), [(None,)])
521
522    def testQueryWithBoolParams(self, use_bool=None):
523        query = self.c.query
524        if use_bool is not None:
525            use_bool_default = pg.get_bool()
526            pg.set_bool(use_bool)
527        try:
528            v_false, v_true = (False, True) if use_bool else 'ft'
529            r_false, r_true = [(v_false,)], [(v_true,)]
530            self.assertEqual(query("select false").getresult(), r_false)
531            self.assertEqual(query("select true").getresult(), r_true)
532            q = "select $1::bool"
533            self.assertEqual(query(q, (None,)).getresult(), [(None,)])
534            self.assertEqual(query(q, ('f',)).getresult(), r_false)
535            self.assertEqual(query(q, ('t',)).getresult(), r_true)
536            self.assertEqual(query(q, ('false',)).getresult(), r_false)
537            self.assertEqual(query(q, ('true',)).getresult(), r_true)
538            self.assertEqual(query(q, ('n',)).getresult(), r_false)
539            self.assertEqual(query(q, ('y',)).getresult(), r_true)
540            self.assertEqual(query(q, (0,)).getresult(), r_false)
541            self.assertEqual(query(q, (1,)).getresult(), r_true)
542            self.assertEqual(query(q, (False,)).getresult(), r_false)
543            self.assertEqual(query(q, (True,)).getresult(), r_true)
544        finally:
545            if use_bool is not None:
546                pg.set_bool(use_bool_default)
547
548    def testQueryWithBoolParamsAndUseBool(self):
549        self.testQueryWithBoolParams(use_bool=True)
550
551    def testQueryWithIntParams(self):
552        query = self.c.query
553        self.assertEqual(query("select 1+1").getresult(), [(2,)])
554        self.assertEqual(query("select 1+$1", (1,)).getresult(), [(2,)])
555        self.assertEqual(query("select 1+$1", [1]).getresult(), [(2,)])
556        self.assertEqual(query("select $1::integer", (2,)).getresult(), [(2,)])
557        self.assertEqual(query("select $1::text", (2,)).getresult(), [('2',)])
558        self.assertEqual(query("select 1+$1::numeric", [1]).getresult(),
559            [(Decimal('2'),)])
560        self.assertEqual(query("select 1, $1::integer", (2,)
561            ).getresult(), [(1, 2)])
562        self.assertEqual(query("select 1 union select $1", (2,)
563            ).getresult(), [(1,), (2,)])
564        self.assertEqual(query("select $1::integer+$2", (1, 2)
565            ).getresult(), [(3,)])
566        self.assertEqual(query("select $1::integer+$2", [1, 2]
567            ).getresult(), [(3,)])
568        self.assertEqual(query("select 0+$1+$2+$3+$4+$5+$6", range(6)
569            ).getresult(), [(15,)])
570
571    def testQueryWithStrParams(self):
572        query = self.c.query
573        self.assertEqual(query("select $1||', world!'", ('Hello',)
574            ).getresult(), [('Hello, world!',)])
575        self.assertEqual(query("select $1||', world!'", ['Hello']
576            ).getresult(), [('Hello, world!',)])
577        self.assertEqual(query("select $1||', '||$2||'!'", ('Hello', 'world'),
578            ).getresult(), [('Hello, world!',)])
579        self.assertEqual(query("select $1::text", ('Hello, world!',)
580            ).getresult(), [('Hello, world!',)])
581        self.assertEqual(query("select $1::text,$2::text", ('Hello', 'world')
582            ).getresult(), [('Hello', 'world')])
583        self.assertEqual(query("select $1::text,$2::text", ['Hello', 'world']
584            ).getresult(), [('Hello', 'world')])
585        self.assertEqual(query("select $1::text union select $2::text",
586            ('Hello', 'world')).getresult(), [('Hello',), ('world',)])
587        self.assertEqual(query("select $1||', '||$2||'!'", ('Hello',
588            'w\xc3\xb6rld')).getresult(), [('Hello, w\xc3\xb6rld!',)])
589
590    def testQueryWithUnicodeParams(self):
591        query = self.c.query
592        query('set client_encoding = utf8')
593        self.assertEqual(query("select $1||', '||$2||'!'",
594            ('Hello', u'w\xf6rld')).getresult(), [('Hello, w\xc3\xb6rld!',)])
595        self.assertEqual(query("select $1||', '||$2||'!'",
596            ('Hello', u'\u043c\u0438\u0440')).getresult(),
597            [('Hello, \xd0\xbc\xd0\xb8\xd1\x80!',)])
598        query('set client_encoding = latin1')
599        self.assertEqual(query("select $1||', '||$2||'!'",
600            ('Hello', u'w\xf6rld')).getresult(), [('Hello, w\xf6rld!',)])
601        self.assertRaises(UnicodeError, query, "select $1||', '||$2||'!'",
602            ('Hello', u'\u043c\u0438\u0440'))
603        query('set client_encoding = iso_8859_1')
604        self.assertEqual(query("select $1||', '||$2||'!'",
605            ('Hello', u'w\xf6rld')).getresult(), [('Hello, w\xf6rld!',)])
606        self.assertRaises(UnicodeError, query, "select $1||', '||$2||'!'",
607            ('Hello', u'\u043c\u0438\u0440'))
608        query('set client_encoding = iso_8859_5')
609        self.assertRaises(UnicodeError, query, "select $1||', '||$2||'!'",
610            ('Hello', u'w\xf6rld'))
611        self.assertEqual(query("select $1||', '||$2||'!'",
612            ('Hello', u'\u043c\u0438\u0440')).getresult(),
613            [('Hello, \xdc\xd8\xe0!',)])
614        query('set client_encoding = sql_ascii')
615        self.assertRaises(UnicodeError, query, "select $1||', '||$2||'!'",
616            ('Hello', u'w\xf6rld'))
617
618    def testQueryWithMixedParams(self):
619        self.assertEqual(self.c.query("select $1+2,$2||', world!'",
620            (1, 'Hello'),).getresult(), [(3, 'Hello, world!')])
621        self.assertEqual(self.c.query("select $1::integer,$2::date,$3::text",
622            (4711, None, 'Hello!'),).getresult(), [(4711, None, 'Hello!')])
623
624    def testQueryWithDuplicateParams(self):
625        self.assertRaises(pg.ProgrammingError,
626            self.c.query, "select $1+$1", (1,))
627        self.assertRaises(pg.ProgrammingError,
628            self.c.query, "select $1+$1", (1, 2))
629
630    def testQueryWithZeroParams(self):
631        self.assertEqual(self.c.query("select 1+1", []
632            ).getresult(), [(2,)])
633
634    def testQueryWithGarbage(self):
635        garbage = r"'\{}+()-#[]oo324"
636        self.assertEqual(self.c.query("select $1::text AS garbage", (garbage,)
637            ).dictresult(), [{'garbage': garbage}])
638
639    def testUnicodeQuery(self):
640        query = self.c.query
641        self.assertEqual(query(u"select 1+1").getresult(), [(2,)])
642        self.assertRaises(TypeError, query, u"select 'Hello, w\xf6rld!'")
643
644
645class TestInserttable(unittest.TestCase):
646    """Test inserttable method."""
647
648    @classmethod
649    def setUpClass(cls):
650        c = connect()
651        c.query("drop table if exists test cascade")
652        c.query("create table test ("
653            "i2 smallint, i4 integer, i8 bigint, b boolean, dt date, ti time,"
654            "d numeric, f4 real, f8 double precision, m money,"
655            "c char(1), v4 varchar(4), c4 char(4), t text)")
656        c.close()
657
658    @classmethod
659    def tearDownClass(cls):
660        c = connect()
661        c.query("drop table test cascade")
662        c.close()
663
664    def setUp(self):
665        self.c = connect()
666        self.c.query("set lc_monetary='C'")
667        self.c.query("set datestyle='ISO,YMD'")
668
669    def tearDown(self):
670        self.c.query("truncate table test")
671        self.c.close()
672
673    data = [
674        (-1, -1, -1L, True, '1492-10-12', '08:30:00',
675            -1.2345, -1.75, -1.875, '-1.25', '-', 'r?', '!u', 'xyz'),
676        (0, 0, 0L, False, '1607-04-14', '09:00:00',
677            0.0, 0.0, 0.0, '0.0', ' ', '0123', '4567', '890'),
678        (1, 1, 1L, True, '1801-03-04', '03:45:00',
679            1.23456, 1.75, 1.875, '1.25', 'x', 'bc', 'cdef', 'g'),
680        (2, 2, 2L, False, '1903-12-17', '11:22:00',
681            2.345678, 2.25, 2.125, '2.75', 'y', 'q', 'ijk', 'mnop\nstux!')]
682
683    def get_back(self):
684        """Convert boolean and decimal values back."""
685        data = []
686        for row in self.c.query("select * from test order by 1").getresult():
687            self.assertIsInstance(row, tuple)
688            row = list(row)
689            if row[0] is not None:  # smallint
690                self.assertIsInstance(row[0], int)
691            if row[1] is not None:  # integer
692                self.assertIsInstance(row[1], int)
693            if row[2] is not None:  # bigint
694                self.assertIsInstance(row[2], long)
695            if row[3] is not None:  # boolean
696                self.assertIsInstance(row[3], str)
697                row[3] = {'f': False, 't': True}.get(row[3])
698            if row[4] is not None:  # date
699                self.assertIsInstance(row[4], str)
700                self.assertTrue(row[4].replace('-', '').isdigit())
701            if row[5] is not None:  # time
702                self.assertIsInstance(row[5], str)
703                self.assertTrue(row[5].replace(':', '').isdigit())
704            if row[6] is not None:  # numeric
705                self.assertIsInstance(row[6], Decimal)
706                row[6] = float(row[6])
707            if row[7] is not None:  # real
708                self.assertIsInstance(row[7], float)
709            if row[8] is not None:  # double precision
710                self.assertIsInstance(row[8], float)
711                row[8] = float(row[8])
712            if row[9] is not None:  # money
713                self.assertIsInstance(row[9], Decimal)
714                row[9] = str(float(row[9]))
715            if row[10] is not None:  # char(1)
716                self.assertIsInstance(row[10], str)
717                self.assertEqual(len(row[10]), 1)
718            if row[11] is not None:  # varchar(4)
719                self.assertIsInstance(row[11], str)
720                self.assertLessEqual(len(row[11]), 4)
721            if row[12] is not None:  # char(4)
722                self.assertIsInstance(row[12], str)
723                self.assertEqual(len(row[12]), 4)
724                row[12] = row[12].rstrip()
725            if row[13] is not None:  # text
726                self.assertIsInstance(row[13], str)
727            row = tuple(row)
728            data.append(row)
729        return data
730
731    def testInserttable1Row(self):
732        data = self.data[2:3]
733        self.c.inserttable("test", data)
734        self.assertEqual(self.get_back(), data)
735
736    def testInserttable4Rows(self):
737        data = self.data
738        self.c.inserttable("test", data)
739        self.assertEqual(self.get_back(), data)
740
741    def testInserttableMultipleRows(self):
742        num_rows = 100
743        data = self.data[2:3] * num_rows
744        self.c.inserttable("test", data)
745        r = self.c.query("select count(*) from test").getresult()[0][0]
746        self.assertEqual(r, num_rows)
747
748    def testInserttableMultipleCalls(self):
749        num_rows = 10
750        data = self.data[2:3]
751        for _i in range(num_rows):
752            self.c.inserttable("test", data)
753        r = self.c.query("select count(*) from test").getresult()[0][0]
754        self.assertEqual(r, num_rows)
755
756    def testInserttableNullValues(self):
757        data = [(None,) * 14] * 100
758        self.c.inserttable("test", data)
759        self.assertEqual(self.get_back(), data)
760
761    def testInserttableMaxValues(self):
762        data = [(2 ** 15 - 1, int(2 ** 31 - 1), long(2 ** 31 - 1),
763            True, '2999-12-31', '11:59:59', 1e99,
764            1.0 + 1.0 / 32, 1.0 + 1.0 / 32, None,
765            "1", "1234", "1234", "1234" * 100)]
766        self.c.inserttable("test", data)
767        self.assertEqual(self.get_back(), data)
768
769
770class TestDirectSocketAccess(unittest.TestCase):
771    """Test copy command with direct socket access."""
772
773    @classmethod
774    def setUpClass(cls):
775        c = connect()
776        c.query("drop table if exists test cascade")
777        c.query("create table test (i int, v varchar(16))")
778        c.close()
779
780    @classmethod
781    def tearDownClass(cls):
782        c = connect()
783        c.query("drop table test cascade")
784        c.close()
785
786    def setUp(self):
787        self.c = connect()
788        self.c.query("set datestyle='ISO,YMD'")
789
790    def tearDown(self):
791        self.c.query("truncate table test")
792        self.c.close()
793
794    def testPutline(self):
795        putline = self.c.putline
796        query = self.c.query
797        data = list(enumerate("apple pear plum cherry banana".split()))
798        query("copy test from stdin")
799        try:
800            for i, v in data:
801                putline("%d\t%s\n" % (i, v))
802            putline("\\.\n")
803        finally:
804            self.c.endcopy()
805        r = query("select * from test").getresult()
806        self.assertEqual(r, data)
807
808    def testGetline(self):
809        getline = self.c.getline
810        query = self.c.query
811        data = list(enumerate("apple banana pear plum strawberry".split()))
812        n = len(data)
813        self.c.inserttable('test', data)
814        query("copy test to stdout")
815        try:
816            for i in range(n + 2):
817                v = getline()
818                if i < n:
819                    self.assertEqual(v, '%d\t%s' % data[i])
820                elif i == n:
821                    self.assertEqual(v, '\\.')
822                else:
823                    self.assertIsNone(v)
824        finally:
825            try:
826                self.c.endcopy()
827            except IOError:
828                pass
829
830    def testParameterChecks(self):
831        self.assertRaises(TypeError, self.c.putline)
832        self.assertRaises(TypeError, self.c.getline, 'invalid')
833        self.assertRaises(TypeError, self.c.endcopy, 'invalid')
834
835
836class TestNotificatons(unittest.TestCase):
837    """Test notification support."""
838
839    def setUp(self):
840        self.c = connect()
841
842    def tearDown(self):
843        self.c.close()
844
845    def testGetNotify(self):
846        getnotify = self.c.getnotify
847        query = self.c.query
848        self.assertIsNone(getnotify())
849        query('listen test_notify')
850        try:
851            self.assertIsNone(self.c.getnotify())
852            query("notify test_notify")
853            r = getnotify()
854            self.assertIsInstance(r, tuple)
855            self.assertEqual(len(r), 3)
856            self.assertIsInstance(r[0], str)
857            self.assertIsInstance(r[1], int)
858            self.assertIsInstance(r[2], str)
859            self.assertEqual(r[0], 'test_notify')
860            self.assertEqual(r[2], '')
861            self.assertIsNone(self.c.getnotify())
862            try:
863                query("notify test_notify, 'test_payload'")
864            except pg.ProgrammingError:  # PostgreSQL < 9.0
865                pass
866            else:
867                r = getnotify()
868                self.assertTrue(isinstance(r, tuple))
869                self.assertEqual(len(r), 3)
870                self.assertIsInstance(r[0], str)
871                self.assertIsInstance(r[1], int)
872                self.assertIsInstance(r[2], str)
873                self.assertEqual(r[0], 'test_notify')
874                self.assertEqual(r[2], 'test_payload')
875                self.assertIsNone(getnotify())
876        finally:
877            query('unlisten test_notify')
878
879    def testGetNoticeReceiver(self):
880        self.assertIsNone(self.c.get_notice_receiver())
881
882    def testSetNoticeReceiver(self):
883        self.assertRaises(TypeError, self.c.set_notice_receiver, None)
884        self.assertRaises(TypeError, self.c.set_notice_receiver, 42)
885        self.assertIsNone(self.c.set_notice_receiver(lambda notice: None))
886
887    def testSetAndGetNoticeReceiver(self):
888        r = lambda notice: None
889        self.assertIsNone(self.c.set_notice_receiver(r))
890        self.assertIs(self.c.get_notice_receiver(), r)
891
892    def testNoticeReceiver(self):
893        self.c.query('''create function bilbo_notice() returns void AS $$
894            begin
895                raise warning 'Bilbo was here!';
896            end;
897            $$ language plpgsql''')
898        try:
899            received = {}
900
901            def notice_receiver(notice):
902                for attr in dir(notice):
903                    value = getattr(notice, attr)
904                    if isinstance(value, str):
905                        value = value.replace('WARNUNG', 'WARNING')
906                    received[attr] = value
907
908            self.c.set_notice_receiver(notice_receiver)
909            self.c.query('''select bilbo_notice()''')
910            self.assertEqual(received, dict(
911                pgcnx=self.c, message='WARNING:  Bilbo was here!\n',
912                severity='WARNING', primary='Bilbo was here!',
913                detail=None, hint=None))
914        finally:
915            self.c.query('''drop function bilbo_notice();''')
916
917
918class TestConfigFunctions(unittest.TestCase):
919    """Test the functions for changing default settings.
920
921    To test the effect of most of these functions, we need a database
922    connection.  That's why they are covered in this test module.
923
924    """
925
926    def setUp(self):
927        self.c = connect()
928
929    def tearDown(self):
930        self.c.close()
931
932    def testGetDecimalPoint(self):
933        point = pg.get_decimal_point()
934        self.assertIsInstance(point, str)
935        self.assertEqual(point, '.')
936
937    def testSetDecimalPoint(self):
938        d = pg.Decimal
939        point = pg.get_decimal_point()
940        query = self.c.query
941        # check that money values can be interpreted correctly
942        # if and only if the decimal point is set appropriately
943        # for the current lc_monetary setting
944        try:
945            query("set lc_monetary='en_US.UTF-8'")
946        except pg.ProgrammingError:
947            self.skipTest("cannot set English money locale")
948        pg.set_decimal_point(None)
949        try:
950            r = query("select '34.25'::money").getresult()[0][0]
951        finally:
952            pg.set_decimal_point(point)
953        self.assertIsInstance(r, str)
954        self.assertIn(r, (
955            '$34.25', '$ 34.25', '34.25$', '34.25 $', '34.25 Dollar'))
956        pg.set_decimal_point('.')
957        try:
958            r = query("select '34.25'::money").getresult()[0][0]
959        finally:
960            pg.set_decimal_point(point)
961        self.assertIsInstance(r, d)
962        self.assertEqual(r, d('34.25'))
963        pg.set_decimal_point(',')
964        try:
965            r = query("select '34.25'::money").getresult()[0][0]
966        finally:
967            pg.set_decimal_point(point)
968        self.assertNotEqual(r, d('34.25'))
969        try:
970            query("set lc_monetary='de_DE.UTF-8'")
971        except pg.ProgrammingError:
972            self.skipTest("cannot set German money locale")
973        pg.set_decimal_point(None)
974        try:
975            r = query("select '34,25'::money").getresult()[0][0]
976        finally:
977            pg.set_decimal_point(point)
978        self.assertIsInstance(r, str)
979        self.assertIn(r, ('34,25€', '34,25 €', '€34,25' '€ 34,25',
980            '34,25 EUR', '34,25 Euro', '34,25 DM'))
981        pg.set_decimal_point(',')
982        try:
983            r = query("select '34,25'::money").getresult()[0][0]
984        finally:
985            pg.set_decimal_point(point)
986        self.assertIsInstance(r, d)
987        self.assertEqual(r, d('34.25'))
988        try:
989            pg.set_decimal_point('.')
990        finally:
991            pg.set_decimal_point(point)
992        r = query("select '34,25'::money").getresult()[0][0]
993        self.assertNotEqual(r, d('34.25'))
994
995    def testSetDecimal(self):
996        d = pg.Decimal
997        query = self.c.query
998        r = query("select 3425::numeric").getresult()[0][0]
999        self.assertIsInstance(r, d)
1000        self.assertEqual(r, d('3425'))
1001        pg.set_decimal(long)
1002        try:
1003            r = query("select 3425::numeric").getresult()[0][0]
1004        finally:
1005            pg.set_decimal(d)
1006        self.assertNotIsInstance(r, d)
1007        self.assertIsInstance(r, long)
1008        self.assertEqual(r, 3425L)
1009
1010    @unittest.skipUnless(namedtuple, 'Named tuples not available')
1011    def testSetNamedresult(self):
1012        query = self.c.query
1013
1014        r = query("select 1 as x, 2 as y").namedresult()[0]
1015        self.assertIsInstance(r, tuple)
1016        self.assertEqual(r, (1, 2))
1017        self.assertIsNot(type(r), tuple)
1018        self.assertEqual(r._fields, ('x', 'y'))
1019        self.assertEqual(r._asdict(), {'x': 1, 'y': 2})
1020        self.assertEqual(r.__class__.__name__, 'Row')
1021
1022        _namedresult = pg._namedresult
1023        self.assertTrue(callable(_namedresult))
1024        pg.set_namedresult(_namedresult)
1025
1026        r = query("select 1 as x, 2 as y").namedresult()[0]
1027        self.assertIsInstance(r, tuple)
1028        self.assertEqual(r, (1, 2))
1029        self.assertIsNot(type(r), tuple)
1030        self.assertEqual(r._fields, ('x', 'y'))
1031        self.assertEqual(r._asdict(), {'x': 1, 'y': 2})
1032        self.assertEqual(r.__class__.__name__, 'Row')
1033
1034        def _listresult(q):
1035            return map(list, q.getresult())
1036
1037        pg.set_namedresult(_listresult)
1038
1039        try:
1040            r = query("select 1 as x, 2 as y").namedresult()[0]
1041            self.assertIsInstance(r, list)
1042            self.assertEqual(r, [1, 2])
1043            self.assertIsNot(type(r), tuple)
1044            self.assertFalse(hasattr(r, '_fields'))
1045            self.assertNotEqual(r.__class__.__name__, 'Row')
1046        finally:
1047            pg.set_namedresult(_namedresult)
1048
1049
1050if __name__ == '__main__':
1051    unittest.main()
Note: See TracBrowser for help on using the repository browser.