source: trunk/module/tests/test_classic_connection.py @ 655

Last change on this file since 655 was 655, checked in by cito, 4 years ago

Fix garbage collection issues

This patch fixes memory management problems, particularly on Windows.

The query() method and the handling of object references inside the method has
been greatly simplified and corrected. Note that we don't calculate and pass
the paramLengths any more since this is only needed when also passing info about
binary data in paramFormats. We also use PyMem_Malloc() instead of malloc() to
allocate memory to make sure the memory is allocated in the Python heap.

  • Property svn:executable set to *
  • Property svn:keywords set to Id
File size: 59.3 KB
Line 
1#! /usr/bin/python
2# -*- coding: utf-8 -*-
3
4"""Test the classic PyGreSQL interface.
5
6Sub-tests for the low-level connection object.
7
8Contributed by Christoph Zwerschke.
9
10These tests need a database to test against.
11
12"""
13
14try:
15    import unittest2 as unittest  # for Python < 2.7
16except ImportError:
17    import unittest
18import threading
19import time
20import os
21
22import pg  # the module under test
23
24from decimal import Decimal
25
26# We need a database to test against.  If LOCAL_PyGreSQL.py exists we will
27# get our information from that.  Otherwise we use the defaults.
28# These tests should be run with various PostgreSQL versions and databases
29# created with different encodings and locales.  Particularly, make sure the
30# tests are running against databases created with both SQL_ASCII and UTF8.
31dbname = 'unittest'
32dbhost = None
33dbport = 5432
34
35try:
36    from .LOCAL_PyGreSQL import *
37except (ImportError, ValueError):
38    try:
39        from LOCAL_PyGreSQL import *
40    except ImportError:
41        pass
42
43try:
44    long
45except NameError:  # Python >= 3.0
46    long = int
47
48try:
49    unicode
50except NameError:  # Python >= 3.0
51    unicode = str
52
53unicode_strings = str is not bytes
54
55windows = os.name == 'nt'
56
57# There is a known a bug in libpq under Windows which can cause
58# the interface to crash when calling PQhost():
59do_not_ask_for_host = windows
60do_not_ask_for_host_reason = 'libpq issue on Windows'
61
62
63def connect():
64    """Create a basic pg connection to the test database."""
65    connection = pg.connect(dbname, dbhost, dbport)
66    connection.query("set client_min_messages=warning")
67    return connection
68
69
70class TestCanConnect(unittest.TestCase):
71    """Test whether a basic connection to PostgreSQL is possible."""
72
73    def testCanConnect(self):
74        try:
75            connection = connect()
76        except pg.Error as error:
77            self.fail('Cannot connect to database %s:\n%s' % (dbname, error))
78        try:
79            connection.close()
80        except pg.Error:
81            self.fail('Cannot close the database connection')
82
83
84class TestConnectObject(unittest.TestCase):
85    """Test existence of basic pg connection methods."""
86
87    def setUp(self):
88        self.connection = connect()
89
90    def tearDown(self):
91        try:
92            self.connection.close()
93        except pg.InternalError:
94            pass
95
96    def is_method(self, attribute):
97        """Check if given attribute on the connection is a method."""
98        if do_not_ask_for_host and attribute == 'host':
99            return False
100        return callable(getattr(self.connection, attribute))
101
102    def testClassName(self):
103        self.assertEqual(self.connection.__class__.__name__, 'Connection')
104
105    def testModuleName(self):
106        self.assertEqual(self.connection.__class__.__module__, 'pg')
107
108    def testStr(self):
109        r = str(self.connection)
110        self.assertTrue(r.startswith('<pg.Connection object'), r)
111
112    def testRepr(self):
113        r = repr(self.connection)
114        self.assertTrue(r.startswith('<pg.Connection object'), r)
115
116    def testAllConnectAttributes(self):
117        attributes = '''db error host options port
118            protocol_version server_version status tty user'''.split()
119        connection_attributes = [a for a in dir(self.connection)
120            if not a.startswith('__') and not self.is_method(a)]
121        self.assertEqual(attributes, connection_attributes)
122
123    def testAllConnectMethods(self):
124        methods = '''cancel close endcopy
125            escape_bytea escape_identifier escape_literal escape_string
126            fileno get_notice_receiver getline getlo getnotify
127            inserttable locreate loimport parameter putline query reset
128            set_notice_receiver source transaction'''.split()
129        connection_methods = [a for a in dir(self.connection)
130            if not a.startswith('__') and self.is_method(a)]
131        self.assertEqual(methods, connection_methods)
132
133    def testAttributeDb(self):
134        self.assertEqual(self.connection.db, dbname)
135
136    def testAttributeError(self):
137        error = self.connection.error
138        self.assertTrue(not error or 'krb5_' in error)
139
140    @unittest.skipIf(do_not_ask_for_host, do_not_ask_for_host_reason)
141    def testAttributeHost(self):
142        def_host = 'localhost'
143        self.assertIsInstance(self.connection.host, str)
144        self.assertEqual(self.connection.host, dbhost or def_host)
145
146    def testAttributeOptions(self):
147        no_options = ''
148        self.assertEqual(self.connection.options, no_options)
149
150    def testAttributePort(self):
151        def_port = 5432
152        self.assertIsInstance(self.connection.port, int)
153        self.assertEqual(self.connection.port, dbport or def_port)
154
155    def testAttributeProtocolVersion(self):
156        protocol_version = self.connection.protocol_version
157        self.assertIsInstance(protocol_version, int)
158        self.assertTrue(2 <= protocol_version < 4)
159
160    def testAttributeServerVersion(self):
161        server_version = self.connection.server_version
162        self.assertIsInstance(server_version, int)
163        self.assertTrue(70400 <= server_version < 100000)
164
165    def testAttributeStatus(self):
166        status_ok = 1
167        self.assertIsInstance(self.connection.status, int)
168        self.assertEqual(self.connection.status, status_ok)
169
170    def testAttributeTty(self):
171        def_tty = ''
172        self.assertIsInstance(self.connection.tty, str)
173        self.assertEqual(self.connection.tty, def_tty)
174
175    def testAttributeUser(self):
176        no_user = 'Deprecated facility'
177        user = self.connection.user
178        self.assertTrue(user)
179        self.assertIsInstance(user, str)
180        self.assertNotEqual(user, no_user)
181
182    def testMethodQuery(self):
183        query = self.connection.query
184        query("select 1+1")
185        query("select 1+$1", (1,))
186        query("select 1+$1+$2", (2, 3))
187        query("select 1+$1+$2", [2, 3])
188
189    def testMethodQueryEmpty(self):
190        self.assertRaises(ValueError, self.connection.query, '')
191
192    def testMethodEndcopy(self):
193        try:
194            self.connection.endcopy()
195        except IOError:
196            pass
197
198    def testMethodClose(self):
199        self.connection.close()
200        try:
201            self.connection.reset()
202        except (pg.Error, TypeError):
203            pass
204        else:
205            self.fail('Reset should give an error for a closed connection')
206        self.assertRaises(pg.InternalError, self.connection.close)
207        try:
208            self.connection.query('select 1')
209        except (pg.Error, TypeError):
210            pass
211        else:
212            self.fail('Query should give an error for a closed connection')
213        self.connection = connect()
214
215    def testMethodReset(self):
216        query = self.connection.query
217        # check that client encoding gets reset
218        encoding = query('show client_encoding').getresult()[0][0].upper()
219        changed_encoding = 'LATIN1' if encoding == 'UTF8' else 'UTF8'
220        self.assertNotEqual(encoding, changed_encoding)
221        self.connection.query("set client_encoding=%s" % changed_encoding)
222        new_encoding = query('show client_encoding').getresult()[0][0].upper()
223        self.assertEqual(new_encoding, changed_encoding)
224        self.connection.reset()
225        new_encoding = query('show client_encoding').getresult()[0][0].upper()
226        self.assertNotEqual(new_encoding, changed_encoding)
227        self.assertEqual(new_encoding, encoding)
228
229    def testMethodCancel(self):
230        r = self.connection.cancel()
231        self.assertIsInstance(r, int)
232        self.assertEqual(r, 1)
233
234    def testCancelLongRunningThread(self):
235        errors = []
236
237        def sleep():
238            try:
239                self.connection.query('select pg_sleep(5)').getresult()
240            except pg.ProgrammingError as error:
241                errors.append(str(error))
242
243        thread = threading.Thread(target=sleep)
244        t1 = time.time()
245        thread.start()  # run the query
246        while 1:  # make sure the query is really running
247            time.sleep(0.1)
248            if thread.is_alive() or time.time() - t1 > 5:
249                break
250        r = self.connection.cancel()  # cancel the running query
251        thread.join()  # wait for the thread to end
252        t2 = time.time()
253
254        self.assertIsInstance(r, int)
255        self.assertEqual(r, 1)  # return code should be 1
256        self.assertLessEqual(t2 - t1, 3)  # time should be under 3 seconds
257        self.assertTrue(errors)
258
259    def testMethodFileNo(self):
260        r = self.connection.fileno()
261        self.assertIsInstance(r, int)
262        self.assertGreaterEqual(r, 0)
263
264
265class TestSimpleQueries(unittest.TestCase):
266    """Test simple queries via a basic pg connection."""
267
268    def setUp(self):
269        self.c = connect()
270
271    def tearDown(self):
272        self.c.close()
273
274    def testClassName(self):
275        r = self.c.query("select 1")
276        self.assertEqual(r.__class__.__name__, 'Query')
277
278    def testModuleName(self):
279        r = self.c.query("select 1")
280        self.assertEqual(r.__class__.__module__, 'pg')
281
282    def testStr(self):
283        q = ("select 1 as a, 'hello' as h, 'w' as world"
284            " union select 2, 'xyz', 'uvw'")
285        r = self.c.query(q)
286        self.assertEqual(str(r),
287            'a|  h  |world\n'
288            '-+-----+-----\n'
289            '1|hello|w    \n'
290            '2|xyz  |uvw  \n'
291            '(2 rows)')
292
293    def testRepr(self):
294        r = repr(self.c.query("select 1"))
295        self.assertTrue(r.startswith('<pg.Query object'), r)
296
297    def testSelect0(self):
298        q = "select 0"
299        self.c.query(q)
300
301    def testSelect0Semicolon(self):
302        q = "select 0;"
303        self.c.query(q)
304
305    def testSelectDotSemicolon(self):
306        q = "select .;"
307        self.assertRaises(pg.ProgrammingError, self.c.query, q)
308
309    def testGetresult(self):
310        q = "select 0"
311        result = [(0,)]
312        r = self.c.query(q).getresult()
313        self.assertIsInstance(r, list)
314        v = r[0]
315        self.assertIsInstance(v, tuple)
316        self.assertIsInstance(v[0], int)
317        self.assertEqual(r, result)
318
319    def testGetresultLong(self):
320        q = "select 9876543210"
321        result = long(9876543210)
322        self.assertIsInstance(result, long)
323        v = self.c.query(q).getresult()[0][0]
324        self.assertIsInstance(v, long)
325        self.assertEqual(v, result)
326
327    def testGetresultDecimal(self):
328        q = "select 98765432109876543210"
329        result = Decimal(98765432109876543210)
330        v = self.c.query(q).getresult()[0][0]
331        self.assertIsInstance(v, Decimal)
332        self.assertEqual(v, result)
333
334    def testGetresultString(self):
335        result = 'Hello, world!'
336        q = "select '%s'" % result
337        v = self.c.query(q).getresult()[0][0]
338        self.assertIsInstance(v, str)
339        self.assertEqual(v, result)
340
341    def testDictresult(self):
342        q = "select 0 as alias0"
343        result = [{'alias0': 0}]
344        r = self.c.query(q).dictresult()
345        self.assertIsInstance(r, list)
346        v = r[0]
347        self.assertIsInstance(v, dict)
348        self.assertIsInstance(v['alias0'], int)
349        self.assertEqual(r, result)
350
351    def testDictresultLong(self):
352        q = "select 9876543210 as longjohnsilver"
353        result = long(9876543210)
354        self.assertIsInstance(result, long)
355        v = self.c.query(q).dictresult()[0]['longjohnsilver']
356        self.assertIsInstance(v, long)
357        self.assertEqual(v, result)
358
359    def testDictresultDecimal(self):
360        q = "select 98765432109876543210 as longjohnsilver"
361        result = Decimal(98765432109876543210)
362        v = self.c.query(q).dictresult()[0]['longjohnsilver']
363        self.assertIsInstance(v, Decimal)
364        self.assertEqual(v, result)
365
366    def testDictresultString(self):
367        result = 'Hello, world!'
368        q = "select '%s' as greeting" % result
369        v = self.c.query(q).dictresult()[0]['greeting']
370        self.assertIsInstance(v, str)
371        self.assertEqual(v, result)
372
373    def testNamedresult(self):
374        q = "select 0 as alias0"
375        result = [(0,)]
376        r = self.c.query(q).namedresult()
377        self.assertEqual(r, result)
378        v = r[0]
379        self.assertEqual(v._fields, ('alias0',))
380        self.assertEqual(v.alias0, 0)
381
382    def testGet3Cols(self):
383        q = "select 1,2,3"
384        result = [(1, 2, 3)]
385        r = self.c.query(q).getresult()
386        self.assertEqual(r, result)
387
388    def testGet3DictCols(self):
389        q = "select 1 as a,2 as b,3 as c"
390        result = [dict(a=1, b=2, c=3)]
391        r = self.c.query(q).dictresult()
392        self.assertEqual(r, result)
393
394    def testGet3NamedCols(self):
395        q = "select 1 as a,2 as b,3 as c"
396        result = [(1, 2, 3)]
397        r = self.c.query(q).namedresult()
398        self.assertEqual(r, result)
399        v = r[0]
400        self.assertEqual(v._fields, ('a', 'b', 'c'))
401        self.assertEqual(v.b, 2)
402
403    def testGet3Rows(self):
404        q = "select 3 union select 1 union select 2 order by 1"
405        result = [(1,), (2,), (3,)]
406        r = self.c.query(q).getresult()
407        self.assertEqual(r, result)
408
409    def testGet3DictRows(self):
410        q = ("select 3 as alias3"
411            " union select 1 union select 2 order by 1")
412        result = [{'alias3': 1}, {'alias3': 2}, {'alias3': 3}]
413        r = self.c.query(q).dictresult()
414        self.assertEqual(r, result)
415
416    def testGet3NamedRows(self):
417        q = ("select 3 as alias3"
418            " union select 1 union select 2 order by 1")
419        result = [(1,), (2,), (3,)]
420        r = self.c.query(q).namedresult()
421        self.assertEqual(r, result)
422        for v in r:
423            self.assertEqual(v._fields, ('alias3',))
424
425    def testDictresultNames(self):
426        q = "select 'MixedCase' as MixedCaseAlias"
427        result = [{'mixedcasealias': 'MixedCase'}]
428        r = self.c.query(q).dictresult()
429        self.assertEqual(r, result)
430        q = "select 'MixedCase' as \"MixedCaseAlias\""
431        result = [{'MixedCaseAlias': 'MixedCase'}]
432        r = self.c.query(q).dictresult()
433        self.assertEqual(r, result)
434
435    def testNamedresultNames(self):
436        q = "select 'MixedCase' as MixedCaseAlias"
437        result = [('MixedCase',)]
438        r = self.c.query(q).namedresult()
439        self.assertEqual(r, result)
440        v = r[0]
441        self.assertEqual(v._fields, ('mixedcasealias',))
442        self.assertEqual(v.mixedcasealias, 'MixedCase')
443        q = "select 'MixedCase' as \"MixedCaseAlias\""
444        r = self.c.query(q).namedresult()
445        self.assertEqual(r, result)
446        v = r[0]
447        self.assertEqual(v._fields, ('MixedCaseAlias',))
448        self.assertEqual(v.MixedCaseAlias, 'MixedCase')
449
450    def testBigGetresult(self):
451        num_cols = 100
452        num_rows = 100
453        q = "select " + ','.join(map(str, range(num_cols)))
454        q = ' union all '.join((q,) * num_rows)
455        r = self.c.query(q).getresult()
456        result = [tuple(range(num_cols))] * num_rows
457        self.assertEqual(r, result)
458
459    def testListfields(self):
460        q = ('select 0 as a, 0 as b, 0 as c,'
461            ' 0 as c, 0 as b, 0 as a,'
462            ' 0 as lowercase, 0 as UPPERCASE,'
463            ' 0 as MixedCase, 0 as "MixedCase",'
464            ' 0 as a_long_name_with_underscores,'
465            ' 0 as "A long name with Blanks"')
466        r = self.c.query(q).listfields()
467        result = ('a', 'b', 'c', 'c', 'b', 'a',
468            'lowercase', 'uppercase', 'mixedcase', 'MixedCase',
469            'a_long_name_with_underscores',
470            'A long name with Blanks')
471        self.assertEqual(r, result)
472
473    def testFieldname(self):
474        q = "select 0 as z, 0 as a, 0 as x, 0 as y"
475        r = self.c.query(q).fieldname(2)
476        self.assertEqual(r, 'x')
477        r = self.c.query(q).fieldname(3)
478        self.assertEqual(r, 'y')
479
480    def testFieldnum(self):
481        q = "select 1 as x"
482        self.assertRaises(ValueError, self.c.query(q).fieldnum, 'y')
483        q = "select 1 as x"
484        r = self.c.query(q).fieldnum('x')
485        self.assertIsInstance(r, int)
486        self.assertEqual(r, 0)
487        q = "select 0 as z, 0 as a, 0 as x, 0 as y"
488        r = self.c.query(q).fieldnum('x')
489        self.assertIsInstance(r, int)
490        self.assertEqual(r, 2)
491        r = self.c.query(q).fieldnum('y')
492        self.assertIsInstance(r, int)
493        self.assertEqual(r, 3)
494
495    def testNtuples(self):
496        q = "select 1 where false"
497        r = self.c.query(q).ntuples()
498        self.assertIsInstance(r, int)
499        self.assertEqual(r, 0)
500        q = ("select 1 as a, 2 as b, 3 as c, 4 as d"
501            " union select 5 as a, 6 as b, 7 as c, 8 as d")
502        r = self.c.query(q).ntuples()
503        self.assertIsInstance(r, int)
504        self.assertEqual(r, 2)
505        q = ("select 1 union select 2 union select 3"
506            " union select 4 union select 5 union select 6")
507        r = self.c.query(q).ntuples()
508        self.assertIsInstance(r, int)
509        self.assertEqual(r, 6)
510
511    def testQuery(self):
512        query = self.c.query
513        query("drop table if exists test_table")
514        q = "create table test_table (n integer) with oids"
515        r = query(q)
516        self.assertIsNone(r)
517        q = "insert into test_table values (1)"
518        r = query(q)
519        self.assertIsInstance(r, int)
520        q = "insert into test_table select 2"
521        r = query(q)
522        self.assertIsInstance(r, int)
523        oid = r
524        q = "select oid from test_table where n=2"
525        r = query(q).getresult()
526        self.assertEqual(len(r), 1)
527        r = r[0]
528        self.assertEqual(len(r), 1)
529        r = r[0]
530        self.assertIsInstance(r, int)
531        self.assertEqual(r, oid)
532        q = "insert into test_table select 3 union select 4 union select 5"
533        r = query(q)
534        self.assertIsInstance(r, str)
535        self.assertEqual(r, '3')
536        q = "update test_table set n=4 where n<5"
537        r = query(q)
538        self.assertIsInstance(r, str)
539        self.assertEqual(r, '4')
540        q = "delete from test_table"
541        r = query(q)
542        self.assertIsInstance(r, str)
543        self.assertEqual(r, '5')
544        query("drop table test_table")
545
546
547class TestUnicodeQueries(unittest.TestCase):
548    """Test unicode strings as queries via a basic pg connection."""
549
550    def setUp(self):
551        self.c = connect()
552        self.c.query('set client_encoding=utf8')
553
554    def tearDown(self):
555        self.c.close()
556
557    def testGetresulAscii(self):
558        result = u'Hello, world!'
559        q = u"select '%s'" % result
560        v = self.c.query(q).getresult()[0][0]
561        self.assertIsInstance(v, str)
562        self.assertEqual(v, result)
563
564    def testDictresulAscii(self):
565        result = u'Hello, world!'
566        q = u"select '%s' as greeting" % result
567        v = self.c.query(q).dictresult()[0]['greeting']
568        self.assertIsInstance(v, str)
569        self.assertEqual(v, result)
570
571    def testGetresultUtf8(self):
572        result = u'Hello, wörld & ЌОр!'
573        q = u"select '%s'" % result
574        if not unicode_strings:
575            result = result.encode('utf8')
576        # pass the query as unicode
577        try:
578            v = self.c.query(q).getresult()[0][0]
579        except pg.ProgrammingError:
580            self.skipTest("database does not support utf8")
581        self.assertIsInstance(v, str)
582        self.assertEqual(v, result)
583        q = q.encode('utf8')
584        # pass the query as bytes
585        v = self.c.query(q).getresult()[0][0]
586        self.assertIsInstance(v, str)
587        self.assertEqual(v, result)
588
589    def testDictresultUtf8(self):
590        result = u'Hello, wörld & ЌОр!'
591        q = u"select '%s' as greeting" % result
592        if not unicode_strings:
593            result = result.encode('utf8')
594        try:
595            v = self.c.query(q).dictresult()[0]['greeting']
596        except pg.ProgrammingError:
597            self.skipTest("database does not support utf8")
598        self.assertIsInstance(v, str)
599        self.assertEqual(v, result)
600        q = q.encode('utf8')
601        v = self.c.query(q).dictresult()[0]['greeting']
602        self.assertIsInstance(v, str)
603        self.assertEqual(v, result)
604
605    def testDictresultLatin1(self):
606        try:
607            self.c.query('set client_encoding=latin1')
608        except pg.ProgrammingError:
609            self.skipTest("database does not support latin1")
610        result = u'Hello, wörld!'
611        q = u"select '%s'" % result
612        if not unicode_strings:
613            result = result.encode('latin1')
614        v = self.c.query(q).getresult()[0][0]
615        self.assertIsInstance(v, str)
616        self.assertEqual(v, result)
617        q = q.encode('latin1')
618        v = self.c.query(q).getresult()[0][0]
619        self.assertIsInstance(v, str)
620        self.assertEqual(v, result)
621
622    def testDictresultLatin1(self):
623        try:
624            self.c.query('set client_encoding=latin1')
625        except pg.ProgrammingError:
626            self.skipTest("database does not support latin1")
627        result = u'Hello, wörld!'
628        q = u"select '%s' as greeting" % result
629        if not unicode_strings:
630            result = result.encode('latin1')
631        v = self.c.query(q).dictresult()[0]['greeting']
632        self.assertIsInstance(v, str)
633        self.assertEqual(v, result)
634        q = q.encode('latin1')
635        v = self.c.query(q).dictresult()[0]['greeting']
636        self.assertIsInstance(v, str)
637        self.assertEqual(v, result)
638
639    def testGetresultCyrillic(self):
640        try:
641            self.c.query('set client_encoding=iso_8859_5')
642        except pg.ProgrammingError:
643            self.skipTest("database does not support cyrillic")
644        result = u'Hello, ЌОр!'
645        q = u"select '%s'" % result
646        if not unicode_strings:
647            result = result.encode('cyrillic')
648        v = self.c.query(q).getresult()[0][0]
649        self.assertIsInstance(v, str)
650        self.assertEqual(v, result)
651        q = q.encode('cyrillic')
652        v = self.c.query(q).getresult()[0][0]
653        self.assertIsInstance(v, str)
654        self.assertEqual(v, result)
655
656    def testDictresultCyrillic(self):
657        try:
658            self.c.query('set client_encoding=iso_8859_5')
659        except pg.ProgrammingError:
660            self.skipTest("database does not support cyrillic")
661        result = u'Hello, ЌОр!'
662        q = u"select '%s' as greeting" % result
663        if not unicode_strings:
664            result = result.encode('cyrillic')
665        v = self.c.query(q).dictresult()[0]['greeting']
666        self.assertIsInstance(v, str)
667        self.assertEqual(v, result)
668        q = q.encode('cyrillic')
669        v = self.c.query(q).dictresult()[0]['greeting']
670        self.assertIsInstance(v, str)
671        self.assertEqual(v, result)
672
673    def testGetresultLatin9(self):
674        try:
675            self.c.query('set client_encoding=latin9')
676        except pg.ProgrammingError:
677            self.skipTest("database does not support latin9")
678        result = u'smœrebrœd with praÅŸská Å¡unka (pay in ¢, £, €, or Â¥)'
679        q = u"select '%s'" % result
680        if not unicode_strings:
681            result = result.encode('latin9')
682        v = self.c.query(q).getresult()[0][0]
683        self.assertIsInstance(v, str)
684        self.assertEqual(v, result)
685        q = q.encode('latin9')
686        v = self.c.query(q).getresult()[0][0]
687        self.assertIsInstance(v, str)
688        self.assertEqual(v, result)
689
690    def testDictresultLatin9(self):
691        try:
692            self.c.query('set client_encoding=latin9')
693        except pg.ProgrammingError:
694            self.skipTest("database does not support latin9")
695        result = u'smœrebrœd with praÅŸská Å¡unka (pay in ¢, £, €, or Â¥)'
696        q = u"select '%s' as menu" % result
697        if not unicode_strings:
698            result = result.encode('latin9')
699        v = self.c.query(q).dictresult()[0]['menu']
700        self.assertIsInstance(v, str)
701        self.assertEqual(v, result)
702        q = q.encode('latin9')
703        v = self.c.query(q).dictresult()[0]['menu']
704        self.assertIsInstance(v, str)
705        self.assertEqual(v, result)
706
707
708class TestParamQueries(unittest.TestCase):
709    """Test queries with parameters via a basic pg connection."""
710
711    def setUp(self):
712        self.c = connect()
713        self.c.query('set client_encoding=utf8')
714
715    def tearDown(self):
716        self.c.close()
717
718    def testQueryWithNoneParam(self):
719        self.assertRaises(TypeError, self.c.query, "select $1", None)
720        self.assertRaises(TypeError, self.c.query, "select $1+$2", None, None)
721        self.assertEqual(self.c.query("select $1::integer", (None,)
722            ).getresult(), [(None,)])
723        self.assertEqual(self.c.query("select $1::text", [None]
724            ).getresult(), [(None,)])
725        self.assertEqual(self.c.query("select $1::text", [[None]]
726            ).getresult(), [(None,)])
727
728    def testQueryWithBoolParams(self, use_bool=None):
729        query = self.c.query
730        if use_bool is not None:
731            use_bool_default = pg.get_bool()
732            pg.set_bool(use_bool)
733        try:
734            v_false, v_true = (False, True) if use_bool else 'ft'
735            r_false, r_true = [(v_false,)], [(v_true,)]
736            self.assertEqual(query("select false").getresult(), r_false)
737            self.assertEqual(query("select true").getresult(), r_true)
738            q = "select $1::bool"
739            self.assertEqual(query(q, (None,)).getresult(), [(None,)])
740            self.assertEqual(query(q, ('f',)).getresult(), r_false)
741            self.assertEqual(query(q, ('t',)).getresult(), r_true)
742            self.assertEqual(query(q, ('false',)).getresult(), r_false)
743            self.assertEqual(query(q, ('true',)).getresult(), r_true)
744            self.assertEqual(query(q, ('n',)).getresult(), r_false)
745            self.assertEqual(query(q, ('y',)).getresult(), r_true)
746            self.assertEqual(query(q, (0,)).getresult(), r_false)
747            self.assertEqual(query(q, (1,)).getresult(), r_true)
748            self.assertEqual(query(q, (False,)).getresult(), r_false)
749            self.assertEqual(query(q, (True,)).getresult(), r_true)
750        finally:
751            if use_bool is not None:
752                pg.set_bool(use_bool_default)
753
754    def testQueryWithBoolParamsAndUseBool(self):
755        self.testQueryWithBoolParams(use_bool=True)
756
757    def testQueryWithIntParams(self):
758        query = self.c.query
759        self.assertEqual(query("select 1+1").getresult(), [(2,)])
760        self.assertEqual(query("select 1+$1", (1,)).getresult(), [(2,)])
761        self.assertEqual(query("select 1+$1", [1]).getresult(), [(2,)])
762        self.assertEqual(query("select $1::integer", (2,)).getresult(), [(2,)])
763        self.assertEqual(query("select $1::text", (2,)).getresult(), [('2',)])
764        self.assertEqual(query("select 1+$1::numeric", [1]).getresult(),
765            [(Decimal('2'),)])
766        self.assertEqual(query("select 1, $1::integer", (2,)
767            ).getresult(), [(1, 2)])
768        self.assertEqual(query("select 1 union select $1", (2,)
769            ).getresult(), [(1,), (2,)])
770        self.assertEqual(query("select $1::integer+$2", (1, 2)
771            ).getresult(), [(3,)])
772        self.assertEqual(query("select $1::integer+$2", [1, 2]
773            ).getresult(), [(3,)])
774        self.assertEqual(query("select 0+$1+$2+$3+$4+$5+$6", list(range(6))
775            ).getresult(), [(15,)])
776
777    def testQueryWithStrParams(self):
778        query = self.c.query
779        self.assertEqual(query("select $1||', world!'", ('Hello',)
780            ).getresult(), [('Hello, world!',)])
781        self.assertEqual(query("select $1||', world!'", ['Hello']
782            ).getresult(), [('Hello, world!',)])
783        self.assertEqual(query("select $1||', '||$2||'!'", ('Hello', 'world'),
784            ).getresult(), [('Hello, world!',)])
785        self.assertEqual(query("select $1::text", ('Hello, world!',)
786            ).getresult(), [('Hello, world!',)])
787        self.assertEqual(query("select $1::text,$2::text", ('Hello', 'world')
788            ).getresult(), [('Hello', 'world')])
789        self.assertEqual(query("select $1::text,$2::text", ['Hello', 'world']
790            ).getresult(), [('Hello', 'world')])
791        self.assertEqual(query("select $1::text union select $2::text",
792            ('Hello', 'world')).getresult(), [('Hello',), ('world',)])
793        self.assertEqual(query("select $1||', '||$2||'!'", ('Hello',
794            'w\xc3\xb6rld')).getresult(), [('Hello, w\xc3\xb6rld!',)])
795
796    def testQueryWithUnicodeParams(self):
797        query = self.c.query
798        try:
799            query('set client_encoding=utf8')
800            query("select 'wörld'").getresult()[0][0] == 'wörld'
801        except pg.ProgrammingError:
802            self.skipTest("database does not support utf8")
803        self.assertEqual(query("select $1||', '||$2||'!'",
804            ('Hello', u'wörld')).getresult(), [('Hello, wörld!',)])
805
806    def testQueryWithUnicodeParamsLatin1(self):
807        query = self.c.query
808        try:
809            query('set client_encoding=latin1')
810            query("select 'wörld'").getresult()[0][0] == 'wörld'
811        except pg.ProgrammingError:
812            self.skipTest("database does not support latin1")
813        r = query("select $1||', '||$2||'!'", ('Hello', u'wörld')).getresult()
814        if unicode_strings:
815            self.assertEqual(r, [('Hello, wörld!',)])
816        else:
817            self.assertEqual(r, [(u'Hello, wörld!'.encode('latin1'),)])
818        self.assertRaises(UnicodeError, query, "select $1||', '||$2||'!'",
819            ('Hello', u'ЌОр'))
820        query('set client_encoding=iso_8859_1')
821        r = query("select $1||', '||$2||'!'",
822            ('Hello', u'wörld')).getresult()
823        if unicode_strings:
824            self.assertEqual(r, [('Hello, wörld!',)])
825        else:
826            self.assertEqual(r, [(u'Hello, wörld!'.encode('latin1'),)])
827        self.assertRaises(UnicodeError, query, "select $1||', '||$2||'!'",
828            ('Hello', u'ЌОр'))
829        query('set client_encoding=sql_ascii')
830        self.assertRaises(UnicodeError, query, "select $1||', '||$2||'!'",
831            ('Hello', u'wörld'))
832
833    def testQueryWithUnicodeParamsCyrillic(self):
834        query = self.c.query
835        try:
836            query('set client_encoding=iso_8859_5')
837            query("select 'ЌОр'").getresult()[0][0] == 'ЌОр'
838        except pg.ProgrammingError:
839            self.skipTest("database does not support cyrillic")
840        self.assertRaises(UnicodeError, query, "select $1||', '||$2||'!'",
841            ('Hello', u'wörld'))
842        r = query("select $1||', '||$2||'!'",
843            ('Hello', u'ЌОр')).getresult()
844        if unicode_strings:
845            self.assertEqual(r, [('Hello, ЌОр!',)])
846        else:
847            self.assertEqual(r, [(u'Hello, ЌОр!'.encode('cyrillic'),)])
848        query('set client_encoding=sql_ascii')
849        self.assertRaises(UnicodeError, query, "select $1||', '||$2||'!'",
850            ('Hello', u'ЌОр!'))
851
852    def testQueryWithMixedParams(self):
853        self.assertEqual(self.c.query("select $1+2,$2||', world!'",
854            (1, 'Hello'),).getresult(), [(3, 'Hello, world!')])
855        self.assertEqual(self.c.query("select $1::integer,$2::date,$3::text",
856            (4711, None, 'Hello!'),).getresult(), [(4711, None, 'Hello!')])
857
858    def testQueryWithDuplicateParams(self):
859        self.assertRaises(pg.ProgrammingError,
860            self.c.query, "select $1+$1", (1,))
861        self.assertRaises(pg.ProgrammingError,
862            self.c.query, "select $1+$1", (1, 2))
863
864    def testQueryWithZeroParams(self):
865        self.assertEqual(self.c.query("select 1+1", []
866            ).getresult(), [(2,)])
867
868    def testQueryWithGarbage(self):
869        garbage = r"'\{}+()-#[]oo324"
870        self.assertEqual(self.c.query("select $1::text AS garbage", (garbage,)
871            ).dictresult(), [{'garbage': garbage}])
872
873
874class TestInserttable(unittest.TestCase):
875    """Test inserttable method."""
876
877    @classmethod
878    def setUpClass(cls):
879        c = connect()
880        c.query("drop table if exists test cascade")
881        c.query("create table test ("
882            "i2 smallint, i4 integer, i8 bigint, b boolean, dt date, ti time,"
883            "d numeric, f4 real, f8 double precision, m money,"
884            "c char(1), v4 varchar(4), c4 char(4), t text)")
885        # Check whether the test database uses SQL_ASCII - this means
886        # that it does not consider encoding when calculating lengths.
887        c.query("set client_encoding=utf8")
888        cls.has_encoding = c.query(
889            "select length('À') - length('a')").getresult()[0][0] == 0
890        c.close()
891
892    @classmethod
893    def tearDownClass(cls):
894        c = connect()
895        c.query("drop table test cascade")
896        c.close()
897
898    def setUp(self):
899        self.c = connect()
900        self.c.query("set client_encoding=utf8")
901        self.c.query("set datestyle='ISO,YMD'")
902        self.c.query("set lc_monetary='C'")
903
904    def tearDown(self):
905        self.c.query("truncate table test")
906        self.c.close()
907
908    data = [
909        (-1, -1, long(-1), True, '1492-10-12', '08:30:00',
910            -1.2345, -1.75, -1.875, '-1.25', '-', 'r?', '!u', 'xyz'),
911        (0, 0, long(0), False, '1607-04-14', '09:00:00',
912            0.0, 0.0, 0.0, '0.0', ' ', '0123', '4567', '890'),
913        (1, 1, long(1), True, '1801-03-04', '03:45:00',
914            1.23456, 1.75, 1.875, '1.25', 'x', 'bc', 'cdef', 'g'),
915        (2, 2, long(2), False, '1903-12-17', '11:22:00',
916            2.345678, 2.25, 2.125, '2.75', 'y', 'q', 'ijk', 'mnop\nstux!')]
917
918    @classmethod
919    def db_len(cls, s, encoding):
920        if cls.has_encoding:
921            s = s if isinstance(s, unicode) else s.decode(encoding)
922        else:
923            s = s.encode(encoding) if isinstance(s, unicode) else s
924        return len(s)
925
926    def get_back(self, encoding='utf-8'):
927        """Convert boolean and decimal values back."""
928        data = []
929        for row in self.c.query("select * from test order by 1").getresult():
930            self.assertIsInstance(row, tuple)
931            row = list(row)
932            if row[0] is not None:  # smallint
933                self.assertIsInstance(row[0], int)
934            if row[1] is not None:  # integer
935                self.assertIsInstance(row[1], int)
936            if row[2] is not None:  # bigint
937                self.assertIsInstance(row[2], long)
938            if row[3] is not None:  # boolean
939                self.assertIsInstance(row[3], str)
940                row[3] = {'f': False, 't': True}.get(row[3])
941            if row[4] is not None:  # date
942                self.assertIsInstance(row[4], str)
943                self.assertTrue(row[4].replace('-', '').isdigit())
944            if row[5] is not None:  # time
945                self.assertIsInstance(row[5], str)
946                self.assertTrue(row[5].replace(':', '').isdigit())
947            if row[6] is not None:  # numeric
948                self.assertIsInstance(row[6], Decimal)
949                row[6] = float(row[6])
950            if row[7] is not None:  # real
951                self.assertIsInstance(row[7], float)
952            if row[8] is not None:  # double precision
953                self.assertIsInstance(row[8], float)
954                row[8] = float(row[8])
955            if row[9] is not None:  # money
956                self.assertIsInstance(row[9], Decimal)
957                row[9] = str(float(row[9]))
958            if row[10] is not None:  # char(1)
959                self.assertIsInstance(row[10], str)
960                self.assertEqual(self.db_len(row[10], encoding), 1)
961            if row[11] is not None:  # varchar(4)
962                self.assertIsInstance(row[11], str)
963                self.assertLessEqual(self.db_len(row[11], encoding), 4)
964            if row[12] is not None:  # char(4)
965                self.assertIsInstance(row[12], str)
966                self.assertEqual(self.db_len(row[12], encoding), 4)
967                row[12] = row[12].rstrip()
968            if row[13] is not None:  # text
969                self.assertIsInstance(row[13], str)
970            row = tuple(row)
971            data.append(row)
972        return data
973
974    def testInserttable1Row(self):
975        data = self.data[2:3]
976        self.c.inserttable('test', data)
977        self.assertEqual(self.get_back(), data)
978
979    def testInserttable4Rows(self):
980        data = self.data
981        self.c.inserttable('test', data)
982        self.assertEqual(self.get_back(), data)
983
984    def testInserttableMultipleRows(self):
985        num_rows = 100
986        data = self.data[2:3] * num_rows
987        self.c.inserttable('test', data)
988        r = self.c.query("select count(*) from test").getresult()[0][0]
989        self.assertEqual(r, num_rows)
990
991    def testInserttableMultipleCalls(self):
992        num_rows = 10
993        data = self.data[2:3]
994        for _i in range(num_rows):
995            self.c.inserttable('test', data)
996        r = self.c.query("select count(*) from test").getresult()[0][0]
997        self.assertEqual(r, num_rows)
998
999    def testInserttableNullValues(self):
1000        data = [(None,) * 14] * 100
1001        self.c.inserttable('test', data)
1002        self.assertEqual(self.get_back(), data)
1003
1004    def testInserttableMaxValues(self):
1005        data = [(2 ** 15 - 1, int(2 ** 31 - 1), long(2 ** 31 - 1),
1006            True, '2999-12-31', '11:59:59', 1e99,
1007            1.0 + 1.0 / 32, 1.0 + 1.0 / 32, None,
1008            "1", "1234", "1234", "1234" * 100)]
1009        self.c.inserttable('test', data)
1010        self.assertEqual(self.get_back(), data)
1011
1012    def testInserttableByteValues(self):
1013        try:
1014            self.c.query("select '€', 'kÀse', 'сыр', 'pont-l''évêque'")
1015        except pg.ProgrammingError:
1016            self.skipTest("database does not support utf8")
1017        # non-ascii chars do not fit in char(1) when there is no encoding
1018        c = u'€' if self.has_encoding else u'$'
1019        row_unicode = (0, 0, long(0), False, u'1970-01-01', u'00:00:00',
1020            0.0, 0.0, 0.0, u'0.0',
1021            c, u'bÀd', u'bÀd', u"kÀse сыр pont-l'évêque")
1022        row_bytes = tuple(s.encode('utf-8')
1023            if isinstance(s, unicode) else s for s in row_unicode)
1024        data = [row_bytes] * 2
1025        self.c.inserttable('test', data)
1026        if unicode_strings:
1027            data = [row_unicode] * 2
1028        self.assertEqual(self.get_back(), data)
1029
1030    def testInserttableUnicodeUtf8(self):
1031        try:
1032            self.c.query("select '€', 'kÀse', 'сыр', 'pont-l''évêque'")
1033        except pg.ProgrammingError:
1034            self.skipTest("database does not support utf8")
1035        # non-ascii chars do not fit in char(1) when there is no encoding
1036        c = u'€' if self.has_encoding else u'$'
1037        row_unicode = (0, 0, long(0), False, u'1970-01-01', u'00:00:00',
1038            0.0, 0.0, 0.0, u'0.0',
1039            c, u'bÀd', u'bÀd', u"kÀse сыр pont-l'évêque")
1040        data = [row_unicode] * 2
1041        self.c.inserttable('test', data)
1042        if not unicode_strings:
1043            row_bytes = tuple(s.encode('utf-8')
1044                if isinstance(s, unicode) else s for s in row_unicode)
1045            data = [row_bytes] * 2
1046        self.assertEqual(self.get_back(), data)
1047
1048    def testInserttableUnicodeLatin1(self):
1049
1050        try:
1051            self.c.query("set client_encoding=latin1")
1052            self.c.query("select 'Â¥'")
1053        except pg.ProgrammingError:
1054            self.skipTest("database does not support latin1")
1055        # non-ascii chars do not fit in char(1) when there is no encoding
1056        c = u'€' if self.has_encoding else u'$'
1057        row_unicode = (0, 0, long(0), False, u'1970-01-01', u'00:00:00',
1058            0.0, 0.0, 0.0, u'0.0',
1059            c, u'bÀd', u'bÀd', u"for kÀse and pont-l'évêque pay in €")
1060        data = [row_unicode]
1061        # cannot encode € sign with latin1 encoding
1062        self.assertRaises(UnicodeEncodeError, self.c.inserttable, 'test', data)
1063        row_unicode = tuple(s.replace(u'€', u'Â¥')
1064            if isinstance(s, unicode) else s for s in row_unicode)
1065        data = [row_unicode] * 2
1066        self.c.inserttable('test', data)
1067        if not unicode_strings:
1068            row_bytes = tuple(s.encode('latin1')
1069                if isinstance(s, unicode) else s for s in row_unicode)
1070            data = [row_bytes] * 2
1071        self.assertEqual(self.get_back('latin1'), data)
1072
1073    def testInserttableUnicodeLatin9(self):
1074        try:
1075            self.c.query("set client_encoding=latin9")
1076            self.c.query("select '€'")
1077        except pg.ProgrammingError:
1078            self.skipTest("database does not support latin9")
1079            return
1080        # non-ascii chars do not fit in char(1) when there is no encoding
1081        c = u'€' if self.has_encoding else u'$'
1082        row_unicode = (0, 0, long(0), False, u'1970-01-01', u'00:00:00',
1083            0.0, 0.0, 0.0, u'0.0',
1084            c, u'bÀd', u'bÀd', u"for kÀse and pont-l'évêque pay in €")
1085        data = [row_unicode] * 2
1086        self.c.inserttable('test', data)
1087        if not unicode_strings:
1088            row_bytes = tuple(s.encode('latin9')
1089                if isinstance(s, unicode) else s for s in row_unicode)
1090            data = [row_bytes] * 2
1091        self.assertEqual(self.get_back('latin9'), data)
1092
1093    def testInserttableNoEncoding(self):
1094        self.c.query("set client_encoding=sql_ascii")
1095        # non-ascii chars do not fit in char(1) when there is no encoding
1096        c = u'€' if self.has_encoding else u'$'
1097        row_unicode = (0, 0, long(0), False, u'1970-01-01', u'00:00:00',
1098            0.0, 0.0, 0.0, u'0.0',
1099            c, u'bÀd', u'bÀd', u"for kÀse and pont-l'évêque pay in €")
1100        data = [row_unicode]
1101        # cannot encode non-ascii unicode without a specific encoding
1102        self.assertRaises(UnicodeEncodeError, self.c.inserttable, 'test', data)
1103
1104
1105class TestDirectSocketAccess(unittest.TestCase):
1106    """Test copy command with direct socket access."""
1107
1108    @classmethod
1109    def setUpClass(cls):
1110        c = connect()
1111        c.query("drop table if exists test cascade")
1112        c.query("create table test (i int, v varchar(16))")
1113        c.close()
1114
1115    @classmethod
1116    def tearDownClass(cls):
1117        c = connect()
1118        c.query("drop table test cascade")
1119        c.close()
1120
1121    def setUp(self):
1122        self.c = connect()
1123        self.c.query("set client_encoding=utf8")
1124
1125    def tearDown(self):
1126        self.c.query("truncate table test")
1127        self.c.close()
1128
1129    def testPutline(self):
1130        putline = self.c.putline
1131        query = self.c.query
1132        data = list(enumerate("apple pear plum cherry banana".split()))
1133        query("copy test from stdin")
1134        try:
1135            for i, v in data:
1136                putline("%d\t%s\n" % (i, v))
1137            putline("\\.\n")
1138        finally:
1139            self.c.endcopy()
1140        r = query("select * from test").getresult()
1141        self.assertEqual(r, data)
1142
1143    def testPutlineBytesAndUnicode(self):
1144        putline = self.c.putline
1145        query = self.c.query
1146        query("copy test from stdin")
1147        try:
1148            putline(u"47\tkÀse\n".encode('utf8'))
1149            putline("35\twÃŒrstel\n")
1150            putline(b"\\.\n")
1151        finally:
1152            self.c.endcopy()
1153        r = query("select * from test").getresult()
1154        self.assertEqual(r, [(47, 'kÀse'), (35, 'wÃŒrstel')])
1155
1156    def testGetline(self):
1157        getline = self.c.getline
1158        query = self.c.query
1159        data = list(enumerate("apple banana pear plum strawberry".split()))
1160        n = len(data)
1161        self.c.inserttable('test', data)
1162        query("copy test to stdout")
1163        try:
1164            for i in range(n + 2):
1165                v = getline()
1166                if i < n:
1167                    self.assertEqual(v, '%d\t%s' % data[i])
1168                elif i == n:
1169                    self.assertEqual(v, '\\.')
1170                else:
1171                    self.assertIsNone(v)
1172        finally:
1173            try:
1174                self.c.endcopy()
1175            except IOError:
1176                pass
1177
1178    def testGetlineBytesAndUnicode(self):
1179        getline = self.c.getline
1180        query = self.c.query
1181        data = [(54, u'kÀse'.encode('utf8')), (73, u'wÃŒrstel')]
1182        self.c.inserttable('test', data)
1183        query("copy test to stdout")
1184        try:
1185            v = getline()
1186            self.assertIsInstance(v, str)
1187            self.assertEqual(v, '54\tkÀse')
1188            v = getline()
1189            self.assertIsInstance(v, str)
1190            self.assertEqual(v, '73\twÃŒrstel')
1191            self.assertEqual(getline(), '\\.')
1192            self.assertIsNone(getline())
1193        finally:
1194            try:
1195                self.c.endcopy()
1196            except IOError:
1197                pass
1198
1199    def testParameterChecks(self):
1200        self.assertRaises(TypeError, self.c.putline)
1201        self.assertRaises(TypeError, self.c.getline, 'invalid')
1202        self.assertRaises(TypeError, self.c.endcopy, 'invalid')
1203
1204
1205class TestNotificatons(unittest.TestCase):
1206    """Test notification support."""
1207
1208    def setUp(self):
1209        self.c = connect()
1210
1211    def tearDown(self):
1212        self.c.close()
1213
1214    def testGetNotify(self):
1215        getnotify = self.c.getnotify
1216        query = self.c.query
1217        self.assertIsNone(getnotify())
1218        query('listen test_notify')
1219        try:
1220            self.assertIsNone(self.c.getnotify())
1221            query("notify test_notify")
1222            r = getnotify()
1223            self.assertIsInstance(r, tuple)
1224            self.assertEqual(len(r), 3)
1225            self.assertIsInstance(r[0], str)
1226            self.assertIsInstance(r[1], int)
1227            self.assertIsInstance(r[2], str)
1228            self.assertEqual(r[0], 'test_notify')
1229            self.assertEqual(r[2], '')
1230            self.assertIsNone(self.c.getnotify())
1231            try:
1232                query("notify test_notify, 'test_payload'")
1233            except pg.ProgrammingError:  # PostgreSQL < 9.0
1234                pass
1235            else:
1236                r = getnotify()
1237                self.assertTrue(isinstance(r, tuple))
1238                self.assertEqual(len(r), 3)
1239                self.assertIsInstance(r[0], str)
1240                self.assertIsInstance(r[1], int)
1241                self.assertIsInstance(r[2], str)
1242                self.assertEqual(r[0], 'test_notify')
1243                self.assertEqual(r[2], 'test_payload')
1244                self.assertIsNone(getnotify())
1245        finally:
1246            query('unlisten test_notify')
1247
1248    def testGetNoticeReceiver(self):
1249        self.assertIsNone(self.c.get_notice_receiver())
1250
1251    def testSetNoticeReceiver(self):
1252        self.assertRaises(TypeError, self.c.set_notice_receiver, None)
1253        self.assertRaises(TypeError, self.c.set_notice_receiver, 42)
1254        self.assertIsNone(self.c.set_notice_receiver(lambda notice: None))
1255
1256    def testSetAndGetNoticeReceiver(self):
1257        r = lambda notice: None
1258        self.assertIsNone(self.c.set_notice_receiver(r))
1259        self.assertIs(self.c.get_notice_receiver(), r)
1260
1261    def testNoticeReceiver(self):
1262        self.c.query('''create function bilbo_notice() returns void AS $$
1263            begin
1264                raise warning 'Bilbo was here!';
1265            end;
1266            $$ language plpgsql''')
1267        try:
1268            received = {}
1269
1270            def notice_receiver(notice):
1271                for attr in dir(notice):
1272                    if attr.startswith('__'):
1273                        continue
1274                    value = getattr(notice, attr)
1275                    if isinstance(value, str):
1276                        value = value.replace('WARNUNG', 'WARNING')
1277                    received[attr] = value
1278
1279            self.c.set_notice_receiver(notice_receiver)
1280            self.c.query('''select bilbo_notice()''')
1281            self.assertEqual(received, dict(
1282                pgcnx=self.c, message='WARNING:  Bilbo was here!\n',
1283                severity='WARNING', primary='Bilbo was here!',
1284                detail=None, hint=None))
1285        finally:
1286            self.c.query('''drop function bilbo_notice();''')
1287
1288
1289class TestConfigFunctions(unittest.TestCase):
1290    """Test the functions for changing default settings.
1291
1292    To test the effect of most of these functions, we need a database
1293    connection.  That's why they are covered in this test module.
1294
1295    """
1296
1297    def setUp(self):
1298        self.c = connect()
1299        self.c.query("set client_encoding=utf8")
1300        self.c.query("set lc_monetary='C'")
1301
1302    def tearDown(self):
1303        self.c.close()
1304
1305    def testGetDecimalPoint(self):
1306        point = pg.get_decimal_point()
1307        # error if a parameter is passed
1308        self.assertRaises(TypeError, pg.get_decimal_point, point)
1309        self.assertIsInstance(point, str)
1310        self.assertEqual(point, '.')  # the default setting
1311        pg.set_decimal_point(',')
1312        try:
1313            r = pg.get_decimal_point()
1314        finally:
1315            pg.set_decimal_point(point)
1316        self.assertIsInstance(r, str)
1317        self.assertEqual(r, ',')
1318        pg.set_decimal_point("'")
1319        try:
1320            r = pg.get_decimal_point()
1321        finally:
1322            pg.set_decimal_point(point)
1323        self.assertIsInstance(r, str)
1324        self.assertEqual(r, "'")
1325        pg.set_decimal_point('')
1326        try:
1327            r = pg.get_decimal_point()
1328        finally:
1329            pg.set_decimal_point(point)
1330        self.assertIsNone(r)
1331        pg.set_decimal_point(None)
1332        try:
1333            r = pg.get_decimal_point()
1334        finally:
1335            pg.set_decimal_point(point)
1336        self.assertIsNone(r)
1337
1338    def testSetDecimalPoint(self):
1339        d = pg.Decimal
1340        point = pg.get_decimal_point()
1341        self.assertRaises(TypeError, pg.set_decimal_point)
1342        # error if decimal point is not a string
1343        self.assertRaises(TypeError, pg.set_decimal_point, 0)
1344        # error if more than one decimal point passed
1345        self.assertRaises(TypeError, pg.set_decimal_point, '.', ',')
1346        self.assertRaises(TypeError, pg.set_decimal_point, '.,')
1347        # error if decimal point is not a punctuation character
1348        self.assertRaises(TypeError, pg.set_decimal_point, '0')
1349        query = self.c.query
1350        # check that money values are interpreted as decimal values
1351        # only if decimal_point is set, and that the result is correct
1352        # only if it is set suitable for the current lc_monetary setting
1353        select_money = "select '34.25'::money"
1354        proper_money = d('34.25')
1355        bad_money = d('3425')
1356        en_locales = 'en', 'en_US', 'en_US.utf8', 'en_US.UTF-8'
1357        en_money = '$34.25', '$ 34.25', '34.25$', '34.25 $', '34.25 Dollar'
1358        de_locales = 'de', 'de_DE', 'de_DE.utf8', 'de_DE.UTF-8'
1359        de_money = ('34,25€', '34,25 €', '€34,25' '€ 34,25',
1360            '34,25 EUR', '34,25 Euro', '34,25 DM')
1361        # first try with English localization (using the point)
1362        for lc in en_locales:
1363            try:
1364                query("set lc_monetary='%s'" % lc)
1365            except pg.ProgrammingError:
1366                pass
1367            else:
1368                break
1369        else:
1370            self.skipTest("cannot set English money locale")
1371        try:
1372            r = query(select_money)
1373        except pg.ProgrammingError:
1374            # this can happen if the currency signs cannot be
1375            # converted using the encoding of the test database
1376            self.skipTest("database does not support English money")
1377        pg.set_decimal_point(None)
1378        try:
1379            r = r.getresult()[0][0]
1380        finally:
1381            pg.set_decimal_point(point)
1382        self.assertIsInstance(r, str)
1383        self.assertIn(r, en_money)
1384        r = query(select_money)
1385        pg.set_decimal_point('')
1386        try:
1387            r = r.getresult()[0][0]
1388        finally:
1389            pg.set_decimal_point(point)
1390        self.assertIsInstance(r, str)
1391        self.assertIn(r, en_money)
1392        r = query(select_money)
1393        pg.set_decimal_point('.')
1394        try:
1395            r = r.getresult()[0][0]
1396        finally:
1397            pg.set_decimal_point(point)
1398        self.assertIsInstance(r, d)
1399        self.assertEqual(r, proper_money)
1400        r = query(select_money)
1401        pg.set_decimal_point(',')
1402        try:
1403            r = r.getresult()[0][0]
1404        finally:
1405            pg.set_decimal_point(point)
1406        self.assertIsInstance(r, d)
1407        self.assertEqual(r, bad_money)
1408        r = query(select_money)
1409        pg.set_decimal_point("'")
1410        try:
1411            r = r.getresult()[0][0]
1412        finally:
1413            pg.set_decimal_point(point)
1414        self.assertIsInstance(r, d)
1415        self.assertEqual(r, bad_money)
1416        # then try with German localization (using the comma)
1417        for lc in de_locales:
1418            try:
1419                query("set lc_monetary='%s'" % lc)
1420            except pg.ProgrammingError:
1421                pass
1422            else:
1423                break
1424        else:
1425            self.skipTest("cannot set German money locale")
1426        select_money = select_money.replace('.', ',')
1427        try:
1428            r = query(select_money)
1429        except pg.ProgrammingError:
1430            self.skipTest("database does not support English money")
1431        pg.set_decimal_point(None)
1432        try:
1433            r = r.getresult()[0][0]
1434        finally:
1435            pg.set_decimal_point(point)
1436        self.assertIsInstance(r, str)
1437        self.assertIn(r, de_money)
1438        r = query(select_money)
1439        pg.set_decimal_point('')
1440        try:
1441            r = r.getresult()[0][0]
1442        finally:
1443            pg.set_decimal_point(point)
1444        self.assertIsInstance(r, str)
1445        self.assertIn(r, de_money)
1446        r = query(select_money)
1447        pg.set_decimal_point(',')
1448        try:
1449            r = r.getresult()[0][0]
1450        finally:
1451            pg.set_decimal_point(point)
1452        self.assertIsInstance(r, d)
1453        self.assertEqual(r, proper_money)
1454        r = query(select_money)
1455        pg.set_decimal_point('.')
1456        try:
1457            r = r.getresult()[0][0]
1458        finally:
1459            pg.set_decimal_point(point)
1460        self.assertEqual(r, bad_money)
1461        r = query(select_money)
1462        pg.set_decimal_point("'")
1463        try:
1464            r = r.getresult()[0][0]
1465        finally:
1466            pg.set_decimal_point(point)
1467        self.assertEqual(r, bad_money)
1468
1469    def testGetDecimal(self):
1470        decimal_class = pg.get_decimal()
1471        # error if a parameter is passed
1472        self.assertRaises(TypeError, pg.get_decimal, decimal_class)
1473        self.assertIs(decimal_class, pg.Decimal)  # the default setting
1474        pg.set_decimal(int)
1475        try:
1476            r = pg.get_decimal()
1477        finally:
1478            pg.set_decimal(decimal_class)
1479        self.assertIs(r, int)
1480        r = pg.get_decimal()
1481        self.assertIs(r, decimal_class)
1482
1483    def testSetDecimal(self):
1484        decimal_class = pg.get_decimal()
1485        # error if no parameter is passed
1486        self.assertRaises(TypeError, pg.set_decimal)
1487        query = self.c.query
1488        try:
1489            r = query("select 3425::numeric")
1490        except pg.ProgrammingError:
1491            self.skipTest('database does not support numeric')
1492        r = r.getresult()[0][0]
1493        self.assertIsInstance(r, decimal_class)
1494        self.assertEqual(r, decimal_class('3425'))
1495        r = query("select 3425::numeric")
1496        pg.set_decimal(int)
1497        try:
1498            r = r.getresult()[0][0]
1499        finally:
1500            pg.set_decimal(decimal_class)
1501        self.assertNotIsInstance(r, decimal_class)
1502        self.assertIsInstance(r, int)
1503        self.assertEqual(r, int(3425))
1504
1505    def testGetBool(self):
1506        use_bool = pg.get_bool()
1507        # error if a parameter is passed
1508        self.assertRaises(TypeError, pg.get_bool, use_bool)
1509        self.assertIsInstance(use_bool, bool)
1510        self.assertIs(use_bool, False)  # the default setting
1511        pg.set_bool(True)
1512        try:
1513            r = pg.get_bool()
1514        finally:
1515            pg.set_bool(use_bool)
1516        self.assertIsInstance(r, bool)
1517        self.assertIs(r, True)
1518        pg.set_bool(False)
1519        try:
1520            r = pg.get_bool()
1521        finally:
1522            pg.set_bool(use_bool)
1523        self.assertIsInstance(r, bool)
1524        self.assertIs(r, False)
1525        pg.set_bool(1)
1526        try:
1527            r = pg.get_bool()
1528        finally:
1529            pg.set_bool(use_bool)
1530        self.assertIsInstance(r, bool)
1531        self.assertIs(r, True)
1532        pg.set_bool(0)
1533        try:
1534            r = pg.get_bool()
1535        finally:
1536            pg.set_bool(use_bool)
1537        self.assertIsInstance(r, bool)
1538        self.assertIs(r, False)
1539
1540    def testSetBool(self):
1541        use_bool = pg.get_bool()
1542        # error if no parameter is passed
1543        self.assertRaises(TypeError, pg.set_bool)
1544        query = self.c.query
1545        try:
1546            r = query("select true::bool")
1547        except pg.ProgrammingError:
1548            self.skipTest('database does not support bool')
1549        r = r.getresult()[0][0]
1550        self.assertIsInstance(r, str)
1551        self.assertEqual(r, 't')
1552        r = query("select true::bool")
1553        pg.set_bool(True)
1554        try:
1555            r = r.getresult()[0][0]
1556        finally:
1557            pg.set_bool(use_bool)
1558        self.assertIsInstance(r, bool)
1559        self.assertIs(r, True)
1560        r = query("select true::bool")
1561        pg.set_bool(False)
1562        try:
1563            r = r.getresult()[0][0]
1564        finally:
1565            pg.set_bool(use_bool)
1566        self.assertIsInstance(r, str)
1567        self.assertIs(r, 't')
1568
1569    def testGetNamedresult(self):
1570        namedresult = pg.get_namedresult()
1571        # error if a parameter is passed
1572        self.assertRaises(TypeError, pg.get_namedresult, namedresult)
1573        self.assertIs(namedresult, pg._namedresult)  # the default setting
1574
1575    def testSetNamedresult(self):
1576        namedresult = pg.get_namedresult()
1577        self.assertTrue(callable(namedresult))
1578
1579        query = self.c.query
1580
1581        r = query("select 1 as x, 2 as y").namedresult()[0]
1582        self.assertIsInstance(r, tuple)
1583        self.assertEqual(r, (1, 2))
1584        self.assertIsNot(type(r), tuple)
1585        self.assertEqual(r._fields, ('x', 'y'))
1586        self.assertEqual(r._asdict(), {'x': 1, 'y': 2})
1587        self.assertEqual(r.__class__.__name__, 'Row')
1588
1589        def listresult(q):
1590            return [list(row) for row in q.getresult()]
1591
1592        pg.set_namedresult(listresult)
1593        try:
1594            r = pg.get_namedresult()
1595            self.assertIs(r, listresult)
1596            r = query("select 1 as x, 2 as y").namedresult()[0]
1597            self.assertIsInstance(r, list)
1598            self.assertEqual(r, [1, 2])
1599            self.assertIsNot(type(r), tuple)
1600            self.assertFalse(hasattr(r, '_fields'))
1601            self.assertNotEqual(r.__class__.__name__, 'Row')
1602        finally:
1603            pg.set_namedresult(namedresult)
1604
1605        r = pg.get_namedresult()
1606        self.assertIs(r, namedresult)
1607
1608
1609class TestStandaloneEscapeFunctions(unittest.TestCase):
1610    """Test pg escape functions.
1611
1612    The libpq interface memorizes some parameters of the last opened
1613    connection that influence the result of these functions.  Therefore
1614    we need to open a connection with fixed parameters prior to testing
1615    in order to ensure that the tests always run under the same conditions.
1616    That's why these tests are included in this test module.
1617
1618    """
1619
1620    @classmethod
1621    def setUpClass(cls):
1622        query = connect().query
1623        query('set client_encoding=sql_ascii')
1624        query('set standard_conforming_strings=off')
1625        query('set bytea_output=escape')
1626
1627    def testEscapeString(self):
1628        f = pg.escape_string
1629        r = f(b'plain')
1630        self.assertIsInstance(r, bytes)
1631        self.assertEqual(r, b'plain')
1632        r = f(u'plain')
1633        self.assertIsInstance(r, unicode)
1634        self.assertEqual(r, u'plain')
1635        r = f(u"das is' kÀse".encode('utf-8'))
1636        self.assertIsInstance(r, bytes)
1637        self.assertEqual(r, u"das is'' kÀse".encode('utf-8'))
1638        r = f(u"that's cheesy")
1639        self.assertIsInstance(r, unicode)
1640        self.assertEqual(r, u"that''s cheesy")
1641        r = f(r"It's bad to have a \ inside.")
1642        self.assertEqual(r, r"It''s bad to have a \\ inside.")
1643
1644    def testEscapeBytea(self):
1645        f = pg.escape_bytea
1646        r = f(b'plain')
1647        self.assertIsInstance(r, bytes)
1648        self.assertEqual(r, b'plain')
1649        r = f(u'plain')
1650        self.assertIsInstance(r, unicode)
1651        self.assertEqual(r, u'plain')
1652        r = f(u"das is' kÀse".encode('utf-8'))
1653        self.assertIsInstance(r, bytes)
1654        self.assertEqual(r, b"das is'' k\\\\303\\\\244se")
1655        r = f(u"that's cheesy")
1656        self.assertIsInstance(r, unicode)
1657        self.assertEqual(r, u"that''s cheesy")
1658        r = f(b'O\x00ps\xff!')
1659        self.assertEqual(r, b'O\\\\000ps\\\\377!')
1660
1661
1662if __name__ == '__main__':
1663    unittest.main()
Note: See TracBrowser for help on using the repository browser.