source: trunk/module/TEST_PyGreSQL_classic_connection.py @ 553

Last change on this file since 553 was 553, checked in by cito, 4 years ago

Require at least Python 2.6 for the trunk (5.x)

Support for even older Python versions is maintained in the 4.x branch.
The goal for 5.x is to be a single-source code for both Python 2 and 3,
and this is only possible by dropping support for Python 2.5 and older.
For instance, the new except .. as syntax works only since Python 2.6.
Otherwise we would need to use 2to3 and things would be very ugly.
Note that Python 2.6 is now 7 years old. We may want to drop Python 2.6
as well at some point if it turns out to be a burden.

  • Property svn:executable set to *
  • Property svn:keywords set to Id
File size: 35.0 KB
Line 
1#! /usr/bin/python
2# -*- coding: utf-8 -*-
3
4"""Test the classic PyGreSQL interface.
5
6Sub-tests for the low-level connection object.
7
8Contributed by Christoph Zwerschke.
9
10These tests need a database to test against.
11
12"""
13
14try:
15    import unittest2 as unittest  # for Python < 2.7
16except ImportError:
17    import unittest
18
19import sys
20import threading
21import time
22
23import pg  # the module under test
24
25from decimal import Decimal
26
27from collections import namedtuple
28
29from StringIO import StringIO
30
31# We need a database to test against.  If LOCAL_PyGreSQL.py exists we will
32# get our information from that.  Otherwise we use the defaults.
33dbname = 'unittest'
34dbhost = None
35dbport = 5432
36
37try:
38    from LOCAL_PyGreSQL import *
39except ImportError:
40    pass
41
42
43def connect():
44    """Create a basic pg connection to the test database."""
45    connection = pg.connect(dbname, dbhost, dbport)
46    connection.query("set client_min_messages=warning")
47    return connection
48
49
50class TestCanConnect(unittest.TestCase):
51    """Test whether a basic connection to PostgreSQL is possible."""
52
53    def testCanConnect(self):
54        try:
55            connection = connect()
56        except pg.Error, error:
57            self.fail('Cannot connect to database %s:\n%s' % (dbname, error))
58        try:
59            connection.close()
60        except pg.Error:
61            self.fail('Cannot close the database connection')
62
63
64class TestConnectObject(unittest.TestCase):
65    """"Test existence of basic pg connection methods."""
66
67    def setUp(self):
68        self.connection = connect()
69
70    def tearDown(self):
71        try:
72            self.connection.close()
73        except pg.InternalError:
74            pass
75
76    def testAllConnectAttributes(self):
77        attributes = '''db error host options port
78            protocol_version server_version status tty user'''.split()
79        connection_attributes = [a for a in dir(self.connection)
80            if not callable(eval("self.connection." + a))]
81        self.assertEqual(attributes, connection_attributes)
82
83    def testAllConnectMethods(self):
84        methods = '''cancel close endcopy
85            escape_bytea escape_identifier escape_literal escape_string
86            fileno get_notice_receiver getline getlo getnotify
87            inserttable locreate loimport parameter putline query reset
88            set_notice_receiver source transaction'''.split()
89        connection_methods = [a for a in dir(self.connection)
90            if callable(eval("self.connection." + a))]
91        self.assertEqual(methods, connection_methods)
92
93    def testAttributeDb(self):
94        self.assertEqual(self.connection.db, dbname)
95
96    def testAttributeError(self):
97        error = self.connection.error
98        self.assertTrue(not error or 'krb5_' in error)
99
100    def testAttributeHost(self):
101        def_host = 'localhost'
102        self.assertIsInstance(self.connection.host, str)
103        self.assertEqual(self.connection.host, dbhost or def_host)
104
105    def testAttributeOptions(self):
106        no_options = ''
107        self.assertEqual(self.connection.options, no_options)
108
109    def testAttributePort(self):
110        def_port = 5432
111        self.assertIsInstance(self.connection.port, int)
112        self.assertEqual(self.connection.port, dbport or def_port)
113
114    def testAttributeProtocolVersion(self):
115        protocol_version = self.connection.protocol_version
116        self.assertIsInstance(protocol_version, int)
117        self.assertTrue(2 <= protocol_version < 4)
118
119    def testAttributeServerVersion(self):
120        server_version = self.connection.server_version
121        self.assertIsInstance(server_version, int)
122        self.assertTrue(70400 <= server_version < 100000)
123
124    def testAttributeStatus(self):
125        status_ok = 1
126        self.assertIsInstance(self.connection.status, int)
127        self.assertEqual(self.connection.status, status_ok)
128
129    def testAttributeTty(self):
130        def_tty = ''
131        self.assertIsInstance(self.connection.tty, str)
132        self.assertEqual(self.connection.tty, def_tty)
133
134    def testAttributeUser(self):
135        no_user = 'Deprecated facility'
136        user = self.connection.user
137        self.assertTrue(user)
138        self.assertIsInstance(user, str)
139        self.assertNotEqual(user, no_user)
140
141    def testMethodQuery(self):
142        query = self.connection.query
143        query("select 1+1")
144        query("select 1+$1", (1,))
145        query("select 1+$1+$2", (2, 3))
146        query("select 1+$1+$2", [2, 3])
147
148    def testMethodQueryEmpty(self):
149        self.assertRaises(ValueError, self.connection.query, '')
150
151    def testMethodEndcopy(self):
152        try:
153            self.connection.endcopy()
154        except IOError:
155            pass
156
157    def testMethodClose(self):
158        self.connection.close()
159        try:
160            self.connection.reset()
161        except (pg.Error, TypeError):
162            pass
163        else:
164            self.fail('Reset should give an error for a closed connection')
165        self.assertRaises(pg.InternalError, self.connection.close)
166        try:
167            self.connection.query('select 1')
168        except (pg.Error, TypeError):
169            pass
170        else:
171            self.fail('Query should give an error for a closed connection')
172        self.connection = connect()
173
174    def testMethodReset(self):
175        query = self.connection.query
176        # check that client encoding gets reset
177        encoding = query('show client_encoding').getresult()[0][0].upper()
178        changed_encoding = 'LATIN1' if encoding == 'UTF8' else 'UTF8'
179        self.assertNotEqual(encoding, changed_encoding)
180        self.connection.query("set client_encoding=%s" % changed_encoding)
181        new_encoding = query('show client_encoding').getresult()[0][0].upper()
182        self.assertEqual(new_encoding, changed_encoding)
183        self.connection.reset()
184        new_encoding = query('show client_encoding').getresult()[0][0].upper()
185        self.assertNotEqual(new_encoding, changed_encoding)
186        self.assertEqual(new_encoding, encoding)
187
188    def testMethodCancel(self):
189        r = self.connection.cancel()
190        self.assertIsInstance(r, int)
191        self.assertEqual(r, 1)
192
193    def testCancelLongRunningThread(self):
194        errors = []
195
196        def sleep():
197            try:
198                self.connection.query('select pg_sleep(5)').getresult()
199            except pg.ProgrammingError, error:
200                errors.append(str(error))
201
202        thread = threading.Thread(target=sleep)
203        t1 = time.time()
204        thread.start()  # run the query
205        while 1:  # make sure the query is really running
206            time.sleep(0.1)
207            if thread.is_alive() or time.time() - t1 > 5:
208                break
209        r = self.connection.cancel()  # cancel the running query
210        thread.join()  # wait for the thread to end
211        t2 = time.time()
212
213        self.assertIsInstance(r, int)
214        self.assertEqual(r, 1)  # return code should be 1
215        self.assertLessEqual(t2 - t1, 3)  # time should be under 3 seconds
216        self.assertTrue(errors)
217
218    def testMethodFileNo(self):
219        r = self.connection.fileno()
220        self.assertIsInstance(r, int)
221        self.assertGreaterEqual(r, 0)
222
223
224class TestSimpleQueries(unittest.TestCase):
225    """"Test simple queries via a basic pg connection."""
226
227    def setUp(self):
228        self.c = connect()
229
230    def tearDown(self):
231        self.c.close()
232
233    def testSelect0(self):
234        q = "select 0"
235        self.c.query(q)
236
237    def testSelect0Semicolon(self):
238        q = "select 0;"
239        self.c.query(q)
240
241    def testSelectDotSemicolon(self):
242        q = "select .;"
243        self.assertRaises(pg.ProgrammingError, self.c.query, q)
244
245    def testGetresult(self):
246        q = "select 0"
247        result = [(0,)]
248        r = self.c.query(q).getresult()
249        self.assertIsInstance(r, list)
250        v = r[0]
251        self.assertIsInstance(v, tuple)
252        self.assertIsInstance(v[0], int)
253        self.assertEqual(r, result)
254
255    def testGetresultLong(self):
256        q = "select 1234567890123456790"
257        result = 1234567890123456790L
258        v = self.c.query(q).getresult()[0][0]
259        self.assertIsInstance(v, long)
260        self.assertEqual(v, result)
261
262    def testGetresultString(self):
263        result = 'Hello, world!'
264        q = "select '%s'" % result
265        v = self.c.query(q).getresult()[0][0]
266        self.assertIsInstance(v, str)
267        self.assertEqual(v, result)
268
269    def testDictresult(self):
270        q = "select 0 as alias0"
271        result = [{'alias0': 0}]
272        r = self.c.query(q).dictresult()
273        self.assertIsInstance(r, list)
274        v = r[0]
275        self.assertIsInstance(v, dict)
276        self.assertIsInstance(v['alias0'], int)
277        self.assertEqual(r, result)
278
279    def testDictresultLong(self):
280        q = "select 1234567890123456790 as longjohnsilver"
281        result = 1234567890123456790L
282        v = self.c.query(q).dictresult()[0]['longjohnsilver']
283        self.assertIsInstance(v, long)
284        self.assertEqual(v, result)
285
286    def testDictresultString(self):
287        result = 'Hello, world!'
288        q = "select '%s' as greeting" % result
289        v = self.c.query(q).dictresult()[0]['greeting']
290        self.assertIsInstance(v, str)
291        self.assertEqual(v, result)
292
293    @unittest.skipUnless(namedtuple, 'Named tuples not available')
294    def testNamedresult(self):
295        q = "select 0 as alias0"
296        result = [(0,)]
297        r = self.c.query(q).namedresult()
298        self.assertEqual(r, result)
299        v = r[0]
300        self.assertEqual(v._fields, ('alias0',))
301        self.assertEqual(v.alias0, 0)
302
303    def testGet3Cols(self):
304        q = "select 1,2,3"
305        result = [(1, 2, 3)]
306        r = self.c.query(q).getresult()
307        self.assertEqual(r, result)
308
309    def testGet3DictCols(self):
310        q = "select 1 as a,2 as b,3 as c"
311        result = [dict(a=1, b=2, c=3)]
312        r = self.c.query(q).dictresult()
313        self.assertEqual(r, result)
314
315    @unittest.skipUnless(namedtuple, 'Named tuples not available')
316    def testGet3NamedCols(self):
317        q = "select 1 as a,2 as b,3 as c"
318        result = [(1, 2, 3)]
319        r = self.c.query(q).namedresult()
320        self.assertEqual(r, result)
321        v = r[0]
322        self.assertEqual(v._fields, ('a', 'b', 'c'))
323        self.assertEqual(v.b, 2)
324
325    def testGet3Rows(self):
326        q = "select 3 union select 1 union select 2 order by 1"
327        result = [(1,), (2,), (3,)]
328        r = self.c.query(q).getresult()
329        self.assertEqual(r, result)
330
331    def testGet3DictRows(self):
332        q = ("select 3 as alias3"
333            " union select 1 union select 2 order by 1")
334        result = [{'alias3': 1}, {'alias3': 2}, {'alias3': 3}]
335        r = self.c.query(q).dictresult()
336        self.assertEqual(r, result)
337
338    @unittest.skipUnless(namedtuple, 'Named tuples not available')
339    def testGet3NamedRows(self):
340        q = ("select 3 as alias3"
341            " union select 1 union select 2 order by 1")
342        result = [(1,), (2,), (3,)]
343        r = self.c.query(q).namedresult()
344        self.assertEqual(r, result)
345        for v in r:
346            self.assertEqual(v._fields, ('alias3',))
347
348    def testDictresultNames(self):
349        q = "select 'MixedCase' as MixedCaseAlias"
350        result = [{'mixedcasealias': 'MixedCase'}]
351        r = self.c.query(q).dictresult()
352        self.assertEqual(r, result)
353        q = "select 'MixedCase' as \"MixedCaseAlias\""
354        result = [{'MixedCaseAlias': 'MixedCase'}]
355        r = self.c.query(q).dictresult()
356        self.assertEqual(r, result)
357
358    @unittest.skipUnless(namedtuple, 'Named tuples not available')
359    def testNamedresultNames(self):
360        q = "select 'MixedCase' as MixedCaseAlias"
361        result = [('MixedCase',)]
362        r = self.c.query(q).namedresult()
363        self.assertEqual(r, result)
364        v = r[0]
365        self.assertEqual(v._fields, ('mixedcasealias',))
366        self.assertEqual(v.mixedcasealias, 'MixedCase')
367        q = "select 'MixedCase' as \"MixedCaseAlias\""
368        r = self.c.query(q).namedresult()
369        self.assertEqual(r, result)
370        v = r[0]
371        self.assertEqual(v._fields, ('MixedCaseAlias',))
372        self.assertEqual(v.MixedCaseAlias, 'MixedCase')
373
374    def testBigGetresult(self):
375        num_cols = 100
376        num_rows = 100
377        q = "select " + ','.join(map(str, xrange(num_cols)))
378        q = ' union all '.join((q,) * num_rows)
379        r = self.c.query(q).getresult()
380        result = [tuple(range(num_cols))] * num_rows
381        self.assertEqual(r, result)
382
383    def testListfields(self):
384        q = ('select 0 as a, 0 as b, 0 as c,'
385            ' 0 as c, 0 as b, 0 as a,'
386            ' 0 as lowercase, 0 as UPPERCASE,'
387            ' 0 as MixedCase, 0 as "MixedCase",'
388            ' 0 as a_long_name_with_underscores,'
389            ' 0 as "A long name with Blanks"')
390        r = self.c.query(q).listfields()
391        result = ('a', 'b', 'c', 'c', 'b', 'a',
392            'lowercase', 'uppercase', 'mixedcase', 'MixedCase',
393            'a_long_name_with_underscores',
394            'A long name with Blanks')
395        self.assertEqual(r, result)
396
397    def testFieldname(self):
398        q = "select 0 as z, 0 as a, 0 as x, 0 as y"
399        r = self.c.query(q).fieldname(2)
400        self.assertEqual(r, 'x')
401        r = self.c.query(q).fieldname(3)
402        self.assertEqual(r, 'y')
403
404    def testFieldnum(self):
405        q = "select 1 as x"
406        self.assertRaises(ValueError, self.c.query(q).fieldnum, 'y')
407        q = "select 1 as x"
408        r = self.c.query(q).fieldnum('x')
409        self.assertIsInstance(r, int)
410        self.assertEqual(r, 0)
411        q = "select 0 as z, 0 as a, 0 as x, 0 as y"
412        r = self.c.query(q).fieldnum('x')
413        self.assertIsInstance(r, int)
414        self.assertEqual(r, 2)
415        r = self.c.query(q).fieldnum('y')
416        self.assertIsInstance(r, int)
417        self.assertEqual(r, 3)
418
419    def testNtuples(self):
420        q = "select 1 where false"
421        r = self.c.query(q).ntuples()
422        self.assertIsInstance(r, int)
423        self.assertEqual(r, 0)
424        q = ("select 1 as a, 2 as b, 3 as c, 4 as d"
425            " union select 5 as a, 6 as b, 7 as c, 8 as d")
426        r = self.c.query(q).ntuples()
427        self.assertIsInstance(r, int)
428        self.assertEqual(r, 2)
429        q = ("select 1 union select 2 union select 3"
430            " union select 4 union select 5 union select 6")
431        r = self.c.query(q).ntuples()
432        self.assertIsInstance(r, int)
433        self.assertEqual(r, 6)
434
435    def testQuery(self):
436        query = self.c.query
437        query("drop table if exists test_table")
438        q = "create table test_table (n integer) with oids"
439        r = query(q)
440        self.assertIsNone(r)
441        q = "insert into test_table values (1)"
442        r = query(q)
443        self.assertIsInstance(r, int)
444        q = "insert into test_table select 2"
445        r = query(q)
446        self.assertIsInstance(r, int)
447        oid = r
448        q = "select oid from test_table where n=2"
449        r = query(q).getresult()
450        self.assertEqual(len(r), 1)
451        r = r[0]
452        self.assertEqual(len(r), 1)
453        r = r[0]
454        self.assertIsInstance(r, int)
455        self.assertEqual(r, oid)
456        q = "insert into test_table select 3 union select 4 union select 5"
457        r = query(q)
458        self.assertIsInstance(r, str)
459        self.assertEqual(r, '3')
460        q = "update test_table set n=4 where n<5"
461        r = query(q)
462        self.assertIsInstance(r, str)
463        self.assertEqual(r, '4')
464        q = "delete from test_table"
465        r = query(q)
466        self.assertIsInstance(r, str)
467        self.assertEqual(r, '5')
468        query("drop table test_table")
469
470    def testPrint(self):
471        q = ("select 1 as a, 'hello' as h, 'w' as world"
472            " union select 2, 'xyz', 'uvw'")
473        r = self.c.query(q)
474        s = StringIO()
475        stdout, sys.stdout = sys.stdout, s
476        try:
477            print r
478        except Exception:
479            pass
480        finally:
481            sys.stdout = stdout
482        r = s.getvalue()
483        s.close()
484        self.assertEqual(r,
485            'a|  h  |world\n'
486            '-+-----+-----\n'
487            '1|hello|w    \n'
488            '2|xyz  |uvw  \n'
489            '(2 rows)\n')
490
491
492class TestParamQueries(unittest.TestCase):
493    """"Test queries with parameters via a basic pg connection."""
494
495    def setUp(self):
496        self.c = connect()
497
498    def tearDown(self):
499        self.c.close()
500
501    def testQueryWithNoneParam(self):
502        self.assertEqual(self.c.query("select $1::integer", (None,)
503            ).getresult(), [(None,)])
504        self.assertEqual(self.c.query("select $1::text", [None]
505            ).getresult(), [(None,)])
506
507    def testQueryWithBoolParams(self):
508        query = self.c.query
509        self.assertEqual(query("select false").getresult(), [('f',)])
510        self.assertEqual(query("select true").getresult(), [('t',)])
511        self.assertEqual(query("select $1::bool", (None,)).getresult(),
512            [(None,)])
513        self.assertEqual(query("select $1::bool", ('f',)).getresult(), [('f',)])
514        self.assertEqual(query("select $1::bool", ('t',)).getresult(), [('t',)])
515        self.assertEqual(query("select $1::bool", ('false',)).getresult(),
516            [('f',)])
517        self.assertEqual(query("select $1::bool", ('true',)).getresult(),
518            [('t',)])
519        self.assertEqual(query("select $1::bool", ('n',)).getresult(), [('f',)])
520        self.assertEqual(query("select $1::bool", ('y',)).getresult(), [('t',)])
521        self.assertEqual(query("select $1::bool", (0,)).getresult(), [('f',)])
522        self.assertEqual(query("select $1::bool", (1,)).getresult(), [('t',)])
523        self.assertEqual(query("select $1::bool", (False,)).getresult(),
524            [('f',)])
525        self.assertEqual(query("select $1::bool", (True,)).getresult(),
526            [('t',)])
527
528    def testQueryWithIntParams(self):
529        query = self.c.query
530        self.assertEqual(query("select 1+1").getresult(), [(2,)])
531        self.assertEqual(query("select 1+$1", (1,)).getresult(), [(2,)])
532        self.assertEqual(query("select 1+$1", [1]).getresult(), [(2,)])
533        self.assertEqual(query("select $1::integer", (2,)).getresult(), [(2,)])
534        self.assertEqual(query("select $1::text", (2,)).getresult(), [('2',)])
535        self.assertEqual(query("select 1+$1::numeric", [1]).getresult(),
536            [(Decimal('2'),)])
537        self.assertEqual(query("select 1, $1::integer", (2,)
538            ).getresult(), [(1, 2)])
539        self.assertEqual(query("select 1 union select $1", (2,)
540            ).getresult(), [(1,), (2,)])
541        self.assertEqual(query("select $1::integer+$2", (1, 2)
542            ).getresult(), [(3,)])
543        self.assertEqual(query("select $1::integer+$2", [1, 2]
544            ).getresult(), [(3,)])
545        self.assertEqual(query("select 0+$1+$2+$3+$4+$5+$6", range(6)
546            ).getresult(), [(15,)])
547
548    def testQueryWithStrParams(self):
549        query = self.c.query
550        self.assertEqual(query("select $1||', world!'", ('Hello',)
551            ).getresult(), [('Hello, world!',)])
552        self.assertEqual(query("select $1||', world!'", ['Hello']
553            ).getresult(), [('Hello, world!',)])
554        self.assertEqual(query("select $1||', '||$2||'!'", ('Hello', 'world'),
555            ).getresult(), [('Hello, world!',)])
556        self.assertEqual(query("select $1::text", ('Hello, world!',)
557            ).getresult(), [('Hello, world!',)])
558        self.assertEqual(query("select $1::text,$2::text", ('Hello', 'world')
559            ).getresult(), [('Hello', 'world')])
560        self.assertEqual(query("select $1::text,$2::text", ['Hello', 'world']
561            ).getresult(), [('Hello', 'world')])
562        self.assertEqual(query("select $1::text union select $2::text",
563            ('Hello', 'world')).getresult(), [('Hello',), ('world',)])
564        self.assertEqual(query("select $1||', '||$2||'!'", ('Hello',
565            'w\xc3\xb6rld')).getresult(), [('Hello, w\xc3\xb6rld!',)])
566
567    def testQueryWithUnicodeParams(self):
568        query = self.c.query
569        query('set client_encoding = utf8')
570        self.assertEqual(query("select $1||', '||$2||'!'",
571            ('Hello', u'w\xf6rld')).getresult(), [('Hello, w\xc3\xb6rld!',)])
572        self.assertEqual(query("select $1||', '||$2||'!'",
573            ('Hello', u'\u043c\u0438\u0440')).getresult(),
574            [('Hello, \xd0\xbc\xd0\xb8\xd1\x80!',)])
575        query('set client_encoding = latin1')
576        self.assertEqual(query("select $1||', '||$2||'!'",
577            ('Hello', u'w\xf6rld')).getresult(), [('Hello, w\xf6rld!',)])
578        self.assertRaises(UnicodeError, query, "select $1||', '||$2||'!'",
579            ('Hello', u'\u043c\u0438\u0440'))
580        query('set client_encoding = iso_8859_1')
581        self.assertEqual(query("select $1||', '||$2||'!'",
582            ('Hello', u'w\xf6rld')).getresult(), [('Hello, w\xf6rld!',)])
583        self.assertRaises(UnicodeError, query, "select $1||', '||$2||'!'",
584            ('Hello', u'\u043c\u0438\u0440'))
585        query('set client_encoding = iso_8859_5')
586        self.assertRaises(UnicodeError, query, "select $1||', '||$2||'!'",
587            ('Hello', u'w\xf6rld'))
588        self.assertEqual(query("select $1||', '||$2||'!'",
589            ('Hello', u'\u043c\u0438\u0440')).getresult(),
590            [('Hello, \xdc\xd8\xe0!',)])
591        query('set client_encoding = sql_ascii')
592        self.assertRaises(UnicodeError, query, "select $1||', '||$2||'!'",
593            ('Hello', u'w\xf6rld'))
594
595    def testQueryWithMixedParams(self):
596        self.assertEqual(self.c.query("select $1+2,$2||', world!'",
597            (1, 'Hello'),).getresult(), [(3, 'Hello, world!')])
598        self.assertEqual(self.c.query("select $1::integer,$2::date,$3::text",
599            (4711, None, 'Hello!'),).getresult(), [(4711, None, 'Hello!')])
600
601    def testQueryWithDuplicateParams(self):
602        self.assertRaises(pg.ProgrammingError,
603            self.c.query, "select $1+$1", (1,))
604        self.assertRaises(pg.ProgrammingError,
605            self.c.query, "select $1+$1", (1, 2))
606
607    def testQueryWithZeroParams(self):
608        self.assertEqual(self.c.query("select 1+1", []
609            ).getresult(), [(2,)])
610
611    def testQueryWithGarbage(self):
612        garbage = r"'\{}+()-#[]oo324"
613        self.assertEqual(self.c.query("select $1::text AS garbage", (garbage,)
614            ).dictresult(), [{'garbage': garbage}])
615
616    def testUnicodeQuery(self):
617        query = self.c.query
618        self.assertEqual(query(u"select 1+1").getresult(), [(2,)])
619        self.assertRaises(TypeError, query, u"select 'Hello, w\xf6rld!'")
620
621
622class TestInserttable(unittest.TestCase):
623    """"Test inserttable method."""
624
625    @classmethod
626    def setUpClass(cls):
627        c = connect()
628        c.query("drop table if exists test cascade")
629        c.query("create table test ("
630            "i2 smallint, i4 integer, i8 bigint, b boolean, dt date, ti time,"
631            "d numeric, f4 real, f8 double precision, m money,"
632            "c char(1), v4 varchar(4), c4 char(4), t text)")
633        c.close()
634
635    @classmethod
636    def tearDownClass(cls):
637        c = connect()
638        c.query("drop table test cascade")
639        c.close()
640
641    def setUp(self):
642        self.c = connect()
643        self.c.query("set datestyle='ISO,YMD'")
644
645    def tearDown(self):
646        self.c.query("truncate table test")
647        self.c.close()
648
649    data = [
650        (-1, -1, -1L, True, '1492-10-12', '08:30:00',
651            -1.2345, -1.75, -1.875, '-1.25', '-', 'r?', '!u', 'xyz'),
652        (0, 0, 0L, False, '1607-04-14', '09:00:00',
653            0.0, 0.0, 0.0, '0.0', ' ', '0123', '4567', '890'),
654        (1, 1, 1L, True, '1801-03-04', '03:45:00',
655            1.23456, 1.75, 1.875, '1.25', 'x', 'bc', 'cdef', 'g'),
656        (2, 2, 2L, False, '1903-12-17', '11:22:00',
657            2.345678, 2.25, 2.125, '2.75', 'y', 'q', 'ijk', 'mnop\nstux!')]
658
659    def get_back(self):
660        """Convert boolean and decimal values back."""
661        data = []
662        for row in self.c.query("select * from test order by 1").getresult():
663            self.assertIsInstance(row, tuple)
664            row = list(row)
665            if row[0] is not None:  # smallint
666                self.assertIsInstance(row[0], int)
667            if row[1] is not None:  # integer
668                self.assertIsInstance(row[1], int)
669            if row[2] is not None:  # bigint
670                self.assertIsInstance(row[2], long)
671            if row[3] is not None:  # boolean
672                self.assertIsInstance(row[3], str)
673                row[3] = {'f': False, 't': True}.get(row[3])
674            if row[4] is not None:  # date
675                self.assertIsInstance(row[4], str)
676                self.assertTrue(row[4].replace('-', '').isdigit())
677            if row[5] is not None:  # time
678                self.assertIsInstance(row[5], str)
679                self.assertTrue(row[5].replace(':', '').isdigit())
680            if row[6] is not None:  # numeric
681                self.assertIsInstance(row[6], Decimal)
682                row[6] = float(row[6])
683            if row[7] is not None:  # real
684                self.assertIsInstance(row[7], float)
685            if row[8] is not None:  # double precision
686                self.assertIsInstance(row[8], float)
687                row[8] = float(row[8])
688            if row[9] is not None:  # money
689                self.assertIsInstance(row[9], Decimal)
690                row[9] = str(float(row[9]))
691            if row[10] is not None:  # char(1)
692                self.assertIsInstance(row[10], str)
693                self.assertEqual(len(row[10]), 1)
694            if row[11] is not None:  # varchar(4)
695                self.assertIsInstance(row[11], str)
696                self.assertLessEqual(len(row[11]), 4)
697            if row[12] is not None:  # char(4)
698                self.assertIsInstance(row[12], str)
699                self.assertEqual(len(row[12]), 4)
700                row[12] = row[12].rstrip()
701            if row[13] is not None:  # text
702                self.assertIsInstance(row[13], str)
703            row = tuple(row)
704            data.append(row)
705        return data
706
707    def testInserttable1Row(self):
708        data = self.data[2:3]
709        self.c.inserttable("test", data)
710        self.assertEqual(self.get_back(), data)
711
712    def testInserttable4Rows(self):
713        data = self.data
714        self.c.inserttable("test", data)
715        self.assertEqual(self.get_back(), data)
716
717    def testInserttableMultipleRows(self):
718        num_rows = 100
719        data = self.data[2:3] * num_rows
720        self.c.inserttable("test", data)
721        r = self.c.query("select count(*) from test").getresult()[0][0]
722        self.assertEqual(r, num_rows)
723
724    def testInserttableMultipleCalls(self):
725        num_rows = 10
726        data = self.data[2:3]
727        for _i in range(num_rows):
728            self.c.inserttable("test", data)
729        r = self.c.query("select count(*) from test").getresult()[0][0]
730        self.assertEqual(r, num_rows)
731
732    def testInserttableNullValues(self):
733        data = [(None,) * 14] * 100
734        self.c.inserttable("test", data)
735        self.assertEqual(self.get_back(), data)
736
737    def testInserttableMaxValues(self):
738        data = [(2 ** 15 - 1, int(2 ** 31 - 1), long(2 ** 31 - 1),
739            True, '2999-12-31', '11:59:59', 1e99,
740            1.0 + 1.0 / 32, 1.0 + 1.0 / 32, None,
741            "1", "1234", "1234", "1234" * 100)]
742        self.c.inserttable("test", data)
743        self.assertEqual(self.get_back(), data)
744
745
746class TestDirectSocketAccess(unittest.TestCase):
747    """"Test copy command with direct socket access."""
748
749    @classmethod
750    def setUpClass(cls):
751        c = connect()
752        c.query("drop table if exists test cascade")
753        c.query("create table test (i int, v varchar(16))")
754        c.close()
755
756    @classmethod
757    def tearDownClass(cls):
758        c = connect()
759        c.query("drop table test cascade")
760        c.close()
761
762    def setUp(self):
763        self.c = connect()
764        self.c.query("set datestyle='ISO,YMD'")
765
766    def tearDown(self):
767        self.c.query("truncate table test")
768        self.c.close()
769
770    def testPutline(self):
771        putline = self.c.putline
772        query = self.c.query
773        data = list(enumerate("apple pear plum cherry banana".split()))
774        query("copy test from stdin")
775        try:
776            for i, v in data:
777                putline("%d\t%s\n" % (i, v))
778            putline("\\.\n")
779        finally:
780            self.c.endcopy()
781        r = query("select * from test").getresult()
782        self.assertEqual(r, data)
783
784    def testPutline(self):
785        getline = self.c.getline
786        query = self.c.query
787        data = list(enumerate("apple banana pear plum strawberry".split()))
788        n = len(data)
789        self.c.inserttable('test', data)
790        query("copy test to stdout")
791        try:
792            for i in range(n + 2):
793                v = getline()
794                if i < n:
795                    self.assertEqual(v, '%d\t%s' % data[i])
796                elif i == n:
797                    self.assertEqual(v, '\\.')
798                else:
799                    self.assertIsNone(v)
800        finally:
801            try:
802                self.c.endcopy()
803            except IOError:
804                pass
805
806    def testParameterChecks(self):
807        self.assertRaises(TypeError, self.c.putline)
808        self.assertRaises(TypeError, self.c.getline, 'invalid')
809        self.assertRaises(TypeError, self.c.endcopy, 'invalid')
810
811
812class TestNotificatons(unittest.TestCase):
813    """"Test notification support."""
814
815    def setUp(self):
816        self.c = connect()
817
818    def tearDown(self):
819        self.c.close()
820
821    def testGetNotify(self):
822        getnotify = self.c.getnotify
823        query = self.c.query
824        self.assertIsNone(getnotify())
825        query('listen test_notify')
826        try:
827            self.assertIsNone(self.c.getnotify())
828            query("notify test_notify")
829            r = getnotify()
830            self.assertIsInstance(r, tuple)
831            self.assertEqual(len(r), 3)
832            self.assertIsInstance(r[0], str)
833            self.assertIsInstance(r[1], int)
834            self.assertIsInstance(r[2], str)
835            self.assertEqual(r[0], 'test_notify')
836            self.assertEqual(r[2], '')
837            self.assertIsNone(self.c.getnotify())
838            try:
839                query("notify test_notify, 'test_payload'")
840            except pg.ProgrammingError:  # PostgreSQL < 9.0
841                pass
842            else:
843                r = getnotify()
844                self.assertTrue(isinstance(r, tuple))
845                self.assertEqual(len(r), 3)
846                self.assertIsInstance(r[0], str)
847                self.assertIsInstance(r[1], int)
848                self.assertIsInstance(r[2], str)
849                self.assertEqual(r[0], 'test_notify')
850                self.assertEqual(r[2], 'test_payload')
851                self.assertIsNone(getnotify())
852        finally:
853            query('unlisten test_notify')
854
855    def testGetNoticeReceiver(self):
856        self.assertIsNone(self.c.get_notice_receiver())
857
858    def testSetNoticeReceiver(self):
859        self.assertRaises(TypeError, self.c.set_notice_receiver, None)
860        self.assertRaises(TypeError, self.c.set_notice_receiver, 42)
861        self.assertIsNone(self.c.set_notice_receiver(lambda notice: None))
862
863    def testSetAndGetNoticeReceiver(self):
864        r = lambda notice: None
865        self.assertIsNone(self.c.set_notice_receiver(r))
866        self.assertIs(self.c.get_notice_receiver(), r)
867
868    def testNoticeReceiver(self):
869        self.c.query('''create function bilbo_notice() returns void AS $$
870            begin
871                raise warning 'Bilbo was here!';
872            end;
873            $$ language plpgsql''')
874        try:
875            received = {}
876
877            def notice_receiver(notice):
878                for attr in dir(notice):
879                    value = getattr(notice, attr)
880                    if isinstance(value, str):
881                        value = value.replace('WARNUNG', 'WARNING')
882                    received[attr] = value
883
884            self.c.set_notice_receiver(notice_receiver)
885            self.c.query('''select bilbo_notice()''')
886            self.assertEqual(received, dict(
887                pgcnx=self.c, message='WARNING:  Bilbo was here!\n',
888                severity='WARNING', primary='Bilbo was here!',
889                detail=None, hint=None))
890        finally:
891            self.c.query('''drop function bilbo_notice();''')
892
893
894class TestConfigFunctions(unittest.TestCase):
895    """Test the functions for changing default settings.
896
897    To test the effect of most of these functions, we need a database
898    connection.  That's why they are covered in this test module.
899
900    """
901
902    def setUp(self):
903        self.c = connect()
904
905    def tearDown(self):
906        self.c.close()
907
908    def testGetDecimalPoint(self):
909        point = pg.get_decimal_point()
910        self.assertIsInstance(point, str)
911        self.assertEqual(point, '.')
912
913    def testSetDecimalPoint(self):
914        d = pg.Decimal
915        point = pg.get_decimal_point()
916        query = self.c.query
917        # check that money values can be interpreted correctly
918        # if and only if the decimal point is set appropriately
919        # for the current lc_monetary setting
920        query("set lc_monetary='en_US'")
921        pg.set_decimal_point('.')
922        r = query("select '34.25'::money").getresult()[0][0]
923        self.assertIsInstance(r, d)
924        self.assertEqual(r, d('34.25'))
925        pg.set_decimal_point(',')
926        r = query("select '34.25'::money").getresult()[0][0]
927        self.assertNotEqual(r, d('34.25'))
928        query("set lc_monetary='de_DE'")
929        pg.set_decimal_point(',')
930        r = query("select '34,25'::money").getresult()[0][0]
931        self.assertIsInstance(r, d)
932        self.assertEqual(r, d('34.25'))
933        pg.set_decimal_point('.')
934        r = query("select '34,25'::money").getresult()[0][0]
935        self.assertNotEqual(r, d('34.25'))
936        pg.set_decimal_point(point)
937
938    def testSetDecimal(self):
939        d = pg.Decimal
940        query = self.c.query
941        r = query("select 3425::numeric").getresult()[0][0]
942        self.assertIsInstance(r, d)
943        self.assertEqual(r, d('3425'))
944        pg.set_decimal(long)
945        r = query("select 3425::numeric").getresult()[0][0]
946        self.assertNotIsInstance(r, d)
947        self.assertIsInstance(r, long)
948        self.assertEqual(r, 3425L)
949        pg.set_decimal(d)
950
951    @unittest.skipUnless(namedtuple, 'Named tuples not available')
952    def testSetNamedresult(self):
953        query = self.c.query
954
955        r = query("select 1 as x, 2 as y").namedresult()[0]
956        self.assertIsInstance(r, tuple)
957        self.assertEqual(r, (1, 2))
958        self.assertIsNot(type(r), tuple)
959        self.assertEqual(r._fields, ('x', 'y'))
960        self.assertEqual(r._asdict(), {'x': 1, 'y': 2})
961        self.assertEqual(r.__class__.__name__, 'Row')
962
963        _namedresult = pg._namedresult
964        self.assertTrue(callable(_namedresult))
965        pg.set_namedresult(_namedresult)
966
967        r = query("select 1 as x, 2 as y").namedresult()[0]
968        self.assertIsInstance(r, tuple)
969        self.assertEqual(r, (1, 2))
970        self.assertIsNot(type(r), tuple)
971        self.assertEqual(r._fields, ('x', 'y'))
972        self.assertEqual(r._asdict(), {'x': 1, 'y': 2})
973        self.assertEqual(r.__class__.__name__, 'Row')
974
975        def _listresult(q):
976            return map(list, q.getresult())
977
978        pg.set_namedresult(_listresult)
979
980        try:
981            r = query("select 1 as x, 2 as y").namedresult()[0]
982            self.assertIsInstance(r, list)
983            self.assertEqual(r, [1, 2])
984            self.assertIsNot(type(r), tuple)
985            self.assertFalse(hasattr(r, '_fields'))
986            self.assertNotEqual(r.__class__.__name__, 'Row')
987        finally:
988            pg.set_namedresult(_namedresult)
989
990
991if __name__ == '__main__':
992    unittest.main()
Note: See TracBrowser for help on using the repository browser.