source: trunk/module/TEST_PyGreSQL_classic_connection.py @ 550

Last change on this file since 550 was 550, checked in by cito, 4 years ago

Merge changes in 4.x branch to the trunk

This is mainly the refactoring of the tests for the classic module
which have been split into modules TEST_PyGreSQL_classic*

  • Property svn:executable set to *
  • Property svn:keywords set to Id
File size: 35.1 KB
Line 
1#! /usr/bin/python
2# -*- coding: utf-8 -*-
3
4"""Test the classic PyGreSQL interface.
5
6Sub-tests for the low-level connection object.
7
8Contributed by Christoph Zwerschke.
9
10These tests need a database to test against.
11
12"""
13
14try:
15    import unittest2 as unittest  # for Python < 2.6
16except ImportError:
17    import unittest
18import sys
19import threading
20import time
21
22import pg  # the module under test
23
24from decimal import Decimal
25try:
26    from collections import namedtuple
27except ImportError:  # Python < 2.6
28    namedtuple = None
29
30from StringIO import StringIO
31
32# We need a database to test against.  If LOCAL_PyGreSQL.py exists we will
33# get our information from that.  Otherwise we use the defaults.
34dbname = 'unittest'
35dbhost = None
36dbport = 5432
37
38try:
39    from LOCAL_PyGreSQL import *
40except ImportError:
41    pass
42
43
44def connect():
45    """Create a basic pg connection to the test database."""
46    connection = pg.connect(dbname, dbhost, dbport)
47    connection.query("set client_min_messages=warning")
48    return connection
49
50
51class TestCanConnect(unittest.TestCase):
52    """Test whether a basic connection to PostgreSQL is possible."""
53
54    def testCanConnect(self):
55        try:
56            connection = connect()
57        except pg.Error, error:
58            self.fail('Cannot connect to database %s:\n%s' % (dbname, error))
59        try:
60            connection.close()
61        except pg.Error:
62            self.fail('Cannot close the database connection')
63
64
65class TestConnectObject(unittest.TestCase):
66    """"Test existence of basic pg connection methods."""
67
68    def setUp(self):
69        self.connection = connect()
70
71    def tearDown(self):
72        try:
73            self.connection.close()
74        except pg.InternalError:
75            pass
76
77    def testAllConnectAttributes(self):
78        attributes = '''db error host options port
79            protocol_version server_version status tty user'''.split()
80        connection_attributes = [a for a in dir(self.connection)
81            if not callable(eval("self.connection." + a))]
82        self.assertEqual(attributes, connection_attributes)
83
84    def testAllConnectMethods(self):
85        methods = '''cancel close endcopy
86            escape_bytea escape_identifier escape_literal escape_string
87            fileno get_notice_receiver getline getlo getnotify
88            inserttable locreate loimport parameter putline query reset
89            set_notice_receiver source transaction'''.split()
90        connection_methods = [a for a in dir(self.connection)
91            if callable(eval("self.connection." + a))]
92        self.assertEqual(methods, connection_methods)
93
94    def testAttributeDb(self):
95        self.assertEqual(self.connection.db, dbname)
96
97    def testAttributeError(self):
98        error = self.connection.error
99        self.assertTrue(not error or 'krb5_' in error)
100
101    def testAttributeHost(self):
102        def_host = 'localhost'
103        self.assertIsInstance(self.connection.host, str)
104        self.assertEqual(self.connection.host, dbhost or def_host)
105
106    def testAttributeOptions(self):
107        no_options = ''
108        self.assertEqual(self.connection.options, no_options)
109
110    def testAttributePort(self):
111        def_port = 5432
112        self.assertIsInstance(self.connection.port, int)
113        self.assertEqual(self.connection.port, dbport or def_port)
114
115    def testAttributeProtocolVersion(self):
116        protocol_version = self.connection.protocol_version
117        self.assertIsInstance(protocol_version, int)
118        self.assertTrue(2 <= protocol_version < 4)
119
120    def testAttributeServerVersion(self):
121        server_version = self.connection.server_version
122        self.assertIsInstance(server_version, int)
123        self.assertTrue(70400 <= server_version < 100000)
124
125    def testAttributeStatus(self):
126        status_ok = 1
127        self.assertIsInstance(self.connection.status, int)
128        self.assertEqual(self.connection.status, status_ok)
129
130    def testAttributeTty(self):
131        def_tty = ''
132        self.assertIsInstance(self.connection.tty, str)
133        self.assertEqual(self.connection.tty, def_tty)
134
135    def testAttributeUser(self):
136        no_user = 'Deprecated facility'
137        user = self.connection.user
138        self.assertTrue(user)
139        self.assertIsInstance(user, str)
140        self.assertNotEqual(user, no_user)
141
142    def testMethodQuery(self):
143        query = self.connection.query
144        query("select 1+1")
145        query("select 1+$1", (1,))
146        query("select 1+$1+$2", (2, 3))
147        query("select 1+$1+$2", [2, 3])
148
149    def testMethodQueryEmpty(self):
150        self.assertRaises(ValueError, self.connection.query, '')
151
152    def testMethodEndcopy(self):
153        try:
154            self.connection.endcopy()
155        except IOError:
156            pass
157
158    def testMethodClose(self):
159        self.connection.close()
160        try:
161            self.connection.reset()
162        except (pg.Error, TypeError):
163            pass
164        else:
165            self.fail('Reset should give an error for a closed connection')
166        self.assertRaises(pg.InternalError, self.connection.close)
167        try:
168            self.connection.query('select 1')
169        except (pg.Error, TypeError):
170            pass
171        else:
172            self.fail('Query should give an error for a closed connection')
173        self.connection = connect()
174
175    def testMethodReset(self):
176        query = self.connection.query
177        # check that client encoding gets reset
178        encoding = query('show client_encoding').getresult()[0][0].upper()
179        changed_encoding = 'LATIN1' if encoding == 'UTF8' else 'UTF8'
180        self.assertNotEqual(encoding, changed_encoding)
181        self.connection.query("set client_encoding=%s" % changed_encoding)
182        new_encoding = query('show client_encoding').getresult()[0][0].upper()
183        self.assertEqual(new_encoding, changed_encoding)
184        self.connection.reset()
185        new_encoding = query('show client_encoding').getresult()[0][0].upper()
186        self.assertNotEqual(new_encoding, changed_encoding)
187        self.assertEqual(new_encoding, encoding)
188
189    def testMethodCancel(self):
190        r = self.connection.cancel()
191        self.assertIsInstance(r, int)
192        self.assertEqual(r, 1)
193
194    def testCancelLongRunningThread(self):
195        errors = []
196
197        def sleep():
198            try:
199                self.connection.query('select pg_sleep(5)').getresult()
200            except pg.ProgrammingError, error:
201                errors.append(str(error))
202
203        thread = threading.Thread(target=sleep)
204        t1 = time.time()
205        thread.start()  # run the query
206        while 1:  # make sure the query is really running
207            time.sleep(0.1)
208            if thread.is_alive() or time.time() - t1 > 5:
209                break
210        r = self.connection.cancel()  # cancel the running query
211        thread.join()  # wait for the thread to end
212        t2 = time.time()
213
214        self.assertIsInstance(r, int)
215        self.assertEqual(r, 1)  # return code should be 1
216        self.assertLessEqual(t2 - t1, 3)  # time should be under 3 seconds
217        self.assertTrue(errors)
218
219    def testMethodFileNo(self):
220        r = self.connection.fileno()
221        self.assertIsInstance(r, int)
222        self.assertGreaterEqual(r, 0)
223
224
225class TestSimpleQueries(unittest.TestCase):
226    """"Test simple queries via a basic pg connection."""
227
228    def setUp(self):
229        self.c = connect()
230
231    def tearDown(self):
232        self.c.close()
233
234    def testSelect0(self):
235        q = "select 0"
236        self.c.query(q)
237
238    def testSelect0Semicolon(self):
239        q = "select 0;"
240        self.c.query(q)
241
242    def testSelectDotSemicolon(self):
243        q = "select .;"
244        self.assertRaises(pg.ProgrammingError, self.c.query, q)
245
246    def testGetresult(self):
247        q = "select 0"
248        result = [(0,)]
249        r = self.c.query(q).getresult()
250        self.assertIsInstance(r, list)
251        v = r[0]
252        self.assertIsInstance(v, tuple)
253        self.assertIsInstance(v[0], int)
254        self.assertEqual(r, result)
255
256    def testGetresultLong(self):
257        q = "select 1234567890123456790"
258        result = 1234567890123456790L
259        v = self.c.query(q).getresult()[0][0]
260        self.assertIsInstance(v, long)
261        self.assertEqual(v, result)
262
263    def testGetresultString(self):
264        result = 'Hello, world!'
265        q = "select '%s'" % result
266        v = self.c.query(q).getresult()[0][0]
267        self.assertIsInstance(v, str)
268        self.assertEqual(v, result)
269
270    def testDictresult(self):
271        q = "select 0 as alias0"
272        result = [{'alias0': 0}]
273        r = self.c.query(q).dictresult()
274        self.assertIsInstance(r, list)
275        v = r[0]
276        self.assertIsInstance(v, dict)
277        self.assertIsInstance(v['alias0'], int)
278        self.assertEqual(r, result)
279
280    def testDictresultLong(self):
281        q = "select 1234567890123456790 as longjohnsilver"
282        result = 1234567890123456790L
283        v = self.c.query(q).dictresult()[0]['longjohnsilver']
284        self.assertIsInstance(v, long)
285        self.assertEqual(v, result)
286
287    def testDictresultString(self):
288        result = 'Hello, world!'
289        q = "select '%s' as greeting" % result
290        v = self.c.query(q).dictresult()[0]['greeting']
291        self.assertIsInstance(v, str)
292        self.assertEqual(v, result)
293
294    @unittest.skipUnless(namedtuple, 'Named tuples not available')
295    def testNamedresult(self):
296        q = "select 0 as alias0"
297        result = [(0,)]
298        r = self.c.query(q).namedresult()
299        self.assertEqual(r, result)
300        v = r[0]
301        self.assertEqual(v._fields, ('alias0',))
302        self.assertEqual(v.alias0, 0)
303
304    def testGet3Cols(self):
305        q = "select 1,2,3"
306        result = [(1, 2, 3)]
307        r = self.c.query(q).getresult()
308        self.assertEqual(r, result)
309
310    def testGet3DictCols(self):
311        q = "select 1 as a,2 as b,3 as c"
312        result = [dict(a=1, b=2, c=3)]
313        r = self.c.query(q).dictresult()
314        self.assertEqual(r, result)
315
316    @unittest.skipUnless(namedtuple, 'Named tuples not available')
317    def testGet3NamedCols(self):
318        q = "select 1 as a,2 as b,3 as c"
319        result = [(1, 2, 3)]
320        r = self.c.query(q).namedresult()
321        self.assertEqual(r, result)
322        v = r[0]
323        self.assertEqual(v._fields, ('a', 'b', 'c'))
324        self.assertEqual(v.b, 2)
325
326    def testGet3Rows(self):
327        q = "select 3 union select 1 union select 2 order by 1"
328        result = [(1,), (2,), (3,)]
329        r = self.c.query(q).getresult()
330        self.assertEqual(r, result)
331
332    def testGet3DictRows(self):
333        q = ("select 3 as alias3"
334            " union select 1 union select 2 order by 1")
335        result = [{'alias3': 1}, {'alias3': 2}, {'alias3': 3}]
336        r = self.c.query(q).dictresult()
337        self.assertEqual(r, result)
338
339    @unittest.skipUnless(namedtuple, 'Named tuples not available')
340    def testGet3NamedRows(self):
341        q = ("select 3 as alias3"
342            " union select 1 union select 2 order by 1")
343        result = [(1,), (2,), (3,)]
344        r = self.c.query(q).namedresult()
345        self.assertEqual(r, result)
346        for v in r:
347            self.assertEqual(v._fields, ('alias3',))
348
349    def testDictresultNames(self):
350        q = "select 'MixedCase' as MixedCaseAlias"
351        result = [{'mixedcasealias': 'MixedCase'}]
352        r = self.c.query(q).dictresult()
353        self.assertEqual(r, result)
354        q = "select 'MixedCase' as \"MixedCaseAlias\""
355        result = [{'MixedCaseAlias': 'MixedCase'}]
356        r = self.c.query(q).dictresult()
357        self.assertEqual(r, result)
358
359    @unittest.skipUnless(namedtuple, 'Named tuples not available')
360    def testNamedresultNames(self):
361        q = "select 'MixedCase' as MixedCaseAlias"
362        result = [('MixedCase',)]
363        r = self.c.query(q).namedresult()
364        self.assertEqual(r, result)
365        v = r[0]
366        self.assertEqual(v._fields, ('mixedcasealias',))
367        self.assertEqual(v.mixedcasealias, 'MixedCase')
368        q = "select 'MixedCase' as \"MixedCaseAlias\""
369        r = self.c.query(q).namedresult()
370        self.assertEqual(r, result)
371        v = r[0]
372        self.assertEqual(v._fields, ('MixedCaseAlias',))
373        self.assertEqual(v.MixedCaseAlias, 'MixedCase')
374
375    def testBigGetresult(self):
376        num_cols = 100
377        num_rows = 100
378        q = "select " + ','.join(map(str, xrange(num_cols)))
379        q = ' union all '.join((q,) * num_rows)
380        r = self.c.query(q).getresult()
381        result = [tuple(range(num_cols))] * num_rows
382        self.assertEqual(r, result)
383
384    def testListfields(self):
385        q = ('select 0 as a, 0 as b, 0 as c,'
386            ' 0 as c, 0 as b, 0 as a,'
387            ' 0 as lowercase, 0 as UPPERCASE,'
388            ' 0 as MixedCase, 0 as "MixedCase",'
389            ' 0 as a_long_name_with_underscores,'
390            ' 0 as "A long name with Blanks"')
391        r = self.c.query(q).listfields()
392        result = ('a', 'b', 'c', 'c', 'b', 'a',
393            'lowercase', 'uppercase', 'mixedcase', 'MixedCase',
394            'a_long_name_with_underscores',
395            'A long name with Blanks')
396        self.assertEqual(r, result)
397
398    def testFieldname(self):
399        q = "select 0 as z, 0 as a, 0 as x, 0 as y"
400        r = self.c.query(q).fieldname(2)
401        self.assertEqual(r, 'x')
402        r = self.c.query(q).fieldname(3)
403        self.assertEqual(r, 'y')
404
405    def testFieldnum(self):
406        q = "select 1 as x"
407        self.assertRaises(ValueError, self.c.query(q).fieldnum, 'y')
408        q = "select 1 as x"
409        r = self.c.query(q).fieldnum('x')
410        self.assertIsInstance(r, int)
411        self.assertEqual(r, 0)
412        q = "select 0 as z, 0 as a, 0 as x, 0 as y"
413        r = self.c.query(q).fieldnum('x')
414        self.assertIsInstance(r, int)
415        self.assertEqual(r, 2)
416        r = self.c.query(q).fieldnum('y')
417        self.assertIsInstance(r, int)
418        self.assertEqual(r, 3)
419
420    def testNtuples(self):
421        q = "select 1 where false"
422        r = self.c.query(q).ntuples()
423        self.assertIsInstance(r, int)
424        self.assertEqual(r, 0)
425        q = ("select 1 as a, 2 as b, 3 as c, 4 as d"
426            " union select 5 as a, 6 as b, 7 as c, 8 as d")
427        r = self.c.query(q).ntuples()
428        self.assertIsInstance(r, int)
429        self.assertEqual(r, 2)
430        q = ("select 1 union select 2 union select 3"
431            " union select 4 union select 5 union select 6")
432        r = self.c.query(q).ntuples()
433        self.assertIsInstance(r, int)
434        self.assertEqual(r, 6)
435
436    def testQuery(self):
437        query = self.c.query
438        query("drop table if exists test_table")
439        q = "create table test_table (n integer) with oids"
440        r = query(q)
441        self.assertIsNone(r)
442        q = "insert into test_table values (1)"
443        r = query(q)
444        self.assertIsInstance(r, int)
445        q = "insert into test_table select 2"
446        r = query(q)
447        self.assertIsInstance(r, int)
448        oid = r
449        q = "select oid from test_table where n=2"
450        r = query(q).getresult()
451        self.assertEqual(len(r), 1)
452        r = r[0]
453        self.assertEqual(len(r), 1)
454        r = r[0]
455        self.assertIsInstance(r, int)
456        self.assertEqual(r, oid)
457        q = "insert into test_table select 3 union select 4 union select 5"
458        r = query(q)
459        self.assertIsInstance(r, str)
460        self.assertEqual(r, '3')
461        q = "update test_table set n=4 where n<5"
462        r = query(q)
463        self.assertIsInstance(r, str)
464        self.assertEqual(r, '4')
465        q = "delete from test_table"
466        r = query(q)
467        self.assertIsInstance(r, str)
468        self.assertEqual(r, '5')
469        query("drop table test_table")
470
471    def testPrint(self):
472        q = ("select 1 as a, 'hello' as h, 'w' as world"
473            " union select 2, 'xyz', 'uvw'")
474        r = self.c.query(q)
475        s = StringIO()
476        stdout, sys.stdout = sys.stdout, s
477        try:
478            print r
479        except Exception:
480            pass
481        finally:
482            sys.stdout = stdout
483        r = s.getvalue()
484        s.close()
485        self.assertEqual(r,
486            'a|  h  |world\n'
487            '-+-----+-----\n'
488            '1|hello|w    \n'
489            '2|xyz  |uvw  \n'
490            '(2 rows)\n')
491
492
493class TestParamQueries(unittest.TestCase):
494    """"Test queries with parameters via a basic pg connection."""
495
496    def setUp(self):
497        self.c = connect()
498
499    def tearDown(self):
500        self.c.close()
501
502    def testQueryWithNoneParam(self):
503        self.assertEqual(self.c.query("select $1::integer", (None,)
504            ).getresult(), [(None,)])
505        self.assertEqual(self.c.query("select $1::text", [None]
506            ).getresult(), [(None,)])
507
508    def testQueryWithBoolParams(self):
509        query = self.c.query
510        self.assertEqual(query("select false").getresult(), [('f',)])
511        self.assertEqual(query("select true").getresult(), [('t',)])
512        self.assertEqual(query("select $1::bool", (None,)).getresult(),
513            [(None,)])
514        self.assertEqual(query("select $1::bool", ('f',)).getresult(), [('f',)])
515        self.assertEqual(query("select $1::bool", ('t',)).getresult(), [('t',)])
516        self.assertEqual(query("select $1::bool", ('false',)).getresult(),
517            [('f',)])
518        self.assertEqual(query("select $1::bool", ('true',)).getresult(),
519            [('t',)])
520        self.assertEqual(query("select $1::bool", ('n',)).getresult(), [('f',)])
521        self.assertEqual(query("select $1::bool", ('y',)).getresult(), [('t',)])
522        self.assertEqual(query("select $1::bool", (0,)).getresult(), [('f',)])
523        self.assertEqual(query("select $1::bool", (1,)).getresult(), [('t',)])
524        self.assertEqual(query("select $1::bool", (False,)).getresult(),
525            [('f',)])
526        self.assertEqual(query("select $1::bool", (True,)).getresult(),
527            [('t',)])
528
529    def testQueryWithIntParams(self):
530        query = self.c.query
531        self.assertEqual(query("select 1+1").getresult(), [(2,)])
532        self.assertEqual(query("select 1+$1", (1,)).getresult(), [(2,)])
533        self.assertEqual(query("select 1+$1", [1]).getresult(), [(2,)])
534        self.assertEqual(query("select $1::integer", (2,)).getresult(), [(2,)])
535        self.assertEqual(query("select $1::text", (2,)).getresult(), [('2',)])
536        self.assertEqual(query("select 1+$1::numeric", [1]).getresult(),
537            [(Decimal('2'),)])
538        self.assertEqual(query("select 1, $1::integer", (2,)
539            ).getresult(), [(1, 2)])
540        self.assertEqual(query("select 1 union select $1", (2,)
541            ).getresult(), [(1,), (2,)])
542        self.assertEqual(query("select $1::integer+$2", (1, 2)
543            ).getresult(), [(3,)])
544        self.assertEqual(query("select $1::integer+$2", [1, 2]
545            ).getresult(), [(3,)])
546        self.assertEqual(query("select 0+$1+$2+$3+$4+$5+$6", range(6)
547            ).getresult(), [(15,)])
548
549    def testQueryWithStrParams(self):
550        query = self.c.query
551        self.assertEqual(query("select $1||', world!'", ('Hello',)
552            ).getresult(), [('Hello, world!',)])
553        self.assertEqual(query("select $1||', world!'", ['Hello']
554            ).getresult(), [('Hello, world!',)])
555        self.assertEqual(query("select $1||', '||$2||'!'", ('Hello', 'world'),
556            ).getresult(), [('Hello, world!',)])
557        self.assertEqual(query("select $1::text", ('Hello, world!',)
558            ).getresult(), [('Hello, world!',)])
559        self.assertEqual(query("select $1::text,$2::text", ('Hello', 'world')
560            ).getresult(), [('Hello', 'world')])
561        self.assertEqual(query("select $1::text,$2::text", ['Hello', 'world']
562            ).getresult(), [('Hello', 'world')])
563        self.assertEqual(query("select $1::text union select $2::text",
564            ('Hello', 'world')).getresult(), [('Hello',), ('world',)])
565        self.assertEqual(query("select $1||', '||$2||'!'", ('Hello',
566            'w\xc3\xb6rld')).getresult(), [('Hello, w\xc3\xb6rld!',)])
567
568    def testQueryWithUnicodeParams(self):
569        query = self.c.query
570        query('set client_encoding = utf8')
571        self.assertEqual(query("select $1||', '||$2||'!'",
572            ('Hello', u'w\xf6rld')).getresult(), [('Hello, w\xc3\xb6rld!',)])
573        self.assertEqual(query("select $1||', '||$2||'!'",
574            ('Hello', u'\u043c\u0438\u0440')).getresult(),
575            [('Hello, \xd0\xbc\xd0\xb8\xd1\x80!',)])
576        query('set client_encoding = latin1')
577        self.assertEqual(query("select $1||', '||$2||'!'",
578            ('Hello', u'w\xf6rld')).getresult(), [('Hello, w\xf6rld!',)])
579        self.assertRaises(UnicodeError, query, "select $1||', '||$2||'!'",
580            ('Hello', u'\u043c\u0438\u0440'))
581        query('set client_encoding = iso_8859_1')
582        self.assertEqual(query("select $1||', '||$2||'!'",
583            ('Hello', u'w\xf6rld')).getresult(), [('Hello, w\xf6rld!',)])
584        self.assertRaises(UnicodeError, query, "select $1||', '||$2||'!'",
585            ('Hello', u'\u043c\u0438\u0440'))
586        query('set client_encoding = iso_8859_5')
587        self.assertRaises(UnicodeError, query, "select $1||', '||$2||'!'",
588            ('Hello', u'w\xf6rld'))
589        self.assertEqual(query("select $1||', '||$2||'!'",
590            ('Hello', u'\u043c\u0438\u0440')).getresult(),
591            [('Hello, \xdc\xd8\xe0!',)])
592        query('set client_encoding = sql_ascii')
593        self.assertRaises(UnicodeError, query, "select $1||', '||$2||'!'",
594            ('Hello', u'w\xf6rld'))
595
596    def testQueryWithMixedParams(self):
597        self.assertEqual(self.c.query("select $1+2,$2||', world!'",
598            (1, 'Hello'),).getresult(), [(3, 'Hello, world!')])
599        self.assertEqual(self.c.query("select $1::integer,$2::date,$3::text",
600            (4711, None, 'Hello!'),).getresult(), [(4711, None, 'Hello!')])
601
602    def testQueryWithDuplicateParams(self):
603        self.assertRaises(pg.ProgrammingError,
604            self.c.query, "select $1+$1", (1,))
605        self.assertRaises(pg.ProgrammingError,
606            self.c.query, "select $1+$1", (1, 2))
607
608    def testQueryWithZeroParams(self):
609        self.assertEqual(self.c.query("select 1+1", []
610            ).getresult(), [(2,)])
611
612    def testQueryWithGarbage(self):
613        garbage = r"'\{}+()-#[]oo324"
614        self.assertEqual(self.c.query("select $1::text AS garbage", (garbage,)
615            ).dictresult(), [{'garbage': garbage}])
616
617    def testUnicodeQuery(self):
618        query = self.c.query
619        self.assertEqual(query(u"select 1+1").getresult(), [(2,)])
620        self.assertRaises(TypeError, query, u"select 'Hello, w\xf6rld!'")
621
622
623class TestInserttable(unittest.TestCase):
624    """"Test inserttable method."""
625
626    @classmethod
627    def setUpClass(cls):
628        c = connect()
629        c.query("drop table if exists test cascade")
630        c.query("create table test ("
631            "i2 smallint, i4 integer, i8 bigint, b boolean, dt date, ti time,"
632            "d numeric, f4 real, f8 double precision, m money,"
633            "c char(1), v4 varchar(4), c4 char(4), t text)")
634        c.close()
635
636    @classmethod
637    def tearDownClass(cls):
638        c = connect()
639        c.query("drop table test cascade")
640        c.close()
641
642    def setUp(self):
643        self.c = connect()
644        self.c.query("set datestyle='ISO,YMD'")
645
646    def tearDown(self):
647        self.c.query("truncate table test")
648        self.c.close()
649
650    data = [
651        (-1, -1, -1L, True, '1492-10-12', '08:30:00',
652            -1.2345, -1.75, -1.875, '-1.25', '-', 'r?', '!u', 'xyz'),
653        (0, 0, 0L, False, '1607-04-14', '09:00:00',
654            0.0, 0.0, 0.0, '0.0', ' ', '0123', '4567', '890'),
655        (1, 1, 1L, True, '1801-03-04', '03:45:00',
656            1.23456, 1.75, 1.875, '1.25', 'x', 'bc', 'cdef', 'g'),
657        (2, 2, 2L, False, '1903-12-17', '11:22:00',
658            2.345678, 2.25, 2.125, '2.75', 'y', 'q', 'ijk', 'mnop\nstux!')]
659
660    def get_back(self):
661        """Convert boolean and decimal values back."""
662        data = []
663        for row in self.c.query("select * from test order by 1").getresult():
664            self.assertIsInstance(row, tuple)
665            row = list(row)
666            if row[0] is not None:  # smallint
667                self.assertIsInstance(row[0], int)
668            if row[1] is not None:  # integer
669                self.assertIsInstance(row[1], int)
670            if row[2] is not None:  # bigint
671                self.assertIsInstance(row[2], long)
672            if row[3] is not None:  # boolean
673                self.assertIsInstance(row[3], str)
674                row[3] = {'f': False, 't': True}.get(row[3])
675            if row[4] is not None:  # date
676                self.assertIsInstance(row[4], str)
677                self.assertTrue(row[4].replace('-', '').isdigit())
678            if row[5] is not None:  # time
679                self.assertIsInstance(row[5], str)
680                self.assertTrue(row[5].replace(':', '').isdigit())
681            if row[6] is not None:  # numeric
682                self.assertIsInstance(row[6], Decimal)
683                row[6] = float(row[6])
684            if row[7] is not None:  # real
685                self.assertIsInstance(row[7], float)
686            if row[8] is not None:  # double precision
687                self.assertIsInstance(row[8], float)
688                row[8] = float(row[8])
689            if row[9] is not None:  # money
690                self.assertIsInstance(row[9], Decimal)
691                row[9] = str(float(row[9]))
692            if row[10] is not None:  # char(1)
693                self.assertIsInstance(row[10], str)
694                self.assertEqual(len(row[10]), 1)
695            if row[11] is not None:  # varchar(4)
696                self.assertIsInstance(row[11], str)
697                self.assertLessEqual(len(row[11]), 4)
698            if row[12] is not None:  # char(4)
699                self.assertIsInstance(row[12], str)
700                self.assertEqual(len(row[12]), 4)
701                row[12] = row[12].rstrip()
702            if row[13] is not None:  # text
703                self.assertIsInstance(row[13], str)
704            row = tuple(row)
705            data.append(row)
706        return data
707
708    def testInserttable1Row(self):
709        data = self.data[2:3]
710        self.c.inserttable("test", data)
711        self.assertEqual(self.get_back(), data)
712
713    def testInserttable4Rows(self):
714        data = self.data
715        self.c.inserttable("test", data)
716        self.assertEqual(self.get_back(), data)
717
718    def testInserttableMultipleRows(self):
719        num_rows = 100
720        data = self.data[2:3] * num_rows
721        self.c.inserttable("test", data)
722        r = self.c.query("select count(*) from test").getresult()[0][0]
723        self.assertEqual(r, num_rows)
724
725    def testInserttableMultipleCalls(self):
726        num_rows = 10
727        data = self.data[2:3]
728        for _i in range(num_rows):
729            self.c.inserttable("test", data)
730        r = self.c.query("select count(*) from test").getresult()[0][0]
731        self.assertEqual(r, num_rows)
732
733    def testInserttableNullValues(self):
734        data = [(None,) * 14] * 100
735        self.c.inserttable("test", data)
736        self.assertEqual(self.get_back(), data)
737
738    def testInserttableMaxValues(self):
739        data = [(2 ** 15 - 1, int(2 ** 31 - 1), long(2 ** 31 - 1),
740            True, '2999-12-31', '11:59:59', 1e99,
741            1.0 + 1.0 / 32, 1.0 + 1.0 / 32, None,
742            "1", "1234", "1234", "1234" * 100)]
743        self.c.inserttable("test", data)
744        self.assertEqual(self.get_back(), data)
745
746
747class TestDirectSocketAccess(unittest.TestCase):
748    """"Test copy command with direct socket access."""
749
750    @classmethod
751    def setUpClass(cls):
752        c = connect()
753        c.query("drop table if exists test cascade")
754        c.query("create table test (i int, v varchar(16))")
755        c.close()
756
757    @classmethod
758    def tearDownClass(cls):
759        c = connect()
760        c.query("drop table test cascade")
761        c.close()
762
763    def setUp(self):
764        self.c = connect()
765        self.c.query("set datestyle='ISO,YMD'")
766
767    def tearDown(self):
768        self.c.query("truncate table test")
769        self.c.close()
770
771    def testPutline(self):
772        putline = self.c.putline
773        query = self.c.query
774        data = list(enumerate("apple pear plum cherry banana".split()))
775        query("copy test from stdin")
776        try:
777            for i, v in data:
778                putline("%d\t%s\n" % (i, v))
779            putline("\\.\n")
780        finally:
781            self.c.endcopy()
782        r = query("select * from test").getresult()
783        self.assertEqual(r, data)
784
785    def testPutline(self):
786        getline = self.c.getline
787        query = self.c.query
788        data = list(enumerate("apple banana pear plum strawberry".split()))
789        n = len(data)
790        self.c.inserttable('test', data)
791        query("copy test to stdout")
792        try:
793            for i in range(n + 2):
794                v = getline()
795                if i < n:
796                    self.assertEqual(v, '%d\t%s' % data[i])
797                elif i == n:
798                    self.assertEqual(v, '\\.')
799                else:
800                    self.assertIsNone(v)
801        finally:
802            try:
803                self.c.endcopy()
804            except IOError:
805                pass
806
807    def testParameterChecks(self):
808        self.assertRaises(TypeError, self.c.putline)
809        self.assertRaises(TypeError, self.c.getline, 'invalid')
810        self.assertRaises(TypeError, self.c.endcopy, 'invalid')
811
812
813class TestNotificatons(unittest.TestCase):
814    """"Test notification support."""
815
816    def setUp(self):
817        self.c = connect()
818
819    def tearDown(self):
820        self.c.close()
821
822    def testGetNotify(self):
823        getnotify = self.c.getnotify
824        query = self.c.query
825        self.assertIsNone(getnotify())
826        query('listen test_notify')
827        try:
828            self.assertIsNone(self.c.getnotify())
829            query("notify test_notify")
830            r = getnotify()
831            self.assertIsInstance(r, tuple)
832            self.assertEqual(len(r), 3)
833            self.assertIsInstance(r[0], str)
834            self.assertIsInstance(r[1], int)
835            self.assertIsInstance(r[2], str)
836            self.assertEqual(r[0], 'test_notify')
837            self.assertEqual(r[2], '')
838            self.assertIsNone(self.c.getnotify())
839            try:
840                query("notify test_notify, 'test_payload'")
841            except pg.ProgrammingError:  # PostgreSQL < 9.0
842                pass
843            else:
844                r = getnotify()
845                self.assertTrue(isinstance(r, tuple))
846                self.assertEqual(len(r), 3)
847                self.assertIsInstance(r[0], str)
848                self.assertIsInstance(r[1], int)
849                self.assertIsInstance(r[2], str)
850                self.assertEqual(r[0], 'test_notify')
851                self.assertEqual(r[2], 'test_payload')
852                self.assertIsNone(getnotify())
853        finally:
854            query('unlisten test_notify')
855
856    def testGetNoticeReceiver(self):
857        self.assertIsNone(self.c.get_notice_receiver())
858
859    def testSetNoticeReceiver(self):
860        self.assertRaises(TypeError, self.c.set_notice_receiver, None)
861        self.assertRaises(TypeError, self.c.set_notice_receiver, 42)
862        self.assertIsNone(self.c.set_notice_receiver(lambda notice: None))
863
864    def testSetAndGetNoticeReceiver(self):
865        r = lambda notice: None
866        self.assertIsNone(self.c.set_notice_receiver(r))
867        self.assertIs(self.c.get_notice_receiver(), r)
868
869    def testNoticeReceiver(self):
870        self.c.query('''create function bilbo_notice() returns void AS $$
871            begin
872                raise warning 'Bilbo was here!';
873            end;
874            $$ language plpgsql''')
875        try:
876            received = {}
877
878            def notice_receiver(notice):
879                for attr in dir(notice):
880                    value = getattr(notice, attr)
881                    if isinstance(value, str):
882                        value = value.replace('WARNUNG', 'WARNING')
883                    received[attr] = value
884
885            self.c.set_notice_receiver(notice_receiver)
886            self.c.query('''select bilbo_notice()''')
887            self.assertEqual(received, dict(
888                pgcnx=self.c, message='WARNING:  Bilbo was here!\n',
889                severity='WARNING', primary='Bilbo was here!',
890                detail=None, hint=None))
891        finally:
892            self.c.query('''drop function bilbo_notice();''')
893
894
895class TestConfigFunctions(unittest.TestCase):
896    """Test the functions for changing default settings.
897
898    To test the effect of most of these functions, we need a database
899    connection.  That's why they are covered in this test module.
900
901    """
902
903    def setUp(self):
904        self.c = connect()
905
906    def tearDown(self):
907        self.c.close()
908
909    def testGetDecimalPoint(self):
910        point = pg.get_decimal_point()
911        self.assertIsInstance(point, str)
912        self.assertEqual(point, '.')
913
914    def testSetDecimalPoint(self):
915        d = pg.Decimal
916        point = pg.get_decimal_point()
917        query = self.c.query
918        # check that money values can be interpreted correctly
919        # if and only if the decimal point is set appropriately
920        # for the current lc_monetary setting
921        query("set lc_monetary='en_US'")
922        pg.set_decimal_point('.')
923        r = query("select '34.25'::money").getresult()[0][0]
924        self.assertIsInstance(r, d)
925        self.assertEqual(r, d('34.25'))
926        pg.set_decimal_point(',')
927        r = query("select '34.25'::money").getresult()[0][0]
928        self.assertNotEqual(r, d('34.25'))
929        query("set lc_monetary='de_DE'")
930        pg.set_decimal_point(',')
931        r = query("select '34,25'::money").getresult()[0][0]
932        self.assertIsInstance(r, d)
933        self.assertEqual(r, d('34.25'))
934        pg.set_decimal_point('.')
935        r = query("select '34,25'::money").getresult()[0][0]
936        self.assertNotEqual(r, d('34.25'))
937        pg.set_decimal_point(point)
938
939    def testSetDecimal(self):
940        d = pg.Decimal
941        query = self.c.query
942        r = query("select 3425::numeric").getresult()[0][0]
943        self.assertIsInstance(r, d)
944        self.assertEqual(r, d('3425'))
945        pg.set_decimal(long)
946        r = query("select 3425::numeric").getresult()[0][0]
947        self.assertNotIsInstance(r, d)
948        self.assertIsInstance(r, long)
949        self.assertEqual(r, 3425L)
950        pg.set_decimal(d)
951
952    @unittest.skipUnless(namedtuple, 'Named tuples not available')
953    def testSetNamedresult(self):
954        query = self.c.query
955
956        r = query("select 1 as x, 2 as y").namedresult()[0]
957        self.assertIsInstance(r, tuple)
958        self.assertEqual(r, (1, 2))
959        self.assertIsNot(type(r), tuple)
960        self.assertEqual(r._fields, ('x', 'y'))
961        self.assertEqual(r._asdict(), {'x': 1, 'y': 2})
962        self.assertEqual(r.__class__.__name__, 'Row')
963
964        _namedresult = pg._namedresult
965        self.assertTrue(callable(_namedresult))
966        pg.set_namedresult(_namedresult)
967
968        r = query("select 1 as x, 2 as y").namedresult()[0]
969        self.assertIsInstance(r, tuple)
970        self.assertEqual(r, (1, 2))
971        self.assertIsNot(type(r), tuple)
972        self.assertEqual(r._fields, ('x', 'y'))
973        self.assertEqual(r._asdict(), {'x': 1, 'y': 2})
974        self.assertEqual(r.__class__.__name__, 'Row')
975
976        def _listresult(q):
977            return map(list, q.getresult())
978
979        pg.set_namedresult(_listresult)
980
981        try:
982            r = query("select 1 as x, 2 as y").namedresult()[0]
983            self.assertIsInstance(r, list)
984            self.assertEqual(r, [1, 2])
985            self.assertIsNot(type(r), tuple)
986            self.assertFalse(hasattr(r, '_fields'))
987            self.assertNotEqual(r.__class__.__name__, 'Row')
988        finally:
989            pg.set_namedresult(_namedresult)
990
991
992if __name__ == '__main__':
993    unittest.main()
Note: See TracBrowser for help on using the repository browser.