source: trunk/tests/test_classic_connection.py @ 744

Last change on this file since 744 was 744, checked in by cito, 4 years ago

Use cleanup feature in unittests

  • Property svn:executable set to *
  • Property svn:keywords set to Id
File size: 59.0 KB
Line 
1#! /usr/bin/python
2# -*- coding: utf-8 -*-
3
4"""Test the classic PyGreSQL interface.
5
6Sub-tests for the low-level connection object.
7
8Contributed by Christoph Zwerschke.
9
10These tests need a database to test against.
11
12"""
13
14try:
15    import unittest2 as unittest  # for Python < 2.7
16except ImportError:
17    import unittest
18import threading
19import time
20import os
21
22import pg  # the module under test
23
24from decimal import Decimal
25
26# We need a database to test against.  If LOCAL_PyGreSQL.py exists we will
27# get our information from that.  Otherwise we use the defaults.
28# These tests should be run with various PostgreSQL versions and databases
29# created with different encodings and locales.  Particularly, make sure the
30# tests are running against databases created with both SQL_ASCII and UTF8.
31dbname = 'unittest'
32dbhost = None
33dbport = 5432
34
35try:
36    from .LOCAL_PyGreSQL import *
37except (ImportError, ValueError):
38    try:
39        from LOCAL_PyGreSQL import *
40    except ImportError:
41        pass
42
43try:
44    long
45except NameError:  # Python >= 3.0
46    long = int
47
48try:
49    unicode
50except NameError:  # Python >= 3.0
51    unicode = str
52
53unicode_strings = str is not bytes
54
55windows = os.name == 'nt'
56
57# There is a known a bug in libpq under Windows which can cause
58# the interface to crash when calling PQhost():
59do_not_ask_for_host = windows
60do_not_ask_for_host_reason = 'libpq issue on Windows'
61
62
63def connect():
64    """Create a basic pg connection to the test database."""
65    connection = pg.connect(dbname, dbhost, dbport)
66    connection.query("set client_min_messages=warning")
67    return connection
68
69
70class TestCanConnect(unittest.TestCase):
71    """Test whether a basic connection to PostgreSQL is possible."""
72
73    def testCanConnect(self):
74        try:
75            connection = connect()
76        except pg.Error as error:
77            self.fail('Cannot connect to database %s:\n%s' % (dbname, error))
78        try:
79            connection.close()
80        except pg.Error:
81            self.fail('Cannot close the database connection')
82
83
84class TestConnectObject(unittest.TestCase):
85    """Test existence of basic pg connection methods."""
86
87    def setUp(self):
88        self.connection = connect()
89
90    def tearDown(self):
91        try:
92            self.connection.close()
93        except pg.InternalError:
94            pass
95
96    def is_method(self, attribute):
97        """Check if given attribute on the connection is a method."""
98        if do_not_ask_for_host and attribute == 'host':
99            return False
100        return callable(getattr(self.connection, attribute))
101
102    def testClassName(self):
103        self.assertEqual(self.connection.__class__.__name__, 'Connection')
104
105    def testModuleName(self):
106        self.assertEqual(self.connection.__class__.__module__, 'pg')
107
108    def testStr(self):
109        r = str(self.connection)
110        self.assertTrue(r.startswith('<pg.Connection object'), r)
111
112    def testRepr(self):
113        r = repr(self.connection)
114        self.assertTrue(r.startswith('<pg.Connection object'), r)
115
116    def testAllConnectAttributes(self):
117        attributes = '''db error host options port
118            protocol_version server_version status user'''.split()
119        connection_attributes = [a for a in dir(self.connection)
120            if not a.startswith('__') and not self.is_method(a)]
121        self.assertEqual(attributes, connection_attributes)
122
123    def testAllConnectMethods(self):
124        methods = '''cancel close endcopy
125            escape_bytea escape_identifier escape_literal escape_string
126            fileno get_notice_receiver getline getlo getnotify
127            inserttable locreate loimport parameter putline query reset
128            set_notice_receiver source transaction'''.split()
129        connection_methods = [a for a in dir(self.connection)
130            if not a.startswith('__') and self.is_method(a)]
131        self.assertEqual(methods, connection_methods)
132
133    def testAttributeDb(self):
134        self.assertEqual(self.connection.db, dbname)
135
136    def testAttributeError(self):
137        error = self.connection.error
138        self.assertTrue(not error or 'krb5_' in error)
139
140    @unittest.skipIf(do_not_ask_for_host, do_not_ask_for_host_reason)
141    def testAttributeHost(self):
142        def_host = 'localhost'
143        self.assertIsInstance(self.connection.host, str)
144        self.assertEqual(self.connection.host, dbhost or def_host)
145
146    def testAttributeOptions(self):
147        no_options = ''
148        self.assertEqual(self.connection.options, no_options)
149
150    def testAttributePort(self):
151        def_port = 5432
152        self.assertIsInstance(self.connection.port, int)
153        self.assertEqual(self.connection.port, dbport or def_port)
154
155    def testAttributeProtocolVersion(self):
156        protocol_version = self.connection.protocol_version
157        self.assertIsInstance(protocol_version, int)
158        self.assertTrue(2 <= protocol_version < 4)
159
160    def testAttributeServerVersion(self):
161        server_version = self.connection.server_version
162        self.assertIsInstance(server_version, int)
163        self.assertTrue(70400 <= server_version < 100000)
164
165    def testAttributeStatus(self):
166        status_ok = 1
167        self.assertIsInstance(self.connection.status, int)
168        self.assertEqual(self.connection.status, status_ok)
169
170    def testAttributeUser(self):
171        no_user = 'Deprecated facility'
172        user = self.connection.user
173        self.assertTrue(user)
174        self.assertIsInstance(user, str)
175        self.assertNotEqual(user, no_user)
176
177    def testMethodQuery(self):
178        query = self.connection.query
179        query("select 1+1")
180        query("select 1+$1", (1,))
181        query("select 1+$1+$2", (2, 3))
182        query("select 1+$1+$2", [2, 3])
183
184    def testMethodQueryEmpty(self):
185        self.assertRaises(ValueError, self.connection.query, '')
186
187    def testMethodEndcopy(self):
188        try:
189            self.connection.endcopy()
190        except IOError:
191            pass
192
193    def testMethodClose(self):
194        self.connection.close()
195        try:
196            self.connection.reset()
197        except (pg.Error, TypeError):
198            pass
199        else:
200            self.fail('Reset should give an error for a closed connection')
201        self.assertRaises(pg.InternalError, self.connection.close)
202        try:
203            self.connection.query('select 1')
204        except (pg.Error, TypeError):
205            pass
206        else:
207            self.fail('Query should give an error for a closed connection')
208        self.connection = connect()
209
210    def testMethodReset(self):
211        query = self.connection.query
212        # check that client encoding gets reset
213        encoding = query('show client_encoding').getresult()[0][0].upper()
214        changed_encoding = 'LATIN1' if encoding == 'UTF8' else 'UTF8'
215        self.assertNotEqual(encoding, changed_encoding)
216        self.connection.query("set client_encoding=%s" % changed_encoding)
217        new_encoding = query('show client_encoding').getresult()[0][0].upper()
218        self.assertEqual(new_encoding, changed_encoding)
219        self.connection.reset()
220        new_encoding = query('show client_encoding').getresult()[0][0].upper()
221        self.assertNotEqual(new_encoding, changed_encoding)
222        self.assertEqual(new_encoding, encoding)
223
224    def testMethodCancel(self):
225        r = self.connection.cancel()
226        self.assertIsInstance(r, int)
227        self.assertEqual(r, 1)
228
229    def testCancelLongRunningThread(self):
230        errors = []
231
232        def sleep():
233            try:
234                self.connection.query('select pg_sleep(5)').getresult()
235            except pg.ProgrammingError as error:
236                errors.append(str(error))
237
238        thread = threading.Thread(target=sleep)
239        t1 = time.time()
240        thread.start()  # run the query
241        while 1:  # make sure the query is really running
242            time.sleep(0.1)
243            if thread.is_alive() or time.time() - t1 > 5:
244                break
245        r = self.connection.cancel()  # cancel the running query
246        thread.join()  # wait for the thread to end
247        t2 = time.time()
248
249        self.assertIsInstance(r, int)
250        self.assertEqual(r, 1)  # return code should be 1
251        self.assertLessEqual(t2 - t1, 3)  # time should be under 3 seconds
252        self.assertTrue(errors)
253
254    def testMethodFileNo(self):
255        r = self.connection.fileno()
256        self.assertIsInstance(r, int)
257        self.assertGreaterEqual(r, 0)
258
259
260class TestSimpleQueries(unittest.TestCase):
261    """Test simple queries via a basic pg connection."""
262
263    def setUp(self):
264        self.c = connect()
265
266    def tearDown(self):
267        self.doCleanups()
268        self.c.close()
269
270    def testClassName(self):
271        r = self.c.query("select 1")
272        self.assertEqual(r.__class__.__name__, 'Query')
273
274    def testModuleName(self):
275        r = self.c.query("select 1")
276        self.assertEqual(r.__class__.__module__, 'pg')
277
278    def testStr(self):
279        q = ("select 1 as a, 'hello' as h, 'w' as world"
280            " union select 2, 'xyz', 'uvw'")
281        r = self.c.query(q)
282        self.assertEqual(str(r),
283            'a|  h  |world\n'
284            '-+-----+-----\n'
285            '1|hello|w    \n'
286            '2|xyz  |uvw  \n'
287            '(2 rows)')
288
289    def testRepr(self):
290        r = repr(self.c.query("select 1"))
291        self.assertTrue(r.startswith('<pg.Query object'), r)
292
293    def testSelect0(self):
294        q = "select 0"
295        self.c.query(q)
296
297    def testSelect0Semicolon(self):
298        q = "select 0;"
299        self.c.query(q)
300
301    def testSelectDotSemicolon(self):
302        q = "select .;"
303        self.assertRaises(pg.ProgrammingError, self.c.query, q)
304
305    def testGetresult(self):
306        q = "select 0"
307        result = [(0,)]
308        r = self.c.query(q).getresult()
309        self.assertIsInstance(r, list)
310        v = r[0]
311        self.assertIsInstance(v, tuple)
312        self.assertIsInstance(v[0], int)
313        self.assertEqual(r, result)
314
315    def testGetresultLong(self):
316        q = "select 9876543210"
317        result = long(9876543210)
318        self.assertIsInstance(result, long)
319        v = self.c.query(q).getresult()[0][0]
320        self.assertIsInstance(v, long)
321        self.assertEqual(v, result)
322
323    def testGetresultDecimal(self):
324        q = "select 98765432109876543210"
325        result = Decimal(98765432109876543210)
326        v = self.c.query(q).getresult()[0][0]
327        self.assertIsInstance(v, Decimal)
328        self.assertEqual(v, result)
329
330    def testGetresultString(self):
331        result = 'Hello, world!'
332        q = "select '%s'" % result
333        v = self.c.query(q).getresult()[0][0]
334        self.assertIsInstance(v, str)
335        self.assertEqual(v, result)
336
337    def testDictresult(self):
338        q = "select 0 as alias0"
339        result = [{'alias0': 0}]
340        r = self.c.query(q).dictresult()
341        self.assertIsInstance(r, list)
342        v = r[0]
343        self.assertIsInstance(v, dict)
344        self.assertIsInstance(v['alias0'], int)
345        self.assertEqual(r, result)
346
347    def testDictresultLong(self):
348        q = "select 9876543210 as longjohnsilver"
349        result = long(9876543210)
350        self.assertIsInstance(result, long)
351        v = self.c.query(q).dictresult()[0]['longjohnsilver']
352        self.assertIsInstance(v, long)
353        self.assertEqual(v, result)
354
355    def testDictresultDecimal(self):
356        q = "select 98765432109876543210 as longjohnsilver"
357        result = Decimal(98765432109876543210)
358        v = self.c.query(q).dictresult()[0]['longjohnsilver']
359        self.assertIsInstance(v, Decimal)
360        self.assertEqual(v, result)
361
362    def testDictresultString(self):
363        result = 'Hello, world!'
364        q = "select '%s' as greeting" % result
365        v = self.c.query(q).dictresult()[0]['greeting']
366        self.assertIsInstance(v, str)
367        self.assertEqual(v, result)
368
369    def testNamedresult(self):
370        q = "select 0 as alias0"
371        result = [(0,)]
372        r = self.c.query(q).namedresult()
373        self.assertEqual(r, result)
374        v = r[0]
375        self.assertEqual(v._fields, ('alias0',))
376        self.assertEqual(v.alias0, 0)
377
378    def testGet3Cols(self):
379        q = "select 1,2,3"
380        result = [(1, 2, 3)]
381        r = self.c.query(q).getresult()
382        self.assertEqual(r, result)
383
384    def testGet3DictCols(self):
385        q = "select 1 as a,2 as b,3 as c"
386        result = [dict(a=1, b=2, c=3)]
387        r = self.c.query(q).dictresult()
388        self.assertEqual(r, result)
389
390    def testGet3NamedCols(self):
391        q = "select 1 as a,2 as b,3 as c"
392        result = [(1, 2, 3)]
393        r = self.c.query(q).namedresult()
394        self.assertEqual(r, result)
395        v = r[0]
396        self.assertEqual(v._fields, ('a', 'b', 'c'))
397        self.assertEqual(v.b, 2)
398
399    def testGet3Rows(self):
400        q = "select 3 union select 1 union select 2 order by 1"
401        result = [(1,), (2,), (3,)]
402        r = self.c.query(q).getresult()
403        self.assertEqual(r, result)
404
405    def testGet3DictRows(self):
406        q = ("select 3 as alias3"
407            " union select 1 union select 2 order by 1")
408        result = [{'alias3': 1}, {'alias3': 2}, {'alias3': 3}]
409        r = self.c.query(q).dictresult()
410        self.assertEqual(r, result)
411
412    def testGet3NamedRows(self):
413        q = ("select 3 as alias3"
414            " union select 1 union select 2 order by 1")
415        result = [(1,), (2,), (3,)]
416        r = self.c.query(q).namedresult()
417        self.assertEqual(r, result)
418        for v in r:
419            self.assertEqual(v._fields, ('alias3',))
420
421    def testDictresultNames(self):
422        q = "select 'MixedCase' as MixedCaseAlias"
423        result = [{'mixedcasealias': 'MixedCase'}]
424        r = self.c.query(q).dictresult()
425        self.assertEqual(r, result)
426        q = "select 'MixedCase' as \"MixedCaseAlias\""
427        result = [{'MixedCaseAlias': 'MixedCase'}]
428        r = self.c.query(q).dictresult()
429        self.assertEqual(r, result)
430
431    def testNamedresultNames(self):
432        q = "select 'MixedCase' as MixedCaseAlias"
433        result = [('MixedCase',)]
434        r = self.c.query(q).namedresult()
435        self.assertEqual(r, result)
436        v = r[0]
437        self.assertEqual(v._fields, ('mixedcasealias',))
438        self.assertEqual(v.mixedcasealias, 'MixedCase')
439        q = "select 'MixedCase' as \"MixedCaseAlias\""
440        r = self.c.query(q).namedresult()
441        self.assertEqual(r, result)
442        v = r[0]
443        self.assertEqual(v._fields, ('MixedCaseAlias',))
444        self.assertEqual(v.MixedCaseAlias, 'MixedCase')
445
446    def testBigGetresult(self):
447        num_cols = 100
448        num_rows = 100
449        q = "select " + ','.join(map(str, range(num_cols)))
450        q = ' union all '.join((q,) * num_rows)
451        r = self.c.query(q).getresult()
452        result = [tuple(range(num_cols))] * num_rows
453        self.assertEqual(r, result)
454
455    def testListfields(self):
456        q = ('select 0 as a, 0 as b, 0 as c,'
457            ' 0 as c, 0 as b, 0 as a,'
458            ' 0 as lowercase, 0 as UPPERCASE,'
459            ' 0 as MixedCase, 0 as "MixedCase",'
460            ' 0 as a_long_name_with_underscores,'
461            ' 0 as "A long name with Blanks"')
462        r = self.c.query(q).listfields()
463        result = ('a', 'b', 'c', 'c', 'b', 'a',
464            'lowercase', 'uppercase', 'mixedcase', 'MixedCase',
465            'a_long_name_with_underscores',
466            'A long name with Blanks')
467        self.assertEqual(r, result)
468
469    def testFieldname(self):
470        q = "select 0 as z, 0 as a, 0 as x, 0 as y"
471        r = self.c.query(q).fieldname(2)
472        self.assertEqual(r, 'x')
473        r = self.c.query(q).fieldname(3)
474        self.assertEqual(r, 'y')
475
476    def testFieldnum(self):
477        q = "select 1 as x"
478        self.assertRaises(ValueError, self.c.query(q).fieldnum, 'y')
479        q = "select 1 as x"
480        r = self.c.query(q).fieldnum('x')
481        self.assertIsInstance(r, int)
482        self.assertEqual(r, 0)
483        q = "select 0 as z, 0 as a, 0 as x, 0 as y"
484        r = self.c.query(q).fieldnum('x')
485        self.assertIsInstance(r, int)
486        self.assertEqual(r, 2)
487        r = self.c.query(q).fieldnum('y')
488        self.assertIsInstance(r, int)
489        self.assertEqual(r, 3)
490
491    def testNtuples(self):
492        q = "select 1 where false"
493        r = self.c.query(q).ntuples()
494        self.assertIsInstance(r, int)
495        self.assertEqual(r, 0)
496        q = ("select 1 as a, 2 as b, 3 as c, 4 as d"
497            " union select 5 as a, 6 as b, 7 as c, 8 as d")
498        r = self.c.query(q).ntuples()
499        self.assertIsInstance(r, int)
500        self.assertEqual(r, 2)
501        q = ("select 1 union select 2 union select 3"
502            " union select 4 union select 5 union select 6")
503        r = self.c.query(q).ntuples()
504        self.assertIsInstance(r, int)
505        self.assertEqual(r, 6)
506
507    def testQuery(self):
508        query = self.c.query
509        query("drop table if exists test_table")
510        self.addCleanup(query, "drop table test_table")
511        q = "create table test_table (n integer) with oids"
512        r = query(q)
513        self.assertIsNone(r)
514        q = "insert into test_table values (1)"
515        r = query(q)
516        self.assertIsInstance(r, int)
517        q = "insert into test_table select 2"
518        r = query(q)
519        self.assertIsInstance(r, int)
520        oid = r
521        q = "select oid from test_table where n=2"
522        r = query(q).getresult()
523        self.assertEqual(len(r), 1)
524        r = r[0]
525        self.assertEqual(len(r), 1)
526        r = r[0]
527        self.assertIsInstance(r, int)
528        self.assertEqual(r, oid)
529        q = "insert into test_table select 3 union select 4 union select 5"
530        r = query(q)
531        self.assertIsInstance(r, str)
532        self.assertEqual(r, '3')
533        q = "update test_table set n=4 where n<5"
534        r = query(q)
535        self.assertIsInstance(r, str)
536        self.assertEqual(r, '4')
537        q = "delete from test_table"
538        r = query(q)
539        self.assertIsInstance(r, str)
540        self.assertEqual(r, '5')
541
542
543class TestUnicodeQueries(unittest.TestCase):
544    """Test unicode strings as queries via a basic pg connection."""
545
546    def setUp(self):
547        self.c = connect()
548        self.c.query('set client_encoding=utf8')
549
550    def tearDown(self):
551        self.c.close()
552
553    def testGetresulAscii(self):
554        result = u'Hello, world!'
555        q = u"select '%s'" % result
556        v = self.c.query(q).getresult()[0][0]
557        self.assertIsInstance(v, str)
558        self.assertEqual(v, result)
559
560    def testDictresulAscii(self):
561        result = u'Hello, world!'
562        q = u"select '%s' as greeting" % result
563        v = self.c.query(q).dictresult()[0]['greeting']
564        self.assertIsInstance(v, str)
565        self.assertEqual(v, result)
566
567    def testGetresultUtf8(self):
568        result = u'Hello, wörld & ЌОр!'
569        q = u"select '%s'" % result
570        if not unicode_strings:
571            result = result.encode('utf8')
572        # pass the query as unicode
573        try:
574            v = self.c.query(q).getresult()[0][0]
575        except pg.ProgrammingError:
576            self.skipTest("database does not support utf8")
577        self.assertIsInstance(v, str)
578        self.assertEqual(v, result)
579        q = q.encode('utf8')
580        # pass the query as bytes
581        v = self.c.query(q).getresult()[0][0]
582        self.assertIsInstance(v, str)
583        self.assertEqual(v, result)
584
585    def testDictresultUtf8(self):
586        result = u'Hello, wörld & ЌОр!'
587        q = u"select '%s' as greeting" % result
588        if not unicode_strings:
589            result = result.encode('utf8')
590        try:
591            v = self.c.query(q).dictresult()[0]['greeting']
592        except pg.ProgrammingError:
593            self.skipTest("database does not support utf8")
594        self.assertIsInstance(v, str)
595        self.assertEqual(v, result)
596        q = q.encode('utf8')
597        v = self.c.query(q).dictresult()[0]['greeting']
598        self.assertIsInstance(v, str)
599        self.assertEqual(v, result)
600
601    def testDictresultLatin1(self):
602        try:
603            self.c.query('set client_encoding=latin1')
604        except pg.ProgrammingError:
605            self.skipTest("database does not support latin1")
606        result = u'Hello, wörld!'
607        q = u"select '%s'" % result
608        if not unicode_strings:
609            result = result.encode('latin1')
610        v = self.c.query(q).getresult()[0][0]
611        self.assertIsInstance(v, str)
612        self.assertEqual(v, result)
613        q = q.encode('latin1')
614        v = self.c.query(q).getresult()[0][0]
615        self.assertIsInstance(v, str)
616        self.assertEqual(v, result)
617
618    def testDictresultLatin1(self):
619        try:
620            self.c.query('set client_encoding=latin1')
621        except pg.ProgrammingError:
622            self.skipTest("database does not support latin1")
623        result = u'Hello, wörld!'
624        q = u"select '%s' as greeting" % result
625        if not unicode_strings:
626            result = result.encode('latin1')
627        v = self.c.query(q).dictresult()[0]['greeting']
628        self.assertIsInstance(v, str)
629        self.assertEqual(v, result)
630        q = q.encode('latin1')
631        v = self.c.query(q).dictresult()[0]['greeting']
632        self.assertIsInstance(v, str)
633        self.assertEqual(v, result)
634
635    def testGetresultCyrillic(self):
636        try:
637            self.c.query('set client_encoding=iso_8859_5')
638        except pg.ProgrammingError:
639            self.skipTest("database does not support cyrillic")
640        result = u'Hello, ЌОр!'
641        q = u"select '%s'" % result
642        if not unicode_strings:
643            result = result.encode('cyrillic')
644        v = self.c.query(q).getresult()[0][0]
645        self.assertIsInstance(v, str)
646        self.assertEqual(v, result)
647        q = q.encode('cyrillic')
648        v = self.c.query(q).getresult()[0][0]
649        self.assertIsInstance(v, str)
650        self.assertEqual(v, result)
651
652    def testDictresultCyrillic(self):
653        try:
654            self.c.query('set client_encoding=iso_8859_5')
655        except pg.ProgrammingError:
656            self.skipTest("database does not support cyrillic")
657        result = u'Hello, ЌОр!'
658        q = u"select '%s' as greeting" % result
659        if not unicode_strings:
660            result = result.encode('cyrillic')
661        v = self.c.query(q).dictresult()[0]['greeting']
662        self.assertIsInstance(v, str)
663        self.assertEqual(v, result)
664        q = q.encode('cyrillic')
665        v = self.c.query(q).dictresult()[0]['greeting']
666        self.assertIsInstance(v, str)
667        self.assertEqual(v, result)
668
669    def testGetresultLatin9(self):
670        try:
671            self.c.query('set client_encoding=latin9')
672        except pg.ProgrammingError:
673            self.skipTest("database does not support latin9")
674        result = u'smœrebrœd with praÅŸská Å¡unka (pay in ¢, £, €, or Â¥)'
675        q = u"select '%s'" % result
676        if not unicode_strings:
677            result = result.encode('latin9')
678        v = self.c.query(q).getresult()[0][0]
679        self.assertIsInstance(v, str)
680        self.assertEqual(v, result)
681        q = q.encode('latin9')
682        v = self.c.query(q).getresult()[0][0]
683        self.assertIsInstance(v, str)
684        self.assertEqual(v, result)
685
686    def testDictresultLatin9(self):
687        try:
688            self.c.query('set client_encoding=latin9')
689        except pg.ProgrammingError:
690            self.skipTest("database does not support latin9")
691        result = u'smœrebrœd with praÅŸská Å¡unka (pay in ¢, £, €, or Â¥)'
692        q = u"select '%s' as menu" % result
693        if not unicode_strings:
694            result = result.encode('latin9')
695        v = self.c.query(q).dictresult()[0]['menu']
696        self.assertIsInstance(v, str)
697        self.assertEqual(v, result)
698        q = q.encode('latin9')
699        v = self.c.query(q).dictresult()[0]['menu']
700        self.assertIsInstance(v, str)
701        self.assertEqual(v, result)
702
703
704class TestParamQueries(unittest.TestCase):
705    """Test queries with parameters via a basic pg connection."""
706
707    def setUp(self):
708        self.c = connect()
709        self.c.query('set client_encoding=utf8')
710
711    def tearDown(self):
712        self.c.close()
713
714    def testQueryWithNoneParam(self):
715        self.assertRaises(TypeError, self.c.query, "select $1", None)
716        self.assertRaises(TypeError, self.c.query, "select $1+$2", None, None)
717        self.assertEqual(self.c.query("select $1::integer", (None,)
718            ).getresult(), [(None,)])
719        self.assertEqual(self.c.query("select $1::text", [None]
720            ).getresult(), [(None,)])
721        self.assertEqual(self.c.query("select $1::text", [[None]]
722            ).getresult(), [(None,)])
723
724    def testQueryWithBoolParams(self, use_bool=None):
725        query = self.c.query
726        if use_bool is not None:
727            use_bool_default = pg.get_bool()
728            pg.set_bool(use_bool)
729        try:
730            v_false, v_true = (False, True) if use_bool else 'ft'
731            r_false, r_true = [(v_false,)], [(v_true,)]
732            self.assertEqual(query("select false").getresult(), r_false)
733            self.assertEqual(query("select true").getresult(), r_true)
734            q = "select $1::bool"
735            self.assertEqual(query(q, (None,)).getresult(), [(None,)])
736            self.assertEqual(query(q, ('f',)).getresult(), r_false)
737            self.assertEqual(query(q, ('t',)).getresult(), r_true)
738            self.assertEqual(query(q, ('false',)).getresult(), r_false)
739            self.assertEqual(query(q, ('true',)).getresult(), r_true)
740            self.assertEqual(query(q, ('n',)).getresult(), r_false)
741            self.assertEqual(query(q, ('y',)).getresult(), r_true)
742            self.assertEqual(query(q, (0,)).getresult(), r_false)
743            self.assertEqual(query(q, (1,)).getresult(), r_true)
744            self.assertEqual(query(q, (False,)).getresult(), r_false)
745            self.assertEqual(query(q, (True,)).getresult(), r_true)
746        finally:
747            if use_bool is not None:
748                pg.set_bool(use_bool_default)
749
750    def testQueryWithBoolParamsAndUseBool(self):
751        self.testQueryWithBoolParams(use_bool=True)
752
753    def testQueryWithIntParams(self):
754        query = self.c.query
755        self.assertEqual(query("select 1+1").getresult(), [(2,)])
756        self.assertEqual(query("select 1+$1", (1,)).getresult(), [(2,)])
757        self.assertEqual(query("select 1+$1", [1]).getresult(), [(2,)])
758        self.assertEqual(query("select $1::integer", (2,)).getresult(), [(2,)])
759        self.assertEqual(query("select $1::text", (2,)).getresult(), [('2',)])
760        self.assertEqual(query("select 1+$1::numeric", [1]).getresult(),
761            [(Decimal('2'),)])
762        self.assertEqual(query("select 1, $1::integer", (2,)
763            ).getresult(), [(1, 2)])
764        self.assertEqual(query("select 1 union select $1::integer", (2,)
765            ).getresult(), [(1,), (2,)])
766        self.assertEqual(query("select $1::integer+$2", (1, 2)
767            ).getresult(), [(3,)])
768        self.assertEqual(query("select $1::integer+$2", [1, 2]
769            ).getresult(), [(3,)])
770        self.assertEqual(query("select 0+$1+$2+$3+$4+$5+$6", list(range(6))
771            ).getresult(), [(15,)])
772
773    def testQueryWithStrParams(self):
774        query = self.c.query
775        self.assertEqual(query("select $1||', world!'", ('Hello',)
776            ).getresult(), [('Hello, world!',)])
777        self.assertEqual(query("select $1||', world!'", ['Hello']
778            ).getresult(), [('Hello, world!',)])
779        self.assertEqual(query("select $1||', '||$2||'!'", ('Hello', 'world'),
780            ).getresult(), [('Hello, world!',)])
781        self.assertEqual(query("select $1::text", ('Hello, world!',)
782            ).getresult(), [('Hello, world!',)])
783        self.assertEqual(query("select $1::text,$2::text", ('Hello', 'world')
784            ).getresult(), [('Hello', 'world')])
785        self.assertEqual(query("select $1::text,$2::text", ['Hello', 'world']
786            ).getresult(), [('Hello', 'world')])
787        self.assertEqual(query("select $1::text union select $2::text",
788            ('Hello', 'world')).getresult(), [('Hello',), ('world',)])
789        self.assertEqual(query("select $1||', '||$2||'!'", ('Hello',
790            'w\xc3\xb6rld')).getresult(), [('Hello, w\xc3\xb6rld!',)])
791
792    def testQueryWithUnicodeParams(self):
793        query = self.c.query
794        try:
795            query('set client_encoding=utf8')
796            query("select 'wörld'").getresult()[0][0] == 'wörld'
797        except pg.ProgrammingError:
798            self.skipTest("database does not support utf8")
799        self.assertEqual(query("select $1||', '||$2||'!'",
800            ('Hello', u'wörld')).getresult(), [('Hello, wörld!',)])
801
802    def testQueryWithUnicodeParamsLatin1(self):
803        query = self.c.query
804        try:
805            query('set client_encoding=latin1')
806            query("select 'wörld'").getresult()[0][0] == 'wörld'
807        except pg.ProgrammingError:
808            self.skipTest("database does not support latin1")
809        r = query("select $1||', '||$2||'!'", ('Hello', u'wörld')).getresult()
810        if unicode_strings:
811            self.assertEqual(r, [('Hello, wörld!',)])
812        else:
813            self.assertEqual(r, [(u'Hello, wörld!'.encode('latin1'),)])
814        self.assertRaises(UnicodeError, query, "select $1||', '||$2||'!'",
815            ('Hello', u'ЌОр'))
816        query('set client_encoding=iso_8859_1')
817        r = query("select $1||', '||$2||'!'",
818            ('Hello', u'wörld')).getresult()
819        if unicode_strings:
820            self.assertEqual(r, [('Hello, wörld!',)])
821        else:
822            self.assertEqual(r, [(u'Hello, wörld!'.encode('latin1'),)])
823        self.assertRaises(UnicodeError, query, "select $1||', '||$2||'!'",
824            ('Hello', u'ЌОр'))
825        query('set client_encoding=sql_ascii')
826        self.assertRaises(UnicodeError, query, "select $1||', '||$2||'!'",
827            ('Hello', u'wörld'))
828
829    def testQueryWithUnicodeParamsCyrillic(self):
830        query = self.c.query
831        try:
832            query('set client_encoding=iso_8859_5')
833            query("select 'ЌОр'").getresult()[0][0] == 'ЌОр'
834        except pg.ProgrammingError:
835            self.skipTest("database does not support cyrillic")
836        self.assertRaises(UnicodeError, query, "select $1||', '||$2||'!'",
837            ('Hello', u'wörld'))
838        r = query("select $1||', '||$2||'!'",
839            ('Hello', u'ЌОр')).getresult()
840        if unicode_strings:
841            self.assertEqual(r, [('Hello, ЌОр!',)])
842        else:
843            self.assertEqual(r, [(u'Hello, ЌОр!'.encode('cyrillic'),)])
844        query('set client_encoding=sql_ascii')
845        self.assertRaises(UnicodeError, query, "select $1||', '||$2||'!'",
846            ('Hello', u'ЌОр!'))
847
848    def testQueryWithMixedParams(self):
849        self.assertEqual(self.c.query("select $1+2,$2||', world!'",
850            (1, 'Hello'),).getresult(), [(3, 'Hello, world!')])
851        self.assertEqual(self.c.query("select $1::integer,$2::date,$3::text",
852            (4711, None, 'Hello!'),).getresult(), [(4711, None, 'Hello!')])
853
854    def testQueryWithDuplicateParams(self):
855        self.assertRaises(pg.ProgrammingError,
856            self.c.query, "select $1+$1", (1,))
857        self.assertRaises(pg.ProgrammingError,
858            self.c.query, "select $1+$1", (1, 2))
859
860    def testQueryWithZeroParams(self):
861        self.assertEqual(self.c.query("select 1+1", []
862            ).getresult(), [(2,)])
863
864    def testQueryWithGarbage(self):
865        garbage = r"'\{}+()-#[]oo324"
866        self.assertEqual(self.c.query("select $1::text AS garbage", (garbage,)
867            ).dictresult(), [{'garbage': garbage}])
868
869
870class TestInserttable(unittest.TestCase):
871    """Test inserttable method."""
872
873    @classmethod
874    def setUpClass(cls):
875        c = connect()
876        c.query("drop table if exists test cascade")
877        c.query("create table test ("
878            "i2 smallint, i4 integer, i8 bigint, b boolean, dt date, ti time,"
879            "d numeric, f4 real, f8 double precision, m money,"
880            "c char(1), v4 varchar(4), c4 char(4), t text)")
881        # Check whether the test database uses SQL_ASCII - this means
882        # that it does not consider encoding when calculating lengths.
883        c.query("set client_encoding=utf8")
884        cls.has_encoding = c.query(
885            "select length('À') - length('a')").getresult()[0][0] == 0
886        c.close()
887
888    @classmethod
889    def tearDownClass(cls):
890        c = connect()
891        c.query("drop table test cascade")
892        c.close()
893
894    def setUp(self):
895        self.c = connect()
896        self.c.query("set client_encoding=utf8")
897        self.c.query("set datestyle='ISO,YMD'")
898        self.c.query("set lc_monetary='C'")
899
900    def tearDown(self):
901        self.c.query("truncate table test")
902        self.c.close()
903
904    data = [
905        (-1, -1, long(-1), True, '1492-10-12', '08:30:00',
906            -1.2345, -1.75, -1.875, '-1.25', '-', 'r?', '!u', 'xyz'),
907        (0, 0, long(0), False, '1607-04-14', '09:00:00',
908            0.0, 0.0, 0.0, '0.0', ' ', '0123', '4567', '890'),
909        (1, 1, long(1), True, '1801-03-04', '03:45:00',
910            1.23456, 1.75, 1.875, '1.25', 'x', 'bc', 'cdef', 'g'),
911        (2, 2, long(2), False, '1903-12-17', '11:22:00',
912            2.345678, 2.25, 2.125, '2.75', 'y', 'q', 'ijk', 'mnop\nstux!')]
913
914    @classmethod
915    def db_len(cls, s, encoding):
916        if cls.has_encoding:
917            s = s if isinstance(s, unicode) else s.decode(encoding)
918        else:
919            s = s.encode(encoding) if isinstance(s, unicode) else s
920        return len(s)
921
922    def get_back(self, encoding='utf-8'):
923        """Convert boolean and decimal values back."""
924        data = []
925        for row in self.c.query("select * from test order by 1").getresult():
926            self.assertIsInstance(row, tuple)
927            row = list(row)
928            if row[0] is not None:  # smallint
929                self.assertIsInstance(row[0], int)
930            if row[1] is not None:  # integer
931                self.assertIsInstance(row[1], int)
932            if row[2] is not None:  # bigint
933                self.assertIsInstance(row[2], long)
934            if row[3] is not None:  # boolean
935                self.assertIsInstance(row[3], str)
936                row[3] = {'f': False, 't': True}.get(row[3])
937            if row[4] is not None:  # date
938                self.assertIsInstance(row[4], str)
939                self.assertTrue(row[4].replace('-', '').isdigit())
940            if row[5] is not None:  # time
941                self.assertIsInstance(row[5], str)
942                self.assertTrue(row[5].replace(':', '').isdigit())
943            if row[6] is not None:  # numeric
944                self.assertIsInstance(row[6], Decimal)
945                row[6] = float(row[6])
946            if row[7] is not None:  # real
947                self.assertIsInstance(row[7], float)
948            if row[8] is not None:  # double precision
949                self.assertIsInstance(row[8], float)
950                row[8] = float(row[8])
951            if row[9] is not None:  # money
952                self.assertIsInstance(row[9], Decimal)
953                row[9] = str(float(row[9]))
954            if row[10] is not None:  # char(1)
955                self.assertIsInstance(row[10], str)
956                self.assertEqual(self.db_len(row[10], encoding), 1)
957            if row[11] is not None:  # varchar(4)
958                self.assertIsInstance(row[11], str)
959                self.assertLessEqual(self.db_len(row[11], encoding), 4)
960            if row[12] is not None:  # char(4)
961                self.assertIsInstance(row[12], str)
962                self.assertEqual(self.db_len(row[12], encoding), 4)
963                row[12] = row[12].rstrip()
964            if row[13] is not None:  # text
965                self.assertIsInstance(row[13], str)
966            row = tuple(row)
967            data.append(row)
968        return data
969
970    def testInserttable1Row(self):
971        data = self.data[2:3]
972        self.c.inserttable('test', data)
973        self.assertEqual(self.get_back(), data)
974
975    def testInserttable4Rows(self):
976        data = self.data
977        self.c.inserttable('test', data)
978        self.assertEqual(self.get_back(), data)
979
980    def testInserttableMultipleRows(self):
981        num_rows = 100
982        data = self.data[2:3] * num_rows
983        self.c.inserttable('test', data)
984        r = self.c.query("select count(*) from test").getresult()[0][0]
985        self.assertEqual(r, num_rows)
986
987    def testInserttableMultipleCalls(self):
988        num_rows = 10
989        data = self.data[2:3]
990        for _i in range(num_rows):
991            self.c.inserttable('test', data)
992        r = self.c.query("select count(*) from test").getresult()[0][0]
993        self.assertEqual(r, num_rows)
994
995    def testInserttableNullValues(self):
996        data = [(None,) * 14] * 100
997        self.c.inserttable('test', data)
998        self.assertEqual(self.get_back(), data)
999
1000    def testInserttableMaxValues(self):
1001        data = [(2 ** 15 - 1, int(2 ** 31 - 1), long(2 ** 31 - 1),
1002            True, '2999-12-31', '11:59:59', 1e99,
1003            1.0 + 1.0 / 32, 1.0 + 1.0 / 32, None,
1004            "1", "1234", "1234", "1234" * 100)]
1005        self.c.inserttable('test', data)
1006        self.assertEqual(self.get_back(), data)
1007
1008    def testInserttableByteValues(self):
1009        try:
1010            self.c.query("select '€', 'kÀse', 'сыр', 'pont-l''évêque'")
1011        except pg.ProgrammingError:
1012            self.skipTest("database does not support utf8")
1013        # non-ascii chars do not fit in char(1) when there is no encoding
1014        c = u'€' if self.has_encoding else u'$'
1015        row_unicode = (0, 0, long(0), False, u'1970-01-01', u'00:00:00',
1016            0.0, 0.0, 0.0, u'0.0',
1017            c, u'bÀd', u'bÀd', u"kÀse сыр pont-l'évêque")
1018        row_bytes = tuple(s.encode('utf-8')
1019            if isinstance(s, unicode) else s for s in row_unicode)
1020        data = [row_bytes] * 2
1021        self.c.inserttable('test', data)
1022        if unicode_strings:
1023            data = [row_unicode] * 2
1024        self.assertEqual(self.get_back(), data)
1025
1026    def testInserttableUnicodeUtf8(self):
1027        try:
1028            self.c.query("select '€', 'kÀse', 'сыр', 'pont-l''évêque'")
1029        except pg.ProgrammingError:
1030            self.skipTest("database does not support utf8")
1031        # non-ascii chars do not fit in char(1) when there is no encoding
1032        c = u'€' if self.has_encoding else u'$'
1033        row_unicode = (0, 0, long(0), False, u'1970-01-01', u'00:00:00',
1034            0.0, 0.0, 0.0, u'0.0',
1035            c, u'bÀd', u'bÀd', u"kÀse сыр pont-l'évêque")
1036        data = [row_unicode] * 2
1037        self.c.inserttable('test', data)
1038        if not unicode_strings:
1039            row_bytes = tuple(s.encode('utf-8')
1040                if isinstance(s, unicode) else s for s in row_unicode)
1041            data = [row_bytes] * 2
1042        self.assertEqual(self.get_back(), data)
1043
1044    def testInserttableUnicodeLatin1(self):
1045
1046        try:
1047            self.c.query("set client_encoding=latin1")
1048            self.c.query("select 'Â¥'")
1049        except pg.ProgrammingError:
1050            self.skipTest("database does not support latin1")
1051        # non-ascii chars do not fit in char(1) when there is no encoding
1052        c = u'€' if self.has_encoding else u'$'
1053        row_unicode = (0, 0, long(0), False, u'1970-01-01', u'00:00:00',
1054            0.0, 0.0, 0.0, u'0.0',
1055            c, u'bÀd', u'bÀd', u"for kÀse and pont-l'évêque pay in €")
1056        data = [row_unicode]
1057        # cannot encode € sign with latin1 encoding
1058        self.assertRaises(UnicodeEncodeError, self.c.inserttable, 'test', data)
1059        row_unicode = tuple(s.replace(u'€', u'Â¥')
1060            if isinstance(s, unicode) else s for s in row_unicode)
1061        data = [row_unicode] * 2
1062        self.c.inserttable('test', data)
1063        if not unicode_strings:
1064            row_bytes = tuple(s.encode('latin1')
1065                if isinstance(s, unicode) else s for s in row_unicode)
1066            data = [row_bytes] * 2
1067        self.assertEqual(self.get_back('latin1'), data)
1068
1069    def testInserttableUnicodeLatin9(self):
1070        try:
1071            self.c.query("set client_encoding=latin9")
1072            self.c.query("select '€'")
1073        except pg.ProgrammingError:
1074            self.skipTest("database does not support latin9")
1075            return
1076        # non-ascii chars do not fit in char(1) when there is no encoding
1077        c = u'€' if self.has_encoding else u'$'
1078        row_unicode = (0, 0, long(0), False, u'1970-01-01', u'00:00:00',
1079            0.0, 0.0, 0.0, u'0.0',
1080            c, u'bÀd', u'bÀd', u"for kÀse and pont-l'évêque pay in €")
1081        data = [row_unicode] * 2
1082        self.c.inserttable('test', data)
1083        if not unicode_strings:
1084            row_bytes = tuple(s.encode('latin9')
1085                if isinstance(s, unicode) else s for s in row_unicode)
1086            data = [row_bytes] * 2
1087        self.assertEqual(self.get_back('latin9'), data)
1088
1089    def testInserttableNoEncoding(self):
1090        self.c.query("set client_encoding=sql_ascii")
1091        # non-ascii chars do not fit in char(1) when there is no encoding
1092        c = u'€' if self.has_encoding else u'$'
1093        row_unicode = (0, 0, long(0), False, u'1970-01-01', u'00:00:00',
1094            0.0, 0.0, 0.0, u'0.0',
1095            c, u'bÀd', u'bÀd', u"for kÀse and pont-l'évêque pay in €")
1096        data = [row_unicode]
1097        # cannot encode non-ascii unicode without a specific encoding
1098        self.assertRaises(UnicodeEncodeError, self.c.inserttable, 'test', data)
1099
1100
1101class TestDirectSocketAccess(unittest.TestCase):
1102    """Test copy command with direct socket access."""
1103
1104    @classmethod
1105    def setUpClass(cls):
1106        c = connect()
1107        c.query("drop table if exists test cascade")
1108        c.query("create table test (i int, v varchar(16))")
1109        c.close()
1110
1111    @classmethod
1112    def tearDownClass(cls):
1113        c = connect()
1114        c.query("drop table test cascade")
1115        c.close()
1116
1117    def setUp(self):
1118        self.c = connect()
1119        self.c.query("set client_encoding=utf8")
1120
1121    def tearDown(self):
1122        self.c.query("truncate table test")
1123        self.c.close()
1124
1125    def testPutline(self):
1126        putline = self.c.putline
1127        query = self.c.query
1128        data = list(enumerate("apple pear plum cherry banana".split()))
1129        query("copy test from stdin")
1130        try:
1131            for i, v in data:
1132                putline("%d\t%s\n" % (i, v))
1133            putline("\\.\n")
1134        finally:
1135            self.c.endcopy()
1136        r = query("select * from test").getresult()
1137        self.assertEqual(r, data)
1138
1139    def testPutlineBytesAndUnicode(self):
1140        putline = self.c.putline
1141        query = self.c.query
1142        query("copy test from stdin")
1143        try:
1144            putline(u"47\tkÀse\n".encode('utf8'))
1145            putline("35\twÃŒrstel\n")
1146            putline(b"\\.\n")
1147        finally:
1148            self.c.endcopy()
1149        r = query("select * from test").getresult()
1150        self.assertEqual(r, [(47, 'kÀse'), (35, 'wÃŒrstel')])
1151
1152    def testGetline(self):
1153        getline = self.c.getline
1154        query = self.c.query
1155        data = list(enumerate("apple banana pear plum strawberry".split()))
1156        n = len(data)
1157        self.c.inserttable('test', data)
1158        query("copy test to stdout")
1159        try:
1160            for i in range(n + 2):
1161                v = getline()
1162                if i < n:
1163                    self.assertEqual(v, '%d\t%s' % data[i])
1164                elif i == n:
1165                    self.assertEqual(v, '\\.')
1166                else:
1167                    self.assertIsNone(v)
1168        finally:
1169            try:
1170                self.c.endcopy()
1171            except IOError:
1172                pass
1173
1174    def testGetlineBytesAndUnicode(self):
1175        getline = self.c.getline
1176        query = self.c.query
1177        data = [(54, u'kÀse'.encode('utf8')), (73, u'wÃŒrstel')]
1178        self.c.inserttable('test', data)
1179        query("copy test to stdout")
1180        try:
1181            v = getline()
1182            self.assertIsInstance(v, str)
1183            self.assertEqual(v, '54\tkÀse')
1184            v = getline()
1185            self.assertIsInstance(v, str)
1186            self.assertEqual(v, '73\twÃŒrstel')
1187            self.assertEqual(getline(), '\\.')
1188            self.assertIsNone(getline())
1189        finally:
1190            try:
1191                self.c.endcopy()
1192            except IOError:
1193                pass
1194
1195    def testParameterChecks(self):
1196        self.assertRaises(TypeError, self.c.putline)
1197        self.assertRaises(TypeError, self.c.getline, 'invalid')
1198        self.assertRaises(TypeError, self.c.endcopy, 'invalid')
1199
1200
1201class TestNotificatons(unittest.TestCase):
1202    """Test notification support."""
1203
1204    def setUp(self):
1205        self.c = connect()
1206
1207    def tearDown(self):
1208        self.doCleanups()
1209        self.c.close()
1210
1211    def testGetNotify(self):
1212        getnotify = self.c.getnotify
1213        query = self.c.query
1214        self.assertIsNone(getnotify())
1215        query('listen test_notify')
1216        try:
1217            self.assertIsNone(self.c.getnotify())
1218            query("notify test_notify")
1219            r = getnotify()
1220            self.assertIsInstance(r, tuple)
1221            self.assertEqual(len(r), 3)
1222            self.assertIsInstance(r[0], str)
1223            self.assertIsInstance(r[1], int)
1224            self.assertIsInstance(r[2], str)
1225            self.assertEqual(r[0], 'test_notify')
1226            self.assertEqual(r[2], '')
1227            self.assertIsNone(self.c.getnotify())
1228            query("notify test_notify, 'test_payload'")
1229            r = getnotify()
1230            self.assertTrue(isinstance(r, tuple))
1231            self.assertEqual(len(r), 3)
1232            self.assertIsInstance(r[0], str)
1233            self.assertIsInstance(r[1], int)
1234            self.assertIsInstance(r[2], str)
1235            self.assertEqual(r[0], 'test_notify')
1236            self.assertEqual(r[2], 'test_payload')
1237            self.assertIsNone(getnotify())
1238        finally:
1239            query('unlisten test_notify')
1240
1241    def testGetNoticeReceiver(self):
1242        self.assertIsNone(self.c.get_notice_receiver())
1243
1244    def testSetNoticeReceiver(self):
1245        self.assertRaises(TypeError, self.c.set_notice_receiver, None)
1246        self.assertRaises(TypeError, self.c.set_notice_receiver, 42)
1247        self.assertIsNone(self.c.set_notice_receiver(lambda notice: None))
1248
1249    def testSetAndGetNoticeReceiver(self):
1250        r = lambda notice: None
1251        self.assertIsNone(self.c.set_notice_receiver(r))
1252        self.assertIs(self.c.get_notice_receiver(), r)
1253
1254    def testNoticeReceiver(self):
1255        self.addCleanup(self.c.query, 'drop function bilbo_notice();')
1256        self.c.query('''create function bilbo_notice() returns void AS $$
1257            begin
1258                raise warning 'Bilbo was here!';
1259            end;
1260            $$ language plpgsql''')
1261        received = {}
1262
1263        def notice_receiver(notice):
1264            for attr in dir(notice):
1265                if attr.startswith('__'):
1266                    continue
1267                value = getattr(notice, attr)
1268                if isinstance(value, str):
1269                    value = value.replace('WARNUNG', 'WARNING')
1270                received[attr] = value
1271
1272        self.c.set_notice_receiver(notice_receiver)
1273        self.c.query('select bilbo_notice()')
1274        self.assertEqual(received, dict(
1275            pgcnx=self.c, message='WARNING:  Bilbo was here!\n',
1276            severity='WARNING', primary='Bilbo was here!',
1277            detail=None, hint=None))
1278
1279
1280class TestConfigFunctions(unittest.TestCase):
1281    """Test the functions for changing default settings.
1282
1283    To test the effect of most of these functions, we need a database
1284    connection.  That's why they are covered in this test module.
1285
1286    """
1287
1288    def setUp(self):
1289        self.c = connect()
1290        self.c.query("set client_encoding=utf8")
1291        self.c.query("set lc_monetary='C'")
1292
1293    def tearDown(self):
1294        self.c.close()
1295
1296    def testGetDecimalPoint(self):
1297        point = pg.get_decimal_point()
1298        # error if a parameter is passed
1299        self.assertRaises(TypeError, pg.get_decimal_point, point)
1300        self.assertIsInstance(point, str)
1301        self.assertEqual(point, '.')  # the default setting
1302        pg.set_decimal_point(',')
1303        try:
1304            r = pg.get_decimal_point()
1305        finally:
1306            pg.set_decimal_point(point)
1307        self.assertIsInstance(r, str)
1308        self.assertEqual(r, ',')
1309        pg.set_decimal_point("'")
1310        try:
1311            r = pg.get_decimal_point()
1312        finally:
1313            pg.set_decimal_point(point)
1314        self.assertIsInstance(r, str)
1315        self.assertEqual(r, "'")
1316        pg.set_decimal_point('')
1317        try:
1318            r = pg.get_decimal_point()
1319        finally:
1320            pg.set_decimal_point(point)
1321        self.assertIsNone(r)
1322        pg.set_decimal_point(None)
1323        try:
1324            r = pg.get_decimal_point()
1325        finally:
1326            pg.set_decimal_point(point)
1327        self.assertIsNone(r)
1328
1329    def testSetDecimalPoint(self):
1330        d = pg.Decimal
1331        point = pg.get_decimal_point()
1332        self.assertRaises(TypeError, pg.set_decimal_point)
1333        # error if decimal point is not a string
1334        self.assertRaises(TypeError, pg.set_decimal_point, 0)
1335        # error if more than one decimal point passed
1336        self.assertRaises(TypeError, pg.set_decimal_point, '.', ',')
1337        self.assertRaises(TypeError, pg.set_decimal_point, '.,')
1338        # error if decimal point is not a punctuation character
1339        self.assertRaises(TypeError, pg.set_decimal_point, '0')
1340        query = self.c.query
1341        # check that money values are interpreted as decimal values
1342        # only if decimal_point is set, and that the result is correct
1343        # only if it is set suitable for the current lc_monetary setting
1344        select_money = "select '34.25'::money"
1345        proper_money = d('34.25')
1346        bad_money = d('3425')
1347        en_locales = 'en', 'en_US', 'en_US.utf8', 'en_US.UTF-8'
1348        en_money = '$34.25', '$ 34.25', '34.25$', '34.25 $', '34.25 Dollar'
1349        de_locales = 'de', 'de_DE', 'de_DE.utf8', 'de_DE.UTF-8'
1350        de_money = ('34,25€', '34,25 €', '€34,25', '€ 34,25',
1351            'EUR34,25', 'EUR 34,25', '34,25 EUR', '34,25 Euro', '34,25 DM')
1352        # first try with English localization (using the point)
1353        for lc in en_locales:
1354            try:
1355                query("set lc_monetary='%s'" % lc)
1356            except pg.ProgrammingError:
1357                pass
1358            else:
1359                break
1360        else:
1361            self.skipTest("cannot set English money locale")
1362        try:
1363            r = query(select_money)
1364        except pg.ProgrammingError:
1365            # this can happen if the currency signs cannot be
1366            # converted using the encoding of the test database
1367            self.skipTest("database does not support English money")
1368        pg.set_decimal_point(None)
1369        try:
1370            r = r.getresult()[0][0]
1371        finally:
1372            pg.set_decimal_point(point)
1373        self.assertIsInstance(r, str)
1374        self.assertIn(r, en_money)
1375        r = query(select_money)
1376        pg.set_decimal_point('')
1377        try:
1378            r = r.getresult()[0][0]
1379        finally:
1380            pg.set_decimal_point(point)
1381        self.assertIsInstance(r, str)
1382        self.assertIn(r, en_money)
1383        r = query(select_money)
1384        pg.set_decimal_point('.')
1385        try:
1386            r = r.getresult()[0][0]
1387        finally:
1388            pg.set_decimal_point(point)
1389        self.assertIsInstance(r, d)
1390        self.assertEqual(r, proper_money)
1391        r = query(select_money)
1392        pg.set_decimal_point(',')
1393        try:
1394            r = r.getresult()[0][0]
1395        finally:
1396            pg.set_decimal_point(point)
1397        self.assertIsInstance(r, d)
1398        self.assertEqual(r, bad_money)
1399        r = query(select_money)
1400        pg.set_decimal_point("'")
1401        try:
1402            r = r.getresult()[0][0]
1403        finally:
1404            pg.set_decimal_point(point)
1405        self.assertIsInstance(r, d)
1406        self.assertEqual(r, bad_money)
1407        # then try with German localization (using the comma)
1408        for lc in de_locales:
1409            try:
1410                query("set lc_monetary='%s'" % lc)
1411            except pg.ProgrammingError:
1412                pass
1413            else:
1414                break
1415        else:
1416            self.skipTest("cannot set German money locale")
1417        select_money = select_money.replace('.', ',')
1418        try:
1419            r = query(select_money)
1420        except pg.ProgrammingError:
1421            self.skipTest("database does not support English money")
1422        pg.set_decimal_point(None)
1423        try:
1424            r = r.getresult()[0][0]
1425        finally:
1426            pg.set_decimal_point(point)
1427        self.assertIsInstance(r, str)
1428        self.assertIn(r, de_money)
1429        r = query(select_money)
1430        pg.set_decimal_point('')
1431        try:
1432            r = r.getresult()[0][0]
1433        finally:
1434            pg.set_decimal_point(point)
1435        self.assertIsInstance(r, str)
1436        self.assertIn(r, de_money)
1437        r = query(select_money)
1438        pg.set_decimal_point(',')
1439        try:
1440            r = r.getresult()[0][0]
1441        finally:
1442            pg.set_decimal_point(point)
1443        self.assertIsInstance(r, d)
1444        self.assertEqual(r, proper_money)
1445        r = query(select_money)
1446        pg.set_decimal_point('.')
1447        try:
1448            r = r.getresult()[0][0]
1449        finally:
1450            pg.set_decimal_point(point)
1451        self.assertEqual(r, bad_money)
1452        r = query(select_money)
1453        pg.set_decimal_point("'")
1454        try:
1455            r = r.getresult()[0][0]
1456        finally:
1457            pg.set_decimal_point(point)
1458        self.assertEqual(r, bad_money)
1459
1460    def testGetDecimal(self):
1461        decimal_class = pg.get_decimal()
1462        # error if a parameter is passed
1463        self.assertRaises(TypeError, pg.get_decimal, decimal_class)
1464        self.assertIs(decimal_class, pg.Decimal)  # the default setting
1465        pg.set_decimal(int)
1466        try:
1467            r = pg.get_decimal()
1468        finally:
1469            pg.set_decimal(decimal_class)
1470        self.assertIs(r, int)
1471        r = pg.get_decimal()
1472        self.assertIs(r, decimal_class)
1473
1474    def testSetDecimal(self):
1475        decimal_class = pg.get_decimal()
1476        # error if no parameter is passed
1477        self.assertRaises(TypeError, pg.set_decimal)
1478        query = self.c.query
1479        try:
1480            r = query("select 3425::numeric")
1481        except pg.ProgrammingError:
1482            self.skipTest('database does not support numeric')
1483        r = r.getresult()[0][0]
1484        self.assertIsInstance(r, decimal_class)
1485        self.assertEqual(r, decimal_class('3425'))
1486        r = query("select 3425::numeric")
1487        pg.set_decimal(int)
1488        try:
1489            r = r.getresult()[0][0]
1490        finally:
1491            pg.set_decimal(decimal_class)
1492        self.assertNotIsInstance(r, decimal_class)
1493        self.assertIsInstance(r, int)
1494        self.assertEqual(r, int(3425))
1495
1496    def testGetBool(self):
1497        use_bool = pg.get_bool()
1498        # error if a parameter is passed
1499        self.assertRaises(TypeError, pg.get_bool, use_bool)
1500        self.assertIsInstance(use_bool, bool)
1501        self.assertIs(use_bool, False)  # the default setting
1502        pg.set_bool(True)
1503        try:
1504            r = pg.get_bool()
1505        finally:
1506            pg.set_bool(use_bool)
1507        self.assertIsInstance(r, bool)
1508        self.assertIs(r, True)
1509        pg.set_bool(False)
1510        try:
1511            r = pg.get_bool()
1512        finally:
1513            pg.set_bool(use_bool)
1514        self.assertIsInstance(r, bool)
1515        self.assertIs(r, False)
1516        pg.set_bool(1)
1517        try:
1518            r = pg.get_bool()
1519        finally:
1520            pg.set_bool(use_bool)
1521        self.assertIsInstance(r, bool)
1522        self.assertIs(r, True)
1523        pg.set_bool(0)
1524        try:
1525            r = pg.get_bool()
1526        finally:
1527            pg.set_bool(use_bool)
1528        self.assertIsInstance(r, bool)
1529        self.assertIs(r, False)
1530
1531    def testSetBool(self):
1532        use_bool = pg.get_bool()
1533        # error if no parameter is passed
1534        self.assertRaises(TypeError, pg.set_bool)
1535        query = self.c.query
1536        try:
1537            r = query("select true::bool")
1538        except pg.ProgrammingError:
1539            self.skipTest('database does not support bool')
1540        r = r.getresult()[0][0]
1541        self.assertIsInstance(r, str)
1542        self.assertEqual(r, 't')
1543        r = query("select true::bool")
1544        pg.set_bool(True)
1545        try:
1546            r = r.getresult()[0][0]
1547        finally:
1548            pg.set_bool(use_bool)
1549        self.assertIsInstance(r, bool)
1550        self.assertIs(r, True)
1551        r = query("select true::bool")
1552        pg.set_bool(False)
1553        try:
1554            r = r.getresult()[0][0]
1555        finally:
1556            pg.set_bool(use_bool)
1557        self.assertIsInstance(r, str)
1558        self.assertIs(r, 't')
1559
1560    def testGetNamedresult(self):
1561        namedresult = pg.get_namedresult()
1562        # error if a parameter is passed
1563        self.assertRaises(TypeError, pg.get_namedresult, namedresult)
1564        self.assertIs(namedresult, pg._namedresult)  # the default setting
1565
1566    def testSetNamedresult(self):
1567        namedresult = pg.get_namedresult()
1568        self.assertTrue(callable(namedresult))
1569
1570        query = self.c.query
1571
1572        r = query("select 1 as x, 2 as y").namedresult()[0]
1573        self.assertIsInstance(r, tuple)
1574        self.assertEqual(r, (1, 2))
1575        self.assertIsNot(type(r), tuple)
1576        self.assertEqual(r._fields, ('x', 'y'))
1577        self.assertEqual(r._asdict(), {'x': 1, 'y': 2})
1578        self.assertEqual(r.__class__.__name__, 'Row')
1579
1580        def listresult(q):
1581            return [list(row) for row in q.getresult()]
1582
1583        pg.set_namedresult(listresult)
1584        try:
1585            r = pg.get_namedresult()
1586            self.assertIs(r, listresult)
1587            r = query("select 1 as x, 2 as y").namedresult()[0]
1588            self.assertIsInstance(r, list)
1589            self.assertEqual(r, [1, 2])
1590            self.assertIsNot(type(r), tuple)
1591            self.assertFalse(hasattr(r, '_fields'))
1592            self.assertNotEqual(r.__class__.__name__, 'Row')
1593        finally:
1594            pg.set_namedresult(namedresult)
1595
1596        r = pg.get_namedresult()
1597        self.assertIs(r, namedresult)
1598
1599
1600class TestStandaloneEscapeFunctions(unittest.TestCase):
1601    """Test pg escape functions.
1602
1603    The libpq interface memorizes some parameters of the last opened
1604    connection that influence the result of these functions.  Therefore
1605    we need to open a connection with fixed parameters prior to testing
1606    in order to ensure that the tests always run under the same conditions.
1607    That's why these tests are included in this test module.
1608
1609    """
1610
1611    @classmethod
1612    def setUpClass(cls):
1613        query = connect().query
1614        query('set client_encoding=sql_ascii')
1615        query('set standard_conforming_strings=off')
1616        query('set bytea_output=escape')
1617
1618    def testEscapeString(self):
1619        f = pg.escape_string
1620        r = f(b'plain')
1621        self.assertIsInstance(r, bytes)
1622        self.assertEqual(r, b'plain')
1623        r = f(u'plain')
1624        self.assertIsInstance(r, unicode)
1625        self.assertEqual(r, u'plain')
1626        r = f(u"das is' kÀse".encode('utf-8'))
1627        self.assertIsInstance(r, bytes)
1628        self.assertEqual(r, u"das is'' kÀse".encode('utf-8'))
1629        r = f(u"that's cheesy")
1630        self.assertIsInstance(r, unicode)
1631        self.assertEqual(r, u"that''s cheesy")
1632        r = f(r"It's bad to have a \ inside.")
1633        self.assertEqual(r, r"It''s bad to have a \\ inside.")
1634
1635    def testEscapeBytea(self):
1636        f = pg.escape_bytea
1637        r = f(b'plain')
1638        self.assertIsInstance(r, bytes)
1639        self.assertEqual(r, b'plain')
1640        r = f(u'plain')
1641        self.assertIsInstance(r, unicode)
1642        self.assertEqual(r, u'plain')
1643        r = f(u"das is' kÀse".encode('utf-8'))
1644        self.assertIsInstance(r, bytes)
1645        self.assertEqual(r, b"das is'' k\\\\303\\\\244se")
1646        r = f(u"that's cheesy")
1647        self.assertIsInstance(r, unicode)
1648        self.assertEqual(r, u"that''s cheesy")
1649        r = f(b'O\x00ps\xff!')
1650        self.assertEqual(r, b'O\\\\000ps\\\\377!')
1651
1652
1653if __name__ == '__main__':
1654    unittest.main()
Note: See TracBrowser for help on using the repository browser.