source: branches/4.x/module/tests/test_classic_connection.py @ 650

Last change on this file since 650 was 650, checked in by cito, 4 years ago

Make tests compatible with Python 2.4

We want to be able to test the 4.x branch with older Python versions
so that we can specify eligible minimum requirements.

  • Property svn:executable set to *
  • Property svn:keywords set to Id
File size: 43.2 KB
Line 
1#! /usr/bin/python
2# -*- coding: utf-8 -*-
3
4"""Test the classic PyGreSQL interface.
5
6Sub-tests for the low-level connection object.
7
8Contributed by Christoph Zwerschke.
9
10These tests need a database to test against.
11
12"""
13
14try:
15    import unittest2 as unittest  # for Python < 2.7
16except ImportError:
17    import unittest
18import sys
19import tempfile
20import threading
21import time
22
23import pg  # the module under test
24
25from decimal import Decimal
26try:
27    from collections import namedtuple
28except ImportError:  # Python < 2.6
29    namedtuple = None
30
31# We need a database to test against.  If LOCAL_PyGreSQL.py exists we will
32# get our information from that.  Otherwise we use the defaults.
33dbname = 'unittest'
34dbhost = None
35dbport = 5432
36
37try:
38    from LOCAL_PyGreSQL import *
39except ImportError:
40    pass
41
42
43def connect():
44    """Create a basic pg connection to the test database."""
45    connection = pg.connect(dbname, dbhost, dbport)
46    connection.query("set client_min_messages=warning")
47    return connection
48
49
50class TestCanConnect(unittest.TestCase):
51    """Test whether a basic connection to PostgreSQL is possible."""
52
53    def testCanConnect(self):
54        try:
55            connection = connect()
56        except pg.Error, error:
57            self.fail('Cannot connect to database %s:\n%s' % (dbname, error))
58        try:
59            connection.close()
60        except pg.Error:
61            self.fail('Cannot close the database connection')
62
63
64class TestConnectObject(unittest.TestCase):
65    """Test existence of basic pg connection methods."""
66
67    def setUp(self):
68        self.connection = connect()
69
70    def tearDown(self):
71        try:
72            self.connection.close()
73        except pg.InternalError:
74            pass
75
76    def testAllConnectAttributes(self):
77        attributes = '''db error host options port
78            protocol_version server_version status tty user'''.split()
79        connection_attributes = [a for a in dir(self.connection)
80            if not callable(eval("self.connection." + a))]
81        self.assertEqual(attributes, connection_attributes)
82
83    def testAllConnectMethods(self):
84        methods = '''cancel close endcopy
85            escape_bytea escape_identifier escape_literal escape_string
86            fileno get_notice_receiver getline getlo getnotify
87            inserttable locreate loimport parameter putline query reset
88            set_notice_receiver source transaction'''.split()
89        connection_methods = [a for a in dir(self.connection)
90            if callable(eval("self.connection." + a))]
91        self.assertEqual(methods, connection_methods)
92
93    def testAttributeDb(self):
94        self.assertEqual(self.connection.db, dbname)
95
96    def testAttributeError(self):
97        error = self.connection.error
98        self.assertTrue(not error or 'krb5_' in error)
99
100    def testAttributeHost(self):
101        def_host = 'localhost'
102        self.assertIsInstance(self.connection.host, str)
103        self.assertEqual(self.connection.host, dbhost or def_host)
104
105    def testAttributeOptions(self):
106        no_options = ''
107        self.assertEqual(self.connection.options, no_options)
108
109    def testAttributePort(self):
110        def_port = 5432
111        self.assertIsInstance(self.connection.port, int)
112        self.assertEqual(self.connection.port, dbport or def_port)
113
114    def testAttributeProtocolVersion(self):
115        protocol_version = self.connection.protocol_version
116        self.assertIsInstance(protocol_version, int)
117        self.assertTrue(2 <= protocol_version < 4)
118
119    def testAttributeServerVersion(self):
120        server_version = self.connection.server_version
121        self.assertIsInstance(server_version, int)
122        self.assertTrue(70400 <= server_version < 100000)
123
124    def testAttributeStatus(self):
125        status_ok = 1
126        self.assertIsInstance(self.connection.status, int)
127        self.assertEqual(self.connection.status, status_ok)
128
129    def testAttributeTty(self):
130        def_tty = ''
131        self.assertIsInstance(self.connection.tty, str)
132        self.assertEqual(self.connection.tty, def_tty)
133
134    def testAttributeUser(self):
135        no_user = 'Deprecated facility'
136        user = self.connection.user
137        self.assertTrue(user)
138        self.assertIsInstance(user, str)
139        self.assertNotEqual(user, no_user)
140
141    def testMethodQuery(self):
142        query = self.connection.query
143        query("select 1+1")
144        query("select 1+$1", (1,))
145        query("select 1+$1+$2", (2, 3))
146        query("select 1+$1+$2", [2, 3])
147
148    def testMethodQueryEmpty(self):
149        self.assertRaises(ValueError, self.connection.query, '')
150
151    def testMethodEndcopy(self):
152        try:
153            self.connection.endcopy()
154        except IOError:
155            pass
156
157    def testMethodClose(self):
158        self.connection.close()
159        try:
160            self.connection.reset()
161        except (pg.Error, TypeError):
162            pass
163        else:
164            self.fail('Reset should give an error for a closed connection')
165        self.assertRaises(pg.InternalError, self.connection.close)
166        try:
167            self.connection.query('select 1')
168        except (pg.Error, TypeError):
169            pass
170        else:
171            self.fail('Query should give an error for a closed connection')
172        self.connection = connect()
173
174    def testMethodReset(self):
175        query = self.connection.query
176        # check that client encoding gets reset
177        encoding = query('show client_encoding').getresult()[0][0].upper()
178        changed_encoding = encoding == 'UTF8' and 'LATIN1' or 'UTF8'
179        self.assertNotEqual(encoding, changed_encoding)
180        self.connection.query("set client_encoding=%s" % changed_encoding)
181        new_encoding = query('show client_encoding').getresult()[0][0].upper()
182        self.assertEqual(new_encoding, changed_encoding)
183        self.connection.reset()
184        new_encoding = query('show client_encoding').getresult()[0][0].upper()
185        self.assertNotEqual(new_encoding, changed_encoding)
186        self.assertEqual(new_encoding, encoding)
187
188    def testMethodCancel(self):
189        r = self.connection.cancel()
190        self.assertIsInstance(r, int)
191        self.assertEqual(r, 1)
192
193    def testCancelLongRunningThread(self):
194        errors = []
195
196        def sleep():
197            try:
198                self.connection.query('select pg_sleep(5)').getresult()
199            except pg.ProgrammingError, error:
200                errors.append(str(error))
201
202        thread = threading.Thread(target=sleep)
203        t1 = time.time()
204        thread.start()  # run the query
205        while 1:  # make sure the query is really running
206            time.sleep(0.1)
207            if thread.isAlive() or time.time() - t1 > 5:
208                break
209        r = self.connection.cancel()  # cancel the running query
210        thread.join()  # wait for the thread to end
211        t2 = time.time()
212
213        self.assertIsInstance(r, int)
214        self.assertEqual(r, 1)  # return code should be 1
215        self.assertLessEqual(t2 - t1, 3)  # time should be under 3 seconds
216        self.assertTrue(errors)
217
218    def testMethodFileNo(self):
219        r = self.connection.fileno()
220        self.assertIsInstance(r, int)
221        self.assertGreaterEqual(r, 0)
222
223
224class TestSimpleQueries(unittest.TestCase):
225    """Test simple queries via a basic pg connection."""
226
227    def setUp(self):
228        self.c = connect()
229
230    def tearDown(self):
231        self.c.close()
232
233    def testSelect0(self):
234        q = "select 0"
235        self.c.query(q)
236
237    def testSelect0Semicolon(self):
238        q = "select 0;"
239        self.c.query(q)
240
241    def testSelectDotSemicolon(self):
242        q = "select .;"
243        self.assertRaises(pg.ProgrammingError, self.c.query, q)
244
245    def testGetresult(self):
246        q = "select 0"
247        result = [(0,)]
248        r = self.c.query(q).getresult()
249        self.assertIsInstance(r, list)
250        v = r[0]
251        self.assertIsInstance(v, tuple)
252        self.assertIsInstance(v[0], int)
253        self.assertEqual(r, result)
254
255    def testGetresultLong(self):
256        q = "select 9876543210"
257        result = 9876543210L
258        v = self.c.query(q).getresult()[0][0]
259        self.assertIsInstance(v, long)
260        self.assertEqual(v, result)
261
262    def testGetresultDecimal(self):
263        q = "select 98765432109876543210"
264        result = Decimal(98765432109876543210L)
265        v = self.c.query(q).getresult()[0][0]
266        self.assertIsInstance(v, Decimal)
267        self.assertEqual(v, result)
268
269    def testGetresultString(self):
270        result = 'Hello, world!'
271        q = "select '%s'" % result
272        v = self.c.query(q).getresult()[0][0]
273        self.assertIsInstance(v, str)
274        self.assertEqual(v, result)
275
276    def testDictresult(self):
277        q = "select 0 as alias0"
278        result = [{'alias0': 0}]
279        r = self.c.query(q).dictresult()
280        self.assertIsInstance(r, list)
281        v = r[0]
282        self.assertIsInstance(v, dict)
283        self.assertIsInstance(v['alias0'], int)
284        self.assertEqual(r, result)
285
286    def testDictresultLong(self):
287        q = "select 9876543210 as longjohnsilver"
288        result = 9876543210L
289        v = self.c.query(q).dictresult()[0]['longjohnsilver']
290        self.assertIsInstance(v, long)
291        self.assertEqual(v, result)
292
293    def testDictresultDecimal(self):
294        q = "select 98765432109876543210 as longjohnsilver"
295        result = Decimal(98765432109876543210L)
296        v = self.c.query(q).dictresult()[0]['longjohnsilver']
297        self.assertIsInstance(v, Decimal)
298        self.assertEqual(v, result)
299
300    def testDictresultString(self):
301        result = 'Hello, world!'
302        q = "select '%s' as greeting" % result
303        v = self.c.query(q).dictresult()[0]['greeting']
304        self.assertIsInstance(v, str)
305        self.assertEqual(v, result)
306
307    @unittest.skipUnless(namedtuple, 'Named tuples not available')
308    def testNamedresult(self):
309        q = "select 0 as alias0"
310        result = [(0,)]
311        r = self.c.query(q).namedresult()
312        self.assertEqual(r, result)
313        v = r[0]
314        self.assertEqual(v._fields, ('alias0',))
315        self.assertEqual(v.alias0, 0)
316
317    def testGet3Cols(self):
318        q = "select 1,2,3"
319        result = [(1, 2, 3)]
320        r = self.c.query(q).getresult()
321        self.assertEqual(r, result)
322
323    def testGet3DictCols(self):
324        q = "select 1 as a,2 as b,3 as c"
325        result = [dict(a=1, b=2, c=3)]
326        r = self.c.query(q).dictresult()
327        self.assertEqual(r, result)
328
329    @unittest.skipUnless(namedtuple, 'Named tuples not available')
330    def testGet3NamedCols(self):
331        q = "select 1 as a,2 as b,3 as c"
332        result = [(1, 2, 3)]
333        r = self.c.query(q).namedresult()
334        self.assertEqual(r, result)
335        v = r[0]
336        self.assertEqual(v._fields, ('a', 'b', 'c'))
337        self.assertEqual(v.b, 2)
338
339    def testGet3Rows(self):
340        q = "select 3 union select 1 union select 2 order by 1"
341        result = [(1,), (2,), (3,)]
342        r = self.c.query(q).getresult()
343        self.assertEqual(r, result)
344
345    def testGet3DictRows(self):
346        q = ("select 3 as alias3"
347            " union select 1 union select 2 order by 1")
348        result = [{'alias3': 1}, {'alias3': 2}, {'alias3': 3}]
349        r = self.c.query(q).dictresult()
350        self.assertEqual(r, result)
351
352    @unittest.skipUnless(namedtuple, 'Named tuples not available')
353    def testGet3NamedRows(self):
354        q = ("select 3 as alias3"
355            " union select 1 union select 2 order by 1")
356        result = [(1,), (2,), (3,)]
357        r = self.c.query(q).namedresult()
358        self.assertEqual(r, result)
359        for v in r:
360            self.assertEqual(v._fields, ('alias3',))
361
362    def testDictresultNames(self):
363        q = "select 'MixedCase' as MixedCaseAlias"
364        result = [{'mixedcasealias': 'MixedCase'}]
365        r = self.c.query(q).dictresult()
366        self.assertEqual(r, result)
367        q = "select 'MixedCase' as \"MixedCaseAlias\""
368        result = [{'MixedCaseAlias': 'MixedCase'}]
369        r = self.c.query(q).dictresult()
370        self.assertEqual(r, result)
371
372    @unittest.skipUnless(namedtuple, 'Named tuples not available')
373    def testNamedresultNames(self):
374        q = "select 'MixedCase' as MixedCaseAlias"
375        result = [('MixedCase',)]
376        r = self.c.query(q).namedresult()
377        self.assertEqual(r, result)
378        v = r[0]
379        self.assertEqual(v._fields, ('mixedcasealias',))
380        self.assertEqual(v.mixedcasealias, 'MixedCase')
381        q = "select 'MixedCase' as \"MixedCaseAlias\""
382        r = self.c.query(q).namedresult()
383        self.assertEqual(r, result)
384        v = r[0]
385        self.assertEqual(v._fields, ('MixedCaseAlias',))
386        self.assertEqual(v.MixedCaseAlias, 'MixedCase')
387
388    def testBigGetresult(self):
389        num_cols = 100
390        num_rows = 100
391        q = "select " + ','.join(map(str, xrange(num_cols)))
392        q = ' union all '.join((q,) * num_rows)
393        r = self.c.query(q).getresult()
394        result = [tuple(range(num_cols))] * num_rows
395        self.assertEqual(r, result)
396
397    def testListfields(self):
398        q = ('select 0 as a, 0 as b, 0 as c,'
399            ' 0 as c, 0 as b, 0 as a,'
400            ' 0 as lowercase, 0 as UPPERCASE,'
401            ' 0 as MixedCase, 0 as "MixedCase",'
402            ' 0 as a_long_name_with_underscores,'
403            ' 0 as "A long name with Blanks"')
404        r = self.c.query(q).listfields()
405        result = ('a', 'b', 'c', 'c', 'b', 'a',
406            'lowercase', 'uppercase', 'mixedcase', 'MixedCase',
407            'a_long_name_with_underscores',
408            'A long name with Blanks')
409        self.assertEqual(r, result)
410
411    def testFieldname(self):
412        q = "select 0 as z, 0 as a, 0 as x, 0 as y"
413        r = self.c.query(q).fieldname(2)
414        self.assertEqual(r, 'x')
415        r = self.c.query(q).fieldname(3)
416        self.assertEqual(r, 'y')
417
418    def testFieldnum(self):
419        q = "select 1 as x"
420        self.assertRaises(ValueError, self.c.query(q).fieldnum, 'y')
421        q = "select 1 as x"
422        r = self.c.query(q).fieldnum('x')
423        self.assertIsInstance(r, int)
424        self.assertEqual(r, 0)
425        q = "select 0 as z, 0 as a, 0 as x, 0 as y"
426        r = self.c.query(q).fieldnum('x')
427        self.assertIsInstance(r, int)
428        self.assertEqual(r, 2)
429        r = self.c.query(q).fieldnum('y')
430        self.assertIsInstance(r, int)
431        self.assertEqual(r, 3)
432
433    def testNtuples(self):
434        q = "select 1 where false"
435        r = self.c.query(q).ntuples()
436        self.assertIsInstance(r, int)
437        self.assertEqual(r, 0)
438        q = ("select 1 as a, 2 as b, 3 as c, 4 as d"
439            " union select 5 as a, 6 as b, 7 as c, 8 as d")
440        r = self.c.query(q).ntuples()
441        self.assertIsInstance(r, int)
442        self.assertEqual(r, 2)
443        q = ("select 1 union select 2 union select 3"
444            " union select 4 union select 5 union select 6")
445        r = self.c.query(q).ntuples()
446        self.assertIsInstance(r, int)
447        self.assertEqual(r, 6)
448
449    def testQuery(self):
450        query = self.c.query
451        query("drop table if exists test_table")
452        q = "create table test_table (n integer) with oids"
453        r = query(q)
454        self.assertIsNone(r)
455        q = "insert into test_table values (1)"
456        r = query(q)
457        self.assertIsInstance(r, int)
458        q = "insert into test_table select 2"
459        r = query(q)
460        self.assertIsInstance(r, int)
461        oid = r
462        q = "select oid from test_table where n=2"
463        r = query(q).getresult()
464        self.assertEqual(len(r), 1)
465        r = r[0]
466        self.assertEqual(len(r), 1)
467        r = r[0]
468        self.assertIsInstance(r, int)
469        self.assertEqual(r, oid)
470        q = "insert into test_table select 3 union select 4 union select 5"
471        r = query(q)
472        self.assertIsInstance(r, str)
473        self.assertEqual(r, '3')
474        q = "update test_table set n=4 where n<5"
475        r = query(q)
476        self.assertIsInstance(r, str)
477        self.assertEqual(r, '4')
478        q = "delete from test_table"
479        r = query(q)
480        self.assertIsInstance(r, str)
481        self.assertEqual(r, '5')
482        query("drop table test_table")
483
484    def testPrint(self):
485        q = ("select 1 as a, 'hello' as h, 'w' as world"
486            " union select 2, 'xyz', 'uvw'")
487        r = self.c.query(q)
488        f = tempfile.TemporaryFile()
489        stdout, sys.stdout = sys.stdout, f
490        try:
491            print r
492        except Exception:
493            pass
494        sys.stdout = stdout
495        f.seek(0)
496        r = f.read()
497        f.close()
498        self.assertEqual(r,
499            'a|  h  |world\n'
500            '-+-----+-----\n'
501            '1|hello|w    \n'
502            '2|xyz  |uvw  \n'
503            '(2 rows)\n')
504
505
506class TestParamQueries(unittest.TestCase):
507    """Test queries with parameters via a basic pg connection."""
508
509    def setUp(self):
510        self.c = connect()
511
512    def tearDown(self):
513        self.c.close()
514
515    def testQueryWithNoneParam(self):
516        self.assertEqual(self.c.query("select $1::integer", (None,)
517            ).getresult(), [(None,)])
518        self.assertEqual(self.c.query("select $1::text", [None]
519            ).getresult(), [(None,)])
520
521    def testQueryWithBoolParams(self, use_bool=None):
522        query = self.c.query
523        if use_bool is not None:
524            use_bool_default = pg.get_bool()
525            pg.set_bool(use_bool)
526        try:
527            v_false, v_true = use_bool and (False, True) or 'ft'
528            r_false, r_true = [(v_false,)], [(v_true,)]
529            self.assertEqual(query("select false").getresult(), r_false)
530            self.assertEqual(query("select true").getresult(), r_true)
531            q = "select $1::bool"
532            self.assertEqual(query(q, (None,)).getresult(), [(None,)])
533            self.assertEqual(query(q, ('f',)).getresult(), r_false)
534            self.assertEqual(query(q, ('t',)).getresult(), r_true)
535            self.assertEqual(query(q, ('false',)).getresult(), r_false)
536            self.assertEqual(query(q, ('true',)).getresult(), r_true)
537            self.assertEqual(query(q, ('n',)).getresult(), r_false)
538            self.assertEqual(query(q, ('y',)).getresult(), r_true)
539            self.assertEqual(query(q, (0,)).getresult(), r_false)
540            self.assertEqual(query(q, (1,)).getresult(), r_true)
541            self.assertEqual(query(q, (False,)).getresult(), r_false)
542            self.assertEqual(query(q, (True,)).getresult(), r_true)
543        finally:
544            if use_bool is not None:
545                pg.set_bool(use_bool_default)
546
547    def testQueryWithBoolParamsAndUseBool(self):
548        self.testQueryWithBoolParams(use_bool=True)
549
550    def testQueryWithIntParams(self):
551        query = self.c.query
552        self.assertEqual(query("select 1+1").getresult(), [(2,)])
553        self.assertEqual(query("select 1+$1", (1,)).getresult(), [(2,)])
554        self.assertEqual(query("select 1+$1", [1]).getresult(), [(2,)])
555        self.assertEqual(query("select $1::integer", (2,)).getresult(), [(2,)])
556        self.assertEqual(query("select $1::text", (2,)).getresult(), [('2',)])
557        self.assertEqual(query("select 1+$1::numeric", [1]).getresult(),
558            [(Decimal('2'),)])
559        self.assertEqual(query("select 1, $1::integer", (2,)
560            ).getresult(), [(1, 2)])
561        self.assertEqual(query("select 1 union select $1", (2,)
562            ).getresult(), [(1,), (2,)])
563        self.assertEqual(query("select $1::integer+$2", (1, 2)
564            ).getresult(), [(3,)])
565        self.assertEqual(query("select $1::integer+$2", [1, 2]
566            ).getresult(), [(3,)])
567        self.assertEqual(query("select 0+$1+$2+$3+$4+$5+$6", range(6)
568            ).getresult(), [(15,)])
569
570    def testQueryWithStrParams(self):
571        query = self.c.query
572        self.assertEqual(query("select $1||', world!'", ('Hello',)
573            ).getresult(), [('Hello, world!',)])
574        self.assertEqual(query("select $1||', world!'", ['Hello']
575            ).getresult(), [('Hello, world!',)])
576        self.assertEqual(query("select $1||', '||$2||'!'", ('Hello', 'world'),
577            ).getresult(), [('Hello, world!',)])
578        self.assertEqual(query("select $1::text", ('Hello, world!',)
579            ).getresult(), [('Hello, world!',)])
580        self.assertEqual(query("select $1::text,$2::text", ('Hello', 'world')
581            ).getresult(), [('Hello', 'world')])
582        self.assertEqual(query("select $1::text,$2::text", ['Hello', 'world']
583            ).getresult(), [('Hello', 'world')])
584        self.assertEqual(query("select $1::text union select $2::text",
585            ('Hello', 'world')).getresult(), [('Hello',), ('world',)])
586        self.assertEqual(query("select $1||', '||$2||'!'", ('Hello',
587            'w\xc3\xb6rld')).getresult(), [('Hello, w\xc3\xb6rld!',)])
588
589    def testQueryWithUnicodeParams(self):
590        query = self.c.query
591        query('set client_encoding = utf8')
592        self.assertEqual(query("select $1||', '||$2||'!'",
593            ('Hello', u'w\xf6rld')).getresult(), [('Hello, w\xc3\xb6rld!',)])
594        self.assertEqual(query("select $1||', '||$2||'!'",
595            ('Hello', u'\u043c\u0438\u0440')).getresult(),
596            [('Hello, \xd0\xbc\xd0\xb8\xd1\x80!',)])
597        query('set client_encoding = latin1')
598        self.assertEqual(query("select $1||', '||$2||'!'",
599            ('Hello', u'w\xf6rld')).getresult(), [('Hello, w\xf6rld!',)])
600        self.assertRaises(UnicodeError, query, "select $1||', '||$2||'!'",
601            ('Hello', u'\u043c\u0438\u0440'))
602        query('set client_encoding = iso_8859_1')
603        self.assertEqual(query("select $1||', '||$2||'!'",
604            ('Hello', u'w\xf6rld')).getresult(), [('Hello, w\xf6rld!',)])
605        self.assertRaises(UnicodeError, query, "select $1||', '||$2||'!'",
606            ('Hello', u'\u043c\u0438\u0440'))
607        query('set client_encoding = iso_8859_5')
608        self.assertRaises(UnicodeError, query, "select $1||', '||$2||'!'",
609            ('Hello', u'w\xf6rld'))
610        self.assertEqual(query("select $1||', '||$2||'!'",
611            ('Hello', u'\u043c\u0438\u0440')).getresult(),
612            [('Hello, \xdc\xd8\xe0!',)])
613        query('set client_encoding = sql_ascii')
614        self.assertRaises(UnicodeError, query, "select $1||', '||$2||'!'",
615            ('Hello', u'w\xf6rld'))
616
617    def testQueryWithMixedParams(self):
618        self.assertEqual(self.c.query("select $1+2,$2||', world!'",
619            (1, 'Hello'),).getresult(), [(3, 'Hello, world!')])
620        self.assertEqual(self.c.query("select $1::integer,$2::date,$3::text",
621            (4711, None, 'Hello!'),).getresult(), [(4711, None, 'Hello!')])
622
623    def testQueryWithDuplicateParams(self):
624        self.assertRaises(pg.ProgrammingError,
625            self.c.query, "select $1+$1", (1,))
626        self.assertRaises(pg.ProgrammingError,
627            self.c.query, "select $1+$1", (1, 2))
628
629    def testQueryWithZeroParams(self):
630        self.assertEqual(self.c.query("select 1+1", []
631            ).getresult(), [(2,)])
632
633    def testQueryWithGarbage(self):
634        garbage = r"'\{}+()-#[]oo324"
635        self.assertEqual(self.c.query("select $1::text AS garbage", (garbage,)
636            ).dictresult(), [{'garbage': garbage}])
637
638    def testUnicodeQuery(self):
639        query = self.c.query
640        self.assertEqual(query(u"select 1+1").getresult(), [(2,)])
641        self.assertRaises(TypeError, query, u"select 'Hello, w\xf6rld!'")
642
643
644class TestInserttable(unittest.TestCase):
645    """Test inserttable method."""
646
647    @classmethod
648    def setUpClass(cls):
649        c = connect()
650        c.query("drop table if exists test cascade")
651        c.query("create table test ("
652            "i2 smallint, i4 integer, i8 bigint, b boolean, dt date, ti time,"
653            "d numeric, f4 real, f8 double precision, m money,"
654            "c char(1), v4 varchar(4), c4 char(4), t text)")
655        c.close()
656
657    @classmethod
658    def tearDownClass(cls):
659        c = connect()
660        c.query("drop table test cascade")
661        c.close()
662
663    def setUp(self):
664        self.c = connect()
665        self.c.query("set lc_monetary='C'")
666        self.c.query("set datestyle='ISO,YMD'")
667
668    def tearDown(self):
669        self.c.query("truncate table test")
670        self.c.close()
671
672    data = [
673        (-1, -1, -1L, True, '1492-10-12', '08:30:00',
674            -1.2345, -1.75, -1.875, '-1.25', '-', 'r?', '!u', 'xyz'),
675        (0, 0, 0L, False, '1607-04-14', '09:00:00',
676            0.0, 0.0, 0.0, '0.0', ' ', '0123', '4567', '890'),
677        (1, 1, 1L, True, '1801-03-04', '03:45:00',
678            1.23456, 1.75, 1.875, '1.25', 'x', 'bc', 'cdef', 'g'),
679        (2, 2, 2L, False, '1903-12-17', '11:22:00',
680            2.345678, 2.25, 2.125, '2.75', 'y', 'q', 'ijk', 'mnop\nstux!')]
681
682    def get_back(self):
683        """Convert boolean and decimal values back."""
684        data = []
685        for row in self.c.query("select * from test order by 1").getresult():
686            self.assertIsInstance(row, tuple)
687            row = list(row)
688            if row[0] is not None:  # smallint
689                self.assertIsInstance(row[0], int)
690            if row[1] is not None:  # integer
691                self.assertIsInstance(row[1], int)
692            if row[2] is not None:  # bigint
693                self.assertIsInstance(row[2], long)
694            if row[3] is not None:  # boolean
695                self.assertIsInstance(row[3], str)
696                row[3] = {'f': False, 't': True}.get(row[3])
697            if row[4] is not None:  # date
698                self.assertIsInstance(row[4], str)
699                self.assertTrue(row[4].replace('-', '').isdigit())
700            if row[5] is not None:  # time
701                self.assertIsInstance(row[5], str)
702                self.assertTrue(row[5].replace(':', '').isdigit())
703            if row[6] is not None:  # numeric
704                self.assertIsInstance(row[6], Decimal)
705                row[6] = float(row[6])
706            if row[7] is not None:  # real
707                self.assertIsInstance(row[7], float)
708            if row[8] is not None:  # double precision
709                self.assertIsInstance(row[8], float)
710                row[8] = float(row[8])
711            if row[9] is not None:  # money
712                self.assertIsInstance(row[9], Decimal)
713                row[9] = str(float(row[9]))
714            if row[10] is not None:  # char(1)
715                self.assertIsInstance(row[10], str)
716                self.assertEqual(len(row[10]), 1)
717            if row[11] is not None:  # varchar(4)
718                self.assertIsInstance(row[11], str)
719                self.assertLessEqual(len(row[11]), 4)
720            if row[12] is not None:  # char(4)
721                self.assertIsInstance(row[12], str)
722                self.assertEqual(len(row[12]), 4)
723                row[12] = row[12].rstrip()
724            if row[13] is not None:  # text
725                self.assertIsInstance(row[13], str)
726            row = tuple(row)
727            data.append(row)
728        return data
729
730    def testInserttable1Row(self):
731        data = self.data[2:3]
732        self.c.inserttable("test", data)
733        self.assertEqual(self.get_back(), data)
734
735    def testInserttable4Rows(self):
736        data = self.data
737        self.c.inserttable("test", data)
738        self.assertEqual(self.get_back(), data)
739
740    def testInserttableMultipleRows(self):
741        num_rows = 100
742        data = self.data[2:3] * num_rows
743        self.c.inserttable("test", data)
744        r = self.c.query("select count(*) from test").getresult()[0][0]
745        self.assertEqual(r, num_rows)
746
747    def testInserttableMultipleCalls(self):
748        num_rows = 10
749        data = self.data[2:3]
750        for _i in range(num_rows):
751            self.c.inserttable("test", data)
752        r = self.c.query("select count(*) from test").getresult()[0][0]
753        self.assertEqual(r, num_rows)
754
755    def testInserttableNullValues(self):
756        data = [(None,) * 14] * 100
757        self.c.inserttable("test", data)
758        self.assertEqual(self.get_back(), data)
759
760    def testInserttableMaxValues(self):
761        data = [(2 ** 15 - 1, int(2 ** 31 - 1), long(2 ** 31 - 1),
762            True, '2999-12-31', '11:59:59', 1e99,
763            1.0 + 1.0 / 32, 1.0 + 1.0 / 32, None,
764            "1", "1234", "1234", "1234" * 100)]
765        self.c.inserttable("test", data)
766        self.assertEqual(self.get_back(), data)
767
768
769class TestDirectSocketAccess(unittest.TestCase):
770    """Test copy command with direct socket access."""
771
772    @classmethod
773    def setUpClass(cls):
774        c = connect()
775        c.query("drop table if exists test cascade")
776        c.query("create table test (i int, v varchar(16))")
777        c.close()
778
779    @classmethod
780    def tearDownClass(cls):
781        c = connect()
782        c.query("drop table test cascade")
783        c.close()
784
785    def setUp(self):
786        self.c = connect()
787        self.c.query("set datestyle='ISO,YMD'")
788
789    def tearDown(self):
790        self.c.query("truncate table test")
791        self.c.close()
792
793    def testPutline(self):
794        putline = self.c.putline
795        query = self.c.query
796        data = list(enumerate("apple pear plum cherry banana".split()))
797        query("copy test from stdin")
798        try:
799            for i, v in data:
800                putline("%d\t%s\n" % (i, v))
801            putline("\\.\n")
802        finally:
803            self.c.endcopy()
804        r = query("select * from test").getresult()
805        self.assertEqual(r, data)
806
807    def testGetline(self):
808        getline = self.c.getline
809        query = self.c.query
810        data = list(enumerate("apple banana pear plum strawberry".split()))
811        n = len(data)
812        self.c.inserttable('test', data)
813        query("copy test to stdout")
814        try:
815            for i in range(n + 2):
816                v = getline()
817                if i < n:
818                    self.assertEqual(v, '%d\t%s' % data[i])
819                elif i == n:
820                    self.assertEqual(v, '\\.')
821                else:
822                    self.assertIsNone(v)
823        finally:
824            try:
825                self.c.endcopy()
826            except IOError:
827                pass
828
829    def testParameterChecks(self):
830        self.assertRaises(TypeError, self.c.putline)
831        self.assertRaises(TypeError, self.c.getline, 'invalid')
832        self.assertRaises(TypeError, self.c.endcopy, 'invalid')
833
834
835class TestNotificatons(unittest.TestCase):
836    """Test notification support."""
837
838    def setUp(self):
839        self.c = connect()
840
841    def tearDown(self):
842        self.c.close()
843
844    def testGetNotify(self):
845        getnotify = self.c.getnotify
846        query = self.c.query
847        self.assertIsNone(getnotify())
848        query('listen test_notify')
849        try:
850            self.assertIsNone(self.c.getnotify())
851            query("notify test_notify")
852            r = getnotify()
853            self.assertIsInstance(r, tuple)
854            self.assertEqual(len(r), 3)
855            self.assertIsInstance(r[0], str)
856            self.assertIsInstance(r[1], int)
857            self.assertIsInstance(r[2], str)
858            self.assertEqual(r[0], 'test_notify')
859            self.assertEqual(r[2], '')
860            self.assertIsNone(self.c.getnotify())
861            try:
862                query("notify test_notify, 'test_payload'")
863            except pg.ProgrammingError:  # PostgreSQL < 9.0
864                pass
865            else:
866                r = getnotify()
867                self.assertTrue(isinstance(r, tuple))
868                self.assertEqual(len(r), 3)
869                self.assertIsInstance(r[0], str)
870                self.assertIsInstance(r[1], int)
871                self.assertIsInstance(r[2], str)
872                self.assertEqual(r[0], 'test_notify')
873                self.assertEqual(r[2], 'test_payload')
874                self.assertIsNone(getnotify())
875        finally:
876            query('unlisten test_notify')
877
878    def testGetNoticeReceiver(self):
879        self.assertIsNone(self.c.get_notice_receiver())
880
881    def testSetNoticeReceiver(self):
882        self.assertRaises(TypeError, self.c.set_notice_receiver, None)
883        self.assertRaises(TypeError, self.c.set_notice_receiver, 42)
884        self.assertIsNone(self.c.set_notice_receiver(lambda notice: None))
885
886    def testSetAndGetNoticeReceiver(self):
887        r = lambda notice: None
888        self.assertIsNone(self.c.set_notice_receiver(r))
889        self.assertIs(self.c.get_notice_receiver(), r)
890
891    def testNoticeReceiver(self):
892        self.c.query('''create function bilbo_notice() returns void AS $$
893            begin
894                raise warning 'Bilbo was here!';
895            end;
896            $$ language plpgsql''')
897        try:
898            received = {}
899
900            def notice_receiver(notice):
901                for attr in dir(notice):
902                    value = getattr(notice, attr)
903                    if isinstance(value, str):
904                        value = value.replace('WARNUNG', 'WARNING')
905                    received[attr] = value
906
907            self.c.set_notice_receiver(notice_receiver)
908            self.c.query('''select bilbo_notice()''')
909            self.assertEqual(received, dict(
910                pgcnx=self.c, message='WARNING:  Bilbo was here!\n',
911                severity='WARNING', primary='Bilbo was here!',
912                detail=None, hint=None))
913        finally:
914            self.c.query('''drop function bilbo_notice();''')
915
916
917class TestConfigFunctions(unittest.TestCase):
918    """Test the functions for changing default settings.
919
920    To test the effect of most of these functions, we need a database
921    connection.  That's why they are covered in this test module.
922
923    """
924
925    def setUp(self):
926        self.c = connect()
927
928    def tearDown(self):
929        self.c.close()
930
931    def testGetDecimalPoint(self):
932        point = pg.get_decimal_point()
933        # error if a parameter is passed
934        self.assertRaises(TypeError, pg.get_decimal_point, point)
935        self.assertIsInstance(point, str)
936        self.assertEqual(point, '.')  # the default setting
937        pg.set_decimal_point(',')
938        try:
939            r = pg.get_decimal_point()
940        finally:
941            pg.set_decimal_point(point)
942        self.assertIsInstance(r, str)
943        self.assertEqual(r, ',')
944        pg.set_decimal_point("'")
945        try:
946            r = pg.get_decimal_point()
947        finally:
948            pg.set_decimal_point(point)
949        self.assertIsInstance(r, str)
950        self.assertEqual(r, "'")
951        pg.set_decimal_point('')
952        try:
953            r = pg.get_decimal_point()
954        finally:
955            pg.set_decimal_point(point)
956        self.assertIsNone(r)
957        pg.set_decimal_point(None)
958        try:
959            r = pg.get_decimal_point()
960        finally:
961            pg.set_decimal_point(point)
962        self.assertIsNone(r)
963
964    def testSetDecimalPoint(self):
965        d = pg.Decimal
966        point = pg.get_decimal_point()
967        self.assertRaises(TypeError, pg.set_decimal_point)
968        # error if decimal point is not a string
969        self.assertRaises(TypeError, pg.set_decimal_point, 0)
970        # error if more than one decimal point passed
971        self.assertRaises(TypeError, pg.set_decimal_point, '.', ',')
972        self.assertRaises(TypeError, pg.set_decimal_point, '.,')
973        # error if decimal point is not a punctuation character
974        self.assertRaises(TypeError, pg.set_decimal_point, '0')
975        query = self.c.query
976        # check that money values are interpreted as decimal values
977        # only if decimal_point is set, and that the result is correct
978        # only if it is set suitable for the current lc_monetary setting
979        select_money = "select '34.25'::money"
980        proper_money = d('34.25')
981        bad_money = d('3425')
982        en_locales = 'en', 'en_US', 'en_US.utf8', 'en_US.UTF-8'
983        en_money = '$34.25', '$ 34.25', '34.25$', '34.25 $', '34.25 Dollar'
984        de_locales = 'de', 'de_DE', 'de_DE.utf8', 'de_DE.UTF-8'
985        de_money = ('34,25€', '34,25 €', '€34,25' '€ 34,25',
986            '34,25 EUR', '34,25 Euro', '34,25 DM')
987        # first try with English localization (using the point)
988        for lc in en_locales:
989            try:
990                query("set lc_monetary='%s'" % lc)
991            except pg.ProgrammingError:
992                pass
993            else:
994                break
995        else:
996            self.skipTest("cannot set English money locale")
997        try:
998            r = query(select_money)
999        except pg.ProgrammingError:
1000            # this can happen if the currency signs cannot be
1001            # converted using the encoding of the test database
1002            self.skipTest("database does not support English money")
1003        pg.set_decimal_point(None)
1004        try:
1005            r = r.getresult()[0][0]
1006        finally:
1007            pg.set_decimal_point(point)
1008        self.assertIsInstance(r, str)
1009        self.assertIn(r, en_money)
1010        r = query(select_money)
1011        pg.set_decimal_point('')
1012        try:
1013            r = r.getresult()[0][0]
1014        finally:
1015            pg.set_decimal_point(point)
1016        self.assertIsInstance(r, str)
1017        self.assertIn(r, en_money)
1018        r = query(select_money)
1019        pg.set_decimal_point('.')
1020        try:
1021            r = r.getresult()[0][0]
1022        finally:
1023            pg.set_decimal_point(point)
1024        self.assertIsInstance(r, d)
1025        self.assertEqual(r, proper_money)
1026        r = query(select_money)
1027        pg.set_decimal_point(',')
1028        try:
1029            r = r.getresult()[0][0]
1030        finally:
1031            pg.set_decimal_point(point)
1032        self.assertIsInstance(r, d)
1033        self.assertEqual(r, bad_money)
1034        r = query(select_money)
1035        pg.set_decimal_point("'")
1036        try:
1037            r = r.getresult()[0][0]
1038        finally:
1039            pg.set_decimal_point(point)
1040        self.assertIsInstance(r, d)
1041        self.assertEqual(r, bad_money)
1042        # then try with German localization (using the comma)
1043        for lc in de_locales:
1044            try:
1045                query("set lc_monetary='%s'" % lc)
1046            except pg.ProgrammingError:
1047                pass
1048            else:
1049                break
1050        else:
1051            self.skipTest("cannot set German money locale")
1052        select_money = select_money.replace('.', ',')
1053        try:
1054            r = query(select_money)
1055        except pg.ProgrammingError:
1056            self.skipTest("database does not support English money")
1057        pg.set_decimal_point(None)
1058        try:
1059            r = r.getresult()[0][0]
1060        finally:
1061            pg.set_decimal_point(point)
1062        self.assertIsInstance(r, str)
1063        self.assertIn(r, de_money)
1064        r = query(select_money)
1065        pg.set_decimal_point('')
1066        try:
1067            r = r.getresult()[0][0]
1068        finally:
1069            pg.set_decimal_point(point)
1070        self.assertIsInstance(r, str)
1071        self.assertIn(r, de_money)
1072        r = query(select_money)
1073        pg.set_decimal_point(',')
1074        try:
1075            r = r.getresult()[0][0]
1076        finally:
1077            pg.set_decimal_point(point)
1078        self.assertIsInstance(r, d)
1079        self.assertEqual(r, proper_money)
1080        r = query(select_money)
1081        pg.set_decimal_point('.')
1082        try:
1083            r = r.getresult()[0][0]
1084        finally:
1085            pg.set_decimal_point(point)
1086        self.assertEqual(r, bad_money)
1087        r = query(select_money)
1088        pg.set_decimal_point("'")
1089        try:
1090            r = r.getresult()[0][0]
1091        finally:
1092            pg.set_decimal_point(point)
1093        self.assertEqual(r, bad_money)
1094
1095    def testGetDecimal(self):
1096        decimal_class = pg.get_decimal()
1097        # error if a parameter is passed
1098        self.assertRaises(TypeError, pg.get_decimal, decimal_class)
1099        self.assertIs(decimal_class, pg.Decimal)  # the default setting
1100        pg.set_decimal(int)
1101        try:
1102            r = pg.get_decimal()
1103        finally:
1104            pg.set_decimal(decimal_class)
1105        self.assertIs(r, int)
1106        r = pg.get_decimal()
1107        self.assertIs(r, decimal_class)
1108
1109    def testSetDecimal(self):
1110        decimal_class = pg.get_decimal()
1111        # error if no parameter is passed
1112        self.assertRaises(TypeError, pg.set_decimal)
1113        query = self.c.query
1114        try:
1115            r = query("select 3425::numeric")
1116        except pg.ProgrammingError:
1117            self.skipTest('database does not support numeric')
1118        r = r.getresult()[0][0]
1119        self.assertIsInstance(r, decimal_class)
1120        self.assertEqual(r, decimal_class('3425'))
1121        r = query("select 3425::numeric")
1122        pg.set_decimal(int)
1123        try:
1124            r = r.getresult()[0][0]
1125        finally:
1126            pg.set_decimal(decimal_class)
1127        self.assertNotIsInstance(r, decimal_class)
1128        self.assertIsInstance(r, int)
1129        self.assertEqual(r, int(3425))
1130
1131    def testGetBool(self):
1132        use_bool = pg.get_bool()
1133        # error if a parameter is passed
1134        self.assertRaises(TypeError, pg.get_bool, use_bool)
1135        self.assertIsInstance(use_bool, bool)
1136        self.assertIs(use_bool, False)  # the default setting
1137        pg.set_bool(True)
1138        try:
1139            r = pg.get_bool()
1140        finally:
1141            pg.set_bool(use_bool)
1142        self.assertIsInstance(r, bool)
1143        self.assertIs(r, True)
1144        pg.set_bool(False)
1145        try:
1146            r = pg.get_bool()
1147        finally:
1148            pg.set_bool(use_bool)
1149        self.assertIsInstance(r, bool)
1150        self.assertIs(r, False)
1151        pg.set_bool(1)
1152        try:
1153            r = pg.get_bool()
1154        finally:
1155            pg.set_bool(use_bool)
1156        self.assertIsInstance(r, bool)
1157        self.assertIs(r, True)
1158        pg.set_bool(0)
1159        try:
1160            r = pg.get_bool()
1161        finally:
1162            pg.set_bool(use_bool)
1163        self.assertIsInstance(r, bool)
1164        self.assertIs(r, False)
1165
1166    def testSetBool(self):
1167        use_bool = pg.get_bool()
1168        # error if no parameter is passed
1169        self.assertRaises(TypeError, pg.set_bool)
1170        query = self.c.query
1171        try:
1172            r = query("select true::bool")
1173        except pg.ProgrammingError:
1174            self.skipTest('database does not support bool')
1175        r = r.getresult()[0][0]
1176        self.assertIsInstance(r, str)
1177        self.assertEqual(r, 't')
1178        r = query("select true::bool")
1179        pg.set_bool(True)
1180        try:
1181            r = r.getresult()[0][0]
1182        finally:
1183            pg.set_bool(use_bool)
1184        self.assertIsInstance(r, bool)
1185        self.assertIs(r, True)
1186        r = query("select true::bool")
1187        pg.set_bool(False)
1188        try:
1189            r = r.getresult()[0][0]
1190        finally:
1191            pg.set_bool(use_bool)
1192        self.assertIsInstance(r, str)
1193        self.assertIs(r, 't')
1194
1195    @unittest.skipUnless(namedtuple, 'Named tuples not available')
1196    def testGetNamedresult(self):
1197        namedresult = pg.get_namedresult()
1198        # error if a parameter is passed
1199        self.assertRaises(TypeError, pg.get_namedresult, namedresult)
1200        self.assertIs(namedresult, pg._namedresult)  # the default setting
1201
1202    @unittest.skipUnless(namedtuple, 'Named tuples not available')
1203    def testSetNamedresult(self):
1204        namedresult = pg.get_namedresult()
1205        self.assertTrue(callable(namedresult))
1206
1207        query = self.c.query
1208
1209        r = query("select 1 as x, 2 as y").namedresult()[0]
1210        self.assertIsInstance(r, tuple)
1211        self.assertEqual(r, (1, 2))
1212        self.assertIsNot(type(r), tuple)
1213        self.assertEqual(r._fields, ('x', 'y'))
1214        self.assertEqual(r._asdict(), {'x': 1, 'y': 2})
1215        self.assertEqual(r.__class__.__name__, 'Row')
1216
1217        def listresult(q):
1218            return [list(row) for row in q.getresult()]
1219
1220        pg.set_namedresult(listresult)
1221        try:
1222            r = pg.get_namedresult()
1223            self.assertIs(r, listresult)
1224            r = query("select 1 as x, 2 as y").namedresult()[0]
1225            self.assertIsInstance(r, list)
1226            self.assertEqual(r, [1, 2])
1227            self.assertIsNot(type(r), tuple)
1228            self.assertFalse(hasattr(r, '_fields'))
1229            self.assertNotEqual(r.__class__.__name__, 'Row')
1230        finally:
1231            pg.set_namedresult(namedresult)
1232
1233        r = pg.get_namedresult()
1234        self.assertIs(r, namedresult)
1235
1236
1237if __name__ == '__main__':
1238    unittest.main()
Note: See TracBrowser for help on using the repository browser.