source: branches/4.x/module/TEST_PyGreSQL_classic_connection.py @ 565

Last change on this file since 565 was 565, checked in by cito, 4 years ago

Using a better locale name for tests

In the test we need to force Postgres to use a money format that does
not use a dot as decimal point. Maybe we should try several common
locale names to get this test running on different platforms?

  • Property svn:executable set to *
  • Property svn:keywords set to Id
File size: 35.1 KB
Line 
1#! /usr/bin/python
2# -*- coding: utf-8 -*-
3
4"""Test the classic PyGreSQL interface.
5
6Sub-tests for the low-level connection object.
7
8Contributed by Christoph Zwerschke.
9
10These tests need a database to test against.
11
12"""
13
14try:
15    import unittest2 as unittest  # for Python < 2.7
16except ImportError:
17    import unittest
18import sys
19import tempfile
20import threading
21import time
22
23import pg  # the module under test
24
25from decimal import Decimal
26try:
27    from collections import namedtuple
28except ImportError:  # Python < 2.6
29    namedtuple = None
30
31# We need a database to test against.  If LOCAL_PyGreSQL.py exists we will
32# get our information from that.  Otherwise we use the defaults.
33dbname = 'unittest'
34dbhost = None
35dbport = 5432
36
37try:
38    from LOCAL_PyGreSQL import *
39except ImportError:
40    pass
41
42
43def connect():
44    """Create a basic pg connection to the test database."""
45    connection = pg.connect(dbname, dbhost, dbport)
46    connection.query("set client_min_messages=warning")
47    return connection
48
49
50class TestCanConnect(unittest.TestCase):
51    """Test whether a basic connection to PostgreSQL is possible."""
52
53    def testCanConnect(self):
54        try:
55            connection = connect()
56        except pg.Error, error:
57            self.fail('Cannot connect to database %s:\n%s' % (dbname, error))
58        try:
59            connection.close()
60        except pg.Error:
61            self.fail('Cannot close the database connection')
62
63
64class TestConnectObject(unittest.TestCase):
65    """"Test existence of basic pg connection methods."""
66
67    def setUp(self):
68        self.connection = connect()
69
70    def tearDown(self):
71        try:
72            self.connection.close()
73        except pg.InternalError:
74            pass
75
76    def testAllConnectAttributes(self):
77        attributes = '''db error host options port
78            protocol_version server_version status tty user'''.split()
79        connection_attributes = [a for a in dir(self.connection)
80            if not callable(eval("self.connection." + a))]
81        self.assertEqual(attributes, connection_attributes)
82
83    def testAllConnectMethods(self):
84        methods = '''cancel close endcopy
85            escape_bytea escape_identifier escape_literal escape_string
86            fileno get_notice_receiver getline getlo getnotify
87            inserttable locreate loimport parameter putline query reset
88            set_notice_receiver source transaction'''.split()
89        connection_methods = [a for a in dir(self.connection)
90            if callable(eval("self.connection." + a))]
91        self.assertEqual(methods, connection_methods)
92
93    def testAttributeDb(self):
94        self.assertEqual(self.connection.db, dbname)
95
96    def testAttributeError(self):
97        error = self.connection.error
98        self.assertTrue(not error or 'krb5_' in error)
99
100    def testAttributeHost(self):
101        def_host = 'localhost'
102        self.assertIsInstance(self.connection.host, str)
103        self.assertEqual(self.connection.host, dbhost or def_host)
104
105    def testAttributeOptions(self):
106        no_options = ''
107        self.assertEqual(self.connection.options, no_options)
108
109    def testAttributePort(self):
110        def_port = 5432
111        self.assertIsInstance(self.connection.port, int)
112        self.assertEqual(self.connection.port, dbport or def_port)
113
114    def testAttributeProtocolVersion(self):
115        protocol_version = self.connection.protocol_version
116        self.assertIsInstance(protocol_version, int)
117        self.assertTrue(2 <= protocol_version < 4)
118
119    def testAttributeServerVersion(self):
120        server_version = self.connection.server_version
121        self.assertIsInstance(server_version, int)
122        self.assertTrue(70400 <= server_version < 100000)
123
124    def testAttributeStatus(self):
125        status_ok = 1
126        self.assertIsInstance(self.connection.status, int)
127        self.assertEqual(self.connection.status, status_ok)
128
129    def testAttributeTty(self):
130        def_tty = ''
131        self.assertIsInstance(self.connection.tty, str)
132        self.assertEqual(self.connection.tty, def_tty)
133
134    def testAttributeUser(self):
135        no_user = 'Deprecated facility'
136        user = self.connection.user
137        self.assertTrue(user)
138        self.assertIsInstance(user, str)
139        self.assertNotEqual(user, no_user)
140
141    def testMethodQuery(self):
142        query = self.connection.query
143        query("select 1+1")
144        query("select 1+$1", (1,))
145        query("select 1+$1+$2", (2, 3))
146        query("select 1+$1+$2", [2, 3])
147
148    def testMethodQueryEmpty(self):
149        self.assertRaises(ValueError, self.connection.query, '')
150
151    def testMethodEndcopy(self):
152        try:
153            self.connection.endcopy()
154        except IOError:
155            pass
156
157    def testMethodClose(self):
158        self.connection.close()
159        try:
160            self.connection.reset()
161        except (pg.Error, TypeError):
162            pass
163        else:
164            self.fail('Reset should give an error for a closed connection')
165        self.assertRaises(pg.InternalError, self.connection.close)
166        try:
167            self.connection.query('select 1')
168        except (pg.Error, TypeError):
169            pass
170        else:
171            self.fail('Query should give an error for a closed connection')
172        self.connection = connect()
173
174    def testMethodReset(self):
175        query = self.connection.query
176        # check that client encoding gets reset
177        encoding = query('show client_encoding').getresult()[0][0].upper()
178        changed_encoding = 'LATIN1' if encoding == 'UTF8' else 'UTF8'
179        self.assertNotEqual(encoding, changed_encoding)
180        self.connection.query("set client_encoding=%s" % changed_encoding)
181        new_encoding = query('show client_encoding').getresult()[0][0].upper()
182        self.assertEqual(new_encoding, changed_encoding)
183        self.connection.reset()
184        new_encoding = query('show client_encoding').getresult()[0][0].upper()
185        self.assertNotEqual(new_encoding, changed_encoding)
186        self.assertEqual(new_encoding, encoding)
187
188    def testMethodCancel(self):
189        r = self.connection.cancel()
190        self.assertIsInstance(r, int)
191        self.assertEqual(r, 1)
192
193    def testCancelLongRunningThread(self):
194        errors = []
195
196        def sleep():
197            try:
198                self.connection.query('select pg_sleep(5)').getresult()
199            except pg.ProgrammingError, error:
200                errors.append(str(error))
201
202        thread = threading.Thread(target=sleep)
203        t1 = time.time()
204        thread.start()  # run the query
205        while 1:  # make sure the query is really running
206            time.sleep(0.1)
207            if thread.is_alive() or time.time() - t1 > 5:
208                break
209        r = self.connection.cancel()  # cancel the running query
210        thread.join()  # wait for the thread to end
211        t2 = time.time()
212
213        self.assertIsInstance(r, int)
214        self.assertEqual(r, 1)  # return code should be 1
215        self.assertLessEqual(t2 - t1, 3)  # time should be under 3 seconds
216        self.assertTrue(errors)
217
218    def testMethodFileNo(self):
219        r = self.connection.fileno()
220        self.assertIsInstance(r, int)
221        self.assertGreaterEqual(r, 0)
222
223
224class TestSimpleQueries(unittest.TestCase):
225    """"Test simple queries via a basic pg connection."""
226
227    def setUp(self):
228        self.c = connect()
229
230    def tearDown(self):
231        self.c.close()
232
233    def testSelect0(self):
234        q = "select 0"
235        self.c.query(q)
236
237    def testSelect0Semicolon(self):
238        q = "select 0;"
239        self.c.query(q)
240
241    def testSelectDotSemicolon(self):
242        q = "select .;"
243        self.assertRaises(pg.ProgrammingError, self.c.query, q)
244
245    def testGetresult(self):
246        q = "select 0"
247        result = [(0,)]
248        r = self.c.query(q).getresult()
249        self.assertIsInstance(r, list)
250        v = r[0]
251        self.assertIsInstance(v, tuple)
252        self.assertIsInstance(v[0], int)
253        self.assertEqual(r, result)
254
255    def testGetresultLong(self):
256        q = "select 1234567890123456790"
257        result = 1234567890123456790L
258        v = self.c.query(q).getresult()[0][0]
259        self.assertIsInstance(v, long)
260        self.assertEqual(v, result)
261
262    def testGetresultString(self):
263        result = 'Hello, world!'
264        q = "select '%s'" % result
265        v = self.c.query(q).getresult()[0][0]
266        self.assertIsInstance(v, str)
267        self.assertEqual(v, result)
268
269    def testDictresult(self):
270        q = "select 0 as alias0"
271        result = [{'alias0': 0}]
272        r = self.c.query(q).dictresult()
273        self.assertIsInstance(r, list)
274        v = r[0]
275        self.assertIsInstance(v, dict)
276        self.assertIsInstance(v['alias0'], int)
277        self.assertEqual(r, result)
278
279    def testDictresultLong(self):
280        q = "select 1234567890123456790 as longjohnsilver"
281        result = 1234567890123456790L
282        v = self.c.query(q).dictresult()[0]['longjohnsilver']
283        self.assertIsInstance(v, long)
284        self.assertEqual(v, result)
285
286    def testDictresultString(self):
287        result = 'Hello, world!'
288        q = "select '%s' as greeting" % result
289        v = self.c.query(q).dictresult()[0]['greeting']
290        self.assertIsInstance(v, str)
291        self.assertEqual(v, result)
292
293    @unittest.skipUnless(namedtuple, 'Named tuples not available')
294    def testNamedresult(self):
295        q = "select 0 as alias0"
296        result = [(0,)]
297        r = self.c.query(q).namedresult()
298        self.assertEqual(r, result)
299        v = r[0]
300        self.assertEqual(v._fields, ('alias0',))
301        self.assertEqual(v.alias0, 0)
302
303    def testGet3Cols(self):
304        q = "select 1,2,3"
305        result = [(1, 2, 3)]
306        r = self.c.query(q).getresult()
307        self.assertEqual(r, result)
308
309    def testGet3DictCols(self):
310        q = "select 1 as a,2 as b,3 as c"
311        result = [dict(a=1, b=2, c=3)]
312        r = self.c.query(q).dictresult()
313        self.assertEqual(r, result)
314
315    @unittest.skipUnless(namedtuple, 'Named tuples not available')
316    def testGet3NamedCols(self):
317        q = "select 1 as a,2 as b,3 as c"
318        result = [(1, 2, 3)]
319        r = self.c.query(q).namedresult()
320        self.assertEqual(r, result)
321        v = r[0]
322        self.assertEqual(v._fields, ('a', 'b', 'c'))
323        self.assertEqual(v.b, 2)
324
325    def testGet3Rows(self):
326        q = "select 3 union select 1 union select 2 order by 1"
327        result = [(1,), (2,), (3,)]
328        r = self.c.query(q).getresult()
329        self.assertEqual(r, result)
330
331    def testGet3DictRows(self):
332        q = ("select 3 as alias3"
333            " union select 1 union select 2 order by 1")
334        result = [{'alias3': 1}, {'alias3': 2}, {'alias3': 3}]
335        r = self.c.query(q).dictresult()
336        self.assertEqual(r, result)
337
338    @unittest.skipUnless(namedtuple, 'Named tuples not available')
339    def testGet3NamedRows(self):
340        q = ("select 3 as alias3"
341            " union select 1 union select 2 order by 1")
342        result = [(1,), (2,), (3,)]
343        r = self.c.query(q).namedresult()
344        self.assertEqual(r, result)
345        for v in r:
346            self.assertEqual(v._fields, ('alias3',))
347
348    def testDictresultNames(self):
349        q = "select 'MixedCase' as MixedCaseAlias"
350        result = [{'mixedcasealias': 'MixedCase'}]
351        r = self.c.query(q).dictresult()
352        self.assertEqual(r, result)
353        q = "select 'MixedCase' as \"MixedCaseAlias\""
354        result = [{'MixedCaseAlias': 'MixedCase'}]
355        r = self.c.query(q).dictresult()
356        self.assertEqual(r, result)
357
358    @unittest.skipUnless(namedtuple, 'Named tuples not available')
359    def testNamedresultNames(self):
360        q = "select 'MixedCase' as MixedCaseAlias"
361        result = [('MixedCase',)]
362        r = self.c.query(q).namedresult()
363        self.assertEqual(r, result)
364        v = r[0]
365        self.assertEqual(v._fields, ('mixedcasealias',))
366        self.assertEqual(v.mixedcasealias, 'MixedCase')
367        q = "select 'MixedCase' as \"MixedCaseAlias\""
368        r = self.c.query(q).namedresult()
369        self.assertEqual(r, result)
370        v = r[0]
371        self.assertEqual(v._fields, ('MixedCaseAlias',))
372        self.assertEqual(v.MixedCaseAlias, 'MixedCase')
373
374    def testBigGetresult(self):
375        num_cols = 100
376        num_rows = 100
377        q = "select " + ','.join(map(str, xrange(num_cols)))
378        q = ' union all '.join((q,) * num_rows)
379        r = self.c.query(q).getresult()
380        result = [tuple(range(num_cols))] * num_rows
381        self.assertEqual(r, result)
382
383    def testListfields(self):
384        q = ('select 0 as a, 0 as b, 0 as c,'
385            ' 0 as c, 0 as b, 0 as a,'
386            ' 0 as lowercase, 0 as UPPERCASE,'
387            ' 0 as MixedCase, 0 as "MixedCase",'
388            ' 0 as a_long_name_with_underscores,'
389            ' 0 as "A long name with Blanks"')
390        r = self.c.query(q).listfields()
391        result = ('a', 'b', 'c', 'c', 'b', 'a',
392            'lowercase', 'uppercase', 'mixedcase', 'MixedCase',
393            'a_long_name_with_underscores',
394            'A long name with Blanks')
395        self.assertEqual(r, result)
396
397    def testFieldname(self):
398        q = "select 0 as z, 0 as a, 0 as x, 0 as y"
399        r = self.c.query(q).fieldname(2)
400        self.assertEqual(r, 'x')
401        r = self.c.query(q).fieldname(3)
402        self.assertEqual(r, 'y')
403
404    def testFieldnum(self):
405        q = "select 1 as x"
406        self.assertRaises(ValueError, self.c.query(q).fieldnum, 'y')
407        q = "select 1 as x"
408        r = self.c.query(q).fieldnum('x')
409        self.assertIsInstance(r, int)
410        self.assertEqual(r, 0)
411        q = "select 0 as z, 0 as a, 0 as x, 0 as y"
412        r = self.c.query(q).fieldnum('x')
413        self.assertIsInstance(r, int)
414        self.assertEqual(r, 2)
415        r = self.c.query(q).fieldnum('y')
416        self.assertIsInstance(r, int)
417        self.assertEqual(r, 3)
418
419    def testNtuples(self):
420        q = "select 1 where false"
421        r = self.c.query(q).ntuples()
422        self.assertIsInstance(r, int)
423        self.assertEqual(r, 0)
424        q = ("select 1 as a, 2 as b, 3 as c, 4 as d"
425            " union select 5 as a, 6 as b, 7 as c, 8 as d")
426        r = self.c.query(q).ntuples()
427        self.assertIsInstance(r, int)
428        self.assertEqual(r, 2)
429        q = ("select 1 union select 2 union select 3"
430            " union select 4 union select 5 union select 6")
431        r = self.c.query(q).ntuples()
432        self.assertIsInstance(r, int)
433        self.assertEqual(r, 6)
434
435    def testQuery(self):
436        query = self.c.query
437        query("drop table if exists test_table")
438        q = "create table test_table (n integer) with oids"
439        r = query(q)
440        self.assertIsNone(r)
441        q = "insert into test_table values (1)"
442        r = query(q)
443        self.assertIsInstance(r, int)
444        q = "insert into test_table select 2"
445        r = query(q)
446        self.assertIsInstance(r, int)
447        oid = r
448        q = "select oid from test_table where n=2"
449        r = query(q).getresult()
450        self.assertEqual(len(r), 1)
451        r = r[0]
452        self.assertEqual(len(r), 1)
453        r = r[0]
454        self.assertIsInstance(r, int)
455        self.assertEqual(r, oid)
456        q = "insert into test_table select 3 union select 4 union select 5"
457        r = query(q)
458        self.assertIsInstance(r, str)
459        self.assertEqual(r, '3')
460        q = "update test_table set n=4 where n<5"
461        r = query(q)
462        self.assertIsInstance(r, str)
463        self.assertEqual(r, '4')
464        q = "delete from test_table"
465        r = query(q)
466        self.assertIsInstance(r, str)
467        self.assertEqual(r, '5')
468        query("drop table test_table")
469
470    def testPrint(self):
471        q = ("select 1 as a, 'hello' as h, 'w' as world"
472            " union select 2, 'xyz', 'uvw'")
473        r = self.c.query(q)
474        f = tempfile.TemporaryFile()
475        stdout, sys.stdout = sys.stdout, f
476        try:
477            print r
478        except Exception:
479            pass
480        finally:
481            sys.stdout = stdout
482        f.seek(0)
483        r = f.read()
484        f.close()
485        self.assertEqual(r,
486            'a|  h  |world\n'
487            '-+-----+-----\n'
488            '1|hello|w    \n'
489            '2|xyz  |uvw  \n'
490            '(2 rows)\n')
491
492
493class TestParamQueries(unittest.TestCase):
494    """"Test queries with parameters via a basic pg connection."""
495
496    def setUp(self):
497        self.c = connect()
498
499    def tearDown(self):
500        self.c.close()
501
502    def testQueryWithNoneParam(self):
503        self.assertEqual(self.c.query("select $1::integer", (None,)
504            ).getresult(), [(None,)])
505        self.assertEqual(self.c.query("select $1::text", [None]
506            ).getresult(), [(None,)])
507
508    def testQueryWithBoolParams(self):
509        query = self.c.query
510        self.assertEqual(query("select false").getresult(), [('f',)])
511        self.assertEqual(query("select true").getresult(), [('t',)])
512        self.assertEqual(query("select $1::bool", (None,)).getresult(),
513            [(None,)])
514        self.assertEqual(query("select $1::bool", ('f',)).getresult(), [('f',)])
515        self.assertEqual(query("select $1::bool", ('t',)).getresult(), [('t',)])
516        self.assertEqual(query("select $1::bool", ('false',)).getresult(),
517            [('f',)])
518        self.assertEqual(query("select $1::bool", ('true',)).getresult(),
519            [('t',)])
520        self.assertEqual(query("select $1::bool", ('n',)).getresult(), [('f',)])
521        self.assertEqual(query("select $1::bool", ('y',)).getresult(), [('t',)])
522        self.assertEqual(query("select $1::bool", (0,)).getresult(), [('f',)])
523        self.assertEqual(query("select $1::bool", (1,)).getresult(), [('t',)])
524        self.assertEqual(query("select $1::bool", (False,)).getresult(),
525            [('f',)])
526        self.assertEqual(query("select $1::bool", (True,)).getresult(),
527            [('t',)])
528
529    def testQueryWithIntParams(self):
530        query = self.c.query
531        self.assertEqual(query("select 1+1").getresult(), [(2,)])
532        self.assertEqual(query("select 1+$1", (1,)).getresult(), [(2,)])
533        self.assertEqual(query("select 1+$1", [1]).getresult(), [(2,)])
534        self.assertEqual(query("select $1::integer", (2,)).getresult(), [(2,)])
535        self.assertEqual(query("select $1::text", (2,)).getresult(), [('2',)])
536        self.assertEqual(query("select 1+$1::numeric", [1]).getresult(),
537            [(Decimal('2'),)])
538        self.assertEqual(query("select 1, $1::integer", (2,)
539            ).getresult(), [(1, 2)])
540        self.assertEqual(query("select 1 union select $1", (2,)
541            ).getresult(), [(1,), (2,)])
542        self.assertEqual(query("select $1::integer+$2", (1, 2)
543            ).getresult(), [(3,)])
544        self.assertEqual(query("select $1::integer+$2", [1, 2]
545            ).getresult(), [(3,)])
546        self.assertEqual(query("select 0+$1+$2+$3+$4+$5+$6", range(6)
547            ).getresult(), [(15,)])
548
549    def testQueryWithStrParams(self):
550        query = self.c.query
551        self.assertEqual(query("select $1||', world!'", ('Hello',)
552            ).getresult(), [('Hello, world!',)])
553        self.assertEqual(query("select $1||', world!'", ['Hello']
554            ).getresult(), [('Hello, world!',)])
555        self.assertEqual(query("select $1||', '||$2||'!'", ('Hello', 'world'),
556            ).getresult(), [('Hello, world!',)])
557        self.assertEqual(query("select $1::text", ('Hello, world!',)
558            ).getresult(), [('Hello, world!',)])
559        self.assertEqual(query("select $1::text,$2::text", ('Hello', 'world')
560            ).getresult(), [('Hello', 'world')])
561        self.assertEqual(query("select $1::text,$2::text", ['Hello', 'world']
562            ).getresult(), [('Hello', 'world')])
563        self.assertEqual(query("select $1::text union select $2::text",
564            ('Hello', 'world')).getresult(), [('Hello',), ('world',)])
565        self.assertEqual(query("select $1||', '||$2||'!'", ('Hello',
566            'w\xc3\xb6rld')).getresult(), [('Hello, w\xc3\xb6rld!',)])
567
568    def testQueryWithUnicodeParams(self):
569        query = self.c.query
570        query('set client_encoding = utf8')
571        self.assertEqual(query("select $1||', '||$2||'!'",
572            ('Hello', u'w\xf6rld')).getresult(), [('Hello, w\xc3\xb6rld!',)])
573        self.assertEqual(query("select $1||', '||$2||'!'",
574            ('Hello', u'\u043c\u0438\u0440')).getresult(),
575            [('Hello, \xd0\xbc\xd0\xb8\xd1\x80!',)])
576        query('set client_encoding = latin1')
577        self.assertEqual(query("select $1||', '||$2||'!'",
578            ('Hello', u'w\xf6rld')).getresult(), [('Hello, w\xf6rld!',)])
579        self.assertRaises(UnicodeError, query, "select $1||', '||$2||'!'",
580            ('Hello', u'\u043c\u0438\u0440'))
581        query('set client_encoding = iso_8859_1')
582        self.assertEqual(query("select $1||', '||$2||'!'",
583            ('Hello', u'w\xf6rld')).getresult(), [('Hello, w\xf6rld!',)])
584        self.assertRaises(UnicodeError, query, "select $1||', '||$2||'!'",
585            ('Hello', u'\u043c\u0438\u0440'))
586        query('set client_encoding = iso_8859_5')
587        self.assertRaises(UnicodeError, query, "select $1||', '||$2||'!'",
588            ('Hello', u'w\xf6rld'))
589        self.assertEqual(query("select $1||', '||$2||'!'",
590            ('Hello', u'\u043c\u0438\u0440')).getresult(),
591            [('Hello, \xdc\xd8\xe0!',)])
592        query('set client_encoding = sql_ascii')
593        self.assertRaises(UnicodeError, query, "select $1||', '||$2||'!'",
594            ('Hello', u'w\xf6rld'))
595
596    def testQueryWithMixedParams(self):
597        self.assertEqual(self.c.query("select $1+2,$2||', world!'",
598            (1, 'Hello'),).getresult(), [(3, 'Hello, world!')])
599        self.assertEqual(self.c.query("select $1::integer,$2::date,$3::text",
600            (4711, None, 'Hello!'),).getresult(), [(4711, None, 'Hello!')])
601
602    def testQueryWithDuplicateParams(self):
603        self.assertRaises(pg.ProgrammingError,
604            self.c.query, "select $1+$1", (1,))
605        self.assertRaises(pg.ProgrammingError,
606            self.c.query, "select $1+$1", (1, 2))
607
608    def testQueryWithZeroParams(self):
609        self.assertEqual(self.c.query("select 1+1", []
610            ).getresult(), [(2,)])
611
612    def testQueryWithGarbage(self):
613        garbage = r"'\{}+()-#[]oo324"
614        self.assertEqual(self.c.query("select $1::text AS garbage", (garbage,)
615            ).dictresult(), [{'garbage': garbage}])
616
617    def testUnicodeQuery(self):
618        query = self.c.query
619        self.assertEqual(query(u"select 1+1").getresult(), [(2,)])
620        self.assertRaises(TypeError, query, u"select 'Hello, w\xf6rld!'")
621
622
623class TestInserttable(unittest.TestCase):
624    """"Test inserttable method."""
625
626    @classmethod
627    def setUpClass(cls):
628        c = connect()
629        c.query("drop table if exists test cascade")
630        c.query("create table test ("
631            "i2 smallint, i4 integer, i8 bigint, b boolean, dt date, ti time,"
632            "d numeric, f4 real, f8 double precision, m money,"
633            "c char(1), v4 varchar(4), c4 char(4), t text)")
634        c.close()
635
636    @classmethod
637    def tearDownClass(cls):
638        c = connect()
639        c.query("drop table test cascade")
640        c.close()
641
642    def setUp(self):
643        self.c = connect()
644        self.c.query("set datestyle='ISO,YMD'")
645
646    def tearDown(self):
647        self.c.query("truncate table test")
648        self.c.close()
649
650    data = [
651        (-1, -1, -1L, True, '1492-10-12', '08:30:00',
652            -1.2345, -1.75, -1.875, '-1.25', '-', 'r?', '!u', 'xyz'),
653        (0, 0, 0L, False, '1607-04-14', '09:00:00',
654            0.0, 0.0, 0.0, '0.0', ' ', '0123', '4567', '890'),
655        (1, 1, 1L, True, '1801-03-04', '03:45:00',
656            1.23456, 1.75, 1.875, '1.25', 'x', 'bc', 'cdef', 'g'),
657        (2, 2, 2L, False, '1903-12-17', '11:22:00',
658            2.345678, 2.25, 2.125, '2.75', 'y', 'q', 'ijk', 'mnop\nstux!')]
659
660    def get_back(self):
661        """Convert boolean and decimal values back."""
662        data = []
663        for row in self.c.query("select * from test order by 1").getresult():
664            self.assertIsInstance(row, tuple)
665            row = list(row)
666            if row[0] is not None:  # smallint
667                self.assertIsInstance(row[0], int)
668            if row[1] is not None:  # integer
669                self.assertIsInstance(row[1], int)
670            if row[2] is not None:  # bigint
671                self.assertIsInstance(row[2], long)
672            if row[3] is not None:  # boolean
673                self.assertIsInstance(row[3], str)
674                row[3] = {'f': False, 't': True}.get(row[3])
675            if row[4] is not None:  # date
676                self.assertIsInstance(row[4], str)
677                self.assertTrue(row[4].replace('-', '').isdigit())
678            if row[5] is not None:  # time
679                self.assertIsInstance(row[5], str)
680                self.assertTrue(row[5].replace(':', '').isdigit())
681            if row[6] is not None:  # numeric
682                self.assertIsInstance(row[6], Decimal)
683                row[6] = float(row[6])
684            if row[7] is not None:  # real
685                self.assertIsInstance(row[7], float)
686            if row[8] is not None:  # double precision
687                self.assertIsInstance(row[8], float)
688                row[8] = float(row[8])
689            if row[9] is not None:  # money
690                self.assertIsInstance(row[9], Decimal)
691                row[9] = str(float(row[9]))
692            if row[10] is not None:  # char(1)
693                self.assertIsInstance(row[10], str)
694                self.assertEqual(len(row[10]), 1)
695            if row[11] is not None:  # varchar(4)
696                self.assertIsInstance(row[11], str)
697                self.assertLessEqual(len(row[11]), 4)
698            if row[12] is not None:  # char(4)
699                self.assertIsInstance(row[12], str)
700                self.assertEqual(len(row[12]), 4)
701                row[12] = row[12].rstrip()
702            if row[13] is not None:  # text
703                self.assertIsInstance(row[13], str)
704            row = tuple(row)
705            data.append(row)
706        return data
707
708    def testInserttable1Row(self):
709        data = self.data[2:3]
710        self.c.inserttable("test", data)
711        self.assertEqual(self.get_back(), data)
712
713    def testInserttable4Rows(self):
714        data = self.data
715        self.c.inserttable("test", data)
716        self.assertEqual(self.get_back(), data)
717
718    def testInserttableMultipleRows(self):
719        num_rows = 100
720        data = self.data[2:3] * num_rows
721        self.c.inserttable("test", data)
722        r = self.c.query("select count(*) from test").getresult()[0][0]
723        self.assertEqual(r, num_rows)
724
725    def testInserttableMultipleCalls(self):
726        num_rows = 10
727        data = self.data[2:3]
728        for _i in range(num_rows):
729            self.c.inserttable("test", data)
730        r = self.c.query("select count(*) from test").getresult()[0][0]
731        self.assertEqual(r, num_rows)
732
733    def testInserttableNullValues(self):
734        data = [(None,) * 14] * 100
735        self.c.inserttable("test", data)
736        self.assertEqual(self.get_back(), data)
737
738    def testInserttableMaxValues(self):
739        data = [(2 ** 15 - 1, int(2 ** 31 - 1), long(2 ** 31 - 1),
740            True, '2999-12-31', '11:59:59', 1e99,
741            1.0 + 1.0 / 32, 1.0 + 1.0 / 32, None,
742            "1", "1234", "1234", "1234" * 100)]
743        self.c.inserttable("test", data)
744        self.assertEqual(self.get_back(), data)
745
746
747class TestDirectSocketAccess(unittest.TestCase):
748    """"Test copy command with direct socket access."""
749
750    @classmethod
751    def setUpClass(cls):
752        c = connect()
753        c.query("drop table if exists test cascade")
754        c.query("create table test (i int, v varchar(16))")
755        c.close()
756
757    @classmethod
758    def tearDownClass(cls):
759        c = connect()
760        c.query("drop table test cascade")
761        c.close()
762
763    def setUp(self):
764        self.c = connect()
765        self.c.query("set datestyle='ISO,YMD'")
766
767    def tearDown(self):
768        self.c.query("truncate table test")
769        self.c.close()
770
771    def testPutline(self):
772        putline = self.c.putline
773        query = self.c.query
774        data = list(enumerate("apple pear plum cherry banana".split()))
775        query("copy test from stdin")
776        try:
777            for i, v in data:
778                putline("%d\t%s\n" % (i, v))
779            putline("\\.\n")
780        finally:
781            self.c.endcopy()
782        r = query("select * from test").getresult()
783        self.assertEqual(r, data)
784
785    def testPutline(self):
786        getline = self.c.getline
787        query = self.c.query
788        data = list(enumerate("apple banana pear plum strawberry".split()))
789        n = len(data)
790        self.c.inserttable('test', data)
791        query("copy test to stdout")
792        try:
793            for i in range(n + 2):
794                v = getline()
795                if i < n:
796                    self.assertEqual(v, '%d\t%s' % data[i])
797                elif i == n:
798                    self.assertEqual(v, '\\.')
799                else:
800                    self.assertIsNone(v)
801        finally:
802            try:
803                self.c.endcopy()
804            except IOError:
805                pass
806
807    def testParameterChecks(self):
808        self.assertRaises(TypeError, self.c.putline)
809        self.assertRaises(TypeError, self.c.getline, 'invalid')
810        self.assertRaises(TypeError, self.c.endcopy, 'invalid')
811
812
813class TestNotificatons(unittest.TestCase):
814    """"Test notification support."""
815
816    def setUp(self):
817        self.c = connect()
818
819    def tearDown(self):
820        self.c.close()
821
822    def testGetNotify(self):
823        getnotify = self.c.getnotify
824        query = self.c.query
825        self.assertIsNone(getnotify())
826        query('listen test_notify')
827        try:
828            self.assertIsNone(self.c.getnotify())
829            query("notify test_notify")
830            r = getnotify()
831            self.assertIsInstance(r, tuple)
832            self.assertEqual(len(r), 3)
833            self.assertIsInstance(r[0], str)
834            self.assertIsInstance(r[1], int)
835            self.assertIsInstance(r[2], str)
836            self.assertEqual(r[0], 'test_notify')
837            self.assertEqual(r[2], '')
838            self.assertIsNone(self.c.getnotify())
839            try:
840                query("notify test_notify, 'test_payload'")
841            except pg.ProgrammingError:  # PostgreSQL < 9.0
842                pass
843            else:
844                r = getnotify()
845                self.assertTrue(isinstance(r, tuple))
846                self.assertEqual(len(r), 3)
847                self.assertIsInstance(r[0], str)
848                self.assertIsInstance(r[1], int)
849                self.assertIsInstance(r[2], str)
850                self.assertEqual(r[0], 'test_notify')
851                self.assertEqual(r[2], 'test_payload')
852                self.assertIsNone(getnotify())
853        finally:
854            query('unlisten test_notify')
855
856    def testGetNoticeReceiver(self):
857        self.assertIsNone(self.c.get_notice_receiver())
858
859    def testSetNoticeReceiver(self):
860        self.assertRaises(TypeError, self.c.set_notice_receiver, None)
861        self.assertRaises(TypeError, self.c.set_notice_receiver, 42)
862        self.assertIsNone(self.c.set_notice_receiver(lambda notice: None))
863
864    def testSetAndGetNoticeReceiver(self):
865        r = lambda notice: None
866        self.assertIsNone(self.c.set_notice_receiver(r))
867        self.assertIs(self.c.get_notice_receiver(), r)
868
869    def testNoticeReceiver(self):
870        self.c.query('''create function bilbo_notice() returns void AS $$
871            begin
872                raise warning 'Bilbo was here!';
873            end;
874            $$ language plpgsql''')
875        try:
876            received = {}
877
878            def notice_receiver(notice):
879                for attr in dir(notice):
880                    value = getattr(notice, attr)
881                    if isinstance(value, str):
882                        value = value.replace('WARNUNG', 'WARNING')
883                    received[attr] = value
884
885            self.c.set_notice_receiver(notice_receiver)
886            self.c.query('''select bilbo_notice()''')
887            self.assertEqual(received, dict(
888                pgcnx=self.c, message='WARNING:  Bilbo was here!\n',
889                severity='WARNING', primary='Bilbo was here!',
890                detail=None, hint=None))
891        finally:
892            self.c.query('''drop function bilbo_notice();''')
893
894
895class TestConfigFunctions(unittest.TestCase):
896    """Test the functions for changing default settings.
897
898    To test the effect of most of these functions, we need a database
899    connection.  That's why they are covered in this test module.
900
901    """
902
903    def setUp(self):
904        self.c = connect()
905
906    def tearDown(self):
907        self.c.close()
908
909    def testGetDecimalPoint(self):
910        point = pg.get_decimal_point()
911        self.assertIsInstance(point, str)
912        self.assertEqual(point, '.')
913
914    def testSetDecimalPoint(self):
915        d = pg.Decimal
916        point = pg.get_decimal_point()
917        query = self.c.query
918        # check that money values can be interpreted correctly
919        # if and only if the decimal point is set appropriately
920        # for the current lc_monetary setting
921        query("set lc_monetary='en_US.UTF-8'")
922        pg.set_decimal_point('.')
923        r = query("select '34.25'::money").getresult()[0][0]
924        self.assertIsInstance(r, d)
925        self.assertEqual(r, d('34.25'))
926        pg.set_decimal_point(',')
927        r = query("select '34.25'::money").getresult()[0][0]
928        self.assertNotEqual(r, d('34.25'))
929        query("set lc_monetary='de_DE.UTF-8'")
930        pg.set_decimal_point(',')
931        r = query("select '34,25'::money").getresult()[0][0]
932        self.assertIsInstance(r, d)
933        self.assertEqual(r, d('34.25'))
934        pg.set_decimal_point('.')
935        r = query("select '34,25'::money").getresult()[0][0]
936        self.assertNotEqual(r, d('34.25'))
937        pg.set_decimal_point(point)
938
939    def testSetDecimal(self):
940        d = pg.Decimal
941        query = self.c.query
942        r = query("select 3425::numeric").getresult()[0][0]
943        self.assertIsInstance(r, d)
944        self.assertEqual(r, d('3425'))
945        pg.set_decimal(long)
946        r = query("select 3425::numeric").getresult()[0][0]
947        self.assertNotIsInstance(r, d)
948        self.assertIsInstance(r, long)
949        self.assertEqual(r, 3425L)
950        pg.set_decimal(d)
951
952    @unittest.skipUnless(namedtuple, 'Named tuples not available')
953    def testSetNamedresult(self):
954        query = self.c.query
955
956        r = query("select 1 as x, 2 as y").namedresult()[0]
957        self.assertIsInstance(r, tuple)
958        self.assertEqual(r, (1, 2))
959        self.assertIsNot(type(r), tuple)
960        self.assertEqual(r._fields, ('x', 'y'))
961        self.assertEqual(r._asdict(), {'x': 1, 'y': 2})
962        self.assertEqual(r.__class__.__name__, 'Row')
963
964        _namedresult = pg._namedresult
965        self.assertTrue(callable(_namedresult))
966        pg.set_namedresult(_namedresult)
967
968        r = query("select 1 as x, 2 as y").namedresult()[0]
969        self.assertIsInstance(r, tuple)
970        self.assertEqual(r, (1, 2))
971        self.assertIsNot(type(r), tuple)
972        self.assertEqual(r._fields, ('x', 'y'))
973        self.assertEqual(r._asdict(), {'x': 1, 'y': 2})
974        self.assertEqual(r.__class__.__name__, 'Row')
975
976        def _listresult(q):
977            return map(list, q.getresult())
978
979        pg.set_namedresult(_listresult)
980
981        try:
982            r = query("select 1 as x, 2 as y").namedresult()[0]
983            self.assertIsInstance(r, list)
984            self.assertEqual(r, [1, 2])
985            self.assertIsNot(type(r), tuple)
986            self.assertFalse(hasattr(r, '_fields'))
987            self.assertNotEqual(r.__class__.__name__, 'Row')
988        finally:
989            pg.set_namedresult(_namedresult)
990
991
992if __name__ == '__main__':
993    unittest.main()
Note: See TracBrowser for help on using the repository browser.