source: branches/4.x/module/TEST_PyGreSQL_classic_connection.py @ 570

Last change on this file since 570 was 570, checked in by cito, 4 years ago

Improve tests for long ints under Python 2

PyGreSQL returns longs for Postgres bigints, even though theoretically
it could return ints in the case of a 64bit Python. But that's ok
because it makes the behavior more consistent and because the int/long
split becomes irrelevant in Python 3 anyway.

  • Property svn:executable set to *
  • Property svn:keywords set to Id
File size: 35.6 KB
Line 
1#! /usr/bin/python
2# -*- coding: utf-8 -*-
3
4"""Test the classic PyGreSQL interface.
5
6Sub-tests for the low-level connection object.
7
8Contributed by Christoph Zwerschke.
9
10These tests need a database to test against.
11
12"""
13
14try:
15    import unittest2 as unittest  # for Python < 2.7
16except ImportError:
17    import unittest
18import sys
19import tempfile
20import threading
21import time
22
23import pg  # the module under test
24
25from decimal import Decimal
26try:
27    from collections import namedtuple
28except ImportError:  # Python < 2.6
29    namedtuple = None
30
31# We need a database to test against.  If LOCAL_PyGreSQL.py exists we will
32# get our information from that.  Otherwise we use the defaults.
33dbname = 'unittest'
34dbhost = None
35dbport = 5432
36
37try:
38    from LOCAL_PyGreSQL import *
39except ImportError:
40    pass
41
42
43def connect():
44    """Create a basic pg connection to the test database."""
45    connection = pg.connect(dbname, dbhost, dbport)
46    connection.query("set client_min_messages=warning")
47    return connection
48
49
50class TestCanConnect(unittest.TestCase):
51    """Test whether a basic connection to PostgreSQL is possible."""
52
53    def testCanConnect(self):
54        try:
55            connection = connect()
56        except pg.Error, error:
57            self.fail('Cannot connect to database %s:\n%s' % (dbname, error))
58        try:
59            connection.close()
60        except pg.Error:
61            self.fail('Cannot close the database connection')
62
63
64class TestConnectObject(unittest.TestCase):
65    """"Test existence of basic pg connection methods."""
66
67    def setUp(self):
68        self.connection = connect()
69
70    def tearDown(self):
71        try:
72            self.connection.close()
73        except pg.InternalError:
74            pass
75
76    def testAllConnectAttributes(self):
77        attributes = '''db error host options port
78            protocol_version server_version status tty user'''.split()
79        connection_attributes = [a for a in dir(self.connection)
80            if not callable(eval("self.connection." + a))]
81        self.assertEqual(attributes, connection_attributes)
82
83    def testAllConnectMethods(self):
84        methods = '''cancel close endcopy
85            escape_bytea escape_identifier escape_literal escape_string
86            fileno get_notice_receiver getline getlo getnotify
87            inserttable locreate loimport parameter putline query reset
88            set_notice_receiver source transaction'''.split()
89        connection_methods = [a for a in dir(self.connection)
90            if callable(eval("self.connection." + a))]
91        self.assertEqual(methods, connection_methods)
92
93    def testAttributeDb(self):
94        self.assertEqual(self.connection.db, dbname)
95
96    def testAttributeError(self):
97        error = self.connection.error
98        self.assertTrue(not error or 'krb5_' in error)
99
100    def testAttributeHost(self):
101        def_host = 'localhost'
102        self.assertIsInstance(self.connection.host, str)
103        self.assertEqual(self.connection.host, dbhost or def_host)
104
105    def testAttributeOptions(self):
106        no_options = ''
107        self.assertEqual(self.connection.options, no_options)
108
109    def testAttributePort(self):
110        def_port = 5432
111        self.assertIsInstance(self.connection.port, int)
112        self.assertEqual(self.connection.port, dbport or def_port)
113
114    def testAttributeProtocolVersion(self):
115        protocol_version = self.connection.protocol_version
116        self.assertIsInstance(protocol_version, int)
117        self.assertTrue(2 <= protocol_version < 4)
118
119    def testAttributeServerVersion(self):
120        server_version = self.connection.server_version
121        self.assertIsInstance(server_version, int)
122        self.assertTrue(70400 <= server_version < 100000)
123
124    def testAttributeStatus(self):
125        status_ok = 1
126        self.assertIsInstance(self.connection.status, int)
127        self.assertEqual(self.connection.status, status_ok)
128
129    def testAttributeTty(self):
130        def_tty = ''
131        self.assertIsInstance(self.connection.tty, str)
132        self.assertEqual(self.connection.tty, def_tty)
133
134    def testAttributeUser(self):
135        no_user = 'Deprecated facility'
136        user = self.connection.user
137        self.assertTrue(user)
138        self.assertIsInstance(user, str)
139        self.assertNotEqual(user, no_user)
140
141    def testMethodQuery(self):
142        query = self.connection.query
143        query("select 1+1")
144        query("select 1+$1", (1,))
145        query("select 1+$1+$2", (2, 3))
146        query("select 1+$1+$2", [2, 3])
147
148    def testMethodQueryEmpty(self):
149        self.assertRaises(ValueError, self.connection.query, '')
150
151    def testMethodEndcopy(self):
152        try:
153            self.connection.endcopy()
154        except IOError:
155            pass
156
157    def testMethodClose(self):
158        self.connection.close()
159        try:
160            self.connection.reset()
161        except (pg.Error, TypeError):
162            pass
163        else:
164            self.fail('Reset should give an error for a closed connection')
165        self.assertRaises(pg.InternalError, self.connection.close)
166        try:
167            self.connection.query('select 1')
168        except (pg.Error, TypeError):
169            pass
170        else:
171            self.fail('Query should give an error for a closed connection')
172        self.connection = connect()
173
174    def testMethodReset(self):
175        query = self.connection.query
176        # check that client encoding gets reset
177        encoding = query('show client_encoding').getresult()[0][0].upper()
178        changed_encoding = 'LATIN1' if encoding == 'UTF8' else 'UTF8'
179        self.assertNotEqual(encoding, changed_encoding)
180        self.connection.query("set client_encoding=%s" % changed_encoding)
181        new_encoding = query('show client_encoding').getresult()[0][0].upper()
182        self.assertEqual(new_encoding, changed_encoding)
183        self.connection.reset()
184        new_encoding = query('show client_encoding').getresult()[0][0].upper()
185        self.assertNotEqual(new_encoding, changed_encoding)
186        self.assertEqual(new_encoding, encoding)
187
188    def testMethodCancel(self):
189        r = self.connection.cancel()
190        self.assertIsInstance(r, int)
191        self.assertEqual(r, 1)
192
193    def testCancelLongRunningThread(self):
194        errors = []
195
196        def sleep():
197            try:
198                self.connection.query('select pg_sleep(5)').getresult()
199            except pg.ProgrammingError, error:
200                errors.append(str(error))
201
202        thread = threading.Thread(target=sleep)
203        t1 = time.time()
204        thread.start()  # run the query
205        while 1:  # make sure the query is really running
206            time.sleep(0.1)
207            if thread.is_alive() or time.time() - t1 > 5:
208                break
209        r = self.connection.cancel()  # cancel the running query
210        thread.join()  # wait for the thread to end
211        t2 = time.time()
212
213        self.assertIsInstance(r, int)
214        self.assertEqual(r, 1)  # return code should be 1
215        self.assertLessEqual(t2 - t1, 3)  # time should be under 3 seconds
216        self.assertTrue(errors)
217
218    def testMethodFileNo(self):
219        r = self.connection.fileno()
220        self.assertIsInstance(r, int)
221        self.assertGreaterEqual(r, 0)
222
223
224class TestSimpleQueries(unittest.TestCase):
225    """"Test simple queries via a basic pg connection."""
226
227    def setUp(self):
228        self.c = connect()
229
230    def tearDown(self):
231        self.c.close()
232
233    def testSelect0(self):
234        q = "select 0"
235        self.c.query(q)
236
237    def testSelect0Semicolon(self):
238        q = "select 0;"
239        self.c.query(q)
240
241    def testSelectDotSemicolon(self):
242        q = "select .;"
243        self.assertRaises(pg.ProgrammingError, self.c.query, q)
244
245    def testGetresult(self):
246        q = "select 0"
247        result = [(0,)]
248        r = self.c.query(q).getresult()
249        self.assertIsInstance(r, list)
250        v = r[0]
251        self.assertIsInstance(v, tuple)
252        self.assertIsInstance(v[0], int)
253        self.assertEqual(r, result)
254
255    def testGetresultLong(self):
256        q = "select 9876543210"
257        result = 9876543210L
258        v = self.c.query(q).getresult()[0][0]
259        self.assertIsInstance(v, long)
260        self.assertEqual(v, result)
261
262    def testGetresultDecimal(self):
263        q = "select 98765432109876543210"
264        result = Decimal(98765432109876543210L)
265        v = self.c.query(q).getresult()[0][0]
266        self.assertIsInstance(v, Decimal)
267        self.assertEqual(v, result)
268
269    def testGetresultString(self):
270        result = 'Hello, world!'
271        q = "select '%s'" % result
272        v = self.c.query(q).getresult()[0][0]
273        self.assertIsInstance(v, str)
274        self.assertEqual(v, result)
275
276    def testDictresult(self):
277        q = "select 0 as alias0"
278        result = [{'alias0': 0}]
279        r = self.c.query(q).dictresult()
280        self.assertIsInstance(r, list)
281        v = r[0]
282        self.assertIsInstance(v, dict)
283        self.assertIsInstance(v['alias0'], int)
284        self.assertEqual(r, result)
285
286    def testDictresultLong(self):
287        q = "select 9876543210 as longjohnsilver"
288        result = 9876543210L
289        v = self.c.query(q).dictresult()[0]['longjohnsilver']
290        self.assertIsInstance(v, long)
291        self.assertEqual(v, result)
292
293    def testDictresultDecimal(self):
294        q = "select 98765432109876543210 as longjohnsilver"
295        result = Decimal(98765432109876543210L)
296        v = self.c.query(q).dictresult()[0]['longjohnsilver']
297        self.assertIsInstance(v, Decimal)
298        self.assertEqual(v, result)
299
300    def testDictresultString(self):
301        result = 'Hello, world!'
302        q = "select '%s' as greeting" % result
303        v = self.c.query(q).dictresult()[0]['greeting']
304        self.assertIsInstance(v, str)
305        self.assertEqual(v, result)
306
307    @unittest.skipUnless(namedtuple, 'Named tuples not available')
308    def testNamedresult(self):
309        q = "select 0 as alias0"
310        result = [(0,)]
311        r = self.c.query(q).namedresult()
312        self.assertEqual(r, result)
313        v = r[0]
314        self.assertEqual(v._fields, ('alias0',))
315        self.assertEqual(v.alias0, 0)
316
317    def testGet3Cols(self):
318        q = "select 1,2,3"
319        result = [(1, 2, 3)]
320        r = self.c.query(q).getresult()
321        self.assertEqual(r, result)
322
323    def testGet3DictCols(self):
324        q = "select 1 as a,2 as b,3 as c"
325        result = [dict(a=1, b=2, c=3)]
326        r = self.c.query(q).dictresult()
327        self.assertEqual(r, result)
328
329    @unittest.skipUnless(namedtuple, 'Named tuples not available')
330    def testGet3NamedCols(self):
331        q = "select 1 as a,2 as b,3 as c"
332        result = [(1, 2, 3)]
333        r = self.c.query(q).namedresult()
334        self.assertEqual(r, result)
335        v = r[0]
336        self.assertEqual(v._fields, ('a', 'b', 'c'))
337        self.assertEqual(v.b, 2)
338
339    def testGet3Rows(self):
340        q = "select 3 union select 1 union select 2 order by 1"
341        result = [(1,), (2,), (3,)]
342        r = self.c.query(q).getresult()
343        self.assertEqual(r, result)
344
345    def testGet3DictRows(self):
346        q = ("select 3 as alias3"
347            " union select 1 union select 2 order by 1")
348        result = [{'alias3': 1}, {'alias3': 2}, {'alias3': 3}]
349        r = self.c.query(q).dictresult()
350        self.assertEqual(r, result)
351
352    @unittest.skipUnless(namedtuple, 'Named tuples not available')
353    def testGet3NamedRows(self):
354        q = ("select 3 as alias3"
355            " union select 1 union select 2 order by 1")
356        result = [(1,), (2,), (3,)]
357        r = self.c.query(q).namedresult()
358        self.assertEqual(r, result)
359        for v in r:
360            self.assertEqual(v._fields, ('alias3',))
361
362    def testDictresultNames(self):
363        q = "select 'MixedCase' as MixedCaseAlias"
364        result = [{'mixedcasealias': 'MixedCase'}]
365        r = self.c.query(q).dictresult()
366        self.assertEqual(r, result)
367        q = "select 'MixedCase' as \"MixedCaseAlias\""
368        result = [{'MixedCaseAlias': 'MixedCase'}]
369        r = self.c.query(q).dictresult()
370        self.assertEqual(r, result)
371
372    @unittest.skipUnless(namedtuple, 'Named tuples not available')
373    def testNamedresultNames(self):
374        q = "select 'MixedCase' as MixedCaseAlias"
375        result = [('MixedCase',)]
376        r = self.c.query(q).namedresult()
377        self.assertEqual(r, result)
378        v = r[0]
379        self.assertEqual(v._fields, ('mixedcasealias',))
380        self.assertEqual(v.mixedcasealias, 'MixedCase')
381        q = "select 'MixedCase' as \"MixedCaseAlias\""
382        r = self.c.query(q).namedresult()
383        self.assertEqual(r, result)
384        v = r[0]
385        self.assertEqual(v._fields, ('MixedCaseAlias',))
386        self.assertEqual(v.MixedCaseAlias, 'MixedCase')
387
388    def testBigGetresult(self):
389        num_cols = 100
390        num_rows = 100
391        q = "select " + ','.join(map(str, xrange(num_cols)))
392        q = ' union all '.join((q,) * num_rows)
393        r = self.c.query(q).getresult()
394        result = [tuple(range(num_cols))] * num_rows
395        self.assertEqual(r, result)
396
397    def testListfields(self):
398        q = ('select 0 as a, 0 as b, 0 as c,'
399            ' 0 as c, 0 as b, 0 as a,'
400            ' 0 as lowercase, 0 as UPPERCASE,'
401            ' 0 as MixedCase, 0 as "MixedCase",'
402            ' 0 as a_long_name_with_underscores,'
403            ' 0 as "A long name with Blanks"')
404        r = self.c.query(q).listfields()
405        result = ('a', 'b', 'c', 'c', 'b', 'a',
406            'lowercase', 'uppercase', 'mixedcase', 'MixedCase',
407            'a_long_name_with_underscores',
408            'A long name with Blanks')
409        self.assertEqual(r, result)
410
411    def testFieldname(self):
412        q = "select 0 as z, 0 as a, 0 as x, 0 as y"
413        r = self.c.query(q).fieldname(2)
414        self.assertEqual(r, 'x')
415        r = self.c.query(q).fieldname(3)
416        self.assertEqual(r, 'y')
417
418    def testFieldnum(self):
419        q = "select 1 as x"
420        self.assertRaises(ValueError, self.c.query(q).fieldnum, 'y')
421        q = "select 1 as x"
422        r = self.c.query(q).fieldnum('x')
423        self.assertIsInstance(r, int)
424        self.assertEqual(r, 0)
425        q = "select 0 as z, 0 as a, 0 as x, 0 as y"
426        r = self.c.query(q).fieldnum('x')
427        self.assertIsInstance(r, int)
428        self.assertEqual(r, 2)
429        r = self.c.query(q).fieldnum('y')
430        self.assertIsInstance(r, int)
431        self.assertEqual(r, 3)
432
433    def testNtuples(self):
434        q = "select 1 where false"
435        r = self.c.query(q).ntuples()
436        self.assertIsInstance(r, int)
437        self.assertEqual(r, 0)
438        q = ("select 1 as a, 2 as b, 3 as c, 4 as d"
439            " union select 5 as a, 6 as b, 7 as c, 8 as d")
440        r = self.c.query(q).ntuples()
441        self.assertIsInstance(r, int)
442        self.assertEqual(r, 2)
443        q = ("select 1 union select 2 union select 3"
444            " union select 4 union select 5 union select 6")
445        r = self.c.query(q).ntuples()
446        self.assertIsInstance(r, int)
447        self.assertEqual(r, 6)
448
449    def testQuery(self):
450        query = self.c.query
451        query("drop table if exists test_table")
452        q = "create table test_table (n integer) with oids"
453        r = query(q)
454        self.assertIsNone(r)
455        q = "insert into test_table values (1)"
456        r = query(q)
457        self.assertIsInstance(r, int)
458        q = "insert into test_table select 2"
459        r = query(q)
460        self.assertIsInstance(r, int)
461        oid = r
462        q = "select oid from test_table where n=2"
463        r = query(q).getresult()
464        self.assertEqual(len(r), 1)
465        r = r[0]
466        self.assertEqual(len(r), 1)
467        r = r[0]
468        self.assertIsInstance(r, int)
469        self.assertEqual(r, oid)
470        q = "insert into test_table select 3 union select 4 union select 5"
471        r = query(q)
472        self.assertIsInstance(r, str)
473        self.assertEqual(r, '3')
474        q = "update test_table set n=4 where n<5"
475        r = query(q)
476        self.assertIsInstance(r, str)
477        self.assertEqual(r, '4')
478        q = "delete from test_table"
479        r = query(q)
480        self.assertIsInstance(r, str)
481        self.assertEqual(r, '5')
482        query("drop table test_table")
483
484    def testPrint(self):
485        q = ("select 1 as a, 'hello' as h, 'w' as world"
486            " union select 2, 'xyz', 'uvw'")
487        r = self.c.query(q)
488        f = tempfile.TemporaryFile()
489        stdout, sys.stdout = sys.stdout, f
490        try:
491            print r
492        except Exception:
493            pass
494        finally:
495            sys.stdout = stdout
496        f.seek(0)
497        r = f.read()
498        f.close()
499        self.assertEqual(r,
500            'a|  h  |world\n'
501            '-+-----+-----\n'
502            '1|hello|w    \n'
503            '2|xyz  |uvw  \n'
504            '(2 rows)\n')
505
506
507class TestParamQueries(unittest.TestCase):
508    """"Test queries with parameters via a basic pg connection."""
509
510    def setUp(self):
511        self.c = connect()
512
513    def tearDown(self):
514        self.c.close()
515
516    def testQueryWithNoneParam(self):
517        self.assertEqual(self.c.query("select $1::integer", (None,)
518            ).getresult(), [(None,)])
519        self.assertEqual(self.c.query("select $1::text", [None]
520            ).getresult(), [(None,)])
521
522    def testQueryWithBoolParams(self):
523        query = self.c.query
524        self.assertEqual(query("select false").getresult(), [('f',)])
525        self.assertEqual(query("select true").getresult(), [('t',)])
526        self.assertEqual(query("select $1::bool", (None,)).getresult(),
527            [(None,)])
528        self.assertEqual(query("select $1::bool", ('f',)).getresult(), [('f',)])
529        self.assertEqual(query("select $1::bool", ('t',)).getresult(), [('t',)])
530        self.assertEqual(query("select $1::bool", ('false',)).getresult(),
531            [('f',)])
532        self.assertEqual(query("select $1::bool", ('true',)).getresult(),
533            [('t',)])
534        self.assertEqual(query("select $1::bool", ('n',)).getresult(), [('f',)])
535        self.assertEqual(query("select $1::bool", ('y',)).getresult(), [('t',)])
536        self.assertEqual(query("select $1::bool", (0,)).getresult(), [('f',)])
537        self.assertEqual(query("select $1::bool", (1,)).getresult(), [('t',)])
538        self.assertEqual(query("select $1::bool", (False,)).getresult(),
539            [('f',)])
540        self.assertEqual(query("select $1::bool", (True,)).getresult(),
541            [('t',)])
542
543    def testQueryWithIntParams(self):
544        query = self.c.query
545        self.assertEqual(query("select 1+1").getresult(), [(2,)])
546        self.assertEqual(query("select 1+$1", (1,)).getresult(), [(2,)])
547        self.assertEqual(query("select 1+$1", [1]).getresult(), [(2,)])
548        self.assertEqual(query("select $1::integer", (2,)).getresult(), [(2,)])
549        self.assertEqual(query("select $1::text", (2,)).getresult(), [('2',)])
550        self.assertEqual(query("select 1+$1::numeric", [1]).getresult(),
551            [(Decimal('2'),)])
552        self.assertEqual(query("select 1, $1::integer", (2,)
553            ).getresult(), [(1, 2)])
554        self.assertEqual(query("select 1 union select $1", (2,)
555            ).getresult(), [(1,), (2,)])
556        self.assertEqual(query("select $1::integer+$2", (1, 2)
557            ).getresult(), [(3,)])
558        self.assertEqual(query("select $1::integer+$2", [1, 2]
559            ).getresult(), [(3,)])
560        self.assertEqual(query("select 0+$1+$2+$3+$4+$5+$6", range(6)
561            ).getresult(), [(15,)])
562
563    def testQueryWithStrParams(self):
564        query = self.c.query
565        self.assertEqual(query("select $1||', world!'", ('Hello',)
566            ).getresult(), [('Hello, world!',)])
567        self.assertEqual(query("select $1||', world!'", ['Hello']
568            ).getresult(), [('Hello, world!',)])
569        self.assertEqual(query("select $1||', '||$2||'!'", ('Hello', 'world'),
570            ).getresult(), [('Hello, world!',)])
571        self.assertEqual(query("select $1::text", ('Hello, world!',)
572            ).getresult(), [('Hello, world!',)])
573        self.assertEqual(query("select $1::text,$2::text", ('Hello', 'world')
574            ).getresult(), [('Hello', 'world')])
575        self.assertEqual(query("select $1::text,$2::text", ['Hello', 'world']
576            ).getresult(), [('Hello', 'world')])
577        self.assertEqual(query("select $1::text union select $2::text",
578            ('Hello', 'world')).getresult(), [('Hello',), ('world',)])
579        self.assertEqual(query("select $1||', '||$2||'!'", ('Hello',
580            'w\xc3\xb6rld')).getresult(), [('Hello, w\xc3\xb6rld!',)])
581
582    def testQueryWithUnicodeParams(self):
583        query = self.c.query
584        query('set client_encoding = utf8')
585        self.assertEqual(query("select $1||', '||$2||'!'",
586            ('Hello', u'w\xf6rld')).getresult(), [('Hello, w\xc3\xb6rld!',)])
587        self.assertEqual(query("select $1||', '||$2||'!'",
588            ('Hello', u'\u043c\u0438\u0440')).getresult(),
589            [('Hello, \xd0\xbc\xd0\xb8\xd1\x80!',)])
590        query('set client_encoding = latin1')
591        self.assertEqual(query("select $1||', '||$2||'!'",
592            ('Hello', u'w\xf6rld')).getresult(), [('Hello, w\xf6rld!',)])
593        self.assertRaises(UnicodeError, query, "select $1||', '||$2||'!'",
594            ('Hello', u'\u043c\u0438\u0440'))
595        query('set client_encoding = iso_8859_1')
596        self.assertEqual(query("select $1||', '||$2||'!'",
597            ('Hello', u'w\xf6rld')).getresult(), [('Hello, w\xf6rld!',)])
598        self.assertRaises(UnicodeError, query, "select $1||', '||$2||'!'",
599            ('Hello', u'\u043c\u0438\u0440'))
600        query('set client_encoding = iso_8859_5')
601        self.assertRaises(UnicodeError, query, "select $1||', '||$2||'!'",
602            ('Hello', u'w\xf6rld'))
603        self.assertEqual(query("select $1||', '||$2||'!'",
604            ('Hello', u'\u043c\u0438\u0440')).getresult(),
605            [('Hello, \xdc\xd8\xe0!',)])
606        query('set client_encoding = sql_ascii')
607        self.assertRaises(UnicodeError, query, "select $1||', '||$2||'!'",
608            ('Hello', u'w\xf6rld'))
609
610    def testQueryWithMixedParams(self):
611        self.assertEqual(self.c.query("select $1+2,$2||', world!'",
612            (1, 'Hello'),).getresult(), [(3, 'Hello, world!')])
613        self.assertEqual(self.c.query("select $1::integer,$2::date,$3::text",
614            (4711, None, 'Hello!'),).getresult(), [(4711, None, 'Hello!')])
615
616    def testQueryWithDuplicateParams(self):
617        self.assertRaises(pg.ProgrammingError,
618            self.c.query, "select $1+$1", (1,))
619        self.assertRaises(pg.ProgrammingError,
620            self.c.query, "select $1+$1", (1, 2))
621
622    def testQueryWithZeroParams(self):
623        self.assertEqual(self.c.query("select 1+1", []
624            ).getresult(), [(2,)])
625
626    def testQueryWithGarbage(self):
627        garbage = r"'\{}+()-#[]oo324"
628        self.assertEqual(self.c.query("select $1::text AS garbage", (garbage,)
629            ).dictresult(), [{'garbage': garbage}])
630
631    def testUnicodeQuery(self):
632        query = self.c.query
633        self.assertEqual(query(u"select 1+1").getresult(), [(2,)])
634        self.assertRaises(TypeError, query, u"select 'Hello, w\xf6rld!'")
635
636
637class TestInserttable(unittest.TestCase):
638    """"Test inserttable method."""
639
640    @classmethod
641    def setUpClass(cls):
642        c = connect()
643        c.query("drop table if exists test cascade")
644        c.query("create table test ("
645            "i2 smallint, i4 integer, i8 bigint, b boolean, dt date, ti time,"
646            "d numeric, f4 real, f8 double precision, m money,"
647            "c char(1), v4 varchar(4), c4 char(4), t text)")
648        c.close()
649
650    @classmethod
651    def tearDownClass(cls):
652        c = connect()
653        c.query("drop table test cascade")
654        c.close()
655
656    def setUp(self):
657        self.c = connect()
658        self.c.query("set datestyle='ISO,YMD'")
659
660    def tearDown(self):
661        self.c.query("truncate table test")
662        self.c.close()
663
664    data = [
665        (-1, -1, -1L, True, '1492-10-12', '08:30:00',
666            -1.2345, -1.75, -1.875, '-1.25', '-', 'r?', '!u', 'xyz'),
667        (0, 0, 0L, False, '1607-04-14', '09:00:00',
668            0.0, 0.0, 0.0, '0.0', ' ', '0123', '4567', '890'),
669        (1, 1, 1L, True, '1801-03-04', '03:45:00',
670            1.23456, 1.75, 1.875, '1.25', 'x', 'bc', 'cdef', 'g'),
671        (2, 2, 2L, False, '1903-12-17', '11:22:00',
672            2.345678, 2.25, 2.125, '2.75', 'y', 'q', 'ijk', 'mnop\nstux!')]
673
674    def get_back(self):
675        """Convert boolean and decimal values back."""
676        data = []
677        for row in self.c.query("select * from test order by 1").getresult():
678            self.assertIsInstance(row, tuple)
679            row = list(row)
680            if row[0] is not None:  # smallint
681                self.assertIsInstance(row[0], int)
682            if row[1] is not None:  # integer
683                self.assertIsInstance(row[1], int)
684            if row[2] is not None:  # bigint
685                self.assertIsInstance(row[2], long)
686            if row[3] is not None:  # boolean
687                self.assertIsInstance(row[3], str)
688                row[3] = {'f': False, 't': True}.get(row[3])
689            if row[4] is not None:  # date
690                self.assertIsInstance(row[4], str)
691                self.assertTrue(row[4].replace('-', '').isdigit())
692            if row[5] is not None:  # time
693                self.assertIsInstance(row[5], str)
694                self.assertTrue(row[5].replace(':', '').isdigit())
695            if row[6] is not None:  # numeric
696                self.assertIsInstance(row[6], Decimal)
697                row[6] = float(row[6])
698            if row[7] is not None:  # real
699                self.assertIsInstance(row[7], float)
700            if row[8] is not None:  # double precision
701                self.assertIsInstance(row[8], float)
702                row[8] = float(row[8])
703            if row[9] is not None:  # money
704                self.assertIsInstance(row[9], Decimal)
705                row[9] = str(float(row[9]))
706            if row[10] is not None:  # char(1)
707                self.assertIsInstance(row[10], str)
708                self.assertEqual(len(row[10]), 1)
709            if row[11] is not None:  # varchar(4)
710                self.assertIsInstance(row[11], str)
711                self.assertLessEqual(len(row[11]), 4)
712            if row[12] is not None:  # char(4)
713                self.assertIsInstance(row[12], str)
714                self.assertEqual(len(row[12]), 4)
715                row[12] = row[12].rstrip()
716            if row[13] is not None:  # text
717                self.assertIsInstance(row[13], str)
718            row = tuple(row)
719            data.append(row)
720        return data
721
722    def testInserttable1Row(self):
723        data = self.data[2:3]
724        self.c.inserttable("test", data)
725        self.assertEqual(self.get_back(), data)
726
727    def testInserttable4Rows(self):
728        data = self.data
729        self.c.inserttable("test", data)
730        self.assertEqual(self.get_back(), data)
731
732    def testInserttableMultipleRows(self):
733        num_rows = 100
734        data = self.data[2:3] * num_rows
735        self.c.inserttable("test", data)
736        r = self.c.query("select count(*) from test").getresult()[0][0]
737        self.assertEqual(r, num_rows)
738
739    def testInserttableMultipleCalls(self):
740        num_rows = 10
741        data = self.data[2:3]
742        for _i in range(num_rows):
743            self.c.inserttable("test", data)
744        r = self.c.query("select count(*) from test").getresult()[0][0]
745        self.assertEqual(r, num_rows)
746
747    def testInserttableNullValues(self):
748        data = [(None,) * 14] * 100
749        self.c.inserttable("test", data)
750        self.assertEqual(self.get_back(), data)
751
752    def testInserttableMaxValues(self):
753        data = [(2 ** 15 - 1, int(2 ** 31 - 1), long(2 ** 31 - 1),
754            True, '2999-12-31', '11:59:59', 1e99,
755            1.0 + 1.0 / 32, 1.0 + 1.0 / 32, None,
756            "1", "1234", "1234", "1234" * 100)]
757        self.c.inserttable("test", data)
758        self.assertEqual(self.get_back(), data)
759
760
761class TestDirectSocketAccess(unittest.TestCase):
762    """"Test copy command with direct socket access."""
763
764    @classmethod
765    def setUpClass(cls):
766        c = connect()
767        c.query("drop table if exists test cascade")
768        c.query("create table test (i int, v varchar(16))")
769        c.close()
770
771    @classmethod
772    def tearDownClass(cls):
773        c = connect()
774        c.query("drop table test cascade")
775        c.close()
776
777    def setUp(self):
778        self.c = connect()
779        self.c.query("set datestyle='ISO,YMD'")
780
781    def tearDown(self):
782        self.c.query("truncate table test")
783        self.c.close()
784
785    def testPutline(self):
786        putline = self.c.putline
787        query = self.c.query
788        data = list(enumerate("apple pear plum cherry banana".split()))
789        query("copy test from stdin")
790        try:
791            for i, v in data:
792                putline("%d\t%s\n" % (i, v))
793            putline("\\.\n")
794        finally:
795            self.c.endcopy()
796        r = query("select * from test").getresult()
797        self.assertEqual(r, data)
798
799    def testPutline(self):
800        getline = self.c.getline
801        query = self.c.query
802        data = list(enumerate("apple banana pear plum strawberry".split()))
803        n = len(data)
804        self.c.inserttable('test', data)
805        query("copy test to stdout")
806        try:
807            for i in range(n + 2):
808                v = getline()
809                if i < n:
810                    self.assertEqual(v, '%d\t%s' % data[i])
811                elif i == n:
812                    self.assertEqual(v, '\\.')
813                else:
814                    self.assertIsNone(v)
815        finally:
816            try:
817                self.c.endcopy()
818            except IOError:
819                pass
820
821    def testParameterChecks(self):
822        self.assertRaises(TypeError, self.c.putline)
823        self.assertRaises(TypeError, self.c.getline, 'invalid')
824        self.assertRaises(TypeError, self.c.endcopy, 'invalid')
825
826
827class TestNotificatons(unittest.TestCase):
828    """"Test notification support."""
829
830    def setUp(self):
831        self.c = connect()
832
833    def tearDown(self):
834        self.c.close()
835
836    def testGetNotify(self):
837        getnotify = self.c.getnotify
838        query = self.c.query
839        self.assertIsNone(getnotify())
840        query('listen test_notify')
841        try:
842            self.assertIsNone(self.c.getnotify())
843            query("notify test_notify")
844            r = getnotify()
845            self.assertIsInstance(r, tuple)
846            self.assertEqual(len(r), 3)
847            self.assertIsInstance(r[0], str)
848            self.assertIsInstance(r[1], int)
849            self.assertIsInstance(r[2], str)
850            self.assertEqual(r[0], 'test_notify')
851            self.assertEqual(r[2], '')
852            self.assertIsNone(self.c.getnotify())
853            try:
854                query("notify test_notify, 'test_payload'")
855            except pg.ProgrammingError:  # PostgreSQL < 9.0
856                pass
857            else:
858                r = getnotify()
859                self.assertTrue(isinstance(r, tuple))
860                self.assertEqual(len(r), 3)
861                self.assertIsInstance(r[0], str)
862                self.assertIsInstance(r[1], int)
863                self.assertIsInstance(r[2], str)
864                self.assertEqual(r[0], 'test_notify')
865                self.assertEqual(r[2], 'test_payload')
866                self.assertIsNone(getnotify())
867        finally:
868            query('unlisten test_notify')
869
870    def testGetNoticeReceiver(self):
871        self.assertIsNone(self.c.get_notice_receiver())
872
873    def testSetNoticeReceiver(self):
874        self.assertRaises(TypeError, self.c.set_notice_receiver, None)
875        self.assertRaises(TypeError, self.c.set_notice_receiver, 42)
876        self.assertIsNone(self.c.set_notice_receiver(lambda notice: None))
877
878    def testSetAndGetNoticeReceiver(self):
879        r = lambda notice: None
880        self.assertIsNone(self.c.set_notice_receiver(r))
881        self.assertIs(self.c.get_notice_receiver(), r)
882
883    def testNoticeReceiver(self):
884        self.c.query('''create function bilbo_notice() returns void AS $$
885            begin
886                raise warning 'Bilbo was here!';
887            end;
888            $$ language plpgsql''')
889        try:
890            received = {}
891
892            def notice_receiver(notice):
893                for attr in dir(notice):
894                    value = getattr(notice, attr)
895                    if isinstance(value, str):
896                        value = value.replace('WARNUNG', 'WARNING')
897                    received[attr] = value
898
899            self.c.set_notice_receiver(notice_receiver)
900            self.c.query('''select bilbo_notice()''')
901            self.assertEqual(received, dict(
902                pgcnx=self.c, message='WARNING:  Bilbo was here!\n',
903                severity='WARNING', primary='Bilbo was here!',
904                detail=None, hint=None))
905        finally:
906            self.c.query('''drop function bilbo_notice();''')
907
908
909class TestConfigFunctions(unittest.TestCase):
910    """Test the functions for changing default settings.
911
912    To test the effect of most of these functions, we need a database
913    connection.  That's why they are covered in this test module.
914
915    """
916
917    def setUp(self):
918        self.c = connect()
919
920    def tearDown(self):
921        self.c.close()
922
923    def testGetDecimalPoint(self):
924        point = pg.get_decimal_point()
925        self.assertIsInstance(point, str)
926        self.assertEqual(point, '.')
927
928    def testSetDecimalPoint(self):
929        d = pg.Decimal
930        point = pg.get_decimal_point()
931        query = self.c.query
932        # check that money values can be interpreted correctly
933        # if and only if the decimal point is set appropriately
934        # for the current lc_monetary setting
935        query("set lc_monetary='en_US.UTF-8'")
936        pg.set_decimal_point('.')
937        r = query("select '34.25'::money").getresult()[0][0]
938        self.assertIsInstance(r, d)
939        self.assertEqual(r, d('34.25'))
940        pg.set_decimal_point(',')
941        r = query("select '34.25'::money").getresult()[0][0]
942        self.assertNotEqual(r, d('34.25'))
943        query("set lc_monetary='de_DE.UTF-8'")
944        pg.set_decimal_point(',')
945        r = query("select '34,25'::money").getresult()[0][0]
946        self.assertIsInstance(r, d)
947        self.assertEqual(r, d('34.25'))
948        pg.set_decimal_point('.')
949        r = query("select '34,25'::money").getresult()[0][0]
950        self.assertNotEqual(r, d('34.25'))
951        pg.set_decimal_point(point)
952
953    def testSetDecimal(self):
954        d = pg.Decimal
955        query = self.c.query
956        r = query("select 3425::numeric").getresult()[0][0]
957        self.assertIsInstance(r, d)
958        self.assertEqual(r, d('3425'))
959        pg.set_decimal(long)
960        r = query("select 3425::numeric").getresult()[0][0]
961        self.assertNotIsInstance(r, d)
962        self.assertIsInstance(r, long)
963        self.assertEqual(r, 3425L)
964        pg.set_decimal(d)
965
966    @unittest.skipUnless(namedtuple, 'Named tuples not available')
967    def testSetNamedresult(self):
968        query = self.c.query
969
970        r = query("select 1 as x, 2 as y").namedresult()[0]
971        self.assertIsInstance(r, tuple)
972        self.assertEqual(r, (1, 2))
973        self.assertIsNot(type(r), tuple)
974        self.assertEqual(r._fields, ('x', 'y'))
975        self.assertEqual(r._asdict(), {'x': 1, 'y': 2})
976        self.assertEqual(r.__class__.__name__, 'Row')
977
978        _namedresult = pg._namedresult
979        self.assertTrue(callable(_namedresult))
980        pg.set_namedresult(_namedresult)
981
982        r = query("select 1 as x, 2 as y").namedresult()[0]
983        self.assertIsInstance(r, tuple)
984        self.assertEqual(r, (1, 2))
985        self.assertIsNot(type(r), tuple)
986        self.assertEqual(r._fields, ('x', 'y'))
987        self.assertEqual(r._asdict(), {'x': 1, 'y': 2})
988        self.assertEqual(r.__class__.__name__, 'Row')
989
990        def _listresult(q):
991            return map(list, q.getresult())
992
993        pg.set_namedresult(_listresult)
994
995        try:
996            r = query("select 1 as x, 2 as y").namedresult()[0]
997            self.assertIsInstance(r, list)
998            self.assertEqual(r, [1, 2])
999            self.assertIsNot(type(r), tuple)
1000            self.assertFalse(hasattr(r, '_fields'))
1001            self.assertNotEqual(r.__class__.__name__, 'Row')
1002        finally:
1003            pg.set_namedresult(_namedresult)
1004
1005
1006if __name__ == '__main__':
1007    unittest.main()
Note: See TracBrowser for help on using the repository browser.